Our Personal Data Server from scratch!
at main 576 lines 18 kB view raw
1use async_trait::async_trait; 2use chrono::{DateTime, Utc}; 3use sqlx::PgPool; 4use tranquil_db_traits::{ 5 AppPasswordCreate, AppPasswordPrivilege, AppPasswordRecord, DbError, LoginType, 6 RefreshSessionResult, SessionForRefresh, SessionId, SessionListItem, SessionMfaStatus, 7 SessionRefreshData, SessionRepository, SessionToken, SessionTokenCreate, 8}; 9use tranquil_types::Did; 10use uuid::Uuid; 11 12use super::user::map_sqlx_error; 13 14pub struct PostgresSessionRepository { 15 pool: PgPool, 16} 17 18impl PostgresSessionRepository { 19 pub fn new(pool: PgPool) -> Self { 20 Self { pool } 21 } 22} 23 24#[async_trait] 25impl SessionRepository for PostgresSessionRepository { 26 async fn create_session(&self, data: &SessionTokenCreate) -> Result<SessionId, DbError> { 27 let row = sqlx::query!( 28 r#" 29 INSERT INTO session_tokens 30 (did, access_jti, refresh_jti, access_expires_at, refresh_expires_at, 31 legacy_login, mfa_verified, scope, controller_did, app_password_name) 32 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) 33 RETURNING id 34 "#, 35 data.did.as_str(), 36 data.access_jti, 37 data.refresh_jti, 38 data.access_expires_at, 39 data.refresh_expires_at, 40 bool::from(data.login_type), 41 data.mfa_verified, 42 data.scope, 43 data.controller_did.as_ref().map(|d| d.as_str()), 44 data.app_password_name 45 ) 46 .fetch_one(&self.pool) 47 .await 48 .map_err(map_sqlx_error)?; 49 50 Ok(SessionId::new(row.id)) 51 } 52 53 async fn get_session_by_access_jti( 54 &self, 55 access_jti: &str, 56 ) -> Result<Option<SessionToken>, DbError> { 57 let row = sqlx::query!( 58 r#" 59 SELECT id, did, access_jti, refresh_jti, access_expires_at, refresh_expires_at, 60 legacy_login, mfa_verified, scope, controller_did, app_password_name, 61 created_at, updated_at 62 FROM session_tokens 63 WHERE access_jti = $1 64 "#, 65 access_jti 66 ) 67 .fetch_optional(&self.pool) 68 .await 69 .map_err(map_sqlx_error)?; 70 71 Ok(row.map(|r| SessionToken { 72 id: SessionId::new(r.id), 73 did: Did::from(r.did), 74 access_jti: r.access_jti, 75 refresh_jti: r.refresh_jti, 76 access_expires_at: r.access_expires_at, 77 refresh_expires_at: r.refresh_expires_at, 78 login_type: LoginType::from(r.legacy_login), 79 mfa_verified: r.mfa_verified, 80 scope: r.scope, 81 controller_did: r.controller_did.map(Did::from), 82 app_password_name: r.app_password_name, 83 created_at: r.created_at, 84 updated_at: r.updated_at, 85 })) 86 } 87 88 async fn get_session_for_refresh( 89 &self, 90 refresh_jti: &str, 91 ) -> Result<Option<SessionForRefresh>, DbError> { 92 let row = sqlx::query!( 93 r#" 94 SELECT st.id, st.did, st.scope, st.controller_did, k.key_bytes, k.encryption_version 95 FROM session_tokens st 96 JOIN users u ON st.did = u.did 97 JOIN user_keys k ON u.id = k.user_id 98 WHERE st.refresh_jti = $1 AND st.refresh_expires_at > NOW() 99 "#, 100 refresh_jti 101 ) 102 .fetch_optional(&self.pool) 103 .await 104 .map_err(map_sqlx_error)?; 105 106 Ok(row.map(|r| SessionForRefresh { 107 id: SessionId::new(r.id), 108 did: Did::from(r.did), 109 scope: r.scope, 110 controller_did: r.controller_did.map(Did::from), 111 key_bytes: r.key_bytes, 112 encryption_version: r.encryption_version.unwrap_or(0), 113 })) 114 } 115 116 async fn update_session_tokens( 117 &self, 118 session_id: SessionId, 119 new_access_jti: &str, 120 new_refresh_jti: &str, 121 new_access_expires_at: DateTime<Utc>, 122 new_refresh_expires_at: DateTime<Utc>, 123 ) -> Result<(), DbError> { 124 sqlx::query!( 125 r#" 126 UPDATE session_tokens 127 SET access_jti = $1, refresh_jti = $2, access_expires_at = $3, 128 refresh_expires_at = $4, updated_at = NOW() 129 WHERE id = $5 130 "#, 131 new_access_jti, 132 new_refresh_jti, 133 new_access_expires_at, 134 new_refresh_expires_at, 135 session_id.as_i32() 136 ) 137 .execute(&self.pool) 138 .await 139 .map_err(map_sqlx_error)?; 140 141 Ok(()) 142 } 143 144 async fn delete_session_by_access_jti(&self, access_jti: &str) -> Result<u64, DbError> { 145 let result = sqlx::query!( 146 "DELETE FROM session_tokens WHERE access_jti = $1", 147 access_jti 148 ) 149 .execute(&self.pool) 150 .await 151 .map_err(map_sqlx_error)?; 152 153 Ok(result.rows_affected()) 154 } 155 156 async fn delete_session_by_id(&self, session_id: SessionId) -> Result<u64, DbError> { 157 let result = sqlx::query!( 158 "DELETE FROM session_tokens WHERE id = $1", 159 session_id.as_i32() 160 ) 161 .execute(&self.pool) 162 .await 163 .map_err(map_sqlx_error)?; 164 165 Ok(result.rows_affected()) 166 } 167 168 async fn delete_sessions_by_did(&self, did: &Did) -> Result<u64, DbError> { 169 let result = sqlx::query!("DELETE FROM session_tokens WHERE did = $1", did.as_str()) 170 .execute(&self.pool) 171 .await 172 .map_err(map_sqlx_error)?; 173 174 Ok(result.rows_affected()) 175 } 176 177 async fn delete_sessions_by_did_except_jti( 178 &self, 179 did: &Did, 180 except_jti: &str, 181 ) -> Result<u64, DbError> { 182 let result = sqlx::query!( 183 "DELETE FROM session_tokens WHERE did = $1 AND access_jti != $2", 184 did.as_str(), 185 except_jti 186 ) 187 .execute(&self.pool) 188 .await 189 .map_err(map_sqlx_error)?; 190 191 Ok(result.rows_affected()) 192 } 193 194 async fn list_sessions_by_did(&self, did: &Did) -> Result<Vec<SessionListItem>, DbError> { 195 let rows = sqlx::query!( 196 r#" 197 SELECT id, access_jti, created_at, refresh_expires_at 198 FROM session_tokens 199 WHERE did = $1 AND refresh_expires_at > NOW() 200 ORDER BY created_at DESC 201 "#, 202 did.as_str() 203 ) 204 .fetch_all(&self.pool) 205 .await 206 .map_err(map_sqlx_error)?; 207 208 Ok(rows 209 .into_iter() 210 .map(|r| SessionListItem { 211 id: SessionId::new(r.id), 212 access_jti: r.access_jti, 213 created_at: r.created_at, 214 refresh_expires_at: r.refresh_expires_at, 215 }) 216 .collect()) 217 } 218 219 async fn get_session_access_jti_by_id( 220 &self, 221 session_id: SessionId, 222 did: &Did, 223 ) -> Result<Option<String>, DbError> { 224 let row = sqlx::query_scalar!( 225 "SELECT access_jti FROM session_tokens WHERE id = $1 AND did = $2", 226 session_id.as_i32(), 227 did.as_str() 228 ) 229 .fetch_optional(&self.pool) 230 .await 231 .map_err(map_sqlx_error)?; 232 233 Ok(row) 234 } 235 236 async fn delete_sessions_by_app_password( 237 &self, 238 did: &Did, 239 app_password_name: &str, 240 ) -> Result<u64, DbError> { 241 let result = sqlx::query!( 242 "DELETE FROM session_tokens WHERE did = $1 AND app_password_name = $2", 243 did.as_str(), 244 app_password_name 245 ) 246 .execute(&self.pool) 247 .await 248 .map_err(map_sqlx_error)?; 249 250 Ok(result.rows_affected()) 251 } 252 253 async fn get_session_jtis_by_app_password( 254 &self, 255 did: &Did, 256 app_password_name: &str, 257 ) -> Result<Vec<String>, DbError> { 258 let rows = sqlx::query_scalar!( 259 "SELECT access_jti FROM session_tokens WHERE did = $1 AND app_password_name = $2", 260 did.as_str(), 261 app_password_name 262 ) 263 .fetch_all(&self.pool) 264 .await 265 .map_err(map_sqlx_error)?; 266 267 Ok(rows) 268 } 269 270 async fn check_refresh_token_used( 271 &self, 272 refresh_jti: &str, 273 ) -> Result<Option<SessionId>, DbError> { 274 let row = sqlx::query_scalar!( 275 "SELECT session_id FROM used_refresh_tokens WHERE refresh_jti = $1", 276 refresh_jti 277 ) 278 .fetch_optional(&self.pool) 279 .await 280 .map_err(map_sqlx_error)?; 281 282 Ok(row.map(SessionId::new)) 283 } 284 285 async fn mark_refresh_token_used( 286 &self, 287 refresh_jti: &str, 288 session_id: SessionId, 289 ) -> Result<bool, DbError> { 290 let result = sqlx::query!( 291 r#" 292 INSERT INTO used_refresh_tokens (refresh_jti, session_id) 293 VALUES ($1, $2) 294 ON CONFLICT (refresh_jti) DO NOTHING 295 "#, 296 refresh_jti, 297 session_id.as_i32() 298 ) 299 .execute(&self.pool) 300 .await 301 .map_err(map_sqlx_error)?; 302 303 Ok(result.rows_affected() > 0) 304 } 305 306 async fn list_app_passwords(&self, user_id: Uuid) -> Result<Vec<AppPasswordRecord>, DbError> { 307 let rows = sqlx::query!( 308 r#" 309 SELECT id, user_id, name, password_hash, created_at, privileged, scopes, created_by_controller_did 310 FROM app_passwords 311 WHERE user_id = $1 312 ORDER BY created_at DESC 313 "#, 314 user_id 315 ) 316 .fetch_all(&self.pool) 317 .await 318 .map_err(map_sqlx_error)?; 319 320 Ok(rows 321 .into_iter() 322 .map(|r| AppPasswordRecord { 323 id: r.id, 324 user_id: r.user_id, 325 name: r.name, 326 password_hash: r.password_hash, 327 created_at: r.created_at, 328 privilege: AppPasswordPrivilege::from(r.privileged), 329 scopes: r.scopes, 330 created_by_controller_did: r.created_by_controller_did.map(Did::from), 331 }) 332 .collect()) 333 } 334 335 async fn get_app_passwords_for_login( 336 &self, 337 user_id: Uuid, 338 ) -> Result<Vec<AppPasswordRecord>, DbError> { 339 let rows = sqlx::query!( 340 r#" 341 SELECT id, user_id, name, password_hash, created_at, privileged, scopes, created_by_controller_did 342 FROM app_passwords 343 WHERE user_id = $1 344 ORDER BY created_at DESC 345 LIMIT 20 346 "#, 347 user_id 348 ) 349 .fetch_all(&self.pool) 350 .await 351 .map_err(map_sqlx_error)?; 352 353 Ok(rows 354 .into_iter() 355 .map(|r| AppPasswordRecord { 356 id: r.id, 357 user_id: r.user_id, 358 name: r.name, 359 password_hash: r.password_hash, 360 created_at: r.created_at, 361 privilege: AppPasswordPrivilege::from(r.privileged), 362 scopes: r.scopes, 363 created_by_controller_did: r.created_by_controller_did.map(Did::from), 364 }) 365 .collect()) 366 } 367 368 async fn get_app_password_by_name( 369 &self, 370 user_id: Uuid, 371 name: &str, 372 ) -> Result<Option<AppPasswordRecord>, DbError> { 373 let row = sqlx::query!( 374 r#" 375 SELECT id, user_id, name, password_hash, created_at, privileged, scopes, created_by_controller_did 376 FROM app_passwords 377 WHERE user_id = $1 AND name = $2 378 "#, 379 user_id, 380 name 381 ) 382 .fetch_optional(&self.pool) 383 .await 384 .map_err(map_sqlx_error)?; 385 386 Ok(row.map(|r| AppPasswordRecord { 387 id: r.id, 388 user_id: r.user_id, 389 name: r.name, 390 password_hash: r.password_hash, 391 created_at: r.created_at, 392 privilege: AppPasswordPrivilege::from(r.privileged), 393 scopes: r.scopes, 394 created_by_controller_did: r.created_by_controller_did.map(Did::from), 395 })) 396 } 397 398 async fn create_app_password(&self, data: &AppPasswordCreate) -> Result<Uuid, DbError> { 399 let row = sqlx::query!( 400 r#" 401 INSERT INTO app_passwords (user_id, name, password_hash, privileged, scopes, created_by_controller_did) 402 VALUES ($1, $2, $3, $4, $5, $6) 403 RETURNING id 404 "#, 405 data.user_id, 406 data.name, 407 data.password_hash, 408 bool::from(data.privilege), 409 data.scopes, 410 data.created_by_controller_did.as_ref().map(|d| d.as_str()) 411 ) 412 .fetch_one(&self.pool) 413 .await 414 .map_err(map_sqlx_error)?; 415 416 Ok(row.id) 417 } 418 419 async fn delete_app_password(&self, user_id: Uuid, name: &str) -> Result<u64, DbError> { 420 let result = sqlx::query!( 421 "DELETE FROM app_passwords WHERE user_id = $1 AND name = $2", 422 user_id, 423 name 424 ) 425 .execute(&self.pool) 426 .await 427 .map_err(map_sqlx_error)?; 428 429 Ok(result.rows_affected()) 430 } 431 432 async fn delete_app_passwords_by_controller( 433 &self, 434 did: &Did, 435 controller_did: &Did, 436 ) -> Result<u64, DbError> { 437 let result = sqlx::query!( 438 r#"DELETE FROM app_passwords 439 WHERE user_id = (SELECT id FROM users WHERE did = $1) 440 AND created_by_controller_did = $2"#, 441 did.as_str(), 442 controller_did.as_str() 443 ) 444 .execute(&self.pool) 445 .await 446 .map_err(map_sqlx_error)?; 447 448 Ok(result.rows_affected()) 449 } 450 451 async fn get_last_reauth_at(&self, did: &Did) -> Result<Option<DateTime<Utc>>, DbError> { 452 let row = sqlx::query_scalar!( 453 r#"SELECT last_reauth_at FROM session_tokens 454 WHERE did = $1 ORDER BY created_at DESC LIMIT 1"#, 455 did.as_str() 456 ) 457 .fetch_optional(&self.pool) 458 .await 459 .map_err(map_sqlx_error)?; 460 461 Ok(row.flatten()) 462 } 463 464 async fn update_last_reauth(&self, did: &Did) -> Result<DateTime<Utc>, DbError> { 465 let now = Utc::now(); 466 sqlx::query!( 467 "UPDATE session_tokens SET last_reauth_at = $1, mfa_verified = TRUE WHERE did = $2", 468 now, 469 did.as_str() 470 ) 471 .execute(&self.pool) 472 .await 473 .map_err(map_sqlx_error)?; 474 475 Ok(now) 476 } 477 478 async fn get_session_mfa_status(&self, did: &Did) -> Result<Option<SessionMfaStatus>, DbError> { 479 let row = sqlx::query!( 480 r#"SELECT legacy_login, mfa_verified, last_reauth_at FROM session_tokens 481 WHERE did = $1 ORDER BY created_at DESC LIMIT 1"#, 482 did.as_str() 483 ) 484 .fetch_optional(&self.pool) 485 .await 486 .map_err(map_sqlx_error)?; 487 488 Ok(row.map(|r| SessionMfaStatus { 489 login_type: LoginType::from(r.legacy_login), 490 mfa_verified: r.mfa_verified, 491 last_reauth_at: r.last_reauth_at, 492 })) 493 } 494 495 async fn update_mfa_verified(&self, did: &Did) -> Result<(), DbError> { 496 sqlx::query!( 497 "UPDATE session_tokens SET mfa_verified = TRUE, last_reauth_at = NOW() WHERE did = $1", 498 did.as_str() 499 ) 500 .execute(&self.pool) 501 .await 502 .map_err(map_sqlx_error)?; 503 504 Ok(()) 505 } 506 507 async fn get_app_password_hashes_by_did(&self, did: &Did) -> Result<Vec<String>, DbError> { 508 let rows = sqlx::query_scalar!( 509 r#"SELECT ap.password_hash FROM app_passwords ap 510 JOIN users u ON ap.user_id = u.id 511 WHERE u.did = $1"#, 512 did.as_str() 513 ) 514 .fetch_all(&self.pool) 515 .await 516 .map_err(map_sqlx_error)?; 517 518 Ok(rows) 519 } 520 521 async fn refresh_session_atomic( 522 &self, 523 data: &SessionRefreshData, 524 ) -> Result<RefreshSessionResult, DbError> { 525 let mut tx = self.pool.begin().await.map_err(map_sqlx_error)?; 526 527 if let Ok(Some(session_id)) = sqlx::query_scalar!( 528 "SELECT session_id FROM used_refresh_tokens WHERE refresh_jti = $1 FOR UPDATE", 529 data.old_refresh_jti 530 ) 531 .fetch_optional(&mut *tx) 532 .await 533 { 534 let _ = sqlx::query!("DELETE FROM session_tokens WHERE id = $1", session_id) 535 .execute(&mut *tx) 536 .await; 537 tx.commit().await.map_err(map_sqlx_error)?; 538 return Ok(RefreshSessionResult::TokenAlreadyUsed); 539 } 540 541 let result = sqlx::query!( 542 "INSERT INTO used_refresh_tokens (refresh_jti, session_id) VALUES ($1, $2) ON CONFLICT (refresh_jti) DO NOTHING", 543 data.old_refresh_jti, 544 data.session_id.as_i32() 545 ) 546 .execute(&mut *tx) 547 .await 548 .map_err(map_sqlx_error)?; 549 550 if result.rows_affected() == 0 { 551 let _ = sqlx::query!( 552 "DELETE FROM session_tokens WHERE id = $1", 553 data.session_id.as_i32() 554 ) 555 .execute(&mut *tx) 556 .await; 557 tx.commit().await.map_err(map_sqlx_error)?; 558 return Ok(RefreshSessionResult::ConcurrentRefresh); 559 } 560 561 sqlx::query!( 562 "UPDATE session_tokens SET access_jti = $1, refresh_jti = $2, access_expires_at = $3, refresh_expires_at = $4, updated_at = NOW() WHERE id = $5", 563 data.new_access_jti, 564 data.new_refresh_jti, 565 data.new_access_expires_at, 566 data.new_refresh_expires_at, 567 data.session_id.as_i32() 568 ) 569 .execute(&mut *tx) 570 .await 571 .map_err(map_sqlx_error)?; 572 573 tx.commit().await.map_err(map_sqlx_error)?; 574 Ok(RefreshSessionResult::Success) 575 } 576}