Rust library to generate static websites
at fix/misc-errors 389 lines 12 kB view raw
1use axum::{ 2 Router, 3 body::{Body, to_bytes}, 4 extract::{ 5 Request, State, 6 ws::{Message, WebSocket, WebSocketUpgrade}, 7 }, 8 http::{HeaderValue, StatusCode, Uri, header::CONTENT_LENGTH}, 9 middleware::{self, Next}, 10 response::{IntoResponse, Response}, 11 routing::get, 12}; 13use quanta::Instant; 14use serde_json::json; 15use tokio::{ 16 net::TcpSocket, 17 signal, 18 sync::{RwLock, broadcast}, 19}; 20use tracing::{Level, debug}; 21 22use std::net::{IpAddr, SocketAddr}; 23use std::sync::Arc; 24use tower_http::{ 25 services::ServeDir, 26 trace::{DefaultMakeSpan, TraceLayer}, 27}; 28 29use axum::extract::connect_info::ConnectInfo; 30use futures::{SinkExt, stream::StreamExt}; 31 32use crate::consts::PORT; 33use crate::server_utils::{CustomOnResponse, find_open_port, log_server_start}; 34use axum::http::header; 35use local_ip_address::local_ip; 36use tokio::fs; 37 38#[derive(Clone, Debug)] 39pub struct WebSocketMessage { 40 pub data: String, 41} 42 43#[derive(Clone, Debug)] 44pub enum StatusType { 45 Success, 46 Info, 47 Error, 48} 49 50impl std::fmt::Display for StatusType { 51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 52 match self { 53 StatusType::Success => write!(f, "success"), 54 StatusType::Info => write!(f, "info"), 55 StatusType::Error => write!(f, "error"), 56 } 57 } 58} 59 60// Persistent state for new connections 61#[derive(Clone, Debug)] 62pub struct PersistentStatus { 63 pub status_type: StatusType, // Only Success or Error 64 pub message: String, 65} 66 67#[derive(Clone)] 68struct AppState { 69 tx: broadcast::Sender<WebSocketMessage>, 70 current_status: Arc<RwLock<Option<PersistentStatus>>>, 71} 72 73fn inject_live_reload_script(html_content: &str, socket_addr: SocketAddr, host: bool) -> String { 74 let mut content = html_content.to_string(); 75 76 // Run cargo xtask build-cli-js to build the client.js file if missing 77 let script_content = include_str!("../../js/dist/client.js").replace( 78 "{SERVER_ADDRESS}", 79 &format!( 80 "{}:{}", 81 if !host { 82 socket_addr.ip().to_string() 83 } else { 84 local_ip().unwrap().to_string() 85 }, 86 &socket_addr.port().to_string() 87 ), 88 ); 89 90 content.push_str(&format!("\n\n<script>{script_content}</script>")); 91 content 92} 93 94pub async fn start_dev_web_server( 95 start_time: Instant, 96 tx: broadcast::Sender<WebSocketMessage>, 97 host: bool, 98 port: Option<u16>, 99 initial_error: Option<String>, 100 current_status: Arc<RwLock<Option<PersistentStatus>>>, 101) { 102 // TODO: The dist dir should be configurable 103 let dist_dir = "dist"; 104 105 // Send initial error if present 106 if let Some(error) = initial_error { 107 let _ = tx.send(WebSocketMessage { 108 data: json!({ 109 "type": StatusType::Error.to_string(), 110 "message": error 111 }) 112 .to_string(), 113 }); 114 } 115 116 async fn handle_404(socket_addr: SocketAddr, host: bool, dist_dir: &str) -> impl IntoResponse { 117 let content = match fs::read_to_string(format!("{}/404.html", dist_dir)).await { 118 Ok(custom_content) => custom_content, 119 Err(_) => include_str!("./404.html").to_string(), 120 }; 121 122 ( 123 StatusCode::NOT_FOUND, 124 [(header::CONTENT_TYPE, "text/html; charset=utf-8")], 125 inject_live_reload_script(&content, socket_addr, host), 126 ) 127 .into_response() 128 } 129 130 // run it with hyper, if --host 0.0.0.0 otherwise localhost 131 let addr = if host { 132 IpAddr::from([0, 0, 0, 0]) 133 } else { 134 IpAddr::from([127, 0, 0, 1]) 135 }; 136 137 // Use provided port or default to the constant PORT 138 let starting_port = port.unwrap_or(PORT); 139 140 let port = find_open_port(&addr, starting_port).await; 141 let socket = TcpSocket::new_v4().unwrap(); 142 let _ = socket.set_reuseaddr(true); 143 144 let socket_addr = SocketAddr::new(addr, port); 145 socket.bind(socket_addr).unwrap(); 146 147 let listener = socket.listen(1024).unwrap(); 148 149 debug!("listening on {}", listener.local_addr().unwrap()); 150 151 let serve_dir = 152 ServeDir::new(dist_dir).not_found_service(axum::routing::any(move || async move { 153 handle_404(socket_addr, host, dist_dir).await 154 })); 155 156 // TODO: Return a `.well-known/appspecific/com.chrome.devtools.json` for Chrome 157 158 let router = Router::new() 159 .route("/ws", get(ws_handler)) 160 .fallback_service(serve_dir) 161 .layer(middleware::from_fn(add_cache_headers)) 162 .layer(middleware::from_fn(move |req, next| { 163 add_dev_client_script(req, next, socket_addr, host) 164 })) 165 .layer( 166 TraceLayer::new_for_http() 167 .make_span_with(DefaultMakeSpan::default().include_headers(true)), 168 ) 169 .layer( 170 TraceLayer::new_for_http() 171 .make_span_with(DefaultMakeSpan::new().level(Level::INFO)) 172 .on_response(CustomOnResponse), 173 ) 174 .with_state(AppState { 175 tx: tx.clone(), 176 current_status: current_status.clone(), 177 }); 178 179 log_server_start( 180 start_time, 181 host, 182 listener.local_addr().unwrap(), 183 "Development", 184 ); 185 186 axum::serve( 187 listener, 188 router.into_make_service_with_connect_info::<SocketAddr>(), 189 ) 190 .with_graceful_shutdown(shutdown_signal()) 191 .await 192 .unwrap(); 193} 194 195pub async fn update_status( 196 tx: &broadcast::Sender<WebSocketMessage>, 197 current_status: Arc<RwLock<Option<PersistentStatus>>>, 198 status_type: StatusType, 199 message: &str, 200) { 201 // Only store persistent states (Success clears errors, Error stores the error) 202 let persistent_status = match status_type { 203 StatusType::Success => None, // Clear any error state 204 StatusType::Error => Some(PersistentStatus { 205 status_type: StatusType::Error, 206 message: message.to_string(), 207 }), 208 // Everything else just keeps the current state 209 _ => { 210 let status = current_status.read().await; 211 status.clone() // Keep existing persistent state 212 } 213 }; 214 215 // Update the stored status 216 { 217 let mut status = current_status.write().await; 218 *status = persistent_status; 219 } 220 221 // Send the message to all connected clients 222 let _ = tx.send(WebSocketMessage { 223 data: json!({ 224 "type": status_type.to_string(), 225 "message": message 226 }) 227 .to_string(), 228 }); 229} 230 231async fn add_dev_client_script( 232 req: Request, 233 next: Next, 234 socket_addr: SocketAddr, 235 host: bool, 236) -> Response { 237 let uri = req.uri().clone(); 238 let mut res: axum::http::Response<Body> = next.run(req).await; 239 240 res.extensions_mut().insert(uri.clone()); 241 242 if res.headers().get(axum::http::header::CONTENT_TYPE) 243 == Some(&HeaderValue::from_static("text/html")) 244 { 245 let original_headers = res.headers().clone(); 246 let body = res.into_body(); 247 let bytes = to_bytes(body, usize::MAX).await.unwrap(); 248 249 let body = String::from_utf8_lossy(&bytes).into_owned(); 250 251 let body_with_script = inject_live_reload_script(&body, socket_addr, host); 252 let new_body_length = body_with_script.len(); 253 254 // Copy the headers from the original response 255 let mut res = Response::new(body_with_script.into()); 256 *res.headers_mut() = original_headers; 257 258 // Update Content-Length header to match new body size 259 res.headers_mut().insert( 260 CONTENT_LENGTH, 261 HeaderValue::from_str(&new_body_length.to_string()).unwrap(), 262 ); 263 264 res.extensions_mut().insert(uri); 265 266 return res; 267 } 268 269 res 270} 271 272async fn add_cache_headers(req: Request, next: Next) -> Response { 273 let uri = req.uri().clone(); 274 let mut res = next.run(req).await; 275 276 if let Some(content_type) = res.headers().get(axum::http::header::CONTENT_TYPE) { 277 let cache_header = cache_header_by_content(&uri, content_type); 278 if let Some(cache_header) = cache_header { 279 res.headers_mut() 280 .insert(header::CACHE_CONTROL, cache_header); 281 } 282 } 283 284 res 285} 286 287fn cache_header_by_content(uri: &Uri, content_type: &HeaderValue) -> Option<HeaderValue> { 288 if content_type == HeaderValue::from_static("text/html") { 289 // No cache for HTML files 290 Some(HeaderValue::from_static( 291 "no-cache, no-store, must-revalidate", 292 )) 293 } 294 // If something comes from the assets path, assume that it's fingerprinted and can be cached for a long time 295 // TODO: Same as dist, shouldn't be hardcoded 296 else if uri.path().starts_with("/_maudit/") { 297 Some(HeaderValue::from_static( 298 "public, max-age=31536000, immutable", 299 )) 300 } else { 301 // Don't try to cache anything else, the browser will decide based on the last-modified header 302 None 303 } 304} 305 306async fn ws_handler( 307 ws: WebSocketUpgrade, 308 ConnectInfo(addr): ConnectInfo<SocketAddr>, 309 State(state): State<AppState>, 310) -> impl IntoResponse { 311 debug!("`{addr} connected."); 312 // finalize the upgrade process by returning upgrade callback. 313 // we can customize the callback by sending additional info such as address. 314 ws.on_upgrade(move |socket| handle_socket(socket, addr, state.tx, state.current_status)) 315} 316 317async fn handle_socket( 318 socket: WebSocket, 319 who: SocketAddr, 320 tx: broadcast::Sender<WebSocketMessage>, 321 current_status: Arc<RwLock<Option<PersistentStatus>>>, 322) { 323 let (mut sender, mut receiver) = socket.split(); 324 325 // Send current persistent status to new connection if there is one 326 { 327 let status = current_status.read().await; 328 if let Some(persistent_status) = status.as_ref() { 329 let _ = sender 330 .send(Message::Text( 331 json!({ 332 "type": persistent_status.status_type.to_string(), 333 "message": persistent_status.message 334 }) 335 .to_string() 336 .into(), 337 )) 338 .await; 339 } 340 } 341 342 let mut rx = tx.subscribe(); 343 344 tokio::select! { 345 _ = async { 346 while let Some(Ok(msg)) = receiver.next().await { 347 match msg { 348 Message::Text(_) => {} 349 Message::Binary(_) => { 350 } 351 _ => {} 352 } 353 } 354 } => {}, 355 _ = async { 356 while let Ok(msg) = rx.recv().await { 357 debug!(">>> got messages from higher level: {0}", msg.data); 358 let _ = sender.send(Message::Text(msg.data.into())).await; 359 } 360 } => {}, 361 } 362 363 // returning from the handler closes the websocket connection 364 debug!("Websocket context {who} destroyed"); 365} 366 367async fn shutdown_signal() { 368 let ctrl_c = async { 369 signal::ctrl_c() 370 .await 371 .expect("failed to install Ctrl+C handler"); 372 }; 373 374 #[cfg(unix)] 375 let terminate = async { 376 signal::unix::signal(signal::unix::SignalKind::terminate()) 377 .expect("failed to install signal handler") 378 .recv() 379 .await; 380 }; 381 382 #[cfg(not(unix))] 383 let terminate = std::future::pending::<()>(); 384 385 tokio::select! { 386 _ = ctrl_c => {}, 387 _ = terminate => {}, 388 } 389}