mas_storage/
repository.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use futures_util::future::BoxFuture;
8use thiserror::Error;
9
10use crate::{
11    app_session::AppSessionRepository,
12    compat::{
13        CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
14        CompatSsoLoginRepository,
15    },
16    oauth2::{
17        OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, OAuth2ClientRepository,
18        OAuth2DeviceCodeGrantRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository,
19    },
20    policy_data::PolicyDataRepository,
21    queue::{QueueJobRepository, QueueScheduleRepository, QueueWorkerRepository},
22    upstream_oauth2::{
23        UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository,
24        UpstreamOAuthSessionRepository,
25    },
26    user::{
27        BrowserSessionRepository, UserEmailRepository, UserPasswordRepository,
28        UserRecoveryRepository, UserRegistrationRepository, UserRepository, UserTermsRepository,
29    },
30};
31
32/// A [`Repository`] helps interacting with the underlying storage backend.
33pub trait Repository<E>:
34    RepositoryAccess<Error = E> + RepositoryTransaction<Error = E> + Send
35where
36    E: std::error::Error + Send + Sync + 'static,
37{
38}
39
40/// An opaque, type-erased error
41#[derive(Debug, Error)]
42#[error(transparent)]
43pub struct RepositoryError {
44    source: Box<dyn std::error::Error + Send + Sync + 'static>,
45}
46
47impl RepositoryError {
48    /// Construct a [`RepositoryError`] from any error kind
49    pub fn from_error<E>(value: E) -> Self
50    where
51        E: std::error::Error + Send + Sync + 'static,
52    {
53        Self {
54            source: Box::new(value),
55        }
56    }
57}
58
59/// A type-erased [`Repository`]
60pub type BoxRepository = Box<dyn Repository<RepositoryError> + Send + Sync + 'static>;
61
62/// A [`RepositoryTransaction`] can be saved or cancelled, after a series
63/// of operations.
64pub trait RepositoryTransaction {
65    /// The error type used by the [`Self::save`] and [`Self::cancel`] functions
66    type Error;
67
68    /// Commit the transaction
69    ///
70    /// # Errors
71    ///
72    /// Returns an error if the underlying storage backend failed to commit the
73    /// transaction.
74    fn save(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>>;
75
76    /// Rollback the transaction
77    ///
78    /// # Errors
79    ///
80    /// Returns an error if the underlying storage backend failed to rollback
81    /// the transaction.
82    fn cancel(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>>;
83}
84
85/// Access the various repositories the backend implements.
86///
87/// All the methods return a boxed trait object, which can be used to access a
88/// particular repository. The lifetime of the returned object is bound to the
89/// lifetime of the whole repository, so that only one mutable reference to the
90/// repository is used at a time.
91///
92/// When adding a new repository, you should add a new method to this trait, and
93/// update the implementations for [`crate::MapErr`] and [`Box<R>`] below.
94///
95/// Note: this used to have generic associated types to avoid boxing all the
96/// repository traits, but that was removed because it made almost impossible to
97/// box the trait object. This might be a shortcoming of the initial
98/// implementation of generic associated types, and might be fixed in the
99/// future.
100pub trait RepositoryAccess: Send {
101    /// The backend-specific error type used by each repository.
102    type Error: std::error::Error + Send + Sync + 'static;
103
104    /// Get an [`UpstreamOAuthLinkRepository`]
105    fn upstream_oauth_link<'c>(
106        &'c mut self,
107    ) -> Box<dyn UpstreamOAuthLinkRepository<Error = Self::Error> + 'c>;
108
109    /// Get an [`UpstreamOAuthProviderRepository`]
110    fn upstream_oauth_provider<'c>(
111        &'c mut self,
112    ) -> Box<dyn UpstreamOAuthProviderRepository<Error = Self::Error> + 'c>;
113
114    /// Get an [`UpstreamOAuthSessionRepository`]
115    fn upstream_oauth_session<'c>(
116        &'c mut self,
117    ) -> Box<dyn UpstreamOAuthSessionRepository<Error = Self::Error> + 'c>;
118
119    /// Get an [`UserRepository`]
120    fn user<'c>(&'c mut self) -> Box<dyn UserRepository<Error = Self::Error> + 'c>;
121
122    /// Get an [`UserEmailRepository`]
123    fn user_email<'c>(&'c mut self) -> Box<dyn UserEmailRepository<Error = Self::Error> + 'c>;
124
125    /// Get an [`UserPasswordRepository`]
126    fn user_password<'c>(&'c mut self)
127    -> Box<dyn UserPasswordRepository<Error = Self::Error> + 'c>;
128
129    /// Get an [`UserRecoveryRepository`]
130    fn user_recovery<'c>(&'c mut self)
131    -> Box<dyn UserRecoveryRepository<Error = Self::Error> + 'c>;
132
133    /// Get an [`UserRegistrationRepository`]
134    fn user_registration<'c>(
135        &'c mut self,
136    ) -> Box<dyn UserRegistrationRepository<Error = Self::Error> + 'c>;
137
138    /// Get an [`UserTermsRepository`]
139    fn user_terms<'c>(&'c mut self) -> Box<dyn UserTermsRepository<Error = Self::Error> + 'c>;
140
141    /// Get a [`BrowserSessionRepository`]
142    fn browser_session<'c>(
143        &'c mut self,
144    ) -> Box<dyn BrowserSessionRepository<Error = Self::Error> + 'c>;
145
146    /// Get a [`AppSessionRepository`]
147    fn app_session<'c>(&'c mut self) -> Box<dyn AppSessionRepository<Error = Self::Error> + 'c>;
148
149    /// Get an [`OAuth2ClientRepository`]
150    fn oauth2_client<'c>(&'c mut self)
151    -> Box<dyn OAuth2ClientRepository<Error = Self::Error> + 'c>;
152
153    /// Get an [`OAuth2AuthorizationGrantRepository`]
154    fn oauth2_authorization_grant<'c>(
155        &'c mut self,
156    ) -> Box<dyn OAuth2AuthorizationGrantRepository<Error = Self::Error> + 'c>;
157
158    /// Get an [`OAuth2SessionRepository`]
159    fn oauth2_session<'c>(
160        &'c mut self,
161    ) -> Box<dyn OAuth2SessionRepository<Error = Self::Error> + 'c>;
162
163    /// Get an [`OAuth2AccessTokenRepository`]
164    fn oauth2_access_token<'c>(
165        &'c mut self,
166    ) -> Box<dyn OAuth2AccessTokenRepository<Error = Self::Error> + 'c>;
167
168    /// Get an [`OAuth2RefreshTokenRepository`]
169    fn oauth2_refresh_token<'c>(
170        &'c mut self,
171    ) -> Box<dyn OAuth2RefreshTokenRepository<Error = Self::Error> + 'c>;
172
173    /// Get an [`OAuth2DeviceCodeGrantRepository`]
174    fn oauth2_device_code_grant<'c>(
175        &'c mut self,
176    ) -> Box<dyn OAuth2DeviceCodeGrantRepository<Error = Self::Error> + 'c>;
177
178    /// Get a [`CompatSessionRepository`]
179    fn compat_session<'c>(
180        &'c mut self,
181    ) -> Box<dyn CompatSessionRepository<Error = Self::Error> + 'c>;
182
183    /// Get a [`CompatSsoLoginRepository`]
184    fn compat_sso_login<'c>(
185        &'c mut self,
186    ) -> Box<dyn CompatSsoLoginRepository<Error = Self::Error> + 'c>;
187
188    /// Get a [`CompatAccessTokenRepository`]
189    fn compat_access_token<'c>(
190        &'c mut self,
191    ) -> Box<dyn CompatAccessTokenRepository<Error = Self::Error> + 'c>;
192
193    /// Get a [`CompatRefreshTokenRepository`]
194    fn compat_refresh_token<'c>(
195        &'c mut self,
196    ) -> Box<dyn CompatRefreshTokenRepository<Error = Self::Error> + 'c>;
197
198    /// Get a [`QueueWorkerRepository`]
199    fn queue_worker<'c>(&'c mut self) -> Box<dyn QueueWorkerRepository<Error = Self::Error> + 'c>;
200
201    /// Get a [`QueueJobRepository`]
202    fn queue_job<'c>(&'c mut self) -> Box<dyn QueueJobRepository<Error = Self::Error> + 'c>;
203
204    /// Get a [`QueueScheduleRepository`]
205    fn queue_schedule<'c>(
206        &'c mut self,
207    ) -> Box<dyn QueueScheduleRepository<Error = Self::Error> + 'c>;
208
209    /// Get a [`PolicyDataRepository`]
210    fn policy_data<'c>(&'c mut self) -> Box<dyn PolicyDataRepository<Error = Self::Error> + 'c>;
211}
212
213/// Implementations of the [`RepositoryAccess`], [`RepositoryTransaction`] and
214/// [`Repository`] for the [`crate::MapErr`] wrapper and [`Box<R>`]
215mod impls {
216    use futures_util::{FutureExt, TryFutureExt, future::BoxFuture};
217
218    use super::RepositoryAccess;
219    use crate::{
220        MapErr, Repository, RepositoryTransaction,
221        app_session::AppSessionRepository,
222        compat::{
223            CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
224            CompatSsoLoginRepository,
225        },
226        oauth2::{
227            OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository,
228            OAuth2ClientRepository, OAuth2DeviceCodeGrantRepository, OAuth2RefreshTokenRepository,
229            OAuth2SessionRepository,
230        },
231        policy_data::PolicyDataRepository,
232        queue::{QueueJobRepository, QueueScheduleRepository, QueueWorkerRepository},
233        upstream_oauth2::{
234            UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository,
235            UpstreamOAuthSessionRepository,
236        },
237        user::{
238            BrowserSessionRepository, UserEmailRepository, UserPasswordRepository,
239            UserRegistrationRepository, UserRepository, UserTermsRepository,
240        },
241    };
242
243    // --- Repository ---
244    impl<R, F, E1, E2> Repository<E2> for MapErr<R, F>
245    where
246        R: Repository<E1> + RepositoryAccess<Error = E1> + RepositoryTransaction<Error = E1>,
247        F: FnMut(E1) -> E2 + Send + Sync + 'static,
248        E1: std::error::Error + Send + Sync + 'static,
249        E2: std::error::Error + Send + Sync + 'static,
250    {
251    }
252
253    // --- RepositoryTransaction --
254    impl<R, F, E> RepositoryTransaction for MapErr<R, F>
255    where
256        R: RepositoryTransaction,
257        R::Error: 'static,
258        F: FnMut(R::Error) -> E + Send + Sync + 'static,
259        E: std::error::Error,
260    {
261        type Error = E;
262
263        fn save(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>> {
264            Box::new(self.inner).save().map_err(self.mapper).boxed()
265        }
266
267        fn cancel(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>> {
268            Box::new(self.inner).cancel().map_err(self.mapper).boxed()
269        }
270    }
271
272    // --- RepositoryAccess --
273    impl<R, F, E> RepositoryAccess for MapErr<R, F>
274    where
275        R: RepositoryAccess,
276        R::Error: 'static,
277        F: FnMut(R::Error) -> E + Send + Sync + 'static,
278        E: std::error::Error + Send + Sync + 'static,
279    {
280        type Error = E;
281
282        fn upstream_oauth_link<'c>(
283            &'c mut self,
284        ) -> Box<dyn UpstreamOAuthLinkRepository<Error = Self::Error> + 'c> {
285            Box::new(MapErr::new(
286                self.inner.upstream_oauth_link(),
287                &mut self.mapper,
288            ))
289        }
290
291        fn upstream_oauth_provider<'c>(
292            &'c mut self,
293        ) -> Box<dyn UpstreamOAuthProviderRepository<Error = Self::Error> + 'c> {
294            Box::new(MapErr::new(
295                self.inner.upstream_oauth_provider(),
296                &mut self.mapper,
297            ))
298        }
299
300        fn upstream_oauth_session<'c>(
301            &'c mut self,
302        ) -> Box<dyn UpstreamOAuthSessionRepository<Error = Self::Error> + 'c> {
303            Box::new(MapErr::new(
304                self.inner.upstream_oauth_session(),
305                &mut self.mapper,
306            ))
307        }
308
309        fn user<'c>(&'c mut self) -> Box<dyn UserRepository<Error = Self::Error> + 'c> {
310            Box::new(MapErr::new(self.inner.user(), &mut self.mapper))
311        }
312
313        fn user_email<'c>(&'c mut self) -> Box<dyn UserEmailRepository<Error = Self::Error> + 'c> {
314            Box::new(MapErr::new(self.inner.user_email(), &mut self.mapper))
315        }
316
317        fn user_password<'c>(
318            &'c mut self,
319        ) -> Box<dyn UserPasswordRepository<Error = Self::Error> + 'c> {
320            Box::new(MapErr::new(self.inner.user_password(), &mut self.mapper))
321        }
322
323        fn user_recovery<'c>(
324            &'c mut self,
325        ) -> Box<dyn crate::user::UserRecoveryRepository<Error = Self::Error> + 'c> {
326            Box::new(MapErr::new(self.inner.user_recovery(), &mut self.mapper))
327        }
328
329        fn user_registration<'c>(
330            &'c mut self,
331        ) -> Box<dyn UserRegistrationRepository<Error = Self::Error> + 'c> {
332            Box::new(MapErr::new(
333                self.inner.user_registration(),
334                &mut self.mapper,
335            ))
336        }
337
338        fn user_terms<'c>(&'c mut self) -> Box<dyn UserTermsRepository<Error = Self::Error> + 'c> {
339            Box::new(MapErr::new(self.inner.user_terms(), &mut self.mapper))
340        }
341
342        fn browser_session<'c>(
343            &'c mut self,
344        ) -> Box<dyn BrowserSessionRepository<Error = Self::Error> + 'c> {
345            Box::new(MapErr::new(self.inner.browser_session(), &mut self.mapper))
346        }
347
348        fn app_session<'c>(
349            &'c mut self,
350        ) -> Box<dyn AppSessionRepository<Error = Self::Error> + 'c> {
351            Box::new(MapErr::new(self.inner.app_session(), &mut self.mapper))
352        }
353
354        fn oauth2_client<'c>(
355            &'c mut self,
356        ) -> Box<dyn OAuth2ClientRepository<Error = Self::Error> + 'c> {
357            Box::new(MapErr::new(self.inner.oauth2_client(), &mut self.mapper))
358        }
359
360        fn oauth2_authorization_grant<'c>(
361            &'c mut self,
362        ) -> Box<dyn OAuth2AuthorizationGrantRepository<Error = Self::Error> + 'c> {
363            Box::new(MapErr::new(
364                self.inner.oauth2_authorization_grant(),
365                &mut self.mapper,
366            ))
367        }
368
369        fn oauth2_session<'c>(
370            &'c mut self,
371        ) -> Box<dyn OAuth2SessionRepository<Error = Self::Error> + 'c> {
372            Box::new(MapErr::new(self.inner.oauth2_session(), &mut self.mapper))
373        }
374
375        fn oauth2_access_token<'c>(
376            &'c mut self,
377        ) -> Box<dyn OAuth2AccessTokenRepository<Error = Self::Error> + 'c> {
378            Box::new(MapErr::new(
379                self.inner.oauth2_access_token(),
380                &mut self.mapper,
381            ))
382        }
383
384        fn oauth2_refresh_token<'c>(
385            &'c mut self,
386        ) -> Box<dyn OAuth2RefreshTokenRepository<Error = Self::Error> + 'c> {
387            Box::new(MapErr::new(
388                self.inner.oauth2_refresh_token(),
389                &mut self.mapper,
390            ))
391        }
392
393        fn oauth2_device_code_grant<'c>(
394            &'c mut self,
395        ) -> Box<dyn OAuth2DeviceCodeGrantRepository<Error = Self::Error> + 'c> {
396            Box::new(MapErr::new(
397                self.inner.oauth2_device_code_grant(),
398                &mut self.mapper,
399            ))
400        }
401
402        fn compat_session<'c>(
403            &'c mut self,
404        ) -> Box<dyn CompatSessionRepository<Error = Self::Error> + 'c> {
405            Box::new(MapErr::new(self.inner.compat_session(), &mut self.mapper))
406        }
407
408        fn compat_sso_login<'c>(
409            &'c mut self,
410        ) -> Box<dyn CompatSsoLoginRepository<Error = Self::Error> + 'c> {
411            Box::new(MapErr::new(self.inner.compat_sso_login(), &mut self.mapper))
412        }
413
414        fn compat_access_token<'c>(
415            &'c mut self,
416        ) -> Box<dyn CompatAccessTokenRepository<Error = Self::Error> + 'c> {
417            Box::new(MapErr::new(
418                self.inner.compat_access_token(),
419                &mut self.mapper,
420            ))
421        }
422
423        fn compat_refresh_token<'c>(
424            &'c mut self,
425        ) -> Box<dyn CompatRefreshTokenRepository<Error = Self::Error> + 'c> {
426            Box::new(MapErr::new(
427                self.inner.compat_refresh_token(),
428                &mut self.mapper,
429            ))
430        }
431
432        fn queue_worker<'c>(
433            &'c mut self,
434        ) -> Box<dyn QueueWorkerRepository<Error = Self::Error> + 'c> {
435            Box::new(MapErr::new(self.inner.queue_worker(), &mut self.mapper))
436        }
437
438        fn queue_job<'c>(&'c mut self) -> Box<dyn QueueJobRepository<Error = Self::Error> + 'c> {
439            Box::new(MapErr::new(self.inner.queue_job(), &mut self.mapper))
440        }
441
442        fn queue_schedule<'c>(
443            &'c mut self,
444        ) -> Box<dyn QueueScheduleRepository<Error = Self::Error> + 'c> {
445            Box::new(MapErr::new(self.inner.queue_schedule(), &mut self.mapper))
446        }
447
448        fn policy_data<'c>(
449            &'c mut self,
450        ) -> Box<dyn PolicyDataRepository<Error = Self::Error> + 'c> {
451            Box::new(MapErr::new(self.inner.policy_data(), &mut self.mapper))
452        }
453    }
454
455    impl<R: RepositoryAccess + ?Sized> RepositoryAccess for Box<R> {
456        type Error = R::Error;
457
458        fn upstream_oauth_link<'c>(
459            &'c mut self,
460        ) -> Box<dyn UpstreamOAuthLinkRepository<Error = Self::Error> + 'c> {
461            (**self).upstream_oauth_link()
462        }
463
464        fn upstream_oauth_provider<'c>(
465            &'c mut self,
466        ) -> Box<dyn UpstreamOAuthProviderRepository<Error = Self::Error> + 'c> {
467            (**self).upstream_oauth_provider()
468        }
469
470        fn upstream_oauth_session<'c>(
471            &'c mut self,
472        ) -> Box<dyn UpstreamOAuthSessionRepository<Error = Self::Error> + 'c> {
473            (**self).upstream_oauth_session()
474        }
475
476        fn user<'c>(&'c mut self) -> Box<dyn UserRepository<Error = Self::Error> + 'c> {
477            (**self).user()
478        }
479
480        fn user_email<'c>(&'c mut self) -> Box<dyn UserEmailRepository<Error = Self::Error> + 'c> {
481            (**self).user_email()
482        }
483
484        fn user_password<'c>(
485            &'c mut self,
486        ) -> Box<dyn UserPasswordRepository<Error = Self::Error> + 'c> {
487            (**self).user_password()
488        }
489
490        fn user_recovery<'c>(
491            &'c mut self,
492        ) -> Box<dyn crate::user::UserRecoveryRepository<Error = Self::Error> + 'c> {
493            (**self).user_recovery()
494        }
495
496        fn user_registration<'c>(
497            &'c mut self,
498        ) -> Box<dyn UserRegistrationRepository<Error = Self::Error> + 'c> {
499            (**self).user_registration()
500        }
501
502        fn user_terms<'c>(&'c mut self) -> Box<dyn UserTermsRepository<Error = Self::Error> + 'c> {
503            (**self).user_terms()
504        }
505
506        fn browser_session<'c>(
507            &'c mut self,
508        ) -> Box<dyn BrowserSessionRepository<Error = Self::Error> + 'c> {
509            (**self).browser_session()
510        }
511
512        fn app_session<'c>(
513            &'c mut self,
514        ) -> Box<dyn AppSessionRepository<Error = Self::Error> + 'c> {
515            (**self).app_session()
516        }
517
518        fn oauth2_client<'c>(
519            &'c mut self,
520        ) -> Box<dyn OAuth2ClientRepository<Error = Self::Error> + 'c> {
521            (**self).oauth2_client()
522        }
523
524        fn oauth2_authorization_grant<'c>(
525            &'c mut self,
526        ) -> Box<dyn OAuth2AuthorizationGrantRepository<Error = Self::Error> + 'c> {
527            (**self).oauth2_authorization_grant()
528        }
529
530        fn oauth2_session<'c>(
531            &'c mut self,
532        ) -> Box<dyn OAuth2SessionRepository<Error = Self::Error> + 'c> {
533            (**self).oauth2_session()
534        }
535
536        fn oauth2_access_token<'c>(
537            &'c mut self,
538        ) -> Box<dyn OAuth2AccessTokenRepository<Error = Self::Error> + 'c> {
539            (**self).oauth2_access_token()
540        }
541
542        fn oauth2_refresh_token<'c>(
543            &'c mut self,
544        ) -> Box<dyn OAuth2RefreshTokenRepository<Error = Self::Error> + 'c> {
545            (**self).oauth2_refresh_token()
546        }
547
548        fn oauth2_device_code_grant<'c>(
549            &'c mut self,
550        ) -> Box<dyn OAuth2DeviceCodeGrantRepository<Error = Self::Error> + 'c> {
551            (**self).oauth2_device_code_grant()
552        }
553
554        fn compat_session<'c>(
555            &'c mut self,
556        ) -> Box<dyn CompatSessionRepository<Error = Self::Error> + 'c> {
557            (**self).compat_session()
558        }
559
560        fn compat_sso_login<'c>(
561            &'c mut self,
562        ) -> Box<dyn CompatSsoLoginRepository<Error = Self::Error> + 'c> {
563            (**self).compat_sso_login()
564        }
565
566        fn compat_access_token<'c>(
567            &'c mut self,
568        ) -> Box<dyn CompatAccessTokenRepository<Error = Self::Error> + 'c> {
569            (**self).compat_access_token()
570        }
571
572        fn compat_refresh_token<'c>(
573            &'c mut self,
574        ) -> Box<dyn CompatRefreshTokenRepository<Error = Self::Error> + 'c> {
575            (**self).compat_refresh_token()
576        }
577
578        fn queue_worker<'c>(
579            &'c mut self,
580        ) -> Box<dyn QueueWorkerRepository<Error = Self::Error> + 'c> {
581            (**self).queue_worker()
582        }
583
584        fn queue_job<'c>(&'c mut self) -> Box<dyn QueueJobRepository<Error = Self::Error> + 'c> {
585            (**self).queue_job()
586        }
587
588        fn queue_schedule<'c>(
589            &'c mut self,
590        ) -> Box<dyn QueueScheduleRepository<Error = Self::Error> + 'c> {
591            (**self).queue_schedule()
592        }
593
594        fn policy_data<'c>(
595            &'c mut self,
596        ) -> Box<dyn PolicyDataRepository<Error = Self::Error> + 'c> {
597            (**self).policy_data()
598        }
599    }
600}