personal activity index (bluesky, leaflet, substack)
pai.desertthunder.dev
rss
bluesky
1use crate::storage::SqliteStorage;
2
3use axum::{
4 extract::{Path, Query, Request, State},
5 http::{header, HeaderValue, Method, StatusCode},
6 middleware::{self, Next},
7 response::{IntoResponse, Response},
8 routing::get,
9 Json, Router,
10};
11use chrono::DateTime;
12use owo_colors::OwoColorize;
13use pai_core::{Config, CorsConfig, Item, ListFilter, PaiError, SourceKind};
14use rss::{Channel, ChannelBuilder, ItemBuilder};
15use serde::{Deserialize, Serialize};
16use std::{io, net::SocketAddr, path::PathBuf, sync::Arc, time::Instant};
17use tokio::net::TcpListener;
18
19const DEFAULT_LIMIT: usize = 20;
20const VERSION: &str = env!("CARGO_PKG_VERSION");
21
22/// Launches the HTTP server using the provided config and address.
23pub fn serve(config: Config, db_path: PathBuf, address: &str) -> Result<(), PaiError> {
24 let addr: SocketAddr = address
25 .parse()
26 .map_err(|e| PaiError::Config(format!("Invalid listen address '{address}': {e}")))?;
27
28 let runtime = tokio::runtime::Builder::new_multi_thread()
29 .enable_all()
30 .build()
31 .map_err(PaiError::Io)?;
32
33 runtime.block_on(async move { run_server(config, db_path, addr).await })
34}
35
36async fn run_server(config: Config, db_path: PathBuf, addr: SocketAddr) -> Result<(), PaiError> {
37 let storage = SqliteStorage::new(&db_path)?;
38 storage.verify_schema()?;
39 drop(storage);
40
41 let state =
42 AppState { db_path: Arc::new(db_path), start_time: Instant::now(), cors_config: Arc::new(config.cors.clone()) };
43
44 let mut app = Router::new()
45 .route("/api/feed", get(feed_handler))
46 .route("/api/item/:id", get(item_handler))
47 .route("/status", get(status_handler))
48 .route("/rss.xml", get(rss_handler))
49 .with_state(state.clone());
50
51 if !config.cors.allowed_origins.is_empty() || config.cors.dev_key.is_some() {
52 app = app.layer(middleware::from_fn_with_state(state.clone(), cors_middleware));
53 }
54
55 let listener = TcpListener::bind(addr).await.map_err(PaiError::Io)?;
56 let local_addr = listener.local_addr().map_err(PaiError::Io)?;
57 println!("{} Listening on http://{}", "Info:".cyan(), local_addr);
58
59 axum::serve(listener, app.into_make_service())
60 .with_graceful_shutdown(shutdown_signal())
61 .await
62 .map_err(|e| io::Error::other(e).into())
63}
64
65/// CORS middleware that validates origins and dev keys
66async fn cors_middleware(State(state): State<AppState>, request: Request, next: Next) -> Result<Response, StatusCode> {
67 let origin = request
68 .headers()
69 .get(header::ORIGIN)
70 .and_then(|v| v.to_str().ok())
71 .map(|s| s.to_string());
72 let dev_key = request
73 .headers()
74 .get("x-local-dev-key")
75 .and_then(|v| v.to_str().ok())
76 .map(|s| s.to_string());
77 let method = request.method().clone();
78
79 let is_authorized = if let Some(ref key) = dev_key {
80 state.cors_config.is_dev_key_valid(Some(key))
81 } else if let Some(ref origin_str) = origin {
82 state.cors_config.is_origin_allowed(origin_str)
83 } else {
84 true
85 };
86
87 if method == Method::OPTIONS {
88 if !is_authorized {
89 return Err(StatusCode::FORBIDDEN);
90 }
91
92 let mut response = Response::new(String::new().into());
93 if let Some(ref origin_str) = origin {
94 response.headers_mut().insert(
95 header::ACCESS_CONTROL_ALLOW_ORIGIN,
96 HeaderValue::from_str(origin_str).unwrap_or(HeaderValue::from_static("*")),
97 );
98 }
99 response.headers_mut().insert(
100 header::ACCESS_CONTROL_ALLOW_METHODS,
101 HeaderValue::from_static("GET, POST, OPTIONS"),
102 );
103 response.headers_mut().insert(
104 header::ACCESS_CONTROL_ALLOW_HEADERS,
105 HeaderValue::from_static("Content-Type, X-Local-Dev-Key"),
106 );
107 response
108 .headers_mut()
109 .insert(header::ACCESS_CONTROL_MAX_AGE, HeaderValue::from_static("3600"));
110 return Ok(response);
111 }
112
113 if origin.is_some() && !is_authorized {
114 return Err(StatusCode::FORBIDDEN);
115 }
116
117 let mut response = next.run(request).await;
118
119 if let Some(ref origin_str) = origin {
120 if is_authorized {
121 response.headers_mut().insert(
122 header::ACCESS_CONTROL_ALLOW_ORIGIN,
123 HeaderValue::from_str(origin_str).unwrap_or(HeaderValue::from_static("*")),
124 );
125 response.headers_mut().insert(
126 header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
127 HeaderValue::from_static("true"),
128 );
129 }
130 }
131
132 Ok(response)
133}
134
135#[derive(Clone)]
136struct AppState {
137 db_path: Arc<PathBuf>,
138 start_time: Instant,
139 cors_config: Arc<CorsConfig>,
140}
141
142impl AppState {
143 fn open_storage(&self) -> Result<SqliteStorage, PaiError> {
144 SqliteStorage::new(self.db_path.as_ref())
145 }
146
147 fn status_snapshot(&self) -> Result<StatusResponse, PaiError> {
148 let storage = self.open_storage()?;
149 let total_items = storage.count_items()?;
150 let sources = storage
151 .get_stats()?
152 .into_iter()
153 .map(|(kind, count)| SourceStat { kind, count })
154 .collect();
155
156 Ok(StatusResponse {
157 status: "ok",
158 version: VERSION,
159 uptime_seconds: self.start_time.elapsed().as_secs(),
160 database_path: self.db_path.display().to_string(),
161 total_items,
162 sources,
163 })
164 }
165}
166
167#[derive(Debug, Default, Deserialize)]
168struct FeedQuery {
169 source_kind: Option<SourceKind>,
170 source_id: Option<String>,
171 limit: Option<usize>,
172 since: Option<String>,
173 q: Option<String>,
174}
175
176impl FeedQuery {
177 fn into_filter(self) -> Result<ListFilter, PaiError> {
178 let limit = match self.limit {
179 Some(value) => ensure_positive_limit(value)?,
180 None => DEFAULT_LIMIT,
181 };
182
183 Ok(ListFilter {
184 source_kind: self.source_kind,
185 source_id: normalize_optional_string(self.source_id),
186 limit: Some(limit),
187 since: normalize_optional_string(self.since),
188 query: normalize_optional_string(self.q),
189 })
190 }
191}
192
193#[derive(Serialize)]
194struct FeedResponse {
195 count: usize,
196 items: Vec<Item>,
197}
198
199#[derive(Serialize)]
200struct StatusResponse {
201 status: &'static str,
202 version: &'static str,
203 uptime_seconds: u64,
204 database_path: String,
205 total_items: usize,
206 sources: Vec<SourceStat>,
207}
208
209#[derive(Serialize)]
210struct SourceStat {
211 kind: String,
212 count: usize,
213}
214
215async fn feed_handler(
216 State(state): State<AppState>, Query(query): Query<FeedQuery>,
217) -> Result<Json<FeedResponse>, ApiError> {
218 let filter = query.into_filter()?;
219 let storage = state.open_storage()?;
220 let items = pai_core::Storage::list_items(&storage, &filter)?;
221
222 Ok(Json(FeedResponse { count: items.len(), items }))
223}
224
225async fn item_handler(State(state): State<AppState>, Path(id): Path<String>) -> Result<Json<Item>, ApiError> {
226 let storage = state.open_storage()?;
227 let item = storage
228 .get_item(&id)?
229 .ok_or_else(|| ApiError::not_found(format!("Item '{id}' not found")))?;
230
231 Ok(Json(item))
232}
233
234async fn status_handler(State(state): State<AppState>) -> Result<Json<StatusResponse>, ApiError> {
235 let snapshot = state.status_snapshot()?;
236 Ok(Json(snapshot))
237}
238
239async fn rss_handler(State(state): State<AppState>, Query(query): Query<FeedQuery>) -> Result<RssResponse, ApiError> {
240 let filter = query.into_filter()?;
241 let storage = state.open_storage()?;
242 let items = pai_core::Storage::list_items(&storage, &filter)?;
243
244 let channel = build_rss_channel(&items)?;
245 Ok(RssResponse(channel))
246}
247
248fn build_rss_channel(items: &[Item]) -> Result<Channel, PaiError> {
249 const TITLE: &str = "Personal Activity Index";
250 const LINK: &str = "https://personal-activity-index.local/";
251 const DESCRIPTION: &str = "Aggregated feed exported by the Personal Activity Index.";
252
253 let rss_items: Vec<rss::Item> = items
254 .iter()
255 .map(|item| {
256 let title = item
257 .title
258 .as_deref()
259 .or(item.summary.as_deref())
260 .unwrap_or(&item.url)
261 .to_string();
262 let description = item
263 .summary
264 .as_deref()
265 .or(item.content_html.as_deref())
266 .unwrap_or("")
267 .to_string();
268 let author = item.author.as_deref().unwrap_or("Unknown").to_string();
269 let pub_date = format_rss_date(&item.published_at);
270
271 ItemBuilder::default()
272 .title(Some(title))
273 .link(Some(item.url.clone()))
274 .guid(Some(
275 rss::GuidBuilder::default().value(&item.id).permalink(false).build(),
276 ))
277 .pub_date(Some(pub_date))
278 .author(Some(author))
279 .description(Some(description))
280 .categories(vec![rss::CategoryBuilder::default()
281 .name(item.source_kind.to_string())
282 .build()])
283 .build()
284 })
285 .collect();
286
287 let channel = ChannelBuilder::default()
288 .title(TITLE)
289 .link(LINK)
290 .description(DESCRIPTION)
291 .items(rss_items)
292 .build();
293
294 Ok(channel)
295}
296
297fn format_rss_date(value: &str) -> String {
298 if let Ok(dt) = DateTime::parse_from_rfc3339(value) {
299 dt.to_rfc2822()
300 } else if let Ok(dt) = DateTime::parse_from_rfc2822(value) {
301 dt.to_rfc2822()
302 } else {
303 value.to_string()
304 }
305}
306
307struct RssResponse(Channel);
308
309impl IntoResponse for RssResponse {
310 fn into_response(self) -> Response {
311 let rss_string = self.0.to_string();
312 (
313 [(header::CONTENT_TYPE, "application/rss+xml; charset=utf-8")],
314 rss_string,
315 )
316 .into_response()
317 }
318}
319
320struct ApiError {
321 status: StatusCode,
322 message: String,
323}
324
325impl ApiError {
326 fn bad_request(msg: impl Into<String>) -> Self {
327 Self { status: StatusCode::BAD_REQUEST, message: msg.into() }
328 }
329
330 fn not_found(msg: impl Into<String>) -> Self {
331 Self { status: StatusCode::NOT_FOUND, message: msg.into() }
332 }
333
334 fn internal(msg: impl Into<String>) -> Self {
335 Self { status: StatusCode::INTERNAL_SERVER_ERROR, message: msg.into() }
336 }
337}
338
339impl From<PaiError> for ApiError {
340 fn from(err: PaiError) -> Self {
341 match err {
342 PaiError::InvalidArgument(msg) => Self::bad_request(msg),
343 other => Self::internal(other.to_string()),
344 }
345 }
346}
347
348#[derive(Serialize)]
349struct ErrorBody {
350 error: String,
351}
352
353impl IntoResponse for ApiError {
354 fn into_response(self) -> Response {
355 (self.status, Json(ErrorBody { error: self.message })).into_response()
356 }
357}
358
359async fn shutdown_signal() {
360 let _ = tokio::signal::ctrl_c().await;
361}
362
363fn ensure_positive_limit(limit: usize) -> Result<usize, PaiError> {
364 if limit == 0 {
365 return Err(PaiError::InvalidArgument("Limit must be greater than zero".to_string()));
366 }
367 Ok(limit)
368}
369
370fn normalize_optional_string(value: Option<String>) -> Option<String> {
371 value.and_then(|input| {
372 let trimmed = input.trim();
373 if trimmed.is_empty() {
374 None
375 } else {
376 Some(trimmed.to_string())
377 }
378 })
379}
380
381#[cfg(test)]
382mod tests {
383 use super::*;
384 use chrono::Utc;
385 use pai_core::Storage;
386 use tempfile::tempdir;
387
388 #[test]
389 fn feed_query_defaults() {
390 let filter = FeedQuery::default().into_filter().unwrap();
391 assert_eq!(filter.limit, Some(DEFAULT_LIMIT));
392 assert!(filter.source_kind.is_none());
393 assert!(filter.source_id.is_none());
394 }
395
396 #[test]
397 fn feed_query_respects_parameters() {
398 let query = FeedQuery {
399 source_kind: Some(SourceKind::Bluesky),
400 source_id: Some(" desertthunder.dev ".to_string()),
401 limit: Some(5),
402 since: Some("2024-01-01T00:00:00Z".to_string()),
403 q: Some(" rust ".to_string()),
404 };
405
406 let filter = query.into_filter().unwrap();
407 assert_eq!(filter.limit, Some(5));
408 assert_eq!(filter.source_kind, Some(SourceKind::Bluesky));
409 assert_eq!(filter.source_id.as_deref(), Some("desertthunder.dev"));
410 assert_eq!(filter.query.as_deref(), Some("rust"));
411 assert_eq!(filter.since.as_deref(), Some("2024-01-01T00:00:00Z"));
412 }
413
414 #[test]
415 fn feed_query_rejects_zero_limit() {
416 let err = FeedQuery { limit: Some(0), ..Default::default() }
417 .into_filter()
418 .unwrap_err();
419 assert!(matches!(err, PaiError::InvalidArgument(_)));
420 }
421
422 #[test]
423 fn api_error_into_response_sets_status() {
424 let resp = ApiError::bad_request("oops").into_response();
425 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
426 }
427
428 #[test]
429 fn status_snapshot_reports_counts() {
430 let dir = tempdir().unwrap();
431 let db_path = dir.path().join("status.db");
432 let state = AppState {
433 db_path: Arc::new(db_path),
434 start_time: Instant::now(),
435 cors_config: Arc::new(pai_core::CorsConfig::default()),
436 };
437
438 let storage = state.open_storage().unwrap();
439 let now = Utc::now().to_rfc3339();
440 let item = Item {
441 id: "status-test".to_string(),
442 source_kind: SourceKind::Substack,
443 source_id: "status.substack.com".to_string(),
444 author: None,
445 title: Some("Status".to_string()),
446 summary: None,
447 url: "https://example.com/status".to_string(),
448 content_html: None,
449 published_at: now.clone(),
450 created_at: now,
451 };
452 storage.insert_or_replace_item(&item).unwrap();
453
454 let snapshot = state.status_snapshot().unwrap();
455 assert_eq!(snapshot.status, "ok");
456 assert_eq!(snapshot.version, VERSION);
457 assert!(snapshot.uptime_seconds < 5);
458 assert_eq!(snapshot.total_items, 1);
459 assert_eq!(snapshot.sources.len(), 1);
460 assert_eq!(snapshot.sources[0].kind, "substack");
461 }
462}