1use axum::{Json, extract::State, response::IntoResponse};
8use axum_extra::TypedHeader;
9use hyper::StatusCode;
10use mas_axum_utils::sentry::SentryEventID;
11use mas_iana::oauth::OAuthClientAuthenticationMethod;
12use mas_keystore::Encrypter;
13use mas_policy::{Policy, Violation};
14use mas_storage::{BoxClock, BoxRepository, BoxRng, oauth2::OAuth2ClientRepository};
15use oauth2_types::{
16 errors::{ClientError, ClientErrorCode},
17 registration::{
18 ClientMetadata, ClientMetadataVerificationError, ClientRegistrationResponse, Localized,
19 VerifiedClientMetadata,
20 },
21};
22use psl::Psl;
23use rand::distributions::{Alphanumeric, DistString};
24use serde::Serialize;
25use sha2::Digest as _;
26use thiserror::Error;
27use tracing::info;
28use url::Url;
29
30use crate::{BoundActivityTracker, impl_from_error_for_route};
31
32#[derive(Debug, Error)]
33pub(crate) enum RouteError {
34 #[error(transparent)]
35 Internal(Box<dyn std::error::Error + Send + Sync>),
36
37 #[error(transparent)]
38 JsonExtract(#[from] axum::extract::rejection::JsonRejection),
39
40 #[error("invalid client metadata")]
41 InvalidClientMetadata(#[from] ClientMetadataVerificationError),
42
43 #[error("{0} is a public suffix, not a valid domain")]
44 UrlIsPublicSuffix(&'static str),
45
46 #[error("denied by the policy: {0:?}")]
47 PolicyDenied(Vec<Violation>),
48}
49
50impl_from_error_for_route!(mas_storage::RepositoryError);
51impl_from_error_for_route!(mas_policy::LoadError);
52impl_from_error_for_route!(mas_policy::EvaluationError);
53impl_from_error_for_route!(mas_keystore::aead::Error);
54impl_from_error_for_route!(serde_json::Error);
55
56impl IntoResponse for RouteError {
57 fn into_response(self) -> axum::response::Response {
58 let event_id = sentry::capture_error(&self);
59 let response = match self {
60 Self::Internal(_) => (
61 StatusCode::INTERNAL_SERVER_ERROR,
62 Json(ClientError::from(ClientErrorCode::ServerError)),
63 )
64 .into_response(),
65
66 Self::JsonExtract(axum::extract::rejection::JsonRejection::JsonDataError(e)) => (
70 StatusCode::BAD_REQUEST,
71 Json(
72 ClientError::from(ClientErrorCode::InvalidClientMetadata)
73 .with_description(e.to_string()),
74 ),
75 )
76 .into_response(),
77
78 Self::JsonExtract(_) => (
81 StatusCode::BAD_REQUEST,
82 Json(ClientError::from(ClientErrorCode::InvalidRequest)),
83 )
84 .into_response(),
85
86 Self::InvalidClientMetadata(
90 ClientMetadataVerificationError::MissingRedirectUris
91 | ClientMetadataVerificationError::RedirectUriWithFragment(_),
92 ) => (
93 StatusCode::BAD_REQUEST,
94 Json(ClientError::from(ClientErrorCode::InvalidRedirectUri)),
95 )
96 .into_response(),
97
98 Self::InvalidClientMetadata(e) => (
99 StatusCode::BAD_REQUEST,
100 Json(
101 ClientError::from(ClientErrorCode::InvalidClientMetadata)
102 .with_description(e.to_string()),
103 ),
104 )
105 .into_response(),
106
107 Self::UrlIsPublicSuffix("redirect_uri") => (
111 StatusCode::BAD_REQUEST,
112 Json(
113 ClientError::from(ClientErrorCode::InvalidRedirectUri)
114 .with_description("redirect_uri is not using a valid domain".to_owned()),
115 ),
116 )
117 .into_response(),
118
119 Self::UrlIsPublicSuffix(field) => (
120 StatusCode::BAD_REQUEST,
121 Json(
122 ClientError::from(ClientErrorCode::InvalidClientMetadata)
123 .with_description(format!("{field} is not using a valid domain")),
124 ),
125 )
126 .into_response(),
127
128 Self::PolicyDenied(violations) => {
132 let code = if violations.iter().any(|v| v.msg.contains("redirect_uri")) {
134 ClientErrorCode::InvalidRedirectUri
135 } else {
136 ClientErrorCode::InvalidClientMetadata
137 };
138
139 let collected = &violations
140 .iter()
141 .map(|v| v.msg.clone())
142 .collect::<Vec<String>>();
143 let joined = collected.join("; ");
144
145 (
146 StatusCode::BAD_REQUEST,
147 Json(ClientError::from(code).with_description(joined)),
148 )
149 .into_response()
150 }
151 };
152
153 (SentryEventID::from(event_id), response).into_response()
154 }
155}
156
157#[derive(Serialize)]
158struct RouteResponse {
159 #[serde(flatten)]
160 response: ClientRegistrationResponse,
161 #[serde(flatten)]
162 metadata: VerifiedClientMetadata,
163}
164
165fn host_is_public_suffix(url: &Url) -> bool {
167 let host = url.host_str().unwrap_or_default().as_bytes();
168 let Some(suffix) = psl::List.suffix(host) else {
169 return false;
172 };
173
174 if !suffix.is_known() {
175 return false;
177 }
178
179 if host.len() <= suffix.as_bytes().len() + 1 {
183 return true;
185 }
186
187 false
188}
189
190fn localised_url_has_public_suffix(url: &Localized<Url>) -> bool {
192 url.iter().any(|(_lang, url)| host_is_public_suffix(url))
193}
194
195#[tracing::instrument(name = "handlers.oauth2.registration.post", skip_all, err)]
196pub(crate) async fn post(
197 mut rng: BoxRng,
198 clock: BoxClock,
199 mut repo: BoxRepository,
200 mut policy: Policy,
201 activity_tracker: BoundActivityTracker,
202 user_agent: Option<TypedHeader<headers::UserAgent>>,
203 State(encrypter): State<Encrypter>,
204 body: Result<Json<ClientMetadata>, axum::extract::rejection::JsonRejection>,
205) -> Result<impl IntoResponse, RouteError> {
206 let Json(body) = body?;
208
209 let body = body.sorted();
211
212 let body_json = serde_json::to_string(&body)?;
214
215 info!(body = body_json, "Client registration");
216
217 let user_agent = user_agent.map(|ua| ua.to_string());
218
219 let metadata = body.validate()?;
221
222 if let Some(client_uri) = &metadata.client_uri {
225 if localised_url_has_public_suffix(client_uri) {
226 return Err(RouteError::UrlIsPublicSuffix("client_uri"));
227 }
228 }
229
230 if let Some(logo_uri) = &metadata.logo_uri {
231 if localised_url_has_public_suffix(logo_uri) {
232 return Err(RouteError::UrlIsPublicSuffix("logo_uri"));
233 }
234 }
235
236 if let Some(policy_uri) = &metadata.policy_uri {
237 if localised_url_has_public_suffix(policy_uri) {
238 return Err(RouteError::UrlIsPublicSuffix("policy_uri"));
239 }
240 }
241
242 if let Some(tos_uri) = &metadata.tos_uri {
243 if localised_url_has_public_suffix(tos_uri) {
244 return Err(RouteError::UrlIsPublicSuffix("tos_uri"));
245 }
246 }
247
248 if let Some(initiate_login_uri) = &metadata.initiate_login_uri {
249 if host_is_public_suffix(initiate_login_uri) {
250 return Err(RouteError::UrlIsPublicSuffix("initiate_login_uri"));
251 }
252 }
253
254 for redirect_uri in metadata.redirect_uris() {
255 if host_is_public_suffix(redirect_uri) {
256 return Err(RouteError::UrlIsPublicSuffix("redirect_uri"));
257 }
258 }
259
260 let res = policy
261 .evaluate_client_registration(mas_policy::ClientRegistrationInput {
262 client_metadata: &metadata,
263 requester: mas_policy::Requester {
264 ip_address: activity_tracker.ip(),
265 user_agent,
266 },
267 })
268 .await?;
269 if !res.valid() {
270 return Err(RouteError::PolicyDenied(res.violations));
271 }
272
273 let (client_secret, encrypted_client_secret) = match metadata.token_endpoint_auth_method {
274 Some(
275 OAuthClientAuthenticationMethod::ClientSecretJwt
276 | OAuthClientAuthenticationMethod::ClientSecretPost
277 | OAuthClientAuthenticationMethod::ClientSecretBasic,
278 ) => {
279 let client_secret = Alphanumeric.sample_string(&mut rng, 20);
281 let encrypted_client_secret = encrypter.encrypt_to_string(client_secret.as_bytes())?;
282 (Some(client_secret), Some(encrypted_client_secret))
283 }
284 _ => (None, None),
285 };
286
287 let (digest_hash, existing_client) = if client_secret.is_none() {
290 let hash = sha2::Sha256::digest(body_json);
297 let hash = hex::encode(hash);
298 let client = repo.oauth2_client().find_by_metadata_digest(&hash).await?;
299 (Some(hash), client)
300 } else {
301 (None, None)
302 };
303
304 let client = if let Some(client) = existing_client {
305 tracing::info!(%client.id, "Reusing existing client");
306 client
307 } else {
308 let client = repo
309 .oauth2_client()
310 .add(
311 &mut rng,
312 &clock,
313 metadata.redirect_uris().to_vec(),
314 digest_hash,
315 encrypted_client_secret,
316 metadata.application_type.clone(),
317 metadata.grant_types().to_vec(),
319 metadata
320 .client_name
321 .clone()
322 .map(Localized::to_non_localized),
323 metadata.logo_uri.clone().map(Localized::to_non_localized),
324 metadata.client_uri.clone().map(Localized::to_non_localized),
325 metadata.policy_uri.clone().map(Localized::to_non_localized),
326 metadata.tos_uri.clone().map(Localized::to_non_localized),
327 metadata.jwks_uri.clone(),
328 metadata.jwks.clone(),
329 metadata.id_token_signed_response_alg.clone(),
331 metadata.userinfo_signed_response_alg.clone(),
332 metadata.token_endpoint_auth_method.clone(),
333 metadata.token_endpoint_auth_signing_alg.clone(),
334 metadata.initiate_login_uri.clone(),
335 )
336 .await?;
337 tracing::info!(%client.id, "Registered new client");
338 client
339 };
340
341 let response = ClientRegistrationResponse {
342 client_id: client.client_id.clone(),
343 client_secret,
344 client_id_issued_at: Some(client.id.datetime().into()),
346 client_secret_expires_at: None,
347 };
348
349 let metadata = client.into_metadata().validate()?;
352
353 repo.save().await?;
354
355 let response = RouteResponse { response, metadata };
356
357 Ok((StatusCode::CREATED, Json(response)))
358}
359
360#[cfg(test)]
361mod tests {
362 use hyper::{Request, StatusCode};
363 use mas_router::SimpleRoute;
364 use oauth2_types::{
365 errors::{ClientError, ClientErrorCode},
366 registration::ClientRegistrationResponse,
367 };
368 use sqlx::PgPool;
369 use url::Url;
370
371 use crate::{
372 oauth2::registration::host_is_public_suffix,
373 test_utils::{RequestBuilderExt, ResponseExt, TestState, setup},
374 };
375
376 #[test]
377 fn test_public_suffix_list() {
378 fn url_is_public_suffix(url: &str) -> bool {
379 host_is_public_suffix(&Url::parse(url).unwrap())
380 }
381
382 assert!(url_is_public_suffix("https://.com"));
383 assert!(url_is_public_suffix("https://.com."));
384 assert!(url_is_public_suffix("https://co.uk"));
385 assert!(url_is_public_suffix("https://github.io"));
386 assert!(!url_is_public_suffix("https://example.com"));
387 assert!(!url_is_public_suffix("https://example.com."));
388 assert!(!url_is_public_suffix("https://x.com"));
389 assert!(!url_is_public_suffix("https://x.com."));
390 assert!(!url_is_public_suffix("https://matrix-org.github.io"));
391 assert!(!url_is_public_suffix("http://localhost"));
392 assert!(!url_is_public_suffix("org.matrix:/callback"));
393 assert!(!url_is_public_suffix("http://somerandominternaldomain"));
394 }
395
396 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
397 async fn test_registration_error(pool: PgPool) {
398 setup();
399 let state = TestState::from_pool(pool).await.unwrap();
400
401 let request = Request::post(mas_router::OAuth2RegistrationEndpoint::PATH)
403 .body("this is not a json".to_owned())
404 .unwrap();
405
406 let response = state.request(request).await;
407 response.assert_status(StatusCode::BAD_REQUEST);
408 let response: ClientError = response.json();
409 assert_eq!(response.error, ClientErrorCode::InvalidRequest);
410
411 let request =
413 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
414 "client_uri": "this is not a uri",
415 }));
416
417 let response = state.request(request).await;
418 response.assert_status(StatusCode::BAD_REQUEST);
419 let response: ClientError = response.json();
420 assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
421
422 let request =
424 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
425 "application_type": "web",
426 "client_uri": "https://example.com/",
427 "redirect_uris": ["http://this-is-insecure.com/"],
428 }));
429
430 let response = state.request(request).await;
431 response.assert_status(StatusCode::BAD_REQUEST);
432 let response: ClientError = response.json();
433 assert_eq!(response.error, ClientErrorCode::InvalidRedirectUri);
434
435 let request =
437 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
438 "client_uri": "https://example.com/",
439 "redirect_uris": ["https://example.com/"],
440 "response_types": ["id_token"],
441 "grant_types": ["authorization_code"],
442 }));
443
444 let response = state.request(request).await;
445 response.assert_status(StatusCode::BAD_REQUEST);
446 let response: ClientError = response.json();
447 assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
448
449 let request =
451 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
452 "client_uri": "https://github.io/",
453 "redirect_uris": ["https://github.io/"],
454 "response_types": ["code"],
455 "grant_types": ["authorization_code"],
456 "token_endpoint_auth_method": "client_secret_basic",
457 }));
458
459 let response = state.request(request).await;
460 response.assert_status(StatusCode::BAD_REQUEST);
461 let response: ClientError = response.json();
462 assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
463 assert_eq!(
464 response.error_description.unwrap(),
465 "client_uri is not using a valid domain"
466 );
467
468 let request =
470 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
471 "client_uri": "https://example.com/",
472 "client_uri#fr-FR": "https://github.io/",
473 "redirect_uris": ["https://example.com/"],
474 "response_types": ["code"],
475 "grant_types": ["authorization_code"],
476 "token_endpoint_auth_method": "client_secret_basic",
477 }));
478
479 let response = state.request(request).await;
480 response.assert_status(StatusCode::BAD_REQUEST);
481 let response: ClientError = response.json();
482 assert_eq!(response.error, ClientErrorCode::InvalidClientMetadata);
483 assert_eq!(
484 response.error_description.unwrap(),
485 "client_uri is not using a valid domain"
486 );
487 }
488
489 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
490 async fn test_registration(pool: PgPool) {
491 setup();
492 let state = TestState::from_pool(pool).await.unwrap();
493
494 let request =
497 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
498 "client_uri": "https://example.com/",
499 "redirect_uris": ["https://example.com/"],
500 "response_types": ["code"],
501 "grant_types": ["authorization_code"],
502 "token_endpoint_auth_method": "none",
503 }));
504
505 let response = state.request(request).await;
506 response.assert_status(StatusCode::CREATED);
507 let response: ClientRegistrationResponse = response.json();
508 assert!(response.client_secret.is_none());
509
510 let request =
513 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
514 "client_uri": "https://example.com/",
515 "redirect_uris": ["https://example.com/"],
516 "response_types": ["code"],
517 "grant_types": ["authorization_code"],
518 "token_endpoint_auth_method": "client_secret_basic",
519 }));
520
521 let response = state.request(request).await;
522 response.assert_status(StatusCode::CREATED);
523 let response: ClientRegistrationResponse = response.json();
524 assert!(response.client_secret.is_some());
525 }
526 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
527 async fn test_registration_dedupe(pool: PgPool) {
528 setup();
529 let state = TestState::from_pool(pool).await.unwrap();
530
531 let request =
533 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
534 "client_uri": "https://example.com/",
535 "client_name": "Example",
536 "client_name#en": "Example",
537 "client_name#fr": "Exemple",
538 "client_name#de": "Beispiel",
539 "redirect_uris": ["https://example.com/", "https://example.com/callback"],
540 "response_types": ["code"],
541 "grant_types": ["authorization_code", "urn:ietf:params:oauth:grant-type:device_code"],
542 "token_endpoint_auth_method": "none",
543 }));
544
545 let response = state.request(request.clone()).await;
546 response.assert_status(StatusCode::CREATED);
547 let response: ClientRegistrationResponse = response.json();
548 let client_id = response.client_id;
549
550 let response = state.request(request).await;
551 response.assert_status(StatusCode::CREATED);
552 let response: ClientRegistrationResponse = response.json();
553 assert_eq!(response.client_id, client_id);
554
555 let request =
557 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
558 "client_uri": "https://example.com/",
559 "client_name": "Example",
560 "client_name#de": "Beispiel",
561 "client_name#fr": "Exemple",
562 "client_name#en": "Example",
563 "redirect_uris": ["https://example.com/callback", "https://example.com/"],
564 "response_types": ["code"],
565 "grant_types": ["urn:ietf:params:oauth:grant-type:device_code", "authorization_code"],
566 "token_endpoint_auth_method": "none",
567 }));
568
569 let response = state.request(request).await;
570 response.assert_status(StatusCode::CREATED);
571 let response: ClientRegistrationResponse = response.json();
572 assert_eq!(response.client_id, client_id);
573
574 let request =
576 Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
577 "client_uri": "https://example.com/",
578 "redirect_uris": ["https://example.com/"],
579 "response_types": ["code"],
580 "grant_types": ["authorization_code"],
581 "token_endpoint_auth_method": "client_secret_basic",
582 }));
583
584 let response = state.request(request.clone()).await;
585 response.assert_status(StatusCode::CREATED);
586 let response: ClientRegistrationResponse = response.json();
587 assert_ne!(response.client_id, client_id);
589 let client_id = response.client_id;
590
591 let response = state.request(request).await;
592 response.assert_status(StatusCode::CREATED);
593 let response: ClientRegistrationResponse = response.json();
594 assert_ne!(response.client_id, client_id);
595 }
596}