1use std::{net::IpAddr, sync::Arc, time::Duration};
8
9use governor::{RateLimiter, clock::QuantaClock, state::keyed::DashMapStateStore};
10use mas_config::RateLimitingConfig;
11use mas_data_model::{User, UserEmailAuthentication};
12use ulid::Ulid;
13
14#[derive(Debug, Clone, thiserror::Error)]
15pub enum AccountRecoveryLimitedError {
16 #[error("Too many account recovery requests for requester {0}")]
17 Requester(RequesterFingerprint),
18
19 #[error("Too many account recovery requests for e-mail {0}")]
20 Email(String),
21}
22
23#[derive(Debug, Clone, Copy, thiserror::Error)]
24pub enum PasswordCheckLimitedError {
25 #[error("Too many password checks for requester {0}")]
26 Requester(RequesterFingerprint),
27
28 #[error("Too many password checks for user {0}")]
29 User(Ulid),
30}
31
32#[derive(Debug, Clone, thiserror::Error)]
33pub enum RegistrationLimitedError {
34 #[error("Too many account registration requests for requester {0}")]
35 Requester(RequesterFingerprint),
36}
37
38#[derive(Debug, Clone, thiserror::Error)]
39pub enum EmailAuthenticationLimitedError {
40 #[error("Too many email authentication requests for requester {0}")]
41 Requester(RequesterFingerprint),
42
43 #[error("Too many email authentication requests for authentication session {0}")]
44 Authentication(Ulid),
45
46 #[error("Too many email authentication requests for email {0}")]
47 Email(String),
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
52pub struct RequesterFingerprint {
53 ip: Option<IpAddr>,
54}
55
56impl std::fmt::Display for RequesterFingerprint {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 if let Some(ip) = self.ip {
59 write!(f, "{ip}")
60 } else {
61 write!(f, "(NO CLIENT IP)")
62 }
63 }
64}
65
66impl RequesterFingerprint {
67 pub const EMPTY: Self = Self { ip: None };
70
71 #[must_use]
73 pub const fn new(ip: IpAddr) -> Self {
74 Self { ip: Some(ip) }
75 }
76}
77
78#[derive(Debug, Clone)]
80pub struct Limiter {
81 inner: Arc<LimiterInner>,
82}
83
84type KeyedRateLimiter<K> = RateLimiter<K, DashMapStateStore<K>, QuantaClock>;
85
86#[derive(Debug)]
87struct LimiterInner {
88 account_recovery_per_requester: KeyedRateLimiter<RequesterFingerprint>,
89 account_recovery_per_email: KeyedRateLimiter<String>,
90 password_check_for_requester: KeyedRateLimiter<RequesterFingerprint>,
91 password_check_for_user: KeyedRateLimiter<Ulid>,
92 registration_per_requester: KeyedRateLimiter<RequesterFingerprint>,
93 email_authentication_per_requester: KeyedRateLimiter<RequesterFingerprint>,
94 email_authentication_per_email: KeyedRateLimiter<String>,
95 email_authentication_emails_per_session: KeyedRateLimiter<Ulid>,
96 email_authentication_attempt_per_session: KeyedRateLimiter<Ulid>,
97}
98
99impl LimiterInner {
100 fn new(config: &RateLimitingConfig) -> Option<Self> {
101 Some(Self {
102 account_recovery_per_requester: RateLimiter::keyed(
103 config.account_recovery.per_ip.to_quota()?,
104 ),
105 account_recovery_per_email: RateLimiter::keyed(
106 config.account_recovery.per_address.to_quota()?,
107 ),
108 password_check_for_requester: RateLimiter::keyed(config.login.per_ip.to_quota()?),
109 password_check_for_user: RateLimiter::keyed(config.login.per_account.to_quota()?),
110 registration_per_requester: RateLimiter::keyed(config.registration.to_quota()?),
111 email_authentication_per_email: RateLimiter::keyed(
112 config.email_authentication.per_address.to_quota()?,
113 ),
114 email_authentication_per_requester: RateLimiter::keyed(
115 config.email_authentication.per_ip.to_quota()?,
116 ),
117 email_authentication_emails_per_session: RateLimiter::keyed(
118 config.email_authentication.emails_per_session.to_quota()?,
119 ),
120 email_authentication_attempt_per_session: RateLimiter::keyed(
121 config.email_authentication.attempt_per_session.to_quota()?,
122 ),
123 })
124 }
125}
126
127impl Limiter {
128 #[must_use]
133 pub fn new(config: &RateLimitingConfig) -> Option<Self> {
134 Some(Self {
135 inner: Arc::new(LimiterInner::new(config)?),
136 })
137 }
138
139 pub fn start(&self) {
144 let this = self.clone();
146 tokio::spawn(async move {
147 let mut interval = tokio::time::interval(Duration::from_secs(60));
149 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
150
151 loop {
152 this.inner.account_recovery_per_email.retain_recent();
154 this.inner.account_recovery_per_requester.retain_recent();
155 this.inner.password_check_for_requester.retain_recent();
156 this.inner.password_check_for_user.retain_recent();
157 this.inner.registration_per_requester.retain_recent();
158 this.inner.email_authentication_per_email.retain_recent();
159 this.inner
160 .email_authentication_per_requester
161 .retain_recent();
162 this.inner
163 .email_authentication_emails_per_session
164 .retain_recent();
165 this.inner
166 .email_authentication_attempt_per_session
167 .retain_recent();
168
169 interval.tick().await;
170 }
171 });
172 }
173
174 pub fn check_account_recovery(
180 &self,
181 requester: RequesterFingerprint,
182 email_address: &str,
183 ) -> Result<(), AccountRecoveryLimitedError> {
184 self.inner
185 .account_recovery_per_requester
186 .check_key(&requester)
187 .map_err(|_| AccountRecoveryLimitedError::Requester(requester))?;
188
189 let canonical_email = email_address.to_lowercase();
193 self.inner
194 .account_recovery_per_email
195 .check_key(&canonical_email)
196 .map_err(|_| AccountRecoveryLimitedError::Email(canonical_email))?;
197
198 Ok(())
199 }
200
201 pub fn check_password(
207 &self,
208 key: RequesterFingerprint,
209 user: &User,
210 ) -> Result<(), PasswordCheckLimitedError> {
211 self.inner
212 .password_check_for_requester
213 .check_key(&key)
214 .map_err(|_| PasswordCheckLimitedError::Requester(key))?;
215
216 self.inner
217 .password_check_for_user
218 .check_key(&user.id)
219 .map_err(|_| PasswordCheckLimitedError::User(user.id))?;
220
221 Ok(())
222 }
223
224 pub fn check_registration(
230 &self,
231 requester: RequesterFingerprint,
232 ) -> Result<(), RegistrationLimitedError> {
233 self.inner
234 .registration_per_requester
235 .check_key(&requester)
236 .map_err(|_| RegistrationLimitedError::Requester(requester))?;
237
238 Ok(())
239 }
240
241 pub fn check_email_authentication_email(
248 &self,
249 requester: RequesterFingerprint,
250 email: &str,
251 ) -> Result<(), EmailAuthenticationLimitedError> {
252 self.inner
253 .email_authentication_per_requester
254 .check_key(&requester)
255 .map_err(|_| EmailAuthenticationLimitedError::Requester(requester))?;
256
257 let canonical_email = email.to_lowercase();
261 self.inner
262 .email_authentication_per_email
263 .check_key(&canonical_email)
264 .map_err(|_| EmailAuthenticationLimitedError::Email(email.to_owned()))?;
265 Ok(())
266 }
267
268 pub fn check_email_authentication_attempt(
274 &self,
275 authentication: &UserEmailAuthentication,
276 ) -> Result<(), EmailAuthenticationLimitedError> {
277 self.inner
278 .email_authentication_attempt_per_session
279 .check_key(&authentication.id)
280 .map_err(|_| EmailAuthenticationLimitedError::Authentication(authentication.id))
281 }
282
283 pub fn check_email_authentication_send_code(
290 &self,
291 requester: RequesterFingerprint,
292 authentication: &UserEmailAuthentication,
293 ) -> Result<(), EmailAuthenticationLimitedError> {
294 self.check_email_authentication_email(requester, &authentication.email)?;
295 self.inner
296 .email_authentication_emails_per_session
297 .check_key(&authentication.id)
298 .map_err(|_| EmailAuthenticationLimitedError::Authentication(authentication.id))
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use mas_data_model::User;
305 use mas_storage::{Clock, clock::MockClock};
306 use rand::SeedableRng;
307
308 use super::*;
309
310 #[test]
311 fn test_password_check_limiter() {
312 let now = MockClock::default().now();
313 let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
314
315 let limiter = Limiter::new(&RateLimitingConfig::default()).unwrap();
316
317 let requesters: [_; 768] = (0..=255)
319 .flat_map(|a| (0..3).map(move |b| RequesterFingerprint::new([a, a, b, b].into())))
320 .collect::<Vec<_>>()
321 .try_into()
322 .unwrap();
323
324 let alice = User {
325 id: Ulid::from_datetime_with_source(now.into(), &mut rng),
326 username: "alice".to_owned(),
327 sub: "123-456".to_owned(),
328 created_at: now,
329 locked_at: None,
330 deactivated_at: None,
331 can_request_admin: false,
332 };
333
334 let bob = User {
335 id: Ulid::from_datetime_with_source(now.into(), &mut rng),
336 username: "bob".to_owned(),
337 sub: "123-456".to_owned(),
338 created_at: now,
339 locked_at: None,
340 deactivated_at: None,
341 can_request_admin: false,
342 };
343
344 assert!(limiter.check_password(requesters[0], &alice).is_ok());
346 assert!(limiter.check_password(requesters[0], &alice).is_ok());
347 assert!(limiter.check_password(requesters[0], &alice).is_ok());
348
349 assert!(limiter.check_password(requesters[0], &alice).is_err());
351 assert!(limiter.check_password(requesters[0], &bob).is_err());
353
354 assert!(limiter.check_password(requesters[1], &alice).is_ok());
356
357 for requester in requesters.iter().skip(2).take(598) {
360 assert!(limiter.check_password(*requester, &alice).is_ok());
361 assert!(limiter.check_password(*requester, &alice).is_ok());
362 assert!(limiter.check_password(*requester, &alice).is_ok());
363 assert!(limiter.check_password(*requester, &alice).is_err());
364 }
365
366 assert!(limiter.check_password(requesters[600], &alice).is_ok());
369 assert!(limiter.check_password(requesters[601], &alice).is_ok());
370 assert!(limiter.check_password(requesters[602], &alice).is_err());
371
372 assert!(limiter.check_password(requesters[603], &bob).is_ok());
374 }
375}