+34
-2
Cargo.lock
+34
-2
Cargo.lock
···
1162
1162
dependencies = [
1163
1163
"serde",
1164
1164
"serde_json",
1165
-
"thiserror",
1165
+
"thiserror 2.0.17",
1166
1166
]
1167
1167
1168
1168
[[package]]
···
1183
1183
"reqwest 0.12.28",
1184
1184
"serde",
1185
1185
"serde_json",
1186
+
"serde_qs",
1186
1187
"sha2",
1187
1188
"tokio",
1188
1189
"tokio-postgres",
···
1991
1992
]
1992
1993
1993
1994
[[package]]
1995
+
name = "serde_qs"
1996
+
version = "0.13.0"
1997
+
source = "registry+https://github.com/rust-lang/crates.io-index"
1998
+
checksum = "cd34f36fe4c5ba9654417139a9b3a20d2e1de6012ee678ad14d240c22c78d8d6"
1999
+
dependencies = [
2000
+
"percent-encoding",
2001
+
"serde",
2002
+
"thiserror 1.0.69",
2003
+
]
2004
+
2005
+
[[package]]
1994
2006
name = "serde_urlencoded"
1995
2007
version = "0.7.1"
1996
2008
source = "registry+https://github.com/rust-lang/crates.io-index"
···
2271
2283
2272
2284
[[package]]
2273
2285
name = "thiserror"
2286
+
version = "1.0.69"
2287
+
source = "registry+https://github.com/rust-lang/crates.io-index"
2288
+
checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52"
2289
+
dependencies = [
2290
+
"thiserror-impl 1.0.69",
2291
+
]
2292
+
2293
+
[[package]]
2294
+
name = "thiserror"
2274
2295
version = "2.0.17"
2275
2296
source = "registry+https://github.com/rust-lang/crates.io-index"
2276
2297
checksum = "f63587ca0f12b72a0600bcba1d40081f830876000bb46dd2337a3051618f4fc8"
2277
2298
dependencies = [
2278
-
"thiserror-impl",
2299
+
"thiserror-impl 2.0.17",
2300
+
]
2301
+
2302
+
[[package]]
2303
+
name = "thiserror-impl"
2304
+
version = "1.0.69"
2305
+
source = "registry+https://github.com/rust-lang/crates.io-index"
2306
+
checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1"
2307
+
dependencies = [
2308
+
"proc-macro2",
2309
+
"quote",
2310
+
"syn 2.0.111",
2279
2311
]
2280
2312
2281
2313
[[package]]
+1
crates/server/Cargo.toml
+1
crates/server/Cargo.toml
+1
crates/server/src/api/mod.rs
+1
crates/server/src/api/mod.rs
+325
crates/server/src/api/oauth.rs
+325
crates/server/src/api/oauth.rs
···
1
+
//! OAuth API endpoints for AT Protocol authentication.
2
+
//!
3
+
//! Provides endpoints for:
4
+
//! - Starting the OAuth authorization flow
5
+
//! - Handling OAuth callbacks
6
+
//! - Refreshing tokens
7
+
8
+
use crate::db::DbPool;
9
+
use crate::oauth::flow::{OAuthFlow, SessionStore, generate_state, new_session_store};
10
+
use crate::repository::oauth::{DbOAuthRepository, OAuthRepository, StoreTokensRequest};
11
+
use axum::{
12
+
Json,
13
+
extract::{Query, State},
14
+
http::StatusCode,
15
+
response::{IntoResponse, Redirect},
16
+
};
17
+
use chrono::{Duration, Utc};
18
+
use serde::{Deserialize, Serialize};
19
+
use serde_json::json;
20
+
use std::sync::Arc;
21
+
22
+
/// Shared OAuth state with database repository.
23
+
pub struct OAuthState {
24
+
pub flow: OAuthFlow,
25
+
pub sessions: SessionStore,
26
+
pub repo: Arc<dyn OAuthRepository>,
27
+
}
28
+
29
+
impl OAuthState {
30
+
/// Create OAuth state with database connection.
31
+
pub fn with_pool(pool: DbPool) -> Self {
32
+
Self { flow: OAuthFlow::new(), sessions: new_session_store(), repo: Arc::new(DbOAuthRepository::new(pool)) }
33
+
}
34
+
35
+
/// Create OAuth state without database (for testing).
36
+
pub fn new() -> Self {
37
+
Self { flow: OAuthFlow::new(), sessions: new_session_store(), repo: Arc::new(MockOAuthRepository) }
38
+
}
39
+
}
40
+
41
+
impl Default for OAuthState {
42
+
fn default() -> Self {
43
+
Self::new()
44
+
}
45
+
}
46
+
47
+
/// Mock repository for testing.
48
+
struct MockOAuthRepository;
49
+
50
+
#[async_trait::async_trait]
51
+
impl OAuthRepository for MockOAuthRepository {
52
+
async fn store_tokens(&self, _req: StoreTokensRequest<'_>) -> Result<(), crate::repository::oauth::OAuthRepoError> {
53
+
Ok(())
54
+
}
55
+
56
+
async fn get_tokens(
57
+
&self, did: &str,
58
+
) -> Result<crate::repository::oauth::StoredToken, crate::repository::oauth::OAuthRepoError> {
59
+
Err(crate::repository::oauth::OAuthRepoError::NotFound(did.to_string()))
60
+
}
61
+
62
+
async fn update_tokens(
63
+
&self, _did: &str, _access_token: &str, _refresh_token: Option<&str>,
64
+
_expires_at: Option<chrono::DateTime<Utc>>,
65
+
) -> Result<(), crate::repository::oauth::OAuthRepoError> {
66
+
Ok(())
67
+
}
68
+
69
+
async fn delete_tokens(&self, _did: &str) -> Result<(), crate::repository::oauth::OAuthRepoError> {
70
+
Ok(())
71
+
}
72
+
}
73
+
74
+
/// Request to start OAuth authorization.
75
+
#[derive(Deserialize)]
76
+
pub struct AuthorizeRequest {
77
+
/// Handle or DID to authenticate
78
+
pub handle: String,
79
+
}
80
+
81
+
/// Response from starting authorization.
82
+
#[derive(Serialize)]
83
+
pub struct AuthorizeResponse {
84
+
/// URL to redirect the user to
85
+
pub authorization_url: String,
86
+
/// State parameter (for CSRF protection)
87
+
pub state: String,
88
+
}
89
+
90
+
/// Query parameters from OAuth callback.
91
+
#[derive(Deserialize)]
92
+
pub struct CallbackQuery {
93
+
pub code: String,
94
+
pub state: String,
95
+
#[serde(default)]
96
+
pub error: Option<String>,
97
+
#[serde(default)]
98
+
pub error_description: Option<String>,
99
+
}
100
+
101
+
/// Start the OAuth authorization flow.
102
+
///
103
+
/// POST /api/oauth/authorize
104
+
/// Body: { "handle": "alice.bsky.social" }
105
+
pub async fn authorize(
106
+
State(oauth): State<Arc<OAuthState>>, Json(payload): Json<AuthorizeRequest>,
107
+
) -> impl IntoResponse {
108
+
let state = generate_state();
109
+
110
+
match oauth
111
+
.flow
112
+
.start_authorization(&payload.handle, &state, &oauth.sessions)
113
+
.await
114
+
{
115
+
Ok(auth_url) => (
116
+
StatusCode::OK,
117
+
Json(AuthorizeResponse { authorization_url: auth_url, state }),
118
+
)
119
+
.into_response(),
120
+
Err(e) => (StatusCode::BAD_REQUEST, Json(json!({ "error": e.to_string() }))).into_response(),
121
+
}
122
+
}
123
+
124
+
/// Handle OAuth callback from authorization server.
125
+
///
126
+
/// GET /api/oauth/callback?code=...&state=...
127
+
pub async fn callback(State(oauth): State<Arc<OAuthState>>, Query(params): Query<CallbackQuery>) -> impl IntoResponse {
128
+
if let Some(error) = params.error {
129
+
let description = params.error_description.unwrap_or_default();
130
+
return Redirect::to(&format!(
131
+
"/login?error={}&description={}",
132
+
urlencoding::encode(&error),
133
+
urlencoding::encode(&description)
134
+
))
135
+
.into_response();
136
+
}
137
+
138
+
let session = {
139
+
let sessions = oauth.sessions.read().unwrap();
140
+
sessions.get(¶ms.state).cloned()
141
+
};
142
+
143
+
let session = match session {
144
+
Some(s) => s,
145
+
None => {
146
+
return Redirect::to("/login?error=session_not_found").into_response();
147
+
}
148
+
};
149
+
150
+
match oauth
151
+
.flow
152
+
.exchange_code(¶ms.code, ¶ms.state, &oauth.sessions)
153
+
.await
154
+
{
155
+
Ok(tokens) => {
156
+
let did = session.did.unwrap_or_default();
157
+
let pds_url = session.pds_url.unwrap_or_default();
158
+
let expires_at = tokens
159
+
.expires_in
160
+
.map(|secs| Utc::now() + Duration::seconds(secs as i64));
161
+
162
+
if let Err(e) = oauth
163
+
.repo
164
+
.store_tokens(StoreTokensRequest {
165
+
did: &did,
166
+
pds_url: &pds_url,
167
+
access_token: &tokens.access_token,
168
+
refresh_token: tokens.refresh_token.as_deref(),
169
+
token_type: &tokens.token_type,
170
+
expires_at,
171
+
dpop_keypair: &session.dpop_keypair,
172
+
})
173
+
.await
174
+
{
175
+
tracing::error!("Failed to store tokens: {}", e);
176
+
return Redirect::to(&format!("/login?error={}", urlencoding::encode("token_storage_failed")))
177
+
.into_response();
178
+
}
179
+
180
+
Redirect::to(&format!("/login/success?did={}", urlencoding::encode(&did))).into_response()
181
+
}
182
+
Err(e) => Redirect::to(&format!("/login?error={}", urlencoding::encode(&e.to_string()))).into_response(),
183
+
}
184
+
}
185
+
186
+
/// Request to refresh tokens.
187
+
#[derive(Deserialize)]
188
+
pub struct RefreshRequest {
189
+
pub did: String,
190
+
}
191
+
192
+
/// Response from token refresh.
193
+
#[derive(Serialize)]
194
+
pub struct RefreshResponse {
195
+
pub success: bool,
196
+
pub expires_at: Option<String>,
197
+
}
198
+
199
+
/// Refresh an access token.
200
+
///
201
+
/// POST /api/oauth/refresh
202
+
/// Body: { "did": "did:plc:..." }
203
+
pub async fn refresh(State(oauth): State<Arc<OAuthState>>, Json(payload): Json<RefreshRequest>) -> impl IntoResponse {
204
+
// Get stored tokens from database
205
+
let stored = match oauth.repo.get_tokens(&payload.did).await {
206
+
Ok(t) => t,
207
+
Err(e) => {
208
+
return (StatusCode::NOT_FOUND, Json(json!({ "error": e.to_string() }))).into_response();
209
+
}
210
+
};
211
+
212
+
// Reconstruct DPoP keypair
213
+
let dpop_keypair = match stored.dpop_keypair() {
214
+
Some(kp) => kp,
215
+
None => {
216
+
return (
217
+
StatusCode::INTERNAL_SERVER_ERROR,
218
+
Json(json!({ "error": "Invalid stored keypair" })),
219
+
)
220
+
.into_response();
221
+
}
222
+
};
223
+
224
+
// Get refresh token
225
+
let refresh_token = match &stored.refresh_token {
226
+
Some(rt) => rt.clone(),
227
+
None => {
228
+
return (
229
+
StatusCode::BAD_REQUEST,
230
+
Json(json!({ "error": "No refresh token available" })),
231
+
)
232
+
.into_response();
233
+
}
234
+
};
235
+
236
+
// Refresh tokens via OAuth flow
237
+
match oauth
238
+
.flow
239
+
.refresh_token(&refresh_token, &stored.pds_url, &dpop_keypair)
240
+
.await
241
+
{
242
+
Ok(new_tokens) => {
243
+
let expires_at = new_tokens
244
+
.expires_in
245
+
.map(|secs| Utc::now() + Duration::seconds(secs as i64));
246
+
247
+
if let Err(e) = oauth
248
+
.repo
249
+
.update_tokens(
250
+
&payload.did,
251
+
&new_tokens.access_token,
252
+
new_tokens.refresh_token.as_deref(),
253
+
expires_at,
254
+
)
255
+
.await
256
+
{
257
+
tracing::error!("Failed to update tokens: {}", e);
258
+
return (
259
+
StatusCode::INTERNAL_SERVER_ERROR,
260
+
Json(json!({ "error": "Failed to update tokens" })),
261
+
)
262
+
.into_response();
263
+
}
264
+
265
+
(
266
+
StatusCode::OK,
267
+
Json(RefreshResponse { success: true, expires_at: expires_at.map(|dt| dt.to_rfc3339()) }),
268
+
)
269
+
.into_response()
270
+
}
271
+
Err(e) => (StatusCode::BAD_REQUEST, Json(json!({ "error": e.to_string() }))).into_response(),
272
+
}
273
+
}
274
+
275
+
#[cfg(test)]
276
+
mod tests {
277
+
use super::*;
278
+
279
+
#[test]
280
+
fn test_oauth_state_creation() {
281
+
let state = OAuthState::new();
282
+
assert!(state.sessions.read().unwrap().is_empty());
283
+
}
284
+
285
+
#[test]
286
+
fn test_authorize_request_deserialization() {
287
+
let json = r#"{"handle": "alice.bsky.social"}"#;
288
+
let request: AuthorizeRequest = serde_json::from_str(json).unwrap();
289
+
assert_eq!(request.handle, "alice.bsky.social");
290
+
}
291
+
292
+
#[test]
293
+
fn test_authorize_response_serialization() {
294
+
let response = AuthorizeResponse {
295
+
authorization_url: "https://example.com/oauth".to_string(),
296
+
state: "abc123".to_string(),
297
+
};
298
+
let json = serde_json::to_string(&response).unwrap();
299
+
assert!(json.contains("authorization_url"));
300
+
assert!(json.contains("state"));
301
+
}
302
+
303
+
#[test]
304
+
fn test_callback_query_deserialization() {
305
+
let query = "code=abc123&state=xyz789";
306
+
let parsed: CallbackQuery = serde_qs::from_str(query).unwrap();
307
+
assert_eq!(parsed.code, "abc123");
308
+
assert_eq!(parsed.state, "xyz789");
309
+
assert!(parsed.error.is_none());
310
+
}
311
+
312
+
#[test]
313
+
fn test_callback_query_with_error() {
314
+
let query = "code=&state=xyz789&error=access_denied&error_description=User+denied";
315
+
let parsed: CallbackQuery = serde_qs::from_str(query).unwrap();
316
+
assert_eq!(parsed.error, Some("access_denied".to_string()));
317
+
}
318
+
319
+
#[test]
320
+
fn test_refresh_request_deserialization() {
321
+
let json = r#"{"did": "did:plc:abc123"}"#;
322
+
let request: RefreshRequest = serde_json::from_str(json).unwrap();
323
+
assert_eq!(request.did, "did:plc:abc123");
324
+
}
325
+
}
+12
crates/server/src/lib.rs
+12
crates/server/src/lib.rs
···
40
40
tracing::info!("Database connection pool created");
41
41
42
42
let state = state::AppState::new(pool);
43
+
let oauth_state = std::sync::Arc::new(api::oauth::OAuthState::new());
43
44
44
45
let auth_routes = Router::new()
45
46
.route("/me", get(api::auth::me))
···
56
57
.route("/notes/{id}", get(api::note::get_note))
57
58
.layer(axum_middleware::from_fn(middleware::auth::optional_auth_middleware));
58
59
60
+
let oauth_routes = Router::new()
61
+
.route("/authorize", post(api::oauth::authorize))
62
+
.route("/callback", get(api::oauth::callback))
63
+
.route("/refresh", post(api::oauth::refresh))
64
+
.with_state(oauth_state.clone());
65
+
59
66
let app = Router::new()
60
67
.route("/health", get(health_check))
68
+
.route(
69
+
"/.well-known/oauth-client-metadata",
70
+
get(oauth::client_metadata::client_metadata_handler),
71
+
)
61
72
.route("/api/auth/login", post(api::auth::login))
62
73
.route("/api/import/article", post(api::importer::import_article))
74
+
.nest("/api/oauth", oauth_routes)
63
75
.nest("/api", optional_auth_routes)
64
76
.nest("/api", auth_routes)
65
77
.layer(TraceLayer::new_for_http())
+10
crates/server/src/oauth/dpop.rs
+10
crates/server/src/oauth/dpop.rs
···
89
89
90
90
format!("{}.{}.{}", header_b64, payload_b64, signature_b64)
91
91
}
92
+
93
+
/// Create a DpopKeypair from an existing SigningKey.
94
+
pub fn from_signing_key(signing_key: SigningKey) -> Self {
95
+
Self { signing_key }
96
+
}
97
+
98
+
/// Get the private key bytes for storage.
99
+
pub fn private_key_bytes(&self) -> Vec<u8> {
100
+
self.signing_key.to_bytes().to_vec()
101
+
}
92
102
}
93
103
94
104
/// Generate a unique JWT ID.
+213
crates/server/src/repository/oauth.rs
+213
crates/server/src/repository/oauth.rs
···
1
+
//! OAuth token repository for database storage.
2
+
//!
3
+
//! Handles storage and retrieval of OAuth tokens and sessions.
4
+
5
+
use crate::db::DbPool;
6
+
use crate::oauth::dpop::DpopKeypair;
7
+
use async_trait::async_trait;
8
+
use chrono::{DateTime, Utc};
9
+
use ed25519_dalek::SigningKey;
10
+
use serde::{Deserialize, Serialize};
11
+
12
+
/// Stored OAuth token record.
13
+
#[derive(Clone, Serialize, Deserialize)]
14
+
pub struct StoredToken {
15
+
pub did: String,
16
+
pub pds_url: String,
17
+
pub access_token: String,
18
+
pub refresh_token: Option<String>,
19
+
pub token_type: String,
20
+
pub expires_at: Option<DateTime<Utc>>,
21
+
pub dpop_private_key: Vec<u8>,
22
+
pub created_at: DateTime<Utc>,
23
+
pub updated_at: DateTime<Utc>,
24
+
}
25
+
26
+
impl StoredToken {
27
+
/// Reconstruct the DPoP keypair from stored bytes.
28
+
pub fn dpop_keypair(&self) -> Option<DpopKeypair> {
29
+
if self.dpop_private_key.len() != 32 {
30
+
return None;
31
+
}
32
+
let mut key_bytes = [0u8; 32];
33
+
key_bytes.copy_from_slice(&self.dpop_private_key);
34
+
let signing_key = SigningKey::from_bytes(&key_bytes);
35
+
Some(DpopKeypair::from_signing_key(signing_key))
36
+
}
37
+
}
38
+
39
+
/// Error type for OAuth repository operations.
40
+
#[derive(Debug, Clone)]
41
+
pub enum OAuthRepoError {
42
+
DatabaseError(String),
43
+
NotFound(String),
44
+
SerializationError(String),
45
+
}
46
+
47
+
impl std::fmt::Display for OAuthRepoError {
48
+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49
+
match self {
50
+
OAuthRepoError::DatabaseError(e) => write!(f, "Database error: {}", e),
51
+
OAuthRepoError::NotFound(e) => write!(f, "Not found: {}", e),
52
+
OAuthRepoError::SerializationError(e) => write!(f, "Serialization error: {}", e),
53
+
}
54
+
}
55
+
}
56
+
57
+
impl std::error::Error for OAuthRepoError {}
58
+
59
+
/// Request to store OAuth tokens.
60
+
pub struct StoreTokensRequest<'a> {
61
+
pub did: &'a str,
62
+
pub pds_url: &'a str,
63
+
pub access_token: &'a str,
64
+
pub refresh_token: Option<&'a str>,
65
+
pub token_type: &'a str,
66
+
pub expires_at: Option<DateTime<Utc>>,
67
+
pub dpop_keypair: &'a DpopKeypair,
68
+
}
69
+
70
+
/// Repository trait for OAuth token operations.
71
+
#[async_trait]
72
+
pub trait OAuthRepository: Send + Sync {
73
+
/// Store OAuth tokens for a user.
74
+
async fn store_tokens(&self, req: StoreTokensRequest<'_>) -> Result<(), OAuthRepoError>;
75
+
76
+
/// Get stored tokens for a user.
77
+
async fn get_tokens(&self, did: &str) -> Result<StoredToken, OAuthRepoError>;
78
+
79
+
/// Update tokens after refresh.
80
+
async fn update_tokens(
81
+
&self, did: &str, access_token: &str, refresh_token: Option<&str>, expires_at: Option<DateTime<Utc>>,
82
+
) -> Result<(), OAuthRepoError>;
83
+
84
+
/// Delete tokens for a user (logout).
85
+
async fn delete_tokens(&self, did: &str) -> Result<(), OAuthRepoError>;
86
+
}
87
+
88
+
/// Database-backed OAuth repository.
89
+
pub struct DbOAuthRepository {
90
+
pool: DbPool,
91
+
}
92
+
93
+
impl DbOAuthRepository {
94
+
pub fn new(pool: DbPool) -> Self {
95
+
Self { pool }
96
+
}
97
+
}
98
+
99
+
#[async_trait]
100
+
impl OAuthRepository for DbOAuthRepository {
101
+
async fn store_tokens(&self, req: StoreTokensRequest<'_>) -> Result<(), OAuthRepoError> {
102
+
let client = self
103
+
.pool
104
+
.get()
105
+
.await
106
+
.map_err(|e| OAuthRepoError::DatabaseError(e.to_string()))?;
107
+
108
+
let dpop_bytes = req.dpop_keypair.private_key_bytes();
109
+
110
+
client
111
+
.execute(
112
+
"INSERT INTO oauth_tokens (did, pds_url, access_token, refresh_token, token_type, expires_at, dpop_private_key)
113
+
VALUES ($1, $2, $3, $4, $5, $6, $7)
114
+
ON CONFLICT (did) DO UPDATE SET
115
+
pds_url = EXCLUDED.pds_url,
116
+
access_token = EXCLUDED.access_token,
117
+
refresh_token = EXCLUDED.refresh_token,
118
+
token_type = EXCLUDED.token_type,
119
+
expires_at = EXCLUDED.expires_at,
120
+
dpop_private_key = EXCLUDED.dpop_private_key,
121
+
updated_at = NOW()",
122
+
&[&req.did, &req.pds_url, &req.access_token, &req.refresh_token, &req.token_type, &req.expires_at, &dpop_bytes.as_slice()],
123
+
)
124
+
.await
125
+
.map_err(|e| OAuthRepoError::DatabaseError(e.to_string()))?;
126
+
127
+
Ok(())
128
+
}
129
+
130
+
async fn get_tokens(&self, did: &str) -> Result<StoredToken, OAuthRepoError> {
131
+
let client = self
132
+
.pool
133
+
.get()
134
+
.await
135
+
.map_err(|e| OAuthRepoError::DatabaseError(e.to_string()))?;
136
+
137
+
let row = client
138
+
.query_opt(
139
+
"SELECT did, pds_url, access_token, refresh_token, token_type, expires_at, dpop_private_key, created_at, updated_at
140
+
FROM oauth_tokens WHERE did = $1",
141
+
&[&did],
142
+
)
143
+
.await
144
+
.map_err(|e| OAuthRepoError::DatabaseError(e.to_string()))?
145
+
.ok_or_else(|| OAuthRepoError::NotFound(format!("No tokens for DID: {}", did)))?;
146
+
147
+
Ok(StoredToken {
148
+
did: row.get("did"),
149
+
pds_url: row.get("pds_url"),
150
+
access_token: row.get("access_token"),
151
+
refresh_token: row.get("refresh_token"),
152
+
token_type: row.get("token_type"),
153
+
expires_at: row.get("expires_at"),
154
+
dpop_private_key: row.get("dpop_private_key"),
155
+
created_at: row.get("created_at"),
156
+
updated_at: row.get("updated_at"),
157
+
})
158
+
}
159
+
160
+
async fn update_tokens(
161
+
&self, did: &str, access_token: &str, refresh_token: Option<&str>, expires_at: Option<DateTime<Utc>>,
162
+
) -> Result<(), OAuthRepoError> {
163
+
let client = self
164
+
.pool
165
+
.get()
166
+
.await
167
+
.map_err(|e| OAuthRepoError::DatabaseError(e.to_string()))?;
168
+
169
+
let result = client
170
+
.execute(
171
+
"UPDATE oauth_tokens SET access_token = $2, refresh_token = $3, expires_at = $4, updated_at = NOW()
172
+
WHERE did = $1",
173
+
&[&did, &access_token, &refresh_token, &expires_at],
174
+
)
175
+
.await
176
+
.map_err(|e| OAuthRepoError::DatabaseError(e.to_string()))?;
177
+
178
+
if result == 0 {
179
+
return Err(OAuthRepoError::NotFound(format!("No tokens for DID: {}", did)));
180
+
}
181
+
182
+
Ok(())
183
+
}
184
+
185
+
async fn delete_tokens(&self, did: &str) -> Result<(), OAuthRepoError> {
186
+
let client = self
187
+
.pool
188
+
.get()
189
+
.await
190
+
.map_err(|e| OAuthRepoError::DatabaseError(e.to_string()))?;
191
+
192
+
client
193
+
.execute("DELETE FROM oauth_tokens WHERE did = $1", &[&did])
194
+
.await
195
+
.map_err(|e| OAuthRepoError::DatabaseError(e.to_string()))?;
196
+
197
+
Ok(())
198
+
}
199
+
}
200
+
201
+
#[cfg(test)]
202
+
mod tests {
203
+
use super::*;
204
+
205
+
#[test]
206
+
fn test_oauth_repo_error_display() {
207
+
let err = OAuthRepoError::NotFound("test".to_string());
208
+
assert!(err.to_string().contains("test"));
209
+
210
+
let err = OAuthRepoError::DatabaseError("connection failed".to_string());
211
+
assert!(err.to_string().contains("connection failed"));
212
+
}
213
+
}
+48
migrations/002_2025_12_28_oauth_tokens.sql
+48
migrations/002_2025_12_28_oauth_tokens.sql
···
1
+
-- OAuth tokens and AT-URI storage for AT Protocol integration
2
+
-- Adds tables for OAuth session management and AT-URI references
3
+
4
+
-- OAuth sessions for tracking authorization flow state
5
+
CREATE TABLE oauth_sessions (
6
+
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
7
+
state TEXT NOT NULL UNIQUE,
8
+
code_verifier TEXT NOT NULL,
9
+
dpop_private_key BYTEA NOT NULL,
10
+
did TEXT,
11
+
pds_url TEXT,
12
+
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
13
+
expires_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + INTERVAL '10 minutes'
14
+
);
15
+
16
+
CREATE INDEX idx_oauth_sessions_state ON oauth_sessions(state);
17
+
CREATE INDEX idx_oauth_sessions_expires_at ON oauth_sessions(expires_at);
18
+
19
+
-- OAuth tokens for authenticated users
20
+
CREATE TABLE oauth_tokens (
21
+
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
22
+
did TEXT NOT NULL UNIQUE,
23
+
pds_url TEXT NOT NULL,
24
+
access_token TEXT NOT NULL,
25
+
refresh_token TEXT,
26
+
token_type TEXT NOT NULL DEFAULT 'DPoP',
27
+
expires_at TIMESTAMPTZ,
28
+
dpop_private_key BYTEA NOT NULL,
29
+
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
30
+
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
31
+
);
32
+
33
+
CREATE INDEX idx_oauth_tokens_did ON oauth_tokens(did);
34
+
35
+
CREATE TRIGGER update_oauth_tokens_updated_at BEFORE UPDATE ON oauth_tokens
36
+
FOR EACH ROW EXECUTE FUNCTION update_updated_at_column();
37
+
38
+
-- Add AT-URI columns to existing tables for tracking PDS record references
39
+
ALTER TABLE decks ADD COLUMN at_uri TEXT;
40
+
ALTER TABLE cards ADD COLUMN at_uri TEXT;
41
+
ALTER TABLE notes ADD COLUMN at_uri TEXT;
42
+
43
+
CREATE INDEX idx_decks_at_uri ON decks(at_uri) WHERE at_uri IS NOT NULL;
44
+
CREATE INDEX idx_cards_at_uri ON cards(at_uri) WHERE at_uri IS NOT NULL;
45
+
CREATE INDEX idx_notes_at_uri ON notes(at_uri) WHERE at_uri IS NOT NULL;
46
+
47
+
-- Cleanup job for expired sessions (run periodically via cron or similar)
48
+
-- DELETE FROM oauth_sessions WHERE expires_at < NOW();