crossing the streams
1use std::collections::HashMap;
2use std::error::Error;
3use std::str::FromStr;
4
5use scru128::Scru128Id;
6
7use base64::Engine;
8
9use tokio::io::AsyncWriteExt;
10use tokio_stream::wrappers::ReceiverStream;
11use tokio_stream::StreamExt;
12use tokio_util::io::ReaderStream;
13
14use http_body_util::StreamBody;
15use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full};
16use hyper::body::Bytes;
17use hyper::header::ACCEPT;
18use hyper::server::conn::http1;
19use hyper::service::service_fn;
20use hyper::{Method, Request, Response, StatusCode};
21use hyper_util::rt::TokioIo;
22
23use crate::listener::Listener;
24use crate::nu;
25use crate::store::{self, FollowOption, Frame, ReadOptions, Store, TTL};
26
27type BoxError = Box<dyn std::error::Error + Send + Sync>;
28type HTTPResult = Result<Response<BoxBody<Bytes, BoxError>>, BoxError>;
29
30#[derive(Debug, PartialEq, Clone)]
31enum AcceptType {
32 Ndjson,
33 EventStream,
34}
35
36enum Routes {
37 StreamCat {
38 accept_type: AcceptType,
39 options: ReadOptions,
40 },
41 StreamAppend {
42 topic: String,
43 ttl: Option<TTL>,
44 context_id: Scru128Id,
45 },
46 HeadGet {
47 topic: String,
48 follow: bool,
49 context_id: Scru128Id,
50 },
51 StreamItemGet(Scru128Id),
52 StreamItemRemove(Scru128Id),
53 CasGet(ssri::Integrity),
54 CasPost,
55 Import,
56 Version,
57 NotFound,
58 BadRequest(String),
59}
60
61/// Validates an Integrity object to ensure all its hashes are properly formatted
62fn validate_integrity(integrity: &ssri::Integrity) -> bool {
63 // Check if there are any hashes
64 if integrity.hashes.is_empty() {
65 return false;
66 }
67
68 // For each hash, check if it has a valid base64-encoded digest
69 for hash in &integrity.hashes {
70 // Check if digest is valid base64 using the modern API
71 if base64::engine::general_purpose::STANDARD
72 .decode(&hash.digest)
73 .is_err()
74 {
75 return false;
76 }
77 }
78
79 true
80}
81
82fn match_route(
83 method: &Method,
84 path: &str,
85 headers: &hyper::HeaderMap,
86 query: Option<&str>,
87) -> Routes {
88 let params: HashMap<String, String> =
89 url::form_urlencoded::parse(query.unwrap_or("").as_bytes())
90 .into_owned()
91 .collect();
92
93 match (method, path) {
94 (&Method::GET, "/version") => Routes::Version,
95
96 (&Method::GET, "/") => {
97 let accept_type = match headers.get(ACCEPT) {
98 Some(accept) if accept == "text/event-stream" => AcceptType::EventStream,
99 _ => AcceptType::Ndjson,
100 };
101
102 let options = ReadOptions::from_query(query);
103
104 match options {
105 Ok(options) => Routes::StreamCat {
106 accept_type,
107 options,
108 },
109 Err(e) => Routes::BadRequest(e.to_string()),
110 }
111 }
112
113 (&Method::GET, p) if p.starts_with("/head/") => {
114 let topic = p.strip_prefix("/head/").unwrap().to_string();
115 let follow = params.contains_key("follow");
116 let context_id = match params.get("context") {
117 None => crate::store::ZERO_CONTEXT,
118 Some(ctx) => match ctx.parse() {
119 Ok(id) => id,
120 Err(e) => return Routes::BadRequest(format!("Invalid context ID: {}", e)),
121 },
122 };
123 Routes::HeadGet {
124 topic,
125 follow,
126 context_id,
127 }
128 }
129
130 (&Method::GET, p) if p.starts_with("/cas/") => {
131 if let Some(hash) = p.strip_prefix("/cas/") {
132 match ssri::Integrity::from_str(hash) {
133 Ok(integrity) => {
134 if validate_integrity(&integrity) {
135 Routes::CasGet(integrity)
136 } else {
137 Routes::BadRequest(format!("Invalid CAS hash format: {}", hash))
138 }
139 }
140 Err(e) => Routes::BadRequest(format!("Invalid CAS hash: {}", e)),
141 }
142 } else {
143 Routes::NotFound
144 }
145 }
146
147 (&Method::POST, "/cas") => Routes::CasPost,
148 (&Method::POST, "/import") => Routes::Import,
149
150 (&Method::GET, p) => match Scru128Id::from_str(p.trim_start_matches('/')) {
151 Ok(id) => Routes::StreamItemGet(id),
152 Err(e) => Routes::BadRequest(format!("Invalid frame ID: {}", e)),
153 },
154
155 (&Method::DELETE, p) => match Scru128Id::from_str(p.trim_start_matches('/')) {
156 Ok(id) => Routes::StreamItemRemove(id),
157 Err(e) => Routes::BadRequest(format!("Invalid frame ID: {}", e)),
158 },
159
160 (&Method::POST, path) if path.starts_with('/') => {
161 let topic = path.trim_start_matches('/').to_string();
162 let context_id = match params.get("context") {
163 None => crate::store::ZERO_CONTEXT,
164 Some(ctx) => match ctx.parse() {
165 Ok(id) => id,
166 Err(e) => return Routes::BadRequest(format!("Invalid context ID: {}", e)),
167 },
168 };
169
170 match TTL::from_query(query) {
171 Ok(ttl) => Routes::StreamAppend {
172 topic,
173 ttl: Some(ttl),
174 context_id,
175 },
176 Err(e) => Routes::BadRequest(e.to_string()),
177 }
178 }
179
180 _ => Routes::NotFound,
181 }
182}
183
184async fn handle(
185 mut store: Store,
186 _engine: nu::Engine, // TODO: potentially vestigial, will .process come back?
187 req: Request<hyper::body::Incoming>,
188) -> HTTPResult {
189 let method = req.method();
190 let path = req.uri().path();
191 let headers = req.headers().clone();
192 let query = req.uri().query();
193
194 let res = match match_route(method, path, &headers, query) {
195 Routes::Version => handle_version().await,
196
197 Routes::StreamCat {
198 accept_type,
199 options,
200 } => handle_stream_cat(&mut store, options, accept_type).await,
201
202 Routes::StreamAppend {
203 topic,
204 ttl,
205 context_id,
206 } => handle_stream_append(&mut store, req, topic, ttl, context_id).await,
207
208 Routes::CasGet(hash) => {
209 let reader = store.cas_reader(hash).await?;
210 let stream = ReaderStream::new(reader);
211
212 let stream = stream.map(|frame| {
213 let frame = frame.unwrap();
214 Ok(hyper::body::Frame::data(frame))
215 });
216
217 let body = StreamBody::new(stream).boxed();
218 Ok(Response::new(body))
219 }
220
221 Routes::CasPost => handle_cas_post(&mut store, req.into_body()).await,
222
223 Routes::StreamItemGet(id) => response_frame_or_404(store.get(&id)),
224
225 Routes::StreamItemRemove(id) => handle_stream_item_remove(&mut store, id).await,
226
227 Routes::HeadGet {
228 topic,
229 follow,
230 context_id,
231 } => handle_head_get(&store, &topic, follow, context_id).await,
232
233 Routes::Import => handle_import(&mut store, req.into_body()).await,
234
235 Routes::NotFound => response_404(),
236 Routes::BadRequest(msg) => response_400(msg),
237 };
238
239 res.or_else(|e| response_500(e.to_string()))
240}
241
242async fn handle_stream_cat(
243 store: &mut Store,
244 options: ReadOptions,
245 accept_type: AcceptType,
246) -> HTTPResult {
247 let rx = store.read(options).await;
248 let stream = ReceiverStream::new(rx);
249
250 let accept_type_clone = accept_type.clone();
251 let stream = stream.map(move |frame| {
252 let bytes = match accept_type_clone {
253 AcceptType::Ndjson => {
254 let mut encoded = serde_json::to_vec(&frame).unwrap();
255 encoded.push(b'\n');
256 encoded
257 }
258 AcceptType::EventStream => format!(
259 "id: {}\ndata: {}\n\n",
260 frame.id,
261 serde_json::to_string(&frame).unwrap_or_default()
262 )
263 .into_bytes(),
264 };
265 Ok(hyper::body::Frame::data(Bytes::from(bytes)))
266 });
267
268 let body = StreamBody::new(stream).boxed();
269
270 let content_type = match accept_type {
271 AcceptType::Ndjson => "application/x-ndjson",
272 AcceptType::EventStream => "text/event-stream",
273 };
274
275 Ok(Response::builder()
276 .status(StatusCode::OK)
277 .header("Content-Type", content_type)
278 .body(body)?)
279}
280
281async fn handle_stream_append(
282 store: &mut Store,
283 req: Request<hyper::body::Incoming>,
284 topic: String,
285 ttl: Option<TTL>,
286 context_id: Scru128Id,
287) -> HTTPResult {
288 let (parts, mut body) = req.into_parts();
289
290 let hash = {
291 let mut writer = store.cas_writer().await?;
292 let mut bytes_written = 0;
293
294 while let Some(frame) = body.frame().await {
295 if let Ok(data) = frame?.into_data() {
296 writer.write_all(&data).await?;
297 bytes_written += data.len();
298 }
299 }
300
301 if bytes_written > 0 {
302 Some(writer.commit().await?)
303 } else {
304 None
305 }
306 };
307
308 let meta = match parts
309 .headers
310 .get("xs-meta")
311 .map(|x| x.to_str())
312 .transpose()
313 .unwrap()
314 .map(|s| {
315 // First decode the Base64-encoded string
316 base64::prelude::BASE64_STANDARD
317 .decode(s)
318 .map_err(|e| format!("xs-meta isn't valid Base64: {}", e))
319 .and_then(|decoded| {
320 // Then parse the decoded bytes as UTF-8 string
321 String::from_utf8(decoded)
322 .map_err(|e| format!("xs-meta isn't valid UTF-8: {}", e))
323 .and_then(|json_str| {
324 // Finally parse the UTF-8 string as JSON
325 serde_json::from_str(&json_str)
326 .map_err(|e| format!("xs-meta isn't valid JSON: {}", e))
327 })
328 })
329 })
330 .transpose()
331 {
332 Ok(meta) => meta,
333 Err(e) => return response_400(e.to_string()),
334 };
335
336 let frame = store.append(
337 Frame::builder(topic, context_id)
338 .maybe_hash(hash)
339 .maybe_meta(meta)
340 .maybe_ttl(ttl)
341 .build(),
342 )?;
343
344 Ok(Response::builder()
345 .status(StatusCode::OK)
346 .header("Content-Type", "application/json")
347 .body(full(serde_json::to_string(&frame).unwrap()))?)
348}
349
350async fn handle_cas_post(store: &mut Store, mut body: hyper::body::Incoming) -> HTTPResult {
351 let hash = {
352 let mut writer = store.cas_writer().await?;
353 let mut bytes_written = 0;
354
355 while let Some(frame) = body.frame().await {
356 if let Ok(data) = frame?.into_data() {
357 writer.write_all(&data).await?;
358 bytes_written += data.len();
359 }
360 }
361
362 if bytes_written == 0 {
363 return response_400("Empty body".to_string());
364 }
365
366 writer.commit().await?
367 };
368
369 Ok(Response::builder()
370 .status(StatusCode::OK)
371 .header("Content-Type", "text/plain")
372 .body(full(hash.to_string()))?)
373}
374
375async fn handle_version() -> HTTPResult {
376 let version = env!("CARGO_PKG_VERSION");
377 let version_info = serde_json::json!({ "version": version });
378 Ok(Response::builder()
379 .status(StatusCode::OK)
380 .header("Content-Type", "application/json")
381 .body(full(serde_json::to_string(&version_info).unwrap()))?)
382}
383
384pub async fn serve(
385 store: Store,
386 engine: nu::Engine,
387 expose: Option<String>,
388) -> Result<(), BoxError> {
389 if let Err(e) = store.append(
390 Frame::builder("xs.start", store::ZERO_CONTEXT)
391 .maybe_meta(expose.as_ref().map(|e| serde_json::json!({"expose": e})))
392 .build(),
393 ) {
394 tracing::error!("Failed to append xs.start frame: {}", e);
395 }
396
397 let path = store.path.join("sock").to_string_lossy().to_string();
398 let listener = Listener::bind(&path).await?;
399
400 let mut listeners = vec![listener];
401
402 if let Some(expose) = expose {
403 listeners.push(Listener::bind(&expose).await?);
404 }
405
406 let mut tasks = Vec::new();
407 for listener in listeners {
408 let store = store.clone();
409 let engine = engine.clone();
410 let task = tokio::spawn(async move { listener_loop(listener, store, engine).await });
411 tasks.push(task);
412 }
413
414 // TODO: graceful shutdown and error handling
415 // Wait for all listener tasks to complete (or until the first error)
416 for task in tasks {
417 task.await??;
418 }
419
420 Ok(())
421}
422
423async fn listener_loop(
424 mut listener: Listener,
425 store: Store,
426 engine: nu::Engine,
427) -> Result<(), BoxError> {
428 loop {
429 let (stream, _) = listener.accept().await?;
430 let io = TokioIo::new(stream);
431 let store = store.clone();
432 let engine = engine.clone();
433 tokio::task::spawn(async move {
434 if let Err(err) = http1::Builder::new()
435 .serve_connection(
436 io,
437 service_fn(move |req| handle(store.clone(), engine.clone(), req)),
438 )
439 .await
440 {
441 // Match against the error kind to selectively ignore `NotConnected` errors
442 if let Some(std::io::ErrorKind::NotConnected) = err.source().and_then(|source| {
443 source
444 .downcast_ref::<std::io::Error>()
445 .map(|io_err| io_err.kind())
446 }) {
447 // ignore the NotConnected error, hyper's way of saying the client disconnected
448 } else {
449 // todo, Handle or log other errors
450 tracing::error!("TBD: {:?}", err);
451 }
452 }
453 });
454 }
455}
456
457fn response_frame_or_404(frame: Option<store::Frame>) -> HTTPResult {
458 if let Some(frame) = frame {
459 Ok(Response::builder()
460 .status(StatusCode::OK)
461 .header("Content-Type", "application/json")
462 .body(full(serde_json::to_string(&frame).unwrap()))?)
463 } else {
464 response_404()
465 }
466}
467
468async fn handle_stream_item_remove(store: &mut Store, id: Scru128Id) -> HTTPResult {
469 match store.remove(&id) {
470 Ok(()) => Ok(Response::builder()
471 .status(StatusCode::NO_CONTENT)
472 .body(empty())?),
473 Err(e) => {
474 tracing::error!("Failed to remove item {}: {:?}", id, e);
475
476 Ok(Response::builder()
477 .status(StatusCode::INTERNAL_SERVER_ERROR)
478 .body(full("internal-error"))?)
479 }
480 }
481}
482
483async fn handle_head_get(
484 store: &Store,
485 topic: &str,
486 follow: bool,
487 context_id: Scru128Id,
488) -> HTTPResult {
489 let current_head = store.head(topic, context_id);
490
491 if !follow {
492 return response_frame_or_404(current_head);
493 }
494
495 let rx = store
496 .read(
497 ReadOptions::builder()
498 .follow(FollowOption::On)
499 .tail(true)
500 .maybe_last_id(current_head.as_ref().map(|f| f.id))
501 .build(),
502 )
503 .await;
504
505 let topic = topic.to_string();
506 let stream = tokio_stream::wrappers::ReceiverStream::new(rx)
507 .filter(move |frame| frame.topic == topic)
508 .map(|frame| {
509 let mut bytes = serde_json::to_vec(&frame).unwrap();
510 bytes.push(b'\n');
511 Ok::<_, BoxError>(hyper::body::Frame::data(Bytes::from(bytes)))
512 });
513
514 let body = if let Some(frame) = current_head {
515 let mut head_bytes = serde_json::to_vec(&frame).unwrap();
516 head_bytes.push(b'\n');
517 let head_chunk = Ok(hyper::body::Frame::data(Bytes::from(head_bytes)));
518 StreamBody::new(futures::stream::once(async { head_chunk }).chain(stream)).boxed()
519 } else {
520 StreamBody::new(stream).boxed()
521 };
522
523 Ok(Response::builder()
524 .status(StatusCode::OK)
525 .header("Content-Type", "application/x-ndjson")
526 .body(body)?)
527}
528
529async fn handle_import(store: &mut Store, body: hyper::body::Incoming) -> HTTPResult {
530 let bytes = body.collect().await?.to_bytes();
531 let frame: Frame = match serde_json::from_slice(&bytes) {
532 Ok(frame) => frame,
533 Err(e) => return response_400(format!("Invalid frame JSON: {}", e)),
534 };
535
536 store.insert_frame(&frame)?;
537
538 Ok(Response::builder()
539 .status(StatusCode::OK)
540 .header("Content-Type", "application/json")
541 .body(full(serde_json::to_string(&frame).unwrap()))?)
542}
543
544fn response_404() -> HTTPResult {
545 Ok(Response::builder()
546 .status(StatusCode::NOT_FOUND)
547 .body(empty())?)
548}
549
550fn response_400(message: String) -> HTTPResult {
551 let body = full(message);
552 Ok(Response::builder()
553 .status(StatusCode::BAD_REQUEST)
554 .body(body)?)
555}
556
557fn response_500(message: String) -> HTTPResult {
558 let body = full(message);
559 Ok(Response::builder()
560 .status(StatusCode::INTERNAL_SERVER_ERROR)
561 .body(body)?)
562}
563
564fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, BoxError> {
565 Full::new(chunk.into())
566 .map_err(|never| match never {})
567 .boxed()
568}
569
570fn empty() -> BoxBody<Bytes, BoxError> {
571 Empty::<Bytes>::new()
572 .map_err(|never| match never {})
573 .boxed()
574}
575
576#[cfg(test)]
577mod tests {
578 use super::*;
579
580 #[test]
581 fn test_match_route_head_follow() {
582 let headers = hyper::HeaderMap::new();
583
584 assert!(matches!(
585 match_route(&Method::GET, "/head/test", &headers, None),
586 Routes::HeadGet { topic, follow: false, context_id: _ } if topic == "test"
587 ));
588
589 assert!(matches!(
590 match_route(&Method::GET, "/head/test", &headers, Some("follow=true")),
591 Routes::HeadGet { topic, follow: true, context_id: _ } if topic == "test"
592 ));
593 }
594}