mas_axum_utils/
csrf.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use base64ct::{Base64UrlUnpadded, Encoding};
8use chrono::{DateTime, Duration, Utc};
9use mas_storage::Clock;
10use rand::{Rng, RngCore, distributions::Standard, prelude::Distribution as _};
11use serde::{Deserialize, Serialize};
12use serde_with::{TimestampSeconds, serde_as};
13use thiserror::Error;
14
15use crate::cookies::{CookieDecodeError, CookieJar};
16
17/// Failed to validate CSRF token
18#[derive(Debug, Error)]
19pub enum CsrfError {
20    /// The token in the form did not match the token in the cookie
21    #[error("CSRF token mismatch")]
22    Mismatch,
23
24    /// The token in the form did not match the token in the cookie
25    #[error("Missing CSRF cookie")]
26    Missing,
27
28    /// Failed to decode the token
29    #[error("could not decode CSRF cookie")]
30    DecodeCookie(#[from] CookieDecodeError),
31
32    /// The token expired
33    #[error("CSRF token expired")]
34    Expired,
35
36    /// Failed to decode the token
37    #[error("could not decode CSRF token")]
38    Decode(#[from] base64ct::Error),
39}
40
41/// A CSRF token
42#[serde_as]
43#[derive(Serialize, Deserialize, Debug)]
44pub struct CsrfToken {
45    #[serde_as(as = "TimestampSeconds<i64>")]
46    expiration: DateTime<Utc>,
47    token: [u8; 32],
48}
49
50impl CsrfToken {
51    /// Create a new token from a defined value valid for a specified duration
52    fn new(token: [u8; 32], now: DateTime<Utc>, ttl: Duration) -> Self {
53        let expiration = now + ttl;
54        Self { expiration, token }
55    }
56
57    /// Generate a new random token valid for a specified duration
58    fn generate(now: DateTime<Utc>, mut rng: impl Rng, ttl: Duration) -> Self {
59        let token = Standard.sample(&mut rng);
60        Self::new(token, now, ttl)
61    }
62
63    /// Generate a new token with the same value but an up to date expiration
64    fn refresh(self, now: DateTime<Utc>, ttl: Duration) -> Self {
65        Self::new(self.token, now, ttl)
66    }
67
68    /// Get the value to include in HTML forms
69    #[must_use]
70    pub fn form_value(&self) -> String {
71        Base64UrlUnpadded::encode_string(&self.token[..])
72    }
73
74    /// Verifies that the value got from an HTML form matches this token
75    ///
76    /// # Errors
77    ///
78    /// Returns an error if the value in the form does not match this token
79    pub fn verify_form_value(&self, form_value: &str) -> Result<(), CsrfError> {
80        let form_value = Base64UrlUnpadded::decode_vec(form_value)?;
81        if self.token[..] == form_value {
82            Ok(())
83        } else {
84            Err(CsrfError::Mismatch)
85        }
86    }
87
88    fn verify_expiration(self, now: DateTime<Utc>) -> Result<Self, CsrfError> {
89        if now < self.expiration {
90            Ok(self)
91        } else {
92            Err(CsrfError::Expired)
93        }
94    }
95}
96
97// A CSRF-protected form
98#[derive(Deserialize)]
99pub struct ProtectedForm<T> {
100    csrf: String,
101
102    #[serde(flatten)]
103    inner: T,
104}
105
106pub trait CsrfExt {
107    /// Get the current CSRF token out of the cookie jar, generating a new one
108    /// if necessary
109    fn csrf_token<C, R>(self, clock: &C, rng: R) -> (CsrfToken, Self)
110    where
111        R: RngCore,
112        C: Clock;
113
114    /// Verify that the given CSRF-protected form is valid, returning the inner
115    /// value
116    ///
117    /// # Errors
118    ///
119    /// Returns an error if the CSRF cookie is missing or if the value in the
120    /// form is invalid
121    fn verify_form<C, T>(&self, clock: &C, form: ProtectedForm<T>) -> Result<T, CsrfError>
122    where
123        C: Clock;
124}
125
126impl CsrfExt for CookieJar {
127    fn csrf_token<C, R>(self, clock: &C, rng: R) -> (CsrfToken, Self)
128    where
129        R: RngCore,
130        C: Clock,
131    {
132        let now = clock.now();
133        let maybe_token = match self.load::<CsrfToken>("csrf") {
134            Ok(Some(token)) => {
135                let token = token.verify_expiration(now);
136
137                // If the token is expired, just ignore it
138                token.ok()
139            }
140            Ok(None) => None,
141            Err(e) => {
142                tracing::warn!("Failed to decode CSRF cookie: {}", e);
143                None
144            }
145        };
146
147        let token = maybe_token.map_or_else(
148            || CsrfToken::generate(now, rng, Duration::try_hours(1).unwrap()),
149            |token| token.refresh(now, Duration::try_hours(1).unwrap()),
150        );
151
152        let jar = self.save("csrf", &token, false);
153        (token, jar)
154    }
155
156    fn verify_form<C, T>(&self, clock: &C, form: ProtectedForm<T>) -> Result<T, CsrfError>
157    where
158        C: Clock,
159    {
160        let token: CsrfToken = self.load("csrf")?.ok_or(CsrfError::Missing)?;
161        let token = token.verify_expiration(clock.now())?;
162        token.verify_form_value(&form.csrf)?;
163        Ok(form.inner)
164    }
165}