mas_axum_utils/
cookies.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7//! Private (encrypted) cookie jar, based on axum-extra's cookie jar
8
9use std::convert::Infallible;
10
11use axum::{
12    extract::{FromRef, FromRequestParts},
13    response::{IntoResponseParts, ResponseParts},
14};
15use axum_extra::extract::cookie::{Cookie, Key, PrivateCookieJar, SameSite};
16use http::request::Parts;
17use serde::{Serialize, de::DeserializeOwned};
18use thiserror::Error;
19use url::Url;
20
21#[derive(Debug, Error)]
22#[error("could not decode cookie")]
23pub enum CookieDecodeError {
24    Deserialize(#[from] serde_json::Error),
25}
26
27/// Manages cookie options and encryption key
28///
29/// This is meant to be accessible through axum's state via the [`FromRef`]
30/// trait
31#[derive(Clone)]
32pub struct CookieManager {
33    options: CookieOption,
34    key: Key,
35}
36
37impl CookieManager {
38    #[must_use]
39    pub const fn new(base_url: Url, key: Key) -> Self {
40        let options = CookieOption::new(base_url);
41        Self { options, key }
42    }
43
44    #[must_use]
45    pub fn derive_from(base_url: Url, key: &[u8]) -> Self {
46        let key = Key::derive_from(key);
47        Self::new(base_url, key)
48    }
49
50    #[must_use]
51    pub fn cookie_jar(&self) -> CookieJar {
52        let inner = PrivateCookieJar::new(self.key.clone());
53        let options = self.options.clone();
54
55        CookieJar { inner, options }
56    }
57
58    #[must_use]
59    pub fn cookie_jar_from_headers(&self, headers: &http::HeaderMap) -> CookieJar {
60        let inner = PrivateCookieJar::from_headers(headers, self.key.clone());
61        let options = self.options.clone();
62
63        CookieJar { inner, options }
64    }
65}
66
67impl<S> FromRequestParts<S> for CookieJar
68where
69    CookieManager: FromRef<S>,
70    S: Send + Sync,
71{
72    type Rejection = Infallible;
73
74    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
75        let cookie_manager = CookieManager::from_ref(state);
76        Ok(cookie_manager.cookie_jar_from_headers(&parts.headers))
77    }
78}
79
80#[derive(Debug, Clone)]
81struct CookieOption {
82    base_url: Url,
83}
84
85impl CookieOption {
86    const fn new(base_url: Url) -> Self {
87        Self { base_url }
88    }
89
90    fn secure(&self) -> bool {
91        self.base_url.scheme() == "https"
92    }
93
94    fn path(&self) -> &str {
95        self.base_url.path()
96    }
97
98    fn apply<'a>(&self, mut cookie: Cookie<'a>) -> Cookie<'a> {
99        cookie.set_http_only(true);
100        cookie.set_secure(self.secure());
101        cookie.set_path(self.path().to_owned());
102        cookie.set_same_site(SameSite::Lax);
103        cookie
104    }
105}
106
107/// A cookie jar which encrypts cookies & sets secure options
108pub struct CookieJar {
109    inner: PrivateCookieJar<Key>,
110    options: CookieOption,
111}
112
113impl CookieJar {
114    /// Save the given payload in a cookie
115    ///
116    /// If `permanent` is true, the cookie will be valid for 10 years
117    ///
118    /// # Panics
119    ///
120    /// Panics if the payload cannot be serialized
121    #[must_use]
122    pub fn save<T: Serialize>(mut self, key: &str, payload: &T, permanent: bool) -> Self {
123        let serialized =
124            serde_json::to_string(payload).expect("failed to serialize cookie payload");
125
126        let cookie = Cookie::new(key.to_owned(), serialized);
127        let mut cookie = self.options.apply(cookie);
128
129        if permanent {
130            // XXX: this should use a clock
131            cookie.make_permanent();
132        }
133
134        self.inner = self.inner.add(cookie);
135
136        self
137    }
138
139    /// Remove a cookie from the jar
140    #[must_use]
141    pub fn remove(mut self, key: &str) -> Self {
142        self.inner = self.inner.remove(key.to_owned());
143        self
144    }
145
146    /// Load and deserialize a cookie from the jar
147    ///
148    /// Returns `None` if the cookie is not present
149    ///
150    /// # Errors
151    ///
152    /// Returns an error if the cookie cannot be deserialized
153    pub fn load<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, CookieDecodeError> {
154        let Some(cookie) = self.inner.get(key) else {
155            return Ok(None);
156        };
157
158        let decoded = serde_json::from_str(cookie.value())?;
159        Ok(Some(decoded))
160    }
161}
162
163impl IntoResponseParts for CookieJar {
164    type Error = Infallible;
165
166    fn into_response_parts(self, res: ResponseParts) -> Result<ResponseParts, Self::Error> {
167        self.inner.into_response_parts(res)
168    }
169}