1#![allow(clippy::module_name_repetitions)]
8
9use std::{net::IpAddr, ops::Deref, sync::Arc};
10
11use async_graphql::{
12 EmptySubscription, InputObject,
13 extensions::Tracing,
14 http::{GraphQLPlaygroundConfig, MultipartOptions, playground_source},
15};
16use axum::{
17 Extension, Json,
18 body::Body,
19 extract::{RawQuery, State as AxumState},
20 http::StatusCode,
21 response::{Html, IntoResponse, Response},
22};
23use axum_extra::typed_header::TypedHeader;
24use chrono::{DateTime, Utc};
25use futures_util::TryStreamExt;
26use headers::{Authorization, ContentType, HeaderValue, authorization::Bearer};
27use hyper::header::CACHE_CONTROL;
28use mas_axum_utils::{
29 FancyError, SessionInfo, SessionInfoExt, cookies::CookieJar, sentry::SentryEventID,
30};
31use mas_data_model::{BrowserSession, Session, SiteConfig, User};
32use mas_matrix::HomeserverConnection;
33use mas_policy::{InstantiateError, Policy, PolicyFactory};
34use mas_router::UrlBuilder;
35use mas_storage::{BoxClock, BoxRepository, BoxRng, Clock, RepositoryError, SystemClock};
36use mas_storage_pg::PgRepository;
37use opentelemetry_semantic_conventions::trace::{GRAPHQL_DOCUMENT, GRAPHQL_OPERATION_NAME};
38use rand::{SeedableRng, thread_rng};
39use rand_chacha::ChaChaRng;
40use sqlx::PgPool;
41use tracing::{Instrument, info_span};
42use ulid::Ulid;
43
44mod model;
45mod mutations;
46mod query;
47mod state;
48
49pub use self::state::{BoxState, State};
50use self::{
51 model::{CreationEvent, Node},
52 mutations::Mutation,
53 query::Query,
54};
55use crate::{
56 BoundActivityTracker, Limiter, RequesterFingerprint, impl_from_error_for_route,
57 passwords::PasswordManager,
58};
59
60#[cfg(test)]
61mod tests;
62
63#[derive(Debug, Clone)]
66pub struct ExtraRouterParameters {
67 pub undocumented_oauth2_access: bool,
68}
69
70struct GraphQLState {
71 pool: PgPool,
72 homeserver_connection: Arc<dyn HomeserverConnection>,
73 policy_factory: Arc<PolicyFactory>,
74 site_config: SiteConfig,
75 password_manager: PasswordManager,
76 url_builder: UrlBuilder,
77 limiter: Limiter,
78}
79
80#[async_trait::async_trait]
81impl state::State for GraphQLState {
82 async fn repository(&self) -> Result<BoxRepository, RepositoryError> {
83 let repo = PgRepository::from_pool(&self.pool)
84 .await
85 .map_err(RepositoryError::from_error)?;
86
87 Ok(repo.boxed())
88 }
89
90 async fn policy(&self) -> Result<Policy, InstantiateError> {
91 self.policy_factory.instantiate().await
92 }
93
94 fn password_manager(&self) -> PasswordManager {
95 self.password_manager.clone()
96 }
97
98 fn site_config(&self) -> &SiteConfig {
99 &self.site_config
100 }
101
102 fn homeserver_connection(&self) -> &dyn HomeserverConnection {
103 self.homeserver_connection.as_ref()
104 }
105
106 fn url_builder(&self) -> &UrlBuilder {
107 &self.url_builder
108 }
109
110 fn limiter(&self) -> &Limiter {
111 &self.limiter
112 }
113
114 fn clock(&self) -> BoxClock {
115 let clock = SystemClock::default();
116 Box::new(clock)
117 }
118
119 fn rng(&self) -> BoxRng {
120 #[allow(clippy::disallowed_methods)]
121 let rng = thread_rng();
122
123 let rng = ChaChaRng::from_rng(rng).expect("Failed to seed rng");
124 Box::new(rng)
125 }
126}
127
128#[must_use]
129pub fn schema(
130 pool: &PgPool,
131 policy_factory: &Arc<PolicyFactory>,
132 homeserver_connection: impl HomeserverConnection + 'static,
133 site_config: SiteConfig,
134 password_manager: PasswordManager,
135 url_builder: UrlBuilder,
136 limiter: Limiter,
137) -> Schema {
138 let state = GraphQLState {
139 pool: pool.clone(),
140 policy_factory: Arc::clone(policy_factory),
141 homeserver_connection: Arc::new(homeserver_connection),
142 site_config,
143 password_manager,
144 url_builder,
145 limiter,
146 };
147 let state: BoxState = Box::new(state);
148
149 schema_builder().extension(Tracing).data(state).finish()
150}
151
152fn span_for_graphql_request(request: &async_graphql::Request) -> tracing::Span {
153 let span = info_span!(
154 "GraphQL operation",
155 "otel.name" = tracing::field::Empty,
156 "otel.kind" = "server",
157 { GRAPHQL_DOCUMENT } = request.query,
158 { GRAPHQL_OPERATION_NAME } = tracing::field::Empty,
159 );
160
161 if let Some(name) = &request.operation_name {
162 span.record("otel.name", name);
163 span.record(GRAPHQL_OPERATION_NAME, name);
164 }
165
166 span
167}
168
169#[derive(thiserror::Error, Debug)]
170pub enum RouteError {
171 #[error(transparent)]
172 Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
173
174 #[error("Loading of some database objects failed")]
175 LoadFailed,
176
177 #[error("Invalid access token")]
178 InvalidToken,
179
180 #[error("Missing scope")]
181 MissingScope,
182
183 #[error(transparent)]
184 ParseRequest(#[from] async_graphql::ParseRequestError),
185}
186
187impl_from_error_for_route!(mas_storage::RepositoryError);
188
189impl IntoResponse for RouteError {
190 fn into_response(self) -> Response {
191 let event_id = sentry::capture_error(&self);
192
193 let response = match self {
194 e @ (Self::Internal(_) | Self::LoadFailed) => {
195 let error = async_graphql::Error::new_with_source(e);
196 (
197 StatusCode::INTERNAL_SERVER_ERROR,
198 Json(serde_json::json!({"errors": [error]})),
199 )
200 .into_response()
201 }
202
203 Self::InvalidToken => {
204 let error = async_graphql::Error::new("Invalid token");
205 (
206 StatusCode::UNAUTHORIZED,
207 Json(serde_json::json!({"errors": [error]})),
208 )
209 .into_response()
210 }
211
212 Self::MissingScope => {
213 let error = async_graphql::Error::new("Missing urn:mas:graphql:* scope");
214 (
215 StatusCode::UNAUTHORIZED,
216 Json(serde_json::json!({"errors": [error]})),
217 )
218 .into_response()
219 }
220
221 Self::ParseRequest(e) => {
222 let error = async_graphql::Error::new_with_source(e);
223 (
224 StatusCode::BAD_REQUEST,
225 Json(serde_json::json!({"errors": [error]})),
226 )
227 .into_response()
228 }
229 };
230
231 (SentryEventID::from(event_id), response).into_response()
232 }
233}
234
235async fn get_requester(
236 undocumented_oauth2_access: bool,
237 clock: &impl Clock,
238 activity_tracker: &BoundActivityTracker,
239 mut repo: BoxRepository,
240 session_info: SessionInfo,
241 user_agent: Option<String>,
242 token: Option<&str>,
243) -> Result<Requester, RouteError> {
244 let entity = if let Some(token) = token {
245 if !undocumented_oauth2_access {
247 return Err(RouteError::InvalidToken);
248 }
249
250 let token = repo
251 .oauth2_access_token()
252 .find_by_token(token)
253 .await?
254 .ok_or(RouteError::InvalidToken)?;
255
256 let session = repo
257 .oauth2_session()
258 .lookup(token.session_id)
259 .await?
260 .ok_or(RouteError::LoadFailed)?;
261
262 activity_tracker
263 .record_oauth2_session(clock, &session)
264 .await;
265
266 let user = if let Some(user_id) = session.user_id {
268 let user = repo
269 .user()
270 .lookup(user_id)
271 .await?
272 .ok_or(RouteError::LoadFailed)?;
273 Some(user)
274 } else {
275 None
276 };
277
278 let user_valid = user.as_ref().is_none_or(User::is_valid);
280
281 if !token.is_valid(clock.now()) || !session.is_valid() || !user_valid {
282 return Err(RouteError::InvalidToken);
283 }
284
285 if !session.scope.contains("urn:mas:graphql:*") {
286 return Err(RouteError::MissingScope);
287 }
288
289 RequestingEntity::OAuth2Session(Box::new((session, user)))
290 } else {
291 let maybe_session = session_info.load_active_session(&mut repo).await?;
292
293 if let Some(session) = maybe_session.as_ref() {
294 activity_tracker
295 .record_browser_session(clock, session)
296 .await;
297 }
298
299 RequestingEntity::from(maybe_session)
300 };
301
302 let requester = Requester {
303 entity,
304 ip_address: activity_tracker.ip(),
305 user_agent,
306 };
307
308 repo.cancel().await?;
309 Ok(requester)
310}
311
312pub async fn post(
313 AxumState(schema): AxumState<Schema>,
314 Extension(ExtraRouterParameters {
315 undocumented_oauth2_access,
316 }): Extension<ExtraRouterParameters>,
317 clock: BoxClock,
318 repo: BoxRepository,
319 activity_tracker: BoundActivityTracker,
320 cookie_jar: CookieJar,
321 content_type: Option<TypedHeader<ContentType>>,
322 authorization: Option<TypedHeader<Authorization<Bearer>>>,
323 user_agent: Option<TypedHeader<headers::UserAgent>>,
324 body: Body,
325) -> Result<impl IntoResponse, RouteError> {
326 let body = body.into_data_stream();
327 let token = authorization
328 .as_ref()
329 .map(|TypedHeader(Authorization(bearer))| bearer.token());
330 let user_agent = user_agent.map(|TypedHeader(h)| h.to_string());
331 let (session_info, _cookie_jar) = cookie_jar.session_info();
332 let requester = get_requester(
333 undocumented_oauth2_access,
334 &clock,
335 &activity_tracker,
336 repo,
337 session_info,
338 user_agent,
339 token,
340 )
341 .await?;
342
343 let content_type = content_type.map(|TypedHeader(h)| h.to_string());
344
345 let request = async_graphql::http::receive_body(
346 content_type,
347 body.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
348 .into_async_read(),
349 MultipartOptions::default(),
350 )
351 .await?
352 .data(requester); let span = span_for_graphql_request(&request);
355 let response = schema.execute(request).instrument(span).await;
356
357 let cache_control = response
358 .cache_control
359 .value()
360 .and_then(|v| HeaderValue::from_str(&v).ok())
361 .map(|h| [(CACHE_CONTROL, h)]);
362
363 let headers = response.http_headers.clone();
364
365 Ok((headers, cache_control, Json(response)))
366}
367
368pub async fn get(
369 AxumState(schema): AxumState<Schema>,
370 Extension(ExtraRouterParameters {
371 undocumented_oauth2_access,
372 }): Extension<ExtraRouterParameters>,
373 clock: BoxClock,
374 repo: BoxRepository,
375 activity_tracker: BoundActivityTracker,
376 cookie_jar: CookieJar,
377 authorization: Option<TypedHeader<Authorization<Bearer>>>,
378 user_agent: Option<TypedHeader<headers::UserAgent>>,
379 RawQuery(query): RawQuery,
380) -> Result<impl IntoResponse, FancyError> {
381 let token = authorization
382 .as_ref()
383 .map(|TypedHeader(Authorization(bearer))| bearer.token());
384 let user_agent = user_agent.map(|TypedHeader(h)| h.to_string());
385 let (session_info, _cookie_jar) = cookie_jar.session_info();
386 let requester = get_requester(
387 undocumented_oauth2_access,
388 &clock,
389 &activity_tracker,
390 repo,
391 session_info,
392 user_agent,
393 token,
394 )
395 .await?;
396
397 let request =
398 async_graphql::http::parse_query_string(&query.unwrap_or_default())?.data(requester);
399
400 let span = span_for_graphql_request(&request);
401 let response = schema.execute(request).instrument(span).await;
402
403 let cache_control = response
404 .cache_control
405 .value()
406 .and_then(|v| HeaderValue::from_str(&v).ok())
407 .map(|h| [(CACHE_CONTROL, h)]);
408
409 let headers = response.http_headers.clone();
410
411 Ok((headers, cache_control, Json(response)))
412}
413
414pub async fn playground() -> impl IntoResponse {
415 Html(playground_source(
416 GraphQLPlaygroundConfig::new("/graphql").with_setting("request.credentials", "include"),
417 ))
418}
419
420pub type Schema = async_graphql::Schema<Query, Mutation, EmptySubscription>;
421pub type SchemaBuilder = async_graphql::SchemaBuilder<Query, Mutation, EmptySubscription>;
422
423#[must_use]
424pub fn schema_builder() -> SchemaBuilder {
425 async_graphql::Schema::build(Query::new(), Mutation::new(), EmptySubscription)
426 .register_output_type::<Node>()
427 .register_output_type::<CreationEvent>()
428}
429
430pub struct Requester {
431 entity: RequestingEntity,
432 ip_address: Option<IpAddr>,
433 user_agent: Option<String>,
434}
435
436impl Requester {
437 pub fn fingerprint(&self) -> RequesterFingerprint {
438 if let Some(ip) = self.ip_address {
439 RequesterFingerprint::new(ip)
440 } else {
441 RequesterFingerprint::EMPTY
442 }
443 }
444
445 pub fn for_policy(&self) -> mas_policy::Requester {
446 mas_policy::Requester {
447 ip_address: self.ip_address,
448 user_agent: self.user_agent.clone(),
449 }
450 }
451}
452
453impl Deref for Requester {
454 type Target = RequestingEntity;
455
456 fn deref(&self) -> &Self::Target {
457 &self.entity
458 }
459}
460
461#[derive(Debug, Clone, Default, PartialEq, Eq)]
463pub enum RequestingEntity {
464 #[default]
466 Anonymous,
467
468 BrowserSession(Box<BrowserSession>),
470
471 OAuth2Session(Box<(Session, Option<User>)>),
473}
474
475trait OwnerId {
476 fn owner_id(&self) -> Option<Ulid>;
477}
478
479impl OwnerId for User {
480 fn owner_id(&self) -> Option<Ulid> {
481 Some(self.id)
482 }
483}
484
485impl OwnerId for BrowserSession {
486 fn owner_id(&self) -> Option<Ulid> {
487 Some(self.user.id)
488 }
489}
490
491impl OwnerId for mas_data_model::UserEmail {
492 fn owner_id(&self) -> Option<Ulid> {
493 Some(self.user_id)
494 }
495}
496
497impl OwnerId for Session {
498 fn owner_id(&self) -> Option<Ulid> {
499 self.user_id
500 }
501}
502
503impl OwnerId for mas_data_model::CompatSession {
504 fn owner_id(&self) -> Option<Ulid> {
505 Some(self.user_id)
506 }
507}
508
509impl OwnerId for mas_data_model::UpstreamOAuthLink {
510 fn owner_id(&self) -> Option<Ulid> {
511 self.user_id
512 }
513}
514
515pub struct UserId(Ulid);
517
518impl OwnerId for UserId {
519 fn owner_id(&self) -> Option<Ulid> {
520 Some(self.0)
521 }
522}
523
524impl RequestingEntity {
525 fn browser_session(&self) -> Option<&BrowserSession> {
526 match self {
527 Self::BrowserSession(session) => Some(session),
528 Self::OAuth2Session(_) | Self::Anonymous => None,
529 }
530 }
531
532 fn user(&self) -> Option<&User> {
533 match self {
534 Self::BrowserSession(session) => Some(&session.user),
535 Self::OAuth2Session(tuple) => tuple.1.as_ref(),
536 Self::Anonymous => None,
537 }
538 }
539
540 fn oauth2_session(&self) -> Option<&Session> {
541 match self {
542 Self::OAuth2Session(tuple) => Some(&tuple.0),
543 Self::BrowserSession(_) | Self::Anonymous => None,
544 }
545 }
546
547 fn is_owner_or_admin(&self, resource: &impl OwnerId) -> bool {
549 if self.is_admin() {
551 return true;
552 }
553
554 let Some(owner_id) = resource.owner_id() else {
556 return false;
557 };
558
559 let Some(user) = self.user() else {
560 return false;
561 };
562
563 user.id == owner_id
564 }
565
566 fn is_admin(&self) -> bool {
567 match self {
568 Self::OAuth2Session(tuple) => {
569 tuple.0.scope.contains("urn:mas:admin")
572 }
573 Self::BrowserSession(_) | Self::Anonymous => false,
574 }
575 }
576
577 fn is_unauthenticated(&self) -> bool {
578 matches!(self, Self::Anonymous)
579 }
580}
581
582impl From<BrowserSession> for RequestingEntity {
583 fn from(session: BrowserSession) -> Self {
584 Self::BrowserSession(Box::new(session))
585 }
586}
587
588impl<T> From<Option<T>> for RequestingEntity
589where
590 T: Into<RequestingEntity>,
591{
592 fn from(session: Option<T>) -> Self {
593 session.map(Into::into).unwrap_or_default()
594 }
595}
596
597#[derive(InputObject, Default, Clone, Copy)]
599pub struct DateFilter {
600 after: Option<DateTime<Utc>>,
602
603 before: Option<DateTime<Utc>>,
605}