this repo has no description
at main 451 lines 16 kB view raw
1use crate::{ 2 session::AxumSessionStore, 3 templates::HtmlTemplate, 4 templates::error::ErrorTemplate, 5 templates::home::{DayStatus, HomeTemplate}, 6}; 7use atrium_api::agent::atp_agent::AtpAgent; 8use atrium_api::agent::atp_agent::store::MemorySessionStore; 9use atrium_identity::{ 10 did::{CommonDidResolver, CommonDidResolverConfig, DEFAULT_PLC_DIRECTORY_URL}, 11 handle::{AtprotoHandleResolver, AtprotoHandleResolverConfig}, 12}; 13use atrium_oauth::{ 14 AtprotoClientMetadata, AtprotoLocalhostClientMetadata, AuthMethod, DefaultHttpClient, 15 GrantType, KnownScope, OAuthClient, OAuthClientConfig, OAuthResolverConfig, Scope, 16}; 17use atrium_xrpc_client::reqwest::ReqwestClient; 18use axum::{ 19 Router, 20 extract::State, 21 http::StatusCode, 22 middleware, 23 response::IntoResponse, 24 response::Response, 25 routing::{get, post}, 26}; 27use bb8_redis::RedisConnectionManager; 28use chrono::Datelike; 29use dotenv::dotenv; 30use redis::AsyncCommands; 31use rust_embed::RustEmbed; 32use shared::{ 33 HandleResolver, OAuthClientType, PasswordAgent, 34 advent::{CompletionStatus, get_all_days_completion_status, get_global_unlock_day}, 35 atrium::dns_resolver::HickoryDnsTxtResolver, 36 atrium::stores::AtriumSessionStore, 37 atrium::stores::AtriumStateStore, 38}; 39use sqlx::{PgPool, migrate::Migrator, postgres::PgPoolOptions}; 40use std::{env, net::SocketAddr, sync::Arc, time}; 41use time::Duration; 42use tower_http::trace::TraceLayer; 43use tower_sessions::{SessionManagerLayer, cookie::SameSite}; 44use tracing_subscriber::EnvFilter; 45 46mod handlers; 47 48extern crate dotenv; 49// 50mod extractors; 51mod redis_session_store; 52mod session; 53mod templates; 54mod unlock; 55 56#[derive(RustEmbed, Clone)] 57#[folder = "./public"] 58struct Assets; 59 60#[derive(Clone)] 61pub(crate) struct AppState { 62 postgres_pool: PgPool, 63 redis_pool: bb8::Pool<RedisConnectionManager>, 64 oauth_client: OAuthClientType, 65 //Used to get did to handle leaving because I figured we'd need it 66 handle_resolver: HandleResolver, 67 challenge_agent: Option<PasswordAgent>, 68 secret_agent: Option<PasswordAgent>, 69} 70 71pub fn oauth_scopes() -> Vec<Scope> { 72 vec![ 73 Scope::Known(KnownScope::Atproto), 74 // Scope::Known(KnownScope::TransitionGeneric), 75 //This looks like it HAS to have the full collection name, before i want to say it worked with wildcard 76 //Gives full CRUD to the codes.advent.* collection 77 // Scope::Unknown("repo:codes.advent.test".to_string()), 78 ] 79} 80 81fn error_response(status: StatusCode, message: &str) -> Response { 82 IntoResponse::into_response(( 83 status, 84 HtmlTemplate(ErrorTemplate { 85 title: "at://advent - Error", 86 message, 87 is_logged_in: false, 88 }), 89 )) 90} 91 92fn build_oauth_client( 93 host: &str, 94 port: u16, 95 redis_pool: bb8::Pool<RedisConnectionManager>, 96) -> OAuthClientType { 97 let http_client = Arc::new(DefaultHttpClient::default()); 98 let state_store = AtriumStateStore::new(redis_pool.clone()); 99 let session_store = AtriumSessionStore::new(redis_pool); 100 let resolver = OAuthResolverConfig { 101 did_resolver: CommonDidResolver::new(CommonDidResolverConfig { 102 plc_directory_url: DEFAULT_PLC_DIRECTORY_URL.to_string(), 103 http_client: http_client.clone(), 104 }), 105 handle_resolver: AtprotoHandleResolver::new(AtprotoHandleResolverConfig { 106 dns_txt_resolver: HickoryDnsTxtResolver::default(), 107 http_client: http_client.clone(), 108 }), 109 authorization_server_metadata: Default::default(), 110 protected_resource_metadata: Default::default(), 111 }; 112 113 if let Ok(oauth_host) = env::var("OAUTH_HOST") { 114 let config = OAuthClientConfig { 115 client_metadata: AtprotoClientMetadata { 116 client_id: format!("https://{oauth_host}/oauth-client-metadata.json"), 117 client_uri: Some(format!("https://{oauth_host}")), 118 redirect_uris: vec![format!("https://{oauth_host}/oauth/callback")], 119 token_endpoint_auth_method: AuthMethod::None, 120 grant_types: vec![GrantType::AuthorizationCode, GrantType::RefreshToken], 121 scopes: oauth_scopes(), 122 jwks_uri: None, 123 token_endpoint_auth_signing_alg: None, 124 }, 125 keys: None, 126 resolver, 127 state_store, 128 session_store, 129 }; 130 Arc::new(OAuthClient::new(config).expect("failed to create OAuth client")) 131 } else { 132 let config = OAuthClientConfig { 133 client_metadata: AtprotoLocalhostClientMetadata { 134 redirect_uris: Some(vec![format!("http://{host}:{port}/oauth/callback")]), 135 scopes: Some(oauth_scopes()), 136 }, 137 keys: None, 138 resolver, 139 state_store, 140 session_store, 141 }; 142 Arc::new(OAuthClient::new(config).expect("failed to create OAuth client")) 143 } 144} 145 146#[tokio::main] 147async fn main() -> Result<(), Box<dyn std::error::Error>> { 148 dotenv().ok(); 149 150 //Sets up logging/tracing 151 tracing_subscriber::fmt() 152 .with_env_filter( 153 EnvFilter::try_from_default_env() 154 .or_else(|_| EnvFilter::try_new("info,axum_tracing_example=error,tower_http=warn")) 155 .unwrap(), 156 ) 157 .init(); 158 159 let host = env::var("HOST").unwrap_or_else(|_| "127.0.0.1".to_string()); 160 let port: u16 = env::var("PORT") 161 .unwrap_or_else(|_| "7878".to_string()) 162 .parse() 163 .expect("PORT must be a number"); 164 165 let addr = SocketAddr::new(host.parse().expect("Invalid HOST address"), port); 166 let host = addr.ip(); 167 let port = addr.port(); 168 let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); 169 170 //sqlx pool 171 let database_url = 172 env::var("DATABASE_URL").expect("DATABASE_URL must be set in the environment or .env"); 173 174 // set up a postgres connection pool 175 let postgres_pool = PgPoolOptions::new() 176 .max_connections(5) 177 .acquire_timeout(Duration::from_secs(3)) 178 .connect(&database_url) 179 .await 180 .expect("can't connect to database"); 181 182 // Run database migrations 183 static MIGRATOR: Migrator = sqlx::migrate!("../migrations"); 184 MIGRATOR 185 .run(&postgres_pool) 186 .await 187 .expect("failed to run database migrations"); 188 log::info!("database migrations applied successfully"); 189 190 // redis pool setup 191 let redis_url = 192 env::var("REDIS_URL").expect("REDIS_URL must be set in the environment or .env"); 193 let manager = RedisConnectionManager::new(redis_url.clone()).unwrap(); 194 let redis_pool = bb8::Pool::builder().build(manager).await.unwrap(); 195 //cam be deleted, just an example for the test endpoint 196 { 197 // ping the database before starting 198 let mut conn = redis_pool.get().await.unwrap(); 199 conn.set::<&str, &str, ()>("foo", "bar").await.unwrap(); 200 let result: String = conn.get("foo").await.unwrap(); 201 assert_eq!(result, "bar"); 202 } 203 204 //Atrium/atproto setup 205 206 //Create a new handle resolver for the home page 207 let http_client = Arc::new(DefaultHttpClient::default()); 208 209 let handle_resolver = CommonDidResolver::new(CommonDidResolverConfig { 210 plc_directory_url: DEFAULT_PLC_DIRECTORY_URL.to_string(), 211 http_client: http_client.clone(), 212 }); 213 let handle_resolver = Arc::new(handle_resolver); 214 215 let client = build_oauth_client(&host.to_string(), port, redis_pool.clone()); 216 217 let session_store = redis_session_store::RedisSessionStore::new(redis_pool.clone()); 218 let session_layer = SessionManagerLayer::new(session_store) 219 //Set to lax so session id cookie can be set on redirect 220 .with_same_site(SameSite::Lax) 221 .with_secure(false); 222 223 // challenge account 224 let mut challenge_agent = None; 225 let challenge_pds = env::var("CHALLENGE_PDS"); 226 let challenge_identity = env::var("CHALLENGE_IDENTITY"); 227 let challenge_password = env::var("CHALLENGE_PASSWORD"); 228 if let (Ok(pds), Ok(identity), Ok(password)) = 229 (challenge_pds, challenge_identity, challenge_password) 230 { 231 let agent = AtpAgent::new(ReqwestClient::new(pds), MemorySessionStore::default()); 232 agent.login(identity, password).await?; 233 challenge_agent = Some(Arc::new(agent)); 234 } 235 236 // secret challenge account 237 let mut secret_challenge_agent = None; 238 let secret_challenge_pds = env::var("SECRET_CHALLENGE_PDS"); 239 let secret_challenge_identity = env::var("SECRET_CHALLENGE_IDENTITY"); 240 let secret_challenge_password = env::var("SECRET_CHALLENGE_PASSWORD"); 241 if let (Ok(pds), Ok(identity), Ok(password)) = ( 242 secret_challenge_pds, 243 secret_challenge_identity, 244 secret_challenge_password, 245 ) { 246 let agent = AtpAgent::new(ReqwestClient::new(pds), MemorySessionStore::default()); 247 agent.login(identity, password).await?; 248 secret_challenge_agent = Some(Arc::new(agent)); 249 } 250 251 let app_state = AppState { 252 postgres_pool, 253 redis_pool, 254 oauth_client: client, 255 handle_resolver, 256 challenge_agent, 257 secret_agent: secret_challenge_agent, 258 }; 259 260 //HACK Yeah I don't like it either - bt 261 let prod: bool = env::var("PROD") 262 .map(|val| val == "true") 263 .unwrap_or_else(|_| true); 264 log::info!( 265 "listening on http://{addr} (mode: {})", 266 if prod { "PROD" } else { "DEV" } 267 ); 268 269 let mut app = Router::new() 270 .route("/", get(home_handler)) 271 .route( 272 "/day/{id}", 273 match prod { 274 true => get(handlers::day::view_day_handler).route_layer( 275 middleware::from_fn_with_state(app_state.postgres_pool.clone(), unlock::unlock), 276 ), 277 false => get(handlers::day::view_day_handler), 278 }, 279 ) 280 .route( 281 "/day/{id}", 282 match prod { 283 true => post(handlers::day::post_day_handler).route_layer( 284 middleware::from_fn_with_state(app_state.postgres_pool.clone(), unlock::unlock), 285 ), 286 false => post(handlers::day::post_day_handler), 287 }, 288 ) 289 .route( 290 "/day/3/upload-car", 291 post(handlers::custom::day_three::inspect_car), // 2MB max for default axum 292 ) 293 .route( 294 "/day/5/{user_did}", 295 get(handlers::custom::day_five::create_record_handler), 296 ) 297 .route( 298 "/xrpc/codes.advent.challenge.getCode", 299 get(handlers::custom::day_six::xrpc_handler), 300 ) 301 .route( 302 "/leaderboard", 303 get(handlers::leaderboard::leaderboard_handler), 304 ) 305 .route("/admin", get(handlers::admin::admin_page_handler)) 306 .route("/admin", post(handlers::admin::admin_post_handler)) 307 .route("/login", get(handlers::auth::login_page_handler)) 308 .route("/logout", get(handlers::auth::logout_handler)) 309 .route("/redirect/login", get(handlers::auth::login_handle)) 310 .route( 311 "/oauth/callback", 312 get(handlers::auth::oauth_callback_handler), 313 ) 314 .nest_service("/public", axum_embed::ServeEmbed::<Assets>::new()); 315 316 if env::var("OAUTH_HOST").is_ok() { 317 app = app 318 .route( 319 "/oauth-client-metadata.json", 320 get(handlers::oauth_metadata::oauth_client_metadata_handler), 321 ) 322 .route( 323 "/.well-known/did.json", 324 get(handlers::did::did_json_handler), 325 ); 326 } 327 328 let app = app 329 .layer(session_layer) 330 .with_state(app_state) 331 .layer(TraceLayer::new_for_http()); 332 axum::serve(listener, app).await?; 333 Ok(()) 334} 335 336/// The default handler that will be used during the advent month, but not during amtosphere conf 337async fn home_handler(State(pool): State<PgPool>, session: AxumSessionStore) -> impl IntoResponse { 338 let mut unlocked: Vec<u8> = Vec::new(); 339 340 let did = session.get_did(); 341 let is_logged_in = session.logged_in(); 342 let all_statuses = get_all_days_completion_status(&pool, did.as_ref()) 343 .await 344 .unwrap_or_else(|_| (1..=25).map(|day| (day, CompletionStatus::None)).collect()); 345 346 let global_unlock_enabled = env::var("GLOBAL_UNLOCK_ENABLED") 347 .map(|v| v == "true") 348 .unwrap_or(false); 349 350 if global_unlock_enabled { 351 let global_unlock_day = get_global_unlock_day(&pool).await.unwrap_or(1); 352 let implemented_days = shared::advent::get_implemented_days(); 353 for d in implemented_days { 354 if d <= global_unlock_day { 355 unlocked.push(d); 356 } 357 } 358 } else { 359 //HACK Yeah I don't like it either - bt 360 let prod: bool = env::var("PROD") 361 .map(|val| val == "true") 362 .unwrap_or_else(|_| true); 363 if prod { 364 let implemented_days = shared::advent::get_implemented_days(); 365 if let Some(&first) = implemented_days.first() { 366 unlocked.push(first); 367 } 368 for window in implemented_days.windows(2) { 369 let prev = window[0]; 370 let prev_status = all_statuses 371 .iter() 372 .find(|(d, _)| *d == prev) 373 .map(|(_, s)| s) 374 .unwrap_or(&CompletionStatus::None); 375 if *prev_status == CompletionStatus::Both { 376 unlocked.push(window[1]); 377 } else if (prev == 4 || prev == 5) && *prev_status == CompletionStatus::PartOne { 378 //HACK hardcoded for the workshop since we don't have a part 2 for day 4 379 unlocked.push(window[1]); 380 } else { 381 break; 382 } 383 } 384 } else { 385 for d in 1..=25 { 386 unlocked.push(d as u8); 387 } 388 } 389 } 390 391 // Filter to only include unlocked days 392 let unlocked_with_status: Vec<DayStatus> = all_statuses 393 .into_iter() 394 .filter(|(day, _)| unlocked.contains(day)) 395 .map(|(day, status)| DayStatus { day, status }) 396 .collect(); 397 398 HtmlTemplate(HomeTemplate { 399 title: "at://advent", 400 unlocked_days: unlocked_with_status, 401 is_logged_in, 402 }) 403} 404 405/// The default handler that will be used during the advent month, but not during amtosphere conf 406#[expect(dead_code)] // until post-conf 407async fn dec_home_handler( 408 State(pool): State<PgPool>, 409 session: AxumSessionStore, 410) -> impl IntoResponse { 411 //TODO make a helper function for this since it is similar to the middleware 412 let now = chrono::Utc::now(); 413 let mut unlocked: Vec<u8> = Vec::new(); 414 415 //HACK Yeah I don't like it either - bt 416 let prod: bool = env::var("PROD") 417 .map(|val| val == "true") 418 .unwrap_or_else(|_| true); 419 if prod { 420 if now.month() == 12 { 421 let today = now.day().min(25); 422 for d in 1..=today { 423 unlocked.push(d as u8); 424 } 425 } 426 } else { 427 for d in 1..=25 { 428 unlocked.push(d as u8); 429 } 430 } 431 432 // Get completion status for all days at once 433 let did = session.get_did(); 434 let is_logged_in = session.logged_in(); 435 let all_statuses = get_all_days_completion_status(&pool, did.as_ref()) 436 .await 437 .unwrap_or_else(|_| (1..=25).map(|day| (day, CompletionStatus::None)).collect()); 438 439 // Filter to only include unlocked days 440 let unlocked_with_status: Vec<DayStatus> = all_statuses 441 .into_iter() 442 .filter(|(day, _)| unlocked.contains(day)) 443 .map(|(day, status)| DayStatus { day, status }) 444 .collect(); 445 446 HtmlTemplate(HomeTemplate { 447 title: "at://advent", 448 unlocked_days: unlocked_with_status, 449 is_logged_in, 450 }) 451}