mas_handlers/oauth2/
registration.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use axum::{Json, extract::State, response::IntoResponse};
8use axum_extra::TypedHeader;
9use hyper::StatusCode;
10use mas_axum_utils::sentry::SentryEventID;
11use mas_iana::oauth::OAuthClientAuthenticationMethod;
12use mas_keystore::Encrypter;
13use mas_policy::{Policy, Violation};
14use mas_storage::{BoxClock, BoxRepository, BoxRng, oauth2::OAuth2ClientRepository};
15use oauth2_types::{
16    errors::{ClientError, ClientErrorCode},
17    registration::{
18        ClientMetadata, ClientMetadataVerificationError, ClientRegistrationResponse, Localized,
19        VerifiedClientMetadata,
20    },
21};
22use psl::Psl;
23use rand::distributions::{Alphanumeric, DistString};
24use serde::Serialize;
25use sha2::Digest as _;
26use thiserror::Error;
27use tracing::info;
28use url::Url;
29
30use crate::{BoundActivityTracker, impl_from_error_for_route};
31
32#[derive(Debug, Error)]
33pub(crate) enum RouteError {
34    #[error(transparent)]
35    Internal(Box<dyn std::error::Error + Send + Sync>),
36
37    #[error(transparent)]
38    JsonExtract(#[from] axum::extract::rejection::JsonRejection),
39
40    #[error("invalid client metadata")]
41    InvalidClientMetadata(#[from] ClientMetadataVerificationError),
42
43    #[error("{0} is a public suffix, not a valid domain")]
44    UrlIsPublicSuffix(&'static str),
45
46    #[error("denied by the policy: {0:?}")]
47    PolicyDenied(Vec<Violation>),
48}
49
50impl_from_error_for_route!(mas_storage::RepositoryError);
51impl_from_error_for_route!(mas_policy::LoadError);
52impl_from_error_for_route!(mas_policy::EvaluationError);
53impl_from_error_for_route!(mas_keystore::aead::Error);
54impl_from_error_for_route!(serde_json::Error);
55
56impl IntoResponse for RouteError {
57    fn into_response(self) -> axum::response::Response {
58        let event_id = sentry::capture_error(&self);
59        let response = match self {
60            Self::Internal(_) => (
61                StatusCode::INTERNAL_SERVER_ERROR,
62                Json(ClientError::from(ClientErrorCode::ServerError)),
63            )
64                .into_response(),
65
66            // This error happens if we managed to parse the incomiong JSON but it can't be
67            // deserialized to the expected type. In this case we return an
68            // `invalid_client_metadata` error with the details of the error.
69            Self::JsonExtract(axum::extract::rejection::JsonRejection::JsonDataError(e)) => (
70                StatusCode::BAD_REQUEST,
71                Json(
72                    ClientError::from(ClientErrorCode::InvalidClientMetadata)
73                        .with_description(e.to_string()),
74                ),
75            )
76                .into_response(),
77
78            // For all other JSON errors we return a `invalid_request` error, since this is
79            // probably due to a malformed request.
80            Self::JsonExtract(_) => (
81                StatusCode::BAD_REQUEST,
82                Json(ClientError::from(ClientErrorCode::InvalidRequest)),
83            )
84                .into_response(),
85
86            // This error comes from the `ClientMetadata::validate` method. We return an
87            // `invalid_redirect_uri` error if the error is related to the redirect URIs, else we
88            // return an `invalid_client_metadata` error.
89            Self::InvalidClientMetadata(
90                ClientMetadataVerificationError::MissingRedirectUris
91                | ClientMetadataVerificationError::RedirectUriWithFragment(_),
92            ) => (
93                StatusCode::BAD_REQUEST,
94                Json(ClientError::from(ClientErrorCode::InvalidRedirectUri)),
95            )
96                .into_response(),
97
98            Self::InvalidClientMetadata(e) => (
99                StatusCode::BAD_REQUEST,
100                Json(
101                    ClientError::from(ClientErrorCode::InvalidClientMetadata)
102                        .with_description(e.to_string()),
103                ),
104            )
105                .into_response(),
106
107            // This error happens if the any of the client's URIs are public suffixes. We return
108            // an `invalid_redirect_uri` error if it's a `redirect_uri`, else we return an
109            // `invalid_client_metadata` error.
110            Self::UrlIsPublicSuffix("redirect_uri") => (
111                StatusCode::BAD_REQUEST,
112                Json(
113                    ClientError::from(ClientErrorCode::InvalidRedirectUri)
114                        .with_description("redirect_uri is not using a valid domain".to_owned()),
115                ),
116            )
117                .into_response(),
118
119            Self::UrlIsPublicSuffix(field) => (
120                StatusCode::BAD_REQUEST,
121                Json(
122                    ClientError::from(ClientErrorCode::InvalidClientMetadata)
123                        .with_description(format!("{field} is not using a valid domain")),
124                ),
125            )
126                .into_response(),
127
128            // For policy violations, we return an `invalid_client_metadata` error with the details
129            // of the violations in most cases. If a violation includes `redirect_uri` in the
130            // message, we return an `invalid_redirect_uri` error instead.
131            Self::PolicyDenied(violations) => {
132                // TODO: detect them better
133                let code = if violations.iter().any(|v| v.msg.contains("redirect_uri")) {
134                    ClientErrorCode::InvalidRedirectUri
135                } else {
136                    ClientErrorCode::InvalidClientMetadata
137                };
138
139                let collected = &violations
140                    .iter()
141                    .map(|v| v.msg.clone())
142                    .collect::<Vec<String>>();
143                let joined = collected.join("; ");
144
145                (
146                    StatusCode::BAD_REQUEST,
147                    Json(ClientError::from(code).with_description(joined)),
148                )
149                    .into_response()
150            }
151        };
152
153        (SentryEventID::from(event_id), response).into_response()
154    }
155}
156
157#[derive(Serialize)]
158struct RouteResponse {
159    #[serde(flatten)]
160    response: ClientRegistrationResponse,
161    #[serde(flatten)]
162    metadata: VerifiedClientMetadata,
163}
164
165/// Check if the host of the given URL is a public suffix
166fn host_is_public_suffix(url: &Url) -> bool {
167    let host = url.host_str().unwrap_or_default().as_bytes();
168    let Some(suffix) = psl::List.suffix(host) else {
169        // There is no suffix, which is the case for empty hosts, like with custom
170        // schemes
171        return false;
172    };
173
174    if !suffix.is_known() {
175        // The suffix is not known, so it's not a public suffix
176        return false;
177    }
178
179    // We want to cover two cases:
180    // - The host is the suffix itself, like `com`
181    // - The host is a dot followed by the suffix, like `.com`
182    if host.len() <= suffix.as_bytes().len() + 1 {
183        // The host only has the suffix in it, so it's a public suffix
184        return true;
185    }
186
187    false
188}
189
190/// Check if any of the URLs in the given `Localized` field is a public suffix
191fn localised_url_has_public_suffix(url: &Localized<Url>) -> bool {
192    url.iter().any(|(_lang, url)| host_is_public_suffix(url))
193}
194
195#[tracing::instrument(name = "handlers.oauth2.registration.post", skip_all, err)]
196pub(crate) async fn post(
197    mut rng: BoxRng,
198    clock: BoxClock,
199    mut repo: BoxRepository,
200    mut policy: Policy,
201    activity_tracker: BoundActivityTracker,
202    user_agent: Option<TypedHeader<headers::UserAgent>>,
203    State(encrypter): State<Encrypter>,
204    body: Result<Json<ClientMetadata>, axum::extract::rejection::JsonRejection>,
205) -> Result<impl IntoResponse, RouteError> {
206    // Propagate any JSON extraction error
207    let Json(body) = body?;
208
209    // Sort the properties to ensure a stable serialisation order for hashing
210    let body = body.sorted();
211
212    // We need to serialize the body to compute the hash, and to log it
213    let body_json = serde_json::to_string(&body)?;
214
215    info!(body = body_json, "Client registration");
216
217    let user_agent = user_agent.map(|ua| ua.to_string());
218
219    // Validate the body
220    let metadata = body.validate()?;
221
222    // Some extra validation that is hard to do in OPA and not done by the
223    // `validate` method either
224    if let Some(client_uri) = &metadata.client_uri {
225        if localised_url_has_public_suffix(client_uri) {
226            return Err(RouteError::UrlIsPublicSuffix("client_uri"));
227        }
228    }
229
230    if let Some(logo_uri) = &metadata.logo_uri {
231        if localised_url_has_public_suffix(logo_uri) {
232            return Err(RouteError::UrlIsPublicSuffix("logo_uri"));
233        }
234    }
235
236    if let Some(policy_uri) = &metadata.policy_uri {
237        if localised_url_has_public_suffix(policy_uri) {
238            return Err(RouteError::UrlIsPublicSuffix("policy_uri"));
239        }
240    }
241
242    if let Some(tos_uri) = &metadata.tos_uri {
243        if localised_url_has_public_suffix(tos_uri) {
244            return Err(RouteError::UrlIsPublicSuffix("tos_uri"));
245        }
246    }
247
248    if let Some(initiate_login_uri) = &metadata.initiate_login_uri {
249        if host_is_public_suffix(initiate_login_uri) {
250            return Err(RouteError::UrlIsPublicSuffix("initiate_login_uri"));
251        }
252    }
253
254    for redirect_uri in metadata.redirect_uris() {
255        if host_is_public_suffix(redirect_uri) {
256            return Err(RouteError::UrlIsPublicSuffix("redirect_uri"));
257        }
258    }
259
260    let res = policy
261        .evaluate_client_registration(mas_policy::ClientRegistrationInput {
262            client_metadata: &metadata,
263            requester: mas_policy::Requester {
264                ip_address: activity_tracker.ip(),
265                user_agent,
266            },
267        })
268        .await?;
269    if !res.valid() {
270        return Err(RouteError::PolicyDenied(res.violations));
271    }
272
273    let (client_secret, encrypted_client_secret) = match metadata.token_endpoint_auth_method {
274        Some(
275            OAuthClientAuthenticationMethod::ClientSecretJwt
276            | OAuthClientAuthenticationMethod::ClientSecretPost
277            | OAuthClientAuthenticationMethod::ClientSecretBasic,
278        ) => {
279            // Let's generate a random client secret
280            let client_secret = Alphanumeric.sample_string(&mut rng, 20);
281            let encrypted_client_secret = encrypter.encrypt_to_string(client_secret.as_bytes())?;
282            (Some(client_secret), Some(encrypted_client_secret))
283        }
284        _ => (None, None),
285    };
286
287    // If the client doesn't have a secret, we may be able to deduplicate it. To
288    // do so, we hash the client metadata, and look for it in the database
289    let (digest_hash, existing_client) = if client_secret.is_none() {
290        // XXX: One interesting caveat is that we hash *before* saving to the database.
291        // It means it takes into account fields that we don't care about *yet*.
292        //
293        // This means that if later we start supporting a particular field, we
294        // will still serve the 'old' client_id, without updating the client in the
295        // database
296        let hash = sha2::Sha256::digest(body_json);
297        let hash = hex::encode(hash);
298        let client = repo.oauth2_client().find_by_metadata_digest(&hash).await?;
299        (Some(hash), client)
300    } else {
301        (None, None)
302    };
303
304    let client = if let Some(client) = existing_client {
305        tracing::info!(%client.id, "Reusing existing client");
306        client
307    } else {
308        let client = repo
309            .oauth2_client()
310            .add(
311                &mut rng,
312                &clock,
313                metadata.redirect_uris().to_vec(),
314                digest_hash,
315                encrypted_client_secret,
316                metadata.application_type.clone(),
317                //&metadata.response_types(),
318                metadata.grant_types().to_vec(),
319                metadata
320                    .client_name
321                    .clone()
322                    .map(Localized::to_non_localized),
323                metadata.logo_uri.clone().map(Localized::to_non_localized),
324                metadata.client_uri.clone().map(Localized::to_non_localized),
325                metadata.policy_uri.clone().map(Localized::to_non_localized),
326                metadata.tos_uri.clone().map(Localized::to_non_localized),
327                metadata.jwks_uri.clone(),
328                metadata.jwks.clone(),
329                // XXX: those might not be right, should be function calls
330                metadata.id_token_signed_response_alg.clone(),
331                metadata.userinfo_signed_response_alg.clone(),
332                metadata.token_endpoint_auth_method.clone(),
333                metadata.token_endpoint_auth_signing_alg.clone(),
334                metadata.initiate_login_uri.clone(),
335            )
336            .await?;
337        tracing::info!(%client.id, "Registered new client");
338        client
339    };
340
341    let response = ClientRegistrationResponse {
342        client_id: client.client_id.clone(),
343        client_secret,
344        // XXX: we should have a `created_at` field on the clients
345        client_id_issued_at: Some(client.id.datetime().into()),
346        client_secret_expires_at: None,
347    };
348
349    // We round-trip back to the metadata to output it in the response
350    // This should never fail, as the client is valid
351    let metadata = client.into_metadata().validate()?;
352
353    repo.save().await?;
354
355    let response = RouteResponse { response, metadata };
356
357    Ok((StatusCode::CREATED, Json(response)))
358}
359
360#[cfg(test)]
361mod tests {
362    use hyper::{Request, StatusCode};
363    use mas_router::SimpleRoute;
364    use oauth2_types::{
365        errors::{ClientError, ClientErrorCode},
366        registration::ClientRegistrationResponse,
367    };
368    use sqlx::PgPool;
369    use url::Url;
370
371    use crate::{
372        oauth2::registration::host_is_public_suffix,
373        test_utils::{RequestBuilderExt, ResponseExt, TestState, setup},
374    };
375
376    #[test]
377    fn test_public_suffix_list() {
378        fn url_is_public_suffix(url: &str) -> bool {
379            host_is_public_suffix(&Url::parse(url).unwrap())
380        }
381
382        assert!(url_is_public_suffix("https://.com"));
383        assert!(url_is_public_suffix("https://.com."));
384        assert!(url_is_public_suffix("https://co.uk"));
385        assert!(url_is_public_suffix("https://github.io"));
386        assert!(!url_is_public_suffix("https://example.com"));
387        assert!(!url_is_public_suffix("https://example.com."));
388        assert!(!url_is_public_suffix("https://x.com"));
389        assert!(!url_is_public_suffix("https://x.com."));
390        assert!(!url_is_public_suffix("https://matrix-org.github.io"));
391        assert!(!url_is_public_suffix("http://localhost"));
392        assert!(!url_is_public_suffix("org.matrix:/callback"));
393        assert!(!url_is_public_suffix("http://somerandominternaldomain"));
394    }
395
396    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
397    async fn test_registration_error(pool: PgPool) {
398        setup();
399        let state = TestState::from_pool(pool).await.unwrap();
400
401        // Body is not a JSON
402        let request = Request::post(mas_router::OAuth2RegistrationEndpoint::PATH)
403            .body("this is not a json".to_owned())
404            .unwrap();
405
406        let response = state.request(request).await;
407        response.assert_status(StatusCode::BAD_REQUEST);
408        let response: ClientError = response.json();
409        assert_eq!(response.error, ClientErrorCode::InvalidRequest);
410
411        // Invalid client metadata
412        let request =
413            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
414                "client_uri": "this is not a uri",
415            }));
416
417        let response = state.request(request).await;
418        response.assert_status(StatusCode::BAD_REQUEST);
419        let response: ClientError = response.json();
420        assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
421
422        // Invalid redirect URI
423        let request =
424            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
425                "application_type": "web",
426                "client_uri": "https://example.com/",
427                "redirect_uris": ["http://this-is-insecure.com/"],
428            }));
429
430        let response = state.request(request).await;
431        response.assert_status(StatusCode::BAD_REQUEST);
432        let response: ClientError = response.json();
433        assert_eq!(response.error, ClientErrorCode::InvalidRedirectUri);
434
435        // Incoherent response types
436        let request =
437            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
438                "client_uri": "https://example.com/",
439                "redirect_uris": ["https://example.com/"],
440                "response_types": ["id_token"],
441                "grant_types": ["authorization_code"],
442            }));
443
444        let response = state.request(request).await;
445        response.assert_status(StatusCode::BAD_REQUEST);
446        let response: ClientError = response.json();
447        assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
448
449        // Using a public suffix
450        let request =
451            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
452                "client_uri": "https://github.io/",
453                "redirect_uris": ["https://github.io/"],
454                "response_types": ["code"],
455                "grant_types": ["authorization_code"],
456                "token_endpoint_auth_method": "client_secret_basic",
457            }));
458
459        let response = state.request(request).await;
460        response.assert_status(StatusCode::BAD_REQUEST);
461        let response: ClientError = response.json();
462        assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
463        assert_eq!(
464            response.error_description.unwrap(),
465            "client_uri is not using a valid domain"
466        );
467
468        // Using a public suffix in a translated URL
469        let request =
470            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
471                "client_uri": "https://example.com/",
472                "client_uri#fr-FR": "https://github.io/",
473                "redirect_uris": ["https://example.com/"],
474                "response_types": ["code"],
475                "grant_types": ["authorization_code"],
476                "token_endpoint_auth_method": "client_secret_basic",
477            }));
478
479        let response = state.request(request).await;
480        response.assert_status(StatusCode::BAD_REQUEST);
481        let response: ClientError = response.json();
482        assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
483        assert_eq!(
484            response.error_description.unwrap(),
485            "client_uri is not using a valid domain"
486        );
487    }
488
489    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
490    async fn test_registration(pool: PgPool) {
491        setup();
492        let state = TestState::from_pool(pool).await.unwrap();
493
494        // A successful registration with no authentication should not return a client
495        // secret
496        let request =
497            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
498                "client_uri": "https://example.com/",
499                "redirect_uris": ["https://example.com/"],
500                "response_types": ["code"],
501                "grant_types": ["authorization_code"],
502                "token_endpoint_auth_method": "none",
503            }));
504
505        let response = state.request(request).await;
506        response.assert_status(StatusCode::CREATED);
507        let response: ClientRegistrationResponse = response.json();
508        assert!(response.client_secret.is_none());
509
510        // A successful registration with client_secret based authentication should
511        // return a client secret
512        let request =
513            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
514                "client_uri": "https://example.com/",
515                "redirect_uris": ["https://example.com/"],
516                "response_types": ["code"],
517                "grant_types": ["authorization_code"],
518                "token_endpoint_auth_method": "client_secret_basic",
519            }));
520
521        let response = state.request(request).await;
522        response.assert_status(StatusCode::CREATED);
523        let response: ClientRegistrationResponse = response.json();
524        assert!(response.client_secret.is_some());
525    }
526    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
527    async fn test_registration_dedupe(pool: PgPool) {
528        setup();
529        let state = TestState::from_pool(pool).await.unwrap();
530
531        // Post a client registration twice, we should get the same client ID
532        let request =
533            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
534                "client_uri": "https://example.com/",
535                "client_name": "Example",
536                "client_name#en": "Example",
537                "client_name#fr": "Exemple",
538                "client_name#de": "Beispiel",
539                "redirect_uris": ["https://example.com/", "https://example.com/callback"],
540                "response_types": ["code"],
541                "grant_types": ["authorization_code", "urn:ietf:params:oauth:grant-type:device_code"],
542                "token_endpoint_auth_method": "none",
543            }));
544
545        let response = state.request(request.clone()).await;
546        response.assert_status(StatusCode::CREATED);
547        let response: ClientRegistrationResponse = response.json();
548        let client_id = response.client_id;
549
550        let response = state.request(request).await;
551        response.assert_status(StatusCode::CREATED);
552        let response: ClientRegistrationResponse = response.json();
553        assert_eq!(response.client_id, client_id);
554
555        // Check that the order of some properties doesn't matter
556        let request =
557            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
558                "client_uri": "https://example.com/",
559                "client_name": "Example",
560                "client_name#de": "Beispiel",
561                "client_name#fr": "Exemple",
562                "client_name#en": "Example",
563                "redirect_uris": ["https://example.com/callback", "https://example.com/"],
564                "response_types": ["code"],
565                "grant_types": ["urn:ietf:params:oauth:grant-type:device_code", "authorization_code"],
566                "token_endpoint_auth_method": "none",
567            }));
568
569        let response = state.request(request).await;
570        response.assert_status(StatusCode::CREATED);
571        let response: ClientRegistrationResponse = response.json();
572        assert_eq!(response.client_id, client_id);
573
574        // Doing that with a client that has a client_secret should not deduplicate
575        let request =
576            Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
577                "client_uri": "https://example.com/",
578                "redirect_uris": ["https://example.com/"],
579                "response_types": ["code"],
580                "grant_types": ["authorization_code"],
581                "token_endpoint_auth_method": "client_secret_basic",
582            }));
583
584        let response = state.request(request.clone()).await;
585        response.assert_status(StatusCode::CREATED);
586        let response: ClientRegistrationResponse = response.json();
587        // Sanity check that the client_id is different
588        assert_ne!(response.client_id, client_id);
589        let client_id = response.client_id;
590
591        let response = state.request(request).await;
592        response.assert_status(StatusCode::CREATED);
593        let response: ClientRegistrationResponse = response.json();
594        assert_ne!(response.client_id, client_id);
595    }
596}