Our Personal Data Server from scratch! tranquil.farm
atproto pds rust postgresql fun oauth

refactor(api): extract repo write lifecycle to repo_ops #79

merged opened by oyster.cafe targeting main from refactor/api
Labels

None yet.

assignee

None yet.

Participants 1
AT URI
at://did:plc:3fwecdnvtcscjnrx2p4n7alz/sh.tangled.repo.pull/3mhi3qdcvqi22
+317 -485
Diff #0
+1 -2
crates/tranquil-api/src/repo/record/mod.rs
··· 2 2 pub mod delete; 3 3 pub mod pagination; 4 4 pub mod read; 5 - pub mod utils; 6 5 pub mod validation; 7 6 pub mod validation_mode; 8 7 pub mod write; ··· 13 12 pub use batch::apply_writes; 14 13 pub use delete::{DeleteRecordInput, delete_record, delete_record_internal}; 15 14 pub use read::{GetRecordInput, ListRecordsInput, ListRecordsOutput, get_record, list_records}; 16 - pub use utils::*; 15 + pub use tranquil_pds::repo_ops::*; 17 16 pub use write::{ 18 17 CreateRecordInput, CreateRecordOutput, PutRecordInput, PutRecordOutput, create_record, 19 18 prepare_repo_write, put_record,
+7 -68
crates/tranquil-api/src/repo/record/read.rs
··· 1 1 use super::pagination::{PaginationDirection, deserialize_pagination_direction}; 2 + use crate::common; 2 3 use axum::{ 3 4 Json, 4 5 extract::{Query, State}, ··· 59 60 _headers: HeaderMap, 60 61 Query(input): Query<GetRecordInput>, 61 62 ) -> Response { 62 - let hostname_for_handles = tranquil_config::get().server.hostname_without_port(); 63 - let user_id_opt = if input.repo.is_did() { 64 - let did: tranquil_pds::types::Did = match input.repo.as_str().parse() { 65 - Ok(d) => d, 66 - Err(_) => return ApiError::InvalidRequest("Invalid DID format".into()).into_response(), 67 - }; 68 - state.user_repo.get_id_by_did(&did).await.map_err(|_| ()) 69 - } else { 70 - let repo_str = input.repo.as_str(); 71 - let handle_str = if !repo_str.contains('.') { 72 - format!("{}.{}", repo_str, hostname_for_handles) 73 - } else { 74 - repo_str.to_string() 75 - }; 76 - let handle: tranquil_pds::types::Handle = match handle_str.parse() { 77 - Ok(h) => h, 78 - Err(_) => { 79 - return ApiError::InvalidRequest("Invalid handle format".into()).into_response(); 80 - } 81 - }; 82 - state 83 - .user_repo 84 - .get_id_by_handle(&handle) 85 - .await 86 - .map_err(|_| ()) 87 - }; 88 - let user_id: uuid::Uuid = match user_id_opt { 89 - Ok(Some(id)) => id, 90 - Ok(None) => { 91 - return ApiError::RepoNotFound(Some("Repo not found".into())).into_response(); 92 - } 93 - Err(_) => { 94 - return ApiError::InternalError(None).into_response(); 95 - } 63 + let user_id = match common::resolve_repo_user_id(state.user_repo.as_ref(), &input.repo).await { 64 + Ok(id) => id, 65 + Err(e) => return e.into_response(), 96 66 }; 97 67 let record_row = state 98 68 .repo_repo ··· 158 128 State(state): State<AppState>, 159 129 Query(input): Query<ListRecordsInput>, 160 130 ) -> Response { 161 - let hostname_for_handles = tranquil_config::get().server.hostname_without_port(); 162 - let user_id_opt = if input.repo.is_did() { 163 - let did: tranquil_pds::types::Did = match input.repo.as_str().parse() { 164 - Ok(d) => d, 165 - Err(_) => return ApiError::InvalidRequest("Invalid DID format".into()).into_response(), 166 - }; 167 - state.user_repo.get_id_by_did(&did).await.map_err(|_| ()) 168 - } else { 169 - let repo_str = input.repo.as_str(); 170 - let handle_str = if !repo_str.contains('.') { 171 - format!("{}.{}", repo_str, hostname_for_handles) 172 - } else { 173 - repo_str.to_string() 174 - }; 175 - let handle: tranquil_pds::types::Handle = match handle_str.parse() { 176 - Ok(h) => h, 177 - Err(_) => { 178 - return ApiError::InvalidRequest("Invalid handle format".into()).into_response(); 179 - } 180 - }; 181 - state 182 - .user_repo 183 - .get_id_by_handle(&handle) 184 - .await 185 - .map_err(|_| ()) 186 - }; 187 - let user_id: uuid::Uuid = match user_id_opt { 188 - Ok(Some(id)) => id, 189 - Ok(None) => { 190 - return ApiError::RepoNotFound(Some("Repo not found".into())).into_response(); 191 - } 192 - Err(_) => { 193 - return ApiError::InternalError(None).into_response(); 194 - } 131 + let user_id = match common::resolve_repo_user_id(state.user_repo.as_ref(), &input.repo).await { 132 + Ok(id) => id, 133 + Err(e) => return e.into_response(), 195 134 }; 196 135 let limit = input.limit.unwrap_or(50).clamp(1, 100); 197 136 let limit_i64 = i64::from(limit);
+6 -9
crates/tranquil-api/src/repo/record/validation.rs
··· 1 - use axum::response::Response; 2 1 use tranquil_pds::api::error::ApiError; 3 2 use tranquil_pds::types::{Nsid, Rkey}; 4 3 use tranquil_pds::validation::{RecordValidator, ValidationError, ValidationStatus}; ··· 8 7 collection: &Nsid, 9 8 rkey: Option<&Rkey>, 10 9 require_lexicon: bool, 11 - ) -> Result<ValidationStatus, Box<Response>> { 10 + ) -> Result<ValidationStatus, ApiError> { 12 11 let registry = tranquil_lexicon::LexiconRegistry::global(); 13 12 if !registry.has_schema(collection.as_str()) { 14 13 let _ = registry.resolve_dynamic(collection.as_str()).await; 15 14 } 16 15 17 16 let validator = RecordValidator::new().require_lexicon(require_lexicon); 18 - match validator.validate_with_rkey(record, collection.as_str(), rkey.map(|r| r.as_str())) { 19 - Ok(status) => Ok(status), 20 - Err(e) => Err(validation_error_to_box_response(e)), 21 - } 17 + validator 18 + .validate_with_rkey(record, collection.as_str(), rkey.map(|v| v.as_str())) 19 + .map_err(validation_error_to_api_error) 22 20 } 23 21 24 - fn validation_error_to_box_response(e: ValidationError) -> Box<Response> { 25 - use axum::response::IntoResponse; 22 + fn validation_error_to_api_error(e: ValidationError) -> ApiError { 26 23 let msg = match e { 27 24 ValidationError::MissingType => "Record must have a $type field".to_string(), 28 25 ValidationError::TypeMismatch { expected, actual } => { ··· 44 41 ValidationError::UnknownType(type_name) => format!("Lexicon not found: lex:{}", type_name), 45 42 e => e.to_string(), 46 43 }; 47 - Box::new(ApiError::InvalidRecord(msg).into_response()) 44 + ApiError::InvalidRecord(msg) 48 45 }
+170 -403
crates/tranquil-api/src/repo/record/write.rs
··· 1 1 use super::validation::validate_record_with_status; 2 2 use super::validation_mode::{ValidationMode, deserialize_validation_mode}; 3 - use crate::repo::record::utils::{ 4 - CommitParams, RecordOp, commit_and_log, extract_backlinks, extract_blob_cids, 5 - get_current_root_cid, 6 - }; 7 - use axum::{ 8 - Json, 9 - extract::State, 10 - http::StatusCode, 11 - response::{IntoResponse, Response}, 12 - }; 3 + use axum::{Json, extract::State}; 13 4 use cid::Cid; 14 - use jacquard_repo::{commit::Commit, mst::Mst, storage::BlockStore}; 5 + use jacquard_repo::storage::BlockStore; 15 6 use serde::{Deserialize, Serialize}; 16 7 use serde_json::json; 17 8 use std::str::FromStr; 18 - use std::sync::Arc; 19 9 use tracing::error; 20 - use tranquil_pds::api::error::ApiError; 10 + use tranquil_pds::api::error::{ApiError, DbResultExt}; 21 11 use tranquil_pds::auth::{ 22 12 Active, Auth, AuthSource, RepoScopeAction, ScopeVerified, VerifyScope, require_not_migrated, 23 13 require_verified_or_delegated, 24 14 }; 25 - use tranquil_pds::cid_types::CommitCid; 26 - use tranquil_pds::delegation::DelegationActionType; 27 - use tranquil_pds::repo::tracking::TrackingBlockStore; 15 + use tranquil_pds::repo_ops::{ 16 + FinalizeParams, RecordOp, begin_repo_write, extract_backlinks, extract_blob_cids, 17 + finalize_repo_write, 18 + }; 28 19 use tranquil_pds::state::AppState; 29 20 use tranquil_pds::types::{AtIdentifier, AtUri, Did, Nsid, Rkey}; 30 21 use tranquil_pds::validation::ValidationStatus; ··· 42 33 state: &AppState, 43 34 scope_proof: &ScopeVerified<'_, A>, 44 35 repo: &AtIdentifier, 45 - ) -> Result<RepoWriteAuth, Response> { 36 + ) -> Result<RepoWriteAuth, ApiError> { 46 37 let user = scope_proof.user(); 47 38 let principal_did = scope_proof.principal_did(); 48 39 if repo.as_str() != principal_did.as_str() { 49 - return Err( 50 - ApiError::InvalidRepo("Repo does not match authenticated user".into()).into_response(), 51 - ); 40 + return Err(ApiError::InvalidRepo( 41 + "Repo does not match authenticated user".into(), 42 + )); 52 43 } 53 44 54 45 require_not_migrated(state, principal_did.as_did()).await?; ··· 58 49 .user_repo 59 50 .get_id_by_did(principal_did.as_did()) 60 51 .await 61 - .map_err(|e| { 62 - error!("DB error fetching user: {}", e); 63 - ApiError::InternalError(None).into_response() 64 - })? 65 - .ok_or_else(|| ApiError::InternalError(Some("User not found".into())).into_response())?; 52 + .log_db_err("fetching user for repo write")? 53 + .ok_or(ApiError::InternalError(Some("User not found".into())))?; 66 54 67 55 Ok(RepoWriteAuth { 68 56 did: principal_did.into_did(), ··· 72 60 controller_did: scope_proof.controller_did().map(|c| c.into_did()), 73 61 }) 74 62 } 63 + 75 64 #[derive(Deserialize)] 76 65 #[allow(dead_code)] 77 66 pub struct CreateRecordInput { ··· 84 73 #[serde(rename = "swapCommit")] 85 74 pub swap_commit: Option<String>, 86 75 } 76 + 87 77 #[derive(Serialize)] 88 78 #[serde(rename_all = "camelCase")] 89 79 pub struct CommitInfo { ··· 100 90 #[serde(skip_serializing_if = "Option::is_none")] 101 91 pub validation_status: Option<ValidationStatus>, 102 92 } 93 + 103 94 pub async fn create_record( 104 95 State(state): State<AppState>, 105 96 auth: Auth<Active>, 106 97 Json(input): Json<CreateRecordInput>, 107 - ) -> Result<Response, tranquil_pds::api::error::ApiError> { 108 - let scope_proof = match auth.verify_repo_create(&input.collection) { 109 - Ok(proof) => proof, 110 - Err(e) => return Ok(e.into_response()), 111 - }; 112 - 113 - let repo_auth = match prepare_repo_write(&state, &scope_proof, &input.repo).await { 114 - Ok(res) => res, 115 - Err(err_res) => return Ok(err_res), 116 - }; 117 - 98 + ) -> Result<Json<CreateRecordOutput>, ApiError> { 99 + let scope_proof = auth.verify_repo_create(&input.collection)?; 100 + let repo_auth = prepare_repo_write(&state, &scope_proof, &input.repo).await?; 118 101 let did = repo_auth.did; 119 102 let user_id = repo_auth.user_id; 120 103 let controller_did = repo_auth.controller_did; 121 104 122 - let _write_lock = state.repo_write_locks.lock(user_id).await; 123 - let current_root_cid = get_current_root_cid(&state, user_id).await?; 124 - 125 - if let Some(swap_commit) = &input.swap_commit 126 - && CommitCid::from_str(swap_commit).ok().as_ref() != Some(&current_root_cid) 127 - { 128 - return Ok(ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response()); 129 - } 105 + let (ctx, mut mst) = begin_repo_write(&state, user_id, input.swap_commit.as_deref()).await?; 130 106 131 107 let validation_status = if input.validate.should_skip() { 132 108 None 133 109 } else { 134 - match validate_record_with_status( 135 - &input.record, 136 - &input.collection, 137 - input.rkey.as_ref(), 138 - input.validate.requires_lexicon(), 110 + Some( 111 + validate_record_with_status( 112 + &input.record, 113 + &input.collection, 114 + input.rkey.as_ref(), 115 + input.validate.requires_lexicon(), 116 + ) 117 + .await?, 139 118 ) 140 - .await 141 - { 142 - Ok(status) => Some(status), 143 - Err(err_response) => return Ok(*err_response), 144 - } 145 119 }; 146 - let rkey = input.rkey.unwrap_or_else(Rkey::generate); 147 - 148 - let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 149 - let commit_bytes = match tracking_store.get(current_root_cid.as_cid()).await { 150 - Ok(Some(b)) => b, 151 - _ => { 152 - return Ok( 153 - ApiError::InternalError(Some("Commit block not found".into())).into_response(), 154 - ); 155 - } 156 - }; 157 - let commit = match Commit::from_cbor(&commit_bytes) { 158 - Ok(c) => c, 159 - _ => { 160 - return Ok( 161 - ApiError::InternalError(Some("Failed to parse commit".into())).into_response(), 162 - ); 163 - } 164 - }; 165 - let mut mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None); 166 - let initial_mst_root = commit.data; 167 120 121 + let rkey = input.rkey.unwrap_or_else(Rkey::generate); 168 122 let mut ops: Vec<RecordOp> = Vec::new(); 169 123 let mut conflict_uris_to_cleanup: Vec<AtUri> = Vec::new(); 170 - let mut all_old_mst_blocks = std::collections::BTreeMap::new(); 171 124 172 125 if !input.validate.should_skip() { 173 126 let record_uri = AtUri::from_parts(&did, &input.collection, &rkey); 174 127 let backlinks = extract_backlinks(&record_uri, &input.record); 175 128 176 129 if !backlinks.is_empty() { 177 - let conflicts = match state 130 + let conflicts = state 178 131 .backlink_repo 179 132 .get_backlink_conflicts(user_id, &input.collection, &backlinks) 180 133 .await 181 - { 182 - Ok(c) => c, 183 - Err(e) => { 184 - error!("Failed to check backlink conflicts: {}", e); 185 - return Ok(ApiError::InternalError(None).into_response()); 186 - } 187 - }; 134 + .log_db_err("checking backlink conflicts")?; 188 135 189 136 for conflict_uri in conflicts { 190 - let conflict_rkey = match conflict_uri.rkey() { 191 - Some(r) => Rkey::from(r.to_string()), 192 - None => continue, 193 - }; 194 - let conflict_collection = match conflict_uri.collection() { 195 - Some(c) => Nsid::from(c.to_string()), 196 - None => continue, 137 + let (Some(conflict_rkey_str), Some(conflict_col_str)) = 138 + (conflict_uri.rkey(), conflict_uri.collection()) 139 + else { 140 + continue; 197 141 }; 142 + let conflict_rkey = Rkey::from(conflict_rkey_str.to_string()); 143 + let conflict_collection = Nsid::from(conflict_col_str.to_string()); 198 144 let conflict_key = format!("{}/{}", conflict_collection, conflict_rkey); 199 145 200 146 let prev_cid = match mst.get(&conflict_key).await { 201 147 Ok(Some(cid)) => cid, 202 - Ok(None) => continue, 203 - Err(_) => continue, 148 + _ => continue, 204 149 }; 205 150 206 - if mst 207 - .blocks_for_path(&conflict_key, &mut all_old_mst_blocks) 208 - .await 209 - .is_err() 210 - { 211 - error!("Failed to get old MST blocks for conflict {}", conflict_uri); 212 - } 213 - 214 151 mst = match mst.delete(&conflict_key).await { 215 152 Ok(m) => m, 216 153 Err(e) => { ··· 233 170 } 234 171 235 172 let record_ipld = tranquil_pds::util::json_to_ipld(&input.record); 236 - let mut record_bytes = Vec::new(); 237 - if serde_ipld_dagcbor::to_writer(&mut record_bytes, &record_ipld).is_err() { 238 - return Ok(ApiError::InvalidRecord("Failed to serialize record".into()).into_response()); 239 - } 240 - let record_cid = match tracking_store.put(&record_bytes).await { 241 - Ok(c) => c, 242 - _ => { 243 - return Ok( 244 - ApiError::InternalError(Some("Failed to save record block".into())).into_response(), 245 - ); 246 - } 247 - }; 248 - let key = format!("{}/{}", input.collection, rkey); 249 - 250 - if mst 251 - .blocks_for_path(&key, &mut all_old_mst_blocks) 173 + let record_bytes = serde_ipld_dagcbor::to_vec(&record_ipld) 174 + .map_err(|_| ApiError::InvalidRecord("Failed to serialize record".into()))?; 175 + let record_cid = ctx 176 + .tracking_store 177 + .put(&record_bytes) 252 178 .await 253 - .is_err() 254 - { 255 - error!("Failed to get old MST blocks for new record path"); 256 - } 179 + .map_err(|_| ApiError::InternalError(Some("Failed to save record block".into())))?; 257 180 258 - let new_mst = match mst.add(&key, record_cid).await { 259 - Ok(m) => m, 260 - _ => { 261 - return Ok(ApiError::InternalError(Some("Failed to add to MST".into())).into_response()); 262 - } 263 - }; 264 - let new_mst_root = match new_mst.persist().await { 265 - Ok(c) => c, 266 - _ => { 267 - return Ok( 268 - ApiError::InternalError(Some("Failed to persist MST".into())).into_response(), 269 - ); 270 - } 271 - }; 181 + let key = format!("{}/{}", input.collection, rkey); 182 + mst = mst 183 + .add(&key, record_cid) 184 + .await 185 + .map_err(|_| ApiError::InternalError(Some("Failed to add to MST".into())))?; 272 186 273 187 ops.push(RecordOp::Create { 274 188 collection: input.collection.clone(), ··· 276 190 cid: record_cid, 277 191 }); 278 192 279 - let mut new_mst_blocks = std::collections::BTreeMap::new(); 280 - if new_mst 281 - .blocks_for_path(&key, &mut new_mst_blocks) 282 - .await 283 - .is_err() 284 - { 285 - return Ok( 286 - ApiError::InternalError(Some("Failed to get new MST blocks for path".into())) 287 - .into_response(), 288 - ); 289 - } 290 - 291 - let mut relevant_blocks = new_mst_blocks.clone(); 292 - relevant_blocks.extend(all_old_mst_blocks.iter().map(|(k, v)| (*k, v.clone()))); 293 - relevant_blocks.insert(record_cid, bytes::Bytes::new()); 294 - let written_cids: Vec<Cid> = tracking_store 295 - .get_all_relevant_cids() 296 - .into_iter() 297 - .chain(relevant_blocks.keys().copied()) 298 - .collect::<std::collections::HashSet<_>>() 299 - .into_iter() 193 + let modified_keys: Vec<String> = ops 194 + .iter() 195 + .map(|op| match op { 196 + RecordOp::Create { 197 + collection, rkey, .. 198 + } 199 + | RecordOp::Update { 200 + collection, rkey, .. 201 + } 202 + | RecordOp::Delete { 203 + collection, rkey, .. 204 + } => format!("{}/{}", collection, rkey), 205 + }) 300 206 .collect(); 301 - let written_cids_str: Vec<String> = written_cids.iter().map(|c| c.to_string()).collect(); 302 207 let blob_cids = extract_blob_cids(&input.record); 303 - let obsolete_cids: Vec<Cid> = std::iter::once(current_root_cid.into_cid()) 304 - .chain( 305 - all_old_mst_blocks 306 - .keys() 307 - .filter(|cid| !new_mst_blocks.contains_key(*cid)) 308 - .copied(), 309 - ) 310 - .collect(); 311 208 312 - let commit_result = match commit_and_log( 209 + let commit_result = finalize_repo_write( 313 210 &state, 314 - CommitParams { 211 + ctx, 212 + mst, 213 + FinalizeParams { 315 214 did: &did, 316 215 user_id, 317 - current_root_cid: Some(current_root_cid.into_cid()), 318 - prev_data_cid: Some(initial_mst_root), 319 - new_mst_root, 216 + controller_did: controller_did.as_ref(), 217 + delegation_detail: controller_did.as_ref().map(|_| { 218 + json!({ 219 + "action": "create", 220 + "collection": input.collection, 221 + "rkey": rkey 222 + }) 223 + }), 320 224 ops, 321 - blocks_cids: &written_cids_str, 322 - blobs: &blob_cids, 323 - obsolete_cids, 225 + modified_keys: &modified_keys, 226 + blob_cids: &blob_cids, 324 227 }, 325 228 ) 326 - .await 327 - { 328 - Ok(res) => res, 329 - Err(e) => return Ok(ApiError::from(e).into_response()), 330 - }; 331 - 332 - for conflict_uri in conflict_uris_to_cleanup { 333 - if let Err(e) = state 334 - .backlink_repo 335 - .remove_backlinks_by_uri(&conflict_uri) 336 - .await 337 - { 338 - error!("Failed to remove backlinks for {}: {}", conflict_uri, e); 339 - } 340 - } 229 + .await?; 341 230 342 - if let Some(ref controller) = controller_did { 343 - let _ = state 344 - .delegation_repo 345 - .log_delegation_action( 346 - &did, 347 - controller, 348 - Some(controller), 349 - DelegationActionType::RepoWrite, 350 - Some(json!({ 351 - "action": "create", 352 - "collection": input.collection, 353 - "rkey": rkey 354 - })), 355 - None, 356 - None, 357 - ) 358 - .await; 231 + { 232 + let backlink_repo = state.backlink_repo.clone(); 233 + futures::future::join_all(conflict_uris_to_cleanup.iter().map(|uri| { 234 + let backlink_repo = backlink_repo.clone(); 235 + async move { 236 + if let Err(e) = backlink_repo.remove_backlinks_by_uri(uri).await { 237 + error!("Failed to remove backlinks for {}: {}", uri, e); 238 + } 239 + } 240 + })) 241 + .await; 359 242 } 360 243 361 244 let created_uri = AtUri::from_parts(&did, &input.collection, &rkey); ··· 366 249 error!("Failed to add backlinks for {}: {}", created_uri, e); 367 250 } 368 251 369 - Ok(( 370 - StatusCode::OK, 371 - Json(CreateRecordOutput { 372 - uri: created_uri, 373 - cid: record_cid.to_string(), 374 - commit: CommitInfo { 375 - cid: commit_result.commit_cid.to_string(), 376 - rev: commit_result.rev, 377 - }, 378 - validation_status, 379 - }), 380 - ) 381 - .into_response()) 252 + Ok(Json(CreateRecordOutput { 253 + uri: created_uri, 254 + cid: record_cid.to_string(), 255 + commit: CommitInfo { 256 + cid: commit_result.commit_cid.to_string(), 257 + rev: commit_result.rev, 258 + }, 259 + validation_status, 260 + })) 382 261 } 262 + 383 263 #[derive(Deserialize)] 384 264 #[allow(dead_code)] 385 265 pub struct PutRecordInput { ··· 394 274 #[serde(rename = "swapRecord")] 395 275 pub swap_record: Option<String>, 396 276 } 277 + 397 278 #[derive(Serialize)] 398 279 #[serde(rename_all = "camelCase")] 399 280 pub struct PutRecordOutput { ··· 404 285 #[serde(skip_serializing_if = "Option::is_none")] 405 286 pub validation_status: Option<ValidationStatus>, 406 287 } 288 + 407 289 pub async fn put_record( 408 290 State(state): State<AppState>, 409 291 auth: Auth<Active>, 410 292 Json(input): Json<PutRecordInput>, 411 - ) -> Result<Response, tranquil_pds::api::error::ApiError> { 412 - let upsert_proof = match auth.verify_repo_upsert(&input.collection) { 413 - Ok(proof) => proof, 414 - Err(e) => return Ok(e.into_response()), 415 - }; 416 - 417 - let repo_auth = match prepare_repo_write(&state, &upsert_proof, &input.repo).await { 418 - Ok(res) => res, 419 - Err(err_res) => return Ok(err_res), 420 - }; 421 - 293 + ) -> Result<Json<PutRecordOutput>, ApiError> { 294 + let upsert_proof = auth.verify_repo_upsert(&input.collection)?; 295 + let repo_auth = prepare_repo_write(&state, &upsert_proof, &input.repo).await?; 422 296 let did = repo_auth.did; 423 297 let user_id = repo_auth.user_id; 424 298 let controller_did = repo_auth.controller_did; 425 299 426 - let _write_lock = state.repo_write_locks.lock(user_id).await; 427 - let current_root_cid = get_current_root_cid(&state, user_id).await?; 300 + let (ctx, mst) = begin_repo_write(&state, user_id, input.swap_commit.as_deref()).await?; 428 301 429 - if let Some(swap_commit) = &input.swap_commit 430 - && CommitCid::from_str(swap_commit).ok().as_ref() != Some(&current_root_cid) 431 - { 432 - return Ok(ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response()); 433 - } 434 - let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 435 - let commit_bytes = match tracking_store.get(current_root_cid.as_cid()).await { 436 - Ok(Some(b)) => b, 437 - _ => { 438 - return Ok( 439 - ApiError::InternalError(Some("Commit block not found".into())).into_response(), 440 - ); 441 - } 442 - }; 443 - let commit = match Commit::from_cbor(&commit_bytes) { 444 - Ok(c) => c, 445 - _ => { 446 - return Ok( 447 - ApiError::InternalError(Some("Failed to parse commit".into())).into_response(), 448 - ); 449 - } 450 - }; 451 - let mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None); 452 - let key = format!("{}/{}", input.collection, input.rkey); 453 302 let validation_status = if input.validate.should_skip() { 454 303 None 455 304 } else { 456 - match validate_record_with_status( 457 - &input.record, 458 - &input.collection, 459 - Some(&input.rkey), 460 - input.validate.requires_lexicon(), 305 + Some( 306 + validate_record_with_status( 307 + &input.record, 308 + &input.collection, 309 + Some(&input.rkey), 310 + input.validate.requires_lexicon(), 311 + ) 312 + .await?, 461 313 ) 462 - .await 463 - { 464 - Ok(status) => Some(status), 465 - Err(err_response) => return Ok(*err_response), 466 - } 467 314 }; 315 + 316 + let key = format!("{}/{}", input.collection, input.rkey); 317 + 468 318 if let Some(swap_record_str) = &input.swap_record { 469 319 let expected_cid = Cid::from_str(swap_record_str).ok(); 470 320 let actual_cid = mst.get(&key).await.ok().flatten(); 471 321 if expected_cid != actual_cid { 472 - return Ok(ApiError::InvalidSwap(Some( 322 + return Err(ApiError::InvalidSwap(Some( 473 323 "Record has been modified or does not exist".into(), 474 - )) 475 - .into_response()); 324 + ))); 476 325 } 477 326 } 327 + 478 328 let existing_cid = mst.get(&key).await.ok().flatten(); 479 329 let record_ipld = tranquil_pds::util::json_to_ipld(&input.record); 480 - let mut record_bytes = Vec::new(); 481 - if serde_ipld_dagcbor::to_writer(&mut record_bytes, &record_ipld).is_err() { 482 - return Ok(ApiError::InvalidRecord("Failed to serialize record".into()).into_response()); 483 - } 484 - let record_cid = match tracking_store.put(&record_bytes).await { 485 - Ok(c) => c, 486 - _ => { 487 - return Ok( 488 - ApiError::InternalError(Some("Failed to save record block".into())).into_response(), 489 - ); 490 - } 491 - }; 330 + let record_bytes = serde_ipld_dagcbor::to_vec(&record_ipld) 331 + .map_err(|_| ApiError::InvalidRecord("Failed to serialize record".into()))?; 332 + let record_cid = ctx 333 + .tracking_store 334 + .put(&record_bytes) 335 + .await 336 + .map_err(|_| ApiError::InternalError(Some("Failed to save record block".into())))?; 337 + 492 338 if existing_cid == Some(record_cid) { 493 - return Ok(( 494 - StatusCode::OK, 495 - Json(PutRecordOutput { 496 - uri: AtUri::from_parts(&did, &input.collection, &input.rkey), 497 - cid: record_cid.to_string(), 498 - commit: None, 499 - validation_status, 500 - }), 501 - ) 502 - .into_response()); 339 + return Ok(Json(PutRecordOutput { 340 + uri: AtUri::from_parts(&did, &input.collection, &input.rkey), 341 + cid: record_cid.to_string(), 342 + commit: None, 343 + validation_status, 344 + })); 503 345 } 504 - let new_mst = 505 - if existing_cid.is_some() { 506 - match mst.update(&key, record_cid).await { 507 - Ok(m) => m, 508 - Err(_) => { 509 - return Ok(ApiError::InternalError(Some("Failed to update MST".into())) 510 - .into_response()); 511 - } 512 - } 513 - } else { 514 - match mst.add(&key, record_cid).await { 515 - Ok(m) => m, 516 - Err(_) => { 517 - return Ok(ApiError::InternalError(Some("Failed to add to MST".into())) 518 - .into_response()); 519 - } 520 - } 521 - }; 522 - let new_mst_root = match new_mst.persist().await { 523 - Ok(c) => c, 524 - Err(_) => { 525 - return Ok( 526 - ApiError::InternalError(Some("Failed to persist MST".into())).into_response(), 527 - ); 528 - } 346 + 347 + let is_update = existing_cid.is_some(); 348 + let new_mst = if is_update { 349 + mst.update(&key, record_cid) 350 + .await 351 + .map_err(|_| ApiError::InternalError(Some("Failed to update MST".into())))? 352 + } else { 353 + mst.add(&key, record_cid) 354 + .await 355 + .map_err(|_| ApiError::InternalError(Some("Failed to add to MST".into())))? 529 356 }; 530 - let op = if existing_cid.is_some() { 357 + 358 + let op = if is_update { 531 359 RecordOp::Update { 532 360 collection: input.collection.clone(), 533 361 rkey: input.rkey.clone(), ··· 541 369 cid: record_cid, 542 370 } 543 371 }; 544 - let mut new_mst_blocks = std::collections::BTreeMap::new(); 545 - let mut old_mst_blocks = std::collections::BTreeMap::new(); 546 - if new_mst 547 - .blocks_for_path(&key, &mut new_mst_blocks) 548 - .await 549 - .is_err() 550 - { 551 - return Ok( 552 - ApiError::InternalError(Some("Failed to get new MST blocks for path".into())) 553 - .into_response(), 554 - ); 555 - } 556 - if mst 557 - .blocks_for_path(&key, &mut old_mst_blocks) 558 - .await 559 - .is_err() 560 - { 561 - return Ok( 562 - ApiError::InternalError(Some("Failed to get old MST blocks for path".into())) 563 - .into_response(), 564 - ); 565 - } 566 - let mut relevant_blocks = new_mst_blocks.clone(); 567 - relevant_blocks.extend(old_mst_blocks.iter().map(|(k, v)| (*k, v.clone()))); 568 - relevant_blocks.insert(record_cid, bytes::Bytes::from(record_bytes)); 569 - let written_cids: Vec<Cid> = tracking_store 570 - .get_all_relevant_cids() 571 - .into_iter() 572 - .chain(relevant_blocks.keys().copied()) 573 - .collect::<std::collections::HashSet<_>>() 574 - .into_iter() 575 - .collect(); 576 - let written_cids_str: Vec<String> = written_cids.iter().map(|c| c.to_string()).collect(); 577 - let is_update = existing_cid.is_some(); 372 + 373 + let modified_keys = [key]; 578 374 let blob_cids = extract_blob_cids(&input.record); 579 - let obsolete_cids: Vec<Cid> = std::iter::once(current_root_cid.into_cid()) 580 - .chain( 581 - old_mst_blocks 582 - .keys() 583 - .filter(|cid| !new_mst_blocks.contains_key(*cid)) 584 - .copied(), 585 - ) 586 - .chain(existing_cid) 587 - .collect(); 588 - let commit_result = match commit_and_log( 375 + 376 + let commit_result = finalize_repo_write( 589 377 &state, 590 - CommitParams { 378 + ctx, 379 + new_mst, 380 + FinalizeParams { 591 381 did: &did, 592 382 user_id, 593 - current_root_cid: Some(current_root_cid.into_cid()), 594 - prev_data_cid: Some(commit.data), 595 - new_mst_root, 596 - ops: vec![op], 597 - blocks_cids: &written_cids_str, 598 - blobs: &blob_cids, 599 - obsolete_cids, 600 - }, 601 - ) 602 - .await 603 - { 604 - Ok(res) => res, 605 - Err(e) => return Ok(ApiError::from(e).into_response()), 606 - }; 607 - 608 - if let Some(ref controller) = controller_did { 609 - let _ = state 610 - .delegation_repo 611 - .log_delegation_action( 612 - &did, 613 - controller, 614 - Some(controller), 615 - DelegationActionType::RepoWrite, 616 - Some(json!({ 383 + controller_did: controller_did.as_ref(), 384 + delegation_detail: controller_did.as_ref().map(|_| { 385 + json!({ 617 386 "action": if is_update { "update" } else { "create" }, 618 387 "collection": input.collection, 619 388 "rkey": input.rkey 620 - })), 621 - None, 622 - None, 623 - ) 624 - .await; 625 - } 626 - 627 - Ok(( 628 - StatusCode::OK, 629 - Json(PutRecordOutput { 630 - uri: AtUri::from_parts(&did, &input.collection, &input.rkey), 631 - cid: record_cid.to_string(), 632 - commit: Some(CommitInfo { 633 - cid: commit_result.commit_cid.to_string(), 634 - rev: commit_result.rev, 389 + }) 635 390 }), 636 - validation_status, 637 - }), 391 + ops: vec![op], 392 + modified_keys: &modified_keys, 393 + blob_cids: &blob_cids, 394 + }, 638 395 ) 639 - .into_response()) 396 + .await?; 397 + 398 + Ok(Json(PutRecordOutput { 399 + uri: AtUri::from_parts(&did, &input.collection, &input.rkey), 400 + cid: record_cid.to_string(), 401 + commit: Some(CommitInfo { 402 + cid: commit_result.commit_cid.to_string(), 403 + rev: commit_result.rev, 404 + }), 405 + validation_status, 406 + })) 640 407 }
+133 -3
crates/tranquil-pds/src/repo_ops.rs
··· 1 1 use crate::api::error::ApiError; 2 2 use crate::cid_types::CommitCid; 3 + use crate::repo::tracking::TrackingBlockStore; 3 4 use crate::state::AppState; 4 5 use crate::types::{Did, Handle, Nsid, Rkey}; 5 6 use bytes::Bytes; 6 7 use cid::Cid; 7 8 use jacquard_common::types::{integer::LimitedU32, string::Tid}; 8 9 use jacquard_repo::commit::Commit; 10 + use jacquard_repo::mst::Mst; 9 11 use jacquard_repo::storage::BlockStore; 10 12 use k256::ecdsa::SigningKey; 11 13 use serde_json::{Value, json}; 12 14 use std::str::FromStr; 15 + use std::sync::Arc; 16 + use tokio::sync::OwnedMutexGuard; 13 17 use tracing::error; 14 18 use tranquil_db_traits::SequenceNumber; 15 19 use uuid::Uuid; ··· 158 162 } 159 163 } 160 164 165 + pub struct RepoWriteContext { 166 + pub tracking_store: TrackingBlockStore, 167 + pub current_root_cid: Cid, 168 + pub prev_data_cid: Cid, 169 + pub write_lock: OwnedMutexGuard<()>, 170 + } 171 + 172 + pub struct FinalizeParams<'a> { 173 + pub did: &'a Did, 174 + pub user_id: Uuid, 175 + pub controller_did: Option<&'a Did>, 176 + pub delegation_detail: Option<serde_json::Value>, 177 + pub ops: Vec<RecordOp>, 178 + pub modified_keys: &'a [String], 179 + pub blob_cids: &'a [String], 180 + } 181 + 182 + pub async fn begin_repo_write( 183 + state: &AppState, 184 + user_id: Uuid, 185 + swap_commit: Option<&str>, 186 + ) -> Result<(RepoWriteContext, Mst<TrackingBlockStore>), ApiError> { 187 + let write_lock = state.repo_write_locks.lock(user_id).await; 188 + 189 + let root_cid_str = state 190 + .repo_repo 191 + .get_repo_root_cid_by_user_id(user_id) 192 + .await 193 + .map_err(|e| { 194 + error!("DB error fetching repo root: {}", e); 195 + ApiError::InternalError(None) 196 + })? 197 + .ok_or_else(|| ApiError::InternalError(Some("Repo root not found".into())))?; 198 + 199 + let current_root_cid = Cid::from_str(root_cid_str.as_str()) 200 + .map_err(|_| ApiError::InternalError(Some("Invalid repo root CID".into())))?; 201 + 202 + if let Some(expected) = swap_commit { 203 + let expected_cid = Cid::from_str(expected) 204 + .map_err(|_| ApiError::InvalidSwap(Some("Invalid swap commit CID".into())))?; 205 + if expected_cid != current_root_cid { 206 + return Err(ApiError::InvalidSwap(Some("Repo has been modified".into()))); 207 + } 208 + } 209 + 210 + let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 211 + let commit_bytes = tracking_store 212 + .get(&current_root_cid) 213 + .await 214 + .map_err(|e| { 215 + error!("Failed to load commit block: {:?}", e); 216 + ApiError::InternalError(None) 217 + })? 218 + .ok_or_else(|| ApiError::InternalError(Some("Commit block not found".into())))?; 219 + 220 + let commit = Commit::from_cbor(&commit_bytes).map_err(|e| { 221 + error!("Failed to parse commit: {:?}", e); 222 + ApiError::InternalError(None) 223 + })?; 224 + 225 + let mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None); 226 + 227 + let ctx = RepoWriteContext { 228 + tracking_store, 229 + current_root_cid, 230 + prev_data_cid: commit.data, 231 + write_lock, 232 + }; 233 + 234 + Ok((ctx, mst)) 235 + } 236 + 237 + pub async fn finalize_repo_write( 238 + state: &AppState, 239 + ctx: RepoWriteContext, 240 + mst: Mst<TrackingBlockStore>, 241 + params: FinalizeParams<'_>, 242 + ) -> Result<CommitResult, ApiError> { 243 + let new_mst_root = mst.persist().await.map_err(|e| { 244 + error!("MST persist failed: {:?}", e); 245 + ApiError::InternalError(None) 246 + })?; 247 + 248 + let written_cids: Vec<Cid> = ctx 249 + .tracking_store 250 + .get_all_relevant_cids() 251 + .into_iter() 252 + .collect::<std::collections::HashSet<_>>() 253 + .into_iter() 254 + .collect(); 255 + let written_cids_str: Vec<String> = written_cids.iter().map(ToString::to_string).collect(); 256 + 257 + let result = commit_and_log( 258 + state, 259 + CommitParams { 260 + did: params.did, 261 + user_id: params.user_id, 262 + current_root_cid: Some(ctx.current_root_cid), 263 + prev_data_cid: Some(ctx.prev_data_cid), 264 + new_mst_root, 265 + ops: params.ops, 266 + blocks_cids: &written_cids_str, 267 + blobs: params.blob_cids, 268 + obsolete_cids: vec![ctx.current_root_cid], 269 + }, 270 + ) 271 + .await?; 272 + 273 + if let Some(controller_did) = params.controller_did 274 + && let Some(detail) = params.delegation_detail 275 + && let Err(e) = state 276 + .delegation_repo 277 + .log_delegation_action( 278 + params.did, 279 + controller_did, 280 + Some(controller_did), 281 + tranquil_db_traits::DelegationActionType::RepoWrite, 282 + Some(detail), 283 + None, 284 + None, 285 + ) 286 + .await 287 + { 288 + tracing::warn!("Failed to log delegation audit: {:?}", e); 289 + } 290 + 291 + Ok(result) 292 + } 293 + 161 294 pub fn create_signed_commit( 162 295 did: &Did, 163 296 data: Cid, ··· 392 525 rkey: &Rkey, 393 526 record: &serde_json::Value, 394 527 ) -> Result<(String, Cid), CommitError> { 395 - use crate::repo::tracking::TrackingBlockStore; 396 - use jacquard_repo::mst::Mst; 397 - use std::sync::Arc; 398 528 let user_id: Uuid = state 399 529 .user_repo 400 530 .get_id_by_did(did)

History

1 round 0 comments
sign up or login to add to the discussion
oyster.cafe submitted #0
1 commit
expand
refactor(api): extract repo write lifecycle to repo_ops
expand 0 comments
pull request successfully merged