mas_handlers/graphql/
mod.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7#![allow(clippy::module_name_repetitions)]
8
9use std::{net::IpAddr, ops::Deref, sync::Arc};
10
11use async_graphql::{
12    EmptySubscription, InputObject,
13    extensions::Tracing,
14    http::{GraphQLPlaygroundConfig, MultipartOptions, playground_source},
15};
16use axum::{
17    Extension, Json,
18    body::Body,
19    extract::{RawQuery, State as AxumState},
20    http::StatusCode,
21    response::{Html, IntoResponse, Response},
22};
23use axum_extra::typed_header::TypedHeader;
24use chrono::{DateTime, Utc};
25use futures_util::TryStreamExt;
26use headers::{Authorization, ContentType, HeaderValue, authorization::Bearer};
27use hyper::header::CACHE_CONTROL;
28use mas_axum_utils::{
29    FancyError, SessionInfo, SessionInfoExt, cookies::CookieJar, sentry::SentryEventID,
30};
31use mas_data_model::{BrowserSession, Session, SiteConfig, User};
32use mas_matrix::HomeserverConnection;
33use mas_policy::{InstantiateError, Policy, PolicyFactory};
34use mas_router::UrlBuilder;
35use mas_storage::{BoxClock, BoxRepository, BoxRng, Clock, RepositoryError, SystemClock};
36use mas_storage_pg::PgRepository;
37use opentelemetry_semantic_conventions::trace::{GRAPHQL_DOCUMENT, GRAPHQL_OPERATION_NAME};
38use rand::{SeedableRng, thread_rng};
39use rand_chacha::ChaChaRng;
40use sqlx::PgPool;
41use tracing::{Instrument, info_span};
42use ulid::Ulid;
43
44mod model;
45mod mutations;
46mod query;
47mod state;
48
49pub use self::state::{BoxState, State};
50use self::{
51    model::{CreationEvent, Node},
52    mutations::Mutation,
53    query::Query,
54};
55use crate::{
56    BoundActivityTracker, Limiter, RequesterFingerprint, impl_from_error_for_route,
57    passwords::PasswordManager,
58};
59
60#[cfg(test)]
61mod tests;
62
63/// Extra parameters we get from the listener configuration, because they are
64/// per-listener options. We pass them through request extensions.
65#[derive(Debug, Clone)]
66pub struct ExtraRouterParameters {
67    pub undocumented_oauth2_access: bool,
68}
69
70struct GraphQLState {
71    pool: PgPool,
72    homeserver_connection: Arc<dyn HomeserverConnection>,
73    policy_factory: Arc<PolicyFactory>,
74    site_config: SiteConfig,
75    password_manager: PasswordManager,
76    url_builder: UrlBuilder,
77    limiter: Limiter,
78}
79
80#[async_trait::async_trait]
81impl state::State for GraphQLState {
82    async fn repository(&self) -> Result<BoxRepository, RepositoryError> {
83        let repo = PgRepository::from_pool(&self.pool)
84            .await
85            .map_err(RepositoryError::from_error)?;
86
87        Ok(repo.boxed())
88    }
89
90    async fn policy(&self) -> Result<Policy, InstantiateError> {
91        self.policy_factory.instantiate().await
92    }
93
94    fn password_manager(&self) -> PasswordManager {
95        self.password_manager.clone()
96    }
97
98    fn site_config(&self) -> &SiteConfig {
99        &self.site_config
100    }
101
102    fn homeserver_connection(&self) -> &dyn HomeserverConnection {
103        self.homeserver_connection.as_ref()
104    }
105
106    fn url_builder(&self) -> &UrlBuilder {
107        &self.url_builder
108    }
109
110    fn limiter(&self) -> &Limiter {
111        &self.limiter
112    }
113
114    fn clock(&self) -> BoxClock {
115        let clock = SystemClock::default();
116        Box::new(clock)
117    }
118
119    fn rng(&self) -> BoxRng {
120        #[allow(clippy::disallowed_methods)]
121        let rng = thread_rng();
122
123        let rng = ChaChaRng::from_rng(rng).expect("Failed to seed rng");
124        Box::new(rng)
125    }
126}
127
128#[must_use]
129pub fn schema(
130    pool: &PgPool,
131    policy_factory: &Arc<PolicyFactory>,
132    homeserver_connection: impl HomeserverConnection + 'static,
133    site_config: SiteConfig,
134    password_manager: PasswordManager,
135    url_builder: UrlBuilder,
136    limiter: Limiter,
137) -> Schema {
138    let state = GraphQLState {
139        pool: pool.clone(),
140        policy_factory: Arc::clone(policy_factory),
141        homeserver_connection: Arc::new(homeserver_connection),
142        site_config,
143        password_manager,
144        url_builder,
145        limiter,
146    };
147    let state: BoxState = Box::new(state);
148
149    schema_builder().extension(Tracing).data(state).finish()
150}
151
152fn span_for_graphql_request(request: &async_graphql::Request) -> tracing::Span {
153    let span = info_span!(
154        "GraphQL operation",
155        "otel.name" = tracing::field::Empty,
156        "otel.kind" = "server",
157        { GRAPHQL_DOCUMENT } = request.query,
158        { GRAPHQL_OPERATION_NAME } = tracing::field::Empty,
159    );
160
161    if let Some(name) = &request.operation_name {
162        span.record("otel.name", name);
163        span.record(GRAPHQL_OPERATION_NAME, name);
164    }
165
166    span
167}
168
169#[derive(thiserror::Error, Debug)]
170pub enum RouteError {
171    #[error(transparent)]
172    Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
173
174    #[error("Loading of some database objects failed")]
175    LoadFailed,
176
177    #[error("Invalid access token")]
178    InvalidToken,
179
180    #[error("Missing scope")]
181    MissingScope,
182
183    #[error(transparent)]
184    ParseRequest(#[from] async_graphql::ParseRequestError),
185}
186
187impl_from_error_for_route!(mas_storage::RepositoryError);
188
189impl IntoResponse for RouteError {
190    fn into_response(self) -> Response {
191        let event_id = sentry::capture_error(&self);
192
193        let response = match self {
194            e @ (Self::Internal(_) | Self::LoadFailed) => {
195                let error = async_graphql::Error::new_with_source(e);
196                (
197                    StatusCode::INTERNAL_SERVER_ERROR,
198                    Json(serde_json::json!({"errors": [error]})),
199                )
200                    .into_response()
201            }
202
203            Self::InvalidToken => {
204                let error = async_graphql::Error::new("Invalid token");
205                (
206                    StatusCode::UNAUTHORIZED,
207                    Json(serde_json::json!({"errors": [error]})),
208                )
209                    .into_response()
210            }
211
212            Self::MissingScope => {
213                let error = async_graphql::Error::new("Missing urn:mas:graphql:* scope");
214                (
215                    StatusCode::UNAUTHORIZED,
216                    Json(serde_json::json!({"errors": [error]})),
217                )
218                    .into_response()
219            }
220
221            Self::ParseRequest(e) => {
222                let error = async_graphql::Error::new_with_source(e);
223                (
224                    StatusCode::BAD_REQUEST,
225                    Json(serde_json::json!({"errors": [error]})),
226                )
227                    .into_response()
228            }
229        };
230
231        (SentryEventID::from(event_id), response).into_response()
232    }
233}
234
235async fn get_requester(
236    undocumented_oauth2_access: bool,
237    clock: &impl Clock,
238    activity_tracker: &BoundActivityTracker,
239    mut repo: BoxRepository,
240    session_info: SessionInfo,
241    user_agent: Option<String>,
242    token: Option<&str>,
243) -> Result<Requester, RouteError> {
244    let entity = if let Some(token) = token {
245        // If we haven't enabled undocumented_oauth2_access on the listener, we bail out
246        if !undocumented_oauth2_access {
247            return Err(RouteError::InvalidToken);
248        }
249
250        let token = repo
251            .oauth2_access_token()
252            .find_by_token(token)
253            .await?
254            .ok_or(RouteError::InvalidToken)?;
255
256        let session = repo
257            .oauth2_session()
258            .lookup(token.session_id)
259            .await?
260            .ok_or(RouteError::LoadFailed)?;
261
262        activity_tracker
263            .record_oauth2_session(clock, &session)
264            .await;
265
266        // Load the user if there is one
267        let user = if let Some(user_id) = session.user_id {
268            let user = repo
269                .user()
270                .lookup(user_id)
271                .await?
272                .ok_or(RouteError::LoadFailed)?;
273            Some(user)
274        } else {
275            None
276        };
277
278        // If there is a user for this session, check that it is not locked
279        let user_valid = user.as_ref().is_none_or(User::is_valid);
280
281        if !token.is_valid(clock.now()) || !session.is_valid() || !user_valid {
282            return Err(RouteError::InvalidToken);
283        }
284
285        if !session.scope.contains("urn:mas:graphql:*") {
286            return Err(RouteError::MissingScope);
287        }
288
289        RequestingEntity::OAuth2Session(Box::new((session, user)))
290    } else {
291        let maybe_session = session_info.load_active_session(&mut repo).await?;
292
293        if let Some(session) = maybe_session.as_ref() {
294            activity_tracker
295                .record_browser_session(clock, session)
296                .await;
297        }
298
299        RequestingEntity::from(maybe_session)
300    };
301
302    let requester = Requester {
303        entity,
304        ip_address: activity_tracker.ip(),
305        user_agent,
306    };
307
308    repo.cancel().await?;
309    Ok(requester)
310}
311
312pub async fn post(
313    AxumState(schema): AxumState<Schema>,
314    Extension(ExtraRouterParameters {
315        undocumented_oauth2_access,
316    }): Extension<ExtraRouterParameters>,
317    clock: BoxClock,
318    repo: BoxRepository,
319    activity_tracker: BoundActivityTracker,
320    cookie_jar: CookieJar,
321    content_type: Option<TypedHeader<ContentType>>,
322    authorization: Option<TypedHeader<Authorization<Bearer>>>,
323    user_agent: Option<TypedHeader<headers::UserAgent>>,
324    body: Body,
325) -> Result<impl IntoResponse, RouteError> {
326    let body = body.into_data_stream();
327    let token = authorization
328        .as_ref()
329        .map(|TypedHeader(Authorization(bearer))| bearer.token());
330    let user_agent = user_agent.map(|TypedHeader(h)| h.to_string());
331    let (session_info, _cookie_jar) = cookie_jar.session_info();
332    let requester = get_requester(
333        undocumented_oauth2_access,
334        &clock,
335        &activity_tracker,
336        repo,
337        session_info,
338        user_agent,
339        token,
340    )
341    .await?;
342
343    let content_type = content_type.map(|TypedHeader(h)| h.to_string());
344
345    let request = async_graphql::http::receive_body(
346        content_type,
347        body.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
348            .into_async_read(),
349        MultipartOptions::default(),
350    )
351    .await?
352    .data(requester); // XXX: this should probably return another error response?
353
354    let span = span_for_graphql_request(&request);
355    let response = schema.execute(request).instrument(span).await;
356
357    let cache_control = response
358        .cache_control
359        .value()
360        .and_then(|v| HeaderValue::from_str(&v).ok())
361        .map(|h| [(CACHE_CONTROL, h)]);
362
363    let headers = response.http_headers.clone();
364
365    Ok((headers, cache_control, Json(response)))
366}
367
368pub async fn get(
369    AxumState(schema): AxumState<Schema>,
370    Extension(ExtraRouterParameters {
371        undocumented_oauth2_access,
372    }): Extension<ExtraRouterParameters>,
373    clock: BoxClock,
374    repo: BoxRepository,
375    activity_tracker: BoundActivityTracker,
376    cookie_jar: CookieJar,
377    authorization: Option<TypedHeader<Authorization<Bearer>>>,
378    user_agent: Option<TypedHeader<headers::UserAgent>>,
379    RawQuery(query): RawQuery,
380) -> Result<impl IntoResponse, FancyError> {
381    let token = authorization
382        .as_ref()
383        .map(|TypedHeader(Authorization(bearer))| bearer.token());
384    let user_agent = user_agent.map(|TypedHeader(h)| h.to_string());
385    let (session_info, _cookie_jar) = cookie_jar.session_info();
386    let requester = get_requester(
387        undocumented_oauth2_access,
388        &clock,
389        &activity_tracker,
390        repo,
391        session_info,
392        user_agent,
393        token,
394    )
395    .await?;
396
397    let request =
398        async_graphql::http::parse_query_string(&query.unwrap_or_default())?.data(requester);
399
400    let span = span_for_graphql_request(&request);
401    let response = schema.execute(request).instrument(span).await;
402
403    let cache_control = response
404        .cache_control
405        .value()
406        .and_then(|v| HeaderValue::from_str(&v).ok())
407        .map(|h| [(CACHE_CONTROL, h)]);
408
409    let headers = response.http_headers.clone();
410
411    Ok((headers, cache_control, Json(response)))
412}
413
414pub async fn playground() -> impl IntoResponse {
415    Html(playground_source(
416        GraphQLPlaygroundConfig::new("/graphql").with_setting("request.credentials", "include"),
417    ))
418}
419
420pub type Schema = async_graphql::Schema<Query, Mutation, EmptySubscription>;
421pub type SchemaBuilder = async_graphql::SchemaBuilder<Query, Mutation, EmptySubscription>;
422
423#[must_use]
424pub fn schema_builder() -> SchemaBuilder {
425    async_graphql::Schema::build(Query::new(), Mutation::new(), EmptySubscription)
426        .register_output_type::<Node>()
427        .register_output_type::<CreationEvent>()
428}
429
430pub struct Requester {
431    entity: RequestingEntity,
432    ip_address: Option<IpAddr>,
433    user_agent: Option<String>,
434}
435
436impl Requester {
437    pub fn fingerprint(&self) -> RequesterFingerprint {
438        if let Some(ip) = self.ip_address {
439            RequesterFingerprint::new(ip)
440        } else {
441            RequesterFingerprint::EMPTY
442        }
443    }
444
445    pub fn for_policy(&self) -> mas_policy::Requester {
446        mas_policy::Requester {
447            ip_address: self.ip_address,
448            user_agent: self.user_agent.clone(),
449        }
450    }
451}
452
453impl Deref for Requester {
454    type Target = RequestingEntity;
455
456    fn deref(&self) -> &Self::Target {
457        &self.entity
458    }
459}
460
461/// The identity of the requester.
462#[derive(Debug, Clone, Default, PartialEq, Eq)]
463pub enum RequestingEntity {
464    /// The requester presented no authentication information.
465    #[default]
466    Anonymous,
467
468    /// The requester is a browser session, stored in a cookie.
469    BrowserSession(Box<BrowserSession>),
470
471    /// The requester is a `OAuth2` session, with an access token.
472    OAuth2Session(Box<(Session, Option<User>)>),
473}
474
475trait OwnerId {
476    fn owner_id(&self) -> Option<Ulid>;
477}
478
479impl OwnerId for User {
480    fn owner_id(&self) -> Option<Ulid> {
481        Some(self.id)
482    }
483}
484
485impl OwnerId for BrowserSession {
486    fn owner_id(&self) -> Option<Ulid> {
487        Some(self.user.id)
488    }
489}
490
491impl OwnerId for mas_data_model::UserEmail {
492    fn owner_id(&self) -> Option<Ulid> {
493        Some(self.user_id)
494    }
495}
496
497impl OwnerId for Session {
498    fn owner_id(&self) -> Option<Ulid> {
499        self.user_id
500    }
501}
502
503impl OwnerId for mas_data_model::CompatSession {
504    fn owner_id(&self) -> Option<Ulid> {
505        Some(self.user_id)
506    }
507}
508
509impl OwnerId for mas_data_model::UpstreamOAuthLink {
510    fn owner_id(&self) -> Option<Ulid> {
511        self.user_id
512    }
513}
514
515/// A dumb wrapper around a `Ulid` to implement `OwnerId` for it.
516pub struct UserId(Ulid);
517
518impl OwnerId for UserId {
519    fn owner_id(&self) -> Option<Ulid> {
520        Some(self.0)
521    }
522}
523
524impl RequestingEntity {
525    fn browser_session(&self) -> Option<&BrowserSession> {
526        match self {
527            Self::BrowserSession(session) => Some(session),
528            Self::OAuth2Session(_) | Self::Anonymous => None,
529        }
530    }
531
532    fn user(&self) -> Option<&User> {
533        match self {
534            Self::BrowserSession(session) => Some(&session.user),
535            Self::OAuth2Session(tuple) => tuple.1.as_ref(),
536            Self::Anonymous => None,
537        }
538    }
539
540    fn oauth2_session(&self) -> Option<&Session> {
541        match self {
542            Self::OAuth2Session(tuple) => Some(&tuple.0),
543            Self::BrowserSession(_) | Self::Anonymous => None,
544        }
545    }
546
547    /// Returns true if the requester can access the resource.
548    fn is_owner_or_admin(&self, resource: &impl OwnerId) -> bool {
549        // If the requester is an admin, they can do anything.
550        if self.is_admin() {
551            return true;
552        }
553
554        // Otherwise, they must be the owner of the resource.
555        let Some(owner_id) = resource.owner_id() else {
556            return false;
557        };
558
559        let Some(user) = self.user() else {
560            return false;
561        };
562
563        user.id == owner_id
564    }
565
566    fn is_admin(&self) -> bool {
567        match self {
568            Self::OAuth2Session(tuple) => {
569                // TODO: is this the right scope?
570                // This has to be in sync with the policy
571                tuple.0.scope.contains("urn:mas:admin")
572            }
573            Self::BrowserSession(_) | Self::Anonymous => false,
574        }
575    }
576
577    fn is_unauthenticated(&self) -> bool {
578        matches!(self, Self::Anonymous)
579    }
580}
581
582impl From<BrowserSession> for RequestingEntity {
583    fn from(session: BrowserSession) -> Self {
584        Self::BrowserSession(Box::new(session))
585    }
586}
587
588impl<T> From<Option<T>> for RequestingEntity
589where
590    T: Into<RequestingEntity>,
591{
592    fn from(session: Option<T>) -> Self {
593        session.map(Into::into).unwrap_or_default()
594    }
595}
596
597/// A filter for dates, with a lower bound and an upper bound
598#[derive(InputObject, Default, Clone, Copy)]
599pub struct DateFilter {
600    /// The lower bound of the date range
601    after: Option<DateTime<Utc>>,
602
603    /// The upper bound of the date range
604    before: Option<DateTime<Utc>>,
605}