mas_data_model/upstream_oauth2/
session.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use chrono::{DateTime, Utc};
8use serde::Serialize;
9use ulid::Ulid;
10
11use super::UpstreamOAuthLink;
12use crate::InvalidTransitionError;
13
14#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize)]
15pub enum UpstreamOAuthAuthorizationSessionState {
16    #[default]
17    Pending,
18    Completed {
19        completed_at: DateTime<Utc>,
20        link_id: Ulid,
21        id_token: Option<String>,
22        extra_callback_parameters: Option<serde_json::Value>,
23        userinfo: Option<serde_json::Value>,
24    },
25    Consumed {
26        completed_at: DateTime<Utc>,
27        consumed_at: DateTime<Utc>,
28        link_id: Ulid,
29        id_token: Option<String>,
30        extra_callback_parameters: Option<serde_json::Value>,
31        userinfo: Option<serde_json::Value>,
32    },
33}
34
35impl UpstreamOAuthAuthorizationSessionState {
36    /// Mark the upstream OAuth 2.0 authorization session as completed.
37    ///
38    /// # Errors
39    ///
40    /// Returns an error if the upstream OAuth 2.0 authorization session state
41    /// is not [`Pending`].
42    ///
43    /// [`Pending`]: UpstreamOAuthAuthorizationSessionState::Pending
44    pub fn complete(
45        self,
46        completed_at: DateTime<Utc>,
47        link: &UpstreamOAuthLink,
48        id_token: Option<String>,
49        extra_callback_parameters: Option<serde_json::Value>,
50        userinfo: Option<serde_json::Value>,
51    ) -> Result<Self, InvalidTransitionError> {
52        match self {
53            Self::Pending => Ok(Self::Completed {
54                completed_at,
55                link_id: link.id,
56                id_token,
57                extra_callback_parameters,
58                userinfo,
59            }),
60            Self::Completed { .. } | Self::Consumed { .. } => Err(InvalidTransitionError),
61        }
62    }
63
64    /// Mark the upstream OAuth 2.0 authorization session as consumed.
65    ///
66    /// # Errors
67    ///
68    /// Returns an error if the upstream OAuth 2.0 authorization session state
69    /// is not [`Completed`].
70    ///
71    /// [`Completed`]: UpstreamOAuthAuthorizationSessionState::Completed
72    pub fn consume(self, consumed_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
73        match self {
74            Self::Completed {
75                completed_at,
76                link_id,
77                id_token,
78                extra_callback_parameters,
79                userinfo,
80            } => Ok(Self::Consumed {
81                completed_at,
82                link_id,
83                consumed_at,
84                id_token,
85                extra_callback_parameters,
86                userinfo,
87            }),
88            Self::Pending | Self::Consumed { .. } => Err(InvalidTransitionError),
89        }
90    }
91
92    /// Get the link ID for the upstream OAuth 2.0 authorization session.
93    ///
94    /// Returns `None` if the upstream OAuth 2.0 authorization session state is
95    /// [`Pending`].
96    ///
97    /// [`Pending`]: UpstreamOAuthAuthorizationSessionState::Pending
98    #[must_use]
99    pub fn link_id(&self) -> Option<Ulid> {
100        match self {
101            Self::Pending => None,
102            Self::Completed { link_id, .. } | Self::Consumed { link_id, .. } => Some(*link_id),
103        }
104    }
105
106    /// Get the time at which the upstream OAuth 2.0 authorization session was
107    /// completed.
108    ///
109    /// Returns `None` if the upstream OAuth 2.0 authorization session state is
110    /// [`Pending`].
111    ///
112    /// [`Pending`]: UpstreamOAuthAuthorizationSessionState::Pending
113    #[must_use]
114    pub fn completed_at(&self) -> Option<DateTime<Utc>> {
115        match self {
116            Self::Pending => None,
117            Self::Completed { completed_at, .. } | Self::Consumed { completed_at, .. } => {
118                Some(*completed_at)
119            }
120        }
121    }
122
123    /// Get the ID token for the upstream OAuth 2.0 authorization session.
124    ///
125    /// Returns `None` if the upstream OAuth 2.0 authorization session state is
126    /// [`Pending`].
127    ///
128    /// [`Pending`]: UpstreamOAuthAuthorizationSessionState::Pending
129    #[must_use]
130    pub fn id_token(&self) -> Option<&str> {
131        match self {
132            Self::Pending => None,
133            Self::Completed { id_token, .. } | Self::Consumed { id_token, .. } => {
134                id_token.as_deref()
135            }
136        }
137    }
138
139    /// Get the extra query parameters that were sent to the upstream provider.
140    ///
141    /// Returns `None` if the upstream OAuth 2.0 authorization session state is
142    /// not [`Pending`].
143    ///
144    /// [`Pending`]: UpstreamOAuthAuthorizationSessionState::Pending
145    #[must_use]
146    pub fn extra_callback_parameters(&self) -> Option<&serde_json::Value> {
147        match self {
148            Self::Pending => None,
149            Self::Completed {
150                extra_callback_parameters,
151                ..
152            }
153            | Self::Consumed {
154                extra_callback_parameters,
155                ..
156            } => extra_callback_parameters.as_ref(),
157        }
158    }
159
160    #[must_use]
161    pub fn userinfo(&self) -> Option<&serde_json::Value> {
162        match self {
163            Self::Pending => None,
164            Self::Completed { userinfo, .. } | Self::Consumed { userinfo, .. } => userinfo.as_ref(),
165        }
166    }
167
168    /// Get the time at which the upstream OAuth 2.0 authorization session was
169    /// consumed.
170    ///
171    /// Returns `None` if the upstream OAuth 2.0 authorization session state is
172    /// not [`Consumed`].
173    ///
174    /// [`Consumed`]: UpstreamOAuthAuthorizationSessionState::Consumed
175    #[must_use]
176    pub fn consumed_at(&self) -> Option<DateTime<Utc>> {
177        match self {
178            Self::Pending | Self::Completed { .. } => None,
179            Self::Consumed { consumed_at, .. } => Some(*consumed_at),
180        }
181    }
182
183    /// Returns `true` if the upstream OAuth 2.0 authorization session state is
184    /// [`Pending`].
185    ///
186    /// [`Pending`]: UpstreamOAuthAuthorizationSessionState::Pending
187    #[must_use]
188    pub fn is_pending(&self) -> bool {
189        matches!(self, Self::Pending)
190    }
191
192    /// Returns `true` if the upstream OAuth 2.0 authorization session state is
193    /// [`Completed`].
194    ///
195    /// [`Completed`]: UpstreamOAuthAuthorizationSessionState::Completed
196    #[must_use]
197    pub fn is_completed(&self) -> bool {
198        matches!(self, Self::Completed { .. })
199    }
200
201    /// Returns `true` if the upstream OAuth 2.0 authorization session state is
202    /// [`Consumed`].
203    ///
204    /// [`Consumed`]: UpstreamOAuthAuthorizationSessionState::Consumed
205    #[must_use]
206    pub fn is_consumed(&self) -> bool {
207        matches!(self, Self::Consumed { .. })
208    }
209}
210
211#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
212pub struct UpstreamOAuthAuthorizationSession {
213    pub id: Ulid,
214    pub state: UpstreamOAuthAuthorizationSessionState,
215    pub provider_id: Ulid,
216    pub state_str: String,
217    pub code_challenge_verifier: Option<String>,
218    pub nonce: String,
219    pub created_at: DateTime<Utc>,
220}
221
222impl std::ops::Deref for UpstreamOAuthAuthorizationSession {
223    type Target = UpstreamOAuthAuthorizationSessionState;
224
225    fn deref(&self) -> &Self::Target {
226        &self.state
227    }
228}
229
230impl UpstreamOAuthAuthorizationSession {
231    /// Mark the upstream OAuth 2.0 authorization session as completed. Returns
232    /// the updated session.
233    ///
234    /// # Errors
235    ///
236    /// Returns an error if the upstream OAuth 2.0 authorization session state
237    /// is not [`Pending`].
238    ///
239    /// [`Pending`]: UpstreamOAuthAuthorizationSessionState::Pending
240    pub fn complete(
241        mut self,
242        completed_at: DateTime<Utc>,
243        link: &UpstreamOAuthLink,
244        id_token: Option<String>,
245        extra_callback_parameters: Option<serde_json::Value>,
246        userinfo: Option<serde_json::Value>,
247    ) -> Result<Self, InvalidTransitionError> {
248        self.state = self.state.complete(
249            completed_at,
250            link,
251            id_token,
252            extra_callback_parameters,
253            userinfo,
254        )?;
255        Ok(self)
256    }
257
258    /// Mark the upstream OAuth 2.0 authorization session as consumed. Returns
259    /// the updated session.
260    ///
261    /// # Errors
262    ///
263    /// Returns an error if the upstream OAuth 2.0 authorization session state
264    /// is not [`Completed`].
265    ///
266    /// [`Completed`]: UpstreamOAuthAuthorizationSessionState::Completed
267    pub fn consume(mut self, consumed_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
268        self.state = self.state.consume(consumed_at)?;
269        Ok(self)
270    }
271}