mas_handlers/upstream_oauth2/
callback.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use axum::{
8    Form,
9    extract::{Path, State},
10    http::Method,
11    response::{Html, IntoResponse, Response},
12};
13use hyper::StatusCode;
14use mas_axum_utils::{cookies::CookieJar, sentry::SentryEventID};
15use mas_data_model::{UpstreamOAuthProvider, UpstreamOAuthProviderResponseMode};
16use mas_jose::claims::TokenHash;
17use mas_keystore::{Encrypter, Keystore};
18use mas_oidc_client::requests::jose::JwtVerificationData;
19use mas_router::UrlBuilder;
20use mas_storage::{
21    BoxClock, BoxRepository, BoxRng, Clock,
22    upstream_oauth2::{
23        UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository,
24        UpstreamOAuthSessionRepository,
25    },
26};
27use mas_templates::{FormPostContext, Templates};
28use oauth2_types::{errors::ClientErrorCode, requests::AccessTokenRequest};
29use serde::{Deserialize, Serialize};
30use serde_json::json;
31use thiserror::Error;
32use ulid::Ulid;
33
34use super::{
35    UpstreamSessionsCookie,
36    cache::LazyProviderInfos,
37    client_credentials_for_provider,
38    template::{AttributeMappingContext, environment},
39};
40use crate::{PreferredLanguage, impl_from_error_for_route, upstream_oauth2::cache::MetadataCache};
41
42#[derive(Serialize, Deserialize)]
43pub struct Params {
44    #[serde(skip_serializing_if = "Option::is_none")]
45    state: Option<String>,
46
47    /// An extra parameter to track whether the POST request was re-made by us
48    /// to the same URL to escape Same-Site cookies restrictions
49    #[serde(default)]
50    did_mas_repost_to_itself: bool,
51
52    #[serde(skip_serializing_if = "Option::is_none")]
53    code: Option<String>,
54
55    #[serde(skip_serializing_if = "Option::is_none")]
56    error: Option<ClientErrorCode>,
57    #[serde(skip_serializing_if = "Option::is_none")]
58    error_description: Option<String>,
59    #[serde(skip_serializing_if = "Option::is_none")]
60    error_uri: Option<String>,
61
62    #[serde(flatten)]
63    extra_callback_parameters: Option<serde_json::Value>,
64}
65
66impl Params {
67    /// Returns true if none of the fields are set
68    pub fn is_empty(&self) -> bool {
69        self.state.is_none()
70            && self.code.is_none()
71            && self.error.is_none()
72            && self.error_description.is_none()
73            && self.error_uri.is_none()
74    }
75}
76
77#[derive(Debug, Error)]
78pub(crate) enum RouteError {
79    #[error("Session not found")]
80    SessionNotFound,
81
82    #[error("Provider not found")]
83    ProviderNotFound,
84
85    #[error("Provider mismatch")]
86    ProviderMismatch,
87
88    #[error("Session already completed")]
89    AlreadyCompleted,
90
91    #[error("State parameter mismatch")]
92    StateMismatch,
93
94    #[error("Missing state parameter")]
95    MissingState,
96
97    #[error("Missing code parameter")]
98    MissingCode,
99
100    #[error("Could not extract subject from ID token")]
101    ExtractSubject(#[source] minijinja::Error),
102
103    #[error("Subject is empty")]
104    EmptySubject,
105
106    #[error("Error from the provider: {error}")]
107    ClientError {
108        error: ClientErrorCode,
109        error_description: Option<String>,
110    },
111
112    #[error("Missing session cookie")]
113    MissingCookie,
114
115    #[error("Missing query parameters")]
116    MissingQueryParams,
117
118    #[error("Missing form parameters")]
119    MissingFormParams,
120
121    #[error("Invalid response mode, expected '{expected}'")]
122    InvalidResponseMode {
123        expected: UpstreamOAuthProviderResponseMode,
124    },
125
126    #[error(transparent)]
127    Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
128}
129
130impl_from_error_for_route!(mas_templates::TemplateError);
131impl_from_error_for_route!(mas_storage::RepositoryError);
132impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError);
133impl_from_error_for_route!(mas_oidc_client::error::JwksError);
134impl_from_error_for_route!(mas_oidc_client::error::TokenRequestError);
135impl_from_error_for_route!(mas_oidc_client::error::IdTokenError);
136impl_from_error_for_route!(mas_oidc_client::error::UserInfoError);
137impl_from_error_for_route!(super::ProviderCredentialsError);
138impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
139
140impl IntoResponse for RouteError {
141    fn into_response(self) -> axum::response::Response {
142        let event_id = sentry::capture_error(&self);
143        let response = match self {
144            Self::ProviderNotFound => (StatusCode::NOT_FOUND, "Provider not found").into_response(),
145            Self::SessionNotFound => (StatusCode::NOT_FOUND, "Session not found").into_response(),
146            Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
147            e => (StatusCode::BAD_REQUEST, e.to_string()).into_response(),
148        };
149
150        (SentryEventID::from(event_id), response).into_response()
151    }
152}
153
154#[tracing::instrument(
155    name = "handlers.upstream_oauth2.callback.handler",
156    fields(upstream_oauth_provider.id = %provider_id),
157    skip_all,
158    err,
159)]
160#[allow(clippy::too_many_lines, clippy::too_many_arguments)]
161pub(crate) async fn handler(
162    mut rng: BoxRng,
163    clock: BoxClock,
164    State(metadata_cache): State<MetadataCache>,
165    mut repo: BoxRepository,
166    State(url_builder): State<UrlBuilder>,
167    State(encrypter): State<Encrypter>,
168    State(keystore): State<Keystore>,
169    State(client): State<reqwest::Client>,
170    State(templates): State<Templates>,
171    method: Method,
172    PreferredLanguage(locale): PreferredLanguage,
173    cookie_jar: CookieJar,
174    Path(provider_id): Path<Ulid>,
175    Form(params): Form<Params>,
176) -> Result<Response, RouteError> {
177    let provider = repo
178        .upstream_oauth_provider()
179        .lookup(provider_id)
180        .await?
181        .filter(UpstreamOAuthProvider::enabled)
182        .ok_or(RouteError::ProviderNotFound)?;
183
184    let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
185
186    if params.is_empty() {
187        if let Method::GET = method {
188            return Err(RouteError::MissingQueryParams);
189        }
190
191        return Err(RouteError::MissingFormParams);
192    }
193
194    // The `Form` extractor will use the body of the request for POST requests and
195    // the query parameters for GET requests. We need to then look at the method do
196    // make sure it matches the expected `response_mode`
197    match (provider.response_mode, method) {
198        (Some(UpstreamOAuthProviderResponseMode::FormPost) | None, Method::POST) => {
199            // We set the cookies with a `Same-Site` policy set to `Lax`, so because this is
200            // usually a cross-site form POST, we need to render a form with the
201            // same values, which posts back to the same URL. However, there are
202            // other valid reasons for the cookie to be missing, so to track whether we did
203            // this POST ourselves, we set a flag.
204            if sessions_cookie.is_empty() && !params.did_mas_repost_to_itself {
205                let params = Params {
206                    did_mas_repost_to_itself: true,
207                    ..params
208                };
209                let context = FormPostContext::new_for_current_url(params).with_language(&locale);
210                let html = templates.render_form_post(&context)?;
211                return Ok(Html(html).into_response());
212            }
213        }
214        (None, _) | (Some(UpstreamOAuthProviderResponseMode::Query), Method::GET) => {}
215        (Some(expected), _) => return Err(RouteError::InvalidResponseMode { expected }),
216    }
217
218    if let Some(error) = params.error {
219        return Err(RouteError::ClientError {
220            error,
221            error_description: params.error_description.clone(),
222        });
223    }
224
225    let Some(state) = params.state else {
226        return Err(RouteError::MissingState);
227    };
228
229    let (session_id, _post_auth_action) = sessions_cookie
230        .find_session(provider_id, &state)
231        .map_err(|_| RouteError::MissingCookie)?;
232
233    let session = repo
234        .upstream_oauth_session()
235        .lookup(session_id)
236        .await?
237        .ok_or(RouteError::SessionNotFound)?;
238
239    if provider.id != session.provider_id {
240        // The provider in the session cookie should match the one from the URL
241        return Err(RouteError::ProviderMismatch);
242    }
243
244    if state != session.state_str {
245        // The state in the session cookie should match the one from the params
246        return Err(RouteError::StateMismatch);
247    }
248
249    if !session.is_pending() {
250        // The session was already completed
251        return Err(RouteError::AlreadyCompleted);
252    }
253
254    // Let's extract the code from the params, and return if there was an error
255    let Some(code) = params.code else {
256        return Err(RouteError::MissingCode);
257    };
258
259    let mut lazy_metadata = LazyProviderInfos::new(&metadata_cache, &provider, &client);
260
261    // Figure out the client credentials
262    let client_credentials = client_credentials_for_provider(
263        &provider,
264        lazy_metadata.token_endpoint().await?,
265        &keystore,
266        &encrypter,
267    )?;
268
269    let redirect_uri = url_builder.upstream_oauth_callback(provider.id);
270
271    let token_response = mas_oidc_client::requests::token::request_access_token(
272        &client,
273        client_credentials,
274        lazy_metadata.token_endpoint().await?,
275        AccessTokenRequest::AuthorizationCode(oauth2_types::requests::AuthorizationCodeGrant {
276            code: code.clone(),
277            redirect_uri: Some(redirect_uri),
278            code_verifier: session.code_challenge_verifier.clone(),
279        }),
280        clock.now(),
281        &mut rng,
282    )
283    .await?;
284
285    let mut jwks = None;
286
287    let mut context = AttributeMappingContext::new();
288    if let Some(id_token) = token_response.id_token.as_ref() {
289        jwks = Some(
290            mas_oidc_client::requests::jose::fetch_jwks(&client, lazy_metadata.jwks_uri().await?)
291                .await?,
292        );
293
294        let id_token_verification_data = JwtVerificationData {
295            issuer: provider.issuer.as_deref(),
296            jwks: jwks.as_ref().unwrap(),
297            signing_algorithm: &provider.id_token_signed_response_alg,
298            client_id: &provider.client_id,
299        };
300
301        // Decode and verify the ID token
302        let id_token = mas_oidc_client::requests::jose::verify_id_token(
303            id_token,
304            id_token_verification_data,
305            None,
306            clock.now(),
307        )?;
308
309        let (_headers, mut claims) = id_token.into_parts();
310
311        // Access token hash must match.
312        mas_jose::claims::AT_HASH
313            .extract_optional_with_options(
314                &mut claims,
315                TokenHash::new(
316                    id_token_verification_data.signing_algorithm,
317                    &token_response.access_token,
318                ),
319            )
320            .map_err(mas_oidc_client::error::IdTokenError::from)?;
321
322        // Code hash must match.
323        mas_jose::claims::C_HASH
324            .extract_optional_with_options(
325                &mut claims,
326                TokenHash::new(id_token_verification_data.signing_algorithm, &code),
327            )
328            .map_err(mas_oidc_client::error::IdTokenError::from)?;
329
330        // Nonce must match.
331        mas_jose::claims::NONCE
332            .extract_required_with_options(&mut claims, session.nonce.as_str())
333            .map_err(mas_oidc_client::error::IdTokenError::from)?;
334
335        context = context.with_id_token_claims(claims);
336    }
337
338    if let Some(extra_callback_parameters) = params.extra_callback_parameters.clone() {
339        context = context.with_extra_callback_parameters(extra_callback_parameters);
340    }
341
342    let userinfo = if provider.fetch_userinfo {
343        Some(json!(match &provider.userinfo_signed_response_alg {
344            Some(signing_algorithm) => {
345                let jwks = match jwks {
346                    Some(jwks) => jwks,
347                    None => {
348                        mas_oidc_client::requests::jose::fetch_jwks(
349                            &client,
350                            lazy_metadata.jwks_uri().await?,
351                        )
352                        .await?
353                    }
354                };
355
356                mas_oidc_client::requests::userinfo::fetch_userinfo(
357                    &client,
358                    lazy_metadata.userinfo_endpoint().await?,
359                    token_response.access_token.as_str(),
360                    Some(JwtVerificationData {
361                        issuer: provider.issuer.as_deref(),
362                        jwks: &jwks,
363                        signing_algorithm,
364                        client_id: &provider.client_id,
365                    }),
366                )
367                .await?
368            }
369            None => {
370                mas_oidc_client::requests::userinfo::fetch_userinfo(
371                    &client,
372                    lazy_metadata.userinfo_endpoint().await?,
373                    token_response.access_token.as_str(),
374                    None,
375                )
376                .await?
377            }
378        }))
379    } else {
380        None
381    };
382
383    if let Some(userinfo) = userinfo.clone() {
384        context = context.with_userinfo_claims(userinfo);
385    }
386
387    let context = context.build();
388
389    let env = environment();
390
391    let template = provider
392        .claims_imports
393        .subject
394        .template
395        .as_deref()
396        .unwrap_or("{{ user.sub }}");
397    let subject = env
398        .render_str(template, context.clone())
399        .map_err(RouteError::ExtractSubject)?;
400
401    if subject.is_empty() {
402        return Err(RouteError::EmptySubject);
403    }
404
405    // Look for an existing link
406    let maybe_link = repo
407        .upstream_oauth_link()
408        .find_by_subject(&provider, &subject)
409        .await?;
410
411    let link = if let Some(link) = maybe_link {
412        link
413    } else {
414        // Try to render the human account name if we have one,
415        // but just log if it fails
416        let human_account_name = provider
417            .claims_imports
418            .account_name
419            .template
420            .as_deref()
421            .and_then(|template| match env.render_str(template, context) {
422                Ok(name) => Some(name),
423                Err(e) => {
424                    tracing::warn!(
425                        error = &e as &dyn std::error::Error,
426                        "Failed to render account name"
427                    );
428                    None
429                }
430            });
431
432        repo.upstream_oauth_link()
433            .add(&mut rng, &clock, &provider, subject, human_account_name)
434            .await?
435    };
436
437    let session = repo
438        .upstream_oauth_session()
439        .complete_with_link(
440            &clock,
441            session,
442            &link,
443            token_response.id_token,
444            params.extra_callback_parameters,
445            userinfo,
446        )
447        .await?;
448
449    let cookie_jar = sessions_cookie
450        .add_link_to_session(session.id, link.id)?
451        .save(cookie_jar, &clock);
452
453    repo.save().await?;
454
455    Ok((
456        cookie_jar,
457        url_builder.redirect(&mas_router::UpstreamOAuth2Link::new(link.id)),
458    )
459        .into_response())
460}