mas_handlers/activity_tracker/
mod.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7mod bound;
8mod worker;
9
10use std::net::IpAddr;
11
12use chrono::{DateTime, Utc};
13use mas_data_model::{BrowserSession, CompatSession, Session};
14use mas_storage::Clock;
15use sqlx::PgPool;
16use tokio_util::{sync::CancellationToken, task::TaskTracker};
17use ulid::Ulid;
18
19pub use self::bound::Bound;
20use self::worker::Worker;
21
22static MESSAGE_QUEUE_SIZE: usize = 1000;
23
24#[derive(Clone, Copy, Debug, PartialOrd, PartialEq, Eq, Hash)]
25enum SessionKind {
26    OAuth2,
27    Compat,
28    Browser,
29}
30
31impl SessionKind {
32    const fn as_str(self) -> &'static str {
33        match self {
34            SessionKind::OAuth2 => "oauth2",
35            SessionKind::Compat => "compat",
36            SessionKind::Browser => "browser",
37        }
38    }
39}
40
41enum Message {
42    Record {
43        kind: SessionKind,
44        id: Ulid,
45        date_time: DateTime<Utc>,
46        ip: Option<IpAddr>,
47    },
48    Flush(tokio::sync::oneshot::Sender<()>),
49}
50
51#[derive(Clone)]
52pub struct ActivityTracker {
53    channel: tokio::sync::mpsc::Sender<Message>,
54}
55
56impl ActivityTracker {
57    /// Create a new activity tracker
58    ///
59    /// It will spawn the background worker and a loop to flush the tracker on
60    /// the task tracker, and both will shut themselves down, flushing one last
61    /// time, when the cancellation token is cancelled.
62    #[must_use]
63    pub fn new(
64        pool: PgPool,
65        flush_interval: std::time::Duration,
66        task_tracker: &TaskTracker,
67        cancellation_token: CancellationToken,
68    ) -> Self {
69        let worker = Worker::new(pool);
70        let (sender, receiver) = tokio::sync::mpsc::channel(MESSAGE_QUEUE_SIZE);
71        let tracker = ActivityTracker { channel: sender };
72
73        // Spawn the flush loop and the worker
74        task_tracker.spawn(
75            tracker
76                .clone()
77                .flush_loop(flush_interval, cancellation_token.clone()),
78        );
79        task_tracker.spawn(worker.run(receiver, cancellation_token));
80
81        tracker
82    }
83
84    /// Bind the activity tracker to an IP address.
85    #[must_use]
86    pub fn bind(self, ip: Option<IpAddr>) -> Bound {
87        Bound::new(self, ip)
88    }
89
90    /// Record activity in an OAuth 2.0 session.
91    pub async fn record_oauth2_session(
92        &self,
93        clock: &dyn Clock,
94        session: &Session,
95        ip: Option<IpAddr>,
96    ) {
97        let res = self
98            .channel
99            .send(Message::Record {
100                kind: SessionKind::OAuth2,
101                id: session.id,
102                date_time: clock.now(),
103                ip,
104            })
105            .await;
106
107        if let Err(e) = res {
108            tracing::error!("Failed to record OAuth2 session: {}", e);
109        }
110    }
111
112    /// Record activity in a compat session.
113    pub async fn record_compat_session(
114        &self,
115        clock: &dyn Clock,
116        compat_session: &CompatSession,
117        ip: Option<IpAddr>,
118    ) {
119        let res = self
120            .channel
121            .send(Message::Record {
122                kind: SessionKind::Compat,
123                id: compat_session.id,
124                date_time: clock.now(),
125                ip,
126            })
127            .await;
128
129        if let Err(e) = res {
130            tracing::error!("Failed to record compat session: {}", e);
131        }
132    }
133
134    /// Record activity in a browser session.
135    pub async fn record_browser_session(
136        &self,
137        clock: &dyn Clock,
138        browser_session: &BrowserSession,
139        ip: Option<IpAddr>,
140    ) {
141        let res = self
142            .channel
143            .send(Message::Record {
144                kind: SessionKind::Browser,
145                id: browser_session.id,
146                date_time: clock.now(),
147                ip,
148            })
149            .await;
150
151        if let Err(e) = res {
152            tracing::error!("Failed to record browser session: {}", e);
153        }
154    }
155
156    /// Manually flush the activity tracker.
157    pub async fn flush(&self) {
158        let (tx, rx) = tokio::sync::oneshot::channel();
159        let res = self.channel.send(Message::Flush(tx)).await;
160
161        match res {
162            Ok(()) => {
163                if let Err(e) = rx.await {
164                    tracing::error!(
165                        error = &e as &dyn std::error::Error,
166                        "Failed to flush activity tracker"
167                    );
168                }
169            }
170            Err(e) => {
171                tracing::error!(
172                    error = &e as &dyn std::error::Error,
173                    "Failed to flush activity tracker"
174                );
175            }
176        }
177    }
178
179    /// Regularly flush the activity tracker.
180    async fn flush_loop(
181        self,
182        interval: std::time::Duration,
183        cancellation_token: CancellationToken,
184    ) {
185        // This guard on the shutdown token is to ensure that if this task crashes for
186        // any reason, the server will shut down
187        let _guard = cancellation_token.clone().drop_guard();
188
189        loop {
190            tokio::select! {
191                biased;
192
193                () = cancellation_token.cancelled() => {
194                    // The cancellation token was cancelled, so we should exit
195                    return;
196                }
197
198                // First check if the channel is closed, then check if the timer expired
199                () = self.channel.closed() => {
200                    // The channel was closed, so we should exit
201                    return;
202                }
203
204
205                () = tokio::time::sleep(interval) => {
206                    self.flush().await;
207                }
208            }
209        }
210    }
211}