this repo has no description
at main 9.7 kB view raw
1mod config; 2mod database; 3use std::{env, net::SocketAddr}; 4 5use axum::{ 6 Json, Router, 7 http::HeaderValue, 8 response::IntoResponse, 9 routing::{get, post}, 10}; 11use hyper::{Method, StatusCode, header}; 12use serde::{Deserialize, Serialize}; 13use serde_json::json; 14use tower_http::cors::CorsLayer; 15use tower_sessions::{ 16 Expiry, Session, SessionManagerLayer, 17 cookie::{SameSite, time::Duration}, 18}; 19use tower_sessions_redis_store::{ 20 RedisStore, 21 fred::prelude::{ClientLike, Config, Pool}, 22}; 23use url::Url; 24 25#[tokio::main] 26async fn main() -> Result<(), Box<dyn std::error::Error>> { 27 dotenv::dotenv().ok(); 28 database::init().unwrap(); 29 let config = Config::from_url_centralized(env::var("REDIS_URL").unwrap().as_str()).unwrap(); 30 let pool = Pool::new(config, None, None, None, 6)?; 31 let redis_conn = pool.connect(); 32 pool.wait_for_connect().await?; 33 34 let session_store = RedisStore::new(pool); 35 let session_layer = SessionManagerLayer::new(session_store) 36 .with_secure(true) 37 .with_same_site(SameSite::None) 38 .with_expiry(Expiry::OnInactivity(Duration::minutes(30))); 39 40 let cors = CorsLayer::new() 41 .allow_origin("http://localhost:3000".parse::<HeaderValue>().unwrap()) 42 .allow_methods(vec![Method::GET, Method::POST]) 43 .allow_credentials(true); 44 let app = Router::new() 45 .route("/auth", get(auth_handler)) 46 .route("/auth_url", get(auth_url_handler)) 47 .route("/callback", post(auth_callback)) 48 .route("/channel", get(user_handler)) 49 .layer(session_layer) 50 .layer(cors); 51 let addr = SocketAddr::from(([127, 0, 0, 1], 5173)); 52 println!("Server listening on {}", addr); 53 54 let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); 55 axum::serve(listener, app.into_make_service()).await?; 56 57 redis_conn.await??; 58 Ok(()) 59} 60 61async fn auth_url_handler(session: Session) -> impl IntoResponse { 62 let state = nanoid::nanoid!(); 63 match session.insert("state", &state).await { 64 Ok(_) => { 65 let twitch_client_id = env::var("TWITCH_CLIENT_ID").unwrap(); 66 let twitch_redirect_uri = env::var("TWITCH_REDIRECT_URI").unwrap(); 67 let scopes = vec!["user:read:email"]; 68 69 let mut auth_url = Url::parse("https://id.twitch.tv/oauth2/authorize").unwrap(); 70 auth_url 71 .query_pairs_mut() 72 .append_pair("client_id", &twitch_client_id) 73 .append_pair("redirect_uri", &twitch_redirect_uri) 74 .append_pair("response_type", "code") 75 .append_pair("scope", &scopes.join(" ")) 76 .append_pair("state", &state); 77 78 let mut headers = axum::http::HeaderMap::new(); 79 headers.insert( 80 header::LOCATION, 81 HeaderValue::from_str(auth_url.as_str()).unwrap(), 82 ); 83 84 (StatusCode::FOUND, headers, ()).into_response() 85 } 86 Err(_) => ( 87 StatusCode::INTERNAL_SERVER_ERROR, 88 "Unable to create session".to_string(), 89 ) 90 .into_response(), 91 } 92} 93 94#[derive(Deserialize, Serialize)] 95struct UserSession { 96 access_token: String, 97 refresh_token: String, 98 state: String, 99 channel: String, 100} 101 102async fn user_handler(session: Session) -> impl IntoResponse { 103 match session.get_value("user_session").await { 104 Ok(Some(user_session_json)) => { 105 let Ok(user_session) = serde_json::from_value::<UserSession>(user_session_json) else { 106 return ( 107 StatusCode::INTERNAL_SERVER_ERROR, 108 "Error deserializing user session".to_string(), 109 ); 110 }; 111 112 (StatusCode::OK, user_session.channel) 113 } 114 Ok(None) => ( 115 StatusCode::UNAUTHORIZED, 116 "User session not found".to_string(), 117 ), // TODO: maybe have the frontend redirect to /auth_url upon receiving this to re-authenticate 118 Err(e) => ( 119 StatusCode::INTERNAL_SERVER_ERROR, 120 format!("Error retrieving session: {}", e), 121 ), 122 } 123} 124 125async fn auth_handler() -> impl IntoResponse { 126 let twitch_client_id = env::var("TWITCH_CLIENT_ID").unwrap(); 127 let twitch_redirect_uri = env::var("TWITCH_REDIRECT_URI").unwrap(); 128 let scopes = vec!["user:read:email"]; 129 let state = nanoid::nanoid!(); 130 131 let mut auth_url = Url::parse("https://id.twitch.tv/oauth2/authorize").unwrap(); 132 auth_url 133 .query_pairs_mut() 134 .append_pair("client_id", &twitch_client_id) 135 .append_pair("redirect_uri", &twitch_redirect_uri) 136 .append_pair("response_type", "code") 137 .append_pair("scope", &scopes.join(" ")) 138 .append_pair("state", &state); 139 140 let mut headers = axum::http::HeaderMap::new(); 141 headers.insert(header::LOCATION, auth_url.as_str().parse().unwrap()); 142 headers.insert( 143 header::SET_COOKIE, 144 format!("state={}; HttpOnly; Secure", state) 145 .parse() 146 .unwrap(), 147 ); 148 149 (StatusCode::FOUND, headers, ()) 150} 151 152#[derive(Deserialize, Serialize)] 153struct AuthTokenRequestResponse { 154 access_token: String, 155 refresh_token: String, 156 expires_in: u64, 157 scope: Vec<String>, 158 token_type: String, 159} 160 161#[derive(Deserialize, Serialize)] 162struct AuthCallbackRequest { 163 code: String, 164 state: String, 165 scope: String, 166} 167async fn auth_callback( 168 session: Session, 169 Json(payload): Json<AuthCallbackRequest>, 170) -> impl IntoResponse { 171 let code = payload.code; 172 let request_state = payload.state; 173 let session_state: String = session 174 .get("state") 175 .await 176 .unwrap() 177 .unwrap_or("".to_string()); 178 if request_state != session_state { 179 println!("State mismatch"); 180 return (StatusCode::INTERNAL_SERVER_ERROR, "State mismatch").into_response(); 181 } 182 183 let request_client = reqwest::Client::builder().build().unwrap(); 184 match request_client 185 .post("https://id.twitch.tv/oauth2/token") 186 .form(&[ 187 ("client_id", env::var("TWITCH_CLIENT_ID").unwrap()), 188 ("client_secret", env::var("TWITCH_CLIENT_SECRET").unwrap()), 189 ("code", code.to_string()), 190 ("grant_type", "authorization_code".to_string()), 191 ("redirect_uri", env::var("TWITCH_REDIRECT_URI").unwrap()), 192 ("state", session_state.to_string()), 193 ]) 194 .send() 195 .await 196 { 197 Ok(response) => { 198 let token_response: AuthTokenRequestResponse = 199 serde_json::from_str(&response.text().await.unwrap()).unwrap(); 200 match validate_token(&token_response.access_token).await { 201 Some(validated_token) => { 202 session 203 .insert_value( 204 "user_session", 205 json!({ 206 "access_token": token_response.access_token, 207 "refresh_token": token_response.refresh_token, 208 "state": session_state, 209 "channel": validated_token.login, 210 }), 211 ) 212 .await 213 .unwrap(); 214 println!("Session created successfully."); 215 (StatusCode::OK, ()).into_response() 216 } 217 None => ( 218 StatusCode::INTERNAL_SERVER_ERROR, 219 "Token is invalid".to_string(), 220 ) 221 .into_response(), 222 } 223 } 224 Err(e) => { 225 println!("Error: {}", e); 226 (StatusCode::INTERNAL_SERVER_ERROR, format!("Error: {}", e)).into_response() 227 } 228 } 229} 230 231// TODO: implement this for invalid tokens 232async fn logout(token: &str, session: Session) { 233 let request_client = reqwest::Client::builder().build().unwrap(); 234 match request_client 235 .post("https://id.twitch.tv/oauth2/revoke") 236 .form(&[ 237 ("client_id", env::var("TWITCH_CLIENT_ID").unwrap()), 238 ("token", token.to_string()), 239 ]) 240 .send() 241 .await 242 { 243 Ok(response) => { 244 let status = response.status(); 245 if status.is_success() { 246 println!("Token revoked successfully."); 247 let _: Option<serde_json::Value> = 248 session.remove_value("user_session").await.unwrap(); 249 println!("Session removed successfully."); 250 } else { 251 let json_response: serde_json::Value = 252 serde_json::from_str(&response.text().await.unwrap()).unwrap(); 253 println!( 254 "Failed to revoke token. Status: {} for error: {}", 255 status, 256 json_response["message"].as_str().unwrap() 257 ); 258 } 259 } 260 Err(e) => { 261 println!("Error revoking token: {}", e); 262 } 263 } 264} 265 266#[derive(Deserialize, Serialize)] 267struct ValidatedTokenResponse { 268 client_id: String, 269 login: String, 270 user_id: String, 271 expires_in: u64, 272 scopes: Vec<String>, 273} 274async fn validate_token(token: &str) -> Option<ValidatedTokenResponse> { 275 let request_client = reqwest::Client::builder().build().unwrap(); 276 match request_client 277 .get("https://id.twitch.tv/oauth2/validate") 278 .header("Authorization", format!("OAuth {}", token)) 279 .send() 280 .await 281 { 282 Ok(response) => Some(serde_json::from_str(&response.text().await.unwrap()).unwrap()), 283 Err(_) => None, 284 } 285}