mas_config/sections/
upstream_oauth2.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::collections::BTreeMap;
8
9use mas_iana::jose::JsonWebSignatureAlg;
10use schemars::JsonSchema;
11use serde::{Deserialize, Serialize, de::Error};
12use serde_with::skip_serializing_none;
13use ulid::Ulid;
14use url::Url;
15
16use crate::ConfigurationSection;
17
18/// Upstream OAuth 2.0 providers configuration
19#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)]
20pub struct UpstreamOAuth2Config {
21    /// List of OAuth 2.0 providers
22    pub providers: Vec<Provider>,
23}
24
25impl UpstreamOAuth2Config {
26    /// Returns true if the configuration is the default one
27    pub(crate) fn is_default(&self) -> bool {
28        self.providers.is_empty()
29    }
30}
31
32impl ConfigurationSection for UpstreamOAuth2Config {
33    const PATH: Option<&'static str> = Some("upstream_oauth2");
34
35    fn validate(&self, figment: &figment::Figment) -> Result<(), figment::Error> {
36        for (index, provider) in self.providers.iter().enumerate() {
37            let annotate = |mut error: figment::Error| {
38                error.metadata = figment
39                    .find_metadata(&format!("{root}.providers", root = Self::PATH.unwrap()))
40                    .cloned();
41                error.profile = Some(figment::Profile::Default);
42                error.path = vec![
43                    Self::PATH.unwrap().to_owned(),
44                    "providers".to_owned(),
45                    index.to_string(),
46                ];
47                Err(error)
48            };
49
50            if !matches!(provider.discovery_mode, DiscoveryMode::Disabled)
51                && provider.issuer.is_none()
52            {
53                return annotate(figment::Error::custom(
54                    "The `issuer` field is required when discovery is enabled",
55                ));
56            }
57
58            match provider.token_endpoint_auth_method {
59                TokenAuthMethod::None
60                | TokenAuthMethod::PrivateKeyJwt
61                | TokenAuthMethod::SignInWithApple => {
62                    if provider.client_secret.is_some() {
63                        return annotate(figment::Error::custom(
64                            "Unexpected field `client_secret` for the selected authentication method",
65                        ));
66                    }
67                }
68                TokenAuthMethod::ClientSecretBasic
69                | TokenAuthMethod::ClientSecretPost
70                | TokenAuthMethod::ClientSecretJwt => {
71                    if provider.client_secret.is_none() {
72                        return annotate(figment::Error::missing_field("client_secret"));
73                    }
74                }
75            }
76
77            match provider.token_endpoint_auth_method {
78                TokenAuthMethod::None
79                | TokenAuthMethod::ClientSecretBasic
80                | TokenAuthMethod::ClientSecretPost
81                | TokenAuthMethod::SignInWithApple => {
82                    if provider.token_endpoint_auth_signing_alg.is_some() {
83                        return annotate(figment::Error::custom(
84                            "Unexpected field `token_endpoint_auth_signing_alg` for the selected authentication method",
85                        ));
86                    }
87                }
88                TokenAuthMethod::ClientSecretJwt | TokenAuthMethod::PrivateKeyJwt => {
89                    if provider.token_endpoint_auth_signing_alg.is_none() {
90                        return annotate(figment::Error::missing_field(
91                            "token_endpoint_auth_signing_alg",
92                        ));
93                    }
94                }
95            }
96
97            match provider.token_endpoint_auth_method {
98                TokenAuthMethod::SignInWithApple => {
99                    if provider.sign_in_with_apple.is_none() {
100                        return annotate(figment::Error::missing_field("sign_in_with_apple"));
101                    }
102                }
103
104                _ => {
105                    if provider.sign_in_with_apple.is_some() {
106                        return annotate(figment::Error::custom(
107                            "Unexpected field `sign_in_with_apple` for the selected authentication method",
108                        ));
109                    }
110                }
111            }
112        }
113
114        Ok(())
115    }
116}
117
118/// The response mode we ask the provider to use for the callback
119#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)]
120#[serde(rename_all = "snake_case")]
121pub enum ResponseMode {
122    /// `query`: The provider will send the response as a query string in the
123    /// URL search parameters
124    Query,
125
126    /// `form_post`: The provider will send the response as a POST request with
127    /// the response parameters in the request body
128    ///
129    /// <https://openid.net/specs/oauth-v2-form-post-response-mode-1_0.html>
130    FormPost,
131}
132
133/// Authentication methods used against the OAuth 2.0 provider
134#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)]
135#[serde(rename_all = "snake_case")]
136pub enum TokenAuthMethod {
137    /// `none`: No authentication
138    None,
139
140    /// `client_secret_basic`: `client_id` and `client_secret` used as basic
141    /// authorization credentials
142    ClientSecretBasic,
143
144    /// `client_secret_post`: `client_id` and `client_secret` sent in the
145    /// request body
146    ClientSecretPost,
147
148    /// `client_secret_jwt`: a `client_assertion` sent in the request body and
149    /// signed using the `client_secret`
150    ClientSecretJwt,
151
152    /// `private_key_jwt`: a `client_assertion` sent in the request body and
153    /// signed by an asymmetric key
154    PrivateKeyJwt,
155
156    /// `sign_in_with_apple`: a special method for Signin with Apple
157    SignInWithApple,
158}
159
160/// How to handle a claim
161#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
162#[serde(rename_all = "lowercase")]
163pub enum ImportAction {
164    /// Ignore the claim
165    #[default]
166    Ignore,
167
168    /// Suggest the claim value, but allow the user to change it
169    Suggest,
170
171    /// Force the claim value, but don't fail if it is missing
172    Force,
173
174    /// Force the claim value, and fail if it is missing
175    Require,
176}
177
178impl ImportAction {
179    #[allow(clippy::trivially_copy_pass_by_ref)]
180    const fn is_default(&self) -> bool {
181        matches!(self, ImportAction::Ignore)
182    }
183}
184
185/// What should be done for the subject attribute
186#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
187pub struct SubjectImportPreference {
188    /// The Jinja2 template to use for the subject attribute
189    ///
190    /// If not provided, the default template is `{{ user.sub }}`
191    #[serde(default, skip_serializing_if = "Option::is_none")]
192    pub template: Option<String>,
193}
194
195impl SubjectImportPreference {
196    const fn is_default(&self) -> bool {
197        self.template.is_none()
198    }
199}
200
201/// What should be done for the localpart attribute
202#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
203pub struct LocalpartImportPreference {
204    /// How to handle the attribute
205    #[serde(default, skip_serializing_if = "ImportAction::is_default")]
206    pub action: ImportAction,
207
208    /// The Jinja2 template to use for the localpart attribute
209    ///
210    /// If not provided, the default template is `{{ user.preferred_username }}`
211    #[serde(default, skip_serializing_if = "Option::is_none")]
212    pub template: Option<String>,
213}
214
215impl LocalpartImportPreference {
216    const fn is_default(&self) -> bool {
217        self.action.is_default() && self.template.is_none()
218    }
219}
220
221/// What should be done for the displayname attribute
222#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
223pub struct DisplaynameImportPreference {
224    /// How to handle the attribute
225    #[serde(default, skip_serializing_if = "ImportAction::is_default")]
226    pub action: ImportAction,
227
228    /// The Jinja2 template to use for the displayname attribute
229    ///
230    /// If not provided, the default template is `{{ user.name }}`
231    #[serde(default, skip_serializing_if = "Option::is_none")]
232    pub template: Option<String>,
233}
234
235impl DisplaynameImportPreference {
236    const fn is_default(&self) -> bool {
237        self.action.is_default() && self.template.is_none()
238    }
239}
240
241/// What should be done with the email attribute
242#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
243pub struct EmailImportPreference {
244    /// How to handle the claim
245    #[serde(default, skip_serializing_if = "ImportAction::is_default")]
246    pub action: ImportAction,
247
248    /// The Jinja2 template to use for the email address attribute
249    ///
250    /// If not provided, the default template is `{{ user.email }}`
251    #[serde(default, skip_serializing_if = "Option::is_none")]
252    pub template: Option<String>,
253}
254
255impl EmailImportPreference {
256    const fn is_default(&self) -> bool {
257        self.action.is_default() && self.template.is_none()
258    }
259}
260
261/// What should be done for the account name attribute
262#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
263pub struct AccountNameImportPreference {
264    /// The Jinja2 template to use for the account name. This name is only used
265    /// for display purposes.
266    ///
267    /// If not provided, it will be ignored.
268    #[serde(default, skip_serializing_if = "Option::is_none")]
269    pub template: Option<String>,
270}
271
272impl AccountNameImportPreference {
273    const fn is_default(&self) -> bool {
274        self.template.is_none()
275    }
276}
277
278/// How claims should be imported
279#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
280pub struct ClaimsImports {
281    /// How to determine the subject of the user
282    #[serde(default, skip_serializing_if = "SubjectImportPreference::is_default")]
283    pub subject: SubjectImportPreference,
284
285    /// Import the localpart of the MXID
286    #[serde(default, skip_serializing_if = "LocalpartImportPreference::is_default")]
287    pub localpart: LocalpartImportPreference,
288
289    /// Import the displayname of the user.
290    #[serde(
291        default,
292        skip_serializing_if = "DisplaynameImportPreference::is_default"
293    )]
294    pub displayname: DisplaynameImportPreference,
295
296    /// Import the email address of the user based on the `email` and
297    /// `email_verified` claims
298    #[serde(default, skip_serializing_if = "EmailImportPreference::is_default")]
299    pub email: EmailImportPreference,
300
301    /// Set a human-readable name for the upstream account for display purposes
302    #[serde(
303        default,
304        skip_serializing_if = "AccountNameImportPreference::is_default"
305    )]
306    pub account_name: AccountNameImportPreference,
307}
308
309impl ClaimsImports {
310    const fn is_default(&self) -> bool {
311        self.subject.is_default()
312            && self.localpart.is_default()
313            && self.displayname.is_default()
314            && self.email.is_default()
315    }
316}
317
318/// How to discover the provider's configuration
319#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, Default)]
320#[serde(rename_all = "snake_case")]
321pub enum DiscoveryMode {
322    /// Use OIDC discovery with strict metadata verification
323    #[default]
324    Oidc,
325
326    /// Use OIDC discovery with relaxed metadata verification
327    Insecure,
328
329    /// Use a static configuration
330    Disabled,
331}
332
333impl DiscoveryMode {
334    #[allow(clippy::trivially_copy_pass_by_ref)]
335    const fn is_default(&self) -> bool {
336        matches!(self, DiscoveryMode::Oidc)
337    }
338}
339
340/// Whether to use proof key for code exchange (PKCE) when requesting and
341/// exchanging the token.
342#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, Default)]
343#[serde(rename_all = "snake_case")]
344pub enum PkceMethod {
345    /// Use PKCE if the provider supports it
346    ///
347    /// Defaults to no PKCE if provider discovery is disabled
348    #[default]
349    Auto,
350
351    /// Always use PKCE with the S256 challenge method
352    Always,
353
354    /// Never use PKCE
355    Never,
356}
357
358impl PkceMethod {
359    #[allow(clippy::trivially_copy_pass_by_ref)]
360    const fn is_default(&self) -> bool {
361        matches!(self, PkceMethod::Auto)
362    }
363}
364
365fn default_true() -> bool {
366    true
367}
368
369#[allow(clippy::trivially_copy_pass_by_ref)]
370fn is_default_true(value: &bool) -> bool {
371    *value
372}
373
374#[allow(clippy::ref_option)]
375fn is_signed_response_alg_default(signed_response_alg: &JsonWebSignatureAlg) -> bool {
376    *signed_response_alg == signed_response_alg_default()
377}
378
379#[allow(clippy::unnecessary_wraps)]
380fn signed_response_alg_default() -> JsonWebSignatureAlg {
381    JsonWebSignatureAlg::Rs256
382}
383
384#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
385pub struct SignInWithApple {
386    /// The private key used to sign the `id_token`
387    pub private_key: String,
388
389    /// The Team ID of the Apple Developer Portal
390    pub team_id: String,
391
392    /// The key ID of the Apple Developer Portal
393    pub key_id: String,
394}
395
396/// Configuration for one upstream OAuth 2 provider.
397#[skip_serializing_none]
398#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
399pub struct Provider {
400    /// Whether this provider is enabled.
401    ///
402    /// Defaults to `true`
403    #[serde(default = "default_true", skip_serializing_if = "is_default_true")]
404    pub enabled: bool,
405
406    /// An internal unique identifier for this provider
407    #[schemars(
408        with = "String",
409        regex(pattern = r"^[0123456789ABCDEFGHJKMNPQRSTVWXYZ]{26}$"),
410        description = "A ULID as per https://github.com/ulid/spec"
411    )]
412    pub id: Ulid,
413
414    /// The OIDC issuer URL
415    ///
416    /// This is required if OIDC discovery is enabled (which is the default)
417    #[serde(skip_serializing_if = "Option::is_none")]
418    pub issuer: Option<String>,
419
420    /// A human-readable name for the provider, that will be shown to users
421    #[serde(skip_serializing_if = "Option::is_none")]
422    pub human_name: Option<String>,
423
424    /// A brand identifier used to customise the UI, e.g. `apple`, `google`,
425    /// `github`, etc.
426    ///
427    /// Values supported by the default template are:
428    ///
429    ///  - `apple`
430    ///  - `google`
431    ///  - `facebook`
432    ///  - `github`
433    ///  - `gitlab`
434    ///  - `twitter`
435    #[serde(skip_serializing_if = "Option::is_none")]
436    pub brand_name: Option<String>,
437
438    /// The client ID to use when authenticating with the provider
439    pub client_id: String,
440
441    /// The client secret to use when authenticating with the provider
442    ///
443    /// Used by the `client_secret_basic`, `client_secret_post`, and
444    /// `client_secret_jwt` methods
445    #[serde(skip_serializing_if = "Option::is_none")]
446    pub client_secret: Option<String>,
447
448    /// The method to authenticate the client with the provider
449    pub token_endpoint_auth_method: TokenAuthMethod,
450
451    /// Additional parameters for the `sign_in_with_apple` method
452    #[serde(skip_serializing_if = "Option::is_none")]
453    pub sign_in_with_apple: Option<SignInWithApple>,
454
455    /// The JWS algorithm to use when authenticating the client with the
456    /// provider
457    ///
458    /// Used by the `client_secret_jwt` and `private_key_jwt` methods
459    #[serde(skip_serializing_if = "Option::is_none")]
460    pub token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
461
462    /// Expected signature for the JWT payload returned by the token
463    /// authentication endpoint.
464    ///
465    /// Defaults to `RS256`.
466    #[serde(
467        default = "signed_response_alg_default",
468        skip_serializing_if = "is_signed_response_alg_default"
469    )]
470    pub id_token_signed_response_alg: JsonWebSignatureAlg,
471
472    /// The scopes to request from the provider
473    pub scope: String,
474
475    /// How to discover the provider's configuration
476    ///
477    /// Defaults to `oidc`, which uses OIDC discovery with strict metadata
478    /// verification
479    #[serde(default, skip_serializing_if = "DiscoveryMode::is_default")]
480    pub discovery_mode: DiscoveryMode,
481
482    /// Whether to use proof key for code exchange (PKCE) when requesting and
483    /// exchanging the token.
484    ///
485    /// Defaults to `auto`, which uses PKCE if the provider supports it.
486    #[serde(default, skip_serializing_if = "PkceMethod::is_default")]
487    pub pkce_method: PkceMethod,
488
489    /// Whether to fetch the user profile from the userinfo endpoint,
490    /// or to rely on the data returned in the `id_token` from the
491    /// `token_endpoint`.
492    ///
493    /// Defaults to `false`.
494    #[serde(default)]
495    pub fetch_userinfo: bool,
496
497    /// Expected signature for the JWT payload returned by the userinfo
498    /// endpoint.
499    ///
500    /// If not specified, the response is expected to be an unsigned JSON
501    /// payload.
502    #[serde(skip_serializing_if = "Option::is_none")]
503    pub userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
504
505    /// The URL to use for the provider's authorization endpoint
506    ///
507    /// Defaults to the `authorization_endpoint` provided through discovery
508    #[serde(skip_serializing_if = "Option::is_none")]
509    pub authorization_endpoint: Option<Url>,
510
511    /// The URL to use for the provider's userinfo endpoint
512    ///
513    /// Defaults to the `userinfo_endpoint` provided through discovery
514    #[serde(skip_serializing_if = "Option::is_none")]
515    pub userinfo_endpoint: Option<Url>,
516
517    /// The URL to use for the provider's token endpoint
518    ///
519    /// Defaults to the `token_endpoint` provided through discovery
520    #[serde(skip_serializing_if = "Option::is_none")]
521    pub token_endpoint: Option<Url>,
522
523    /// The URL to use for getting the provider's public keys
524    ///
525    /// Defaults to the `jwks_uri` provided through discovery
526    #[serde(skip_serializing_if = "Option::is_none")]
527    pub jwks_uri: Option<Url>,
528
529    /// The response mode we ask the provider to use for the callback
530    #[serde(skip_serializing_if = "Option::is_none")]
531    pub response_mode: Option<ResponseMode>,
532
533    /// How claims should be imported from the `id_token` provided by the
534    /// provider
535    #[serde(default, skip_serializing_if = "ClaimsImports::is_default")]
536    pub claims_imports: ClaimsImports,
537
538    /// Additional parameters to include in the authorization request
539    ///
540    /// Orders of the keys are not preserved.
541    #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
542    pub additional_authorization_parameters: BTreeMap<String, String>,
543
544    /// The ID of the provider that was used by Synapse.
545    /// In order to perform a Synapse-to-MAS migration, this must be specified.
546    ///
547    /// ## For providers that used OAuth 2.0 or OpenID Connect in Synapse
548    ///
549    /// ### For `oidc_providers`:
550    /// This should be specified as `oidc-` followed by the ID that was
551    /// configured as `idp_id` in one of the `oidc_providers` in the Synapse
552    /// configuration.
553    /// For example, if Synapse's configuration contained `idp_id: wombat` for
554    /// this provider, then specify `oidc-wombat` here.
555    ///
556    /// ### For `oidc_config` (legacy):
557    /// Specify `oidc` here.
558    #[serde(skip_serializing_if = "Option::is_none")]
559    pub synapse_idp_id: Option<String>,
560}