crossing the streams
at main 594 lines 18 kB view raw
1use std::collections::HashMap; 2use std::error::Error; 3use std::str::FromStr; 4 5use scru128::Scru128Id; 6 7use base64::Engine; 8 9use tokio::io::AsyncWriteExt; 10use tokio_stream::wrappers::ReceiverStream; 11use tokio_stream::StreamExt; 12use tokio_util::io::ReaderStream; 13 14use http_body_util::StreamBody; 15use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full}; 16use hyper::body::Bytes; 17use hyper::header::ACCEPT; 18use hyper::server::conn::http1; 19use hyper::service::service_fn; 20use hyper::{Method, Request, Response, StatusCode}; 21use hyper_util::rt::TokioIo; 22 23use crate::listener::Listener; 24use crate::nu; 25use crate::store::{self, FollowOption, Frame, ReadOptions, Store, TTL}; 26 27type BoxError = Box<dyn std::error::Error + Send + Sync>; 28type HTTPResult = Result<Response<BoxBody<Bytes, BoxError>>, BoxError>; 29 30#[derive(Debug, PartialEq, Clone)] 31enum AcceptType { 32 Ndjson, 33 EventStream, 34} 35 36enum Routes { 37 StreamCat { 38 accept_type: AcceptType, 39 options: ReadOptions, 40 }, 41 StreamAppend { 42 topic: String, 43 ttl: Option<TTL>, 44 context_id: Scru128Id, 45 }, 46 HeadGet { 47 topic: String, 48 follow: bool, 49 context_id: Scru128Id, 50 }, 51 StreamItemGet(Scru128Id), 52 StreamItemRemove(Scru128Id), 53 CasGet(ssri::Integrity), 54 CasPost, 55 Import, 56 Version, 57 NotFound, 58 BadRequest(String), 59} 60 61/// Validates an Integrity object to ensure all its hashes are properly formatted 62fn validate_integrity(integrity: &ssri::Integrity) -> bool { 63 // Check if there are any hashes 64 if integrity.hashes.is_empty() { 65 return false; 66 } 67 68 // For each hash, check if it has a valid base64-encoded digest 69 for hash in &integrity.hashes { 70 // Check if digest is valid base64 using the modern API 71 if base64::engine::general_purpose::STANDARD 72 .decode(&hash.digest) 73 .is_err() 74 { 75 return false; 76 } 77 } 78 79 true 80} 81 82fn match_route( 83 method: &Method, 84 path: &str, 85 headers: &hyper::HeaderMap, 86 query: Option<&str>, 87) -> Routes { 88 let params: HashMap<String, String> = 89 url::form_urlencoded::parse(query.unwrap_or("").as_bytes()) 90 .into_owned() 91 .collect(); 92 93 match (method, path) { 94 (&Method::GET, "/version") => Routes::Version, 95 96 (&Method::GET, "/") => { 97 let accept_type = match headers.get(ACCEPT) { 98 Some(accept) if accept == "text/event-stream" => AcceptType::EventStream, 99 _ => AcceptType::Ndjson, 100 }; 101 102 let options = ReadOptions::from_query(query); 103 104 match options { 105 Ok(options) => Routes::StreamCat { 106 accept_type, 107 options, 108 }, 109 Err(e) => Routes::BadRequest(e.to_string()), 110 } 111 } 112 113 (&Method::GET, p) if p.starts_with("/head/") => { 114 let topic = p.strip_prefix("/head/").unwrap().to_string(); 115 let follow = params.contains_key("follow"); 116 let context_id = match params.get("context") { 117 None => crate::store::ZERO_CONTEXT, 118 Some(ctx) => match ctx.parse() { 119 Ok(id) => id, 120 Err(e) => return Routes::BadRequest(format!("Invalid context ID: {}", e)), 121 }, 122 }; 123 Routes::HeadGet { 124 topic, 125 follow, 126 context_id, 127 } 128 } 129 130 (&Method::GET, p) if p.starts_with("/cas/") => { 131 if let Some(hash) = p.strip_prefix("/cas/") { 132 match ssri::Integrity::from_str(hash) { 133 Ok(integrity) => { 134 if validate_integrity(&integrity) { 135 Routes::CasGet(integrity) 136 } else { 137 Routes::BadRequest(format!("Invalid CAS hash format: {}", hash)) 138 } 139 } 140 Err(e) => Routes::BadRequest(format!("Invalid CAS hash: {}", e)), 141 } 142 } else { 143 Routes::NotFound 144 } 145 } 146 147 (&Method::POST, "/cas") => Routes::CasPost, 148 (&Method::POST, "/import") => Routes::Import, 149 150 (&Method::GET, p) => match Scru128Id::from_str(p.trim_start_matches('/')) { 151 Ok(id) => Routes::StreamItemGet(id), 152 Err(e) => Routes::BadRequest(format!("Invalid frame ID: {}", e)), 153 }, 154 155 (&Method::DELETE, p) => match Scru128Id::from_str(p.trim_start_matches('/')) { 156 Ok(id) => Routes::StreamItemRemove(id), 157 Err(e) => Routes::BadRequest(format!("Invalid frame ID: {}", e)), 158 }, 159 160 (&Method::POST, path) if path.starts_with('/') => { 161 let topic = path.trim_start_matches('/').to_string(); 162 let context_id = match params.get("context") { 163 None => crate::store::ZERO_CONTEXT, 164 Some(ctx) => match ctx.parse() { 165 Ok(id) => id, 166 Err(e) => return Routes::BadRequest(format!("Invalid context ID: {}", e)), 167 }, 168 }; 169 170 match TTL::from_query(query) { 171 Ok(ttl) => Routes::StreamAppend { 172 topic, 173 ttl: Some(ttl), 174 context_id, 175 }, 176 Err(e) => Routes::BadRequest(e.to_string()), 177 } 178 } 179 180 _ => Routes::NotFound, 181 } 182} 183 184async fn handle( 185 mut store: Store, 186 _engine: nu::Engine, // TODO: potentially vestigial, will .process come back? 187 req: Request<hyper::body::Incoming>, 188) -> HTTPResult { 189 let method = req.method(); 190 let path = req.uri().path(); 191 let headers = req.headers().clone(); 192 let query = req.uri().query(); 193 194 let res = match match_route(method, path, &headers, query) { 195 Routes::Version => handle_version().await, 196 197 Routes::StreamCat { 198 accept_type, 199 options, 200 } => handle_stream_cat(&mut store, options, accept_type).await, 201 202 Routes::StreamAppend { 203 topic, 204 ttl, 205 context_id, 206 } => handle_stream_append(&mut store, req, topic, ttl, context_id).await, 207 208 Routes::CasGet(hash) => { 209 let reader = store.cas_reader(hash).await?; 210 let stream = ReaderStream::new(reader); 211 212 let stream = stream.map(|frame| { 213 let frame = frame.unwrap(); 214 Ok(hyper::body::Frame::data(frame)) 215 }); 216 217 let body = StreamBody::new(stream).boxed(); 218 Ok(Response::new(body)) 219 } 220 221 Routes::CasPost => handle_cas_post(&mut store, req.into_body()).await, 222 223 Routes::StreamItemGet(id) => response_frame_or_404(store.get(&id)), 224 225 Routes::StreamItemRemove(id) => handle_stream_item_remove(&mut store, id).await, 226 227 Routes::HeadGet { 228 topic, 229 follow, 230 context_id, 231 } => handle_head_get(&store, &topic, follow, context_id).await, 232 233 Routes::Import => handle_import(&mut store, req.into_body()).await, 234 235 Routes::NotFound => response_404(), 236 Routes::BadRequest(msg) => response_400(msg), 237 }; 238 239 res.or_else(|e| response_500(e.to_string())) 240} 241 242async fn handle_stream_cat( 243 store: &mut Store, 244 options: ReadOptions, 245 accept_type: AcceptType, 246) -> HTTPResult { 247 let rx = store.read(options).await; 248 let stream = ReceiverStream::new(rx); 249 250 let accept_type_clone = accept_type.clone(); 251 let stream = stream.map(move |frame| { 252 let bytes = match accept_type_clone { 253 AcceptType::Ndjson => { 254 let mut encoded = serde_json::to_vec(&frame).unwrap(); 255 encoded.push(b'\n'); 256 encoded 257 } 258 AcceptType::EventStream => format!( 259 "id: {}\ndata: {}\n\n", 260 frame.id, 261 serde_json::to_string(&frame).unwrap_or_default() 262 ) 263 .into_bytes(), 264 }; 265 Ok(hyper::body::Frame::data(Bytes::from(bytes))) 266 }); 267 268 let body = StreamBody::new(stream).boxed(); 269 270 let content_type = match accept_type { 271 AcceptType::Ndjson => "application/x-ndjson", 272 AcceptType::EventStream => "text/event-stream", 273 }; 274 275 Ok(Response::builder() 276 .status(StatusCode::OK) 277 .header("Content-Type", content_type) 278 .body(body)?) 279} 280 281async fn handle_stream_append( 282 store: &mut Store, 283 req: Request<hyper::body::Incoming>, 284 topic: String, 285 ttl: Option<TTL>, 286 context_id: Scru128Id, 287) -> HTTPResult { 288 let (parts, mut body) = req.into_parts(); 289 290 let hash = { 291 let mut writer = store.cas_writer().await?; 292 let mut bytes_written = 0; 293 294 while let Some(frame) = body.frame().await { 295 if let Ok(data) = frame?.into_data() { 296 writer.write_all(&data).await?; 297 bytes_written += data.len(); 298 } 299 } 300 301 if bytes_written > 0 { 302 Some(writer.commit().await?) 303 } else { 304 None 305 } 306 }; 307 308 let meta = match parts 309 .headers 310 .get("xs-meta") 311 .map(|x| x.to_str()) 312 .transpose() 313 .unwrap() 314 .map(|s| { 315 // First decode the Base64-encoded string 316 base64::prelude::BASE64_STANDARD 317 .decode(s) 318 .map_err(|e| format!("xs-meta isn't valid Base64: {}", e)) 319 .and_then(|decoded| { 320 // Then parse the decoded bytes as UTF-8 string 321 String::from_utf8(decoded) 322 .map_err(|e| format!("xs-meta isn't valid UTF-8: {}", e)) 323 .and_then(|json_str| { 324 // Finally parse the UTF-8 string as JSON 325 serde_json::from_str(&json_str) 326 .map_err(|e| format!("xs-meta isn't valid JSON: {}", e)) 327 }) 328 }) 329 }) 330 .transpose() 331 { 332 Ok(meta) => meta, 333 Err(e) => return response_400(e.to_string()), 334 }; 335 336 let frame = store.append( 337 Frame::builder(topic, context_id) 338 .maybe_hash(hash) 339 .maybe_meta(meta) 340 .maybe_ttl(ttl) 341 .build(), 342 )?; 343 344 Ok(Response::builder() 345 .status(StatusCode::OK) 346 .header("Content-Type", "application/json") 347 .body(full(serde_json::to_string(&frame).unwrap()))?) 348} 349 350async fn handle_cas_post(store: &mut Store, mut body: hyper::body::Incoming) -> HTTPResult { 351 let hash = { 352 let mut writer = store.cas_writer().await?; 353 let mut bytes_written = 0; 354 355 while let Some(frame) = body.frame().await { 356 if let Ok(data) = frame?.into_data() { 357 writer.write_all(&data).await?; 358 bytes_written += data.len(); 359 } 360 } 361 362 if bytes_written == 0 { 363 return response_400("Empty body".to_string()); 364 } 365 366 writer.commit().await? 367 }; 368 369 Ok(Response::builder() 370 .status(StatusCode::OK) 371 .header("Content-Type", "text/plain") 372 .body(full(hash.to_string()))?) 373} 374 375async fn handle_version() -> HTTPResult { 376 let version = env!("CARGO_PKG_VERSION"); 377 let version_info = serde_json::json!({ "version": version }); 378 Ok(Response::builder() 379 .status(StatusCode::OK) 380 .header("Content-Type", "application/json") 381 .body(full(serde_json::to_string(&version_info).unwrap()))?) 382} 383 384pub async fn serve( 385 store: Store, 386 engine: nu::Engine, 387 expose: Option<String>, 388) -> Result<(), BoxError> { 389 if let Err(e) = store.append( 390 Frame::builder("xs.start", store::ZERO_CONTEXT) 391 .maybe_meta(expose.as_ref().map(|e| serde_json::json!({"expose": e}))) 392 .build(), 393 ) { 394 tracing::error!("Failed to append xs.start frame: {}", e); 395 } 396 397 let path = store.path.join("sock").to_string_lossy().to_string(); 398 let listener = Listener::bind(&path).await?; 399 400 let mut listeners = vec![listener]; 401 402 if let Some(expose) = expose { 403 listeners.push(Listener::bind(&expose).await?); 404 } 405 406 let mut tasks = Vec::new(); 407 for listener in listeners { 408 let store = store.clone(); 409 let engine = engine.clone(); 410 let task = tokio::spawn(async move { listener_loop(listener, store, engine).await }); 411 tasks.push(task); 412 } 413 414 // TODO: graceful shutdown and error handling 415 // Wait for all listener tasks to complete (or until the first error) 416 for task in tasks { 417 task.await??; 418 } 419 420 Ok(()) 421} 422 423async fn listener_loop( 424 mut listener: Listener, 425 store: Store, 426 engine: nu::Engine, 427) -> Result<(), BoxError> { 428 loop { 429 let (stream, _) = listener.accept().await?; 430 let io = TokioIo::new(stream); 431 let store = store.clone(); 432 let engine = engine.clone(); 433 tokio::task::spawn(async move { 434 if let Err(err) = http1::Builder::new() 435 .serve_connection( 436 io, 437 service_fn(move |req| handle(store.clone(), engine.clone(), req)), 438 ) 439 .await 440 { 441 // Match against the error kind to selectively ignore `NotConnected` errors 442 if let Some(std::io::ErrorKind::NotConnected) = err.source().and_then(|source| { 443 source 444 .downcast_ref::<std::io::Error>() 445 .map(|io_err| io_err.kind()) 446 }) { 447 // ignore the NotConnected error, hyper's way of saying the client disconnected 448 } else { 449 // todo, Handle or log other errors 450 tracing::error!("TBD: {:?}", err); 451 } 452 } 453 }); 454 } 455} 456 457fn response_frame_or_404(frame: Option<store::Frame>) -> HTTPResult { 458 if let Some(frame) = frame { 459 Ok(Response::builder() 460 .status(StatusCode::OK) 461 .header("Content-Type", "application/json") 462 .body(full(serde_json::to_string(&frame).unwrap()))?) 463 } else { 464 response_404() 465 } 466} 467 468async fn handle_stream_item_remove(store: &mut Store, id: Scru128Id) -> HTTPResult { 469 match store.remove(&id) { 470 Ok(()) => Ok(Response::builder() 471 .status(StatusCode::NO_CONTENT) 472 .body(empty())?), 473 Err(e) => { 474 tracing::error!("Failed to remove item {}: {:?}", id, e); 475 476 Ok(Response::builder() 477 .status(StatusCode::INTERNAL_SERVER_ERROR) 478 .body(full("internal-error"))?) 479 } 480 } 481} 482 483async fn handle_head_get( 484 store: &Store, 485 topic: &str, 486 follow: bool, 487 context_id: Scru128Id, 488) -> HTTPResult { 489 let current_head = store.head(topic, context_id); 490 491 if !follow { 492 return response_frame_or_404(current_head); 493 } 494 495 let rx = store 496 .read( 497 ReadOptions::builder() 498 .follow(FollowOption::On) 499 .tail(true) 500 .maybe_last_id(current_head.as_ref().map(|f| f.id)) 501 .build(), 502 ) 503 .await; 504 505 let topic = topic.to_string(); 506 let stream = tokio_stream::wrappers::ReceiverStream::new(rx) 507 .filter(move |frame| frame.topic == topic) 508 .map(|frame| { 509 let mut bytes = serde_json::to_vec(&frame).unwrap(); 510 bytes.push(b'\n'); 511 Ok::<_, BoxError>(hyper::body::Frame::data(Bytes::from(bytes))) 512 }); 513 514 let body = if let Some(frame) = current_head { 515 let mut head_bytes = serde_json::to_vec(&frame).unwrap(); 516 head_bytes.push(b'\n'); 517 let head_chunk = Ok(hyper::body::Frame::data(Bytes::from(head_bytes))); 518 StreamBody::new(futures::stream::once(async { head_chunk }).chain(stream)).boxed() 519 } else { 520 StreamBody::new(stream).boxed() 521 }; 522 523 Ok(Response::builder() 524 .status(StatusCode::OK) 525 .header("Content-Type", "application/x-ndjson") 526 .body(body)?) 527} 528 529async fn handle_import(store: &mut Store, body: hyper::body::Incoming) -> HTTPResult { 530 let bytes = body.collect().await?.to_bytes(); 531 let frame: Frame = match serde_json::from_slice(&bytes) { 532 Ok(frame) => frame, 533 Err(e) => return response_400(format!("Invalid frame JSON: {}", e)), 534 }; 535 536 store.insert_frame(&frame)?; 537 538 Ok(Response::builder() 539 .status(StatusCode::OK) 540 .header("Content-Type", "application/json") 541 .body(full(serde_json::to_string(&frame).unwrap()))?) 542} 543 544fn response_404() -> HTTPResult { 545 Ok(Response::builder() 546 .status(StatusCode::NOT_FOUND) 547 .body(empty())?) 548} 549 550fn response_400(message: String) -> HTTPResult { 551 let body = full(message); 552 Ok(Response::builder() 553 .status(StatusCode::BAD_REQUEST) 554 .body(body)?) 555} 556 557fn response_500(message: String) -> HTTPResult { 558 let body = full(message); 559 Ok(Response::builder() 560 .status(StatusCode::INTERNAL_SERVER_ERROR) 561 .body(body)?) 562} 563 564fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, BoxError> { 565 Full::new(chunk.into()) 566 .map_err(|never| match never {}) 567 .boxed() 568} 569 570fn empty() -> BoxBody<Bytes, BoxError> { 571 Empty::<Bytes>::new() 572 .map_err(|never| match never {}) 573 .boxed() 574} 575 576#[cfg(test)] 577mod tests { 578 use super::*; 579 580 #[test] 581 fn test_match_route_head_follow() { 582 let headers = hyper::HeaderMap::new(); 583 584 assert!(matches!( 585 match_route(&Method::GET, "/head/test", &headers, None), 586 Routes::HeadGet { topic, follow: false, context_id: _ } if topic == "test" 587 )); 588 589 assert!(matches!( 590 match_route(&Method::GET, "/head/test", &headers, Some("follow=true")), 591 Routes::HeadGet { topic, follow: true, context_id: _ } if topic == "test" 592 )); 593 } 594}