1use base64ct::{Base64UrlUnpadded, Encoding};
8use chrono::{DateTime, Duration, Utc};
9use mas_storage::Clock;
10use rand::{Rng, RngCore, distributions::Standard, prelude::Distribution as _};
11use serde::{Deserialize, Serialize};
12use serde_with::{TimestampSeconds, serde_as};
13use thiserror::Error;
14
15use crate::cookies::{CookieDecodeError, CookieJar};
16
17#[derive(Debug, Error)]
19pub enum CsrfError {
20 #[error("CSRF token mismatch")]
22 Mismatch,
23
24 #[error("Missing CSRF cookie")]
26 Missing,
27
28 #[error("could not decode CSRF cookie")]
30 DecodeCookie(#[from] CookieDecodeError),
31
32 #[error("CSRF token expired")]
34 Expired,
35
36 #[error("could not decode CSRF token")]
38 Decode(#[from] base64ct::Error),
39}
40
41#[serde_as]
43#[derive(Serialize, Deserialize, Debug)]
44pub struct CsrfToken {
45 #[serde_as(as = "TimestampSeconds<i64>")]
46 expiration: DateTime<Utc>,
47 token: [u8; 32],
48}
49
50impl CsrfToken {
51 fn new(token: [u8; 32], now: DateTime<Utc>, ttl: Duration) -> Self {
53 let expiration = now + ttl;
54 Self { expiration, token }
55 }
56
57 fn generate(now: DateTime<Utc>, mut rng: impl Rng, ttl: Duration) -> Self {
59 let token = Standard.sample(&mut rng);
60 Self::new(token, now, ttl)
61 }
62
63 fn refresh(self, now: DateTime<Utc>, ttl: Duration) -> Self {
65 Self::new(self.token, now, ttl)
66 }
67
68 #[must_use]
70 pub fn form_value(&self) -> String {
71 Base64UrlUnpadded::encode_string(&self.token[..])
72 }
73
74 pub fn verify_form_value(&self, form_value: &str) -> Result<(), CsrfError> {
80 let form_value = Base64UrlUnpadded::decode_vec(form_value)?;
81 if self.token[..] == form_value {
82 Ok(())
83 } else {
84 Err(CsrfError::Mismatch)
85 }
86 }
87
88 fn verify_expiration(self, now: DateTime<Utc>) -> Result<Self, CsrfError> {
89 if now < self.expiration {
90 Ok(self)
91 } else {
92 Err(CsrfError::Expired)
93 }
94 }
95}
96
97#[derive(Deserialize)]
99pub struct ProtectedForm<T> {
100 csrf: String,
101
102 #[serde(flatten)]
103 inner: T,
104}
105
106pub trait CsrfExt {
107 fn csrf_token<C, R>(self, clock: &C, rng: R) -> (CsrfToken, Self)
110 where
111 R: RngCore,
112 C: Clock;
113
114 fn verify_form<C, T>(&self, clock: &C, form: ProtectedForm<T>) -> Result<T, CsrfError>
122 where
123 C: Clock;
124}
125
126impl CsrfExt for CookieJar {
127 fn csrf_token<C, R>(self, clock: &C, rng: R) -> (CsrfToken, Self)
128 where
129 R: RngCore,
130 C: Clock,
131 {
132 let now = clock.now();
133 let maybe_token = match self.load::<CsrfToken>("csrf") {
134 Ok(Some(token)) => {
135 let token = token.verify_expiration(now);
136
137 token.ok()
139 }
140 Ok(None) => None,
141 Err(e) => {
142 tracing::warn!("Failed to decode CSRF cookie: {}", e);
143 None
144 }
145 };
146
147 let token = maybe_token.map_or_else(
148 || CsrfToken::generate(now, rng, Duration::try_hours(1).unwrap()),
149 |token| token.refresh(now, Duration::try_hours(1).unwrap()),
150 );
151
152 let jar = self.save("csrf", &token, false);
153 (token, jar)
154 }
155
156 fn verify_form<C, T>(&self, clock: &C, form: ProtectedForm<T>) -> Result<T, CsrfError>
157 where
158 C: Clock,
159 {
160 let token: CsrfToken = self.load("csrf")?.ok_or(CsrfError::Missing)?;
161 let token = token.verify_expiration(clock.now())?;
162 token.verify_form_value(&form.csrf)?;
163 Ok(form.inner)
164 }
165}