this repo has no description
1use crate::AppState;
2use crate::auth::UserSession;
3
4use std::collections::{HashMap, HashSet};
5use std::convert::Infallible;
6use std::fmt::Display;
7use std::time::Duration;
8
9use axum::{
10 Json, Router,
11 extract::{Path, State},
12 http::StatusCode,
13 response::{
14 IntoResponse,
15 sse::{Event, KeepAlive, Sse},
16 },
17 routing::{get, post},
18};
19use futures_util::StreamExt;
20use serde::{Deserialize, Serialize};
21use tokio::sync::broadcast;
22use tokio_stream::wrappers::BroadcastStream;
23use tokio_tungstenite::{connect_async, tungstenite::Message as TungsteniteMessage};
24use tower_sessions::Session;
25
26pub fn polls_router() -> Router<AppState> {
27 Router::new()
28 .route("/poll", post(create_poll))
29 .route("/polls", get(find_polls))
30 .route("/poll/{id}/events", get(subscribe_poll))
31 .route("/poll/{id}/end", post(end_poll))
32}
33
34// -----Structs-----
35
36#[derive(Debug, Clone, Deserialize)]
37pub struct PollConfig {
38 pub name: String,
39 pub options: Vec<String>,
40 pub duration: u32,
41}
42
43impl Display for PollConfig {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 write!(
46 f,
47 "Name: {}, Choices: {:?}, Duration: {}min(s)",
48 self.name, self.options, self.duration
49 )
50 }
51}
52
53#[derive(Debug, Clone, Serialize)]
54pub struct PollUpdate {
55 pub event_type: String,
56 pub votes: HashMap<String, u32>,
57 pub total_voters: u32,
58}
59
60#[derive(Debug)]
61pub struct ActivePoll {
62 pub config: PollConfig,
63 pub votes: HashMap<String, u32>,
64 pub voters: HashSet<String>,
65 pub tx: broadcast::Sender<PollUpdate>,
66 pub shutdown_tx: Option<tokio::sync::oneshot::Sender<()>>,
67}
68
69#[derive(Debug, Serialize)]
70struct CreatePollResponse {
71 poll_id: String,
72 config: PollConfigResponse,
73}
74
75#[derive(Debug, Serialize)]
76struct PollConfigResponse {
77 name: String,
78 options: Vec<String>,
79}
80
81#[derive(Debug, Serialize)]
82struct Poll {
83 active: bool,
84 name: String,
85 options: Vec<String>,
86 total_votes: i32,
87}
88
89#[derive(Debug, Serialize)]
90struct FindPollsResponse {
91 polls: Vec<Poll>,
92}
93// -----Handlers-----
94
95async fn create_poll(
96 State(state): State<AppState>,
97 session: Session,
98 Json(config): Json<PollConfig>,
99) -> impl IntoResponse {
100 // Get user session
101 let user_session = match get_user_session(&session).await {
102 Ok(s) => s,
103 Err(response) => return response,
104 };
105
106 println!(
107 "Creating poll '{}' for channel: {}",
108 config.name, user_session.channel
109 );
110
111 // Generate poll ID
112 let poll_id = nanoid::nanoid!(10);
113
114 // Create broadcast channel for updates
115 let (tx, _) = broadcast::channel::<PollUpdate>(100);
116 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
117
118 // Initialize vote counts
119 let mut votes = HashMap::new();
120 for option in &config.options {
121 votes.insert(option.clone(), 0);
122 }
123
124 // Store poll state
125 let active_poll = ActivePoll {
126 config: config.clone(),
127 votes: votes.clone(),
128 voters: HashSet::new(),
129 tx: tx.clone(),
130 shutdown_tx: Some(shutdown_tx),
131 };
132
133 state
134 .active_polls
135 .write()
136 .await
137 .insert(poll_id.clone(), active_poll);
138
139 // Spawn EventSub listener task
140 let poll_id_clone = poll_id.clone();
141 let state_clone = state.clone();
142 tokio::spawn(async move {
143 if let Err(e) = eventsub_listener(
144 state_clone,
145 poll_id_clone.clone(),
146 user_session,
147 shutdown_rx,
148 )
149 .await
150 {
151 eprintln!("EventSub listener error for poll {}: {}", poll_id_clone, e);
152 }
153 });
154
155 (
156 StatusCode::CREATED,
157 Json(CreatePollResponse {
158 poll_id,
159 config: PollConfigResponse {
160 name: config.name,
161 options: config.options,
162 },
163 }),
164 )
165 .into_response()
166}
167
168async fn find_polls() -> Json<FindPollsResponse> {
169 let x = Poll {
170 active: false,
171 name: "Poll A".to_string(),
172 options: vec!["Option A".to_string(), "Option B".to_string()],
173 total_votes: 12,
174 };
175 Json(FindPollsResponse { polls: vec![x] })
176}
177
178async fn subscribe_poll(
179 State(state): State<AppState>,
180 Path(poll_id): Path<String>,
181) -> impl IntoResponse {
182 let polls = state.active_polls.read().await;
183
184 let Some(poll) = polls.get(&poll_id) else {
185 return (StatusCode::NOT_FOUND, "Poll not found").into_response();
186 };
187
188 let rx = poll.tx.subscribe();
189
190 // Send initial state
191 let initial_update = PollUpdate {
192 event_type: "initial".to_string(),
193 votes: poll.votes.clone(),
194 total_voters: poll.voters.len() as u32,
195 };
196
197 drop(polls); // Release lock before streaming
198
199 let stream = async_stream::stream! {
200 // Send initial state
201 yield Ok::<_, Infallible>(Event::default().json_data(&initial_update).unwrap());
202
203 // Stream updates
204 let mut stream = BroadcastStream::new(rx);
205 while let Some(result) = stream.next().await {
206 match result {
207 Ok(update) => {
208 yield Ok(Event::default().json_data(&update).unwrap());
209 if update.event_type == "ended" {
210 break;
211 }
212 }
213 Err(_) => break, // Channel closed
214 }
215 }
216 };
217
218 Sse::new(stream)
219 .keep_alive(KeepAlive::default().interval(Duration::from_secs(15)))
220 .into_response()
221}
222
223async fn end_poll(
224 State(state): State<AppState>,
225 Path(poll_id): Path<String>,
226 session: Session,
227) -> impl IntoResponse {
228 // Verify user is authenticated
229 if get_user_session(&session).await.is_err() {
230 return (StatusCode::UNAUTHORIZED, "Not authenticated").into_response();
231 }
232
233 let mut polls = state.active_polls.write().await;
234
235 let Some(poll) = polls.remove(&poll_id) else {
236 return (StatusCode::NOT_FOUND, "Poll not found").into_response();
237 };
238
239 // Send final update
240 let final_update = PollUpdate {
241 event_type: "ended".to_string(),
242 votes: poll.votes.clone(),
243 total_voters: poll.voters.len() as u32,
244 };
245 let _ = poll.tx.send(final_update);
246
247 // Signal shutdown to EventSub task
248 if let Some(shutdown_tx) = poll.shutdown_tx {
249 let _ = shutdown_tx.send(());
250 }
251
252 (StatusCode::OK, "Poll ended").into_response()
253}
254
255// -----Helper Functions-----
256
257async fn get_user_session(session: &Session) -> Result<UserSession, axum::response::Response> {
258 match session.get_value("user_session").await {
259 Ok(Some(user_session_json)) => serde_json::from_value::<UserSession>(user_session_json)
260 .map_err(|e| {
261 (
262 StatusCode::INTERNAL_SERVER_ERROR,
263 format!("Error deserializing user session: {}", e),
264 )
265 .into_response()
266 }),
267 Ok(None) => Err((StatusCode::UNAUTHORIZED, "Not authenticated").into_response()),
268 Err(e) => Err((
269 StatusCode::INTERNAL_SERVER_ERROR,
270 format!("Error retrieving session: {}", e),
271 )
272 .into_response()),
273 }
274}
275
276// -----EventSub Integration-----
277
278#[derive(Debug, Deserialize)]
279struct EventSubMessage {
280 metadata: EventSubMetadata,
281 payload: serde_json::Value,
282}
283
284#[derive(Debug, Deserialize)]
285struct EventSubMetadata {
286 message_type: String,
287 #[serde(default)]
288 subscription_type: Option<String>,
289}
290
291#[derive(Debug, Deserialize)]
292struct WelcomePayload {
293 session: EventSubSession,
294}
295
296#[derive(Debug, Deserialize)]
297struct EventSubSession {
298 id: String,
299}
300
301#[derive(Debug, Deserialize)]
302struct ChatMessageEvent {
303 chatter_user_id: String,
304 message: ChatMessage,
305}
306
307#[derive(Debug, Deserialize)]
308struct ChatMessage {
309 text: String,
310}
311
312async fn eventsub_listener(
313 state: AppState,
314 poll_id: String,
315 user_session: UserSession,
316 mut shutdown_rx: tokio::sync::oneshot::Receiver<()>,
317) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
318 println!("Starting EventSub listener for poll: {}", poll_id);
319
320 // Connect to Twitch EventSub WebSocket
321 let twitch_ws_url = "wss://eventsub.wss.twitch.tv/ws";
322 let (ws_stream, _) = connect_async(twitch_ws_url).await?;
323 println!("Connected to Twitch EventSub WebSocket");
324
325 let (mut write, mut read) = ws_stream.split();
326 use futures_util::{SinkExt, StreamExt};
327
328 // Wait for welcome message and get session_id
329 let session_id = loop {
330 tokio::select! {
331 _ = &mut shutdown_rx => {
332 println!("Shutdown signal received before welcome");
333 return Ok(());
334 }
335 msg = read.next() => {
336 match msg {
337 Some(Ok(TungsteniteMessage::Text(text))) => {
338 let message: EventSubMessage = serde_json::from_str(&text)?;
339 if message.metadata.message_type == "session_welcome" {
340 let payload: WelcomePayload = serde_json::from_value(message.payload)?;
341 println!("Received session_id: {}", payload.session.id);
342 break payload.session.id;
343 }
344 }
345 Some(Ok(TungsteniteMessage::Ping(data))) => {
346 write.send(TungsteniteMessage::Pong(data)).await?;
347 }
348 Some(Err(e)) => return Err(e.into()),
349 None => return Err("WebSocket closed before welcome".into()),
350 _ => {}
351 }
352 }
353 }
354 };
355
356 // Subscribe to channel.chat.message
357 subscribe_to_chat_messages(
358 &state.client_id,
359 &user_session.access_token,
360 &user_session.user_id,
361 &session_id,
362 )
363 .await?;
364
365 println!(
366 "Subscribed to chat messages for channel: {}",
367 user_session.channel
368 );
369
370 // Main event loop
371 loop {
372 tokio::select! {
373 _ = &mut shutdown_rx => {
374 println!("Shutdown signal received for poll: {}", poll_id);
375 break;
376 }
377 msg = read.next() => {
378 match msg {
379 Some(Ok(TungsteniteMessage::Text(text))) => {
380 if let Err(e) = handle_eventsub_message(&state, &poll_id, &text).await {
381 eprintln!("Error handling EventSub message: {}", e);
382 }
383 }
384 Some(Ok(TungsteniteMessage::Ping(data))) => {
385 if write.send(TungsteniteMessage::Pong(data)).await.is_err() {
386 break;
387 }
388 }
389 Some(Ok(TungsteniteMessage::Close(_))) => {
390 println!("EventSub connection closed");
391 break;
392 }
393 Some(Err(e)) => {
394 eprintln!("EventSub WebSocket error: {}", e);
395 break;
396 }
397 None => break,
398 _ => {}
399 }
400 }
401 }
402 }
403
404 // Cleanup eventsub connection
405 let _ = write.send(TungsteniteMessage::Close(None)).await;
406 println!("EventSub listener ended for poll: {}", poll_id);
407
408 Ok(())
409}
410
411async fn subscribe_to_chat_messages(
412 client_id: &str,
413 access_token: &str,
414 user_id: &str,
415 session_id: &str,
416) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
417 let client = reqwest::Client::new();
418
419 let body = serde_json::json!({
420 "type": "channel.chat.message",
421 "version": "1",
422 "condition": {
423 "broadcaster_user_id": user_id,
424 "user_id": user_id
425 },
426 "transport": {
427 "method": "websocket",
428 "session_id": session_id
429 }
430 });
431
432 let response = client
433 .post("https://api.twitch.tv/helix/eventsub/subscriptions")
434 .header("Authorization", format!("Bearer {}", access_token))
435 .header("Client-Id", client_id)
436 .header("Content-Type", "application/json")
437 .json(&body)
438 .send()
439 .await?;
440
441 let status = response.status();
442 let response_text = response.text().await?;
443 println!("Subscription response ({status}): {response_text}");
444
445 if !status.is_success() {
446 return Err(format!(
447 "Failed to subscribe to chat messages for user {}: {}",
448 user_id, response_text
449 )
450 .into());
451 }
452
453 Ok(())
454}
455
456async fn handle_eventsub_message(
457 state: &AppState,
458 poll_id: &str,
459 text: &str,
460) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
461 let message: EventSubMessage = serde_json::from_str(text)?;
462
463 match message.metadata.message_type.as_str() {
464 "notification" => {
465 if message.metadata.subscription_type.as_deref() == Some("channel.chat.message") {
466 let event: ChatMessageEvent = serde_json::from_value(
467 message.payload.get("event").cloned().unwrap_or_default(),
468 )?;
469
470 process_chat_vote(state, poll_id, &event).await;
471 }
472 }
473 "session_keepalive" => {
474 // Keepalive, nothing to do
475 }
476 "session_reconnect" => {
477 // TODO: Handle reconnect by connecting to new URL
478 println!("Received reconnect request - not yet implemented");
479 }
480 _ => {
481 println!("Received some message: {}", text);
482 }
483 }
484
485 Ok(())
486}
487
488async fn process_chat_vote(state: &AppState, poll_id: &str, event: &ChatMessageEvent) {
489 println!(
490 "Chat message received: '{}' from user {}",
491 event.message.text, event.chatter_user_id
492 );
493 let vote_text = event.message.text.trim().to_lowercase();
494
495 let mut polls = state.active_polls.write().await;
496 let Some(poll) = polls.get_mut(poll_id) else {
497 return;
498 };
499
500 // Check if user already voted
501 if poll.voters.contains(&event.chatter_user_id) {
502 println!("User has already voted in this poll");
503 return;
504 }
505
506 // Try to match vote to an option
507 // Support formats: "1", "!vote 1", "!vote option_name", or exact option name
508 let matched_option = parse_vote(&vote_text, &poll.config.options);
509 println!("Matched option: {:?}", matched_option);
510 if let Some(option) = matched_option {
511 // Update set of voters and the vote entry count
512 poll.voters.insert(event.chatter_user_id.clone());
513 *poll.votes.entry(option.clone()).or_insert(0) += 1;
514
515 // Broadcast update
516 let update = PollUpdate {
517 event_type: "vote".to_string(),
518 votes: poll.votes.clone(),
519 total_voters: poll.voters.len() as u32,
520 };
521 let _ = poll.tx.send(update);
522
523 println!("Vote recorded: {} -> {}", event.chatter_user_id, option);
524 }
525}
526
527fn parse_vote(text: &str, options: &[String]) -> Option<String> {
528 let text = text.trim().to_lowercase();
529
530 // Try exact match (case-insensitive)
531 for option in options {
532 if option.to_lowercase() == text {
533 return Some(option.clone());
534 }
535 }
536
537 // Try "!vote X" format
538 let vote_content = text
539 .strip_prefix("!vote ")
540 .or_else(|| text.strip_prefix("!v "))
541 .unwrap_or(&text);
542
543 // Try numeric vote (1-indexed)
544 if let Ok(num) = vote_content.parse::<usize>() {
545 if num >= 1 && num <= options.len() {
546 return Some(options[num - 1].clone());
547 }
548 }
549
550 // Try partial match at start
551 for option in options {
552 if option.to_lowercase().starts_with(vote_content) {
553 return Some(option.clone());
554 }
555 }
556
557 None
558}