use crate::{ session::AxumSessionStore, templates::HtmlTemplate, templates::error::ErrorTemplate, templates::home::{DayStatus, HomeTemplate}, }; use atrium_api::agent::atp_agent::AtpAgent; use atrium_api::agent::atp_agent::store::MemorySessionStore; use atrium_identity::{ did::{CommonDidResolver, CommonDidResolverConfig, DEFAULT_PLC_DIRECTORY_URL}, handle::{AtprotoHandleResolver, AtprotoHandleResolverConfig}, }; use atrium_oauth::{ AtprotoClientMetadata, AtprotoLocalhostClientMetadata, AuthMethod, DefaultHttpClient, GrantType, KnownScope, OAuthClient, OAuthClientConfig, OAuthResolverConfig, Scope, }; use atrium_xrpc_client::reqwest::ReqwestClient; use axum::{ Router, extract::State, http::StatusCode, middleware, response::IntoResponse, response::Response, routing::{get, post}, }; use bb8_redis::RedisConnectionManager; use chrono::Datelike; use dotenv::dotenv; use redis::AsyncCommands; use rust_embed::RustEmbed; use shared::{ HandleResolver, OAuthClientType, PasswordAgent, advent::{CompletionStatus, get_all_days_completion_status, get_global_unlock_day}, atrium::dns_resolver::HickoryDnsTxtResolver, atrium::stores::AtriumSessionStore, atrium::stores::AtriumStateStore, }; use sqlx::{PgPool, migrate::Migrator, postgres::PgPoolOptions}; use std::{env, net::SocketAddr, sync::Arc, time}; use time::Duration; use tower_http::trace::TraceLayer; use tower_sessions::{SessionManagerLayer, cookie::SameSite}; use tracing_subscriber::EnvFilter; mod handlers; extern crate dotenv; // mod extractors; mod redis_session_store; mod session; mod templates; mod unlock; #[derive(RustEmbed, Clone)] #[folder = "./public"] struct Assets; #[derive(Clone)] pub(crate) struct AppState { postgres_pool: PgPool, redis_pool: bb8::Pool, oauth_client: OAuthClientType, //Used to get did to handle leaving because I figured we'd need it handle_resolver: HandleResolver, challenge_agent: Option, secret_agent: Option, } pub fn oauth_scopes() -> Vec { vec![ Scope::Known(KnownScope::Atproto), // Scope::Known(KnownScope::TransitionGeneric), //This looks like it HAS to have the full collection name, before i want to say it worked with wildcard //Gives full CRUD to the codes.advent.* collection // Scope::Unknown("repo:codes.advent.test".to_string()), ] } fn error_response(status: StatusCode, message: &str) -> Response { IntoResponse::into_response(( status, HtmlTemplate(ErrorTemplate { title: "at://advent - Error", message, is_logged_in: false, }), )) } fn build_oauth_client( host: &str, port: u16, redis_pool: bb8::Pool, ) -> OAuthClientType { let http_client = Arc::new(DefaultHttpClient::default()); let state_store = AtriumStateStore::new(redis_pool.clone()); let session_store = AtriumSessionStore::new(redis_pool); let resolver = OAuthResolverConfig { did_resolver: CommonDidResolver::new(CommonDidResolverConfig { plc_directory_url: DEFAULT_PLC_DIRECTORY_URL.to_string(), http_client: http_client.clone(), }), handle_resolver: AtprotoHandleResolver::new(AtprotoHandleResolverConfig { dns_txt_resolver: HickoryDnsTxtResolver::default(), http_client: http_client.clone(), }), authorization_server_metadata: Default::default(), protected_resource_metadata: Default::default(), }; if let Ok(oauth_host) = env::var("OAUTH_HOST") { let config = OAuthClientConfig { client_metadata: AtprotoClientMetadata { client_id: format!("https://{oauth_host}/oauth-client-metadata.json"), client_uri: Some(format!("https://{oauth_host}")), redirect_uris: vec![format!("https://{oauth_host}/oauth/callback")], token_endpoint_auth_method: AuthMethod::None, grant_types: vec![GrantType::AuthorizationCode, GrantType::RefreshToken], scopes: oauth_scopes(), jwks_uri: None, token_endpoint_auth_signing_alg: None, }, keys: None, resolver, state_store, session_store, }; Arc::new(OAuthClient::new(config).expect("failed to create OAuth client")) } else { let config = OAuthClientConfig { client_metadata: AtprotoLocalhostClientMetadata { redirect_uris: Some(vec![format!("http://{host}:{port}/oauth/callback")]), scopes: Some(oauth_scopes()), }, keys: None, resolver, state_store, session_store, }; Arc::new(OAuthClient::new(config).expect("failed to create OAuth client")) } } #[tokio::main] async fn main() -> Result<(), Box> { dotenv().ok(); //Sets up logging/tracing tracing_subscriber::fmt() .with_env_filter( EnvFilter::try_from_default_env() .or_else(|_| EnvFilter::try_new("info,axum_tracing_example=error,tower_http=warn")) .unwrap(), ) .init(); let host = env::var("HOST").unwrap_or_else(|_| "127.0.0.1".to_string()); let port: u16 = env::var("PORT") .unwrap_or_else(|_| "7878".to_string()) .parse() .expect("PORT must be a number"); let addr = SocketAddr::new(host.parse().expect("Invalid HOST address"), port); let host = addr.ip(); let port = addr.port(); let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); //sqlx pool let database_url = env::var("DATABASE_URL").expect("DATABASE_URL must be set in the environment or .env"); // set up a postgres connection pool let postgres_pool = PgPoolOptions::new() .max_connections(5) .acquire_timeout(Duration::from_secs(3)) .connect(&database_url) .await .expect("can't connect to database"); // Run database migrations static MIGRATOR: Migrator = sqlx::migrate!("../migrations"); MIGRATOR .run(&postgres_pool) .await .expect("failed to run database migrations"); log::info!("database migrations applied successfully"); // redis pool setup let redis_url = env::var("REDIS_URL").expect("REDIS_URL must be set in the environment or .env"); let manager = RedisConnectionManager::new(redis_url.clone()).unwrap(); let redis_pool = bb8::Pool::builder().build(manager).await.unwrap(); //cam be deleted, just an example for the test endpoint { // ping the database before starting let mut conn = redis_pool.get().await.unwrap(); conn.set::<&str, &str, ()>("foo", "bar").await.unwrap(); let result: String = conn.get("foo").await.unwrap(); assert_eq!(result, "bar"); } //Atrium/atproto setup //Create a new handle resolver for the home page let http_client = Arc::new(DefaultHttpClient::default()); let handle_resolver = CommonDidResolver::new(CommonDidResolverConfig { plc_directory_url: DEFAULT_PLC_DIRECTORY_URL.to_string(), http_client: http_client.clone(), }); let handle_resolver = Arc::new(handle_resolver); let client = build_oauth_client(&host.to_string(), port, redis_pool.clone()); let session_store = redis_session_store::RedisSessionStore::new(redis_pool.clone()); let session_layer = SessionManagerLayer::new(session_store) //Set to lax so session id cookie can be set on redirect .with_same_site(SameSite::Lax) .with_secure(false); // challenge account let mut challenge_agent = None; let challenge_pds = env::var("CHALLENGE_PDS"); let challenge_identity = env::var("CHALLENGE_IDENTITY"); let challenge_password = env::var("CHALLENGE_PASSWORD"); if let (Ok(pds), Ok(identity), Ok(password)) = (challenge_pds, challenge_identity, challenge_password) { let agent = AtpAgent::new(ReqwestClient::new(pds), MemorySessionStore::default()); agent.login(identity, password).await?; challenge_agent = Some(Arc::new(agent)); } // secret challenge account let mut secret_challenge_agent = None; let secret_challenge_pds = env::var("SECRET_CHALLENGE_PDS"); let secret_challenge_identity = env::var("SECRET_CHALLENGE_IDENTITY"); let secret_challenge_password = env::var("SECRET_CHALLENGE_PASSWORD"); if let (Ok(pds), Ok(identity), Ok(password)) = ( secret_challenge_pds, secret_challenge_identity, secret_challenge_password, ) { let agent = AtpAgent::new(ReqwestClient::new(pds), MemorySessionStore::default()); agent.login(identity, password).await?; secret_challenge_agent = Some(Arc::new(agent)); } let app_state = AppState { postgres_pool, redis_pool, oauth_client: client, handle_resolver, challenge_agent, secret_agent: secret_challenge_agent, }; //HACK Yeah I don't like it either - bt let prod: bool = env::var("PROD") .map(|val| val == "true") .unwrap_or_else(|_| true); log::info!( "listening on http://{addr} (mode: {})", if prod { "PROD" } else { "DEV" } ); let mut app = Router::new() .route("/", get(home_handler)) .route( "/day/{id}", match prod { true => get(handlers::day::view_day_handler).route_layer( middleware::from_fn_with_state(app_state.postgres_pool.clone(), unlock::unlock), ), false => get(handlers::day::view_day_handler), }, ) .route( "/day/{id}", match prod { true => post(handlers::day::post_day_handler).route_layer( middleware::from_fn_with_state(app_state.postgres_pool.clone(), unlock::unlock), ), false => post(handlers::day::post_day_handler), }, ) .route( "/day/3/upload-car", post(handlers::custom::day_three::inspect_car), // 2MB max for default axum ) .route( "/day/5/{user_did}", get(handlers::custom::day_five::create_record_handler), ) .route( "/xrpc/codes.advent.challenge.getCode", get(handlers::custom::day_six::xrpc_handler), ) .route( "/leaderboard", get(handlers::leaderboard::leaderboard_handler), ) .route("/admin", get(handlers::admin::admin_page_handler)) .route("/admin", post(handlers::admin::admin_post_handler)) .route("/login", get(handlers::auth::login_page_handler)) .route("/logout", get(handlers::auth::logout_handler)) .route("/redirect/login", get(handlers::auth::login_handle)) .route( "/oauth/callback", get(handlers::auth::oauth_callback_handler), ) .nest_service("/public", axum_embed::ServeEmbed::::new()); if env::var("OAUTH_HOST").is_ok() { app = app .route( "/oauth-client-metadata.json", get(handlers::oauth_metadata::oauth_client_metadata_handler), ) .route( "/.well-known/did.json", get(handlers::did::did_json_handler), ); } let app = app .layer(session_layer) .with_state(app_state) .layer(TraceLayer::new_for_http()); axum::serve(listener, app).await?; Ok(()) } /// The default handler that will be used during the advent month, but not during amtosphere conf async fn home_handler(State(pool): State, session: AxumSessionStore) -> impl IntoResponse { let mut unlocked: Vec = Vec::new(); let did = session.get_did(); let is_logged_in = session.logged_in(); let all_statuses = get_all_days_completion_status(&pool, did.as_ref()) .await .unwrap_or_else(|_| (1..=25).map(|day| (day, CompletionStatus::None)).collect()); let global_unlock_enabled = env::var("GLOBAL_UNLOCK_ENABLED") .map(|v| v == "true") .unwrap_or(false); if global_unlock_enabled { let global_unlock_day = get_global_unlock_day(&pool).await.unwrap_or(1); let implemented_days = shared::advent::get_implemented_days(); for d in implemented_days { if d <= global_unlock_day { unlocked.push(d); } } } else { //HACK Yeah I don't like it either - bt let prod: bool = env::var("PROD") .map(|val| val == "true") .unwrap_or_else(|_| true); if prod { let implemented_days = shared::advent::get_implemented_days(); if let Some(&first) = implemented_days.first() { unlocked.push(first); } for window in implemented_days.windows(2) { let prev = window[0]; let prev_status = all_statuses .iter() .find(|(d, _)| *d == prev) .map(|(_, s)| s) .unwrap_or(&CompletionStatus::None); if *prev_status == CompletionStatus::Both { unlocked.push(window[1]); } else if (prev == 4 || prev == 5) && *prev_status == CompletionStatus::PartOne { //HACK hardcoded for the workshop since we don't have a part 2 for day 4 unlocked.push(window[1]); } else { break; } } } else { for d in 1..=25 { unlocked.push(d as u8); } } } // Filter to only include unlocked days let unlocked_with_status: Vec = all_statuses .into_iter() .filter(|(day, _)| unlocked.contains(day)) .map(|(day, status)| DayStatus { day, status }) .collect(); HtmlTemplate(HomeTemplate { title: "at://advent", unlocked_days: unlocked_with_status, is_logged_in, }) } /// The default handler that will be used during the advent month, but not during amtosphere conf #[expect(dead_code)] // until post-conf async fn dec_home_handler( State(pool): State, session: AxumSessionStore, ) -> impl IntoResponse { //TODO make a helper function for this since it is similar to the middleware let now = chrono::Utc::now(); let mut unlocked: Vec = Vec::new(); //HACK Yeah I don't like it either - bt let prod: bool = env::var("PROD") .map(|val| val == "true") .unwrap_or_else(|_| true); if prod { if now.month() == 12 { let today = now.day().min(25); for d in 1..=today { unlocked.push(d as u8); } } } else { for d in 1..=25 { unlocked.push(d as u8); } } // Get completion status for all days at once let did = session.get_did(); let is_logged_in = session.logged_in(); let all_statuses = get_all_days_completion_status(&pool, did.as_ref()) .await .unwrap_or_else(|_| (1..=25).map(|day| (day, CompletionStatus::None)).collect()); // Filter to only include unlocked days let unlocked_with_status: Vec = all_statuses .into_iter() .filter(|(day, _)| unlocked.contains(day)) .map(|(day, status)| DayStatus { day, status }) .collect(); HtmlTemplate(HomeTemplate { title: "at://advent", unlocked_days: unlocked_with_status, is_logged_in, }) }