learn and share notes on atproto (wip) 🦉 malfestio.stormlightlabs.org/
readability solid axum atproto srs

feat: implement token mgmt with DPoP and db persistence.

Changed files
+645 -2
crates
server
migrations
+34 -2
Cargo.lock
··· 1162 1162 dependencies = [ 1163 1163 "serde", 1164 1164 "serde_json", 1165 - "thiserror", 1165 + "thiserror 2.0.17", 1166 1166 ] 1167 1167 1168 1168 [[package]] ··· 1183 1183 "reqwest 0.12.28", 1184 1184 "serde", 1185 1185 "serde_json", 1186 + "serde_qs", 1186 1187 "sha2", 1187 1188 "tokio", 1188 1189 "tokio-postgres", ··· 1991 1992 ] 1992 1993 1993 1994 [[package]] 1995 + name = "serde_qs" 1996 + version = "0.13.0" 1997 + source = "registry+https://github.com/rust-lang/crates.io-index" 1998 + checksum = "cd34f36fe4c5ba9654417139a9b3a20d2e1de6012ee678ad14d240c22c78d8d6" 1999 + dependencies = [ 2000 + "percent-encoding", 2001 + "serde", 2002 + "thiserror 1.0.69", 2003 + ] 2004 + 2005 + [[package]] 1994 2006 name = "serde_urlencoded" 1995 2007 version = "0.7.1" 1996 2008 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 2271 2283 2272 2284 [[package]] 2273 2285 name = "thiserror" 2286 + version = "1.0.69" 2287 + source = "registry+https://github.com/rust-lang/crates.io-index" 2288 + checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" 2289 + dependencies = [ 2290 + "thiserror-impl 1.0.69", 2291 + ] 2292 + 2293 + [[package]] 2294 + name = "thiserror" 2274 2295 version = "2.0.17" 2275 2296 source = "registry+https://github.com/rust-lang/crates.io-index" 2276 2297 checksum = "f63587ca0f12b72a0600bcba1d40081f830876000bb46dd2337a3051618f4fc8" 2277 2298 dependencies = [ 2278 - "thiserror-impl", 2299 + "thiserror-impl 2.0.17", 2300 + ] 2301 + 2302 + [[package]] 2303 + name = "thiserror-impl" 2304 + version = "1.0.69" 2305 + source = "registry+https://github.com/rust-lang/crates.io-index" 2306 + checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" 2307 + dependencies = [ 2308 + "proc-macro2", 2309 + "quote", 2310 + "syn 2.0.111", 2279 2311 ] 2280 2312 2281 2313 [[package]]
+1
crates/server/Cargo.toml
··· 18 18 reqwest = { version = "0.12.28", features = ["json"] } 19 19 serde = "1.0.228" 20 20 serde_json = "1.0.148" 21 + serde_qs = "0.13" 21 22 sha2 = "0.10" 22 23 tokio = { version = "1.48.0", features = ["full"] } 23 24 urlencoding = "2.1"
+1
crates/server/src/api/mod.rs
··· 3 3 pub mod deck; 4 4 pub mod importer; 5 5 pub mod note; 6 + pub mod oauth;
+325
crates/server/src/api/oauth.rs
··· 1 + //! OAuth API endpoints for AT Protocol authentication. 2 + //! 3 + //! Provides endpoints for: 4 + //! - Starting the OAuth authorization flow 5 + //! - Handling OAuth callbacks 6 + //! - Refreshing tokens 7 + 8 + use crate::db::DbPool; 9 + use crate::oauth::flow::{OAuthFlow, SessionStore, generate_state, new_session_store}; 10 + use crate::repository::oauth::{DbOAuthRepository, OAuthRepository, StoreTokensRequest}; 11 + use axum::{ 12 + Json, 13 + extract::{Query, State}, 14 + http::StatusCode, 15 + response::{IntoResponse, Redirect}, 16 + }; 17 + use chrono::{Duration, Utc}; 18 + use serde::{Deserialize, Serialize}; 19 + use serde_json::json; 20 + use std::sync::Arc; 21 + 22 + /// Shared OAuth state with database repository. 23 + pub struct OAuthState { 24 + pub flow: OAuthFlow, 25 + pub sessions: SessionStore, 26 + pub repo: Arc<dyn OAuthRepository>, 27 + } 28 + 29 + impl OAuthState { 30 + /// Create OAuth state with database connection. 31 + pub fn with_pool(pool: DbPool) -> Self { 32 + Self { flow: OAuthFlow::new(), sessions: new_session_store(), repo: Arc::new(DbOAuthRepository::new(pool)) } 33 + } 34 + 35 + /// Create OAuth state without database (for testing). 36 + pub fn new() -> Self { 37 + Self { flow: OAuthFlow::new(), sessions: new_session_store(), repo: Arc::new(MockOAuthRepository) } 38 + } 39 + } 40 + 41 + impl Default for OAuthState { 42 + fn default() -> Self { 43 + Self::new() 44 + } 45 + } 46 + 47 + /// Mock repository for testing. 48 + struct MockOAuthRepository; 49 + 50 + #[async_trait::async_trait] 51 + impl OAuthRepository for MockOAuthRepository { 52 + async fn store_tokens(&self, _req: StoreTokensRequest<'_>) -> Result<(), crate::repository::oauth::OAuthRepoError> { 53 + Ok(()) 54 + } 55 + 56 + async fn get_tokens( 57 + &self, did: &str, 58 + ) -> Result<crate::repository::oauth::StoredToken, crate::repository::oauth::OAuthRepoError> { 59 + Err(crate::repository::oauth::OAuthRepoError::NotFound(did.to_string())) 60 + } 61 + 62 + async fn update_tokens( 63 + &self, _did: &str, _access_token: &str, _refresh_token: Option<&str>, 64 + _expires_at: Option<chrono::DateTime<Utc>>, 65 + ) -> Result<(), crate::repository::oauth::OAuthRepoError> { 66 + Ok(()) 67 + } 68 + 69 + async fn delete_tokens(&self, _did: &str) -> Result<(), crate::repository::oauth::OAuthRepoError> { 70 + Ok(()) 71 + } 72 + } 73 + 74 + /// Request to start OAuth authorization. 75 + #[derive(Deserialize)] 76 + pub struct AuthorizeRequest { 77 + /// Handle or DID to authenticate 78 + pub handle: String, 79 + } 80 + 81 + /// Response from starting authorization. 82 + #[derive(Serialize)] 83 + pub struct AuthorizeResponse { 84 + /// URL to redirect the user to 85 + pub authorization_url: String, 86 + /// State parameter (for CSRF protection) 87 + pub state: String, 88 + } 89 + 90 + /// Query parameters from OAuth callback. 91 + #[derive(Deserialize)] 92 + pub struct CallbackQuery { 93 + pub code: String, 94 + pub state: String, 95 + #[serde(default)] 96 + pub error: Option<String>, 97 + #[serde(default)] 98 + pub error_description: Option<String>, 99 + } 100 + 101 + /// Start the OAuth authorization flow. 102 + /// 103 + /// POST /api/oauth/authorize 104 + /// Body: { "handle": "alice.bsky.social" } 105 + pub async fn authorize( 106 + State(oauth): State<Arc<OAuthState>>, Json(payload): Json<AuthorizeRequest>, 107 + ) -> impl IntoResponse { 108 + let state = generate_state(); 109 + 110 + match oauth 111 + .flow 112 + .start_authorization(&payload.handle, &state, &oauth.sessions) 113 + .await 114 + { 115 + Ok(auth_url) => ( 116 + StatusCode::OK, 117 + Json(AuthorizeResponse { authorization_url: auth_url, state }), 118 + ) 119 + .into_response(), 120 + Err(e) => (StatusCode::BAD_REQUEST, Json(json!({ "error": e.to_string() }))).into_response(), 121 + } 122 + } 123 + 124 + /// Handle OAuth callback from authorization server. 125 + /// 126 + /// GET /api/oauth/callback?code=...&state=... 127 + pub async fn callback(State(oauth): State<Arc<OAuthState>>, Query(params): Query<CallbackQuery>) -> impl IntoResponse { 128 + if let Some(error) = params.error { 129 + let description = params.error_description.unwrap_or_default(); 130 + return Redirect::to(&format!( 131 + "/login?error={}&description={}", 132 + urlencoding::encode(&error), 133 + urlencoding::encode(&description) 134 + )) 135 + .into_response(); 136 + } 137 + 138 + let session = { 139 + let sessions = oauth.sessions.read().unwrap(); 140 + sessions.get(&params.state).cloned() 141 + }; 142 + 143 + let session = match session { 144 + Some(s) => s, 145 + None => { 146 + return Redirect::to("/login?error=session_not_found").into_response(); 147 + } 148 + }; 149 + 150 + match oauth 151 + .flow 152 + .exchange_code(&params.code, &params.state, &oauth.sessions) 153 + .await 154 + { 155 + Ok(tokens) => { 156 + let did = session.did.unwrap_or_default(); 157 + let pds_url = session.pds_url.unwrap_or_default(); 158 + let expires_at = tokens 159 + .expires_in 160 + .map(|secs| Utc::now() + Duration::seconds(secs as i64)); 161 + 162 + if let Err(e) = oauth 163 + .repo 164 + .store_tokens(StoreTokensRequest { 165 + did: &did, 166 + pds_url: &pds_url, 167 + access_token: &tokens.access_token, 168 + refresh_token: tokens.refresh_token.as_deref(), 169 + token_type: &tokens.token_type, 170 + expires_at, 171 + dpop_keypair: &session.dpop_keypair, 172 + }) 173 + .await 174 + { 175 + tracing::error!("Failed to store tokens: {}", e); 176 + return Redirect::to(&format!("/login?error={}", urlencoding::encode("token_storage_failed"))) 177 + .into_response(); 178 + } 179 + 180 + Redirect::to(&format!("/login/success?did={}", urlencoding::encode(&did))).into_response() 181 + } 182 + Err(e) => Redirect::to(&format!("/login?error={}", urlencoding::encode(&e.to_string()))).into_response(), 183 + } 184 + } 185 + 186 + /// Request to refresh tokens. 187 + #[derive(Deserialize)] 188 + pub struct RefreshRequest { 189 + pub did: String, 190 + } 191 + 192 + /// Response from token refresh. 193 + #[derive(Serialize)] 194 + pub struct RefreshResponse { 195 + pub success: bool, 196 + pub expires_at: Option<String>, 197 + } 198 + 199 + /// Refresh an access token. 200 + /// 201 + /// POST /api/oauth/refresh 202 + /// Body: { "did": "did:plc:..." } 203 + pub async fn refresh(State(oauth): State<Arc<OAuthState>>, Json(payload): Json<RefreshRequest>) -> impl IntoResponse { 204 + // Get stored tokens from database 205 + let stored = match oauth.repo.get_tokens(&payload.did).await { 206 + Ok(t) => t, 207 + Err(e) => { 208 + return (StatusCode::NOT_FOUND, Json(json!({ "error": e.to_string() }))).into_response(); 209 + } 210 + }; 211 + 212 + // Reconstruct DPoP keypair 213 + let dpop_keypair = match stored.dpop_keypair() { 214 + Some(kp) => kp, 215 + None => { 216 + return ( 217 + StatusCode::INTERNAL_SERVER_ERROR, 218 + Json(json!({ "error": "Invalid stored keypair" })), 219 + ) 220 + .into_response(); 221 + } 222 + }; 223 + 224 + // Get refresh token 225 + let refresh_token = match &stored.refresh_token { 226 + Some(rt) => rt.clone(), 227 + None => { 228 + return ( 229 + StatusCode::BAD_REQUEST, 230 + Json(json!({ "error": "No refresh token available" })), 231 + ) 232 + .into_response(); 233 + } 234 + }; 235 + 236 + // Refresh tokens via OAuth flow 237 + match oauth 238 + .flow 239 + .refresh_token(&refresh_token, &stored.pds_url, &dpop_keypair) 240 + .await 241 + { 242 + Ok(new_tokens) => { 243 + let expires_at = new_tokens 244 + .expires_in 245 + .map(|secs| Utc::now() + Duration::seconds(secs as i64)); 246 + 247 + if let Err(e) = oauth 248 + .repo 249 + .update_tokens( 250 + &payload.did, 251 + &new_tokens.access_token, 252 + new_tokens.refresh_token.as_deref(), 253 + expires_at, 254 + ) 255 + .await 256 + { 257 + tracing::error!("Failed to update tokens: {}", e); 258 + return ( 259 + StatusCode::INTERNAL_SERVER_ERROR, 260 + Json(json!({ "error": "Failed to update tokens" })), 261 + ) 262 + .into_response(); 263 + } 264 + 265 + ( 266 + StatusCode::OK, 267 + Json(RefreshResponse { success: true, expires_at: expires_at.map(|dt| dt.to_rfc3339()) }), 268 + ) 269 + .into_response() 270 + } 271 + Err(e) => (StatusCode::BAD_REQUEST, Json(json!({ "error": e.to_string() }))).into_response(), 272 + } 273 + } 274 + 275 + #[cfg(test)] 276 + mod tests { 277 + use super::*; 278 + 279 + #[test] 280 + fn test_oauth_state_creation() { 281 + let state = OAuthState::new(); 282 + assert!(state.sessions.read().unwrap().is_empty()); 283 + } 284 + 285 + #[test] 286 + fn test_authorize_request_deserialization() { 287 + let json = r#"{"handle": "alice.bsky.social"}"#; 288 + let request: AuthorizeRequest = serde_json::from_str(json).unwrap(); 289 + assert_eq!(request.handle, "alice.bsky.social"); 290 + } 291 + 292 + #[test] 293 + fn test_authorize_response_serialization() { 294 + let response = AuthorizeResponse { 295 + authorization_url: "https://example.com/oauth".to_string(), 296 + state: "abc123".to_string(), 297 + }; 298 + let json = serde_json::to_string(&response).unwrap(); 299 + assert!(json.contains("authorization_url")); 300 + assert!(json.contains("state")); 301 + } 302 + 303 + #[test] 304 + fn test_callback_query_deserialization() { 305 + let query = "code=abc123&state=xyz789"; 306 + let parsed: CallbackQuery = serde_qs::from_str(query).unwrap(); 307 + assert_eq!(parsed.code, "abc123"); 308 + assert_eq!(parsed.state, "xyz789"); 309 + assert!(parsed.error.is_none()); 310 + } 311 + 312 + #[test] 313 + fn test_callback_query_with_error() { 314 + let query = "code=&state=xyz789&error=access_denied&error_description=User+denied"; 315 + let parsed: CallbackQuery = serde_qs::from_str(query).unwrap(); 316 + assert_eq!(parsed.error, Some("access_denied".to_string())); 317 + } 318 + 319 + #[test] 320 + fn test_refresh_request_deserialization() { 321 + let json = r#"{"did": "did:plc:abc123"}"#; 322 + let request: RefreshRequest = serde_json::from_str(json).unwrap(); 323 + assert_eq!(request.did, "did:plc:abc123"); 324 + } 325 + }
+12
crates/server/src/lib.rs
··· 40 40 tracing::info!("Database connection pool created"); 41 41 42 42 let state = state::AppState::new(pool); 43 + let oauth_state = std::sync::Arc::new(api::oauth::OAuthState::new()); 43 44 44 45 let auth_routes = Router::new() 45 46 .route("/me", get(api::auth::me)) ··· 56 57 .route("/notes/{id}", get(api::note::get_note)) 57 58 .layer(axum_middleware::from_fn(middleware::auth::optional_auth_middleware)); 58 59 60 + let oauth_routes = Router::new() 61 + .route("/authorize", post(api::oauth::authorize)) 62 + .route("/callback", get(api::oauth::callback)) 63 + .route("/refresh", post(api::oauth::refresh)) 64 + .with_state(oauth_state.clone()); 65 + 59 66 let app = Router::new() 60 67 .route("/health", get(health_check)) 68 + .route( 69 + "/.well-known/oauth-client-metadata", 70 + get(oauth::client_metadata::client_metadata_handler), 71 + ) 61 72 .route("/api/auth/login", post(api::auth::login)) 62 73 .route("/api/import/article", post(api::importer::import_article)) 74 + .nest("/api/oauth", oauth_routes) 63 75 .nest("/api", optional_auth_routes) 64 76 .nest("/api", auth_routes) 65 77 .layer(TraceLayer::new_for_http())
+10
crates/server/src/oauth/dpop.rs
··· 89 89 90 90 format!("{}.{}.{}", header_b64, payload_b64, signature_b64) 91 91 } 92 + 93 + /// Create a DpopKeypair from an existing SigningKey. 94 + pub fn from_signing_key(signing_key: SigningKey) -> Self { 95 + Self { signing_key } 96 + } 97 + 98 + /// Get the private key bytes for storage. 99 + pub fn private_key_bytes(&self) -> Vec<u8> { 100 + self.signing_key.to_bytes().to_vec() 101 + } 92 102 } 93 103 94 104 /// Generate a unique JWT ID.
+1
crates/server/src/repository/mod.rs
··· 1 1 pub mod card; 2 2 pub mod note; 3 + pub mod oauth;
+213
crates/server/src/repository/oauth.rs
··· 1 + //! OAuth token repository for database storage. 2 + //! 3 + //! Handles storage and retrieval of OAuth tokens and sessions. 4 + 5 + use crate::db::DbPool; 6 + use crate::oauth::dpop::DpopKeypair; 7 + use async_trait::async_trait; 8 + use chrono::{DateTime, Utc}; 9 + use ed25519_dalek::SigningKey; 10 + use serde::{Deserialize, Serialize}; 11 + 12 + /// Stored OAuth token record. 13 + #[derive(Clone, Serialize, Deserialize)] 14 + pub struct StoredToken { 15 + pub did: String, 16 + pub pds_url: String, 17 + pub access_token: String, 18 + pub refresh_token: Option<String>, 19 + pub token_type: String, 20 + pub expires_at: Option<DateTime<Utc>>, 21 + pub dpop_private_key: Vec<u8>, 22 + pub created_at: DateTime<Utc>, 23 + pub updated_at: DateTime<Utc>, 24 + } 25 + 26 + impl StoredToken { 27 + /// Reconstruct the DPoP keypair from stored bytes. 28 + pub fn dpop_keypair(&self) -> Option<DpopKeypair> { 29 + if self.dpop_private_key.len() != 32 { 30 + return None; 31 + } 32 + let mut key_bytes = [0u8; 32]; 33 + key_bytes.copy_from_slice(&self.dpop_private_key); 34 + let signing_key = SigningKey::from_bytes(&key_bytes); 35 + Some(DpopKeypair::from_signing_key(signing_key)) 36 + } 37 + } 38 + 39 + /// Error type for OAuth repository operations. 40 + #[derive(Debug, Clone)] 41 + pub enum OAuthRepoError { 42 + DatabaseError(String), 43 + NotFound(String), 44 + SerializationError(String), 45 + } 46 + 47 + impl std::fmt::Display for OAuthRepoError { 48 + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 49 + match self { 50 + OAuthRepoError::DatabaseError(e) => write!(f, "Database error: {}", e), 51 + OAuthRepoError::NotFound(e) => write!(f, "Not found: {}", e), 52 + OAuthRepoError::SerializationError(e) => write!(f, "Serialization error: {}", e), 53 + } 54 + } 55 + } 56 + 57 + impl std::error::Error for OAuthRepoError {} 58 + 59 + /// Request to store OAuth tokens. 60 + pub struct StoreTokensRequest<'a> { 61 + pub did: &'a str, 62 + pub pds_url: &'a str, 63 + pub access_token: &'a str, 64 + pub refresh_token: Option<&'a str>, 65 + pub token_type: &'a str, 66 + pub expires_at: Option<DateTime<Utc>>, 67 + pub dpop_keypair: &'a DpopKeypair, 68 + } 69 + 70 + /// Repository trait for OAuth token operations. 71 + #[async_trait] 72 + pub trait OAuthRepository: Send + Sync { 73 + /// Store OAuth tokens for a user. 74 + async fn store_tokens(&self, req: StoreTokensRequest<'_>) -> Result<(), OAuthRepoError>; 75 + 76 + /// Get stored tokens for a user. 77 + async fn get_tokens(&self, did: &str) -> Result<StoredToken, OAuthRepoError>; 78 + 79 + /// Update tokens after refresh. 80 + async fn update_tokens( 81 + &self, did: &str, access_token: &str, refresh_token: Option<&str>, expires_at: Option<DateTime<Utc>>, 82 + ) -> Result<(), OAuthRepoError>; 83 + 84 + /// Delete tokens for a user (logout). 85 + async fn delete_tokens(&self, did: &str) -> Result<(), OAuthRepoError>; 86 + } 87 + 88 + /// Database-backed OAuth repository. 89 + pub struct DbOAuthRepository { 90 + pool: DbPool, 91 + } 92 + 93 + impl DbOAuthRepository { 94 + pub fn new(pool: DbPool) -> Self { 95 + Self { pool } 96 + } 97 + } 98 + 99 + #[async_trait] 100 + impl OAuthRepository for DbOAuthRepository { 101 + async fn store_tokens(&self, req: StoreTokensRequest<'_>) -> Result<(), OAuthRepoError> { 102 + let client = self 103 + .pool 104 + .get() 105 + .await 106 + .map_err(|e| OAuthRepoError::DatabaseError(e.to_string()))?; 107 + 108 + let dpop_bytes = req.dpop_keypair.private_key_bytes(); 109 + 110 + client 111 + .execute( 112 + "INSERT INTO oauth_tokens (did, pds_url, access_token, refresh_token, token_type, expires_at, dpop_private_key) 113 + VALUES ($1, $2, $3, $4, $5, $6, $7) 114 + ON CONFLICT (did) DO UPDATE SET 115 + pds_url = EXCLUDED.pds_url, 116 + access_token = EXCLUDED.access_token, 117 + refresh_token = EXCLUDED.refresh_token, 118 + token_type = EXCLUDED.token_type, 119 + expires_at = EXCLUDED.expires_at, 120 + dpop_private_key = EXCLUDED.dpop_private_key, 121 + updated_at = NOW()", 122 + &[&req.did, &req.pds_url, &req.access_token, &req.refresh_token, &req.token_type, &req.expires_at, &dpop_bytes.as_slice()], 123 + ) 124 + .await 125 + .map_err(|e| OAuthRepoError::DatabaseError(e.to_string()))?; 126 + 127 + Ok(()) 128 + } 129 + 130 + async fn get_tokens(&self, did: &str) -> Result<StoredToken, OAuthRepoError> { 131 + let client = self 132 + .pool 133 + .get() 134 + .await 135 + .map_err(|e| OAuthRepoError::DatabaseError(e.to_string()))?; 136 + 137 + let row = client 138 + .query_opt( 139 + "SELECT did, pds_url, access_token, refresh_token, token_type, expires_at, dpop_private_key, created_at, updated_at 140 + FROM oauth_tokens WHERE did = $1", 141 + &[&did], 142 + ) 143 + .await 144 + .map_err(|e| OAuthRepoError::DatabaseError(e.to_string()))? 145 + .ok_or_else(|| OAuthRepoError::NotFound(format!("No tokens for DID: {}", did)))?; 146 + 147 + Ok(StoredToken { 148 + did: row.get("did"), 149 + pds_url: row.get("pds_url"), 150 + access_token: row.get("access_token"), 151 + refresh_token: row.get("refresh_token"), 152 + token_type: row.get("token_type"), 153 + expires_at: row.get("expires_at"), 154 + dpop_private_key: row.get("dpop_private_key"), 155 + created_at: row.get("created_at"), 156 + updated_at: row.get("updated_at"), 157 + }) 158 + } 159 + 160 + async fn update_tokens( 161 + &self, did: &str, access_token: &str, refresh_token: Option<&str>, expires_at: Option<DateTime<Utc>>, 162 + ) -> Result<(), OAuthRepoError> { 163 + let client = self 164 + .pool 165 + .get() 166 + .await 167 + .map_err(|e| OAuthRepoError::DatabaseError(e.to_string()))?; 168 + 169 + let result = client 170 + .execute( 171 + "UPDATE oauth_tokens SET access_token = $2, refresh_token = $3, expires_at = $4, updated_at = NOW() 172 + WHERE did = $1", 173 + &[&did, &access_token, &refresh_token, &expires_at], 174 + ) 175 + .await 176 + .map_err(|e| OAuthRepoError::DatabaseError(e.to_string()))?; 177 + 178 + if result == 0 { 179 + return Err(OAuthRepoError::NotFound(format!("No tokens for DID: {}", did))); 180 + } 181 + 182 + Ok(()) 183 + } 184 + 185 + async fn delete_tokens(&self, did: &str) -> Result<(), OAuthRepoError> { 186 + let client = self 187 + .pool 188 + .get() 189 + .await 190 + .map_err(|e| OAuthRepoError::DatabaseError(e.to_string()))?; 191 + 192 + client 193 + .execute("DELETE FROM oauth_tokens WHERE did = $1", &[&did]) 194 + .await 195 + .map_err(|e| OAuthRepoError::DatabaseError(e.to_string()))?; 196 + 197 + Ok(()) 198 + } 199 + } 200 + 201 + #[cfg(test)] 202 + mod tests { 203 + use super::*; 204 + 205 + #[test] 206 + fn test_oauth_repo_error_display() { 207 + let err = OAuthRepoError::NotFound("test".to_string()); 208 + assert!(err.to_string().contains("test")); 209 + 210 + let err = OAuthRepoError::DatabaseError("connection failed".to_string()); 211 + assert!(err.to_string().contains("connection failed")); 212 + } 213 + }
+48
migrations/002_2025_12_28_oauth_tokens.sql
··· 1 + -- OAuth tokens and AT-URI storage for AT Protocol integration 2 + -- Adds tables for OAuth session management and AT-URI references 3 + 4 + -- OAuth sessions for tracking authorization flow state 5 + CREATE TABLE oauth_sessions ( 6 + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), 7 + state TEXT NOT NULL UNIQUE, 8 + code_verifier TEXT NOT NULL, 9 + dpop_private_key BYTEA NOT NULL, 10 + did TEXT, 11 + pds_url TEXT, 12 + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), 13 + expires_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + INTERVAL '10 minutes' 14 + ); 15 + 16 + CREATE INDEX idx_oauth_sessions_state ON oauth_sessions(state); 17 + CREATE INDEX idx_oauth_sessions_expires_at ON oauth_sessions(expires_at); 18 + 19 + -- OAuth tokens for authenticated users 20 + CREATE TABLE oauth_tokens ( 21 + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), 22 + did TEXT NOT NULL UNIQUE, 23 + pds_url TEXT NOT NULL, 24 + access_token TEXT NOT NULL, 25 + refresh_token TEXT, 26 + token_type TEXT NOT NULL DEFAULT 'DPoP', 27 + expires_at TIMESTAMPTZ, 28 + dpop_private_key BYTEA NOT NULL, 29 + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), 30 + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() 31 + ); 32 + 33 + CREATE INDEX idx_oauth_tokens_did ON oauth_tokens(did); 34 + 35 + CREATE TRIGGER update_oauth_tokens_updated_at BEFORE UPDATE ON oauth_tokens 36 + FOR EACH ROW EXECUTE FUNCTION update_updated_at_column(); 37 + 38 + -- Add AT-URI columns to existing tables for tracking PDS record references 39 + ALTER TABLE decks ADD COLUMN at_uri TEXT; 40 + ALTER TABLE cards ADD COLUMN at_uri TEXT; 41 + ALTER TABLE notes ADD COLUMN at_uri TEXT; 42 + 43 + CREATE INDEX idx_decks_at_uri ON decks(at_uri) WHERE at_uri IS NOT NULL; 44 + CREATE INDEX idx_cards_at_uri ON cards(at_uri) WHERE at_uri IS NOT NULL; 45 + CREATE INDEX idx_notes_at_uri ON notes(at_uri) WHERE at_uri IS NOT NULL; 46 + 47 + -- Cleanup job for expired sessions (run periodically via cron or similar) 48 + -- DELETE FROM oauth_sessions WHERE expires_at < NOW();