Your one-stop-cake-shop for everything Freshly Baked has to offer
1// SPDX-FileCopyrightText: 2026 Freshly Baked Cake
2//
3// SPDX-License-Identifier: MIT
4mod auth;
5mod direct;
6mod regex;
7mod static_html;
8
9use axum::{
10 Router, ServiceExt,
11 extract::{Path, Query, Request},
12 http::{HeaderMap, StatusCode},
13 response::{IntoResponse, Redirect, Response, Result},
14 routing::get,
15};
16use include_dir::{Dir, include_dir};
17use percent_encoding::{NON_ALPHANUMERIC, utf8_percent_encode};
18use phf::phf_map;
19use sqlx::{Connection, PgConnection};
20use tower_sessions::{MemoryStore, Session, SessionManagerLayer};
21
22use std::{collections::HashMap, env, sync::OnceLock};
23use tokio::{
24 sync::Mutex,
25 time::{Duration, sleep},
26};
27use tower_http::{self, normalize_path::NormalizePathLayer};
28use tower_layer::Layer;
29use tower_serve_static;
30
31use crate::{
32 auth::{ensure_authenticated, ensure_token},
33 static_html::{StaticPageType, handle_static_page},
34};
35
36static PUBLIC_DIR: Dir<'static> = include_dir!("src/html/public");
37
38#[cfg(debug_assertions)]
39static DEVELOPMENT: OnceLock<bool> = OnceLock::new();
40
41#[derive(Debug)]
42struct State {
43 sqlx_connection: Mutex<PgConnection>,
44}
45static STATE: OnceLock<State> = OnceLock::new();
46
47const ALLOWED_HOSTS: &'static [&'static str] = &[
48 "cakeme.nu",
49 "go",
50 "go.search.freshly.space",
51 "menu.freshlybakedca.ke",
52 "starry.sk",
53];
54
55const SEARCH_ENGINES: phf::Map<&'static str, [&'static str; 3]> = phf_map! {
56 "kagi" => [
57 "Kagi",
58 "https://kagi.com/search?q=",
59 "https://kagi.com/api/autosuggest?q=",
60 ],
61 "google" => [
62 "Google",
63 "https://www.google.com/search?q=",
64 "https://www.google.com/complete/search?q=",
65 ],
66 "udm14" => [
67 "Google+UDM14",
68 "https://www.google.com/search?udm=14&q=",
69 "https://www.google.com/complete/search?q=",
70 ],
71 "ddg" => [
72 "DuckDuckGo",
73 "https://duckduckgo.com?q=",
74 "https://duckduckgo.com/ac/?q=",
75 ],
76 "noai" => [
77 "DuckDuckGo+NoAI",
78 "https://noai.duckduckgo.com?q=",
79 "https://noai.duckduckgo.com/ac/?q=",
80 ]
81};
82
83fn clean_host(provided_host: &str) -> &str {
84 if ALLOWED_HOSTS.contains(&provided_host) {
85 return provided_host;
86 }
87
88 return "go";
89}
90
91async fn get_redirect(go: &str) -> Option<Redirect> {
92 if let Some(redirect) = direct::get_redirect(go).await {
93 return Some(redirect);
94 }
95
96 if let Some(redirect) = regex::get_redirect(go).await {
97 return Some(redirect);
98 }
99
100 None
101}
102
103async fn get_redirect_base(go: &str) -> Redirect {
104 get_redirect(go).await.unwrap_or_else(|| {
105 Redirect::temporary(
106 &("/_/create?format=direct&from=".to_string()
107 + &utf8_percent_encode(go, NON_ALPHANUMERIC).to_string()),
108 )
109 })
110}
111
112struct InvalidSearchEngine {
113 engine: String,
114}
115impl IntoResponse for InvalidSearchEngine {
116 fn into_response(self) -> Response {
117 return (
118 StatusCode::NOT_FOUND,
119 format!("Invalid Search Engine {}", self.engine),
120 )
121 .into_response();
122 }
123}
124
125async fn handle_search_suggest(Query(params): Query<HashMap<String, String>>) -> Result<String> {
126 if let Some(q) = params.get("q") {
127 let Some(search_engine_metadata) = (match params.get("engine") {
128 Some(e) => SEARCH_ENGINES.get(e),
129 None => SEARCH_ENGINES.get("kagi"), // This is the default for historical reasons ...
130 }) else {
131 return Err(InvalidSearchEngine {
132 engine: params
133 .get("engine")
134 .and_then(|e| Some(e.as_str()))
135 .unwrap_or("null")
136 .to_owned(),
137 }
138 .into());
139 };
140
141 Ok(reqwest::get(
142 &(search_engine_metadata[2].to_owned()
143 + &utf8_percent_encode(q, NON_ALPHANUMERIC).to_string()),
144 )
145 .await
146 .map_err(|_e| StatusCode::INTERNAL_SERVER_ERROR)?
147 .text()
148 .await
149 .map_err(|_e| StatusCode::INTERNAL_SERVER_ERROR)?)
150 } else {
151 Err(StatusCode::BAD_REQUEST.into())
152 }
153}
154
155async fn get_redirect_search(go: &str, engine: Option<&str>) -> Result<Redirect> {
156 get_redirect(go)
157 .await
158 .and_then(|r| Some(Ok(r)))
159 .unwrap_or_else(|| {
160 let Some(search_engine_metadata) = (match engine {
161 Some(e) => SEARCH_ENGINES.get(e),
162 None => SEARCH_ENGINES.get("kagi"), // This is the default for historical reasons ...
163 }) else {
164 return Err(InvalidSearchEngine {
165 engine: engine.unwrap_or("null").to_owned(),
166 }
167 .into());
168 };
169
170 Ok(Redirect::temporary(
171 &(search_engine_metadata[1].to_owned()
172 + &utf8_percent_encode(go, NON_ALPHANUMERIC).to_string()),
173 ))
174 })
175}
176
177fn get_search_engines() -> String {
178 let mut result = "".to_owned();
179 for (engine, meta) in SEARCH_ENGINES.entries() {
180 let engine_url = utf8_percent_encode(engine, NON_ALPHANUMERIC).to_string();
181 let name_attr = html_escape::encode_quoted_attribute(meta[0]);
182 let name_url = utf8_percent_encode(meta[0], NON_ALPHANUMERIC).to_string();
183
184 result += format!(
185 r#"<link rel="search" type="application/opensearchdescription+xml" title="Menu {name_attr}" href="/_/opensearch.xml?name={name_url}&engine={engine_url}" />"#
186 ).as_str();
187 }
188
189 result
190}
191
192#[axum::debug_handler]
193async fn handle_index(
194 session: Session,
195 headers: HeaderMap,
196 Query(params): Query<HashMap<String, String>>,
197) -> Result<Response> {
198 handle_static_page(StaticPageType::Index, session, ¶ms, &headers).await
199}
200
201async fn handle_base(Path(go): Path<String>) -> Redirect {
202 get_redirect_base(&go).await
203}
204
205async fn handle_search(Query(params): Query<HashMap<String, String>>) -> Result<Redirect> {
206 if let Some(go) = params.get("q") {
207 get_redirect_search(&go, params.get("engine").and_then(|s| Some(s.as_str()))).await
208 } else {
209 Ok(Redirect::temporary("/"))
210 }
211}
212
213enum CreationResult {
214 Success,
215 Conflict(String),
216 Failure,
217}
218
219enum DeletionResult {
220 Success,
221 NotFound,
222 Failure,
223}
224
225async fn handle_create_page(
226 session: Session,
227 Query(params): Query<HashMap<String, String>>,
228 headers: HeaderMap,
229) -> Result<Response> {
230 handle_static_page(StaticPageType::Create, session, ¶ms, &headers).await
231}
232async fn handle_create_success_page(
233 session: Session,
234 Query(params): Query<HashMap<String, String>>,
235 headers: HeaderMap,
236) -> Result<Response> {
237 match params.get("format").and_then(|s| Some(s.as_str())) {
238 Some("direct") => {
239 handle_static_page(
240 StaticPageType::CreateDirectSuccess,
241 session,
242 ¶ms,
243 &headers,
244 )
245 .await
246 }
247 Some("regex") => {
248 handle_static_page(
249 StaticPageType::CreateRegexSuccess,
250 session,
251 ¶ms,
252 &headers,
253 )
254 .await
255 }
256 _ => Err("Invalid format".into()),
257 }
258}
259async fn handle_create_conflict_page(
260 session: Session,
261 Query(params): Query<HashMap<String, String>>,
262 headers: HeaderMap,
263) -> Result<Response> {
264 match params.get("format").and_then(|s| Some(s.as_str())) {
265 Some("direct") => {
266 handle_static_page(
267 StaticPageType::CreateDirectConflict,
268 session,
269 ¶ms,
270 &headers,
271 )
272 .await
273 }
274 Some("regex") => {
275 handle_static_page(
276 StaticPageType::CreateRegexConflict,
277 session,
278 ¶ms,
279 &headers,
280 )
281 .await
282 }
283 _ => Err("Invalid format".into()),
284 }
285}
286async fn handle_create_failure_page(
287 session: Session,
288 Query(params): Query<HashMap<String, String>>,
289 headers: HeaderMap,
290) -> Result<impl IntoResponse> {
291 handle_static_page(StaticPageType::CreateFailure, session, ¶ms, &headers)
292 .await
293 .and_then(|html| Ok((StatusCode::INTERNAL_SERVER_ERROR, html)))
294}
295async fn handle_delete_success_page(
296 session: Session,
297 Query(params): Query<HashMap<String, String>>,
298 headers: HeaderMap,
299) -> Result<Response> {
300 match params.get("format").and_then(|s| Some(s.as_str())) {
301 Some("direct") => {
302 handle_static_page(
303 StaticPageType::DeleteDirectSuccess,
304 session,
305 ¶ms,
306 &headers,
307 )
308 .await
309 }
310 Some("regex") => {
311 handle_static_page(
312 StaticPageType::DeleteRegexSuccess,
313 session,
314 ¶ms,
315 &headers,
316 )
317 .await
318 }
319 _ => Err("Invalid format".into()),
320 }
321}
322async fn handle_delete_failure_page(
323 session: Session,
324 Query(params): Query<HashMap<String, String>>,
325 headers: HeaderMap,
326) -> Result<Response> {
327 handle_static_page(StaticPageType::DeleteFailure, session, ¶ms, &headers).await
328}
329async fn handle_opensearch_xml_page(
330 session: Session,
331 Query(params): Query<HashMap<String, String>>,
332 headers: HeaderMap,
333) -> Result<Response> {
334 handle_static_page(StaticPageType::OpenSearch, session, ¶ms, &headers).await
335}
336
337#[axum::debug_handler]
338async fn handle_create_do(
339 session: Session,
340 headers: HeaderMap,
341 Query(params): Query<HashMap<String, String>>,
342) -> Result<Response> {
343 ensure_token(&session, ¶ms).await?;
344 let owner = ensure_authenticated(
345 &headers,
346 #[cfg(debug_assertions)]
347 ¶ms,
348 )?;
349
350 let from = params.get("from").ok_or("Missing from query")?;
351 let to = params.get("to").ok_or("Missing to query")?;
352 let format = params.get("format").ok_or("Missing format query")?;
353
354 let create_response = match format.as_str() {
355 "direct" => direct::create(from, to, owner, params.get("current")).await,
356 "regex" => regex::create(from, to, owner, params.get("current")).await,
357 _ => return Err(format!("Invalid format {}", format).into_response().into()),
358 };
359
360 match create_response {
361 CreationResult::Success => Ok(Redirect::to(&format!(
362 "/_/create/success?from={}&to={}&format={}",
363 utf8_percent_encode(&from, NON_ALPHANUMERIC).to_string(),
364 utf8_percent_encode(&to, NON_ALPHANUMERIC).to_string(),
365 utf8_percent_encode(&format, NON_ALPHANUMERIC).to_string(),
366 ))
367 .into_response()),
368 CreationResult::Conflict(conflict) => Ok(Redirect::to(&format!(
369 "/_/create/conflict?from={}&to={}¤t={}&format={}",
370 utf8_percent_encode(&from, NON_ALPHANUMERIC).to_string(),
371 utf8_percent_encode(&to, NON_ALPHANUMERIC).to_string(),
372 utf8_percent_encode(&conflict, NON_ALPHANUMERIC).to_string(),
373 utf8_percent_encode(&format, NON_ALPHANUMERIC).to_string(),
374 ))
375 .into_response()),
376 CreationResult::Failure => Ok(Redirect::to(&format!(
377 "/_/create/failure?from={}&to={}&format={}",
378 utf8_percent_encode(&from, NON_ALPHANUMERIC).to_string(),
379 utf8_percent_encode(&to, NON_ALPHANUMERIC).to_string(),
380 utf8_percent_encode(&format, NON_ALPHANUMERIC).to_string(),
381 ))
382 .into_response()),
383 }
384}
385
386#[axum::debug_handler]
387async fn handle_delete_do(
388 session: Session,
389 headers: HeaderMap,
390 Query(params): Query<HashMap<String, String>>,
391) -> Result<Response> {
392 ensure_token(&session, ¶ms).await?;
393 ensure_authenticated(
394 &headers,
395 #[cfg(debug_assertions)]
396 ¶ms,
397 )?;
398
399 let from = params.get("from").ok_or("Missing from query")?;
400 let current = params.get("current").ok_or("Missing current query")?;
401 let format = params.get("format").ok_or("Missing format query")?;
402
403 let delete_result = match format.as_str() {
404 "direct" => direct::delete(from, current).await,
405 "regex" => regex::delete(from, current).await,
406 _ => return Err(format!("Invalid format {}", format).into_response().into()),
407 };
408
409 match delete_result {
410 DeletionResult::Success => Ok(Redirect::to(&format!(
411 "/_/delete/success?from={}¤t={}&format={}",
412 utf8_percent_encode(&from, NON_ALPHANUMERIC).to_string(),
413 utf8_percent_encode(¤t, NON_ALPHANUMERIC).to_string(),
414 utf8_percent_encode(&format, NON_ALPHANUMERIC).to_string(),
415 ))
416 .into_response()),
417 _ => Ok(Redirect::to(&format!(
418 "/_/delete/failure?from={}&to={}&format={}",
419 utf8_percent_encode(&from, NON_ALPHANUMERIC).to_string(),
420 utf8_percent_encode(¤t, NON_ALPHANUMERIC).to_string(),
421 utf8_percent_encode(&format, NON_ALPHANUMERIC).to_string(),
422 ))
423 .into_response()),
424 }
425}
426
427async fn handle_404() -> impl IntoResponse {
428 (StatusCode::NOT_FOUND, "Not Found")
429}
430
431#[tokio::main]
432async fn main() {
433 #[cfg(debug_assertions)]
434 {
435 DEVELOPMENT
436 .set(env::var("DEVELOPMENT").is_ok_and(|value| value == "true"))
437 .unwrap();
438 }
439 let mut connection = {
440 let mut maybe_connection;
441 let mut tries = 3;
442 loop {
443 // We can't use a for loop here as rust doesn't know it will run at least once...
444 tries -= 1;
445 maybe_connection = PgConnection::connect(
446 env::var("DATABASE_URL")
447 .expect(
448 "Please ensure you set your database URL in the $DATABASE_URL environment variable",
449 )
450 .as_str(),
451 )
452 .await;
453
454 if maybe_connection.is_ok() || tries == 0 {
455 break;
456 }
457
458 sleep(Duration::from_secs(5)).await;
459 }
460
461 maybe_connection
462 .expect("Failed to connect to database defined in $DATABASE_URL after 3 retries")
463 };
464
465 let session_layer = {
466 let session_store = MemoryStore::default();
467 SessionManagerLayer::new(session_store).with_secure(false) // must be false for go:// support
468 };
469
470 sqlx::migrate!()
471 .run(&mut connection)
472 .await
473 .expect("Failed to run database migrations");
474
475 STATE
476 .set(State {
477 sqlx_connection: Mutex::new(connection),
478 })
479 .expect("Consistency issue: failed to set STATE - was it already set?");
480
481 let mut router = Router::new();
482
483 #[cfg(not(debug_assertions))]
484 {
485 router = router.nest_service("/_/public", tower_serve_static::ServeDir::new(&PUBLIC_DIR));
486 }
487
488 #[cfg(debug_assertions)]
489 {
490 if *DEVELOPMENT.get().unwrap() {
491 router = router.nest_service(
492 "/_/public",
493 tower_http::services::ServeDir::new("src/html/public"),
494 );
495 } else {
496 router =
497 router.nest_service("/_/public", tower_serve_static::ServeDir::new(&PUBLIC_DIR));
498 }
499 }
500
501 router = router
502 .route("/", get(handle_index))
503 .route("/_/create", get(handle_create_page))
504 .route("/_/create/success", get(handle_create_success_page))
505 .route("/_/create/conflict", get(handle_create_conflict_page))
506 .route("/_/create/failure", get(handle_create_failure_page))
507 .route("/_/create/do", get(handle_create_do))
508 .route("/_/delete/do", get(handle_delete_do))
509 .route("/_/delete/success", get(handle_delete_success_page))
510 .route("/_/delete/failure", get(handle_delete_failure_page))
511 .route("/_/search", get(handle_search))
512 .route("/_/suggest", get(handle_search_suggest))
513 .route("/_/opensearch.xml", get(handle_opensearch_xml_page))
514 .route("/_/{*route}", get(handle_404))
515 .route("/{*go}", get(handle_base));
516 let app = NormalizePathLayer::trim_trailing_slash().layer(router.layer(session_layer));
517
518 let listener = tokio::net::TcpListener::bind(
519 env::var("BIND_ADDR").unwrap_or_else(|_| "0.0.0.0:3000".to_string()),
520 )
521 .await
522 .unwrap();
523 axum::serve(listener, ServiceExt::<Request>::into_make_service(app))
524 .await
525 .unwrap();
526}