mas_axum_utils/
cookies.rs1use std::convert::Infallible;
10
11use axum::{
12 extract::{FromRef, FromRequestParts},
13 response::{IntoResponseParts, ResponseParts},
14};
15use axum_extra::extract::cookie::{Cookie, Key, PrivateCookieJar, SameSite};
16use http::request::Parts;
17use serde::{Serialize, de::DeserializeOwned};
18use thiserror::Error;
19use url::Url;
20
21#[derive(Debug, Error)]
22#[error("could not decode cookie")]
23pub enum CookieDecodeError {
24 Deserialize(#[from] serde_json::Error),
25}
26
27#[derive(Clone)]
32pub struct CookieManager {
33 options: CookieOption,
34 key: Key,
35}
36
37impl CookieManager {
38 #[must_use]
39 pub const fn new(base_url: Url, key: Key) -> Self {
40 let options = CookieOption::new(base_url);
41 Self { options, key }
42 }
43
44 #[must_use]
45 pub fn derive_from(base_url: Url, key: &[u8]) -> Self {
46 let key = Key::derive_from(key);
47 Self::new(base_url, key)
48 }
49
50 #[must_use]
51 pub fn cookie_jar(&self) -> CookieJar {
52 let inner = PrivateCookieJar::new(self.key.clone());
53 let options = self.options.clone();
54
55 CookieJar { inner, options }
56 }
57
58 #[must_use]
59 pub fn cookie_jar_from_headers(&self, headers: &http::HeaderMap) -> CookieJar {
60 let inner = PrivateCookieJar::from_headers(headers, self.key.clone());
61 let options = self.options.clone();
62
63 CookieJar { inner, options }
64 }
65}
66
67impl<S> FromRequestParts<S> for CookieJar
68where
69 CookieManager: FromRef<S>,
70 S: Send + Sync,
71{
72 type Rejection = Infallible;
73
74 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
75 let cookie_manager = CookieManager::from_ref(state);
76 Ok(cookie_manager.cookie_jar_from_headers(&parts.headers))
77 }
78}
79
80#[derive(Debug, Clone)]
81struct CookieOption {
82 base_url: Url,
83}
84
85impl CookieOption {
86 const fn new(base_url: Url) -> Self {
87 Self { base_url }
88 }
89
90 fn secure(&self) -> bool {
91 self.base_url.scheme() == "https"
92 }
93
94 fn path(&self) -> &str {
95 self.base_url.path()
96 }
97
98 fn apply<'a>(&self, mut cookie: Cookie<'a>) -> Cookie<'a> {
99 cookie.set_http_only(true);
100 cookie.set_secure(self.secure());
101 cookie.set_path(self.path().to_owned());
102 cookie.set_same_site(SameSite::Lax);
103 cookie
104 }
105}
106
107pub struct CookieJar {
109 inner: PrivateCookieJar<Key>,
110 options: CookieOption,
111}
112
113impl CookieJar {
114 #[must_use]
122 pub fn save<T: Serialize>(mut self, key: &str, payload: &T, permanent: bool) -> Self {
123 let serialized =
124 serde_json::to_string(payload).expect("failed to serialize cookie payload");
125
126 let cookie = Cookie::new(key.to_owned(), serialized);
127 let mut cookie = self.options.apply(cookie);
128
129 if permanent {
130 cookie.make_permanent();
132 }
133
134 self.inner = self.inner.add(cookie);
135
136 self
137 }
138
139 #[must_use]
141 pub fn remove(mut self, key: &str) -> Self {
142 self.inner = self.inner.remove(key.to_owned());
143 self
144 }
145
146 pub fn load<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, CookieDecodeError> {
154 let Some(cookie) = self.inner.get(key) else {
155 return Ok(None);
156 };
157
158 let decoded = serde_json::from_str(cookie.value())?;
159 Ok(Some(decoded))
160 }
161}
162
163impl IntoResponseParts for CookieJar {
164 type Error = Infallible;
165
166 fn into_response_parts(self, res: ResponseParts) -> Result<ResponseParts, Self::Error> {
167 self.inner.into_response_parts(res)
168 }
169}