forked from
tranquil.farm/tranquil-pds
Our Personal Data Server from scratch!
1use async_trait::async_trait;
2use chrono::{DateTime, Utc};
3use sqlx::PgPool;
4use tranquil_db_traits::{
5 AppPasswordCreate, AppPasswordPrivilege, AppPasswordRecord, DbError, LoginType,
6 RefreshSessionResult, SessionForRefresh, SessionId, SessionListItem, SessionMfaStatus,
7 SessionRefreshData, SessionRepository, SessionToken, SessionTokenCreate,
8};
9use tranquil_types::Did;
10use uuid::Uuid;
11
12use super::user::map_sqlx_error;
13
14pub struct PostgresSessionRepository {
15 pool: PgPool,
16}
17
18impl PostgresSessionRepository {
19 pub fn new(pool: PgPool) -> Self {
20 Self { pool }
21 }
22}
23
24#[async_trait]
25impl SessionRepository for PostgresSessionRepository {
26 async fn create_session(&self, data: &SessionTokenCreate) -> Result<SessionId, DbError> {
27 let row = sqlx::query!(
28 r#"
29 INSERT INTO session_tokens
30 (did, access_jti, refresh_jti, access_expires_at, refresh_expires_at,
31 legacy_login, mfa_verified, scope, controller_did, app_password_name)
32 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
33 RETURNING id
34 "#,
35 data.did.as_str(),
36 data.access_jti,
37 data.refresh_jti,
38 data.access_expires_at,
39 data.refresh_expires_at,
40 bool::from(data.login_type),
41 data.mfa_verified,
42 data.scope,
43 data.controller_did.as_ref().map(|d| d.as_str()),
44 data.app_password_name
45 )
46 .fetch_one(&self.pool)
47 .await
48 .map_err(map_sqlx_error)?;
49
50 Ok(SessionId::new(row.id))
51 }
52
53 async fn get_session_by_access_jti(
54 &self,
55 access_jti: &str,
56 ) -> Result<Option<SessionToken>, DbError> {
57 let row = sqlx::query!(
58 r#"
59 SELECT id, did, access_jti, refresh_jti, access_expires_at, refresh_expires_at,
60 legacy_login, mfa_verified, scope, controller_did, app_password_name,
61 created_at, updated_at
62 FROM session_tokens
63 WHERE access_jti = $1
64 "#,
65 access_jti
66 )
67 .fetch_optional(&self.pool)
68 .await
69 .map_err(map_sqlx_error)?;
70
71 Ok(row.map(|r| SessionToken {
72 id: SessionId::new(r.id),
73 did: Did::from(r.did),
74 access_jti: r.access_jti,
75 refresh_jti: r.refresh_jti,
76 access_expires_at: r.access_expires_at,
77 refresh_expires_at: r.refresh_expires_at,
78 login_type: LoginType::from(r.legacy_login),
79 mfa_verified: r.mfa_verified,
80 scope: r.scope,
81 controller_did: r.controller_did.map(Did::from),
82 app_password_name: r.app_password_name,
83 created_at: r.created_at,
84 updated_at: r.updated_at,
85 }))
86 }
87
88 async fn get_session_for_refresh(
89 &self,
90 refresh_jti: &str,
91 ) -> Result<Option<SessionForRefresh>, DbError> {
92 let row = sqlx::query!(
93 r#"
94 SELECT st.id, st.did, st.scope, st.controller_did, k.key_bytes, k.encryption_version
95 FROM session_tokens st
96 JOIN users u ON st.did = u.did
97 JOIN user_keys k ON u.id = k.user_id
98 WHERE st.refresh_jti = $1 AND st.refresh_expires_at > NOW()
99 "#,
100 refresh_jti
101 )
102 .fetch_optional(&self.pool)
103 .await
104 .map_err(map_sqlx_error)?;
105
106 Ok(row.map(|r| SessionForRefresh {
107 id: SessionId::new(r.id),
108 did: Did::from(r.did),
109 scope: r.scope,
110 controller_did: r.controller_did.map(Did::from),
111 key_bytes: r.key_bytes,
112 encryption_version: r.encryption_version.unwrap_or(0),
113 }))
114 }
115
116 async fn update_session_tokens(
117 &self,
118 session_id: SessionId,
119 new_access_jti: &str,
120 new_refresh_jti: &str,
121 new_access_expires_at: DateTime<Utc>,
122 new_refresh_expires_at: DateTime<Utc>,
123 ) -> Result<(), DbError> {
124 sqlx::query!(
125 r#"
126 UPDATE session_tokens
127 SET access_jti = $1, refresh_jti = $2, access_expires_at = $3,
128 refresh_expires_at = $4, updated_at = NOW()
129 WHERE id = $5
130 "#,
131 new_access_jti,
132 new_refresh_jti,
133 new_access_expires_at,
134 new_refresh_expires_at,
135 session_id.as_i32()
136 )
137 .execute(&self.pool)
138 .await
139 .map_err(map_sqlx_error)?;
140
141 Ok(())
142 }
143
144 async fn delete_session_by_access_jti(&self, access_jti: &str) -> Result<u64, DbError> {
145 let result = sqlx::query!(
146 "DELETE FROM session_tokens WHERE access_jti = $1",
147 access_jti
148 )
149 .execute(&self.pool)
150 .await
151 .map_err(map_sqlx_error)?;
152
153 Ok(result.rows_affected())
154 }
155
156 async fn delete_session_by_id(&self, session_id: SessionId) -> Result<u64, DbError> {
157 let result = sqlx::query!(
158 "DELETE FROM session_tokens WHERE id = $1",
159 session_id.as_i32()
160 )
161 .execute(&self.pool)
162 .await
163 .map_err(map_sqlx_error)?;
164
165 Ok(result.rows_affected())
166 }
167
168 async fn delete_sessions_by_did(&self, did: &Did) -> Result<u64, DbError> {
169 let result = sqlx::query!("DELETE FROM session_tokens WHERE did = $1", did.as_str())
170 .execute(&self.pool)
171 .await
172 .map_err(map_sqlx_error)?;
173
174 Ok(result.rows_affected())
175 }
176
177 async fn delete_sessions_by_did_except_jti(
178 &self,
179 did: &Did,
180 except_jti: &str,
181 ) -> Result<u64, DbError> {
182 let result = sqlx::query!(
183 "DELETE FROM session_tokens WHERE did = $1 AND access_jti != $2",
184 did.as_str(),
185 except_jti
186 )
187 .execute(&self.pool)
188 .await
189 .map_err(map_sqlx_error)?;
190
191 Ok(result.rows_affected())
192 }
193
194 async fn list_sessions_by_did(&self, did: &Did) -> Result<Vec<SessionListItem>, DbError> {
195 let rows = sqlx::query!(
196 r#"
197 SELECT id, access_jti, created_at, refresh_expires_at
198 FROM session_tokens
199 WHERE did = $1 AND refresh_expires_at > NOW()
200 ORDER BY created_at DESC
201 "#,
202 did.as_str()
203 )
204 .fetch_all(&self.pool)
205 .await
206 .map_err(map_sqlx_error)?;
207
208 Ok(rows
209 .into_iter()
210 .map(|r| SessionListItem {
211 id: SessionId::new(r.id),
212 access_jti: r.access_jti,
213 created_at: r.created_at,
214 refresh_expires_at: r.refresh_expires_at,
215 })
216 .collect())
217 }
218
219 async fn get_session_access_jti_by_id(
220 &self,
221 session_id: SessionId,
222 did: &Did,
223 ) -> Result<Option<String>, DbError> {
224 let row = sqlx::query_scalar!(
225 "SELECT access_jti FROM session_tokens WHERE id = $1 AND did = $2",
226 session_id.as_i32(),
227 did.as_str()
228 )
229 .fetch_optional(&self.pool)
230 .await
231 .map_err(map_sqlx_error)?;
232
233 Ok(row)
234 }
235
236 async fn delete_sessions_by_app_password(
237 &self,
238 did: &Did,
239 app_password_name: &str,
240 ) -> Result<u64, DbError> {
241 let result = sqlx::query!(
242 "DELETE FROM session_tokens WHERE did = $1 AND app_password_name = $2",
243 did.as_str(),
244 app_password_name
245 )
246 .execute(&self.pool)
247 .await
248 .map_err(map_sqlx_error)?;
249
250 Ok(result.rows_affected())
251 }
252
253 async fn get_session_jtis_by_app_password(
254 &self,
255 did: &Did,
256 app_password_name: &str,
257 ) -> Result<Vec<String>, DbError> {
258 let rows = sqlx::query_scalar!(
259 "SELECT access_jti FROM session_tokens WHERE did = $1 AND app_password_name = $2",
260 did.as_str(),
261 app_password_name
262 )
263 .fetch_all(&self.pool)
264 .await
265 .map_err(map_sqlx_error)?;
266
267 Ok(rows)
268 }
269
270 async fn check_refresh_token_used(
271 &self,
272 refresh_jti: &str,
273 ) -> Result<Option<SessionId>, DbError> {
274 let row = sqlx::query_scalar!(
275 "SELECT session_id FROM used_refresh_tokens WHERE refresh_jti = $1",
276 refresh_jti
277 )
278 .fetch_optional(&self.pool)
279 .await
280 .map_err(map_sqlx_error)?;
281
282 Ok(row.map(SessionId::new))
283 }
284
285 async fn mark_refresh_token_used(
286 &self,
287 refresh_jti: &str,
288 session_id: SessionId,
289 ) -> Result<bool, DbError> {
290 let result = sqlx::query!(
291 r#"
292 INSERT INTO used_refresh_tokens (refresh_jti, session_id)
293 VALUES ($1, $2)
294 ON CONFLICT (refresh_jti) DO NOTHING
295 "#,
296 refresh_jti,
297 session_id.as_i32()
298 )
299 .execute(&self.pool)
300 .await
301 .map_err(map_sqlx_error)?;
302
303 Ok(result.rows_affected() > 0)
304 }
305
306 async fn list_app_passwords(&self, user_id: Uuid) -> Result<Vec<AppPasswordRecord>, DbError> {
307 let rows = sqlx::query!(
308 r#"
309 SELECT id, user_id, name, password_hash, created_at, privileged, scopes, created_by_controller_did
310 FROM app_passwords
311 WHERE user_id = $1
312 ORDER BY created_at DESC
313 "#,
314 user_id
315 )
316 .fetch_all(&self.pool)
317 .await
318 .map_err(map_sqlx_error)?;
319
320 Ok(rows
321 .into_iter()
322 .map(|r| AppPasswordRecord {
323 id: r.id,
324 user_id: r.user_id,
325 name: r.name,
326 password_hash: r.password_hash,
327 created_at: r.created_at,
328 privilege: AppPasswordPrivilege::from(r.privileged),
329 scopes: r.scopes,
330 created_by_controller_did: r.created_by_controller_did.map(Did::from),
331 })
332 .collect())
333 }
334
335 async fn get_app_passwords_for_login(
336 &self,
337 user_id: Uuid,
338 ) -> Result<Vec<AppPasswordRecord>, DbError> {
339 let rows = sqlx::query!(
340 r#"
341 SELECT id, user_id, name, password_hash, created_at, privileged, scopes, created_by_controller_did
342 FROM app_passwords
343 WHERE user_id = $1
344 ORDER BY created_at DESC
345 LIMIT 20
346 "#,
347 user_id
348 )
349 .fetch_all(&self.pool)
350 .await
351 .map_err(map_sqlx_error)?;
352
353 Ok(rows
354 .into_iter()
355 .map(|r| AppPasswordRecord {
356 id: r.id,
357 user_id: r.user_id,
358 name: r.name,
359 password_hash: r.password_hash,
360 created_at: r.created_at,
361 privilege: AppPasswordPrivilege::from(r.privileged),
362 scopes: r.scopes,
363 created_by_controller_did: r.created_by_controller_did.map(Did::from),
364 })
365 .collect())
366 }
367
368 async fn get_app_password_by_name(
369 &self,
370 user_id: Uuid,
371 name: &str,
372 ) -> Result<Option<AppPasswordRecord>, DbError> {
373 let row = sqlx::query!(
374 r#"
375 SELECT id, user_id, name, password_hash, created_at, privileged, scopes, created_by_controller_did
376 FROM app_passwords
377 WHERE user_id = $1 AND name = $2
378 "#,
379 user_id,
380 name
381 )
382 .fetch_optional(&self.pool)
383 .await
384 .map_err(map_sqlx_error)?;
385
386 Ok(row.map(|r| AppPasswordRecord {
387 id: r.id,
388 user_id: r.user_id,
389 name: r.name,
390 password_hash: r.password_hash,
391 created_at: r.created_at,
392 privilege: AppPasswordPrivilege::from(r.privileged),
393 scopes: r.scopes,
394 created_by_controller_did: r.created_by_controller_did.map(Did::from),
395 }))
396 }
397
398 async fn create_app_password(&self, data: &AppPasswordCreate) -> Result<Uuid, DbError> {
399 let row = sqlx::query!(
400 r#"
401 INSERT INTO app_passwords (user_id, name, password_hash, privileged, scopes, created_by_controller_did)
402 VALUES ($1, $2, $3, $4, $5, $6)
403 RETURNING id
404 "#,
405 data.user_id,
406 data.name,
407 data.password_hash,
408 bool::from(data.privilege),
409 data.scopes,
410 data.created_by_controller_did.as_ref().map(|d| d.as_str())
411 )
412 .fetch_one(&self.pool)
413 .await
414 .map_err(map_sqlx_error)?;
415
416 Ok(row.id)
417 }
418
419 async fn delete_app_password(&self, user_id: Uuid, name: &str) -> Result<u64, DbError> {
420 let result = sqlx::query!(
421 "DELETE FROM app_passwords WHERE user_id = $1 AND name = $2",
422 user_id,
423 name
424 )
425 .execute(&self.pool)
426 .await
427 .map_err(map_sqlx_error)?;
428
429 Ok(result.rows_affected())
430 }
431
432 async fn delete_app_passwords_by_controller(
433 &self,
434 did: &Did,
435 controller_did: &Did,
436 ) -> Result<u64, DbError> {
437 let result = sqlx::query!(
438 r#"DELETE FROM app_passwords
439 WHERE user_id = (SELECT id FROM users WHERE did = $1)
440 AND created_by_controller_did = $2"#,
441 did.as_str(),
442 controller_did.as_str()
443 )
444 .execute(&self.pool)
445 .await
446 .map_err(map_sqlx_error)?;
447
448 Ok(result.rows_affected())
449 }
450
451 async fn get_last_reauth_at(&self, did: &Did) -> Result<Option<DateTime<Utc>>, DbError> {
452 let row = sqlx::query_scalar!(
453 r#"SELECT last_reauth_at FROM session_tokens
454 WHERE did = $1 ORDER BY created_at DESC LIMIT 1"#,
455 did.as_str()
456 )
457 .fetch_optional(&self.pool)
458 .await
459 .map_err(map_sqlx_error)?;
460
461 Ok(row.flatten())
462 }
463
464 async fn update_last_reauth(&self, did: &Did) -> Result<DateTime<Utc>, DbError> {
465 let now = Utc::now();
466 sqlx::query!(
467 "UPDATE session_tokens SET last_reauth_at = $1, mfa_verified = TRUE WHERE did = $2",
468 now,
469 did.as_str()
470 )
471 .execute(&self.pool)
472 .await
473 .map_err(map_sqlx_error)?;
474
475 Ok(now)
476 }
477
478 async fn get_session_mfa_status(&self, did: &Did) -> Result<Option<SessionMfaStatus>, DbError> {
479 let row = sqlx::query!(
480 r#"SELECT legacy_login, mfa_verified, last_reauth_at FROM session_tokens
481 WHERE did = $1 ORDER BY created_at DESC LIMIT 1"#,
482 did.as_str()
483 )
484 .fetch_optional(&self.pool)
485 .await
486 .map_err(map_sqlx_error)?;
487
488 Ok(row.map(|r| SessionMfaStatus {
489 login_type: LoginType::from(r.legacy_login),
490 mfa_verified: r.mfa_verified,
491 last_reauth_at: r.last_reauth_at,
492 }))
493 }
494
495 async fn update_mfa_verified(&self, did: &Did) -> Result<(), DbError> {
496 sqlx::query!(
497 "UPDATE session_tokens SET mfa_verified = TRUE, last_reauth_at = NOW() WHERE did = $1",
498 did.as_str()
499 )
500 .execute(&self.pool)
501 .await
502 .map_err(map_sqlx_error)?;
503
504 Ok(())
505 }
506
507 async fn get_app_password_hashes_by_did(&self, did: &Did) -> Result<Vec<String>, DbError> {
508 let rows = sqlx::query_scalar!(
509 r#"SELECT ap.password_hash FROM app_passwords ap
510 JOIN users u ON ap.user_id = u.id
511 WHERE u.did = $1"#,
512 did.as_str()
513 )
514 .fetch_all(&self.pool)
515 .await
516 .map_err(map_sqlx_error)?;
517
518 Ok(rows)
519 }
520
521 async fn refresh_session_atomic(
522 &self,
523 data: &SessionRefreshData,
524 ) -> Result<RefreshSessionResult, DbError> {
525 let mut tx = self.pool.begin().await.map_err(map_sqlx_error)?;
526
527 if let Ok(Some(session_id)) = sqlx::query_scalar!(
528 "SELECT session_id FROM used_refresh_tokens WHERE refresh_jti = $1 FOR UPDATE",
529 data.old_refresh_jti
530 )
531 .fetch_optional(&mut *tx)
532 .await
533 {
534 let _ = sqlx::query!("DELETE FROM session_tokens WHERE id = $1", session_id)
535 .execute(&mut *tx)
536 .await;
537 tx.commit().await.map_err(map_sqlx_error)?;
538 return Ok(RefreshSessionResult::TokenAlreadyUsed);
539 }
540
541 let result = sqlx::query!(
542 "INSERT INTO used_refresh_tokens (refresh_jti, session_id) VALUES ($1, $2) ON CONFLICT (refresh_jti) DO NOTHING",
543 data.old_refresh_jti,
544 data.session_id.as_i32()
545 )
546 .execute(&mut *tx)
547 .await
548 .map_err(map_sqlx_error)?;
549
550 if result.rows_affected() == 0 {
551 let _ = sqlx::query!(
552 "DELETE FROM session_tokens WHERE id = $1",
553 data.session_id.as_i32()
554 )
555 .execute(&mut *tx)
556 .await;
557 tx.commit().await.map_err(map_sqlx_error)?;
558 return Ok(RefreshSessionResult::ConcurrentRefresh);
559 }
560
561 sqlx::query!(
562 "UPDATE session_tokens SET access_jti = $1, refresh_jti = $2, access_expires_at = $3, refresh_expires_at = $4, updated_at = NOW() WHERE id = $5",
563 data.new_access_jti,
564 data.new_refresh_jti,
565 data.new_access_expires_at,
566 data.new_refresh_expires_at,
567 data.session_id.as_i32()
568 )
569 .execute(&mut *tx)
570 .await
571 .map_err(map_sqlx_error)?;
572
573 tx.commit().await.map_err(map_sqlx_error)?;
574 Ok(RefreshSessionResult::Success)
575 }
576}