mod config; mod database; use std::{env, net::SocketAddr}; use axum::{ Json, Router, http::HeaderValue, response::IntoResponse, routing::{get, post}, }; use hyper::{Method, StatusCode, header}; use serde::{Deserialize, Serialize}; use serde_json::json; use tower_http::cors::CorsLayer; use tower_sessions::{ Expiry, Session, SessionManagerLayer, cookie::{SameSite, time::Duration}, }; use tower_sessions_redis_store::{ RedisStore, fred::prelude::{ClientLike, Config, Pool}, }; use url::Url; #[tokio::main] async fn main() -> Result<(), Box> { dotenv::dotenv().ok(); database::init().unwrap(); let config = Config::from_url_centralized(env::var("REDIS_URL").unwrap().as_str()).unwrap(); let pool = Pool::new(config, None, None, None, 6)?; let redis_conn = pool.connect(); pool.wait_for_connect().await?; let session_store = RedisStore::new(pool); let session_layer = SessionManagerLayer::new(session_store) .with_secure(true) .with_same_site(SameSite::None) .with_expiry(Expiry::OnInactivity(Duration::minutes(30))); let cors = CorsLayer::new() .allow_origin("http://localhost:3000".parse::().unwrap()) .allow_methods(vec![Method::GET, Method::POST]) .allow_credentials(true); let app = Router::new() .route("/auth", get(auth_handler)) .route("/auth_url", get(auth_url_handler)) .route("/callback", post(auth_callback)) .route("/channel", get(user_handler)) .layer(session_layer) .layer(cors); let addr = SocketAddr::from(([127, 0, 0, 1], 5173)); println!("Server listening on {}", addr); let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); axum::serve(listener, app.into_make_service()).await?; redis_conn.await??; Ok(()) } async fn auth_url_handler(session: Session) -> impl IntoResponse { let state = nanoid::nanoid!(); match session.insert("state", &state).await { Ok(_) => { let twitch_client_id = env::var("TWITCH_CLIENT_ID").unwrap(); let twitch_redirect_uri = env::var("TWITCH_REDIRECT_URI").unwrap(); let scopes = vec!["user:read:email"]; let mut auth_url = Url::parse("https://id.twitch.tv/oauth2/authorize").unwrap(); auth_url .query_pairs_mut() .append_pair("client_id", &twitch_client_id) .append_pair("redirect_uri", &twitch_redirect_uri) .append_pair("response_type", "code") .append_pair("scope", &scopes.join(" ")) .append_pair("state", &state); let mut headers = axum::http::HeaderMap::new(); headers.insert( header::LOCATION, HeaderValue::from_str(auth_url.as_str()).unwrap(), ); (StatusCode::FOUND, headers, ()).into_response() } Err(_) => ( StatusCode::INTERNAL_SERVER_ERROR, "Unable to create session".to_string(), ) .into_response(), } } #[derive(Deserialize, Serialize)] struct UserSession { access_token: String, refresh_token: String, state: String, channel: String, } async fn user_handler(session: Session) -> impl IntoResponse { match session.get_value("user_session").await { Ok(Some(user_session_json)) => { let Ok(user_session) = serde_json::from_value::(user_session_json) else { return ( StatusCode::INTERNAL_SERVER_ERROR, "Error deserializing user session".to_string(), ); }; (StatusCode::OK, user_session.channel) } Ok(None) => ( StatusCode::UNAUTHORIZED, "User session not found".to_string(), ), // TODO: maybe have the frontend redirect to /auth_url upon receiving this to re-authenticate Err(e) => ( StatusCode::INTERNAL_SERVER_ERROR, format!("Error retrieving session: {}", e), ), } } async fn auth_handler() -> impl IntoResponse { let twitch_client_id = env::var("TWITCH_CLIENT_ID").unwrap(); let twitch_redirect_uri = env::var("TWITCH_REDIRECT_URI").unwrap(); let scopes = vec!["user:read:email"]; let state = nanoid::nanoid!(); let mut auth_url = Url::parse("https://id.twitch.tv/oauth2/authorize").unwrap(); auth_url .query_pairs_mut() .append_pair("client_id", &twitch_client_id) .append_pair("redirect_uri", &twitch_redirect_uri) .append_pair("response_type", "code") .append_pair("scope", &scopes.join(" ")) .append_pair("state", &state); let mut headers = axum::http::HeaderMap::new(); headers.insert(header::LOCATION, auth_url.as_str().parse().unwrap()); headers.insert( header::SET_COOKIE, format!("state={}; HttpOnly; Secure", state) .parse() .unwrap(), ); (StatusCode::FOUND, headers, ()) } #[derive(Deserialize, Serialize)] struct AuthTokenRequestResponse { access_token: String, refresh_token: String, expires_in: u64, scope: Vec, token_type: String, } #[derive(Deserialize, Serialize)] struct AuthCallbackRequest { code: String, state: String, scope: String, } async fn auth_callback( session: Session, Json(payload): Json, ) -> impl IntoResponse { let code = payload.code; let request_state = payload.state; let session_state: String = session .get("state") .await .unwrap() .unwrap_or("".to_string()); if request_state != session_state { println!("State mismatch"); return (StatusCode::INTERNAL_SERVER_ERROR, "State mismatch").into_response(); } let request_client = reqwest::Client::builder().build().unwrap(); match request_client .post("https://id.twitch.tv/oauth2/token") .form(&[ ("client_id", env::var("TWITCH_CLIENT_ID").unwrap()), ("client_secret", env::var("TWITCH_CLIENT_SECRET").unwrap()), ("code", code.to_string()), ("grant_type", "authorization_code".to_string()), ("redirect_uri", env::var("TWITCH_REDIRECT_URI").unwrap()), ("state", session_state.to_string()), ]) .send() .await { Ok(response) => { let token_response: AuthTokenRequestResponse = serde_json::from_str(&response.text().await.unwrap()).unwrap(); match validate_token(&token_response.access_token).await { Some(validated_token) => { session .insert_value( "user_session", json!({ "access_token": token_response.access_token, "refresh_token": token_response.refresh_token, "state": session_state, "channel": validated_token.login, }), ) .await .unwrap(); println!("Session created successfully."); (StatusCode::OK, ()).into_response() } None => ( StatusCode::INTERNAL_SERVER_ERROR, "Token is invalid".to_string(), ) .into_response(), } } Err(e) => { println!("Error: {}", e); (StatusCode::INTERNAL_SERVER_ERROR, format!("Error: {}", e)).into_response() } } } // TODO: implement this for invalid tokens async fn logout(token: &str, session: Session) { let request_client = reqwest::Client::builder().build().unwrap(); match request_client .post("https://id.twitch.tv/oauth2/revoke") .form(&[ ("client_id", env::var("TWITCH_CLIENT_ID").unwrap()), ("token", token.to_string()), ]) .send() .await { Ok(response) => { let status = response.status(); if status.is_success() { println!("Token revoked successfully."); let _: Option = session.remove_value("user_session").await.unwrap(); println!("Session removed successfully."); } else { let json_response: serde_json::Value = serde_json::from_str(&response.text().await.unwrap()).unwrap(); println!( "Failed to revoke token. Status: {} for error: {}", status, json_response["message"].as_str().unwrap() ); } } Err(e) => { println!("Error revoking token: {}", e); } } } #[derive(Deserialize, Serialize)] struct ValidatedTokenResponse { client_id: String, login: String, user_id: String, expires_in: u64, scopes: Vec, } async fn validate_token(token: &str) -> Option { let request_client = reqwest::Client::builder().build().unwrap(); match request_client .get("https://id.twitch.tv/oauth2/validate") .header("Authorization", format!("OAuth {}", token)) .send() .await { Ok(response) => Some(serde_json::from_str(&response.text().await.unwrap()).unwrap()), Err(_) => None, } }