mas_data_model/upstream_oauth2/session.rs
1// Copyright 2024 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use chrono::{DateTime, Utc};
8use serde::Serialize;
9use ulid::Ulid;
10
11use super::UpstreamOAuthLink;
12use crate::InvalidTransitionError;
13
14#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize)]
15pub enum UpstreamOAuthAuthorizationSessionState {
16 #[default]
17 Pending,
18 Completed {
19 completed_at: DateTime<Utc>,
20 link_id: Ulid,
21 id_token: Option<String>,
22 extra_callback_parameters: Option<serde_json::Value>,
23 userinfo: Option<serde_json::Value>,
24 },
25 Consumed {
26 completed_at: DateTime<Utc>,
27 consumed_at: DateTime<Utc>,
28 link_id: Ulid,
29 id_token: Option<String>,
30 extra_callback_parameters: Option<serde_json::Value>,
31 userinfo: Option<serde_json::Value>,
32 },
33}
34
35impl UpstreamOAuthAuthorizationSessionState {
36 /// Mark the upstream OAuth 2.0 authorization session as completed.
37 ///
38 /// # Errors
39 ///
40 /// Returns an error if the upstream OAuth 2.0 authorization session state
41 /// is not [`Pending`].
42 ///
43 /// [`Pending`]: UpstreamOAuthAuthorizationSessionState::Pending
44 pub fn complete(
45 self,
46 completed_at: DateTime<Utc>,
47 link: &UpstreamOAuthLink,
48 id_token: Option<String>,
49 extra_callback_parameters: Option<serde_json::Value>,
50 userinfo: Option<serde_json::Value>,
51 ) -> Result<Self, InvalidTransitionError> {
52 match self {
53 Self::Pending => Ok(Self::Completed {
54 completed_at,
55 link_id: link.id,
56 id_token,
57 extra_callback_parameters,
58 userinfo,
59 }),
60 Self::Completed { .. } | Self::Consumed { .. } => Err(InvalidTransitionError),
61 }
62 }
63
64 /// Mark the upstream OAuth 2.0 authorization session as consumed.
65 ///
66 /// # Errors
67 ///
68 /// Returns an error if the upstream OAuth 2.0 authorization session state
69 /// is not [`Completed`].
70 ///
71 /// [`Completed`]: UpstreamOAuthAuthorizationSessionState::Completed
72 pub fn consume(self, consumed_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
73 match self {
74 Self::Completed {
75 completed_at,
76 link_id,
77 id_token,
78 extra_callback_parameters,
79 userinfo,
80 } => Ok(Self::Consumed {
81 completed_at,
82 link_id,
83 consumed_at,
84 id_token,
85 extra_callback_parameters,
86 userinfo,
87 }),
88 Self::Pending | Self::Consumed { .. } => Err(InvalidTransitionError),
89 }
90 }
91
92 /// Get the link ID for the upstream OAuth 2.0 authorization session.
93 ///
94 /// Returns `None` if the upstream OAuth 2.0 authorization session state is
95 /// [`Pending`].
96 ///
97 /// [`Pending`]: UpstreamOAuthAuthorizationSessionState::Pending
98 #[must_use]
99 pub fn link_id(&self) -> Option<Ulid> {
100 match self {
101 Self::Pending => None,
102 Self::Completed { link_id, .. } | Self::Consumed { link_id, .. } => Some(*link_id),
103 }
104 }
105
106 /// Get the time at which the upstream OAuth 2.0 authorization session was
107 /// completed.
108 ///
109 /// Returns `None` if the upstream OAuth 2.0 authorization session state is
110 /// [`Pending`].
111 ///
112 /// [`Pending`]: UpstreamOAuthAuthorizationSessionState::Pending
113 #[must_use]
114 pub fn completed_at(&self) -> Option<DateTime<Utc>> {
115 match self {
116 Self::Pending => None,
117 Self::Completed { completed_at, .. } | Self::Consumed { completed_at, .. } => {
118 Some(*completed_at)
119 }
120 }
121 }
122
123 /// Get the ID token for the upstream OAuth 2.0 authorization session.
124 ///
125 /// Returns `None` if the upstream OAuth 2.0 authorization session state is
126 /// [`Pending`].
127 ///
128 /// [`Pending`]: UpstreamOAuthAuthorizationSessionState::Pending
129 #[must_use]
130 pub fn id_token(&self) -> Option<&str> {
131 match self {
132 Self::Pending => None,
133 Self::Completed { id_token, .. } | Self::Consumed { id_token, .. } => {
134 id_token.as_deref()
135 }
136 }
137 }
138
139 /// Get the extra query parameters that were sent to the upstream provider.
140 ///
141 /// Returns `None` if the upstream OAuth 2.0 authorization session state is
142 /// not [`Pending`].
143 ///
144 /// [`Pending`]: UpstreamOAuthAuthorizationSessionState::Pending
145 #[must_use]
146 pub fn extra_callback_parameters(&self) -> Option<&serde_json::Value> {
147 match self {
148 Self::Pending => None,
149 Self::Completed {
150 extra_callback_parameters,
151 ..
152 }
153 | Self::Consumed {
154 extra_callback_parameters,
155 ..
156 } => extra_callback_parameters.as_ref(),
157 }
158 }
159
160 #[must_use]
161 pub fn userinfo(&self) -> Option<&serde_json::Value> {
162 match self {
163 Self::Pending => None,
164 Self::Completed { userinfo, .. } | Self::Consumed { userinfo, .. } => userinfo.as_ref(),
165 }
166 }
167
168 /// Get the time at which the upstream OAuth 2.0 authorization session was
169 /// consumed.
170 ///
171 /// Returns `None` if the upstream OAuth 2.0 authorization session state is
172 /// not [`Consumed`].
173 ///
174 /// [`Consumed`]: UpstreamOAuthAuthorizationSessionState::Consumed
175 #[must_use]
176 pub fn consumed_at(&self) -> Option<DateTime<Utc>> {
177 match self {
178 Self::Pending | Self::Completed { .. } => None,
179 Self::Consumed { consumed_at, .. } => Some(*consumed_at),
180 }
181 }
182
183 /// Returns `true` if the upstream OAuth 2.0 authorization session state is
184 /// [`Pending`].
185 ///
186 /// [`Pending`]: UpstreamOAuthAuthorizationSessionState::Pending
187 #[must_use]
188 pub fn is_pending(&self) -> bool {
189 matches!(self, Self::Pending)
190 }
191
192 /// Returns `true` if the upstream OAuth 2.0 authorization session state is
193 /// [`Completed`].
194 ///
195 /// [`Completed`]: UpstreamOAuthAuthorizationSessionState::Completed
196 #[must_use]
197 pub fn is_completed(&self) -> bool {
198 matches!(self, Self::Completed { .. })
199 }
200
201 /// Returns `true` if the upstream OAuth 2.0 authorization session state is
202 /// [`Consumed`].
203 ///
204 /// [`Consumed`]: UpstreamOAuthAuthorizationSessionState::Consumed
205 #[must_use]
206 pub fn is_consumed(&self) -> bool {
207 matches!(self, Self::Consumed { .. })
208 }
209}
210
211#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
212pub struct UpstreamOAuthAuthorizationSession {
213 pub id: Ulid,
214 pub state: UpstreamOAuthAuthorizationSessionState,
215 pub provider_id: Ulid,
216 pub state_str: String,
217 pub code_challenge_verifier: Option<String>,
218 pub nonce: String,
219 pub created_at: DateTime<Utc>,
220}
221
222impl std::ops::Deref for UpstreamOAuthAuthorizationSession {
223 type Target = UpstreamOAuthAuthorizationSessionState;
224
225 fn deref(&self) -> &Self::Target {
226 &self.state
227 }
228}
229
230impl UpstreamOAuthAuthorizationSession {
231 /// Mark the upstream OAuth 2.0 authorization session as completed. Returns
232 /// the updated session.
233 ///
234 /// # Errors
235 ///
236 /// Returns an error if the upstream OAuth 2.0 authorization session state
237 /// is not [`Pending`].
238 ///
239 /// [`Pending`]: UpstreamOAuthAuthorizationSessionState::Pending
240 pub fn complete(
241 mut self,
242 completed_at: DateTime<Utc>,
243 link: &UpstreamOAuthLink,
244 id_token: Option<String>,
245 extra_callback_parameters: Option<serde_json::Value>,
246 userinfo: Option<serde_json::Value>,
247 ) -> Result<Self, InvalidTransitionError> {
248 self.state = self.state.complete(
249 completed_at,
250 link,
251 id_token,
252 extra_callback_parameters,
253 userinfo,
254 )?;
255 Ok(self)
256 }
257
258 /// Mark the upstream OAuth 2.0 authorization session as consumed. Returns
259 /// the updated session.
260 ///
261 /// # Errors
262 ///
263 /// Returns an error if the upstream OAuth 2.0 authorization session state
264 /// is not [`Completed`].
265 ///
266 /// [`Completed`]: UpstreamOAuthAuthorizationSessionState::Completed
267 pub fn consume(mut self, consumed_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
268 self.state = self.state.consume(consumed_at)?;
269 Ok(self)
270 }
271}