mas_storage/upstream_oauth2/link.rs
1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use async_trait::async_trait;
8use mas_data_model::{UpstreamOAuthLink, UpstreamOAuthProvider, User};
9use rand_core::RngCore;
10use ulid::Ulid;
11
12use crate::{Clock, Pagination, pagination::Page, repository_impl};
13
14/// Filter parameters for listing upstream OAuth links
15#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
16pub struct UpstreamOAuthLinkFilter<'a> {
17 // XXX: we might also want to filter for links without a user linked to them
18 user: Option<&'a User>,
19 provider: Option<&'a UpstreamOAuthProvider>,
20 provider_enabled: Option<bool>,
21 subject: Option<&'a str>,
22}
23
24impl<'a> UpstreamOAuthLinkFilter<'a> {
25 /// Create a new [`UpstreamOAuthLinkFilter`] with default values
26 #[must_use]
27 pub fn new() -> Self {
28 Self::default()
29 }
30
31 /// Set the user who owns the upstream OAuth links
32 #[must_use]
33 pub fn for_user(mut self, user: &'a User) -> Self {
34 self.user = Some(user);
35 self
36 }
37
38 /// Get the user filter
39 ///
40 /// Returns [`None`] if no filter was set
41 #[must_use]
42 pub fn user(&self) -> Option<&User> {
43 self.user
44 }
45
46 /// Set the upstream OAuth provider for which to list links
47 #[must_use]
48 pub fn for_provider(mut self, provider: &'a UpstreamOAuthProvider) -> Self {
49 self.provider = Some(provider);
50 self
51 }
52
53 /// Get the upstream OAuth provider filter
54 ///
55 /// Returns [`None`] if no filter was set
56 #[must_use]
57 pub fn provider(&self) -> Option<&UpstreamOAuthProvider> {
58 self.provider
59 }
60
61 /// Set whether to filter for enabled providers
62 #[must_use]
63 pub const fn enabled_providers_only(mut self) -> Self {
64 self.provider_enabled = Some(true);
65 self
66 }
67
68 /// Set whether to filter for disabled providers
69 #[must_use]
70 pub const fn disabled_providers_only(mut self) -> Self {
71 self.provider_enabled = Some(false);
72 self
73 }
74
75 /// Get the provider enabled filter
76 #[must_use]
77 pub const fn provider_enabled(&self) -> Option<bool> {
78 self.provider_enabled
79 }
80
81 /// Set the subject filter
82 #[must_use]
83 pub const fn for_subject(mut self, subject: &'a str) -> Self {
84 self.subject = Some(subject);
85 self
86 }
87
88 /// Get the subject filter
89 #[must_use]
90 pub const fn subject(&self) -> Option<&str> {
91 self.subject
92 }
93}
94
95/// An [`UpstreamOAuthLinkRepository`] helps interacting with
96/// [`UpstreamOAuthLink`] with the storage backend
97#[async_trait]
98pub trait UpstreamOAuthLinkRepository: Send + Sync {
99 /// The error type returned by the repository
100 type Error;
101
102 /// Lookup an upstream OAuth link by its ID
103 ///
104 /// Returns `None` if the link does not exist
105 ///
106 /// # Parameters
107 ///
108 /// * `id`: The ID of the upstream OAuth link to lookup
109 ///
110 /// # Errors
111 ///
112 /// Returns [`Self::Error`] if the underlying repository fails
113 async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthLink>, Self::Error>;
114
115 /// Find an upstream OAuth link for a provider by its subject
116 ///
117 /// Returns `None` if no matching upstream OAuth link was found
118 ///
119 /// # Parameters
120 ///
121 /// * `upstream_oauth_provider`: The upstream OAuth provider on which to
122 /// find the link
123 /// * `subject`: The subject of the upstream OAuth link to find
124 ///
125 /// # Errors
126 ///
127 /// Returns [`Self::Error`] if the underlying repository fails
128 async fn find_by_subject(
129 &mut self,
130 upstream_oauth_provider: &UpstreamOAuthProvider,
131 subject: &str,
132 ) -> Result<Option<UpstreamOAuthLink>, Self::Error>;
133
134 /// Add a new upstream OAuth link
135 ///
136 /// Returns the newly created upstream OAuth link
137 ///
138 /// # Parameters
139 ///
140 /// * `rng`: The random number generator to use
141 /// * `clock`: The clock used to generate timestamps
142 /// * `upsream_oauth_provider`: The upstream OAuth provider for which to
143 /// create the link
144 /// * `subject`: The subject of the upstream OAuth link to create
145 /// * `human_account_name`: A human-readable name for the upstream account
146 ///
147 /// # Errors
148 ///
149 /// Returns [`Self::Error`] if the underlying repository fails
150 async fn add(
151 &mut self,
152 rng: &mut (dyn RngCore + Send),
153 clock: &dyn Clock,
154 upstream_oauth_provider: &UpstreamOAuthProvider,
155 subject: String,
156 human_account_name: Option<String>,
157 ) -> Result<UpstreamOAuthLink, Self::Error>;
158
159 /// Associate an upstream OAuth link to a user
160 ///
161 /// Returns the updated upstream OAuth link
162 ///
163 /// # Parameters
164 ///
165 /// * `upstream_oauth_link`: The upstream OAuth link to update
166 /// * `user`: The user to associate to the upstream OAuth link
167 ///
168 /// # Errors
169 ///
170 /// Returns [`Self::Error`] if the underlying repository fails
171 async fn associate_to_user(
172 &mut self,
173 upstream_oauth_link: &UpstreamOAuthLink,
174 user: &User,
175 ) -> Result<(), Self::Error>;
176
177 /// List [`UpstreamOAuthLink`] with the given filter and pagination
178 ///
179 /// # Parameters
180 ///
181 /// * `filter`: The filter to apply
182 /// * `pagination`: The pagination parameters
183 ///
184 /// # Errors
185 ///
186 /// Returns [`Self::Error`] if the underlying repository fails
187 async fn list(
188 &mut self,
189 filter: UpstreamOAuthLinkFilter<'_>,
190 pagination: Pagination,
191 ) -> Result<Page<UpstreamOAuthLink>, Self::Error>;
192
193 /// Count the number of [`UpstreamOAuthLink`] with the given filter
194 ///
195 /// # Parameters
196 ///
197 /// * `filter`: The filter to apply
198 ///
199 /// # Errors
200 ///
201 /// Returns [`Self::Error`] if the underlying repository fails
202 async fn count(&mut self, filter: UpstreamOAuthLinkFilter<'_>) -> Result<usize, Self::Error>;
203}
204
205repository_impl!(UpstreamOAuthLinkRepository:
206 async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthLink>, Self::Error>;
207
208 async fn find_by_subject(
209 &mut self,
210 upstream_oauth_provider: &UpstreamOAuthProvider,
211 subject: &str,
212 ) -> Result<Option<UpstreamOAuthLink>, Self::Error>;
213
214 async fn add(
215 &mut self,
216 rng: &mut (dyn RngCore + Send),
217 clock: &dyn Clock,
218 upstream_oauth_provider: &UpstreamOAuthProvider,
219 subject: String,
220 human_account_name: Option<String>,
221 ) -> Result<UpstreamOAuthLink, Self::Error>;
222
223 async fn associate_to_user(
224 &mut self,
225 upstream_oauth_link: &UpstreamOAuthLink,
226 user: &User,
227 ) -> Result<(), Self::Error>;
228
229 async fn list(
230 &mut self,
231 filter: UpstreamOAuthLinkFilter<'_>,
232 pagination: Pagination,
233 ) -> Result<Page<UpstreamOAuthLink>, Self::Error>;
234
235 async fn count(&mut self, filter: UpstreamOAuthLinkFilter<'_>) -> Result<usize, Self::Error>;
236);