1use std::num::NonZeroU32;
8
9use chrono::{DateTime, Duration, Utc};
10use mas_iana::oauth::PkceCodeChallengeMethod;
11use oauth2_types::{
12 pkce::{CodeChallengeError, CodeChallengeMethodExt},
13 requests::ResponseMode,
14 scope::{OPENID, PROFILE, Scope},
15};
16use rand::{
17 RngCore,
18 distributions::{Alphanumeric, DistString},
19};
20use ruma_common::UserId;
21use serde::Serialize;
22use ulid::Ulid;
23use url::Url;
24
25use super::session::Session;
26use crate::InvalidTransitionError;
27
28#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
29pub struct Pkce {
30 pub challenge_method: PkceCodeChallengeMethod,
31 pub challenge: String,
32}
33
34impl Pkce {
35 #[must_use]
37 pub fn new(challenge_method: PkceCodeChallengeMethod, challenge: String) -> Self {
38 Pkce {
39 challenge_method,
40 challenge,
41 }
42 }
43
44 pub fn verify(&self, verifier: &str) -> Result<(), CodeChallengeError> {
50 self.challenge_method.verify(&self.challenge, verifier)
51 }
52}
53
54#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
55pub struct AuthorizationCode {
56 pub code: String,
57 pub pkce: Option<Pkce>,
58}
59
60#[derive(Debug, Clone, PartialEq, Eq, Serialize, Default)]
61#[serde(tag = "stage", rename_all = "lowercase")]
62pub enum AuthorizationGrantStage {
63 #[default]
64 Pending,
65 Fulfilled {
66 session_id: Ulid,
67 fulfilled_at: DateTime<Utc>,
68 },
69 Exchanged {
70 session_id: Ulid,
71 fulfilled_at: DateTime<Utc>,
72 exchanged_at: DateTime<Utc>,
73 },
74 Cancelled {
75 cancelled_at: DateTime<Utc>,
76 },
77}
78
79impl AuthorizationGrantStage {
80 #[must_use]
81 pub fn new() -> Self {
82 Self::Pending
83 }
84
85 fn fulfill(
86 self,
87 fulfilled_at: DateTime<Utc>,
88 session: &Session,
89 ) -> Result<Self, InvalidTransitionError> {
90 match self {
91 Self::Pending => Ok(Self::Fulfilled {
92 fulfilled_at,
93 session_id: session.id,
94 }),
95 _ => Err(InvalidTransitionError),
96 }
97 }
98
99 fn exchange(self, exchanged_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
100 match self {
101 Self::Fulfilled {
102 fulfilled_at,
103 session_id,
104 } => Ok(Self::Exchanged {
105 fulfilled_at,
106 exchanged_at,
107 session_id,
108 }),
109 _ => Err(InvalidTransitionError),
110 }
111 }
112
113 fn cancel(self, cancelled_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
114 match self {
115 Self::Pending => Ok(Self::Cancelled { cancelled_at }),
116 _ => Err(InvalidTransitionError),
117 }
118 }
119
120 #[must_use]
124 pub fn is_pending(&self) -> bool {
125 matches!(self, Self::Pending)
126 }
127
128 #[must_use]
132 pub fn is_fulfilled(&self) -> bool {
133 matches!(self, Self::Fulfilled { .. })
134 }
135
136 #[must_use]
140 pub fn is_exchanged(&self) -> bool {
141 matches!(self, Self::Exchanged { .. })
142 }
143}
144
145pub enum LoginHint<'a> {
146 MXID(&'a UserId),
147 None,
148}
149
150#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
151pub struct AuthorizationGrant {
152 pub id: Ulid,
153 #[serde(flatten)]
154 pub stage: AuthorizationGrantStage,
155 pub code: Option<AuthorizationCode>,
156 pub client_id: Ulid,
157 pub redirect_uri: Url,
158 pub scope: Scope,
159 pub state: Option<String>,
160 pub nonce: Option<String>,
161 pub max_age: Option<NonZeroU32>,
162 pub response_mode: ResponseMode,
163 pub response_type_id_token: bool,
164 pub created_at: DateTime<Utc>,
165 pub requires_consent: bool,
166 pub login_hint: Option<String>,
167}
168
169impl std::ops::Deref for AuthorizationGrant {
170 type Target = AuthorizationGrantStage;
171
172 fn deref(&self) -> &Self::Target {
173 &self.stage
174 }
175}
176
177const DEFAULT_MAX_AGE: Duration = Duration::microseconds(3600 * 24 * 365 * 1000 * 1000);
178
179impl AuthorizationGrant {
180 #[must_use]
181 pub fn max_auth_time(&self) -> DateTime<Utc> {
182 let max_age = self
183 .max_age
184 .and_then(|x| Duration::try_seconds(x.get().into()))
185 .unwrap_or(DEFAULT_MAX_AGE);
186 self.created_at - max_age
187 }
188
189 #[must_use]
190 pub fn parse_login_hint(&self, homeserver: &str) -> LoginHint {
191 let Some(login_hint) = &self.login_hint else {
192 return LoginHint::None;
193 };
194
195 let Some((prefix, value)) = login_hint.split_once(':') else {
197 return LoginHint::None;
198 };
199
200 match prefix {
201 "mxid" => {
202 let Ok(mxid) = <&UserId>::try_from(value) else {
204 return LoginHint::None;
205 };
206
207 if mxid.server_name() != homeserver {
209 return LoginHint::None;
210 }
211
212 LoginHint::MXID(mxid)
213 }
214 _ => LoginHint::None,
216 }
217 }
218
219 pub fn exchange(mut self, exchanged_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
227 self.stage = self.stage.exchange(exchanged_at)?;
228 Ok(self)
229 }
230
231 pub fn fulfill(
239 mut self,
240 fulfilled_at: DateTime<Utc>,
241 session: &Session,
242 ) -> Result<Self, InvalidTransitionError> {
243 self.stage = self.stage.fulfill(fulfilled_at, session)?;
244 Ok(self)
245 }
246
247 pub fn cancel(mut self, canceld_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
259 self.stage = self.stage.cancel(canceld_at)?;
260 Ok(self)
261 }
262
263 #[doc(hidden)]
264 pub fn sample(now: DateTime<Utc>, rng: &mut impl RngCore) -> Self {
265 Self {
266 id: Ulid::from_datetime_with_source(now.into(), rng),
267 stage: AuthorizationGrantStage::Pending,
268 code: Some(AuthorizationCode {
269 code: Alphanumeric.sample_string(rng, 10),
270 pkce: None,
271 }),
272 client_id: Ulid::from_datetime_with_source(now.into(), rng),
273 redirect_uri: Url::parse("http://localhost:8080").unwrap(),
274 scope: Scope::from_iter([OPENID, PROFILE]),
275 state: Some(Alphanumeric.sample_string(rng, 10)),
276 nonce: Some(Alphanumeric.sample_string(rng, 10)),
277 max_age: None,
278 response_mode: ResponseMode::Query,
279 response_type_id_token: false,
280 created_at: now,
281 requires_consent: false,
282 login_hint: Some(String::from("mxid:@example-user:example.com")),
283 }
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use rand::thread_rng;
290
291 use super::*;
292
293 #[test]
294 fn no_login_hint() {
295 #[allow(clippy::disallowed_methods)]
296 let mut rng = thread_rng();
297
298 #[allow(clippy::disallowed_methods)]
299 let now = Utc::now();
300
301 let grant = AuthorizationGrant {
302 login_hint: None,
303 ..AuthorizationGrant::sample(now, &mut rng)
304 };
305
306 let hint = grant.parse_login_hint("example.com");
307
308 assert!(matches!(hint, LoginHint::None));
309 }
310
311 #[test]
312 fn valid_login_hint() {
313 #[allow(clippy::disallowed_methods)]
314 let mut rng = thread_rng();
315
316 #[allow(clippy::disallowed_methods)]
317 let now = Utc::now();
318
319 let grant = AuthorizationGrant {
320 login_hint: Some(String::from("mxid:@example-user:example.com")),
321 ..AuthorizationGrant::sample(now, &mut rng)
322 };
323
324 let hint = grant.parse_login_hint("example.com");
325
326 assert!(matches!(hint, LoginHint::MXID(mxid) if mxid.localpart() == "example-user"));
327 }
328
329 #[test]
330 fn invalid_login_hint() {
331 #[allow(clippy::disallowed_methods)]
332 let mut rng = thread_rng();
333
334 #[allow(clippy::disallowed_methods)]
335 let now = Utc::now();
336
337 let grant = AuthorizationGrant {
338 login_hint: Some(String::from("example-user")),
339 ..AuthorizationGrant::sample(now, &mut rng)
340 };
341
342 let hint = grant.parse_login_hint("example.com");
343
344 assert!(matches!(hint, LoginHint::None));
345 }
346
347 #[test]
348 fn valid_login_hint_for_wrong_homeserver() {
349 #[allow(clippy::disallowed_methods)]
350 let mut rng = thread_rng();
351
352 #[allow(clippy::disallowed_methods)]
353 let now = Utc::now();
354
355 let grant = AuthorizationGrant {
356 login_hint: Some(String::from("mxid:@example-user:matrix.org")),
357 ..AuthorizationGrant::sample(now, &mut rng)
358 };
359
360 let hint = grant.parse_login_hint("example.com");
361
362 assert!(matches!(hint, LoginHint::None));
363 }
364
365 #[test]
366 fn unknown_login_hint_type() {
367 #[allow(clippy::disallowed_methods)]
368 let mut rng = thread_rng();
369
370 #[allow(clippy::disallowed_methods)]
371 let now = Utc::now();
372
373 let grant = AuthorizationGrant {
374 login_hint: Some(String::from("something:anything")),
375 ..AuthorizationGrant::sample(now, &mut rng)
376 };
377
378 let hint = grant.parse_login_hint("example.com");
379
380 assert!(matches!(hint, LoginHint::None));
381 }
382}