PDS software with bells & whistles you didn’t even know you needed. will move this to its own account when ready.
at main 11 kB view raw
1use crate::api::ApiError; 2use crate::state::AppState; 3use axum::{ 4 Json, 5 extract::State, 6 http::StatusCode, 7 response::{IntoResponse, Response}, 8}; 9use chrono::Utc; 10use serde::{Deserialize, Serialize}; 11use serde_json::json; 12 13#[derive(Debug, Clone, Serialize, Deserialize)] 14#[serde(rename_all = "camelCase")] 15pub struct VerificationMethod { 16 pub id: String, 17 #[serde(rename = "type")] 18 pub method_type: String, 19 pub public_key_multibase: String, 20} 21 22#[derive(Deserialize)] 23#[serde(rename_all = "camelCase")] 24pub struct UpdateDidDocumentInput { 25 pub verification_methods: Option<Vec<VerificationMethod>>, 26 pub also_known_as: Option<Vec<String>>, 27 pub service_endpoint: Option<String>, 28} 29 30#[derive(Serialize)] 31#[serde(rename_all = "camelCase")] 32pub struct UpdateDidDocumentOutput { 33 pub success: bool, 34 pub did_document: serde_json::Value, 35} 36 37pub async fn update_did_document( 38 State(state): State<AppState>, 39 headers: axum::http::HeaderMap, 40 Json(input): Json<UpdateDidDocumentInput>, 41) -> Response { 42 let extracted = match crate::auth::extract_auth_token_from_header( 43 headers.get("Authorization").and_then(|h| h.to_str().ok()), 44 ) { 45 Some(t) => t, 46 None => return ApiError::AuthenticationRequired.into_response(), 47 }; 48 let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); 49 let http_uri = format!( 50 "https://{}/xrpc/_account.updateDidDocument", 51 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 52 ); 53 let auth_user = match crate::auth::validate_token_with_dpop( 54 &state.db, 55 &extracted.token, 56 extracted.is_dpop, 57 dpop_proof, 58 "POST", 59 &http_uri, 60 true, 61 false, 62 ) 63 .await 64 { 65 Ok(user) => user, 66 Err(e) => return ApiError::from(e).into_response(), 67 }; 68 69 if !auth_user.did.starts_with("did:web:") { 70 return ApiError::InvalidRequest( 71 "DID document updates are only available for did:web accounts".into(), 72 ) 73 .into_response(); 74 } 75 76 let user = match sqlx::query!( 77 "SELECT id, handle, deactivated_at FROM users WHERE did = $1", 78 &auth_user.did 79 ) 80 .fetch_optional(&state.db) 81 .await 82 { 83 Ok(Some(row)) => row, 84 Ok(None) => return ApiError::AccountNotFound.into_response(), 85 Err(e) => { 86 tracing::error!("DB error getting user: {:?}", e); 87 return ApiError::InternalError(None).into_response(); 88 } 89 }; 90 91 if user.deactivated_at.is_some() { 92 return ApiError::AccountDeactivated.into_response(); 93 } 94 95 if let Some(ref methods) = input.verification_methods { 96 if methods.is_empty() { 97 return ApiError::InvalidRequest("verification_methods cannot be empty".into()) 98 .into_response(); 99 } 100 for method in methods { 101 if method.id.is_empty() { 102 return ApiError::InvalidRequest("verification method id is required".into()) 103 .into_response(); 104 } 105 if method.method_type != "Multikey" { 106 return ApiError::InvalidRequest( 107 "verification method type must be 'Multikey'".into(), 108 ) 109 .into_response(); 110 } 111 if !method.public_key_multibase.starts_with('z') { 112 return ApiError::InvalidRequest( 113 "publicKeyMultibase must start with 'z' (base58btc)".into(), 114 ) 115 .into_response(); 116 } 117 if method.public_key_multibase.len() < 40 { 118 return ApiError::InvalidRequest( 119 "publicKeyMultibase appears too short for a valid key".into(), 120 ) 121 .into_response(); 122 } 123 } 124 } 125 126 if let Some(ref handles) = input.also_known_as { 127 for handle in handles { 128 if !handle.starts_with("at://") { 129 return ApiError::InvalidRequest("alsoKnownAs entries must be at:// URIs".into()) 130 .into_response(); 131 } 132 } 133 } 134 135 if let Some(ref endpoint) = input.service_endpoint { 136 let endpoint = endpoint.trim(); 137 if !endpoint.starts_with("https://") { 138 return ApiError::InvalidRequest("serviceEndpoint must start with https://".into()) 139 .into_response(); 140 } 141 } 142 143 let verification_methods_json = input 144 .verification_methods 145 .as_ref() 146 .map(|v| serde_json::to_value(v).unwrap_or_default()); 147 148 let also_known_as: Option<Vec<String>> = input.also_known_as.clone(); 149 150 let now = Utc::now(); 151 152 let upsert_result = sqlx::query!( 153 r#" 154 INSERT INTO did_web_overrides (user_id, verification_methods, also_known_as, updated_at) 155 VALUES ($1, COALESCE($2, '[]'::jsonb), COALESCE($3, '{}'::text[]), $4) 156 ON CONFLICT (user_id) DO UPDATE SET 157 verification_methods = CASE WHEN $2 IS NOT NULL THEN $2 ELSE did_web_overrides.verification_methods END, 158 also_known_as = CASE WHEN $3 IS NOT NULL THEN $3 ELSE did_web_overrides.also_known_as END, 159 updated_at = $4 160 "#, 161 user.id, 162 verification_methods_json, 163 also_known_as.as_deref(), 164 now 165 ) 166 .execute(&state.db) 167 .await; 168 169 if let Err(e) = upsert_result { 170 tracing::error!("DB error upserting did_web_overrides: {:?}", e); 171 return ApiError::InternalError(None).into_response(); 172 } 173 174 if let Some(ref endpoint) = input.service_endpoint { 175 let endpoint_clean = endpoint.trim().trim_end_matches('/'); 176 let update_result = sqlx::query!( 177 "UPDATE users SET migrated_to_pds = $1, migrated_at = $2 WHERE did = $3", 178 endpoint_clean, 179 now, 180 &auth_user.did 181 ) 182 .execute(&state.db) 183 .await; 184 185 if let Err(e) = update_result { 186 tracing::error!("DB error updating service endpoint: {:?}", e); 187 return ApiError::InternalError(None).into_response(); 188 } 189 } 190 191 let did_doc = build_did_document(&state.db, &auth_user.did).await; 192 193 tracing::info!("Updated DID document for {}", &auth_user.did); 194 195 ( 196 StatusCode::OK, 197 Json(UpdateDidDocumentOutput { 198 success: true, 199 did_document: did_doc, 200 }), 201 ) 202 .into_response() 203} 204 205pub async fn get_did_document( 206 State(state): State<AppState>, 207 headers: axum::http::HeaderMap, 208) -> Response { 209 let extracted = match crate::auth::extract_auth_token_from_header( 210 headers.get("Authorization").and_then(|h| h.to_str().ok()), 211 ) { 212 Some(t) => t, 213 None => return ApiError::AuthenticationRequired.into_response(), 214 }; 215 let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); 216 let http_uri = format!( 217 "https://{}/xrpc/_account.getDidDocument", 218 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 219 ); 220 let auth_user = match crate::auth::validate_token_with_dpop( 221 &state.db, 222 &extracted.token, 223 extracted.is_dpop, 224 dpop_proof, 225 "GET", 226 &http_uri, 227 true, 228 false, 229 ) 230 .await 231 { 232 Ok(user) => user, 233 Err(e) => return ApiError::from(e).into_response(), 234 }; 235 236 if !auth_user.did.starts_with("did:web:") { 237 return ApiError::InvalidRequest( 238 "This endpoint is only available for did:web accounts".into(), 239 ) 240 .into_response(); 241 } 242 243 let did_doc = build_did_document(&state.db, &auth_user.did).await; 244 245 (StatusCode::OK, Json(json!({ "didDocument": did_doc }))).into_response() 246} 247 248async fn build_did_document(db: &sqlx::PgPool, did: &str) -> serde_json::Value { 249 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 250 251 let user = match sqlx::query!( 252 "SELECT id, handle, migrated_to_pds FROM users WHERE did = $1", 253 did 254 ) 255 .fetch_optional(db) 256 .await 257 { 258 Ok(Some(row)) => row, 259 _ => { 260 return json!({ 261 "error": "User not found" 262 }); 263 } 264 }; 265 266 let overrides = sqlx::query!( 267 "SELECT verification_methods, also_known_as FROM did_web_overrides WHERE user_id = $1", 268 user.id 269 ) 270 .fetch_optional(db) 271 .await 272 .ok() 273 .flatten(); 274 275 let service_endpoint = user 276 .migrated_to_pds 277 .unwrap_or_else(|| format!("https://{}", hostname)); 278 279 if let Some((ovr, parsed)) = overrides.as_ref().and_then(|ovr| { 280 serde_json::from_value::<Vec<VerificationMethod>>(ovr.verification_methods.clone()) 281 .ok() 282 .filter(|p| !p.is_empty()) 283 .map(|p| (ovr, p)) 284 }) { 285 let also_known_as = if !ovr.also_known_as.is_empty() { 286 ovr.also_known_as.clone() 287 } else { 288 vec![format!("at://{}", user.handle)] 289 }; 290 return json!({ 291 "@context": [ 292 "https://www.w3.org/ns/did/v1", 293 "https://w3id.org/security/multikey/v1", 294 "https://w3id.org/security/suites/secp256k1-2019/v1" 295 ], 296 "id": did, 297 "alsoKnownAs": also_known_as, 298 "verificationMethod": parsed.iter().map(|m| json!({ 299 "id": format!("{}{}", did, if m.id.starts_with('#') { m.id.clone() } else { format!("#{}", m.id) }), 300 "type": m.method_type, 301 "controller": did, 302 "publicKeyMultibase": m.public_key_multibase 303 })).collect::<Vec<_>>(), 304 "service": [{ 305 "id": "#atproto_pds", 306 "type": "AtprotoPersonalDataServer", 307 "serviceEndpoint": service_endpoint 308 }] 309 }); 310 } 311 312 let key_row = sqlx::query!( 313 "SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1", 314 user.id 315 ) 316 .fetch_optional(db) 317 .await; 318 319 let public_key_multibase = match key_row { 320 Ok(Some(row)) => match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) { 321 Ok(key_bytes) => crate::api::identity::did::get_public_key_multibase(&key_bytes) 322 .unwrap_or_else(|_| "error".to_string()), 323 Err(_) => "error".to_string(), 324 }, 325 _ => "error".to_string(), 326 }; 327 328 let also_known_as = if let Some(ref ovr) = overrides { 329 if !ovr.also_known_as.is_empty() { 330 ovr.also_known_as.clone() 331 } else { 332 vec![format!("at://{}", user.handle)] 333 } 334 } else { 335 vec![format!("at://{}", user.handle)] 336 }; 337 338 json!({ 339 "@context": [ 340 "https://www.w3.org/ns/did/v1", 341 "https://w3id.org/security/multikey/v1", 342 "https://w3id.org/security/suites/secp256k1-2019/v1" 343 ], 344 "id": did, 345 "alsoKnownAs": also_known_as, 346 "verificationMethod": [{ 347 "id": format!("{}#atproto", did), 348 "type": "Multikey", 349 "controller": did, 350 "publicKeyMultibase": public_key_multibase 351 }], 352 "service": [{ 353 "id": "#atproto_pds", 354 "type": "AtprotoPersonalDataServer", 355 "serviceEndpoint": service_endpoint 356 }] 357 }) 358}