mas_storage/oauth2/client.rs
1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::collections::{BTreeMap, BTreeSet};
8
9use async_trait::async_trait;
10use mas_data_model::{Client, User};
11use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
12use mas_jose::jwk::PublicJsonWebKeySet;
13use oauth2_types::{oidc::ApplicationType, requests::GrantType, scope::Scope};
14use rand_core::RngCore;
15use ulid::Ulid;
16use url::Url;
17
18use crate::{Clock, repository_impl};
19
20/// An [`OAuth2ClientRepository`] helps interacting with [`Client`] saved in the
21/// storage backend
22#[async_trait]
23pub trait OAuth2ClientRepository: Send + Sync {
24 /// The error type returned by the repository
25 type Error;
26
27 /// Lookup an OAuth2 client by its ID
28 ///
29 /// Returns `None` if the client does not exist
30 ///
31 /// # Parameters
32 ///
33 /// * `id`: The ID of the client to lookup
34 ///
35 /// # Errors
36 ///
37 /// Returns [`Self::Error`] if the underlying repository fails
38 async fn lookup(&mut self, id: Ulid) -> Result<Option<Client>, Self::Error>;
39
40 /// Find an OAuth2 client by its client ID
41 async fn find_by_client_id(&mut self, client_id: &str) -> Result<Option<Client>, Self::Error> {
42 let Ok(id) = client_id.parse() else {
43 return Ok(None);
44 };
45 self.lookup(id).await
46 }
47
48 /// Find an OAuth2 client by its metadata digest
49 ///
50 /// Returns `None` if the client does not exist
51 ///
52 /// # Parameters
53 ///
54 /// * `digest`: The metadata digest (SHA-256 hash encoded in hex) of the
55 /// client to find
56 ///
57 /// # Errors
58 ///
59 /// Returns [`Self::Error`] if the underlying repository fails
60 async fn find_by_metadata_digest(
61 &mut self,
62 digest: &str,
63 ) -> Result<Option<Client>, Self::Error>;
64
65 /// Load a batch of OAuth2 clients by their IDs
66 ///
67 /// Returns a map of client IDs to clients. If a client does not exist, it
68 /// is not present in the map.
69 ///
70 /// # Parameters
71 ///
72 /// * `ids`: The IDs of the clients to load
73 ///
74 /// # Errors
75 ///
76 /// Returns [`Self::Error`] if the underlying repository fails
77 async fn load_batch(
78 &mut self,
79 ids: BTreeSet<Ulid>,
80 ) -> Result<BTreeMap<Ulid, Client>, Self::Error>;
81
82 /// Add a new OAuth2 client
83 ///
84 /// Returns the client that was added
85 ///
86 /// # Parameters
87 ///
88 /// * `rng`: The random number generator to use
89 /// * `clock`: The clock used to generate timestamps
90 /// * `redirect_uris`: The list of redirect URIs used by this client
91 /// * `metadata_digest`: The hash of the client metadata, if computed
92 /// * `encrypted_client_secret`: The encrypted client secret, if any
93 /// * `application_type`: The application type of this client
94 /// * `grant_types`: The list of grant types this client can use
95 /// * `client_name`: The human-readable name of this client, if given
96 /// * `logo_uri`: The URI of the logo of this client, if given
97 /// * `client_uri`: The URI of a website of this client, if given
98 /// * `policy_uri`: The URI of the privacy policy of this client, if given
99 /// * `tos_uri`: The URI of the terms of service of this client, if given
100 /// * `jwks_uri`: The URI of the JWKS of this client, if given
101 /// * `jwks`: The JWKS of this client, if given
102 /// * `id_token_signed_response_alg`: The algorithm used to sign the ID
103 /// token
104 /// * `userinfo_signed_response_alg`: The algorithm used to sign the user
105 /// info. If none, the user info endpoint will not sign the response
106 /// * `token_endpoint_auth_method`: The authentication method used by this
107 /// client when calling the token endpoint
108 /// * `token_endpoint_auth_signing_alg`: The algorithm used to sign the JWT
109 /// when using the `client_secret_jwt` or `private_key_jwt` authentication
110 /// methods
111 /// * `initiate_login_uri`: The URI used to initiate a login, if given
112 ///
113 /// # Errors
114 ///
115 /// Returns [`Self::Error`] if the underlying repository fails
116 #[allow(clippy::too_many_arguments)]
117 async fn add(
118 &mut self,
119 rng: &mut (dyn RngCore + Send),
120 clock: &dyn Clock,
121 redirect_uris: Vec<Url>,
122 metadata_digest: Option<String>,
123 encrypted_client_secret: Option<String>,
124 application_type: Option<ApplicationType>,
125 grant_types: Vec<GrantType>,
126 client_name: Option<String>,
127 logo_uri: Option<Url>,
128 client_uri: Option<Url>,
129 policy_uri: Option<Url>,
130 tos_uri: Option<Url>,
131 jwks_uri: Option<Url>,
132 jwks: Option<PublicJsonWebKeySet>,
133 id_token_signed_response_alg: Option<JsonWebSignatureAlg>,
134 userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
135 token_endpoint_auth_method: Option<OAuthClientAuthenticationMethod>,
136 token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
137 initiate_login_uri: Option<Url>,
138 ) -> Result<Client, Self::Error>;
139
140 /// Add or replace a static client
141 ///
142 /// Returns the client that was added or replaced
143 ///
144 /// # Parameters
145 ///
146 /// * `client_id`: The client ID
147 /// * `client_auth_method`: The authentication method this client uses
148 /// * `encrypted_client_secret`: The encrypted client secret, if any
149 /// * `jwks`: The client JWKS, if any
150 /// * `jwks_uri`: The client JWKS URI, if any
151 /// * `redirect_uris`: The list of redirect URIs used by this client
152 ///
153 /// # Errors
154 ///
155 /// Returns [`Self::Error`] if the underlying repository fails
156 #[allow(clippy::too_many_arguments)]
157 async fn upsert_static(
158 &mut self,
159 client_id: Ulid,
160 client_auth_method: OAuthClientAuthenticationMethod,
161 encrypted_client_secret: Option<String>,
162 jwks: Option<PublicJsonWebKeySet>,
163 jwks_uri: Option<Url>,
164 redirect_uris: Vec<Url>,
165 ) -> Result<Client, Self::Error>;
166
167 /// List all static clients
168 ///
169 /// # Errors
170 ///
171 /// Returns [`Self::Error`] if the underlying repository fails
172 async fn all_static(&mut self) -> Result<Vec<Client>, Self::Error>;
173
174 /// Get the list of scopes that the user has given consent for the given
175 /// client
176 ///
177 /// # Parameters
178 ///
179 /// * `client`: The client to get the consent for
180 /// * `user`: The user to get the consent for
181 ///
182 /// # Errors
183 ///
184 /// Returns [`Self::Error`] if the underlying repository fails
185 async fn get_consent_for_user(
186 &mut self,
187 client: &Client,
188 user: &User,
189 ) -> Result<Scope, Self::Error>;
190
191 /// Give consent for a set of scopes for the given client and user
192 ///
193 /// # Parameters
194 ///
195 /// * `rng`: The random number generator to use
196 /// * `clock`: The clock used to generate timestamps
197 /// * `client`: The client to give the consent for
198 /// * `user`: The user to give the consent for
199 /// * `scope`: The scope to give consent for
200 ///
201 /// # Errors
202 ///
203 /// Returns [`Self::Error`] if the underlying repository fails
204 async fn give_consent_for_user(
205 &mut self,
206 rng: &mut (dyn RngCore + Send),
207 clock: &dyn Clock,
208 client: &Client,
209 user: &User,
210 scope: &Scope,
211 ) -> Result<(), Self::Error>;
212
213 /// Delete a client
214 ///
215 /// # Parameters
216 ///
217 /// * `client`: The client to delete
218 ///
219 /// # Errors
220 ///
221 /// Returns [`Self::Error`] if the underlying repository fails, or if the
222 /// client does not exist
223 async fn delete(&mut self, client: Client) -> Result<(), Self::Error> {
224 self.delete_by_id(client.id).await
225 }
226
227 /// Delete a client by ID
228 ///
229 /// # Parameters
230 ///
231 /// * `id`: The ID of the client to delete
232 ///
233 /// # Errors
234 ///
235 /// Returns [`Self::Error`] if the underlying repository fails, or if the
236 /// client does not exist
237 async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>;
238}
239
240repository_impl!(OAuth2ClientRepository:
241 async fn lookup(&mut self, id: Ulid) -> Result<Option<Client>, Self::Error>;
242
243 async fn find_by_metadata_digest(
244 &mut self,
245 digest: &str,
246 ) -> Result<Option<Client>, Self::Error>;
247
248 async fn load_batch(
249 &mut self,
250 ids: BTreeSet<Ulid>,
251 ) -> Result<BTreeMap<Ulid, Client>, Self::Error>;
252
253 async fn add(
254 &mut self,
255 rng: &mut (dyn RngCore + Send),
256 clock: &dyn Clock,
257 redirect_uris: Vec<Url>,
258 metadata_digest: Option<String>,
259 encrypted_client_secret: Option<String>,
260 application_type: Option<ApplicationType>,
261 grant_types: Vec<GrantType>,
262 client_name: Option<String>,
263 logo_uri: Option<Url>,
264 client_uri: Option<Url>,
265 policy_uri: Option<Url>,
266 tos_uri: Option<Url>,
267 jwks_uri: Option<Url>,
268 jwks: Option<PublicJsonWebKeySet>,
269 id_token_signed_response_alg: Option<JsonWebSignatureAlg>,
270 userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
271 token_endpoint_auth_method: Option<OAuthClientAuthenticationMethod>,
272 token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
273 initiate_login_uri: Option<Url>,
274 ) -> Result<Client, Self::Error>;
275
276 async fn upsert_static(
277 &mut self,
278 client_id: Ulid,
279 client_auth_method: OAuthClientAuthenticationMethod,
280 encrypted_client_secret: Option<String>,
281 jwks: Option<PublicJsonWebKeySet>,
282 jwks_uri: Option<Url>,
283 redirect_uris: Vec<Url>,
284 ) -> Result<Client, Self::Error>;
285
286 async fn all_static(&mut self) -> Result<Vec<Client>, Self::Error>;
287
288 async fn delete(&mut self, client: Client) -> Result<(), Self::Error>;
289
290 async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>;
291
292 async fn get_consent_for_user(
293 &mut self,
294 client: &Client,
295 user: &User,
296 ) -> Result<Scope, Self::Error>;
297
298 async fn give_consent_for_user(
299 &mut self,
300 rng: &mut (dyn RngCore + Send),
301 clock: &dyn Clock,
302 client: &Client,
303 user: &User,
304 scope: &Scope,
305 ) -> Result<(), Self::Error>;
306);