1use std::net::IpAddr;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use mas_data_model::{BrowserSession, DeviceCodeGrant, DeviceCodeGrantState, Session, UserAgent};
12use mas_storage::{
13 Clock,
14 oauth2::{OAuth2DeviceCodeGrantParams, OAuth2DeviceCodeGrantRepository},
15};
16use oauth2_types::scope::Scope;
17use rand::RngCore;
18use sqlx::PgConnection;
19use ulid::Ulid;
20use uuid::Uuid;
21
22use crate::{DatabaseError, ExecuteExt, errors::DatabaseInconsistencyError};
23
24pub struct PgOAuth2DeviceCodeGrantRepository<'c> {
27 conn: &'c mut PgConnection,
28}
29
30impl<'c> PgOAuth2DeviceCodeGrantRepository<'c> {
31 pub fn new(conn: &'c mut PgConnection) -> Self {
34 Self { conn }
35 }
36}
37
38struct OAuth2DeviceGrantLookup {
39 oauth2_device_code_grant_id: Uuid,
40 oauth2_client_id: Uuid,
41 scope: String,
42 device_code: String,
43 user_code: String,
44 created_at: DateTime<Utc>,
45 expires_at: DateTime<Utc>,
46 fulfilled_at: Option<DateTime<Utc>>,
47 rejected_at: Option<DateTime<Utc>>,
48 exchanged_at: Option<DateTime<Utc>>,
49 user_session_id: Option<Uuid>,
50 oauth2_session_id: Option<Uuid>,
51 ip_address: Option<IpAddr>,
52 user_agent: Option<String>,
53}
54
55impl TryFrom<OAuth2DeviceGrantLookup> for DeviceCodeGrant {
56 type Error = DatabaseInconsistencyError;
57
58 fn try_from(
59 OAuth2DeviceGrantLookup {
60 oauth2_device_code_grant_id,
61 oauth2_client_id,
62 scope,
63 device_code,
64 user_code,
65 created_at,
66 expires_at,
67 fulfilled_at,
68 rejected_at,
69 exchanged_at,
70 user_session_id,
71 oauth2_session_id,
72 ip_address,
73 user_agent,
74 }: OAuth2DeviceGrantLookup,
75 ) -> Result<Self, Self::Error> {
76 let id = Ulid::from(oauth2_device_code_grant_id);
77 let client_id = Ulid::from(oauth2_client_id);
78
79 let scope: Scope = scope.parse().map_err(|e| {
80 DatabaseInconsistencyError::on("oauth2_authorization_grants")
81 .column("scope")
82 .row(id)
83 .source(e)
84 })?;
85
86 let state = match (
87 fulfilled_at,
88 rejected_at,
89 exchanged_at,
90 user_session_id,
91 oauth2_session_id,
92 ) {
93 (None, None, None, None, None) => DeviceCodeGrantState::Pending,
94
95 (Some(fulfilled_at), None, None, Some(user_session_id), None) => {
96 DeviceCodeGrantState::Fulfilled {
97 browser_session_id: Ulid::from(user_session_id),
98 fulfilled_at,
99 }
100 }
101
102 (None, Some(rejected_at), None, Some(user_session_id), None) => {
103 DeviceCodeGrantState::Rejected {
104 browser_session_id: Ulid::from(user_session_id),
105 rejected_at,
106 }
107 }
108
109 (
110 Some(fulfilled_at),
111 None,
112 Some(exchanged_at),
113 Some(user_session_id),
114 Some(oauth2_session_id),
115 ) => DeviceCodeGrantState::Exchanged {
116 browser_session_id: Ulid::from(user_session_id),
117 session_id: Ulid::from(oauth2_session_id),
118 fulfilled_at,
119 exchanged_at,
120 },
121
122 _ => return Err(DatabaseInconsistencyError::on("oauth2_device_code_grant").row(id)),
123 };
124
125 Ok(DeviceCodeGrant {
126 id,
127 state,
128 client_id,
129 scope,
130 user_code,
131 device_code,
132 created_at,
133 expires_at,
134 ip_address,
135 user_agent: user_agent.map(UserAgent::parse),
136 })
137 }
138}
139
140#[async_trait]
141impl OAuth2DeviceCodeGrantRepository for PgOAuth2DeviceCodeGrantRepository<'_> {
142 type Error = DatabaseError;
143
144 #[tracing::instrument(
145 name = "db.oauth2_device_code_grant.add",
146 skip_all,
147 fields(
148 db.query.text,
149 oauth2_device_code.id,
150 oauth2_device_code.scope = %params.scope,
151 oauth2_client.id = %params.client.id,
152 ),
153 err,
154 )]
155 async fn add(
156 &mut self,
157 rng: &mut (dyn RngCore + Send),
158 clock: &dyn Clock,
159 params: OAuth2DeviceCodeGrantParams<'_>,
160 ) -> Result<DeviceCodeGrant, Self::Error> {
161 let now = clock.now();
162 let id = Ulid::from_datetime_with_source(now.into(), rng);
163 tracing::Span::current().record("oauth2_device_code.id", tracing::field::display(id));
164
165 let created_at = now;
166 let expires_at = now + params.expires_in;
167 let client_id = params.client.id;
168
169 sqlx::query!(
170 r#"
171 INSERT INTO "oauth2_device_code_grant"
172 ( oauth2_device_code_grant_id
173 , oauth2_client_id
174 , scope
175 , device_code
176 , user_code
177 , created_at
178 , expires_at
179 , ip_address
180 , user_agent
181 )
182 VALUES
183 ($1, $2, $3, $4, $5, $6, $7, $8, $9)
184 "#,
185 Uuid::from(id),
186 Uuid::from(client_id),
187 params.scope.to_string(),
188 ¶ms.device_code,
189 ¶ms.user_code,
190 created_at,
191 expires_at,
192 params.ip_address as Option<IpAddr>,
193 params.user_agent.as_deref(),
194 )
195 .traced()
196 .execute(&mut *self.conn)
197 .await?;
198
199 Ok(DeviceCodeGrant {
200 id,
201 state: DeviceCodeGrantState::Pending,
202 client_id,
203 scope: params.scope,
204 user_code: params.user_code,
205 device_code: params.device_code,
206 created_at,
207 expires_at,
208 ip_address: params.ip_address,
209 user_agent: params.user_agent,
210 })
211 }
212
213 #[tracing::instrument(
214 name = "db.oauth2_device_code_grant.lookup",
215 skip_all,
216 fields(
217 db.query.text,
218 oauth2_device_code.id = %id,
219 ),
220 err,
221 )]
222 async fn lookup(&mut self, id: Ulid) -> Result<Option<DeviceCodeGrant>, Self::Error> {
223 let res = sqlx::query_as!(
224 OAuth2DeviceGrantLookup,
225 r#"
226 SELECT oauth2_device_code_grant_id
227 , oauth2_client_id
228 , scope
229 , device_code
230 , user_code
231 , created_at
232 , expires_at
233 , fulfilled_at
234 , rejected_at
235 , exchanged_at
236 , user_session_id
237 , oauth2_session_id
238 , ip_address as "ip_address: IpAddr"
239 , user_agent
240 FROM
241 oauth2_device_code_grant
242
243 WHERE oauth2_device_code_grant_id = $1
244 "#,
245 Uuid::from(id),
246 )
247 .traced()
248 .fetch_optional(&mut *self.conn)
249 .await?;
250
251 let Some(res) = res else { return Ok(None) };
252
253 Ok(Some(res.try_into()?))
254 }
255
256 #[tracing::instrument(
257 name = "db.oauth2_device_code_grant.find_by_user_code",
258 skip_all,
259 fields(
260 db.query.text,
261 oauth2_device_code.user_code = %user_code,
262 ),
263 err,
264 )]
265 async fn find_by_user_code(
266 &mut self,
267 user_code: &str,
268 ) -> Result<Option<DeviceCodeGrant>, Self::Error> {
269 let res = sqlx::query_as!(
270 OAuth2DeviceGrantLookup,
271 r#"
272 SELECT oauth2_device_code_grant_id
273 , oauth2_client_id
274 , scope
275 , device_code
276 , user_code
277 , created_at
278 , expires_at
279 , fulfilled_at
280 , rejected_at
281 , exchanged_at
282 , user_session_id
283 , oauth2_session_id
284 , ip_address as "ip_address: IpAddr"
285 , user_agent
286 FROM
287 oauth2_device_code_grant
288
289 WHERE user_code = $1
290 "#,
291 user_code,
292 )
293 .traced()
294 .fetch_optional(&mut *self.conn)
295 .await?;
296
297 let Some(res) = res else { return Ok(None) };
298
299 Ok(Some(res.try_into()?))
300 }
301
302 #[tracing::instrument(
303 name = "db.oauth2_device_code_grant.find_by_device_code",
304 skip_all,
305 fields(
306 db.query.text,
307 oauth2_device_code.device_code = %device_code,
308 ),
309 err,
310 )]
311 async fn find_by_device_code(
312 &mut self,
313 device_code: &str,
314 ) -> Result<Option<DeviceCodeGrant>, Self::Error> {
315 let res = sqlx::query_as!(
316 OAuth2DeviceGrantLookup,
317 r#"
318 SELECT oauth2_device_code_grant_id
319 , oauth2_client_id
320 , scope
321 , device_code
322 , user_code
323 , created_at
324 , expires_at
325 , fulfilled_at
326 , rejected_at
327 , exchanged_at
328 , user_session_id
329 , oauth2_session_id
330 , ip_address as "ip_address: IpAddr"
331 , user_agent
332 FROM
333 oauth2_device_code_grant
334
335 WHERE device_code = $1
336 "#,
337 device_code,
338 )
339 .traced()
340 .fetch_optional(&mut *self.conn)
341 .await?;
342
343 let Some(res) = res else { return Ok(None) };
344
345 Ok(Some(res.try_into()?))
346 }
347
348 #[tracing::instrument(
349 name = "db.oauth2_device_code_grant.fulfill",
350 skip_all,
351 fields(
352 db.query.text,
353 oauth2_device_code.id = %device_code_grant.id,
354 oauth2_client.id = %device_code_grant.client_id,
355 browser_session.id = %browser_session.id,
356 user.id = %browser_session.user.id,
357 ),
358 err,
359 )]
360 async fn fulfill(
361 &mut self,
362 clock: &dyn Clock,
363 device_code_grant: DeviceCodeGrant,
364 browser_session: &BrowserSession,
365 ) -> Result<DeviceCodeGrant, Self::Error> {
366 let fulfilled_at = clock.now();
367 let device_code_grant = device_code_grant
368 .fulfill(browser_session, fulfilled_at)
369 .map_err(DatabaseError::to_invalid_operation)?;
370
371 let res = sqlx::query!(
372 r#"
373 UPDATE oauth2_device_code_grant
374 SET fulfilled_at = $1
375 , user_session_id = $2
376 WHERE oauth2_device_code_grant_id = $3
377 "#,
378 fulfilled_at,
379 Uuid::from(browser_session.id),
380 Uuid::from(device_code_grant.id),
381 )
382 .traced()
383 .execute(&mut *self.conn)
384 .await?;
385
386 DatabaseError::ensure_affected_rows(&res, 1)?;
387
388 Ok(device_code_grant)
389 }
390
391 #[tracing::instrument(
392 name = "db.oauth2_device_code_grant.reject",
393 skip_all,
394 fields(
395 db.query.text,
396 oauth2_device_code.id = %device_code_grant.id,
397 oauth2_client.id = %device_code_grant.client_id,
398 browser_session.id = %browser_session.id,
399 user.id = %browser_session.user.id,
400 ),
401 err,
402 )]
403 async fn reject(
404 &mut self,
405 clock: &dyn Clock,
406 device_code_grant: DeviceCodeGrant,
407 browser_session: &BrowserSession,
408 ) -> Result<DeviceCodeGrant, Self::Error> {
409 let fulfilled_at = clock.now();
410 let device_code_grant = device_code_grant
411 .reject(browser_session, fulfilled_at)
412 .map_err(DatabaseError::to_invalid_operation)?;
413
414 let res = sqlx::query!(
415 r#"
416 UPDATE oauth2_device_code_grant
417 SET rejected_at = $1
418 , user_session_id = $2
419 WHERE oauth2_device_code_grant_id = $3
420 "#,
421 fulfilled_at,
422 Uuid::from(browser_session.id),
423 Uuid::from(device_code_grant.id),
424 )
425 .traced()
426 .execute(&mut *self.conn)
427 .await?;
428
429 DatabaseError::ensure_affected_rows(&res, 1)?;
430
431 Ok(device_code_grant)
432 }
433
434 #[tracing::instrument(
435 name = "db.oauth2_device_code_grant.exchange",
436 skip_all,
437 fields(
438 db.query.text,
439 oauth2_device_code.id = %device_code_grant.id,
440 oauth2_client.id = %device_code_grant.client_id,
441 oauth2_session.id = %session.id,
442 ),
443 err,
444 )]
445 async fn exchange(
446 &mut self,
447 clock: &dyn Clock,
448 device_code_grant: DeviceCodeGrant,
449 session: &Session,
450 ) -> Result<DeviceCodeGrant, Self::Error> {
451 let exchanged_at = clock.now();
452 let device_code_grant = device_code_grant
453 .exchange(session, exchanged_at)
454 .map_err(DatabaseError::to_invalid_operation)?;
455
456 let res = sqlx::query!(
457 r#"
458 UPDATE oauth2_device_code_grant
459 SET exchanged_at = $1
460 , oauth2_session_id = $2
461 WHERE oauth2_device_code_grant_id = $3
462 "#,
463 exchanged_at,
464 Uuid::from(session.id),
465 Uuid::from(device_code_grant.id),
466 )
467 .traced()
468 .execute(&mut *self.conn)
469 .await?;
470
471 DatabaseError::ensure_affected_rows(&res, 1)?;
472
473 Ok(device_code_grant)
474 }
475}