mas_storage_pg/oauth2/
device_code_grant.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::net::IpAddr;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use mas_data_model::{BrowserSession, DeviceCodeGrant, DeviceCodeGrantState, Session, UserAgent};
12use mas_storage::{
13    Clock,
14    oauth2::{OAuth2DeviceCodeGrantParams, OAuth2DeviceCodeGrantRepository},
15};
16use oauth2_types::scope::Scope;
17use rand::RngCore;
18use sqlx::PgConnection;
19use ulid::Ulid;
20use uuid::Uuid;
21
22use crate::{DatabaseError, ExecuteExt, errors::DatabaseInconsistencyError};
23
24/// An implementation of [`OAuth2DeviceCodeGrantRepository`] for a PostgreSQL
25/// connection
26pub struct PgOAuth2DeviceCodeGrantRepository<'c> {
27    conn: &'c mut PgConnection,
28}
29
30impl<'c> PgOAuth2DeviceCodeGrantRepository<'c> {
31    /// Create a new [`PgOAuth2DeviceCodeGrantRepository`] from an active
32    /// PostgreSQL connection
33    pub fn new(conn: &'c mut PgConnection) -> Self {
34        Self { conn }
35    }
36}
37
38struct OAuth2DeviceGrantLookup {
39    oauth2_device_code_grant_id: Uuid,
40    oauth2_client_id: Uuid,
41    scope: String,
42    device_code: String,
43    user_code: String,
44    created_at: DateTime<Utc>,
45    expires_at: DateTime<Utc>,
46    fulfilled_at: Option<DateTime<Utc>>,
47    rejected_at: Option<DateTime<Utc>>,
48    exchanged_at: Option<DateTime<Utc>>,
49    user_session_id: Option<Uuid>,
50    oauth2_session_id: Option<Uuid>,
51    ip_address: Option<IpAddr>,
52    user_agent: Option<String>,
53}
54
55impl TryFrom<OAuth2DeviceGrantLookup> for DeviceCodeGrant {
56    type Error = DatabaseInconsistencyError;
57
58    fn try_from(
59        OAuth2DeviceGrantLookup {
60            oauth2_device_code_grant_id,
61            oauth2_client_id,
62            scope,
63            device_code,
64            user_code,
65            created_at,
66            expires_at,
67            fulfilled_at,
68            rejected_at,
69            exchanged_at,
70            user_session_id,
71            oauth2_session_id,
72            ip_address,
73            user_agent,
74        }: OAuth2DeviceGrantLookup,
75    ) -> Result<Self, Self::Error> {
76        let id = Ulid::from(oauth2_device_code_grant_id);
77        let client_id = Ulid::from(oauth2_client_id);
78
79        let scope: Scope = scope.parse().map_err(|e| {
80            DatabaseInconsistencyError::on("oauth2_authorization_grants")
81                .column("scope")
82                .row(id)
83                .source(e)
84        })?;
85
86        let state = match (
87            fulfilled_at,
88            rejected_at,
89            exchanged_at,
90            user_session_id,
91            oauth2_session_id,
92        ) {
93            (None, None, None, None, None) => DeviceCodeGrantState::Pending,
94
95            (Some(fulfilled_at), None, None, Some(user_session_id), None) => {
96                DeviceCodeGrantState::Fulfilled {
97                    browser_session_id: Ulid::from(user_session_id),
98                    fulfilled_at,
99                }
100            }
101
102            (None, Some(rejected_at), None, Some(user_session_id), None) => {
103                DeviceCodeGrantState::Rejected {
104                    browser_session_id: Ulid::from(user_session_id),
105                    rejected_at,
106                }
107            }
108
109            (
110                Some(fulfilled_at),
111                None,
112                Some(exchanged_at),
113                Some(user_session_id),
114                Some(oauth2_session_id),
115            ) => DeviceCodeGrantState::Exchanged {
116                browser_session_id: Ulid::from(user_session_id),
117                session_id: Ulid::from(oauth2_session_id),
118                fulfilled_at,
119                exchanged_at,
120            },
121
122            _ => return Err(DatabaseInconsistencyError::on("oauth2_device_code_grant").row(id)),
123        };
124
125        Ok(DeviceCodeGrant {
126            id,
127            state,
128            client_id,
129            scope,
130            user_code,
131            device_code,
132            created_at,
133            expires_at,
134            ip_address,
135            user_agent: user_agent.map(UserAgent::parse),
136        })
137    }
138}
139
140#[async_trait]
141impl OAuth2DeviceCodeGrantRepository for PgOAuth2DeviceCodeGrantRepository<'_> {
142    type Error = DatabaseError;
143
144    #[tracing::instrument(
145        name = "db.oauth2_device_code_grant.add",
146        skip_all,
147        fields(
148            db.query.text,
149            oauth2_device_code.id,
150            oauth2_device_code.scope = %params.scope,
151            oauth2_client.id = %params.client.id,
152        ),
153        err,
154    )]
155    async fn add(
156        &mut self,
157        rng: &mut (dyn RngCore + Send),
158        clock: &dyn Clock,
159        params: OAuth2DeviceCodeGrantParams<'_>,
160    ) -> Result<DeviceCodeGrant, Self::Error> {
161        let now = clock.now();
162        let id = Ulid::from_datetime_with_source(now.into(), rng);
163        tracing::Span::current().record("oauth2_device_code.id", tracing::field::display(id));
164
165        let created_at = now;
166        let expires_at = now + params.expires_in;
167        let client_id = params.client.id;
168
169        sqlx::query!(
170            r#"
171                INSERT INTO "oauth2_device_code_grant"
172                    ( oauth2_device_code_grant_id
173                    , oauth2_client_id
174                    , scope
175                    , device_code
176                    , user_code
177                    , created_at
178                    , expires_at
179                    , ip_address
180                    , user_agent
181                    )
182                VALUES
183                    ($1, $2, $3, $4, $5, $6, $7, $8, $9)
184            "#,
185            Uuid::from(id),
186            Uuid::from(client_id),
187            params.scope.to_string(),
188            &params.device_code,
189            &params.user_code,
190            created_at,
191            expires_at,
192            params.ip_address as Option<IpAddr>,
193            params.user_agent.as_deref(),
194        )
195        .traced()
196        .execute(&mut *self.conn)
197        .await?;
198
199        Ok(DeviceCodeGrant {
200            id,
201            state: DeviceCodeGrantState::Pending,
202            client_id,
203            scope: params.scope,
204            user_code: params.user_code,
205            device_code: params.device_code,
206            created_at,
207            expires_at,
208            ip_address: params.ip_address,
209            user_agent: params.user_agent,
210        })
211    }
212
213    #[tracing::instrument(
214        name = "db.oauth2_device_code_grant.lookup",
215        skip_all,
216        fields(
217            db.query.text,
218            oauth2_device_code.id = %id,
219        ),
220        err,
221    )]
222    async fn lookup(&mut self, id: Ulid) -> Result<Option<DeviceCodeGrant>, Self::Error> {
223        let res = sqlx::query_as!(
224            OAuth2DeviceGrantLookup,
225            r#"
226                SELECT oauth2_device_code_grant_id
227                     , oauth2_client_id
228                     , scope
229                     , device_code
230                     , user_code
231                     , created_at
232                     , expires_at
233                     , fulfilled_at
234                     , rejected_at
235                     , exchanged_at
236                     , user_session_id
237                     , oauth2_session_id
238                     , ip_address as "ip_address: IpAddr"
239                     , user_agent
240                FROM
241                    oauth2_device_code_grant
242
243                WHERE oauth2_device_code_grant_id = $1
244            "#,
245            Uuid::from(id),
246        )
247        .traced()
248        .fetch_optional(&mut *self.conn)
249        .await?;
250
251        let Some(res) = res else { return Ok(None) };
252
253        Ok(Some(res.try_into()?))
254    }
255
256    #[tracing::instrument(
257        name = "db.oauth2_device_code_grant.find_by_user_code",
258        skip_all,
259        fields(
260            db.query.text,
261            oauth2_device_code.user_code = %user_code,
262        ),
263        err,
264    )]
265    async fn find_by_user_code(
266        &mut self,
267        user_code: &str,
268    ) -> Result<Option<DeviceCodeGrant>, Self::Error> {
269        let res = sqlx::query_as!(
270            OAuth2DeviceGrantLookup,
271            r#"
272                SELECT oauth2_device_code_grant_id
273                     , oauth2_client_id
274                     , scope
275                     , device_code
276                     , user_code
277                     , created_at
278                     , expires_at
279                     , fulfilled_at
280                     , rejected_at
281                     , exchanged_at
282                     , user_session_id
283                     , oauth2_session_id
284                     , ip_address as "ip_address: IpAddr"
285                     , user_agent
286                FROM
287                    oauth2_device_code_grant
288
289                WHERE user_code = $1
290            "#,
291            user_code,
292        )
293        .traced()
294        .fetch_optional(&mut *self.conn)
295        .await?;
296
297        let Some(res) = res else { return Ok(None) };
298
299        Ok(Some(res.try_into()?))
300    }
301
302    #[tracing::instrument(
303        name = "db.oauth2_device_code_grant.find_by_device_code",
304        skip_all,
305        fields(
306            db.query.text,
307            oauth2_device_code.device_code = %device_code,
308        ),
309        err,
310    )]
311    async fn find_by_device_code(
312        &mut self,
313        device_code: &str,
314    ) -> Result<Option<DeviceCodeGrant>, Self::Error> {
315        let res = sqlx::query_as!(
316            OAuth2DeviceGrantLookup,
317            r#"
318                SELECT oauth2_device_code_grant_id
319                     , oauth2_client_id
320                     , scope
321                     , device_code
322                     , user_code
323                     , created_at
324                     , expires_at
325                     , fulfilled_at
326                     , rejected_at
327                     , exchanged_at
328                     , user_session_id
329                     , oauth2_session_id
330                     , ip_address as "ip_address: IpAddr"
331                     , user_agent
332                FROM
333                    oauth2_device_code_grant
334
335                WHERE device_code = $1
336            "#,
337            device_code,
338        )
339        .traced()
340        .fetch_optional(&mut *self.conn)
341        .await?;
342
343        let Some(res) = res else { return Ok(None) };
344
345        Ok(Some(res.try_into()?))
346    }
347
348    #[tracing::instrument(
349        name = "db.oauth2_device_code_grant.fulfill",
350        skip_all,
351        fields(
352            db.query.text,
353            oauth2_device_code.id = %device_code_grant.id,
354            oauth2_client.id = %device_code_grant.client_id,
355            browser_session.id = %browser_session.id,
356            user.id = %browser_session.user.id,
357        ),
358        err,
359    )]
360    async fn fulfill(
361        &mut self,
362        clock: &dyn Clock,
363        device_code_grant: DeviceCodeGrant,
364        browser_session: &BrowserSession,
365    ) -> Result<DeviceCodeGrant, Self::Error> {
366        let fulfilled_at = clock.now();
367        let device_code_grant = device_code_grant
368            .fulfill(browser_session, fulfilled_at)
369            .map_err(DatabaseError::to_invalid_operation)?;
370
371        let res = sqlx::query!(
372            r#"
373                UPDATE oauth2_device_code_grant
374                SET fulfilled_at = $1
375                  , user_session_id = $2
376                WHERE oauth2_device_code_grant_id = $3
377            "#,
378            fulfilled_at,
379            Uuid::from(browser_session.id),
380            Uuid::from(device_code_grant.id),
381        )
382        .traced()
383        .execute(&mut *self.conn)
384        .await?;
385
386        DatabaseError::ensure_affected_rows(&res, 1)?;
387
388        Ok(device_code_grant)
389    }
390
391    #[tracing::instrument(
392        name = "db.oauth2_device_code_grant.reject",
393        skip_all,
394        fields(
395            db.query.text,
396            oauth2_device_code.id = %device_code_grant.id,
397            oauth2_client.id = %device_code_grant.client_id,
398            browser_session.id = %browser_session.id,
399            user.id = %browser_session.user.id,
400        ),
401        err,
402    )]
403    async fn reject(
404        &mut self,
405        clock: &dyn Clock,
406        device_code_grant: DeviceCodeGrant,
407        browser_session: &BrowserSession,
408    ) -> Result<DeviceCodeGrant, Self::Error> {
409        let fulfilled_at = clock.now();
410        let device_code_grant = device_code_grant
411            .reject(browser_session, fulfilled_at)
412            .map_err(DatabaseError::to_invalid_operation)?;
413
414        let res = sqlx::query!(
415            r#"
416                UPDATE oauth2_device_code_grant
417                SET rejected_at = $1
418                  , user_session_id = $2
419                WHERE oauth2_device_code_grant_id = $3
420            "#,
421            fulfilled_at,
422            Uuid::from(browser_session.id),
423            Uuid::from(device_code_grant.id),
424        )
425        .traced()
426        .execute(&mut *self.conn)
427        .await?;
428
429        DatabaseError::ensure_affected_rows(&res, 1)?;
430
431        Ok(device_code_grant)
432    }
433
434    #[tracing::instrument(
435        name = "db.oauth2_device_code_grant.exchange",
436        skip_all,
437        fields(
438            db.query.text,
439            oauth2_device_code.id = %device_code_grant.id,
440            oauth2_client.id = %device_code_grant.client_id,
441            oauth2_session.id = %session.id,
442        ),
443        err,
444    )]
445    async fn exchange(
446        &mut self,
447        clock: &dyn Clock,
448        device_code_grant: DeviceCodeGrant,
449        session: &Session,
450    ) -> Result<DeviceCodeGrant, Self::Error> {
451        let exchanged_at = clock.now();
452        let device_code_grant = device_code_grant
453            .exchange(session, exchanged_at)
454            .map_err(DatabaseError::to_invalid_operation)?;
455
456        let res = sqlx::query!(
457            r#"
458                UPDATE oauth2_device_code_grant
459                SET exchanged_at = $1
460                  , oauth2_session_id = $2
461                WHERE oauth2_device_code_grant_id = $3
462            "#,
463            exchanged_at,
464            Uuid::from(session.id),
465            Uuid::from(device_code_grant.id),
466        )
467        .traced()
468        .execute(&mut *self.conn)
469        .await?;
470
471        DatabaseError::ensure_affected_rows(&res, 1)?;
472
473        Ok(device_code_grant)
474    }
475}