personal activity index (bluesky, leaflet, substack) pai.desertthunder.dev
rss bluesky
at main 462 lines 14 kB view raw
1use crate::storage::SqliteStorage; 2 3use axum::{ 4 extract::{Path, Query, Request, State}, 5 http::{header, HeaderValue, Method, StatusCode}, 6 middleware::{self, Next}, 7 response::{IntoResponse, Response}, 8 routing::get, 9 Json, Router, 10}; 11use chrono::DateTime; 12use owo_colors::OwoColorize; 13use pai_core::{Config, CorsConfig, Item, ListFilter, PaiError, SourceKind}; 14use rss::{Channel, ChannelBuilder, ItemBuilder}; 15use serde::{Deserialize, Serialize}; 16use std::{io, net::SocketAddr, path::PathBuf, sync::Arc, time::Instant}; 17use tokio::net::TcpListener; 18 19const DEFAULT_LIMIT: usize = 20; 20const VERSION: &str = env!("CARGO_PKG_VERSION"); 21 22/// Launches the HTTP server using the provided config and address. 23pub fn serve(config: Config, db_path: PathBuf, address: &str) -> Result<(), PaiError> { 24 let addr: SocketAddr = address 25 .parse() 26 .map_err(|e| PaiError::Config(format!("Invalid listen address '{address}': {e}")))?; 27 28 let runtime = tokio::runtime::Builder::new_multi_thread() 29 .enable_all() 30 .build() 31 .map_err(PaiError::Io)?; 32 33 runtime.block_on(async move { run_server(config, db_path, addr).await }) 34} 35 36async fn run_server(config: Config, db_path: PathBuf, addr: SocketAddr) -> Result<(), PaiError> { 37 let storage = SqliteStorage::new(&db_path)?; 38 storage.verify_schema()?; 39 drop(storage); 40 41 let state = 42 AppState { db_path: Arc::new(db_path), start_time: Instant::now(), cors_config: Arc::new(config.cors.clone()) }; 43 44 let mut app = Router::new() 45 .route("/api/feed", get(feed_handler)) 46 .route("/api/item/:id", get(item_handler)) 47 .route("/status", get(status_handler)) 48 .route("/rss.xml", get(rss_handler)) 49 .with_state(state.clone()); 50 51 if !config.cors.allowed_origins.is_empty() || config.cors.dev_key.is_some() { 52 app = app.layer(middleware::from_fn_with_state(state.clone(), cors_middleware)); 53 } 54 55 let listener = TcpListener::bind(addr).await.map_err(PaiError::Io)?; 56 let local_addr = listener.local_addr().map_err(PaiError::Io)?; 57 println!("{} Listening on http://{}", "Info:".cyan(), local_addr); 58 59 axum::serve(listener, app.into_make_service()) 60 .with_graceful_shutdown(shutdown_signal()) 61 .await 62 .map_err(|e| io::Error::other(e).into()) 63} 64 65/// CORS middleware that validates origins and dev keys 66async fn cors_middleware(State(state): State<AppState>, request: Request, next: Next) -> Result<Response, StatusCode> { 67 let origin = request 68 .headers() 69 .get(header::ORIGIN) 70 .and_then(|v| v.to_str().ok()) 71 .map(|s| s.to_string()); 72 let dev_key = request 73 .headers() 74 .get("x-local-dev-key") 75 .and_then(|v| v.to_str().ok()) 76 .map(|s| s.to_string()); 77 let method = request.method().clone(); 78 79 let is_authorized = if let Some(ref key) = dev_key { 80 state.cors_config.is_dev_key_valid(Some(key)) 81 } else if let Some(ref origin_str) = origin { 82 state.cors_config.is_origin_allowed(origin_str) 83 } else { 84 true 85 }; 86 87 if method == Method::OPTIONS { 88 if !is_authorized { 89 return Err(StatusCode::FORBIDDEN); 90 } 91 92 let mut response = Response::new(String::new().into()); 93 if let Some(ref origin_str) = origin { 94 response.headers_mut().insert( 95 header::ACCESS_CONTROL_ALLOW_ORIGIN, 96 HeaderValue::from_str(origin_str).unwrap_or(HeaderValue::from_static("*")), 97 ); 98 } 99 response.headers_mut().insert( 100 header::ACCESS_CONTROL_ALLOW_METHODS, 101 HeaderValue::from_static("GET, POST, OPTIONS"), 102 ); 103 response.headers_mut().insert( 104 header::ACCESS_CONTROL_ALLOW_HEADERS, 105 HeaderValue::from_static("Content-Type, X-Local-Dev-Key"), 106 ); 107 response 108 .headers_mut() 109 .insert(header::ACCESS_CONTROL_MAX_AGE, HeaderValue::from_static("3600")); 110 return Ok(response); 111 } 112 113 if origin.is_some() && !is_authorized { 114 return Err(StatusCode::FORBIDDEN); 115 } 116 117 let mut response = next.run(request).await; 118 119 if let Some(ref origin_str) = origin { 120 if is_authorized { 121 response.headers_mut().insert( 122 header::ACCESS_CONTROL_ALLOW_ORIGIN, 123 HeaderValue::from_str(origin_str).unwrap_or(HeaderValue::from_static("*")), 124 ); 125 response.headers_mut().insert( 126 header::ACCESS_CONTROL_ALLOW_CREDENTIALS, 127 HeaderValue::from_static("true"), 128 ); 129 } 130 } 131 132 Ok(response) 133} 134 135#[derive(Clone)] 136struct AppState { 137 db_path: Arc<PathBuf>, 138 start_time: Instant, 139 cors_config: Arc<CorsConfig>, 140} 141 142impl AppState { 143 fn open_storage(&self) -> Result<SqliteStorage, PaiError> { 144 SqliteStorage::new(self.db_path.as_ref()) 145 } 146 147 fn status_snapshot(&self) -> Result<StatusResponse, PaiError> { 148 let storage = self.open_storage()?; 149 let total_items = storage.count_items()?; 150 let sources = storage 151 .get_stats()? 152 .into_iter() 153 .map(|(kind, count)| SourceStat { kind, count }) 154 .collect(); 155 156 Ok(StatusResponse { 157 status: "ok", 158 version: VERSION, 159 uptime_seconds: self.start_time.elapsed().as_secs(), 160 database_path: self.db_path.display().to_string(), 161 total_items, 162 sources, 163 }) 164 } 165} 166 167#[derive(Debug, Default, Deserialize)] 168struct FeedQuery { 169 source_kind: Option<SourceKind>, 170 source_id: Option<String>, 171 limit: Option<usize>, 172 since: Option<String>, 173 q: Option<String>, 174} 175 176impl FeedQuery { 177 fn into_filter(self) -> Result<ListFilter, PaiError> { 178 let limit = match self.limit { 179 Some(value) => ensure_positive_limit(value)?, 180 None => DEFAULT_LIMIT, 181 }; 182 183 Ok(ListFilter { 184 source_kind: self.source_kind, 185 source_id: normalize_optional_string(self.source_id), 186 limit: Some(limit), 187 since: normalize_optional_string(self.since), 188 query: normalize_optional_string(self.q), 189 }) 190 } 191} 192 193#[derive(Serialize)] 194struct FeedResponse { 195 count: usize, 196 items: Vec<Item>, 197} 198 199#[derive(Serialize)] 200struct StatusResponse { 201 status: &'static str, 202 version: &'static str, 203 uptime_seconds: u64, 204 database_path: String, 205 total_items: usize, 206 sources: Vec<SourceStat>, 207} 208 209#[derive(Serialize)] 210struct SourceStat { 211 kind: String, 212 count: usize, 213} 214 215async fn feed_handler( 216 State(state): State<AppState>, Query(query): Query<FeedQuery>, 217) -> Result<Json<FeedResponse>, ApiError> { 218 let filter = query.into_filter()?; 219 let storage = state.open_storage()?; 220 let items = pai_core::Storage::list_items(&storage, &filter)?; 221 222 Ok(Json(FeedResponse { count: items.len(), items })) 223} 224 225async fn item_handler(State(state): State<AppState>, Path(id): Path<String>) -> Result<Json<Item>, ApiError> { 226 let storage = state.open_storage()?; 227 let item = storage 228 .get_item(&id)? 229 .ok_or_else(|| ApiError::not_found(format!("Item '{id}' not found")))?; 230 231 Ok(Json(item)) 232} 233 234async fn status_handler(State(state): State<AppState>) -> Result<Json<StatusResponse>, ApiError> { 235 let snapshot = state.status_snapshot()?; 236 Ok(Json(snapshot)) 237} 238 239async fn rss_handler(State(state): State<AppState>, Query(query): Query<FeedQuery>) -> Result<RssResponse, ApiError> { 240 let filter = query.into_filter()?; 241 let storage = state.open_storage()?; 242 let items = pai_core::Storage::list_items(&storage, &filter)?; 243 244 let channel = build_rss_channel(&items)?; 245 Ok(RssResponse(channel)) 246} 247 248fn build_rss_channel(items: &[Item]) -> Result<Channel, PaiError> { 249 const TITLE: &str = "Personal Activity Index"; 250 const LINK: &str = "https://personal-activity-index.local/"; 251 const DESCRIPTION: &str = "Aggregated feed exported by the Personal Activity Index."; 252 253 let rss_items: Vec<rss::Item> = items 254 .iter() 255 .map(|item| { 256 let title = item 257 .title 258 .as_deref() 259 .or(item.summary.as_deref()) 260 .unwrap_or(&item.url) 261 .to_string(); 262 let description = item 263 .summary 264 .as_deref() 265 .or(item.content_html.as_deref()) 266 .unwrap_or("") 267 .to_string(); 268 let author = item.author.as_deref().unwrap_or("Unknown").to_string(); 269 let pub_date = format_rss_date(&item.published_at); 270 271 ItemBuilder::default() 272 .title(Some(title)) 273 .link(Some(item.url.clone())) 274 .guid(Some( 275 rss::GuidBuilder::default().value(&item.id).permalink(false).build(), 276 )) 277 .pub_date(Some(pub_date)) 278 .author(Some(author)) 279 .description(Some(description)) 280 .categories(vec![rss::CategoryBuilder::default() 281 .name(item.source_kind.to_string()) 282 .build()]) 283 .build() 284 }) 285 .collect(); 286 287 let channel = ChannelBuilder::default() 288 .title(TITLE) 289 .link(LINK) 290 .description(DESCRIPTION) 291 .items(rss_items) 292 .build(); 293 294 Ok(channel) 295} 296 297fn format_rss_date(value: &str) -> String { 298 if let Ok(dt) = DateTime::parse_from_rfc3339(value) { 299 dt.to_rfc2822() 300 } else if let Ok(dt) = DateTime::parse_from_rfc2822(value) { 301 dt.to_rfc2822() 302 } else { 303 value.to_string() 304 } 305} 306 307struct RssResponse(Channel); 308 309impl IntoResponse for RssResponse { 310 fn into_response(self) -> Response { 311 let rss_string = self.0.to_string(); 312 ( 313 [(header::CONTENT_TYPE, "application/rss+xml; charset=utf-8")], 314 rss_string, 315 ) 316 .into_response() 317 } 318} 319 320struct ApiError { 321 status: StatusCode, 322 message: String, 323} 324 325impl ApiError { 326 fn bad_request(msg: impl Into<String>) -> Self { 327 Self { status: StatusCode::BAD_REQUEST, message: msg.into() } 328 } 329 330 fn not_found(msg: impl Into<String>) -> Self { 331 Self { status: StatusCode::NOT_FOUND, message: msg.into() } 332 } 333 334 fn internal(msg: impl Into<String>) -> Self { 335 Self { status: StatusCode::INTERNAL_SERVER_ERROR, message: msg.into() } 336 } 337} 338 339impl From<PaiError> for ApiError { 340 fn from(err: PaiError) -> Self { 341 match err { 342 PaiError::InvalidArgument(msg) => Self::bad_request(msg), 343 other => Self::internal(other.to_string()), 344 } 345 } 346} 347 348#[derive(Serialize)] 349struct ErrorBody { 350 error: String, 351} 352 353impl IntoResponse for ApiError { 354 fn into_response(self) -> Response { 355 (self.status, Json(ErrorBody { error: self.message })).into_response() 356 } 357} 358 359async fn shutdown_signal() { 360 let _ = tokio::signal::ctrl_c().await; 361} 362 363fn ensure_positive_limit(limit: usize) -> Result<usize, PaiError> { 364 if limit == 0 { 365 return Err(PaiError::InvalidArgument("Limit must be greater than zero".to_string())); 366 } 367 Ok(limit) 368} 369 370fn normalize_optional_string(value: Option<String>) -> Option<String> { 371 value.and_then(|input| { 372 let trimmed = input.trim(); 373 if trimmed.is_empty() { 374 None 375 } else { 376 Some(trimmed.to_string()) 377 } 378 }) 379} 380 381#[cfg(test)] 382mod tests { 383 use super::*; 384 use chrono::Utc; 385 use pai_core::Storage; 386 use tempfile::tempdir; 387 388 #[test] 389 fn feed_query_defaults() { 390 let filter = FeedQuery::default().into_filter().unwrap(); 391 assert_eq!(filter.limit, Some(DEFAULT_LIMIT)); 392 assert!(filter.source_kind.is_none()); 393 assert!(filter.source_id.is_none()); 394 } 395 396 #[test] 397 fn feed_query_respects_parameters() { 398 let query = FeedQuery { 399 source_kind: Some(SourceKind::Bluesky), 400 source_id: Some(" desertthunder.dev ".to_string()), 401 limit: Some(5), 402 since: Some("2024-01-01T00:00:00Z".to_string()), 403 q: Some(" rust ".to_string()), 404 }; 405 406 let filter = query.into_filter().unwrap(); 407 assert_eq!(filter.limit, Some(5)); 408 assert_eq!(filter.source_kind, Some(SourceKind::Bluesky)); 409 assert_eq!(filter.source_id.as_deref(), Some("desertthunder.dev")); 410 assert_eq!(filter.query.as_deref(), Some("rust")); 411 assert_eq!(filter.since.as_deref(), Some("2024-01-01T00:00:00Z")); 412 } 413 414 #[test] 415 fn feed_query_rejects_zero_limit() { 416 let err = FeedQuery { limit: Some(0), ..Default::default() } 417 .into_filter() 418 .unwrap_err(); 419 assert!(matches!(err, PaiError::InvalidArgument(_))); 420 } 421 422 #[test] 423 fn api_error_into_response_sets_status() { 424 let resp = ApiError::bad_request("oops").into_response(); 425 assert_eq!(resp.status(), StatusCode::BAD_REQUEST); 426 } 427 428 #[test] 429 fn status_snapshot_reports_counts() { 430 let dir = tempdir().unwrap(); 431 let db_path = dir.path().join("status.db"); 432 let state = AppState { 433 db_path: Arc::new(db_path), 434 start_time: Instant::now(), 435 cors_config: Arc::new(pai_core::CorsConfig::default()), 436 }; 437 438 let storage = state.open_storage().unwrap(); 439 let now = Utc::now().to_rfc3339(); 440 let item = Item { 441 id: "status-test".to_string(), 442 source_kind: SourceKind::Substack, 443 source_id: "status.substack.com".to_string(), 444 author: None, 445 title: Some("Status".to_string()), 446 summary: None, 447 url: "https://example.com/status".to_string(), 448 content_html: None, 449 published_at: now.clone(), 450 created_at: now, 451 }; 452 storage.insert_or_replace_item(&item).unwrap(); 453 454 let snapshot = state.status_snapshot().unwrap(); 455 assert_eq!(snapshot.status, "ok"); 456 assert_eq!(snapshot.version, VERSION); 457 assert!(snapshot.uptime_seconds < 5); 458 assert_eq!(snapshot.total_items, 1); 459 assert_eq!(snapshot.sources.len(), 1); 460 assert_eq!(snapshot.sources[0].kind, "substack"); 461 } 462}