mas_handlers/
rate_limit.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::{net::IpAddr, sync::Arc, time::Duration};
8
9use governor::{RateLimiter, clock::QuantaClock, state::keyed::DashMapStateStore};
10use mas_config::RateLimitingConfig;
11use mas_data_model::{User, UserEmailAuthentication};
12use ulid::Ulid;
13
14#[derive(Debug, Clone, thiserror::Error)]
15pub enum AccountRecoveryLimitedError {
16    #[error("Too many account recovery requests for requester {0}")]
17    Requester(RequesterFingerprint),
18
19    #[error("Too many account recovery requests for e-mail {0}")]
20    Email(String),
21}
22
23#[derive(Debug, Clone, Copy, thiserror::Error)]
24pub enum PasswordCheckLimitedError {
25    #[error("Too many password checks for requester {0}")]
26    Requester(RequesterFingerprint),
27
28    #[error("Too many password checks for user {0}")]
29    User(Ulid),
30}
31
32#[derive(Debug, Clone, thiserror::Error)]
33pub enum RegistrationLimitedError {
34    #[error("Too many account registration requests for requester {0}")]
35    Requester(RequesterFingerprint),
36}
37
38#[derive(Debug, Clone, thiserror::Error)]
39pub enum EmailAuthenticationLimitedError {
40    #[error("Too many email authentication requests for requester {0}")]
41    Requester(RequesterFingerprint),
42
43    #[error("Too many email authentication requests for authentication session {0}")]
44    Authentication(Ulid),
45
46    #[error("Too many email authentication requests for email {0}")]
47    Email(String),
48}
49
50/// Key used to rate limit requests per requester
51#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
52pub struct RequesterFingerprint {
53    ip: Option<IpAddr>,
54}
55
56impl std::fmt::Display for RequesterFingerprint {
57    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58        if let Some(ip) = self.ip {
59            write!(f, "{ip}")
60        } else {
61            write!(f, "(NO CLIENT IP)")
62        }
63    }
64}
65
66impl RequesterFingerprint {
67    /// An anonymous key with no IP address set. This should not be used in
68    /// production, and we should warn users if we can't find their client IPs.
69    pub const EMPTY: Self = Self { ip: None };
70
71    /// Create a new anonymous key with the given IP address
72    #[must_use]
73    pub const fn new(ip: IpAddr) -> Self {
74        Self { ip: Some(ip) }
75    }
76}
77
78/// Rate limiters for the different operations
79#[derive(Debug, Clone)]
80pub struct Limiter {
81    inner: Arc<LimiterInner>,
82}
83
84type KeyedRateLimiter<K> = RateLimiter<K, DashMapStateStore<K>, QuantaClock>;
85
86#[derive(Debug)]
87struct LimiterInner {
88    account_recovery_per_requester: KeyedRateLimiter<RequesterFingerprint>,
89    account_recovery_per_email: KeyedRateLimiter<String>,
90    password_check_for_requester: KeyedRateLimiter<RequesterFingerprint>,
91    password_check_for_user: KeyedRateLimiter<Ulid>,
92    registration_per_requester: KeyedRateLimiter<RequesterFingerprint>,
93    email_authentication_per_requester: KeyedRateLimiter<RequesterFingerprint>,
94    email_authentication_per_email: KeyedRateLimiter<String>,
95    email_authentication_emails_per_session: KeyedRateLimiter<Ulid>,
96    email_authentication_attempt_per_session: KeyedRateLimiter<Ulid>,
97}
98
99impl LimiterInner {
100    fn new(config: &RateLimitingConfig) -> Option<Self> {
101        Some(Self {
102            account_recovery_per_requester: RateLimiter::keyed(
103                config.account_recovery.per_ip.to_quota()?,
104            ),
105            account_recovery_per_email: RateLimiter::keyed(
106                config.account_recovery.per_address.to_quota()?,
107            ),
108            password_check_for_requester: RateLimiter::keyed(config.login.per_ip.to_quota()?),
109            password_check_for_user: RateLimiter::keyed(config.login.per_account.to_quota()?),
110            registration_per_requester: RateLimiter::keyed(config.registration.to_quota()?),
111            email_authentication_per_email: RateLimiter::keyed(
112                config.email_authentication.per_address.to_quota()?,
113            ),
114            email_authentication_per_requester: RateLimiter::keyed(
115                config.email_authentication.per_ip.to_quota()?,
116            ),
117            email_authentication_emails_per_session: RateLimiter::keyed(
118                config.email_authentication.emails_per_session.to_quota()?,
119            ),
120            email_authentication_attempt_per_session: RateLimiter::keyed(
121                config.email_authentication.attempt_per_session.to_quota()?,
122            ),
123        })
124    }
125}
126
127impl Limiter {
128    /// Creates a new `Limiter` based on a `RateLimitingConfig`.
129    ///
130    /// If the config is not valid, returns `None`.
131    /// (This should not happen if the config was validated, though.)
132    #[must_use]
133    pub fn new(config: &RateLimitingConfig) -> Option<Self> {
134        Some(Self {
135            inner: Arc::new(LimiterInner::new(config)?),
136        })
137    }
138
139    /// Start the rate limiter housekeeping task
140    ///
141    /// This task will periodically remove old entries from the rate limiters,
142    /// to make sure we don't build up a huge number of entries in memory.
143    pub fn start(&self) {
144        // Spawn a task that will periodically clean the rate limiters
145        let this = self.clone();
146        tokio::spawn(async move {
147            // Run the task every minute
148            let mut interval = tokio::time::interval(Duration::from_secs(60));
149            interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
150
151            loop {
152                // Call the retain_recent method on each rate limiter
153                this.inner.account_recovery_per_email.retain_recent();
154                this.inner.account_recovery_per_requester.retain_recent();
155                this.inner.password_check_for_requester.retain_recent();
156                this.inner.password_check_for_user.retain_recent();
157                this.inner.registration_per_requester.retain_recent();
158                this.inner.email_authentication_per_email.retain_recent();
159                this.inner
160                    .email_authentication_per_requester
161                    .retain_recent();
162                this.inner
163                    .email_authentication_emails_per_session
164                    .retain_recent();
165                this.inner
166                    .email_authentication_attempt_per_session
167                    .retain_recent();
168
169                interval.tick().await;
170            }
171        });
172    }
173
174    /// Check if an account recovery can be performed
175    ///
176    /// # Errors
177    ///
178    /// Returns an error if the operation is rate limited.
179    pub fn check_account_recovery(
180        &self,
181        requester: RequesterFingerprint,
182        email_address: &str,
183    ) -> Result<(), AccountRecoveryLimitedError> {
184        self.inner
185            .account_recovery_per_requester
186            .check_key(&requester)
187            .map_err(|_| AccountRecoveryLimitedError::Requester(requester))?;
188
189        // Convert to lowercase to prevent bypassing the limit by enumerating different
190        // case variations.
191        // A case-folding transformation may be more proper.
192        let canonical_email = email_address.to_lowercase();
193        self.inner
194            .account_recovery_per_email
195            .check_key(&canonical_email)
196            .map_err(|_| AccountRecoveryLimitedError::Email(canonical_email))?;
197
198        Ok(())
199    }
200
201    /// Check if a password check can be performed
202    ///
203    /// # Errors
204    ///
205    /// Returns an error if the operation is rate limited
206    pub fn check_password(
207        &self,
208        key: RequesterFingerprint,
209        user: &User,
210    ) -> Result<(), PasswordCheckLimitedError> {
211        self.inner
212            .password_check_for_requester
213            .check_key(&key)
214            .map_err(|_| PasswordCheckLimitedError::Requester(key))?;
215
216        self.inner
217            .password_check_for_user
218            .check_key(&user.id)
219            .map_err(|_| PasswordCheckLimitedError::User(user.id))?;
220
221        Ok(())
222    }
223
224    /// Check if an account registration can be performed
225    ///
226    /// # Errors
227    ///
228    /// Returns an error if the operation is rate limited.
229    pub fn check_registration(
230        &self,
231        requester: RequesterFingerprint,
232    ) -> Result<(), RegistrationLimitedError> {
233        self.inner
234            .registration_per_requester
235            .check_key(&requester)
236            .map_err(|_| RegistrationLimitedError::Requester(requester))?;
237
238        Ok(())
239    }
240
241    /// Check if an email can be sent to the address for an email
242    /// authentication session
243    ///
244    /// # Errors
245    ///
246    /// Returns an error if the operation is rate limited.
247    pub fn check_email_authentication_email(
248        &self,
249        requester: RequesterFingerprint,
250        email: &str,
251    ) -> Result<(), EmailAuthenticationLimitedError> {
252        self.inner
253            .email_authentication_per_requester
254            .check_key(&requester)
255            .map_err(|_| EmailAuthenticationLimitedError::Requester(requester))?;
256
257        // Convert to lowercase to prevent bypassing the limit by enumerating different
258        // case variations.
259        // A case-folding transformation may be more proper.
260        let canonical_email = email.to_lowercase();
261        self.inner
262            .email_authentication_per_email
263            .check_key(&canonical_email)
264            .map_err(|_| EmailAuthenticationLimitedError::Email(email.to_owned()))?;
265        Ok(())
266    }
267
268    /// Check if an attempt can be done on an email authentication session
269    ///
270    /// # Errors
271    ///
272    /// Returns an error if the operation is rate limited.
273    pub fn check_email_authentication_attempt(
274        &self,
275        authentication: &UserEmailAuthentication,
276    ) -> Result<(), EmailAuthenticationLimitedError> {
277        self.inner
278            .email_authentication_attempt_per_session
279            .check_key(&authentication.id)
280            .map_err(|_| EmailAuthenticationLimitedError::Authentication(authentication.id))
281    }
282
283    /// Check if a new authentication code can be sent for an email
284    /// authentication session
285    ///
286    /// # Errors
287    ///
288    /// Returns an error if the operation is rate limited.
289    pub fn check_email_authentication_send_code(
290        &self,
291        requester: RequesterFingerprint,
292        authentication: &UserEmailAuthentication,
293    ) -> Result<(), EmailAuthenticationLimitedError> {
294        self.check_email_authentication_email(requester, &authentication.email)?;
295        self.inner
296            .email_authentication_emails_per_session
297            .check_key(&authentication.id)
298            .map_err(|_| EmailAuthenticationLimitedError::Authentication(authentication.id))
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use mas_data_model::User;
305    use mas_storage::{Clock, clock::MockClock};
306    use rand::SeedableRng;
307
308    use super::*;
309
310    #[test]
311    fn test_password_check_limiter() {
312        let now = MockClock::default().now();
313        let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
314
315        let limiter = Limiter::new(&RateLimitingConfig::default()).unwrap();
316
317        // Let's create a lot of requesters to test account-level rate limiting
318        let requesters: [_; 768] = (0..=255)
319            .flat_map(|a| (0..3).map(move |b| RequesterFingerprint::new([a, a, b, b].into())))
320            .collect::<Vec<_>>()
321            .try_into()
322            .unwrap();
323
324        let alice = User {
325            id: Ulid::from_datetime_with_source(now.into(), &mut rng),
326            username: "alice".to_owned(),
327            sub: "123-456".to_owned(),
328            created_at: now,
329            locked_at: None,
330            deactivated_at: None,
331            can_request_admin: false,
332        };
333
334        let bob = User {
335            id: Ulid::from_datetime_with_source(now.into(), &mut rng),
336            username: "bob".to_owned(),
337            sub: "123-456".to_owned(),
338            created_at: now,
339            locked_at: None,
340            deactivated_at: None,
341            can_request_admin: false,
342        };
343
344        // Three times the same IP address should be allowed
345        assert!(limiter.check_password(requesters[0], &alice).is_ok());
346        assert!(limiter.check_password(requesters[0], &alice).is_ok());
347        assert!(limiter.check_password(requesters[0], &alice).is_ok());
348
349        // But the fourth time should be rejected
350        assert!(limiter.check_password(requesters[0], &alice).is_err());
351        // Using another user should also be rejected
352        assert!(limiter.check_password(requesters[0], &bob).is_err());
353
354        // Using a different IP address should be allowed, the account isn't locked yet
355        assert!(limiter.check_password(requesters[1], &alice).is_ok());
356
357        // At this point, we consumed 4 cells out of 1800 on alice, let's distribute the
358        // requests with other IPs so that we get rate-limited on the account-level
359        for requester in requesters.iter().skip(2).take(598) {
360            assert!(limiter.check_password(*requester, &alice).is_ok());
361            assert!(limiter.check_password(*requester, &alice).is_ok());
362            assert!(limiter.check_password(*requester, &alice).is_ok());
363            assert!(limiter.check_password(*requester, &alice).is_err());
364        }
365
366        // We now have consumed 4+598*3 = 1798 cells on the account, so we should be
367        // rejected soon
368        assert!(limiter.check_password(requesters[600], &alice).is_ok());
369        assert!(limiter.check_password(requesters[601], &alice).is_ok());
370        assert!(limiter.check_password(requesters[602], &alice).is_err());
371
372        // The other account isn't rate-limited
373        assert!(limiter.check_password(requesters[603], &bob).is_ok());
374    }
375}