use crate::AppState; use crate::auth::UserSession; use std::collections::{HashMap, HashSet}; use std::convert::Infallible; use std::fmt::Display; use std::time::Duration; use axum::{ Json, Router, extract::{Path, State}, http::StatusCode, response::{ IntoResponse, sse::{Event, KeepAlive, Sse}, }, routing::{get, post}, }; use futures_util::StreamExt; use serde::{Deserialize, Serialize}; use tokio::sync::broadcast; use tokio_stream::wrappers::BroadcastStream; use tokio_tungstenite::{connect_async, tungstenite::Message as TungsteniteMessage}; use tower_sessions::Session; pub fn polls_router() -> Router { Router::new() .route("/poll", post(create_poll)) .route("/polls", get(find_polls)) .route("/poll/{id}/events", get(subscribe_poll)) .route("/poll/{id}/end", post(end_poll)) } // -----Structs----- #[derive(Debug, Clone, Deserialize)] pub struct PollConfig { pub name: String, pub options: Vec, pub duration: u32, } impl Display for PollConfig { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, "Name: {}, Choices: {:?}, Duration: {}min(s)", self.name, self.options, self.duration ) } } #[derive(Debug, Clone, Serialize)] pub struct PollUpdate { pub event_type: String, pub votes: HashMap, pub total_voters: u32, } #[derive(Debug)] pub struct ActivePoll { pub config: PollConfig, pub votes: HashMap, pub voters: HashSet, pub tx: broadcast::Sender, pub shutdown_tx: Option>, } #[derive(Debug, Serialize)] struct CreatePollResponse { poll_id: String, config: PollConfigResponse, } #[derive(Debug, Serialize)] struct PollConfigResponse { name: String, options: Vec, } #[derive(Debug, Serialize)] struct Poll { active: bool, name: String, options: Vec, total_votes: i32, } #[derive(Debug, Serialize)] struct FindPollsResponse { polls: Vec, } // -----Handlers----- async fn create_poll( State(state): State, session: Session, Json(config): Json, ) -> impl IntoResponse { // Get user session let user_session = match get_user_session(&session).await { Ok(s) => s, Err(response) => return response, }; println!( "Creating poll '{}' for channel: {}", config.name, user_session.channel ); // Generate poll ID let poll_id = nanoid::nanoid!(10); // Create broadcast channel for updates let (tx, _) = broadcast::channel::(100); let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel(); // Initialize vote counts let mut votes = HashMap::new(); for option in &config.options { votes.insert(option.clone(), 0); } // Store poll state let active_poll = ActivePoll { config: config.clone(), votes: votes.clone(), voters: HashSet::new(), tx: tx.clone(), shutdown_tx: Some(shutdown_tx), }; state .active_polls .write() .await .insert(poll_id.clone(), active_poll); // Spawn EventSub listener task let poll_id_clone = poll_id.clone(); let state_clone = state.clone(); tokio::spawn(async move { if let Err(e) = eventsub_listener( state_clone, poll_id_clone.clone(), user_session, shutdown_rx, ) .await { eprintln!("EventSub listener error for poll {}: {}", poll_id_clone, e); } }); ( StatusCode::CREATED, Json(CreatePollResponse { poll_id, config: PollConfigResponse { name: config.name, options: config.options, }, }), ) .into_response() } async fn find_polls() -> Json { let x = Poll { active: false, name: "Poll A".to_string(), options: vec!["Option A".to_string(), "Option B".to_string()], total_votes: 12, }; Json(FindPollsResponse { polls: vec![x] }) } async fn subscribe_poll( State(state): State, Path(poll_id): Path, ) -> impl IntoResponse { let polls = state.active_polls.read().await; let Some(poll) = polls.get(&poll_id) else { return (StatusCode::NOT_FOUND, "Poll not found").into_response(); }; let rx = poll.tx.subscribe(); // Send initial state let initial_update = PollUpdate { event_type: "initial".to_string(), votes: poll.votes.clone(), total_voters: poll.voters.len() as u32, }; drop(polls); // Release lock before streaming let stream = async_stream::stream! { // Send initial state yield Ok::<_, Infallible>(Event::default().json_data(&initial_update).unwrap()); // Stream updates let mut stream = BroadcastStream::new(rx); while let Some(result) = stream.next().await { match result { Ok(update) => { yield Ok(Event::default().json_data(&update).unwrap()); if update.event_type == "ended" { break; } } Err(_) => break, // Channel closed } } }; Sse::new(stream) .keep_alive(KeepAlive::default().interval(Duration::from_secs(15))) .into_response() } async fn end_poll( State(state): State, Path(poll_id): Path, session: Session, ) -> impl IntoResponse { // Verify user is authenticated if get_user_session(&session).await.is_err() { return (StatusCode::UNAUTHORIZED, "Not authenticated").into_response(); } let mut polls = state.active_polls.write().await; let Some(poll) = polls.remove(&poll_id) else { return (StatusCode::NOT_FOUND, "Poll not found").into_response(); }; // Send final update let final_update = PollUpdate { event_type: "ended".to_string(), votes: poll.votes.clone(), total_voters: poll.voters.len() as u32, }; let _ = poll.tx.send(final_update); // Signal shutdown to EventSub task if let Some(shutdown_tx) = poll.shutdown_tx { let _ = shutdown_tx.send(()); } (StatusCode::OK, "Poll ended").into_response() } // -----Helper Functions----- async fn get_user_session(session: &Session) -> Result { match session.get_value("user_session").await { Ok(Some(user_session_json)) => serde_json::from_value::(user_session_json) .map_err(|e| { ( StatusCode::INTERNAL_SERVER_ERROR, format!("Error deserializing user session: {}", e), ) .into_response() }), Ok(None) => Err((StatusCode::UNAUTHORIZED, "Not authenticated").into_response()), Err(e) => Err(( StatusCode::INTERNAL_SERVER_ERROR, format!("Error retrieving session: {}", e), ) .into_response()), } } // -----EventSub Integration----- #[derive(Debug, Deserialize)] struct EventSubMessage { metadata: EventSubMetadata, payload: serde_json::Value, } #[derive(Debug, Deserialize)] struct EventSubMetadata { message_type: String, #[serde(default)] subscription_type: Option, } #[derive(Debug, Deserialize)] struct WelcomePayload { session: EventSubSession, } #[derive(Debug, Deserialize)] struct EventSubSession { id: String, } #[derive(Debug, Deserialize)] struct ChatMessageEvent { chatter_user_id: String, message: ChatMessage, } #[derive(Debug, Deserialize)] struct ChatMessage { text: String, } async fn eventsub_listener( state: AppState, poll_id: String, user_session: UserSession, mut shutdown_rx: tokio::sync::oneshot::Receiver<()>, ) -> Result<(), Box> { println!("Starting EventSub listener for poll: {}", poll_id); // Connect to Twitch EventSub WebSocket let twitch_ws_url = "wss://eventsub.wss.twitch.tv/ws"; let (ws_stream, _) = connect_async(twitch_ws_url).await?; println!("Connected to Twitch EventSub WebSocket"); let (mut write, mut read) = ws_stream.split(); use futures_util::{SinkExt, StreamExt}; // Wait for welcome message and get session_id let session_id = loop { tokio::select! { _ = &mut shutdown_rx => { println!("Shutdown signal received before welcome"); return Ok(()); } msg = read.next() => { match msg { Some(Ok(TungsteniteMessage::Text(text))) => { let message: EventSubMessage = serde_json::from_str(&text)?; if message.metadata.message_type == "session_welcome" { let payload: WelcomePayload = serde_json::from_value(message.payload)?; println!("Received session_id: {}", payload.session.id); break payload.session.id; } } Some(Ok(TungsteniteMessage::Ping(data))) => { write.send(TungsteniteMessage::Pong(data)).await?; } Some(Err(e)) => return Err(e.into()), None => return Err("WebSocket closed before welcome".into()), _ => {} } } } }; // Subscribe to channel.chat.message subscribe_to_chat_messages( &state.client_id, &user_session.access_token, &user_session.user_id, &session_id, ) .await?; println!( "Subscribed to chat messages for channel: {}", user_session.channel ); // Main event loop loop { tokio::select! { _ = &mut shutdown_rx => { println!("Shutdown signal received for poll: {}", poll_id); break; } msg = read.next() => { match msg { Some(Ok(TungsteniteMessage::Text(text))) => { if let Err(e) = handle_eventsub_message(&state, &poll_id, &text).await { eprintln!("Error handling EventSub message: {}", e); } } Some(Ok(TungsteniteMessage::Ping(data))) => { if write.send(TungsteniteMessage::Pong(data)).await.is_err() { break; } } Some(Ok(TungsteniteMessage::Close(_))) => { println!("EventSub connection closed"); break; } Some(Err(e)) => { eprintln!("EventSub WebSocket error: {}", e); break; } None => break, _ => {} } } } } // Cleanup eventsub connection let _ = write.send(TungsteniteMessage::Close(None)).await; println!("EventSub listener ended for poll: {}", poll_id); Ok(()) } async fn subscribe_to_chat_messages( client_id: &str, access_token: &str, user_id: &str, session_id: &str, ) -> Result<(), Box> { let client = reqwest::Client::new(); let body = serde_json::json!({ "type": "channel.chat.message", "version": "1", "condition": { "broadcaster_user_id": user_id, "user_id": user_id }, "transport": { "method": "websocket", "session_id": session_id } }); let response = client .post("https://api.twitch.tv/helix/eventsub/subscriptions") .header("Authorization", format!("Bearer {}", access_token)) .header("Client-Id", client_id) .header("Content-Type", "application/json") .json(&body) .send() .await?; let status = response.status(); let response_text = response.text().await?; println!("Subscription response ({status}): {response_text}"); if !status.is_success() { return Err(format!( "Failed to subscribe to chat messages for user {}: {}", user_id, response_text ) .into()); } Ok(()) } async fn handle_eventsub_message( state: &AppState, poll_id: &str, text: &str, ) -> Result<(), Box> { let message: EventSubMessage = serde_json::from_str(text)?; match message.metadata.message_type.as_str() { "notification" => { if message.metadata.subscription_type.as_deref() == Some("channel.chat.message") { let event: ChatMessageEvent = serde_json::from_value( message.payload.get("event").cloned().unwrap_or_default(), )?; process_chat_vote(state, poll_id, &event).await; } } "session_keepalive" => { // Keepalive, nothing to do } "session_reconnect" => { // TODO: Handle reconnect by connecting to new URL println!("Received reconnect request - not yet implemented"); } _ => { println!("Received some message: {}", text); } } Ok(()) } async fn process_chat_vote(state: &AppState, poll_id: &str, event: &ChatMessageEvent) { println!( "Chat message received: '{}' from user {}", event.message.text, event.chatter_user_id ); let vote_text = event.message.text.trim().to_lowercase(); let mut polls = state.active_polls.write().await; let Some(poll) = polls.get_mut(poll_id) else { return; }; // Check if user already voted if poll.voters.contains(&event.chatter_user_id) { println!("User has already voted in this poll"); return; } // Try to match vote to an option // Support formats: "1", "!vote 1", "!vote option_name", or exact option name let matched_option = parse_vote(&vote_text, &poll.config.options); println!("Matched option: {:?}", matched_option); if let Some(option) = matched_option { // Update set of voters and the vote entry count poll.voters.insert(event.chatter_user_id.clone()); *poll.votes.entry(option.clone()).or_insert(0) += 1; // Broadcast update let update = PollUpdate { event_type: "vote".to_string(), votes: poll.votes.clone(), total_voters: poll.voters.len() as u32, }; let _ = poll.tx.send(update); println!("Vote recorded: {} -> {}", event.chatter_user_id, option); } } fn parse_vote(text: &str, options: &[String]) -> Option { let text = text.trim().to_lowercase(); // Try exact match (case-insensitive) for option in options { if option.to_lowercase() == text { return Some(option.clone()); } } // Try "!vote X" format let vote_content = text .strip_prefix("!vote ") .or_else(|| text.strip_prefix("!v ")) .unwrap_or(&text); // Try numeric vote (1-indexed) if let Ok(num) = vote_content.parse::() { if num >= 1 && num <= options.len() { return Some(options[num - 1].clone()); } } // Try partial match at start for option in options { if option.to_lowercase().starts_with(vote_content) { return Some(option.clone()); } } None }