1use std::collections::HashMap;
8
9use axum::{
10 BoxError, Json,
11 extract::{
12 Form, FromRequest, FromRequestParts,
13 rejection::{FailedToDeserializeForm, FormRejection},
14 },
15 response::IntoResponse,
16};
17use axum_extra::typed_header::{TypedHeader, TypedHeaderRejectionReason};
18use headers::{Authorization, authorization::Basic};
19use http::{Request, StatusCode};
20use mas_data_model::{Client, JwksOrJwksUri};
21use mas_http::RequestBuilderExt;
22use mas_iana::oauth::OAuthClientAuthenticationMethod;
23use mas_jose::{jwk::PublicJsonWebKeySet, jwt::Jwt};
24use mas_keystore::Encrypter;
25use mas_storage::{RepositoryAccess, oauth2::OAuth2ClientRepository};
26use oauth2_types::errors::{ClientError, ClientErrorCode};
27use serde::{Deserialize, de::DeserializeOwned};
28use serde_json::Value;
29use thiserror::Error;
30
31static JWT_BEARER_CLIENT_ASSERTION: &str = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer";
32
33#[derive(Deserialize)]
34struct AuthorizedForm<F = ()> {
35 client_id: Option<String>,
36 client_secret: Option<String>,
37 client_assertion_type: Option<String>,
38 client_assertion: Option<String>,
39
40 #[serde(flatten)]
41 inner: F,
42}
43
44#[derive(Debug, PartialEq, Eq)]
45pub enum Credentials {
46 None {
47 client_id: String,
48 },
49 ClientSecretBasic {
50 client_id: String,
51 client_secret: String,
52 },
53 ClientSecretPost {
54 client_id: String,
55 client_secret: String,
56 },
57 ClientAssertionJwtBearer {
58 client_id: String,
59 jwt: Box<Jwt<'static, HashMap<String, serde_json::Value>>>,
60 },
61}
62
63impl Credentials {
64 #[must_use]
66 pub fn client_id(&self) -> &str {
67 match self {
68 Credentials::None { client_id }
69 | Credentials::ClientSecretBasic { client_id, .. }
70 | Credentials::ClientSecretPost { client_id, .. }
71 | Credentials::ClientAssertionJwtBearer { client_id, .. } => client_id,
72 }
73 }
74
75 pub async fn fetch<E>(
82 &self,
83 repo: &mut impl RepositoryAccess<Error = E>,
84 ) -> Result<Option<Client>, E> {
85 let client_id = match self {
86 Credentials::None { client_id }
87 | Credentials::ClientSecretBasic { client_id, .. }
88 | Credentials::ClientSecretPost { client_id, .. }
89 | Credentials::ClientAssertionJwtBearer { client_id, .. } => client_id,
90 };
91
92 repo.oauth2_client().find_by_client_id(client_id).await
93 }
94
95 #[tracing::instrument(skip_all, err)]
101 pub async fn verify(
102 &self,
103 http_client: &reqwest::Client,
104 encrypter: &Encrypter,
105 method: &OAuthClientAuthenticationMethod,
106 client: &Client,
107 ) -> Result<(), CredentialsVerificationError> {
108 match (self, method) {
109 (Credentials::None { .. }, OAuthClientAuthenticationMethod::None) => {}
110
111 (
112 Credentials::ClientSecretPost { client_secret, .. },
113 OAuthClientAuthenticationMethod::ClientSecretPost,
114 )
115 | (
116 Credentials::ClientSecretBasic { client_secret, .. },
117 OAuthClientAuthenticationMethod::ClientSecretBasic,
118 ) => {
119 let encrypted_client_secret = client
121 .encrypted_client_secret
122 .as_ref()
123 .ok_or(CredentialsVerificationError::InvalidClientConfig)?;
124
125 let decrypted_client_secret = encrypter
126 .decrypt_string(encrypted_client_secret)
127 .map_err(|_e| CredentialsVerificationError::DecryptionError)?;
128
129 if client_secret.as_bytes() != decrypted_client_secret {
131 return Err(CredentialsVerificationError::ClientSecretMismatch);
132 }
133 }
134
135 (
136 Credentials::ClientAssertionJwtBearer { jwt, .. },
137 OAuthClientAuthenticationMethod::PrivateKeyJwt,
138 ) => {
139 let jwks = client
141 .jwks
142 .as_ref()
143 .ok_or(CredentialsVerificationError::InvalidClientConfig)?;
144
145 let jwks = fetch_jwks(http_client, jwks)
146 .await
147 .map_err(|_| CredentialsVerificationError::JwksFetchFailed)?;
148
149 jwt.verify_with_jwks(&jwks)
150 .map_err(|_| CredentialsVerificationError::InvalidAssertionSignature)?;
151 }
152
153 (
154 Credentials::ClientAssertionJwtBearer { jwt, .. },
155 OAuthClientAuthenticationMethod::ClientSecretJwt,
156 ) => {
157 let encrypted_client_secret = client
159 .encrypted_client_secret
160 .as_ref()
161 .ok_or(CredentialsVerificationError::InvalidClientConfig)?;
162
163 let decrypted_client_secret = encrypter
164 .decrypt_string(encrypted_client_secret)
165 .map_err(|_e| CredentialsVerificationError::DecryptionError)?;
166
167 jwt.verify_with_shared_secret(decrypted_client_secret)
168 .map_err(|_| CredentialsVerificationError::InvalidAssertionSignature)?;
169 }
170
171 (_, _) => {
172 return Err(CredentialsVerificationError::AuthenticationMethodMismatch);
173 }
174 }
175 Ok(())
176 }
177}
178
179async fn fetch_jwks(
180 http_client: &reqwest::Client,
181 jwks: &JwksOrJwksUri,
182) -> Result<PublicJsonWebKeySet, BoxError> {
183 let uri = match jwks {
184 JwksOrJwksUri::Jwks(j) => return Ok(j.clone()),
185 JwksOrJwksUri::JwksUri(u) => u,
186 };
187
188 let response = http_client
189 .get(uri.as_str())
190 .send_traced()
191 .await?
192 .error_for_status()?
193 .json()
194 .await?;
195
196 Ok(response)
197}
198
199#[derive(Debug, Error)]
200pub enum CredentialsVerificationError {
201 #[error("failed to decrypt client credentials")]
202 DecryptionError,
203
204 #[error("invalid client configuration")]
205 InvalidClientConfig,
206
207 #[error("client secret did not match")]
208 ClientSecretMismatch,
209
210 #[error("authentication method mismatch")]
211 AuthenticationMethodMismatch,
212
213 #[error("invalid assertion signature")]
214 InvalidAssertionSignature,
215
216 #[error("failed to fetch jwks")]
217 JwksFetchFailed,
218}
219
220#[derive(Debug, PartialEq, Eq)]
221pub struct ClientAuthorization<F = ()> {
222 pub credentials: Credentials,
223 pub form: Option<F>,
224}
225
226impl<F> ClientAuthorization<F> {
227 #[must_use]
229 pub fn client_id(&self) -> &str {
230 self.credentials.client_id()
231 }
232}
233
234#[derive(Debug)]
235pub enum ClientAuthorizationError {
236 InvalidHeader,
237 BadForm(FailedToDeserializeForm),
238 ClientIdMismatch { credential: String, form: String },
239 UnsupportedClientAssertion { client_assertion_type: String },
240 MissingCredentials,
241 InvalidRequest,
242 InvalidAssertion,
243 Internal(Box<dyn std::error::Error>),
244}
245
246impl IntoResponse for ClientAuthorizationError {
247 fn into_response(self) -> axum::response::Response {
248 match self {
249 ClientAuthorizationError::InvalidHeader => (
250 StatusCode::BAD_REQUEST,
251 Json(ClientError::new(
252 ClientErrorCode::InvalidRequest,
253 "Invalid Authorization header",
254 )),
255 ),
256
257 ClientAuthorizationError::BadForm(err) => (
258 StatusCode::BAD_REQUEST,
259 Json(
260 ClientError::from(ClientErrorCode::InvalidRequest)
261 .with_description(format!("{err}")),
262 ),
263 ),
264
265 ClientAuthorizationError::ClientIdMismatch { form, credential } => {
266 let description = format!(
267 "client_id in form ({form:?}) does not match credential ({credential:?})"
268 );
269
270 (
271 StatusCode::BAD_REQUEST,
272 Json(
273 ClientError::from(ClientErrorCode::InvalidGrant)
274 .with_description(description),
275 ),
276 )
277 }
278
279 ClientAuthorizationError::UnsupportedClientAssertion {
280 client_assertion_type,
281 } => (
282 StatusCode::BAD_REQUEST,
283 Json(
284 ClientError::from(ClientErrorCode::InvalidRequest).with_description(format!(
285 "Unsupported client_assertion_type: {client_assertion_type}",
286 )),
287 ),
288 ),
289
290 ClientAuthorizationError::MissingCredentials => (
291 StatusCode::BAD_REQUEST,
292 Json(ClientError::new(
293 ClientErrorCode::InvalidRequest,
294 "No credentials were presented",
295 )),
296 ),
297
298 ClientAuthorizationError::InvalidRequest => (
299 StatusCode::BAD_REQUEST,
300 Json(ClientError::from(ClientErrorCode::InvalidRequest)),
301 ),
302
303 ClientAuthorizationError::InvalidAssertion => (
304 StatusCode::BAD_REQUEST,
305 Json(ClientError::new(
306 ClientErrorCode::InvalidRequest,
307 "Invalid client_assertion",
308 )),
309 ),
310
311 ClientAuthorizationError::Internal(e) => (
312 StatusCode::INTERNAL_SERVER_ERROR,
313 Json(
314 ClientError::from(ClientErrorCode::ServerError)
315 .with_description(format!("{e}")),
316 ),
317 ),
318 }
319 .into_response()
320 }
321}
322
323impl<S, F> FromRequest<S> for ClientAuthorization<F>
324where
325 F: DeserializeOwned,
326 S: Send + Sync,
327{
328 type Rejection = ClientAuthorizationError;
329
330 #[allow(clippy::too_many_lines)]
331 async fn from_request(
332 req: Request<axum::body::Body>,
333 state: &S,
334 ) -> Result<Self, Self::Rejection> {
335 let (mut parts, body) = req.into_parts();
337
338 let header =
339 TypedHeader::<Authorization<Basic>>::from_request_parts(&mut parts, state).await;
340
341 let credentials_from_header = match header {
343 Ok(header) => Some((header.username().to_owned(), header.password().to_owned())),
344 Err(err) => match err.reason() {
345 TypedHeaderRejectionReason::Missing => None,
347 _ => return Err(ClientAuthorizationError::InvalidHeader),
349 },
350 };
351
352 let req = Request::from_parts(parts, body);
354
355 let (
357 client_id_from_form,
358 client_secret_from_form,
359 client_assertion_type,
360 client_assertion,
361 form,
362 ) = match Form::<AuthorizedForm<F>>::from_request(req, state).await {
363 Ok(Form(form)) => (
364 form.client_id,
365 form.client_secret,
366 form.client_assertion_type,
367 form.client_assertion,
368 Some(form.inner),
369 ),
370 Err(FormRejection::InvalidFormContentType(_err)) => (None, None, None, None, None),
372 Err(FormRejection::FailedToDeserializeForm(err)) => {
374 return Err(ClientAuthorizationError::BadForm(err));
375 }
376 Err(e) => return Err(ClientAuthorizationError::Internal(Box::new(e))),
378 };
379
380 let credentials = match (
382 credentials_from_header,
383 client_id_from_form,
384 client_secret_from_form,
385 client_assertion_type,
386 client_assertion,
387 ) {
388 (Some((client_id, client_secret)), client_id_from_form, None, None, None) => {
389 if let Some(client_id_from_form) = client_id_from_form {
390 if client_id != client_id_from_form {
392 return Err(ClientAuthorizationError::ClientIdMismatch {
393 credential: client_id,
394 form: client_id_from_form,
395 });
396 }
397 }
398
399 Credentials::ClientSecretBasic {
400 client_id,
401 client_secret,
402 }
403 }
404
405 (None, Some(client_id), Some(client_secret), None, None) => {
406 Credentials::ClientSecretPost {
408 client_id,
409 client_secret,
410 }
411 }
412
413 (None, Some(client_id), None, None, None) => {
414 Credentials::None { client_id }
416 }
417
418 (
419 None,
420 client_id_from_form,
421 None,
422 Some(client_assertion_type),
423 Some(client_assertion),
424 ) if client_assertion_type == JWT_BEARER_CLIENT_ASSERTION => {
425 let jwt: Jwt<'static, HashMap<String, Value>> = Jwt::try_from(client_assertion)
427 .map_err(|_| ClientAuthorizationError::InvalidAssertion)?;
428
429 let client_id = if let Some(Value::String(client_id)) = jwt.payload().get("sub") {
430 client_id.clone()
431 } else {
432 return Err(ClientAuthorizationError::InvalidAssertion);
433 };
434
435 if let Some(client_id_from_form) = client_id_from_form {
436 if client_id != client_id_from_form {
438 return Err(ClientAuthorizationError::ClientIdMismatch {
439 credential: client_id,
440 form: client_id_from_form,
441 });
442 }
443 }
444
445 Credentials::ClientAssertionJwtBearer {
446 client_id,
447 jwt: Box::new(jwt),
448 }
449 }
450
451 (None, None, None, Some(client_assertion_type), Some(_client_assertion)) => {
452 return Err(ClientAuthorizationError::UnsupportedClientAssertion {
454 client_assertion_type,
455 });
456 }
457
458 (None, None, None, None, None) => {
459 return Err(ClientAuthorizationError::MissingCredentials);
461 }
462
463 _ => {
464 return Err(ClientAuthorizationError::InvalidRequest);
466 }
467 };
468
469 Ok(ClientAuthorization { credentials, form })
470 }
471}
472
473#[cfg(test)]
474mod tests {
475 use axum::body::Body;
476 use http::{Method, Request};
477
478 use super::*;
479
480 #[tokio::test]
481 async fn none_test() {
482 let req = Request::builder()
483 .method(Method::POST)
484 .header(
485 http::header::CONTENT_TYPE,
486 mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
487 )
488 .body(Body::new("client_id=client-id&foo=bar".to_owned()))
489 .unwrap();
490
491 assert_eq!(
492 ClientAuthorization::<serde_json::Value>::from_request(req, &())
493 .await
494 .unwrap(),
495 ClientAuthorization {
496 credentials: Credentials::None {
497 client_id: "client-id".to_owned(),
498 },
499 form: Some(serde_json::json!({"foo": "bar"})),
500 }
501 );
502 }
503
504 #[tokio::test]
505 async fn client_secret_basic_test() {
506 let req = Request::builder()
507 .method(Method::POST)
508 .header(
509 http::header::CONTENT_TYPE,
510 mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
511 )
512 .header(
513 http::header::AUTHORIZATION,
514 "Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
515 )
516 .body(Body::new("foo=bar".to_owned()))
517 .unwrap();
518
519 assert_eq!(
520 ClientAuthorization::<serde_json::Value>::from_request(req, &())
521 .await
522 .unwrap(),
523 ClientAuthorization {
524 credentials: Credentials::ClientSecretBasic {
525 client_id: "client-id".to_owned(),
526 client_secret: "client-secret".to_owned(),
527 },
528 form: Some(serde_json::json!({"foo": "bar"})),
529 }
530 );
531
532 let req = Request::builder()
534 .method(Method::POST)
535 .header(
536 http::header::CONTENT_TYPE,
537 mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
538 )
539 .header(
540 http::header::AUTHORIZATION,
541 "Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
542 )
543 .body(Body::new("client_id=client-id&foo=bar".to_owned()))
544 .unwrap();
545
546 assert_eq!(
547 ClientAuthorization::<serde_json::Value>::from_request(req, &())
548 .await
549 .unwrap(),
550 ClientAuthorization {
551 credentials: Credentials::ClientSecretBasic {
552 client_id: "client-id".to_owned(),
553 client_secret: "client-secret".to_owned(),
554 },
555 form: Some(serde_json::json!({"foo": "bar"})),
556 }
557 );
558
559 let req = Request::builder()
561 .method(Method::POST)
562 .header(
563 http::header::CONTENT_TYPE,
564 mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
565 )
566 .header(
567 http::header::AUTHORIZATION,
568 "Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
569 )
570 .body(Body::new("client_id=mismatch-id&foo=bar".to_owned()))
571 .unwrap();
572
573 assert!(matches!(
574 ClientAuthorization::<serde_json::Value>::from_request(req, &()).await,
575 Err(ClientAuthorizationError::ClientIdMismatch { .. }),
576 ));
577
578 let req = Request::builder()
580 .method(Method::POST)
581 .header(
582 http::header::CONTENT_TYPE,
583 mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
584 )
585 .header(http::header::AUTHORIZATION, "Basic invalid")
586 .body(Body::new("foo=bar".to_owned()))
587 .unwrap();
588
589 assert!(matches!(
590 ClientAuthorization::<serde_json::Value>::from_request(req, &()).await,
591 Err(ClientAuthorizationError::InvalidHeader),
592 ));
593 }
594
595 #[tokio::test]
596 async fn client_secret_post_test() {
597 let req = Request::builder()
598 .method(Method::POST)
599 .header(
600 http::header::CONTENT_TYPE,
601 mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
602 )
603 .body(Body::new(
604 "client_id=client-id&client_secret=client-secret&foo=bar".to_owned(),
605 ))
606 .unwrap();
607
608 assert_eq!(
609 ClientAuthorization::<serde_json::Value>::from_request(req, &())
610 .await
611 .unwrap(),
612 ClientAuthorization {
613 credentials: Credentials::ClientSecretPost {
614 client_id: "client-id".to_owned(),
615 client_secret: "client-secret".to_owned(),
616 },
617 form: Some(serde_json::json!({"foo": "bar"})),
618 }
619 );
620 }
621
622 #[tokio::test]
623 async fn client_assertion_test() {
624 let jwt = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJjbGllbnQtaWQiLCJzdWIiOiJjbGllbnQtaWQiLCJhdWQiOiJodHRwczovL2V4YW1wbGUuY29tL29hdXRoMi9pbnRyb3NwZWN0IiwianRpIjoiYWFiYmNjIiwiZXhwIjoxNTE2MjM5MzIyLCJpYXQiOjE1MTYyMzkwMjJ9.XTaACG_Rww0GPecSZvkbem-AczNy9LLNBueCLCiQajU";
626 let body = Body::new(format!(
627 "client_assertion_type={JWT_BEARER_CLIENT_ASSERTION}&client_assertion={jwt}&foo=bar",
628 ));
629
630 let req = Request::builder()
631 .method(Method::POST)
632 .header(
633 http::header::CONTENT_TYPE,
634 mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
635 )
636 .body(body)
637 .unwrap();
638
639 let authz = ClientAuthorization::<serde_json::Value>::from_request(req, &())
640 .await
641 .unwrap();
642 assert_eq!(authz.form, Some(serde_json::json!({"foo": "bar"})));
643
644 let Credentials::ClientAssertionJwtBearer { client_id, jwt } = authz.credentials else {
645 panic!("expected a JWT client_assertion");
646 };
647
648 assert_eq!(client_id, "client-id");
649 jwt.verify_with_shared_secret(b"client-secret".to_vec())
650 .unwrap();
651 }
652}