this repo has no description
at main 558 lines 16 kB view raw
1use crate::AppState; 2use crate::auth::UserSession; 3 4use std::collections::{HashMap, HashSet}; 5use std::convert::Infallible; 6use std::fmt::Display; 7use std::time::Duration; 8 9use axum::{ 10 Json, Router, 11 extract::{Path, State}, 12 http::StatusCode, 13 response::{ 14 IntoResponse, 15 sse::{Event, KeepAlive, Sse}, 16 }, 17 routing::{get, post}, 18}; 19use futures_util::StreamExt; 20use serde::{Deserialize, Serialize}; 21use tokio::sync::broadcast; 22use tokio_stream::wrappers::BroadcastStream; 23use tokio_tungstenite::{connect_async, tungstenite::Message as TungsteniteMessage}; 24use tower_sessions::Session; 25 26pub fn polls_router() -> Router<AppState> { 27 Router::new() 28 .route("/poll", post(create_poll)) 29 .route("/polls", get(find_polls)) 30 .route("/poll/{id}/events", get(subscribe_poll)) 31 .route("/poll/{id}/end", post(end_poll)) 32} 33 34// -----Structs----- 35 36#[derive(Debug, Clone, Deserialize)] 37pub struct PollConfig { 38 pub name: String, 39 pub options: Vec<String>, 40 pub duration: u32, 41} 42 43impl Display for PollConfig { 44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 45 write!( 46 f, 47 "Name: {}, Choices: {:?}, Duration: {}min(s)", 48 self.name, self.options, self.duration 49 ) 50 } 51} 52 53#[derive(Debug, Clone, Serialize)] 54pub struct PollUpdate { 55 pub event_type: String, 56 pub votes: HashMap<String, u32>, 57 pub total_voters: u32, 58} 59 60#[derive(Debug)] 61pub struct ActivePoll { 62 pub config: PollConfig, 63 pub votes: HashMap<String, u32>, 64 pub voters: HashSet<String>, 65 pub tx: broadcast::Sender<PollUpdate>, 66 pub shutdown_tx: Option<tokio::sync::oneshot::Sender<()>>, 67} 68 69#[derive(Debug, Serialize)] 70struct CreatePollResponse { 71 poll_id: String, 72 config: PollConfigResponse, 73} 74 75#[derive(Debug, Serialize)] 76struct PollConfigResponse { 77 name: String, 78 options: Vec<String>, 79} 80 81#[derive(Debug, Serialize)] 82struct Poll { 83 active: bool, 84 name: String, 85 options: Vec<String>, 86 total_votes: i32, 87} 88 89#[derive(Debug, Serialize)] 90struct FindPollsResponse { 91 polls: Vec<Poll>, 92} 93// -----Handlers----- 94 95async fn create_poll( 96 State(state): State<AppState>, 97 session: Session, 98 Json(config): Json<PollConfig>, 99) -> impl IntoResponse { 100 // Get user session 101 let user_session = match get_user_session(&session).await { 102 Ok(s) => s, 103 Err(response) => return response, 104 }; 105 106 println!( 107 "Creating poll '{}' for channel: {}", 108 config.name, user_session.channel 109 ); 110 111 // Generate poll ID 112 let poll_id = nanoid::nanoid!(10); 113 114 // Create broadcast channel for updates 115 let (tx, _) = broadcast::channel::<PollUpdate>(100); 116 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel(); 117 118 // Initialize vote counts 119 let mut votes = HashMap::new(); 120 for option in &config.options { 121 votes.insert(option.clone(), 0); 122 } 123 124 // Store poll state 125 let active_poll = ActivePoll { 126 config: config.clone(), 127 votes: votes.clone(), 128 voters: HashSet::new(), 129 tx: tx.clone(), 130 shutdown_tx: Some(shutdown_tx), 131 }; 132 133 state 134 .active_polls 135 .write() 136 .await 137 .insert(poll_id.clone(), active_poll); 138 139 // Spawn EventSub listener task 140 let poll_id_clone = poll_id.clone(); 141 let state_clone = state.clone(); 142 tokio::spawn(async move { 143 if let Err(e) = eventsub_listener( 144 state_clone, 145 poll_id_clone.clone(), 146 user_session, 147 shutdown_rx, 148 ) 149 .await 150 { 151 eprintln!("EventSub listener error for poll {}: {}", poll_id_clone, e); 152 } 153 }); 154 155 ( 156 StatusCode::CREATED, 157 Json(CreatePollResponse { 158 poll_id, 159 config: PollConfigResponse { 160 name: config.name, 161 options: config.options, 162 }, 163 }), 164 ) 165 .into_response() 166} 167 168async fn find_polls() -> Json<FindPollsResponse> { 169 let x = Poll { 170 active: false, 171 name: "Poll A".to_string(), 172 options: vec!["Option A".to_string(), "Option B".to_string()], 173 total_votes: 12, 174 }; 175 Json(FindPollsResponse { polls: vec![x] }) 176} 177 178async fn subscribe_poll( 179 State(state): State<AppState>, 180 Path(poll_id): Path<String>, 181) -> impl IntoResponse { 182 let polls = state.active_polls.read().await; 183 184 let Some(poll) = polls.get(&poll_id) else { 185 return (StatusCode::NOT_FOUND, "Poll not found").into_response(); 186 }; 187 188 let rx = poll.tx.subscribe(); 189 190 // Send initial state 191 let initial_update = PollUpdate { 192 event_type: "initial".to_string(), 193 votes: poll.votes.clone(), 194 total_voters: poll.voters.len() as u32, 195 }; 196 197 drop(polls); // Release lock before streaming 198 199 let stream = async_stream::stream! { 200 // Send initial state 201 yield Ok::<_, Infallible>(Event::default().json_data(&initial_update).unwrap()); 202 203 // Stream updates 204 let mut stream = BroadcastStream::new(rx); 205 while let Some(result) = stream.next().await { 206 match result { 207 Ok(update) => { 208 yield Ok(Event::default().json_data(&update).unwrap()); 209 if update.event_type == "ended" { 210 break; 211 } 212 } 213 Err(_) => break, // Channel closed 214 } 215 } 216 }; 217 218 Sse::new(stream) 219 .keep_alive(KeepAlive::default().interval(Duration::from_secs(15))) 220 .into_response() 221} 222 223async fn end_poll( 224 State(state): State<AppState>, 225 Path(poll_id): Path<String>, 226 session: Session, 227) -> impl IntoResponse { 228 // Verify user is authenticated 229 if get_user_session(&session).await.is_err() { 230 return (StatusCode::UNAUTHORIZED, "Not authenticated").into_response(); 231 } 232 233 let mut polls = state.active_polls.write().await; 234 235 let Some(poll) = polls.remove(&poll_id) else { 236 return (StatusCode::NOT_FOUND, "Poll not found").into_response(); 237 }; 238 239 // Send final update 240 let final_update = PollUpdate { 241 event_type: "ended".to_string(), 242 votes: poll.votes.clone(), 243 total_voters: poll.voters.len() as u32, 244 }; 245 let _ = poll.tx.send(final_update); 246 247 // Signal shutdown to EventSub task 248 if let Some(shutdown_tx) = poll.shutdown_tx { 249 let _ = shutdown_tx.send(()); 250 } 251 252 (StatusCode::OK, "Poll ended").into_response() 253} 254 255// -----Helper Functions----- 256 257async fn get_user_session(session: &Session) -> Result<UserSession, axum::response::Response> { 258 match session.get_value("user_session").await { 259 Ok(Some(user_session_json)) => serde_json::from_value::<UserSession>(user_session_json) 260 .map_err(|e| { 261 ( 262 StatusCode::INTERNAL_SERVER_ERROR, 263 format!("Error deserializing user session: {}", e), 264 ) 265 .into_response() 266 }), 267 Ok(None) => Err((StatusCode::UNAUTHORIZED, "Not authenticated").into_response()), 268 Err(e) => Err(( 269 StatusCode::INTERNAL_SERVER_ERROR, 270 format!("Error retrieving session: {}", e), 271 ) 272 .into_response()), 273 } 274} 275 276// -----EventSub Integration----- 277 278#[derive(Debug, Deserialize)] 279struct EventSubMessage { 280 metadata: EventSubMetadata, 281 payload: serde_json::Value, 282} 283 284#[derive(Debug, Deserialize)] 285struct EventSubMetadata { 286 message_type: String, 287 #[serde(default)] 288 subscription_type: Option<String>, 289} 290 291#[derive(Debug, Deserialize)] 292struct WelcomePayload { 293 session: EventSubSession, 294} 295 296#[derive(Debug, Deserialize)] 297struct EventSubSession { 298 id: String, 299} 300 301#[derive(Debug, Deserialize)] 302struct ChatMessageEvent { 303 chatter_user_id: String, 304 message: ChatMessage, 305} 306 307#[derive(Debug, Deserialize)] 308struct ChatMessage { 309 text: String, 310} 311 312async fn eventsub_listener( 313 state: AppState, 314 poll_id: String, 315 user_session: UserSession, 316 mut shutdown_rx: tokio::sync::oneshot::Receiver<()>, 317) -> Result<(), Box<dyn std::error::Error + Send + Sync>> { 318 println!("Starting EventSub listener for poll: {}", poll_id); 319 320 // Connect to Twitch EventSub WebSocket 321 let twitch_ws_url = "wss://eventsub.wss.twitch.tv/ws"; 322 let (ws_stream, _) = connect_async(twitch_ws_url).await?; 323 println!("Connected to Twitch EventSub WebSocket"); 324 325 let (mut write, mut read) = ws_stream.split(); 326 use futures_util::{SinkExt, StreamExt}; 327 328 // Wait for welcome message and get session_id 329 let session_id = loop { 330 tokio::select! { 331 _ = &mut shutdown_rx => { 332 println!("Shutdown signal received before welcome"); 333 return Ok(()); 334 } 335 msg = read.next() => { 336 match msg { 337 Some(Ok(TungsteniteMessage::Text(text))) => { 338 let message: EventSubMessage = serde_json::from_str(&text)?; 339 if message.metadata.message_type == "session_welcome" { 340 let payload: WelcomePayload = serde_json::from_value(message.payload)?; 341 println!("Received session_id: {}", payload.session.id); 342 break payload.session.id; 343 } 344 } 345 Some(Ok(TungsteniteMessage::Ping(data))) => { 346 write.send(TungsteniteMessage::Pong(data)).await?; 347 } 348 Some(Err(e)) => return Err(e.into()), 349 None => return Err("WebSocket closed before welcome".into()), 350 _ => {} 351 } 352 } 353 } 354 }; 355 356 // Subscribe to channel.chat.message 357 subscribe_to_chat_messages( 358 &state.client_id, 359 &user_session.access_token, 360 &user_session.user_id, 361 &session_id, 362 ) 363 .await?; 364 365 println!( 366 "Subscribed to chat messages for channel: {}", 367 user_session.channel 368 ); 369 370 // Main event loop 371 loop { 372 tokio::select! { 373 _ = &mut shutdown_rx => { 374 println!("Shutdown signal received for poll: {}", poll_id); 375 break; 376 } 377 msg = read.next() => { 378 match msg { 379 Some(Ok(TungsteniteMessage::Text(text))) => { 380 if let Err(e) = handle_eventsub_message(&state, &poll_id, &text).await { 381 eprintln!("Error handling EventSub message: {}", e); 382 } 383 } 384 Some(Ok(TungsteniteMessage::Ping(data))) => { 385 if write.send(TungsteniteMessage::Pong(data)).await.is_err() { 386 break; 387 } 388 } 389 Some(Ok(TungsteniteMessage::Close(_))) => { 390 println!("EventSub connection closed"); 391 break; 392 } 393 Some(Err(e)) => { 394 eprintln!("EventSub WebSocket error: {}", e); 395 break; 396 } 397 None => break, 398 _ => {} 399 } 400 } 401 } 402 } 403 404 // Cleanup eventsub connection 405 let _ = write.send(TungsteniteMessage::Close(None)).await; 406 println!("EventSub listener ended for poll: {}", poll_id); 407 408 Ok(()) 409} 410 411async fn subscribe_to_chat_messages( 412 client_id: &str, 413 access_token: &str, 414 user_id: &str, 415 session_id: &str, 416) -> Result<(), Box<dyn std::error::Error + Send + Sync>> { 417 let client = reqwest::Client::new(); 418 419 let body = serde_json::json!({ 420 "type": "channel.chat.message", 421 "version": "1", 422 "condition": { 423 "broadcaster_user_id": user_id, 424 "user_id": user_id 425 }, 426 "transport": { 427 "method": "websocket", 428 "session_id": session_id 429 } 430 }); 431 432 let response = client 433 .post("https://api.twitch.tv/helix/eventsub/subscriptions") 434 .header("Authorization", format!("Bearer {}", access_token)) 435 .header("Client-Id", client_id) 436 .header("Content-Type", "application/json") 437 .json(&body) 438 .send() 439 .await?; 440 441 let status = response.status(); 442 let response_text = response.text().await?; 443 println!("Subscription response ({status}): {response_text}"); 444 445 if !status.is_success() { 446 return Err(format!( 447 "Failed to subscribe to chat messages for user {}: {}", 448 user_id, response_text 449 ) 450 .into()); 451 } 452 453 Ok(()) 454} 455 456async fn handle_eventsub_message( 457 state: &AppState, 458 poll_id: &str, 459 text: &str, 460) -> Result<(), Box<dyn std::error::Error + Send + Sync>> { 461 let message: EventSubMessage = serde_json::from_str(text)?; 462 463 match message.metadata.message_type.as_str() { 464 "notification" => { 465 if message.metadata.subscription_type.as_deref() == Some("channel.chat.message") { 466 let event: ChatMessageEvent = serde_json::from_value( 467 message.payload.get("event").cloned().unwrap_or_default(), 468 )?; 469 470 process_chat_vote(state, poll_id, &event).await; 471 } 472 } 473 "session_keepalive" => { 474 // Keepalive, nothing to do 475 } 476 "session_reconnect" => { 477 // TODO: Handle reconnect by connecting to new URL 478 println!("Received reconnect request - not yet implemented"); 479 } 480 _ => { 481 println!("Received some message: {}", text); 482 } 483 } 484 485 Ok(()) 486} 487 488async fn process_chat_vote(state: &AppState, poll_id: &str, event: &ChatMessageEvent) { 489 println!( 490 "Chat message received: '{}' from user {}", 491 event.message.text, event.chatter_user_id 492 ); 493 let vote_text = event.message.text.trim().to_lowercase(); 494 495 let mut polls = state.active_polls.write().await; 496 let Some(poll) = polls.get_mut(poll_id) else { 497 return; 498 }; 499 500 // Check if user already voted 501 if poll.voters.contains(&event.chatter_user_id) { 502 println!("User has already voted in this poll"); 503 return; 504 } 505 506 // Try to match vote to an option 507 // Support formats: "1", "!vote 1", "!vote option_name", or exact option name 508 let matched_option = parse_vote(&vote_text, &poll.config.options); 509 println!("Matched option: {:?}", matched_option); 510 if let Some(option) = matched_option { 511 // Update set of voters and the vote entry count 512 poll.voters.insert(event.chatter_user_id.clone()); 513 *poll.votes.entry(option.clone()).or_insert(0) += 1; 514 515 // Broadcast update 516 let update = PollUpdate { 517 event_type: "vote".to_string(), 518 votes: poll.votes.clone(), 519 total_voters: poll.voters.len() as u32, 520 }; 521 let _ = poll.tx.send(update); 522 523 println!("Vote recorded: {} -> {}", event.chatter_user_id, option); 524 } 525} 526 527fn parse_vote(text: &str, options: &[String]) -> Option<String> { 528 let text = text.trim().to_lowercase(); 529 530 // Try exact match (case-insensitive) 531 for option in options { 532 if option.to_lowercase() == text { 533 return Some(option.clone()); 534 } 535 } 536 537 // Try "!vote X" format 538 let vote_content = text 539 .strip_prefix("!vote ") 540 .or_else(|| text.strip_prefix("!v ")) 541 .unwrap_or(&text); 542 543 // Try numeric vote (1-indexed) 544 if let Ok(num) = vote_content.parse::<usize>() { 545 if num >= 1 && num <= options.len() { 546 return Some(options[num - 1].clone()); 547 } 548 } 549 550 // Try partial match at start 551 for option in options { 552 if option.to_lowercase().starts_with(vote_content) { 553 return Some(option.clone()); 554 } 555 } 556 557 None 558}