Rust library to generate static websites
1use axum::{
2 Router,
3 body::{Body, to_bytes},
4 extract::{
5 Request, State,
6 ws::{Message, WebSocket, WebSocketUpgrade},
7 },
8 http::{HeaderValue, StatusCode, Uri, header::CONTENT_LENGTH},
9 middleware::{self, Next},
10 response::{IntoResponse, Response},
11 routing::get,
12};
13use quanta::Instant;
14use serde_json::json;
15use tokio::{
16 net::TcpSocket,
17 signal,
18 sync::{RwLock, broadcast},
19};
20use tracing::{Level, debug};
21
22use std::net::{IpAddr, SocketAddr};
23use std::sync::Arc;
24use tower_http::{
25 services::ServeDir,
26 trace::{DefaultMakeSpan, TraceLayer},
27};
28
29use axum::extract::connect_info::ConnectInfo;
30use futures::{SinkExt, stream::StreamExt};
31
32use crate::consts::PORT;
33use crate::server_utils::{CustomOnResponse, find_open_port, log_server_start};
34use axum::http::header;
35use local_ip_address::local_ip;
36use tokio::fs;
37
38#[derive(Clone, Debug)]
39pub struct WebSocketMessage {
40 pub data: String,
41}
42
43#[derive(Clone, Debug)]
44pub enum StatusType {
45 Success,
46 Info,
47 Error,
48}
49
50impl std::fmt::Display for StatusType {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 match self {
53 StatusType::Success => write!(f, "success"),
54 StatusType::Info => write!(f, "info"),
55 StatusType::Error => write!(f, "error"),
56 }
57 }
58}
59
60// Persistent state for new connections
61#[derive(Clone, Debug)]
62pub struct PersistentStatus {
63 pub status_type: StatusType, // Only Success or Error
64 pub message: String,
65}
66
67#[derive(Clone)]
68struct AppState {
69 tx: broadcast::Sender<WebSocketMessage>,
70 current_status: Arc<RwLock<Option<PersistentStatus>>>,
71}
72
73fn inject_live_reload_script(html_content: &str, socket_addr: SocketAddr, host: bool) -> String {
74 let mut content = html_content.to_string();
75
76 // Run cargo xtask build-cli-js to build the client.js file if missing
77 let script_content = include_str!("../../js/dist/client.js").replace(
78 "{SERVER_ADDRESS}",
79 &format!(
80 "{}:{}",
81 if !host {
82 socket_addr.ip().to_string()
83 } else {
84 local_ip().unwrap().to_string()
85 },
86 &socket_addr.port().to_string()
87 ),
88 );
89
90 content.push_str(&format!("\n\n<script>{script_content}</script>"));
91 content
92}
93
94pub async fn start_dev_web_server(
95 start_time: Instant,
96 tx: broadcast::Sender<WebSocketMessage>,
97 host: bool,
98 port: Option<u16>,
99 initial_error: Option<String>,
100 current_status: Arc<RwLock<Option<PersistentStatus>>>,
101) {
102 // TODO: The dist dir should be configurable
103 let dist_dir = "dist";
104
105 // Send initial error if present
106 if let Some(error) = initial_error {
107 let _ = tx.send(WebSocketMessage {
108 data: json!({
109 "type": StatusType::Error.to_string(),
110 "message": error
111 })
112 .to_string(),
113 });
114 }
115
116 async fn handle_404(socket_addr: SocketAddr, host: bool, dist_dir: &str) -> impl IntoResponse {
117 let content = match fs::read_to_string(format!("{}/404.html", dist_dir)).await {
118 Ok(custom_content) => custom_content,
119 Err(_) => include_str!("./404.html").to_string(),
120 };
121
122 (
123 StatusCode::NOT_FOUND,
124 [(header::CONTENT_TYPE, "text/html; charset=utf-8")],
125 inject_live_reload_script(&content, socket_addr, host),
126 )
127 .into_response()
128 }
129
130 // run it with hyper, if --host 0.0.0.0 otherwise localhost
131 let addr = if host {
132 IpAddr::from([0, 0, 0, 0])
133 } else {
134 IpAddr::from([127, 0, 0, 1])
135 };
136
137 // Use provided port or default to the constant PORT
138 let starting_port = port.unwrap_or(PORT);
139
140 let port = find_open_port(&addr, starting_port).await;
141 let socket = TcpSocket::new_v4().unwrap();
142 let _ = socket.set_reuseaddr(true);
143
144 let socket_addr = SocketAddr::new(addr, port);
145 socket.bind(socket_addr).unwrap();
146
147 let listener = socket.listen(1024).unwrap();
148
149 debug!("listening on {}", listener.local_addr().unwrap());
150
151 let serve_dir =
152 ServeDir::new(dist_dir).not_found_service(axum::routing::any(move || async move {
153 handle_404(socket_addr, host, dist_dir).await
154 }));
155
156 // TODO: Return a `.well-known/appspecific/com.chrome.devtools.json` for Chrome
157
158 let router = Router::new()
159 .route("/ws", get(ws_handler))
160 .fallback_service(serve_dir)
161 .layer(middleware::from_fn(add_cache_headers))
162 .layer(middleware::from_fn(move |req, next| {
163 add_dev_client_script(req, next, socket_addr, host)
164 }))
165 .layer(
166 TraceLayer::new_for_http()
167 .make_span_with(DefaultMakeSpan::default().include_headers(true)),
168 )
169 .layer(
170 TraceLayer::new_for_http()
171 .make_span_with(DefaultMakeSpan::new().level(Level::INFO))
172 .on_response(CustomOnResponse),
173 )
174 .with_state(AppState {
175 tx: tx.clone(),
176 current_status: current_status.clone(),
177 });
178
179 log_server_start(
180 start_time,
181 host,
182 listener.local_addr().unwrap(),
183 "Development",
184 );
185
186 axum::serve(
187 listener,
188 router.into_make_service_with_connect_info::<SocketAddr>(),
189 )
190 .with_graceful_shutdown(shutdown_signal())
191 .await
192 .unwrap();
193}
194
195pub async fn update_status(
196 tx: &broadcast::Sender<WebSocketMessage>,
197 current_status: Arc<RwLock<Option<PersistentStatus>>>,
198 status_type: StatusType,
199 message: &str,
200) {
201 // Only store persistent states (Success clears errors, Error stores the error)
202 let persistent_status = match status_type {
203 StatusType::Success => None, // Clear any error state
204 StatusType::Error => Some(PersistentStatus {
205 status_type: StatusType::Error,
206 message: message.to_string(),
207 }),
208 // Everything else just keeps the current state
209 _ => {
210 let status = current_status.read().await;
211 status.clone() // Keep existing persistent state
212 }
213 };
214
215 // Update the stored status
216 {
217 let mut status = current_status.write().await;
218 *status = persistent_status;
219 }
220
221 // Send the message to all connected clients
222 let _ = tx.send(WebSocketMessage {
223 data: json!({
224 "type": status_type.to_string(),
225 "message": message
226 })
227 .to_string(),
228 });
229}
230
231async fn add_dev_client_script(
232 req: Request,
233 next: Next,
234 socket_addr: SocketAddr,
235 host: bool,
236) -> Response {
237 let uri = req.uri().clone();
238 let mut res: axum::http::Response<Body> = next.run(req).await;
239
240 res.extensions_mut().insert(uri.clone());
241
242 if res.headers().get(axum::http::header::CONTENT_TYPE)
243 == Some(&HeaderValue::from_static("text/html"))
244 {
245 let original_headers = res.headers().clone();
246 let body = res.into_body();
247 let bytes = to_bytes(body, usize::MAX).await.unwrap();
248
249 let body = String::from_utf8_lossy(&bytes).into_owned();
250
251 let body_with_script = inject_live_reload_script(&body, socket_addr, host);
252 let new_body_length = body_with_script.len();
253
254 // Copy the headers from the original response
255 let mut res = Response::new(body_with_script.into());
256 *res.headers_mut() = original_headers;
257
258 // Update Content-Length header to match new body size
259 res.headers_mut().insert(
260 CONTENT_LENGTH,
261 HeaderValue::from_str(&new_body_length.to_string()).unwrap(),
262 );
263
264 res.extensions_mut().insert(uri);
265
266 return res;
267 }
268
269 res
270}
271
272async fn add_cache_headers(req: Request, next: Next) -> Response {
273 let uri = req.uri().clone();
274 let mut res = next.run(req).await;
275
276 if let Some(content_type) = res.headers().get(axum::http::header::CONTENT_TYPE) {
277 let cache_header = cache_header_by_content(&uri, content_type);
278 if let Some(cache_header) = cache_header {
279 res.headers_mut()
280 .insert(header::CACHE_CONTROL, cache_header);
281 }
282 }
283
284 res
285}
286
287fn cache_header_by_content(uri: &Uri, content_type: &HeaderValue) -> Option<HeaderValue> {
288 if content_type == HeaderValue::from_static("text/html") {
289 // No cache for HTML files
290 Some(HeaderValue::from_static(
291 "no-cache, no-store, must-revalidate",
292 ))
293 }
294 // If something comes from the assets path, assume that it's fingerprinted and can be cached for a long time
295 // TODO: Same as dist, shouldn't be hardcoded
296 else if uri.path().starts_with("/_maudit/") {
297 Some(HeaderValue::from_static(
298 "public, max-age=31536000, immutable",
299 ))
300 } else {
301 // Don't try to cache anything else, the browser will decide based on the last-modified header
302 None
303 }
304}
305
306async fn ws_handler(
307 ws: WebSocketUpgrade,
308 ConnectInfo(addr): ConnectInfo<SocketAddr>,
309 State(state): State<AppState>,
310) -> impl IntoResponse {
311 debug!("`{addr} connected.");
312 // finalize the upgrade process by returning upgrade callback.
313 // we can customize the callback by sending additional info such as address.
314 ws.on_upgrade(move |socket| handle_socket(socket, addr, state.tx, state.current_status))
315}
316
317async fn handle_socket(
318 socket: WebSocket,
319 who: SocketAddr,
320 tx: broadcast::Sender<WebSocketMessage>,
321 current_status: Arc<RwLock<Option<PersistentStatus>>>,
322) {
323 let (mut sender, mut receiver) = socket.split();
324
325 // Send current persistent status to new connection if there is one
326 {
327 let status = current_status.read().await;
328 if let Some(persistent_status) = status.as_ref() {
329 let _ = sender
330 .send(Message::Text(
331 json!({
332 "type": persistent_status.status_type.to_string(),
333 "message": persistent_status.message
334 })
335 .to_string()
336 .into(),
337 ))
338 .await;
339 }
340 }
341
342 let mut rx = tx.subscribe();
343
344 tokio::select! {
345 _ = async {
346 while let Some(Ok(msg)) = receiver.next().await {
347 match msg {
348 Message::Text(_) => {}
349 Message::Binary(_) => {
350 }
351 _ => {}
352 }
353 }
354 } => {},
355 _ = async {
356 while let Ok(msg) = rx.recv().await {
357 debug!(">>> got messages from higher level: {0}", msg.data);
358 let _ = sender.send(Message::Text(msg.data.into())).await;
359 }
360 } => {},
361 }
362
363 // returning from the handler closes the websocket connection
364 debug!("Websocket context {who} destroyed");
365}
366
367async fn shutdown_signal() {
368 let ctrl_c = async {
369 signal::ctrl_c()
370 .await
371 .expect("failed to install Ctrl+C handler");
372 };
373
374 #[cfg(unix)]
375 let terminate = async {
376 signal::unix::signal(signal::unix::SignalKind::terminate())
377 .expect("failed to install signal handler")
378 .recv()
379 .await;
380 };
381
382 #[cfg(not(unix))]
383 let terminate = std::future::pending::<()>();
384
385 tokio::select! {
386 _ = ctrl_c => {},
387 _ = terminate => {},
388 }
389}