1use axum::{
8 Form,
9 extract::{Path, State},
10 http::Method,
11 response::{Html, IntoResponse, Response},
12};
13use hyper::StatusCode;
14use mas_axum_utils::{cookies::CookieJar, sentry::SentryEventID};
15use mas_data_model::{UpstreamOAuthProvider, UpstreamOAuthProviderResponseMode};
16use mas_jose::claims::TokenHash;
17use mas_keystore::{Encrypter, Keystore};
18use mas_oidc_client::requests::jose::JwtVerificationData;
19use mas_router::UrlBuilder;
20use mas_storage::{
21 BoxClock, BoxRepository, BoxRng, Clock,
22 upstream_oauth2::{
23 UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository,
24 UpstreamOAuthSessionRepository,
25 },
26};
27use mas_templates::{FormPostContext, Templates};
28use oauth2_types::{errors::ClientErrorCode, requests::AccessTokenRequest};
29use serde::{Deserialize, Serialize};
30use serde_json::json;
31use thiserror::Error;
32use ulid::Ulid;
33
34use super::{
35 UpstreamSessionsCookie,
36 cache::LazyProviderInfos,
37 client_credentials_for_provider,
38 template::{AttributeMappingContext, environment},
39};
40use crate::{PreferredLanguage, impl_from_error_for_route, upstream_oauth2::cache::MetadataCache};
41
42#[derive(Serialize, Deserialize)]
43pub struct Params {
44 #[serde(skip_serializing_if = "Option::is_none")]
45 state: Option<String>,
46
47 #[serde(default)]
50 did_mas_repost_to_itself: bool,
51
52 #[serde(skip_serializing_if = "Option::is_none")]
53 code: Option<String>,
54
55 #[serde(skip_serializing_if = "Option::is_none")]
56 error: Option<ClientErrorCode>,
57 #[serde(skip_serializing_if = "Option::is_none")]
58 error_description: Option<String>,
59 #[serde(skip_serializing_if = "Option::is_none")]
60 error_uri: Option<String>,
61
62 #[serde(flatten)]
63 extra_callback_parameters: Option<serde_json::Value>,
64}
65
66impl Params {
67 pub fn is_empty(&self) -> bool {
69 self.state.is_none()
70 && self.code.is_none()
71 && self.error.is_none()
72 && self.error_description.is_none()
73 && self.error_uri.is_none()
74 }
75}
76
77#[derive(Debug, Error)]
78pub(crate) enum RouteError {
79 #[error("Session not found")]
80 SessionNotFound,
81
82 #[error("Provider not found")]
83 ProviderNotFound,
84
85 #[error("Provider mismatch")]
86 ProviderMismatch,
87
88 #[error("Session already completed")]
89 AlreadyCompleted,
90
91 #[error("State parameter mismatch")]
92 StateMismatch,
93
94 #[error("Missing state parameter")]
95 MissingState,
96
97 #[error("Missing code parameter")]
98 MissingCode,
99
100 #[error("Could not extract subject from ID token")]
101 ExtractSubject(#[source] minijinja::Error),
102
103 #[error("Subject is empty")]
104 EmptySubject,
105
106 #[error("Error from the provider: {error}")]
107 ClientError {
108 error: ClientErrorCode,
109 error_description: Option<String>,
110 },
111
112 #[error("Missing session cookie")]
113 MissingCookie,
114
115 #[error("Missing query parameters")]
116 MissingQueryParams,
117
118 #[error("Missing form parameters")]
119 MissingFormParams,
120
121 #[error("Invalid response mode, expected '{expected}'")]
122 InvalidResponseMode {
123 expected: UpstreamOAuthProviderResponseMode,
124 },
125
126 #[error(transparent)]
127 Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
128}
129
130impl_from_error_for_route!(mas_templates::TemplateError);
131impl_from_error_for_route!(mas_storage::RepositoryError);
132impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError);
133impl_from_error_for_route!(mas_oidc_client::error::JwksError);
134impl_from_error_for_route!(mas_oidc_client::error::TokenRequestError);
135impl_from_error_for_route!(mas_oidc_client::error::IdTokenError);
136impl_from_error_for_route!(mas_oidc_client::error::UserInfoError);
137impl_from_error_for_route!(super::ProviderCredentialsError);
138impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
139
140impl IntoResponse for RouteError {
141 fn into_response(self) -> axum::response::Response {
142 let event_id = sentry::capture_error(&self);
143 let response = match self {
144 Self::ProviderNotFound => (StatusCode::NOT_FOUND, "Provider not found").into_response(),
145 Self::SessionNotFound => (StatusCode::NOT_FOUND, "Session not found").into_response(),
146 Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
147 e => (StatusCode::BAD_REQUEST, e.to_string()).into_response(),
148 };
149
150 (SentryEventID::from(event_id), response).into_response()
151 }
152}
153
154#[tracing::instrument(
155 name = "handlers.upstream_oauth2.callback.handler",
156 fields(upstream_oauth_provider.id = %provider_id),
157 skip_all,
158 err,
159)]
160#[allow(clippy::too_many_lines, clippy::too_many_arguments)]
161pub(crate) async fn handler(
162 mut rng: BoxRng,
163 clock: BoxClock,
164 State(metadata_cache): State<MetadataCache>,
165 mut repo: BoxRepository,
166 State(url_builder): State<UrlBuilder>,
167 State(encrypter): State<Encrypter>,
168 State(keystore): State<Keystore>,
169 State(client): State<reqwest::Client>,
170 State(templates): State<Templates>,
171 method: Method,
172 PreferredLanguage(locale): PreferredLanguage,
173 cookie_jar: CookieJar,
174 Path(provider_id): Path<Ulid>,
175 Form(params): Form<Params>,
176) -> Result<Response, RouteError> {
177 let provider = repo
178 .upstream_oauth_provider()
179 .lookup(provider_id)
180 .await?
181 .filter(UpstreamOAuthProvider::enabled)
182 .ok_or(RouteError::ProviderNotFound)?;
183
184 let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
185
186 if params.is_empty() {
187 if let Method::GET = method {
188 return Err(RouteError::MissingQueryParams);
189 }
190
191 return Err(RouteError::MissingFormParams);
192 }
193
194 match (provider.response_mode, method) {
198 (Some(UpstreamOAuthProviderResponseMode::FormPost) | None, Method::POST) => {
199 if sessions_cookie.is_empty() && !params.did_mas_repost_to_itself {
205 let params = Params {
206 did_mas_repost_to_itself: true,
207 ..params
208 };
209 let context = FormPostContext::new_for_current_url(params).with_language(&locale);
210 let html = templates.render_form_post(&context)?;
211 return Ok(Html(html).into_response());
212 }
213 }
214 (None, _) | (Some(UpstreamOAuthProviderResponseMode::Query), Method::GET) => {}
215 (Some(expected), _) => return Err(RouteError::InvalidResponseMode { expected }),
216 }
217
218 if let Some(error) = params.error {
219 return Err(RouteError::ClientError {
220 error,
221 error_description: params.error_description.clone(),
222 });
223 }
224
225 let Some(state) = params.state else {
226 return Err(RouteError::MissingState);
227 };
228
229 let (session_id, _post_auth_action) = sessions_cookie
230 .find_session(provider_id, &state)
231 .map_err(|_| RouteError::MissingCookie)?;
232
233 let session = repo
234 .upstream_oauth_session()
235 .lookup(session_id)
236 .await?
237 .ok_or(RouteError::SessionNotFound)?;
238
239 if provider.id != session.provider_id {
240 return Err(RouteError::ProviderMismatch);
242 }
243
244 if state != session.state_str {
245 return Err(RouteError::StateMismatch);
247 }
248
249 if !session.is_pending() {
250 return Err(RouteError::AlreadyCompleted);
252 }
253
254 let Some(code) = params.code else {
256 return Err(RouteError::MissingCode);
257 };
258
259 let mut lazy_metadata = LazyProviderInfos::new(&metadata_cache, &provider, &client);
260
261 let client_credentials = client_credentials_for_provider(
263 &provider,
264 lazy_metadata.token_endpoint().await?,
265 &keystore,
266 &encrypter,
267 )?;
268
269 let redirect_uri = url_builder.upstream_oauth_callback(provider.id);
270
271 let token_response = mas_oidc_client::requests::token::request_access_token(
272 &client,
273 client_credentials,
274 lazy_metadata.token_endpoint().await?,
275 AccessTokenRequest::AuthorizationCode(oauth2_types::requests::AuthorizationCodeGrant {
276 code: code.clone(),
277 redirect_uri: Some(redirect_uri),
278 code_verifier: session.code_challenge_verifier.clone(),
279 }),
280 clock.now(),
281 &mut rng,
282 )
283 .await?;
284
285 let mut jwks = None;
286
287 let mut context = AttributeMappingContext::new();
288 if let Some(id_token) = token_response.id_token.as_ref() {
289 jwks = Some(
290 mas_oidc_client::requests::jose::fetch_jwks(&client, lazy_metadata.jwks_uri().await?)
291 .await?,
292 );
293
294 let id_token_verification_data = JwtVerificationData {
295 issuer: provider.issuer.as_deref(),
296 jwks: jwks.as_ref().unwrap(),
297 signing_algorithm: &provider.id_token_signed_response_alg,
298 client_id: &provider.client_id,
299 };
300
301 let id_token = mas_oidc_client::requests::jose::verify_id_token(
303 id_token,
304 id_token_verification_data,
305 None,
306 clock.now(),
307 )?;
308
309 let (_headers, mut claims) = id_token.into_parts();
310
311 mas_jose::claims::AT_HASH
313 .extract_optional_with_options(
314 &mut claims,
315 TokenHash::new(
316 id_token_verification_data.signing_algorithm,
317 &token_response.access_token,
318 ),
319 )
320 .map_err(mas_oidc_client::error::IdTokenError::from)?;
321
322 mas_jose::claims::C_HASH
324 .extract_optional_with_options(
325 &mut claims,
326 TokenHash::new(id_token_verification_data.signing_algorithm, &code),
327 )
328 .map_err(mas_oidc_client::error::IdTokenError::from)?;
329
330 mas_jose::claims::NONCE
332 .extract_required_with_options(&mut claims, session.nonce.as_str())
333 .map_err(mas_oidc_client::error::IdTokenError::from)?;
334
335 context = context.with_id_token_claims(claims);
336 }
337
338 if let Some(extra_callback_parameters) = params.extra_callback_parameters.clone() {
339 context = context.with_extra_callback_parameters(extra_callback_parameters);
340 }
341
342 let userinfo = if provider.fetch_userinfo {
343 Some(json!(match &provider.userinfo_signed_response_alg {
344 Some(signing_algorithm) => {
345 let jwks = match jwks {
346 Some(jwks) => jwks,
347 None => {
348 mas_oidc_client::requests::jose::fetch_jwks(
349 &client,
350 lazy_metadata.jwks_uri().await?,
351 )
352 .await?
353 }
354 };
355
356 mas_oidc_client::requests::userinfo::fetch_userinfo(
357 &client,
358 lazy_metadata.userinfo_endpoint().await?,
359 token_response.access_token.as_str(),
360 Some(JwtVerificationData {
361 issuer: provider.issuer.as_deref(),
362 jwks: &jwks,
363 signing_algorithm,
364 client_id: &provider.client_id,
365 }),
366 )
367 .await?
368 }
369 None => {
370 mas_oidc_client::requests::userinfo::fetch_userinfo(
371 &client,
372 lazy_metadata.userinfo_endpoint().await?,
373 token_response.access_token.as_str(),
374 None,
375 )
376 .await?
377 }
378 }))
379 } else {
380 None
381 };
382
383 if let Some(userinfo) = userinfo.clone() {
384 context = context.with_userinfo_claims(userinfo);
385 }
386
387 let context = context.build();
388
389 let env = environment();
390
391 let template = provider
392 .claims_imports
393 .subject
394 .template
395 .as_deref()
396 .unwrap_or("{{ user.sub }}");
397 let subject = env
398 .render_str(template, context.clone())
399 .map_err(RouteError::ExtractSubject)?;
400
401 if subject.is_empty() {
402 return Err(RouteError::EmptySubject);
403 }
404
405 let maybe_link = repo
407 .upstream_oauth_link()
408 .find_by_subject(&provider, &subject)
409 .await?;
410
411 let link = if let Some(link) = maybe_link {
412 link
413 } else {
414 let human_account_name = provider
417 .claims_imports
418 .account_name
419 .template
420 .as_deref()
421 .and_then(|template| match env.render_str(template, context) {
422 Ok(name) => Some(name),
423 Err(e) => {
424 tracing::warn!(
425 error = &e as &dyn std::error::Error,
426 "Failed to render account name"
427 );
428 None
429 }
430 });
431
432 repo.upstream_oauth_link()
433 .add(&mut rng, &clock, &provider, subject, human_account_name)
434 .await?
435 };
436
437 let session = repo
438 .upstream_oauth_session()
439 .complete_with_link(
440 &clock,
441 session,
442 &link,
443 token_response.id_token,
444 params.extra_callback_parameters,
445 userinfo,
446 )
447 .await?;
448
449 let cookie_jar = sessions_cookie
450 .add_link_to_session(session.id, link.id)?
451 .save(cookie_jar, &clock);
452
453 repo.save().await?;
454
455 Ok((
456 cookie_jar,
457 url_builder.redirect(&mas_router::UpstreamOAuth2Link::new(link.id)),
458 )
459 .into_response())
460}