mas_storage/upstream_oauth2/
link.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use async_trait::async_trait;
8use mas_data_model::{UpstreamOAuthLink, UpstreamOAuthProvider, User};
9use rand_core::RngCore;
10use ulid::Ulid;
11
12use crate::{Clock, Pagination, pagination::Page, repository_impl};
13
14/// Filter parameters for listing upstream OAuth links
15#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
16pub struct UpstreamOAuthLinkFilter<'a> {
17    // XXX: we might also want to filter for links without a user linked to them
18    user: Option<&'a User>,
19    provider: Option<&'a UpstreamOAuthProvider>,
20    provider_enabled: Option<bool>,
21    subject: Option<&'a str>,
22}
23
24impl<'a> UpstreamOAuthLinkFilter<'a> {
25    /// Create a new [`UpstreamOAuthLinkFilter`] with default values
26    #[must_use]
27    pub fn new() -> Self {
28        Self::default()
29    }
30
31    /// Set the user who owns the upstream OAuth links
32    #[must_use]
33    pub fn for_user(mut self, user: &'a User) -> Self {
34        self.user = Some(user);
35        self
36    }
37
38    /// Get the user filter
39    ///
40    /// Returns [`None`] if no filter was set
41    #[must_use]
42    pub fn user(&self) -> Option<&User> {
43        self.user
44    }
45
46    /// Set the upstream OAuth provider for which to list links
47    #[must_use]
48    pub fn for_provider(mut self, provider: &'a UpstreamOAuthProvider) -> Self {
49        self.provider = Some(provider);
50        self
51    }
52
53    /// Get the upstream OAuth provider filter
54    ///
55    /// Returns [`None`] if no filter was set
56    #[must_use]
57    pub fn provider(&self) -> Option<&UpstreamOAuthProvider> {
58        self.provider
59    }
60
61    /// Set whether to filter for enabled providers
62    #[must_use]
63    pub const fn enabled_providers_only(mut self) -> Self {
64        self.provider_enabled = Some(true);
65        self
66    }
67
68    /// Set whether to filter for disabled providers
69    #[must_use]
70    pub const fn disabled_providers_only(mut self) -> Self {
71        self.provider_enabled = Some(false);
72        self
73    }
74
75    /// Get the provider enabled filter
76    #[must_use]
77    pub const fn provider_enabled(&self) -> Option<bool> {
78        self.provider_enabled
79    }
80
81    /// Set the subject filter
82    #[must_use]
83    pub const fn for_subject(mut self, subject: &'a str) -> Self {
84        self.subject = Some(subject);
85        self
86    }
87
88    /// Get the subject filter
89    #[must_use]
90    pub const fn subject(&self) -> Option<&str> {
91        self.subject
92    }
93}
94
95/// An [`UpstreamOAuthLinkRepository`] helps interacting with
96/// [`UpstreamOAuthLink`] with the storage backend
97#[async_trait]
98pub trait UpstreamOAuthLinkRepository: Send + Sync {
99    /// The error type returned by the repository
100    type Error;
101
102    /// Lookup an upstream OAuth link by its ID
103    ///
104    /// Returns `None` if the link does not exist
105    ///
106    /// # Parameters
107    ///
108    /// * `id`: The ID of the upstream OAuth link to lookup
109    ///
110    /// # Errors
111    ///
112    /// Returns [`Self::Error`] if the underlying repository fails
113    async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthLink>, Self::Error>;
114
115    /// Find an upstream OAuth link for a provider by its subject
116    ///
117    /// Returns `None` if no matching upstream OAuth link was found
118    ///
119    /// # Parameters
120    ///
121    /// * `upstream_oauth_provider`: The upstream OAuth provider on which to
122    ///   find the link
123    /// * `subject`: The subject of the upstream OAuth link to find
124    ///
125    /// # Errors
126    ///
127    /// Returns [`Self::Error`] if the underlying repository fails
128    async fn find_by_subject(
129        &mut self,
130        upstream_oauth_provider: &UpstreamOAuthProvider,
131        subject: &str,
132    ) -> Result<Option<UpstreamOAuthLink>, Self::Error>;
133
134    /// Add a new upstream OAuth link
135    ///
136    /// Returns the newly created upstream OAuth link
137    ///
138    /// # Parameters
139    ///
140    /// * `rng`: The random number generator to use
141    /// * `clock`: The clock used to generate timestamps
142    /// * `upsream_oauth_provider`: The upstream OAuth provider for which to
143    ///   create the link
144    /// * `subject`: The subject of the upstream OAuth link to create
145    /// * `human_account_name`: A human-readable name for the upstream account
146    ///
147    /// # Errors
148    ///
149    /// Returns [`Self::Error`] if the underlying repository fails
150    async fn add(
151        &mut self,
152        rng: &mut (dyn RngCore + Send),
153        clock: &dyn Clock,
154        upstream_oauth_provider: &UpstreamOAuthProvider,
155        subject: String,
156        human_account_name: Option<String>,
157    ) -> Result<UpstreamOAuthLink, Self::Error>;
158
159    /// Associate an upstream OAuth link to a user
160    ///
161    /// Returns the updated upstream OAuth link
162    ///
163    /// # Parameters
164    ///
165    /// * `upstream_oauth_link`: The upstream OAuth link to update
166    /// * `user`: The user to associate to the upstream OAuth link
167    ///
168    /// # Errors
169    ///
170    /// Returns [`Self::Error`] if the underlying repository fails
171    async fn associate_to_user(
172        &mut self,
173        upstream_oauth_link: &UpstreamOAuthLink,
174        user: &User,
175    ) -> Result<(), Self::Error>;
176
177    /// List [`UpstreamOAuthLink`] with the given filter and pagination
178    ///
179    /// # Parameters
180    ///
181    /// * `filter`: The filter to apply
182    /// * `pagination`: The pagination parameters
183    ///
184    /// # Errors
185    ///
186    /// Returns [`Self::Error`] if the underlying repository fails
187    async fn list(
188        &mut self,
189        filter: UpstreamOAuthLinkFilter<'_>,
190        pagination: Pagination,
191    ) -> Result<Page<UpstreamOAuthLink>, Self::Error>;
192
193    /// Count the number of [`UpstreamOAuthLink`] with the given filter
194    ///
195    /// # Parameters
196    ///
197    /// * `filter`: The filter to apply
198    ///
199    /// # Errors
200    ///
201    /// Returns [`Self::Error`] if the underlying repository fails
202    async fn count(&mut self, filter: UpstreamOAuthLinkFilter<'_>) -> Result<usize, Self::Error>;
203}
204
205repository_impl!(UpstreamOAuthLinkRepository:
206    async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthLink>, Self::Error>;
207
208    async fn find_by_subject(
209        &mut self,
210        upstream_oauth_provider: &UpstreamOAuthProvider,
211        subject: &str,
212    ) -> Result<Option<UpstreamOAuthLink>, Self::Error>;
213
214    async fn add(
215        &mut self,
216        rng: &mut (dyn RngCore + Send),
217        clock: &dyn Clock,
218        upstream_oauth_provider: &UpstreamOAuthProvider,
219        subject: String,
220        human_account_name: Option<String>,
221    ) -> Result<UpstreamOAuthLink, Self::Error>;
222
223    async fn associate_to_user(
224        &mut self,
225        upstream_oauth_link: &UpstreamOAuthLink,
226        user: &User,
227    ) -> Result<(), Self::Error>;
228
229    async fn list(
230        &mut self,
231        filter: UpstreamOAuthLinkFilter<'_>,
232        pagination: Pagination,
233    ) -> Result<Page<UpstreamOAuthLink>, Self::Error>;
234
235    async fn count(&mut self, filter: UpstreamOAuthLinkFilter<'_>) -> Result<usize, Self::Error>;
236);