use axum::{ extract, response::{IntoResponse, Response}, }; use ecsdb::{Component, Entity, EntityId, query}; use http::{HeaderValue, header}; use serde::{Deserialize, Serialize}; use serde_json::json; use tracing::{debug, debug_span}; use crate::{AppState, AthleteId, strava::Athlete}; use super::UserSession; #[derive(Clone)] pub struct HtmxTemplate { pub template_name: &'static str, pub context: serde_json::Value, } impl HtmxTemplate { pub fn new(template_name: &'static str, context: serde_json::Value) -> Self { Self { template_name, context, } } } impl IntoResponse for HtmxTemplate { fn into_response(self) -> Response { Response::builder() .extension(self) .body(axum::body::Body::empty()) .unwrap() } } tokio::task_local! { static DB: ecsdb::Ecs; } #[derive(Debug, Copy, Clone)] pub struct TemplateEcs; impl minijinja::value::Object for TemplateEcs { fn call_method( self: &std::sync::Arc, _state: &minijinja::State<'_, '_>, method: &str, args: &[minijinja::Value], ) -> Result { DB.with(|db| match (method, &args) { ("query", component_names) => { let _span = debug_span!("query", ?component_names).entered(); let component_names: Vec = component_names .iter() .filter_map(|name| name.as_str()) .map(|name| query::ComponentName(name.to_owned())) .collect(); let entities = db .query_filtered::(&component_names[..]) .map(|e| TemplateEntity(e.id())) .map(minijinja::Value::from_object) .collect::>(); debug!(?entities); Ok(entities.into()) } (other, args) => Err(minijinja::Error::new( minijinja::ErrorKind::UnknownMethod, format!( "{}({})", other, args.iter() .map(|v| v.kind().to_string()) .collect::>() .join(", ") ), )), }) } } #[derive(Debug, Copy, Clone, PartialEq, Eq, Ord, PartialOrd)] pub struct TemplateEntity(pub EntityId); impl TemplateEntity {} impl minijinja::value::Object for TemplateEntity { fn is_true(self: &std::sync::Arc) -> bool { DB.with(|db| db.entity(self.0).exists()) } fn get_value( self: &std::sync::Arc, field: &minijinja::Value, ) -> Option { DB.with(|db| { db.entity(self.0) .dyn_component(field.as_str()?) .and_then(|c| c.as_json()) .map(minijinja::Value::from_serialize) }) } fn call_method( self: &std::sync::Arc, _state: &minijinja::State<'_, '_>, method: &str, args: &[minijinja::Value], ) -> Result { use minijinja::{Value, value::ValueKind}; DB.with(|db| match (method, &args) { ("exists", &[]) => Ok(db.entity(self.0).or_none().is_some().into()), ("component", &[name]) if name.kind() == ValueKind::String => { Ok(self.get_value(name).unwrap_or_default()) } ("components", &[]) => Ok(db .entity(self.0) .component_names() .collect::>() .into()), ("id", &[]) => Ok(Value::from(self.0)), ("has", &[name]) if name.kind() == ValueKind::String => { let name = name.as_str().unwrap(); Ok(Value::from( db.entity(self.0).component_names().any(|c| c == name), )) } ("last_modified", &[]) => { Ok(Value::from(db.entity(self.0).last_modified().to_rfc3339())) } ("ref", &[component]) if component.kind() == ValueKind::String => { let Some(component) = db.entity(self.0).dyn_component(component.as_str().unwrap()) else { return Ok(Value::default()); }; #[derive(Deserialize, Serialize, Component)] struct RefComponent(EntityId); match component.as_typed() { Ok(RefComponent(eid)) => Ok(Value::from_object(TemplateEntity(eid))), Err(e) => Err(minijinja::Error::new( minijinja::ErrorKind::CannotDeserialize, format!("{} not a referencing component: {e}", component.name()), )), } } (other, args) => Err(minijinja::Error::new( minijinja::ErrorKind::UnknownMethod, format!( "{}({})", other, args.iter() .map(|v| v.kind().to_string()) .collect::>() .join(", ") ), )), }) } } pub async fn htmx_middleware( app_state: extract::State>, user_session: Option, request: axum::http::Request, next: axum::middleware::Next, ) -> axum::response::Response { let is_hx_request = request .headers() .get("HX-Request") .is_some_and(|h| h.to_str().is_ok_and(|v| v == "true")); let is_hx_boosted = request .headers() .get("HX-Boosted") .is_some_and(|h| h.to_str().is_ok_and(|v| v == "true")); let is_htmx_request = is_hx_request && !is_hx_boosted; // let app_state = request.extensions().get::().unwrap().to_owned(); let mut response = next.run(request).await; let Some(HtmxTemplate { template_name, mut context, }) = response.extensions_mut().remove::() else { return response; }; context["debugger_enabled"] = json!(cfg!(debug_assertions)); // Enrich `context` with user-info if let Some(UserSession { athlete_id }) = user_session { let db = app_state.acquire_db().await; if let Some(account) = db.query_filtered::(athlete_id).next() { context["session"]["athlete"] = json!(account); context["session"]["athlete"]["profile_url"] = json!(account.profile_url()); context["session"]["athlete"]["full_name"] = json!(account.format_name()); } context["athlete"]["id"] = json!(athlete_id); } let run = || { let body = if is_htmx_request { let template = app_state.templates.get_template(template_name).unwrap(); let mut state = template.eval_to_state(&context).unwrap(); match state.render_block("content") { Ok(x) => x, Err(e) if e.kind() == minijinja::ErrorKind::UnknownBlock => { template.render(context).unwrap() } Err(e) => panic!("{e}"), } } else { app_state.render_template(template_name, context) }; // Fix Content-Length response.headers_mut().remove(header::CONTENT_LENGTH); response .headers_mut() .insert(header::CONTENT_LENGTH, HeaderValue::from(body.len())); *response.body_mut() = axum::body::Body::new(body); // Default to HTML response .headers_mut() .entry(header::CONTENT_TYPE) .or_insert(http::HeaderValue::from_static( mime::TEXT_HTML_UTF_8.as_ref(), )); response }; DB.sync_scope(app_state.acquire_db().await, run) }