this repo has no description
1use axum::{
2 extract,
3 response::{IntoResponse, Response},
4};
5use ecsdb::{Component, Entity, EntityId, query};
6use http::{HeaderValue, header};
7use serde::{Deserialize, Serialize};
8use serde_json::json;
9use tracing::{debug, debug_span};
10
11use crate::{AppState, AthleteId, strava::Athlete};
12
13use super::UserSession;
14
15#[derive(Clone)]
16pub struct HtmxTemplate {
17 pub template_name: &'static str,
18 pub context: serde_json::Value,
19}
20
21impl HtmxTemplate {
22 pub fn new(template_name: &'static str, context: serde_json::Value) -> Self {
23 Self {
24 template_name,
25 context,
26 }
27 }
28}
29
30impl IntoResponse for HtmxTemplate {
31 fn into_response(self) -> Response<axum::body::Body> {
32 Response::builder()
33 .extension(self)
34 .body(axum::body::Body::empty())
35 .unwrap()
36 }
37}
38
39tokio::task_local! {
40 static DB: ecsdb::Ecs;
41}
42
43#[derive(Debug, Copy, Clone)]
44pub struct TemplateEcs;
45
46impl minijinja::value::Object for TemplateEcs {
47 fn call_method(
48 self: &std::sync::Arc<Self>,
49 _state: &minijinja::State<'_, '_>,
50 method: &str,
51 args: &[minijinja::Value],
52 ) -> Result<minijinja::Value, minijinja::Error> {
53 DB.with(|db| match (method, &args) {
54 ("query", component_names) => {
55 let _span = debug_span!("query", ?component_names).entered();
56
57 let component_names: Vec<query::ComponentName> = component_names
58 .iter()
59 .filter_map(|name| name.as_str())
60 .map(|name| query::ComponentName(name.to_owned()))
61 .collect();
62
63 let entities = db
64 .query_filtered::<Entity, ()>(&component_names[..])
65 .map(|e| TemplateEntity(e.id()))
66 .map(minijinja::Value::from_object)
67 .collect::<Vec<_>>();
68
69 debug!(?entities);
70
71 Ok(entities.into())
72 }
73
74 (other, args) => Err(minijinja::Error::new(
75 minijinja::ErrorKind::UnknownMethod,
76 format!(
77 "{}({})",
78 other,
79 args.iter()
80 .map(|v| v.kind().to_string())
81 .collect::<Vec<_>>()
82 .join(", ")
83 ),
84 )),
85 })
86 }
87}
88
89#[derive(Debug, Copy, Clone, PartialEq, Eq, Ord, PartialOrd)]
90pub struct TemplateEntity(pub EntityId);
91
92impl TemplateEntity {}
93
94impl minijinja::value::Object for TemplateEntity {
95 fn is_true(self: &std::sync::Arc<Self>) -> bool {
96 DB.with(|db| db.entity(self.0).exists())
97 }
98
99 fn get_value(
100 self: &std::sync::Arc<Self>,
101 field: &minijinja::Value,
102 ) -> Option<minijinja::Value> {
103 DB.with(|db| {
104 db.entity(self.0)
105 .dyn_component(field.as_str()?)
106 .and_then(|c| c.as_json())
107 .map(minijinja::Value::from_serialize)
108 })
109 }
110
111 fn call_method(
112 self: &std::sync::Arc<Self>,
113 _state: &minijinja::State<'_, '_>,
114 method: &str,
115 args: &[minijinja::Value],
116 ) -> Result<minijinja::Value, minijinja::Error> {
117 use minijinja::{Value, value::ValueKind};
118
119 DB.with(|db| match (method, &args) {
120 ("exists", &[]) => Ok(db.entity(self.0).or_none().is_some().into()),
121 ("component", &[name]) if name.kind() == ValueKind::String => {
122 Ok(self.get_value(name).unwrap_or_default())
123 }
124 ("components", &[]) => Ok(db
125 .entity(self.0)
126 .component_names()
127 .collect::<Vec<_>>()
128 .into()),
129 ("id", &[]) => Ok(Value::from(self.0)),
130 ("has", &[name]) if name.kind() == ValueKind::String => {
131 let name = name.as_str().unwrap();
132 Ok(Value::from(
133 db.entity(self.0).component_names().any(|c| c == name),
134 ))
135 }
136
137 ("last_modified", &[]) => {
138 Ok(Value::from(db.entity(self.0).last_modified().to_rfc3339()))
139 }
140 ("ref", &[component]) if component.kind() == ValueKind::String => {
141 let Some(component) = db.entity(self.0).dyn_component(component.as_str().unwrap())
142 else {
143 return Ok(Value::default());
144 };
145
146 #[derive(Deserialize, Serialize, Component)]
147 struct RefComponent(EntityId);
148
149 match component.as_typed() {
150 Ok(RefComponent(eid)) => Ok(Value::from_object(TemplateEntity(eid))),
151 Err(e) => Err(minijinja::Error::new(
152 minijinja::ErrorKind::CannotDeserialize,
153 format!("{} not a referencing component: {e}", component.name()),
154 )),
155 }
156 }
157 (other, args) => Err(minijinja::Error::new(
158 minijinja::ErrorKind::UnknownMethod,
159 format!(
160 "{}({})",
161 other,
162 args.iter()
163 .map(|v| v.kind().to_string())
164 .collect::<Vec<_>>()
165 .join(", ")
166 ),
167 )),
168 })
169 }
170}
171
172pub async fn htmx_middleware(
173 app_state: extract::State<AppState<'static>>,
174 user_session: Option<UserSession>,
175 request: axum::http::Request<axum::body::Body>,
176 next: axum::middleware::Next,
177) -> axum::response::Response {
178 let is_hx_request = request
179 .headers()
180 .get("HX-Request")
181 .is_some_and(|h| h.to_str().is_ok_and(|v| v == "true"));
182 let is_hx_boosted = request
183 .headers()
184 .get("HX-Boosted")
185 .is_some_and(|h| h.to_str().is_ok_and(|v| v == "true"));
186
187 let is_htmx_request = is_hx_request && !is_hx_boosted;
188
189 // let app_state = request.extensions().get::<AppState>().unwrap().to_owned();
190 let mut response = next.run(request).await;
191
192 let Some(HtmxTemplate {
193 template_name,
194 mut context,
195 }) = response.extensions_mut().remove::<HtmxTemplate>()
196 else {
197 return response;
198 };
199
200 context["debugger_enabled"] = json!(cfg!(debug_assertions));
201
202 // Enrich `context` with user-info
203 if let Some(UserSession { athlete_id }) = user_session {
204 let db = app_state.acquire_db().await;
205
206 if let Some(account) = db.query_filtered::<Athlete, AthleteId>(athlete_id).next() {
207 context["session"]["athlete"] = json!(account);
208 context["session"]["athlete"]["profile_url"] = json!(account.profile_url());
209 context["session"]["athlete"]["full_name"] = json!(account.format_name());
210 }
211
212 context["athlete"]["id"] = json!(athlete_id);
213 }
214
215 let run = || {
216 let body = if is_htmx_request {
217 let template = app_state.templates.get_template(template_name).unwrap();
218
219 let mut state = template.eval_to_state(&context).unwrap();
220 match state.render_block("content") {
221 Ok(x) => x,
222 Err(e) if e.kind() == minijinja::ErrorKind::UnknownBlock => {
223 template.render(context).unwrap()
224 }
225 Err(e) => panic!("{e}"),
226 }
227 } else {
228 app_state.render_template(template_name, context)
229 };
230
231 // Fix Content-Length
232 response.headers_mut().remove(header::CONTENT_LENGTH);
233 response
234 .headers_mut()
235 .insert(header::CONTENT_LENGTH, HeaderValue::from(body.len()));
236
237 *response.body_mut() = axum::body::Body::new(body);
238
239 // Default to HTML
240 response
241 .headers_mut()
242 .entry(header::CONTENT_TYPE)
243 .or_insert(http::HeaderValue::from_static(
244 mime::TEXT_HTML_UTF_8.as_ref(),
245 ));
246
247 response
248 };
249
250 DB.sync_scope(app_state.acquire_db().await, run)
251}