1use async_trait::async_trait;
10use mas_data_model::{CompatSession, CompatSessionState, Device, Session, SessionState, UserAgent};
11use mas_storage::{
12 Page, Pagination,
13 app_session::{AppSession, AppSessionFilter, AppSessionRepository, AppSessionState},
14 compat::CompatSessionFilter,
15 oauth2::OAuth2SessionFilter,
16};
17use oauth2_types::scope::{Scope, ScopeToken};
18use sea_query::{
19 Alias, ColumnRef, CommonTableExpression, Expr, PostgresQueryBuilder, Query, UnionType,
20};
21use sea_query_binder::SqlxBinder;
22use sqlx::PgConnection;
23use ulid::Ulid;
24
25use crate::{
26 DatabaseError, ExecuteExt,
27 errors::DatabaseInconsistencyError,
28 filter::StatementExt,
29 iden::{CompatSessions, OAuth2Sessions},
30 pagination::QueryBuilderExt,
31};
32
33pub struct PgAppSessionRepository<'c> {
35 conn: &'c mut PgConnection,
36}
37
38impl<'c> PgAppSessionRepository<'c> {
39 pub fn new(conn: &'c mut PgConnection) -> Self {
42 Self { conn }
43 }
44}
45
46mod priv_ {
47 use std::net::IpAddr;
51
52 use chrono::{DateTime, Utc};
53 use sea_query::enum_def;
54 use uuid::Uuid;
55
56 #[derive(sqlx::FromRow)]
57 #[enum_def]
58 pub(super) struct AppSessionLookup {
59 pub(super) cursor: Uuid,
60 pub(super) compat_session_id: Option<Uuid>,
61 pub(super) oauth2_session_id: Option<Uuid>,
62 pub(super) oauth2_client_id: Option<Uuid>,
63 pub(super) user_session_id: Option<Uuid>,
64 pub(super) user_id: Option<Uuid>,
65 pub(super) scope_list: Option<Vec<String>>,
66 pub(super) device_id: Option<String>,
67 pub(super) human_name: Option<String>,
68 pub(super) created_at: DateTime<Utc>,
69 pub(super) finished_at: Option<DateTime<Utc>>,
70 pub(super) is_synapse_admin: Option<bool>,
71 pub(super) user_agent: Option<String>,
72 pub(super) last_active_at: Option<DateTime<Utc>>,
73 pub(super) last_active_ip: Option<IpAddr>,
74 }
75}
76
77use priv_::{AppSessionLookup, AppSessionLookupIden};
78
79impl TryFrom<AppSessionLookup> for AppSession {
80 type Error = DatabaseError;
81
82 #[allow(clippy::too_many_lines)]
83 fn try_from(value: AppSessionLookup) -> Result<Self, Self::Error> {
84 let AppSessionLookup {
87 cursor,
88 compat_session_id,
89 oauth2_session_id,
90 oauth2_client_id,
91 user_session_id,
92 user_id,
93 scope_list,
94 device_id,
95 human_name,
96 created_at,
97 finished_at,
98 is_synapse_admin,
99 user_agent,
100 last_active_at,
101 last_active_ip,
102 } = value;
103
104 let user_agent = user_agent.map(UserAgent::parse);
105 let user_session_id = user_session_id.map(Ulid::from);
106
107 match (
108 compat_session_id,
109 oauth2_session_id,
110 oauth2_client_id,
111 user_id,
112 scope_list,
113 device_id,
114 is_synapse_admin,
115 ) {
116 (
117 Some(compat_session_id),
118 None,
119 None,
120 Some(user_id),
121 None,
122 device_id_opt,
123 Some(is_synapse_admin),
124 ) => {
125 let id = compat_session_id.into();
126 let device = device_id_opt
127 .map(Device::try_from)
128 .transpose()
129 .map_err(|e| {
130 DatabaseInconsistencyError::on("compat_sessions")
131 .column("device_id")
132 .row(id)
133 .source(e)
134 })?;
135
136 let state = match finished_at {
137 None => CompatSessionState::Valid,
138 Some(finished_at) => CompatSessionState::Finished { finished_at },
139 };
140
141 let session = CompatSession {
142 id,
143 state,
144 user_id: user_id.into(),
145 device,
146 human_name,
147 user_session_id,
148 created_at,
149 is_synapse_admin,
150 user_agent,
151 last_active_at,
152 last_active_ip,
153 };
154
155 Ok(AppSession::Compat(Box::new(session)))
156 }
157
158 (
159 None,
160 Some(oauth2_session_id),
161 Some(oauth2_client_id),
162 user_id,
163 Some(scope_list),
164 None,
165 None,
166 ) => {
167 let id = oauth2_session_id.into();
168 let scope: Result<Scope, _> =
169 scope_list.iter().map(|s| s.parse::<ScopeToken>()).collect();
170 let scope = scope.map_err(|e| {
171 DatabaseInconsistencyError::on("oauth2_sessions")
172 .column("scope")
173 .row(id)
174 .source(e)
175 })?;
176
177 let state = match value.finished_at {
178 None => SessionState::Valid,
179 Some(finished_at) => SessionState::Finished { finished_at },
180 };
181
182 let session = Session {
183 id,
184 state,
185 created_at,
186 client_id: oauth2_client_id.into(),
187 user_id: user_id.map(Ulid::from),
188 user_session_id,
189 scope,
190 user_agent,
191 last_active_at,
192 last_active_ip,
193 };
194
195 Ok(AppSession::OAuth2(Box::new(session)))
196 }
197
198 _ => Err(DatabaseInconsistencyError::on("sessions")
199 .row(cursor.into())
200 .into()),
201 }
202 }
203}
204
205fn split_filter(
208 filter: AppSessionFilter<'_>,
209) -> (CompatSessionFilter<'_>, OAuth2SessionFilter<'_>) {
210 let mut compat_filter = CompatSessionFilter::new();
211 let mut oauth2_filter = OAuth2SessionFilter::new();
212
213 if let Some(user) = filter.user() {
214 compat_filter = compat_filter.for_user(user);
215 oauth2_filter = oauth2_filter.for_user(user);
216 }
217
218 match filter.state() {
219 Some(AppSessionState::Active) => {
220 compat_filter = compat_filter.active_only();
221 oauth2_filter = oauth2_filter.active_only();
222 }
223 Some(AppSessionState::Finished) => {
224 compat_filter = compat_filter.finished_only();
225 oauth2_filter = oauth2_filter.finished_only();
226 }
227 None => {}
228 }
229
230 if let Some(device) = filter.device() {
231 compat_filter = compat_filter.for_device(device);
232 oauth2_filter = oauth2_filter.for_device(device);
233 }
234
235 if let Some(browser_session) = filter.browser_session() {
236 compat_filter = compat_filter.for_browser_session(browser_session);
237 oauth2_filter = oauth2_filter.for_browser_session(browser_session);
238 }
239
240 if let Some(last_active_before) = filter.last_active_before() {
241 compat_filter = compat_filter.with_last_active_before(last_active_before);
242 oauth2_filter = oauth2_filter.with_last_active_before(last_active_before);
243 }
244
245 if let Some(last_active_after) = filter.last_active_after() {
246 compat_filter = compat_filter.with_last_active_after(last_active_after);
247 oauth2_filter = oauth2_filter.with_last_active_after(last_active_after);
248 }
249
250 (compat_filter, oauth2_filter)
251}
252
253#[async_trait]
254impl AppSessionRepository for PgAppSessionRepository<'_> {
255 type Error = DatabaseError;
256
257 #[allow(clippy::too_many_lines)]
258 #[tracing::instrument(
259 name = "db.app_session.list",
260 fields(
261 db.query.text,
262 ),
263 skip_all,
264 err,
265 )]
266 async fn list(
267 &mut self,
268 filter: AppSessionFilter<'_>,
269 pagination: Pagination,
270 ) -> Result<Page<AppSession>, Self::Error> {
271 let (compat_filter, oauth2_filter) = split_filter(filter);
272
273 let mut oauth2_session_select = Query::select()
274 .expr_as(
275 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)),
276 AppSessionLookupIden::Cursor,
277 )
278 .expr_as(Expr::cust("NULL"), AppSessionLookupIden::CompatSessionId)
279 .expr_as(
280 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)),
281 AppSessionLookupIden::Oauth2SessionId,
282 )
283 .expr_as(
284 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId)),
285 AppSessionLookupIden::Oauth2ClientId,
286 )
287 .expr_as(
288 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId)),
289 AppSessionLookupIden::UserSessionId,
290 )
291 .expr_as(
292 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)),
293 AppSessionLookupIden::UserId,
294 )
295 .expr_as(
296 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)),
297 AppSessionLookupIden::ScopeList,
298 )
299 .expr_as(Expr::cust("NULL"), AppSessionLookupIden::DeviceId)
300 .expr_as(Expr::cust("NULL"), AppSessionLookupIden::HumanName)
301 .expr_as(
302 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::CreatedAt)),
303 AppSessionLookupIden::CreatedAt,
304 )
305 .expr_as(
306 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)),
307 AppSessionLookupIden::FinishedAt,
308 )
309 .expr_as(Expr::cust("NULL"), AppSessionLookupIden::IsSynapseAdmin)
310 .expr_as(
311 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserAgent)),
312 AppSessionLookupIden::UserAgent,
313 )
314 .expr_as(
315 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt)),
316 AppSessionLookupIden::LastActiveAt,
317 )
318 .expr_as(
319 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveIp)),
320 AppSessionLookupIden::LastActiveIp,
321 )
322 .from(OAuth2Sessions::Table)
323 .apply_filter(oauth2_filter)
324 .clone();
325
326 let compat_session_select = Query::select()
327 .expr_as(
328 Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId)),
329 AppSessionLookupIden::Cursor,
330 )
331 .expr_as(
332 Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId)),
333 AppSessionLookupIden::CompatSessionId,
334 )
335 .expr_as(Expr::cust("NULL"), AppSessionLookupIden::Oauth2SessionId)
336 .expr_as(Expr::cust("NULL"), AppSessionLookupIden::Oauth2ClientId)
337 .expr_as(
338 Expr::col((CompatSessions::Table, CompatSessions::UserSessionId)),
339 AppSessionLookupIden::UserSessionId,
340 )
341 .expr_as(
342 Expr::col((CompatSessions::Table, CompatSessions::UserId)),
343 AppSessionLookupIden::UserId,
344 )
345 .expr_as(Expr::cust("NULL"), AppSessionLookupIden::ScopeList)
346 .expr_as(
347 Expr::col((CompatSessions::Table, CompatSessions::DeviceId)),
348 AppSessionLookupIden::DeviceId,
349 )
350 .expr_as(
351 Expr::col((CompatSessions::Table, CompatSessions::HumanName)),
352 AppSessionLookupIden::HumanName,
353 )
354 .expr_as(
355 Expr::col((CompatSessions::Table, CompatSessions::CreatedAt)),
356 AppSessionLookupIden::CreatedAt,
357 )
358 .expr_as(
359 Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)),
360 AppSessionLookupIden::FinishedAt,
361 )
362 .expr_as(
363 Expr::col((CompatSessions::Table, CompatSessions::IsSynapseAdmin)),
364 AppSessionLookupIden::IsSynapseAdmin,
365 )
366 .expr_as(
367 Expr::col((CompatSessions::Table, CompatSessions::UserAgent)),
368 AppSessionLookupIden::UserAgent,
369 )
370 .expr_as(
371 Expr::col((CompatSessions::Table, CompatSessions::LastActiveAt)),
372 AppSessionLookupIden::LastActiveAt,
373 )
374 .expr_as(
375 Expr::col((CompatSessions::Table, CompatSessions::LastActiveIp)),
376 AppSessionLookupIden::LastActiveIp,
377 )
378 .from(CompatSessions::Table)
379 .apply_filter(compat_filter)
380 .clone();
381
382 let common_table_expression = CommonTableExpression::new()
383 .query(
384 oauth2_session_select
385 .union(UnionType::All, compat_session_select)
386 .clone(),
387 )
388 .table_name(Alias::new("sessions"))
389 .clone();
390
391 let with_clause = Query::with().cte(common_table_expression).clone();
392
393 let select = Query::select()
394 .column(ColumnRef::Asterisk)
395 .from(Alias::new("sessions"))
396 .generate_pagination(AppSessionLookupIden::Cursor, pagination)
397 .clone();
398
399 let (sql, arguments) = with_clause.query(select).build_sqlx(PostgresQueryBuilder);
400
401 let edges: Vec<AppSessionLookup> = sqlx::query_as_with(&sql, arguments)
402 .traced()
403 .fetch_all(&mut *self.conn)
404 .await?;
405
406 let page = pagination.process(edges).try_map(TryFrom::try_from)?;
407
408 Ok(page)
409 }
410
411 #[tracing::instrument(
412 name = "db.app_session.count",
413 fields(
414 db.query.text,
415 ),
416 skip_all,
417 err,
418 )]
419 async fn count(&mut self, filter: AppSessionFilter<'_>) -> Result<usize, Self::Error> {
420 let (compat_filter, oauth2_filter) = split_filter(filter);
421 let mut oauth2_session_select = Query::select()
422 .expr(Expr::cust("1"))
423 .from(OAuth2Sessions::Table)
424 .apply_filter(oauth2_filter)
425 .clone();
426
427 let compat_session_select = Query::select()
428 .expr(Expr::cust("1"))
429 .from(CompatSessions::Table)
430 .apply_filter(compat_filter)
431 .clone();
432
433 let common_table_expression = CommonTableExpression::new()
434 .query(
435 oauth2_session_select
436 .union(UnionType::All, compat_session_select)
437 .clone(),
438 )
439 .table_name(Alias::new("sessions"))
440 .clone();
441
442 let with_clause = Query::with().cte(common_table_expression).clone();
443
444 let select = Query::select()
445 .expr(Expr::cust("COUNT(*)"))
446 .from(Alias::new("sessions"))
447 .clone();
448
449 let (sql, arguments) = with_clause.query(select).build_sqlx(PostgresQueryBuilder);
450
451 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
452 .traced()
453 .fetch_one(&mut *self.conn)
454 .await?;
455
456 count
457 .try_into()
458 .map_err(DatabaseError::to_invalid_operation)
459 }
460}
461
462#[cfg(test)]
463mod tests {
464 use chrono::Duration;
465 use mas_data_model::Device;
466 use mas_storage::{
467 Pagination, RepositoryAccess,
468 app_session::{AppSession, AppSessionFilter},
469 clock::MockClock,
470 oauth2::OAuth2SessionRepository,
471 };
472 use oauth2_types::{
473 requests::GrantType,
474 scope::{OPENID, Scope},
475 };
476 use rand::SeedableRng;
477 use rand_chacha::ChaChaRng;
478 use sqlx::PgPool;
479
480 use crate::PgRepository;
481
482 #[sqlx::test(migrator = "crate::MIGRATOR")]
483 async fn test_app_repo(pool: PgPool) {
484 let mut rng = ChaChaRng::seed_from_u64(42);
485 let clock = MockClock::default();
486 let mut repo = PgRepository::from_pool(&pool).await.unwrap();
487
488 let user = repo
490 .user()
491 .add(&mut rng, &clock, "john".to_owned())
492 .await
493 .unwrap();
494
495 let all = AppSessionFilter::new().for_user(&user);
496 let active = all.active_only();
497 let finished = all.finished_only();
498 let pagination = Pagination::first(10);
499
500 assert_eq!(repo.app_session().count(all).await.unwrap(), 0);
501 assert_eq!(repo.app_session().count(active).await.unwrap(), 0);
502 assert_eq!(repo.app_session().count(finished).await.unwrap(), 0);
503
504 let full_list = repo.app_session().list(all, pagination).await.unwrap();
505 assert!(full_list.edges.is_empty());
506 let active_list = repo.app_session().list(active, pagination).await.unwrap();
507 assert!(active_list.edges.is_empty());
508 let finished_list = repo.app_session().list(finished, pagination).await.unwrap();
509 assert!(finished_list.edges.is_empty());
510
511 let device = Device::generate(&mut rng);
513 let compat_session = repo
514 .compat_session()
515 .add(&mut rng, &clock, &user, device.clone(), None, false)
516 .await
517 .unwrap();
518
519 assert_eq!(repo.app_session().count(all).await.unwrap(), 1);
520 assert_eq!(repo.app_session().count(active).await.unwrap(), 1);
521 assert_eq!(repo.app_session().count(finished).await.unwrap(), 0);
522
523 let full_list = repo.app_session().list(all, pagination).await.unwrap();
524 assert_eq!(full_list.edges.len(), 1);
525 assert_eq!(
526 full_list.edges[0],
527 AppSession::Compat(Box::new(compat_session.clone()))
528 );
529 let active_list = repo.app_session().list(active, pagination).await.unwrap();
530 assert_eq!(active_list.edges.len(), 1);
531 assert_eq!(
532 active_list.edges[0],
533 AppSession::Compat(Box::new(compat_session.clone()))
534 );
535 let finished_list = repo.app_session().list(finished, pagination).await.unwrap();
536 assert!(finished_list.edges.is_empty());
537
538 let compat_session = repo
540 .compat_session()
541 .finish(&clock, compat_session)
542 .await
543 .unwrap();
544
545 assert_eq!(repo.app_session().count(all).await.unwrap(), 1);
546 assert_eq!(repo.app_session().count(active).await.unwrap(), 0);
547 assert_eq!(repo.app_session().count(finished).await.unwrap(), 1);
548
549 let full_list = repo.app_session().list(all, pagination).await.unwrap();
550 assert_eq!(full_list.edges.len(), 1);
551 assert_eq!(
552 full_list.edges[0],
553 AppSession::Compat(Box::new(compat_session.clone()))
554 );
555 let active_list = repo.app_session().list(active, pagination).await.unwrap();
556 assert!(active_list.edges.is_empty());
557 let finished_list = repo.app_session().list(finished, pagination).await.unwrap();
558 assert_eq!(finished_list.edges.len(), 1);
559 assert_eq!(
560 finished_list.edges[0],
561 AppSession::Compat(Box::new(compat_session.clone()))
562 );
563
564 let client = repo
566 .oauth2_client()
567 .add(
568 &mut rng,
569 &clock,
570 vec!["https://example.com/redirect".parse().unwrap()],
571 None,
572 None,
573 None,
574 vec![GrantType::AuthorizationCode],
575 Some("First client".to_owned()),
576 Some("https://example.com/logo.png".parse().unwrap()),
577 Some("https://example.com/".parse().unwrap()),
578 Some("https://example.com/policy".parse().unwrap()),
579 Some("https://example.com/tos".parse().unwrap()),
580 Some("https://example.com/jwks.json".parse().unwrap()),
581 None,
582 None,
583 None,
584 None,
585 None,
586 Some("https://example.com/login".parse().unwrap()),
587 )
588 .await
589 .unwrap();
590
591 let device2 = Device::generate(&mut rng);
592 let scope = Scope::from_iter([OPENID, device2.to_scope_token().unwrap()]);
593
594 clock.advance(Duration::try_minutes(1).unwrap());
597
598 let oauth_session = repo
599 .oauth2_session()
600 .add(&mut rng, &clock, &client, Some(&user), None, scope)
601 .await
602 .unwrap();
603
604 assert_eq!(repo.app_session().count(all).await.unwrap(), 2);
605 assert_eq!(repo.app_session().count(active).await.unwrap(), 1);
606 assert_eq!(repo.app_session().count(finished).await.unwrap(), 1);
607
608 let full_list = repo.app_session().list(all, pagination).await.unwrap();
609 assert_eq!(full_list.edges.len(), 2);
610 assert_eq!(
611 full_list.edges[0],
612 AppSession::Compat(Box::new(compat_session.clone()))
613 );
614 assert_eq!(
615 full_list.edges[1],
616 AppSession::OAuth2(Box::new(oauth_session.clone()))
617 );
618
619 let active_list = repo.app_session().list(active, pagination).await.unwrap();
620 assert_eq!(active_list.edges.len(), 1);
621 assert_eq!(
622 active_list.edges[0],
623 AppSession::OAuth2(Box::new(oauth_session.clone()))
624 );
625
626 let finished_list = repo.app_session().list(finished, pagination).await.unwrap();
627 assert_eq!(finished_list.edges.len(), 1);
628 assert_eq!(
629 finished_list.edges[0],
630 AppSession::Compat(Box::new(compat_session.clone()))
631 );
632
633 let oauth_session = repo
635 .oauth2_session()
636 .finish(&clock, oauth_session)
637 .await
638 .unwrap();
639
640 assert_eq!(repo.app_session().count(all).await.unwrap(), 2);
641 assert_eq!(repo.app_session().count(active).await.unwrap(), 0);
642 assert_eq!(repo.app_session().count(finished).await.unwrap(), 2);
643
644 let full_list = repo.app_session().list(all, pagination).await.unwrap();
645 assert_eq!(full_list.edges.len(), 2);
646 assert_eq!(
647 full_list.edges[0],
648 AppSession::Compat(Box::new(compat_session.clone()))
649 );
650 assert_eq!(
651 full_list.edges[1],
652 AppSession::OAuth2(Box::new(oauth_session.clone()))
653 );
654
655 let active_list = repo.app_session().list(active, pagination).await.unwrap();
656 assert!(active_list.edges.is_empty());
657
658 let finished_list = repo.app_session().list(finished, pagination).await.unwrap();
659 assert_eq!(finished_list.edges.len(), 2);
660 assert_eq!(
661 finished_list.edges[0],
662 AppSession::Compat(Box::new(compat_session.clone()))
663 );
664 assert_eq!(
665 full_list.edges[1],
666 AppSession::OAuth2(Box::new(oauth_session.clone()))
667 );
668
669 let filter = AppSessionFilter::new().for_device(&device);
671 assert_eq!(repo.app_session().count(filter).await.unwrap(), 1);
672 let list = repo.app_session().list(filter, pagination).await.unwrap();
673 assert_eq!(list.edges.len(), 1);
674 assert_eq!(
675 list.edges[0],
676 AppSession::Compat(Box::new(compat_session.clone()))
677 );
678
679 let filter = AppSessionFilter::new().for_device(&device2);
680 assert_eq!(repo.app_session().count(filter).await.unwrap(), 1);
681 let list = repo.app_session().list(filter, pagination).await.unwrap();
682 assert_eq!(list.edges.len(), 1);
683 assert_eq!(
684 list.edges[0],
685 AppSession::OAuth2(Box::new(oauth_session.clone()))
686 );
687
688 let user2 = repo
690 .user()
691 .add(&mut rng, &clock, "alice".to_owned())
692 .await
693 .unwrap();
694
695 let filter = AppSessionFilter::new().for_user(&user2);
697 assert_eq!(repo.app_session().count(filter).await.unwrap(), 0);
698 let list = repo.app_session().list(filter, pagination).await.unwrap();
699 assert!(list.edges.is_empty());
700 }
701}