The world's most clever kitty cat
at main 283 lines 8.7 kB view raw
1#![allow(unused_features)] 2#![feature(iter_map_windows)] 3#![feature(iter_intersperse)] 4#![feature(test)] 5 6mod brain; 7mod cmd; 8mod on_message; 9mod status; 10 11pub mod prelude { 12 pub use anyhow::Context; 13 use std::result::Result as StdResult; 14 pub type Result<T = (), E = anyhow::Error> = StdResult<T, E>; 15} 16 17use std::{ 18 collections::HashSet, 19 fs::File, 20 path::{Path, PathBuf}, 21 sync::{ 22 Arc, 23 atomic::{AtomicBool, Ordering}, 24 }, 25}; 26 27use brotli::enc::{BrotliEncoderParams, backward_references::BrotliEncoderMode}; 28use log::{debug, error, info, warn}; 29use prelude::*; 30use tokio::{ 31 signal::unix::{SignalKind, signal}, 32 sync::RwLock, 33 time::{self, Duration}, 34}; 35use twilight_gateway::{ 36 CloseFrame, Event, EventTypeFlags, Intents, MessageSender, Shard, ShardId, StreamExt, 37}; 38use twilight_http::Client as HttpClient; 39use twilight_model::{ 40 application::interaction::InteractionData, 41 id::{ 42 Id, 43 marker::{ApplicationMarker, ChannelMarker, UserMarker}, 44 }, 45}; 46 47use crate::{ 48 brain::Brain, 49 cmd::{handle_app_command, register_all_commands}, 50 on_message::handle_discord_message, 51 status::update_status, 52}; 53 54pub type BrainHandle = RwLock<Brain>; 55 56#[derive(Debug)] 57pub struct BotContext { 58 http: HttpClient, 59 self_id: Id<UserMarker>, 60 app_id: Id<ApplicationMarker>, 61 owners: HashSet<Id<UserMarker>>, 62 brain_file_path: PathBuf, 63 reply_channels: HashSet<Id<ChannelMarker>>, 64 brain_handle: BrainHandle, 65 shard_sender: MessageSender, 66 pending_save: AtomicBool, 67} 68 69async fn handle_discord_event(event: Event, ctx: Arc<BotContext>) -> Result { 70 match event { 71 Event::MessageCreate(msg) => handle_discord_message(msg, ctx) 72 .await 73 .context("While handling a new message"), 74 Event::InteractionCreate(mut inter) => { 75 if let Some(InteractionData::ApplicationCommand(data)) = 76 std::mem::take(&mut inter.0.data) 77 { 78 handle_app_command(*data, ctx, inter.0) 79 .await 80 .context("While handling an app command") 81 } else { 82 Ok(()) 83 } 84 } 85 Event::Ready(ev) => { 86 info!("Connected to gateway as {}", ev.user.name); 87 let brain = ctx.brain_handle.read().await; 88 update_status(&brain, &ctx.shard_sender).context("Failed to update status on ready") 89 } 90 _ => Ok(()), 91 } 92} 93 94const BROTLI_BUF_SIZE: usize = 1024 * 1000; 95fn get_brotli_params() -> BrotliEncoderParams { 96 BrotliEncoderParams { 97 quality: 5, 98 mode: BrotliEncoderMode::BROTLI_MODE_TEXT, 99 ..Default::default() 100 } 101} 102 103fn load_brain(path: &Path) -> Result<Option<Brain>> { 104 if path.exists() { 105 let mut file = File::open(path).context("Failed to open brain file")?; 106 let mut brotli_stream = brotli::Decompressor::new(&mut file, BROTLI_BUF_SIZE); 107 rmp_serde::from_read(&mut brotli_stream) 108 .map(Some) 109 .context("Failed to decode brain file") 110 } else { 111 Ok(None) 112 } 113} 114 115async fn save_brain(ctx: Arc<BotContext>) -> Result { 116 let scratch_path = ctx.brain_file_path.with_file_name(format!( 117 "~{}", 118 ctx.brain_file_path.file_name().unwrap().to_str().unwrap() 119 )); 120 let mut file = File::create(&scratch_path).context("Failed to open brain file")?; 121 let mut brotli_writer = 122 brotli::CompressorWriter::with_params(&mut file, BROTLI_BUF_SIZE, &get_brotli_params()); 123 124 let brain = ctx.brain_handle.read().await; 125 rmp_serde::encode::write(&mut brotli_writer, &*brain) 126 .context("Failed to write serialized brain")?; 127 128 std::fs::rename(&scratch_path, &ctx.brain_file_path) 129 .context("Failed to override scratch file")?; 130 131 debug!("Saved brain file"); 132 Ok(()) 133} 134 135#[tokio::main] 136async fn main() -> Result { 137 let mut clog = colog::default_builder(); 138 clog.filter( 139 None, 140 if cfg!(debug_assertions) { 141 log::LevelFilter::Debug 142 } else { 143 log::LevelFilter::Info 144 }, 145 ); 146 clog.try_init().context("Failed to initialize colog")?; 147 148 info!("Start of bingus-bot {}", env!("CARGO_PKG_VERSION")); 149 150 // Config 151 let token_file = std::env::var("TOKEN_FILE").context("Missing TOKEN_FILE env var")?; 152 let reply_channels = std::env::var("REPLY_CHANNELS") 153 .context("Missing REPLY_CHANNELS env var")? 154 .split(",") 155 .filter_map(|s| { 156 if s.trim().is_empty() { 157 None 158 } else { 159 Some(s.trim().parse::<u64>().map(|c| Id::new(c))) 160 } 161 }) 162 .collect::<Result<_, _>>() 163 .context("Invalid channel IDs for REPLY_CHANNELS")?; 164 let brain_file_path = 165 PathBuf::from(std::env::var("BRAIN_FILE").unwrap_or_else(|_| "brain.msgpackz".to_string())); 166 let intents = Intents::GUILD_MESSAGES | Intents::MESSAGE_CONTENT; 167 168 // Read token 169 let token = std::fs::read_to_string(token_file).context("Failed to read bot token")?; 170 let token = token.trim(); 171 172 // Read Brain 173 let brain = if let Some(brain) = load_brain(&brain_file_path)? { 174 info!("Loading brain from {brain_file_path:?}"); 175 brain 176 } else { 177 info!("Creating new brain file at {brain_file_path:?}"); 178 Brain::default() 179 }; 180 let brain_handle = RwLock::new(brain); 181 182 // Init 183 let mut shard = Shard::new(ShardId::ONE, token.to_string(), intents); 184 let http = HttpClient::new(token.to_string()); 185 186 let app = http 187 .current_user_application() 188 .await 189 .context("Failed to get current App")? 190 .model() 191 .await 192 .context("Failed to deserialize")?; 193 194 let app_id = app.id; 195 196 let self_id = app.bot.context("App is not a bot!")?.id; 197 198 let owners = if let Some(user) = app.owner { 199 HashSet::from_iter([user.id]) 200 } else if let Some(team) = app.team { 201 team.members.iter().map(|m| m.user.id).collect() 202 } else { 203 warn!("No Owner?? Bingus is free!!!"); 204 HashSet::new() 205 }; 206 207 let context = Arc::new(BotContext { 208 http, 209 self_id, 210 app_id, 211 owners, 212 reply_channels, 213 brain_file_path, 214 brain_handle, 215 shard_sender: shard.sender(), 216 pending_save: AtomicBool::new(false), 217 }); 218 219 info!("Registering Commands..."); 220 register_all_commands(context.clone()).await?; 221 222 let mut interval = time::interval(Duration::from_secs(60)); 223 interval.tick().await; 224 225 let mut sigterm = signal(SignalKind::terminate()).context("Failed to listen to SIGTERM")?; 226 227 info!("Connecting to gateway..."); 228 229 loop { 230 tokio::select! { 231 232 biased; 233 234 Ok(()) = tokio::signal::ctrl_c() => { 235 info!("SIGINT: Closing connection and saving"); 236 shard.close(CloseFrame::NORMAL); 237 } 238 _ = sigterm.recv() => { 239 info!("SIGTERM: Closing connection and saving"); 240 shard.close(CloseFrame::NORMAL); 241 } 242 _ = interval.tick() => { 243 debug!("Save Interval"); 244 if context.pending_save.load(Ordering::Relaxed) { 245 let ctx = context.clone(); 246 tokio::spawn(async move { 247 if let Err(why) = save_brain(ctx.clone()).await { 248 error!("Failed to save brain file:\n{why:?}"); 249 } 250 ctx.pending_save.store(false, Ordering::Relaxed); 251 }); 252 } 253 }, 254 opt = shard.next_event(EventTypeFlags::all()) => { 255 match opt { 256 Some(Ok(Event::GatewayClose(_))) | None => { 257 info!("Disconnected from Discord"); 258 break; 259 } 260 Some(Ok(event)) => { 261 let ctx = context.clone(); 262 tokio::spawn(async move { 263 if let Err(why) = handle_discord_event(event, ctx).await { 264 error!("Error while processing Discord event:\n{why:?}"); 265 } 266 }); 267 } 268 Some(Err(why)) => { 269 warn!("Failed to receive event:\n{why:?}"); 270 } 271 } 272 } 273 } 274 } 275 276 if context.pending_save.load(Ordering::Relaxed) { 277 save_brain(context) 278 .await 279 .context("Failed to write brain file on exit")?; 280 } 281 282 Ok(()) 283}