1use base64ct::{Base64UrlUnpadded, Encoding};
8use chrono::{DateTime, Utc};
9use crc::{CRC_32_ISO_HDLC, Crc};
10use mas_iana::oauth::OAuthTokenTypeHint;
11use rand::{Rng, RngCore, distributions::Alphanumeric};
12use thiserror::Error;
13use ulid::Ulid;
14
15use crate::InvalidTransitionError;
16
17#[derive(Debug, Clone, Default, PartialEq, Eq)]
18pub enum AccessTokenState {
19 #[default]
20 Valid,
21 Revoked {
22 revoked_at: DateTime<Utc>,
23 },
24}
25
26impl AccessTokenState {
27 fn revoke(self, revoked_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
28 match self {
29 Self::Valid => Ok(Self::Revoked { revoked_at }),
30 Self::Revoked { .. } => Err(InvalidTransitionError),
31 }
32 }
33
34 #[must_use]
38 pub fn is_valid(&self) -> bool {
39 matches!(self, Self::Valid)
40 }
41
42 #[must_use]
46 pub fn is_revoked(&self) -> bool {
47 matches!(self, Self::Revoked { .. })
48 }
49}
50
51#[derive(Debug, Clone, PartialEq, Eq)]
52pub struct AccessToken {
53 pub id: Ulid,
54 pub state: AccessTokenState,
55 pub session_id: Ulid,
56 pub access_token: String,
57 pub created_at: DateTime<Utc>,
58 pub expires_at: Option<DateTime<Utc>>,
59 pub first_used_at: Option<DateTime<Utc>>,
60}
61
62impl AccessToken {
63 #[must_use]
64 pub fn jti(&self) -> String {
65 self.id.to_string()
66 }
67
68 #[must_use]
74 pub fn is_valid(&self, now: DateTime<Utc>) -> bool {
75 self.state.is_valid() && !self.is_expired(now)
76 }
77
78 #[must_use]
86 pub fn is_expired(&self, now: DateTime<Utc>) -> bool {
87 match self.expires_at {
88 Some(expires_at) => expires_at < now,
89 None => false,
90 }
91 }
92
93 #[must_use]
95 pub fn is_used(&self) -> bool {
96 self.first_used_at.is_some()
97 }
98
99 pub fn revoke(mut self, revoked_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
109 self.state = self.state.revoke(revoked_at)?;
110 Ok(self)
111 }
112}
113
114#[derive(Debug, Clone, Default, PartialEq, Eq)]
115pub enum RefreshTokenState {
116 #[default]
117 Valid,
118 Consumed {
119 consumed_at: DateTime<Utc>,
120 next_refresh_token_id: Option<Ulid>,
121 },
122 Revoked {
123 revoked_at: DateTime<Utc>,
124 },
125}
126
127impl RefreshTokenState {
128 fn consume(
134 self,
135 consumed_at: DateTime<Utc>,
136 replaced_by: &RefreshToken,
137 ) -> Result<Self, InvalidTransitionError> {
138 match self {
139 Self::Valid | Self::Consumed { .. } => Ok(Self::Consumed {
140 consumed_at,
141 next_refresh_token_id: Some(replaced_by.id),
142 }),
143 Self::Revoked { .. } => Err(InvalidTransitionError),
144 }
145 }
146
147 pub fn revoke(self, revoked_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
153 match self {
154 Self::Valid => Ok(Self::Revoked { revoked_at }),
155 Self::Consumed { .. } | Self::Revoked { .. } => Err(InvalidTransitionError),
156 }
157 }
158
159 #[must_use]
163 pub fn is_valid(&self) -> bool {
164 matches!(self, Self::Valid)
165 }
166
167 #[must_use]
169 pub fn next_refresh_token_id(&self) -> Option<Ulid> {
170 match self {
171 Self::Valid | Self::Revoked { .. } => None,
172 Self::Consumed {
173 next_refresh_token_id,
174 ..
175 } => *next_refresh_token_id,
176 }
177 }
178}
179
180#[derive(Debug, Clone, PartialEq, Eq)]
181pub struct RefreshToken {
182 pub id: Ulid,
183 pub state: RefreshTokenState,
184 pub refresh_token: String,
185 pub session_id: Ulid,
186 pub created_at: DateTime<Utc>,
187 pub access_token_id: Option<Ulid>,
188}
189
190impl std::ops::Deref for RefreshToken {
191 type Target = RefreshTokenState;
192
193 fn deref(&self) -> &Self::Target {
194 &self.state
195 }
196}
197
198impl RefreshToken {
199 #[must_use]
200 pub fn jti(&self) -> String {
201 self.id.to_string()
202 }
203
204 pub fn consume(
210 mut self,
211 consumed_at: DateTime<Utc>,
212 replaced_by: &Self,
213 ) -> Result<Self, InvalidTransitionError> {
214 self.state = self.state.consume(consumed_at, replaced_by)?;
215 Ok(self)
216 }
217
218 pub fn revoke(mut self, revoked_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
224 self.state = self.state.revoke(revoked_at)?;
225 Ok(self)
226 }
227}
228
229#[derive(Debug, Clone, Copy, PartialEq, Eq)]
231pub enum TokenType {
232 AccessToken,
234
235 RefreshToken,
237
238 CompatAccessToken,
240
241 CompatRefreshToken,
243}
244
245impl std::fmt::Display for TokenType {
246 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
247 match self {
248 TokenType::AccessToken => write!(f, "access token"),
249 TokenType::RefreshToken => write!(f, "refresh token"),
250 TokenType::CompatAccessToken => write!(f, "compat access token"),
251 TokenType::CompatRefreshToken => write!(f, "compat refresh token"),
252 }
253 }
254}
255
256impl TokenType {
257 fn prefix(self) -> &'static str {
258 match self {
259 TokenType::AccessToken => "mat",
260 TokenType::RefreshToken => "mar",
261 TokenType::CompatAccessToken => "mct",
262 TokenType::CompatRefreshToken => "mcr",
263 }
264 }
265
266 fn match_prefix(prefix: &str) -> Option<Self> {
267 match prefix {
268 "mat" => Some(TokenType::AccessToken),
269 "mar" => Some(TokenType::RefreshToken),
270 "mct" | "syt" => Some(TokenType::CompatAccessToken),
271 "mcr" | "syr" => Some(TokenType::CompatRefreshToken),
272 _ => None,
273 }
274 }
275
276 pub fn generate(self, rng: &mut (impl RngCore + ?Sized)) -> String {
278 let random_part: String = rng
279 .sample_iter(&Alphanumeric)
280 .take(30)
281 .map(char::from)
282 .collect();
283
284 let base = format!("{prefix}_{random_part}", prefix = self.prefix());
285 let crc = CRC.checksum(base.as_bytes());
286 let crc = base62_encode(crc);
287 format!("{base}_{crc}")
288 }
289
290 pub fn check(token: &str) -> Result<TokenType, TokenFormatError> {
296 if token.starts_with("syt_") || is_likely_synapse_macaroon(token) {
299 return Ok(TokenType::CompatAccessToken);
300 }
301 if token.starts_with("syr_") {
302 return Ok(TokenType::CompatRefreshToken);
303 }
304
305 let split: Vec<&str> = token.split('_').collect();
306 let [prefix, random_part, crc]: [&str; 3] = split
307 .try_into()
308 .map_err(|_| TokenFormatError::InvalidFormat)?;
309
310 if prefix.len() != 3 || random_part.len() != 30 || crc.len() != 6 {
311 return Err(TokenFormatError::InvalidFormat);
312 }
313
314 let token_type =
315 TokenType::match_prefix(prefix).ok_or_else(|| TokenFormatError::UnknownPrefix {
316 prefix: prefix.to_owned(),
317 })?;
318
319 let base = format!("{prefix}_{random_part}", prefix = token_type.prefix());
320 let expected_crc = CRC.checksum(base.as_bytes());
321 let expected_crc = base62_encode(expected_crc);
322 if crc != expected_crc {
323 return Err(TokenFormatError::InvalidCrc {
324 expected: expected_crc,
325 got: crc.to_owned(),
326 });
327 }
328
329 Ok(token_type)
330 }
331}
332
333impl PartialEq<OAuthTokenTypeHint> for TokenType {
334 fn eq(&self, other: &OAuthTokenTypeHint) -> bool {
335 matches!(
336 (self, other),
337 (
338 TokenType::AccessToken | TokenType::CompatAccessToken,
339 OAuthTokenTypeHint::AccessToken
340 ) | (
341 TokenType::RefreshToken | TokenType::CompatRefreshToken,
342 OAuthTokenTypeHint::RefreshToken
343 )
344 )
345 }
346}
347
348fn is_likely_synapse_macaroon(token: &str) -> bool {
356 let Ok(decoded) = Base64UrlUnpadded::decode_vec(token) else {
357 return false;
358 };
359 decoded.get(4..13) == Some(b"location ")
360}
361
362const NUM: [u8; 62] = *b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
363
364fn base62_encode(mut num: u32) -> String {
365 let mut res = String::with_capacity(6);
366 while num > 0 {
367 res.push(NUM[(num % 62) as usize] as char);
368 num /= 62;
369 }
370
371 format!("{res:0>6}")
372}
373
374const CRC: Crc<u32> = Crc::<u32>::new(&CRC_32_ISO_HDLC);
375
376#[derive(Debug, Error, PartialEq, Eq)]
378pub enum TokenFormatError {
379 #[error("invalid token format")]
381 InvalidFormat,
382
383 #[error("unknown token prefix {prefix:?}")]
385 UnknownPrefix {
386 prefix: String,
388 },
389
390 #[error("invalid crc {got:?}, expected {expected:?}")]
392 InvalidCrc {
393 expected: String,
395 got: String,
397 },
398}
399
400#[cfg(test)]
401mod tests {
402 use std::collections::HashSet;
403
404 use rand::thread_rng;
405
406 use super::*;
407
408 #[test]
409 fn test_prefix_match() {
410 use TokenType::{AccessToken, CompatAccessToken, CompatRefreshToken, RefreshToken};
411 assert_eq!(TokenType::match_prefix("syt"), Some(CompatAccessToken));
412 assert_eq!(TokenType::match_prefix("syr"), Some(CompatRefreshToken));
413 assert_eq!(TokenType::match_prefix("mct"), Some(CompatAccessToken));
414 assert_eq!(TokenType::match_prefix("mcr"), Some(CompatRefreshToken));
415 assert_eq!(TokenType::match_prefix("mat"), Some(AccessToken));
416 assert_eq!(TokenType::match_prefix("mar"), Some(RefreshToken));
417 assert_eq!(TokenType::match_prefix("matt"), None);
418 assert_eq!(TokenType::match_prefix("marr"), None);
419 assert_eq!(TokenType::match_prefix("ma"), None);
420 assert_eq!(
421 TokenType::match_prefix(TokenType::CompatAccessToken.prefix()),
422 Some(TokenType::CompatAccessToken)
423 );
424 assert_eq!(
425 TokenType::match_prefix(TokenType::CompatRefreshToken.prefix()),
426 Some(TokenType::CompatRefreshToken)
427 );
428 assert_eq!(
429 TokenType::match_prefix(TokenType::AccessToken.prefix()),
430 Some(TokenType::AccessToken)
431 );
432 assert_eq!(
433 TokenType::match_prefix(TokenType::RefreshToken.prefix()),
434 Some(TokenType::RefreshToken)
435 );
436 }
437
438 #[test]
439 fn test_is_likely_synapse_macaroon() {
440 assert!(is_likely_synapse_macaroon(
443 "MDAxYmxvY2F0aW9uIGxpYnJlcHVzaC5uZXQKMDAx"
444 ));
445
446 assert!(is_likely_synapse_macaroon(
448 "MDAxY2xvY2F0aW9uIGh0dHA6Ly9teWJhbmsvCjAwMjZpZGVudGlmaWVyIHdlIHVzZWQgb3VyIHNlY3JldCBrZXkKMDAyZnNpZ25hdHVyZSDj2eApCFJsTAA5rhURQRXZf91ovyujebNCqvD2F9BVLwo"
449 ));
450
451 assert!(!is_likely_synapse_macaroon(
453 "eyJARTOhearotnaeisahtoarsnhiasra.arsohenaor.oarnsteao"
454 ));
455 assert!(!is_likely_synapse_macaroon("...."));
456 assert!(!is_likely_synapse_macaroon("aaa"));
457 }
458
459 #[test]
460 fn test_generate_and_check() {
461 const COUNT: usize = 500; #[allow(clippy::disallowed_methods)]
464 let mut rng = thread_rng();
465
466 for t in [
467 TokenType::CompatAccessToken,
468 TokenType::CompatRefreshToken,
469 TokenType::AccessToken,
470 TokenType::RefreshToken,
471 ] {
472 let tokens: HashSet<String> = (0..COUNT).map(|_| t.generate(&mut rng)).collect();
474
475 assert_eq!(tokens.len(), COUNT, "All tokens are unique");
477
478 for token in tokens {
480 assert_eq!(TokenType::check(&token).unwrap(), t);
481 }
482 }
483 }
484}