forked from
lewis.moe/bspds-sandbox
PDS software with bells & whistles you didn’t even know you needed. will move this to its own account when ready.
1use crate::api::ApiError;
2use crate::state::AppState;
3use axum::{
4 Json,
5 extract::State,
6 http::StatusCode,
7 response::{IntoResponse, Response},
8};
9use chrono::Utc;
10use serde::{Deserialize, Serialize};
11use serde_json::json;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14#[serde(rename_all = "camelCase")]
15pub struct VerificationMethod {
16 pub id: String,
17 #[serde(rename = "type")]
18 pub method_type: String,
19 pub public_key_multibase: String,
20}
21
22#[derive(Deserialize)]
23#[serde(rename_all = "camelCase")]
24pub struct UpdateDidDocumentInput {
25 pub verification_methods: Option<Vec<VerificationMethod>>,
26 pub also_known_as: Option<Vec<String>>,
27 pub service_endpoint: Option<String>,
28}
29
30#[derive(Serialize)]
31#[serde(rename_all = "camelCase")]
32pub struct UpdateDidDocumentOutput {
33 pub success: bool,
34 pub did_document: serde_json::Value,
35}
36
37pub async fn update_did_document(
38 State(state): State<AppState>,
39 headers: axum::http::HeaderMap,
40 Json(input): Json<UpdateDidDocumentInput>,
41) -> Response {
42 let extracted = match crate::auth::extract_auth_token_from_header(
43 headers.get("Authorization").and_then(|h| h.to_str().ok()),
44 ) {
45 Some(t) => t,
46 None => return ApiError::AuthenticationRequired.into_response(),
47 };
48 let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok());
49 let http_uri = format!(
50 "https://{}/xrpc/_account.updateDidDocument",
51 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string())
52 );
53 let auth_user = match crate::auth::validate_token_with_dpop(
54 &state.db,
55 &extracted.token,
56 extracted.is_dpop,
57 dpop_proof,
58 "POST",
59 &http_uri,
60 true,
61 false,
62 )
63 .await
64 {
65 Ok(user) => user,
66 Err(e) => return ApiError::from(e).into_response(),
67 };
68
69 if !auth_user.did.starts_with("did:web:") {
70 return ApiError::InvalidRequest(
71 "DID document updates are only available for did:web accounts".into(),
72 )
73 .into_response();
74 }
75
76 let user = match sqlx::query!(
77 "SELECT id, handle, deactivated_at FROM users WHERE did = $1",
78 &auth_user.did
79 )
80 .fetch_optional(&state.db)
81 .await
82 {
83 Ok(Some(row)) => row,
84 Ok(None) => return ApiError::AccountNotFound.into_response(),
85 Err(e) => {
86 tracing::error!("DB error getting user: {:?}", e);
87 return ApiError::InternalError(None).into_response();
88 }
89 };
90
91 if user.deactivated_at.is_some() {
92 return ApiError::AccountDeactivated.into_response();
93 }
94
95 if let Some(ref methods) = input.verification_methods {
96 if methods.is_empty() {
97 return ApiError::InvalidRequest("verification_methods cannot be empty".into())
98 .into_response();
99 }
100 for method in methods {
101 if method.id.is_empty() {
102 return ApiError::InvalidRequest("verification method id is required".into())
103 .into_response();
104 }
105 if method.method_type != "Multikey" {
106 return ApiError::InvalidRequest(
107 "verification method type must be 'Multikey'".into(),
108 )
109 .into_response();
110 }
111 if !method.public_key_multibase.starts_with('z') {
112 return ApiError::InvalidRequest(
113 "publicKeyMultibase must start with 'z' (base58btc)".into(),
114 )
115 .into_response();
116 }
117 if method.public_key_multibase.len() < 40 {
118 return ApiError::InvalidRequest(
119 "publicKeyMultibase appears too short for a valid key".into(),
120 )
121 .into_response();
122 }
123 }
124 }
125
126 if let Some(ref handles) = input.also_known_as {
127 for handle in handles {
128 if !handle.starts_with("at://") {
129 return ApiError::InvalidRequest("alsoKnownAs entries must be at:// URIs".into())
130 .into_response();
131 }
132 }
133 }
134
135 if let Some(ref endpoint) = input.service_endpoint {
136 let endpoint = endpoint.trim();
137 if !endpoint.starts_with("https://") {
138 return ApiError::InvalidRequest("serviceEndpoint must start with https://".into())
139 .into_response();
140 }
141 }
142
143 let verification_methods_json = input
144 .verification_methods
145 .as_ref()
146 .map(|v| serde_json::to_value(v).unwrap_or_default());
147
148 let also_known_as: Option<Vec<String>> = input.also_known_as.clone();
149
150 let now = Utc::now();
151
152 let upsert_result = sqlx::query!(
153 r#"
154 INSERT INTO did_web_overrides (user_id, verification_methods, also_known_as, updated_at)
155 VALUES ($1, COALESCE($2, '[]'::jsonb), COALESCE($3, '{}'::text[]), $4)
156 ON CONFLICT (user_id) DO UPDATE SET
157 verification_methods = CASE WHEN $2 IS NOT NULL THEN $2 ELSE did_web_overrides.verification_methods END,
158 also_known_as = CASE WHEN $3 IS NOT NULL THEN $3 ELSE did_web_overrides.also_known_as END,
159 updated_at = $4
160 "#,
161 user.id,
162 verification_methods_json,
163 also_known_as.as_deref(),
164 now
165 )
166 .execute(&state.db)
167 .await;
168
169 if let Err(e) = upsert_result {
170 tracing::error!("DB error upserting did_web_overrides: {:?}", e);
171 return ApiError::InternalError(None).into_response();
172 }
173
174 if let Some(ref endpoint) = input.service_endpoint {
175 let endpoint_clean = endpoint.trim().trim_end_matches('/');
176 let update_result = sqlx::query!(
177 "UPDATE users SET migrated_to_pds = $1, migrated_at = $2 WHERE did = $3",
178 endpoint_clean,
179 now,
180 &auth_user.did
181 )
182 .execute(&state.db)
183 .await;
184
185 if let Err(e) = update_result {
186 tracing::error!("DB error updating service endpoint: {:?}", e);
187 return ApiError::InternalError(None).into_response();
188 }
189 }
190
191 let did_doc = build_did_document(&state.db, &auth_user.did).await;
192
193 tracing::info!("Updated DID document for {}", &auth_user.did);
194
195 (
196 StatusCode::OK,
197 Json(UpdateDidDocumentOutput {
198 success: true,
199 did_document: did_doc,
200 }),
201 )
202 .into_response()
203}
204
205pub async fn get_did_document(
206 State(state): State<AppState>,
207 headers: axum::http::HeaderMap,
208) -> Response {
209 let extracted = match crate::auth::extract_auth_token_from_header(
210 headers.get("Authorization").and_then(|h| h.to_str().ok()),
211 ) {
212 Some(t) => t,
213 None => return ApiError::AuthenticationRequired.into_response(),
214 };
215 let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok());
216 let http_uri = format!(
217 "https://{}/xrpc/_account.getDidDocument",
218 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string())
219 );
220 let auth_user = match crate::auth::validate_token_with_dpop(
221 &state.db,
222 &extracted.token,
223 extracted.is_dpop,
224 dpop_proof,
225 "GET",
226 &http_uri,
227 true,
228 false,
229 )
230 .await
231 {
232 Ok(user) => user,
233 Err(e) => return ApiError::from(e).into_response(),
234 };
235
236 if !auth_user.did.starts_with("did:web:") {
237 return ApiError::InvalidRequest(
238 "This endpoint is only available for did:web accounts".into(),
239 )
240 .into_response();
241 }
242
243 let did_doc = build_did_document(&state.db, &auth_user.did).await;
244
245 (StatusCode::OK, Json(json!({ "didDocument": did_doc }))).into_response()
246}
247
248async fn build_did_document(db: &sqlx::PgPool, did: &str) -> serde_json::Value {
249 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string());
250
251 let user = match sqlx::query!(
252 "SELECT id, handle, migrated_to_pds FROM users WHERE did = $1",
253 did
254 )
255 .fetch_optional(db)
256 .await
257 {
258 Ok(Some(row)) => row,
259 _ => {
260 return json!({
261 "error": "User not found"
262 });
263 }
264 };
265
266 let overrides = sqlx::query!(
267 "SELECT verification_methods, also_known_as FROM did_web_overrides WHERE user_id = $1",
268 user.id
269 )
270 .fetch_optional(db)
271 .await
272 .ok()
273 .flatten();
274
275 let service_endpoint = user
276 .migrated_to_pds
277 .unwrap_or_else(|| format!("https://{}", hostname));
278
279 if let Some((ovr, parsed)) = overrides.as_ref().and_then(|ovr| {
280 serde_json::from_value::<Vec<VerificationMethod>>(ovr.verification_methods.clone())
281 .ok()
282 .filter(|p| !p.is_empty())
283 .map(|p| (ovr, p))
284 }) {
285 let also_known_as = if !ovr.also_known_as.is_empty() {
286 ovr.also_known_as.clone()
287 } else {
288 vec![format!("at://{}", user.handle)]
289 };
290 return json!({
291 "@context": [
292 "https://www.w3.org/ns/did/v1",
293 "https://w3id.org/security/multikey/v1",
294 "https://w3id.org/security/suites/secp256k1-2019/v1"
295 ],
296 "id": did,
297 "alsoKnownAs": also_known_as,
298 "verificationMethod": parsed.iter().map(|m| json!({
299 "id": format!("{}{}", did, if m.id.starts_with('#') { m.id.clone() } else { format!("#{}", m.id) }),
300 "type": m.method_type,
301 "controller": did,
302 "publicKeyMultibase": m.public_key_multibase
303 })).collect::<Vec<_>>(),
304 "service": [{
305 "id": "#atproto_pds",
306 "type": "AtprotoPersonalDataServer",
307 "serviceEndpoint": service_endpoint
308 }]
309 });
310 }
311
312 let key_row = sqlx::query!(
313 "SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1",
314 user.id
315 )
316 .fetch_optional(db)
317 .await;
318
319 let public_key_multibase = match key_row {
320 Ok(Some(row)) => match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) {
321 Ok(key_bytes) => crate::api::identity::did::get_public_key_multibase(&key_bytes)
322 .unwrap_or_else(|_| "error".to_string()),
323 Err(_) => "error".to_string(),
324 },
325 _ => "error".to_string(),
326 };
327
328 let also_known_as = if let Some(ref ovr) = overrides {
329 if !ovr.also_known_as.is_empty() {
330 ovr.also_known_as.clone()
331 } else {
332 vec![format!("at://{}", user.handle)]
333 }
334 } else {
335 vec![format!("at://{}", user.handle)]
336 };
337
338 json!({
339 "@context": [
340 "https://www.w3.org/ns/did/v1",
341 "https://w3id.org/security/multikey/v1",
342 "https://w3id.org/security/suites/secp256k1-2019/v1"
343 ],
344 "id": did,
345 "alsoKnownAs": also_known_as,
346 "verificationMethod": [{
347 "id": format!("{}#atproto", did),
348 "type": "Multikey",
349 "controller": did,
350 "publicKeyMultibase": public_key_multibase
351 }],
352 "service": [{
353 "id": "#atproto_pds",
354 "type": "AtprotoPersonalDataServer",
355 "serviceEndpoint": service_endpoint
356 }]
357 })
358}