at main 251 lines 8.1 kB view raw
1use axum::{ 2 extract, 3 response::{IntoResponse, Response}, 4}; 5use ecsdb::{Component, Entity, EntityId, query}; 6use http::{HeaderValue, header}; 7use serde::{Deserialize, Serialize}; 8use serde_json::json; 9use tracing::{debug, debug_span}; 10 11use crate::{AppState, AthleteId, strava::Athlete}; 12 13use super::UserSession; 14 15#[derive(Clone)] 16pub struct HtmxTemplate { 17 pub template_name: &'static str, 18 pub context: serde_json::Value, 19} 20 21impl HtmxTemplate { 22 pub fn new(template_name: &'static str, context: serde_json::Value) -> Self { 23 Self { 24 template_name, 25 context, 26 } 27 } 28} 29 30impl IntoResponse for HtmxTemplate { 31 fn into_response(self) -> Response<axum::body::Body> { 32 Response::builder() 33 .extension(self) 34 .body(axum::body::Body::empty()) 35 .unwrap() 36 } 37} 38 39tokio::task_local! { 40 static DB: ecsdb::Ecs; 41} 42 43#[derive(Debug, Copy, Clone)] 44pub struct TemplateEcs; 45 46impl minijinja::value::Object for TemplateEcs { 47 fn call_method( 48 self: &std::sync::Arc<Self>, 49 _state: &minijinja::State<'_, '_>, 50 method: &str, 51 args: &[minijinja::Value], 52 ) -> Result<minijinja::Value, minijinja::Error> { 53 DB.with(|db| match (method, &args) { 54 ("query", component_names) => { 55 let _span = debug_span!("query", ?component_names).entered(); 56 57 let component_names: Vec<query::ComponentName> = component_names 58 .iter() 59 .filter_map(|name| name.as_str()) 60 .map(|name| query::ComponentName(name.to_owned())) 61 .collect(); 62 63 let entities = db 64 .query_filtered::<Entity, ()>(&component_names[..]) 65 .map(|e| TemplateEntity(e.id())) 66 .map(minijinja::Value::from_object) 67 .collect::<Vec<_>>(); 68 69 debug!(?entities); 70 71 Ok(entities.into()) 72 } 73 74 (other, args) => Err(minijinja::Error::new( 75 minijinja::ErrorKind::UnknownMethod, 76 format!( 77 "{}({})", 78 other, 79 args.iter() 80 .map(|v| v.kind().to_string()) 81 .collect::<Vec<_>>() 82 .join(", ") 83 ), 84 )), 85 }) 86 } 87} 88 89#[derive(Debug, Copy, Clone, PartialEq, Eq, Ord, PartialOrd)] 90pub struct TemplateEntity(pub EntityId); 91 92impl TemplateEntity {} 93 94impl minijinja::value::Object for TemplateEntity { 95 fn is_true(self: &std::sync::Arc<Self>) -> bool { 96 DB.with(|db| db.entity(self.0).exists()) 97 } 98 99 fn get_value( 100 self: &std::sync::Arc<Self>, 101 field: &minijinja::Value, 102 ) -> Option<minijinja::Value> { 103 DB.with(|db| { 104 db.entity(self.0) 105 .dyn_component(field.as_str()?) 106 .and_then(|c| c.as_json()) 107 .map(minijinja::Value::from_serialize) 108 }) 109 } 110 111 fn call_method( 112 self: &std::sync::Arc<Self>, 113 _state: &minijinja::State<'_, '_>, 114 method: &str, 115 args: &[minijinja::Value], 116 ) -> Result<minijinja::Value, minijinja::Error> { 117 use minijinja::{Value, value::ValueKind}; 118 119 DB.with(|db| match (method, &args) { 120 ("exists", &[]) => Ok(db.entity(self.0).or_none().is_some().into()), 121 ("component", &[name]) if name.kind() == ValueKind::String => { 122 Ok(self.get_value(name).unwrap_or_default()) 123 } 124 ("components", &[]) => Ok(db 125 .entity(self.0) 126 .component_names() 127 .collect::<Vec<_>>() 128 .into()), 129 ("id", &[]) => Ok(Value::from(self.0)), 130 ("has", &[name]) if name.kind() == ValueKind::String => { 131 let name = name.as_str().unwrap(); 132 Ok(Value::from( 133 db.entity(self.0).component_names().any(|c| c == name), 134 )) 135 } 136 137 ("last_modified", &[]) => { 138 Ok(Value::from(db.entity(self.0).last_modified().to_rfc3339())) 139 } 140 ("ref", &[component]) if component.kind() == ValueKind::String => { 141 let Some(component) = db.entity(self.0).dyn_component(component.as_str().unwrap()) 142 else { 143 return Ok(Value::default()); 144 }; 145 146 #[derive(Deserialize, Serialize, Component)] 147 struct RefComponent(EntityId); 148 149 match component.as_typed() { 150 Ok(RefComponent(eid)) => Ok(Value::from_object(TemplateEntity(eid))), 151 Err(e) => Err(minijinja::Error::new( 152 minijinja::ErrorKind::CannotDeserialize, 153 format!("{} not a referencing component: {e}", component.name()), 154 )), 155 } 156 } 157 (other, args) => Err(minijinja::Error::new( 158 minijinja::ErrorKind::UnknownMethod, 159 format!( 160 "{}({})", 161 other, 162 args.iter() 163 .map(|v| v.kind().to_string()) 164 .collect::<Vec<_>>() 165 .join(", ") 166 ), 167 )), 168 }) 169 } 170} 171 172pub async fn htmx_middleware( 173 app_state: extract::State<AppState<'static>>, 174 user_session: Option<UserSession>, 175 request: axum::http::Request<axum::body::Body>, 176 next: axum::middleware::Next, 177) -> axum::response::Response { 178 let is_hx_request = request 179 .headers() 180 .get("HX-Request") 181 .is_some_and(|h| h.to_str().is_ok_and(|v| v == "true")); 182 let is_hx_boosted = request 183 .headers() 184 .get("HX-Boosted") 185 .is_some_and(|h| h.to_str().is_ok_and(|v| v == "true")); 186 187 let is_htmx_request = is_hx_request && !is_hx_boosted; 188 189 // let app_state = request.extensions().get::<AppState>().unwrap().to_owned(); 190 let mut response = next.run(request).await; 191 192 let Some(HtmxTemplate { 193 template_name, 194 mut context, 195 }) = response.extensions_mut().remove::<HtmxTemplate>() 196 else { 197 return response; 198 }; 199 200 context["debugger_enabled"] = json!(cfg!(debug_assertions)); 201 202 // Enrich `context` with user-info 203 if let Some(UserSession { athlete_id }) = user_session { 204 let db = app_state.acquire_db().await; 205 206 if let Some(account) = db.query_filtered::<Athlete, AthleteId>(athlete_id).next() { 207 context["session"]["athlete"] = json!(account); 208 context["session"]["athlete"]["profile_url"] = json!(account.profile_url()); 209 context["session"]["athlete"]["full_name"] = json!(account.format_name()); 210 } 211 212 context["athlete"]["id"] = json!(athlete_id); 213 } 214 215 let run = || { 216 let body = if is_htmx_request { 217 let template = app_state.templates.get_template(template_name).unwrap(); 218 219 let mut state = template.eval_to_state(&context).unwrap(); 220 match state.render_block("content") { 221 Ok(x) => x, 222 Err(e) if e.kind() == minijinja::ErrorKind::UnknownBlock => { 223 template.render(context).unwrap() 224 } 225 Err(e) => panic!("{e}"), 226 } 227 } else { 228 app_state.render_template(template_name, context) 229 }; 230 231 // Fix Content-Length 232 response.headers_mut().remove(header::CONTENT_LENGTH); 233 response 234 .headers_mut() 235 .insert(header::CONTENT_LENGTH, HeaderValue::from(body.len())); 236 237 *response.body_mut() = axum::body::Body::new(body); 238 239 // Default to HTML 240 response 241 .headers_mut() 242 .entry(header::CONTENT_TYPE) 243 .or_insert(http::HeaderValue::from_static( 244 mime::TEXT_HTML_UTF_8.as_ref(), 245 )); 246 247 response 248 }; 249 250 DB.sync_scope(app_state.acquire_db().await, run) 251}