mas_data_model/
tokens.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use base64ct::{Base64UrlUnpadded, Encoding};
8use chrono::{DateTime, Utc};
9use crc::{CRC_32_ISO_HDLC, Crc};
10use mas_iana::oauth::OAuthTokenTypeHint;
11use rand::{Rng, RngCore, distributions::Alphanumeric};
12use thiserror::Error;
13use ulid::Ulid;
14
15use crate::InvalidTransitionError;
16
17#[derive(Debug, Clone, Default, PartialEq, Eq)]
18pub enum AccessTokenState {
19    #[default]
20    Valid,
21    Revoked {
22        revoked_at: DateTime<Utc>,
23    },
24}
25
26impl AccessTokenState {
27    fn revoke(self, revoked_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
28        match self {
29            Self::Valid => Ok(Self::Revoked { revoked_at }),
30            Self::Revoked { .. } => Err(InvalidTransitionError),
31        }
32    }
33
34    /// Returns `true` if the refresh token state is [`Valid`].
35    ///
36    /// [`Valid`]: AccessTokenState::Valid
37    #[must_use]
38    pub fn is_valid(&self) -> bool {
39        matches!(self, Self::Valid)
40    }
41
42    /// Returns `true` if the refresh token state is [`Revoked`].
43    ///
44    /// [`Revoked`]: AccessTokenState::Revoked
45    #[must_use]
46    pub fn is_revoked(&self) -> bool {
47        matches!(self, Self::Revoked { .. })
48    }
49}
50
51#[derive(Debug, Clone, PartialEq, Eq)]
52pub struct AccessToken {
53    pub id: Ulid,
54    pub state: AccessTokenState,
55    pub session_id: Ulid,
56    pub access_token: String,
57    pub created_at: DateTime<Utc>,
58    pub expires_at: Option<DateTime<Utc>>,
59    pub first_used_at: Option<DateTime<Utc>>,
60}
61
62impl AccessToken {
63    #[must_use]
64    pub fn jti(&self) -> String {
65        self.id.to_string()
66    }
67
68    /// Whether the access token is valid, i.e. not revoked and not expired
69    ///
70    /// # Parameters
71    ///
72    /// * `now` - The current time
73    #[must_use]
74    pub fn is_valid(&self, now: DateTime<Utc>) -> bool {
75        self.state.is_valid() && !self.is_expired(now)
76    }
77
78    /// Whether the access token is expired
79    ///
80    /// Always returns `false` if the access token does not have an expiry time.
81    ///
82    /// # Parameters
83    ///
84    /// * `now` - The current time
85    #[must_use]
86    pub fn is_expired(&self, now: DateTime<Utc>) -> bool {
87        match self.expires_at {
88            Some(expires_at) => expires_at < now,
89            None => false,
90        }
91    }
92
93    /// Whether the access token was used at least once
94    #[must_use]
95    pub fn is_used(&self) -> bool {
96        self.first_used_at.is_some()
97    }
98
99    /// Mark the access token as revoked
100    ///
101    /// # Parameters
102    ///
103    /// * `revoked_at` - The time at which the access token was revoked
104    ///
105    /// # Errors
106    ///
107    /// Returns an error if the access token is already revoked
108    pub fn revoke(mut self, revoked_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
109        self.state = self.state.revoke(revoked_at)?;
110        Ok(self)
111    }
112}
113
114#[derive(Debug, Clone, Default, PartialEq, Eq)]
115pub enum RefreshTokenState {
116    #[default]
117    Valid,
118    Consumed {
119        consumed_at: DateTime<Utc>,
120        next_refresh_token_id: Option<Ulid>,
121    },
122    Revoked {
123        revoked_at: DateTime<Utc>,
124    },
125}
126
127impl RefreshTokenState {
128    /// Consume the refresh token, returning a new state.
129    ///
130    /// # Errors
131    ///
132    /// Returns an error if the refresh token is revoked.
133    fn consume(
134        self,
135        consumed_at: DateTime<Utc>,
136        replaced_by: &RefreshToken,
137    ) -> Result<Self, InvalidTransitionError> {
138        match self {
139            Self::Valid | Self::Consumed { .. } => Ok(Self::Consumed {
140                consumed_at,
141                next_refresh_token_id: Some(replaced_by.id),
142            }),
143            Self::Revoked { .. } => Err(InvalidTransitionError),
144        }
145    }
146
147    /// Revoke the refresh token, returning a new state.
148    ///
149    /// # Errors
150    ///
151    /// Returns an error if the refresh token is already consumed or revoked.
152    pub fn revoke(self, revoked_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
153        match self {
154            Self::Valid => Ok(Self::Revoked { revoked_at }),
155            Self::Consumed { .. } | Self::Revoked { .. } => Err(InvalidTransitionError),
156        }
157    }
158
159    /// Returns `true` if the refresh token state is [`Valid`].
160    ///
161    /// [`Valid`]: RefreshTokenState::Valid
162    #[must_use]
163    pub fn is_valid(&self) -> bool {
164        matches!(self, Self::Valid)
165    }
166
167    /// Returns the next refresh token ID, if any.
168    #[must_use]
169    pub fn next_refresh_token_id(&self) -> Option<Ulid> {
170        match self {
171            Self::Valid | Self::Revoked { .. } => None,
172            Self::Consumed {
173                next_refresh_token_id,
174                ..
175            } => *next_refresh_token_id,
176        }
177    }
178}
179
180#[derive(Debug, Clone, PartialEq, Eq)]
181pub struct RefreshToken {
182    pub id: Ulid,
183    pub state: RefreshTokenState,
184    pub refresh_token: String,
185    pub session_id: Ulid,
186    pub created_at: DateTime<Utc>,
187    pub access_token_id: Option<Ulid>,
188}
189
190impl std::ops::Deref for RefreshToken {
191    type Target = RefreshTokenState;
192
193    fn deref(&self) -> &Self::Target {
194        &self.state
195    }
196}
197
198impl RefreshToken {
199    #[must_use]
200    pub fn jti(&self) -> String {
201        self.id.to_string()
202    }
203
204    /// Consumes the refresh token and returns the consumed token.
205    ///
206    /// # Errors
207    ///
208    /// Returns an error if the refresh token is revoked.
209    pub fn consume(
210        mut self,
211        consumed_at: DateTime<Utc>,
212        replaced_by: &Self,
213    ) -> Result<Self, InvalidTransitionError> {
214        self.state = self.state.consume(consumed_at, replaced_by)?;
215        Ok(self)
216    }
217
218    /// Revokes the refresh token and returns a new revoked token
219    ///
220    /// # Errors
221    ///
222    /// Returns an error if the refresh token is already revoked.
223    pub fn revoke(mut self, revoked_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
224        self.state = self.state.revoke(revoked_at)?;
225        Ok(self)
226    }
227}
228
229/// Type of token to generate or validate
230#[derive(Debug, Clone, Copy, PartialEq, Eq)]
231pub enum TokenType {
232    /// An access token, used by Relying Parties to authenticate requests
233    AccessToken,
234
235    /// A refresh token, used by the refresh token grant
236    RefreshToken,
237
238    /// A legacy access token
239    CompatAccessToken,
240
241    /// A legacy refresh token
242    CompatRefreshToken,
243}
244
245impl std::fmt::Display for TokenType {
246    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
247        match self {
248            TokenType::AccessToken => write!(f, "access token"),
249            TokenType::RefreshToken => write!(f, "refresh token"),
250            TokenType::CompatAccessToken => write!(f, "compat access token"),
251            TokenType::CompatRefreshToken => write!(f, "compat refresh token"),
252        }
253    }
254}
255
256impl TokenType {
257    fn prefix(self) -> &'static str {
258        match self {
259            TokenType::AccessToken => "mat",
260            TokenType::RefreshToken => "mar",
261            TokenType::CompatAccessToken => "mct",
262            TokenType::CompatRefreshToken => "mcr",
263        }
264    }
265
266    fn match_prefix(prefix: &str) -> Option<Self> {
267        match prefix {
268            "mat" => Some(TokenType::AccessToken),
269            "mar" => Some(TokenType::RefreshToken),
270            "mct" | "syt" => Some(TokenType::CompatAccessToken),
271            "mcr" | "syr" => Some(TokenType::CompatRefreshToken),
272            _ => None,
273        }
274    }
275
276    /// Generate a token for the given type
277    pub fn generate(self, rng: &mut (impl RngCore + ?Sized)) -> String {
278        let random_part: String = rng
279            .sample_iter(&Alphanumeric)
280            .take(30)
281            .map(char::from)
282            .collect();
283
284        let base = format!("{prefix}_{random_part}", prefix = self.prefix());
285        let crc = CRC.checksum(base.as_bytes());
286        let crc = base62_encode(crc);
287        format!("{base}_{crc}")
288    }
289
290    /// Check the format of a token and determine its type
291    ///
292    /// # Errors
293    ///
294    /// Returns an error if the token is not valid
295    pub fn check(token: &str) -> Result<TokenType, TokenFormatError> {
296        // these are legacy tokens imported from Synapse
297        // we don't do any validation on them and continue as is
298        if token.starts_with("syt_") || is_likely_synapse_macaroon(token) {
299            return Ok(TokenType::CompatAccessToken);
300        }
301        if token.starts_with("syr_") {
302            return Ok(TokenType::CompatRefreshToken);
303        }
304
305        let split: Vec<&str> = token.split('_').collect();
306        let [prefix, random_part, crc]: [&str; 3] = split
307            .try_into()
308            .map_err(|_| TokenFormatError::InvalidFormat)?;
309
310        if prefix.len() != 3 || random_part.len() != 30 || crc.len() != 6 {
311            return Err(TokenFormatError::InvalidFormat);
312        }
313
314        let token_type =
315            TokenType::match_prefix(prefix).ok_or_else(|| TokenFormatError::UnknownPrefix {
316                prefix: prefix.to_owned(),
317            })?;
318
319        let base = format!("{prefix}_{random_part}", prefix = token_type.prefix());
320        let expected_crc = CRC.checksum(base.as_bytes());
321        let expected_crc = base62_encode(expected_crc);
322        if crc != expected_crc {
323            return Err(TokenFormatError::InvalidCrc {
324                expected: expected_crc,
325                got: crc.to_owned(),
326            });
327        }
328
329        Ok(token_type)
330    }
331}
332
333impl PartialEq<OAuthTokenTypeHint> for TokenType {
334    fn eq(&self, other: &OAuthTokenTypeHint) -> bool {
335        matches!(
336            (self, other),
337            (
338                TokenType::AccessToken | TokenType::CompatAccessToken,
339                OAuthTokenTypeHint::AccessToken
340            ) | (
341                TokenType::RefreshToken | TokenType::CompatRefreshToken,
342                OAuthTokenTypeHint::RefreshToken
343            )
344        )
345    }
346}
347
348/// Returns true if and only if a token looks like it may be a macaroon.
349///
350/// Macaroons are a standard for tokens that support attenuation.
351/// Synapse used them for old sessions and for guest sessions.
352///
353/// We won't bother to decode them fully, but we can check to see if the first
354/// constraint is the `location` constraint.
355fn is_likely_synapse_macaroon(token: &str) -> bool {
356    let Ok(decoded) = Base64UrlUnpadded::decode_vec(token) else {
357        return false;
358    };
359    decoded.get(4..13) == Some(b"location ")
360}
361
362const NUM: [u8; 62] = *b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
363
364fn base62_encode(mut num: u32) -> String {
365    let mut res = String::with_capacity(6);
366    while num > 0 {
367        res.push(NUM[(num % 62) as usize] as char);
368        num /= 62;
369    }
370
371    format!("{res:0>6}")
372}
373
374const CRC: Crc<u32> = Crc::<u32>::new(&CRC_32_ISO_HDLC);
375
376/// Invalid token
377#[derive(Debug, Error, PartialEq, Eq)]
378pub enum TokenFormatError {
379    /// Overall token format is invalid
380    #[error("invalid token format")]
381    InvalidFormat,
382
383    /// Token used an unknown prefix
384    #[error("unknown token prefix {prefix:?}")]
385    UnknownPrefix {
386        /// The prefix found in the token
387        prefix: String,
388    },
389
390    /// The CRC checksum in the token is invalid
391    #[error("invalid crc {got:?}, expected {expected:?}")]
392    InvalidCrc {
393        /// The CRC hash expected to be found in the token
394        expected: String,
395        /// The CRC found in the token
396        got: String,
397    },
398}
399
400#[cfg(test)]
401mod tests {
402    use std::collections::HashSet;
403
404    use rand::thread_rng;
405
406    use super::*;
407
408    #[test]
409    fn test_prefix_match() {
410        use TokenType::{AccessToken, CompatAccessToken, CompatRefreshToken, RefreshToken};
411        assert_eq!(TokenType::match_prefix("syt"), Some(CompatAccessToken));
412        assert_eq!(TokenType::match_prefix("syr"), Some(CompatRefreshToken));
413        assert_eq!(TokenType::match_prefix("mct"), Some(CompatAccessToken));
414        assert_eq!(TokenType::match_prefix("mcr"), Some(CompatRefreshToken));
415        assert_eq!(TokenType::match_prefix("mat"), Some(AccessToken));
416        assert_eq!(TokenType::match_prefix("mar"), Some(RefreshToken));
417        assert_eq!(TokenType::match_prefix("matt"), None);
418        assert_eq!(TokenType::match_prefix("marr"), None);
419        assert_eq!(TokenType::match_prefix("ma"), None);
420        assert_eq!(
421            TokenType::match_prefix(TokenType::CompatAccessToken.prefix()),
422            Some(TokenType::CompatAccessToken)
423        );
424        assert_eq!(
425            TokenType::match_prefix(TokenType::CompatRefreshToken.prefix()),
426            Some(TokenType::CompatRefreshToken)
427        );
428        assert_eq!(
429            TokenType::match_prefix(TokenType::AccessToken.prefix()),
430            Some(TokenType::AccessToken)
431        );
432        assert_eq!(
433            TokenType::match_prefix(TokenType::RefreshToken.prefix()),
434            Some(TokenType::RefreshToken)
435        );
436    }
437
438    #[test]
439    fn test_is_likely_synapse_macaroon() {
440        // This is just the prefix of a Synapse macaroon, but it's enough to make the
441        // sniffing work
442        assert!(is_likely_synapse_macaroon(
443            "MDAxYmxvY2F0aW9uIGxpYnJlcHVzaC5uZXQKMDAx"
444        ));
445
446        // This is a valid macaroon (even though Synapse did not generate this one)
447        assert!(is_likely_synapse_macaroon(
448            "MDAxY2xvY2F0aW9uIGh0dHA6Ly9teWJhbmsvCjAwMjZpZGVudGlmaWVyIHdlIHVzZWQgb3VyIHNlY3JldCBrZXkKMDAyZnNpZ25hdHVyZSDj2eApCFJsTAA5rhURQRXZf91ovyujebNCqvD2F9BVLwo"
449        ));
450
451        // None of these are macaroons
452        assert!(!is_likely_synapse_macaroon(
453            "eyJARTOhearotnaeisahtoarsnhiasra.arsohenaor.oarnsteao"
454        ));
455        assert!(!is_likely_synapse_macaroon("...."));
456        assert!(!is_likely_synapse_macaroon("aaa"));
457    }
458
459    #[test]
460    fn test_generate_and_check() {
461        const COUNT: usize = 500; // Generate 500 of each token type
462
463        #[allow(clippy::disallowed_methods)]
464        let mut rng = thread_rng();
465
466        for t in [
467            TokenType::CompatAccessToken,
468            TokenType::CompatRefreshToken,
469            TokenType::AccessToken,
470            TokenType::RefreshToken,
471        ] {
472            // Generate many tokens
473            let tokens: HashSet<String> = (0..COUNT).map(|_| t.generate(&mut rng)).collect();
474
475            // Check that they are all different
476            assert_eq!(tokens.len(), COUNT, "All tokens are unique");
477
478            // Check that they are all valid and detected as the right token type
479            for token in tokens {
480                assert_eq!(TokenType::check(&token).unwrap(), t);
481            }
482        }
483    }
484}