a simple IRC bouncer
at master 5.5 kB view raw
1/* 2 * This Source Code Form is subject to the terms of the Mozilla Public 3 * License, v. 2.0. If a copy of the MPL was not distributed with this 4 * file, You can obtain one at https://mozilla.org/MPL/2.0/. 5 */ 6 7use std::{collections::VecDeque, ops::Deref, sync::Arc}; 8 9use anyhow::{Result, bail}; 10use clap::Parser; 11use tokio::{ 12 io::{AsyncBufRead, AsyncBufReadExt, AsyncWriteExt, BufReader}, 13 net::{ 14 TcpListener, TcpStream, 15 tcp::{OwnedReadHalf, OwnedWriteHalf}, 16 }, 17 sync::Mutex, 18}; 19use tracing::{debug, error, warn}; 20use tracing_subscriber::{EnvFilter, prelude::*}; 21 22#[derive(Parser, Debug)] 23#[command(version)] 24struct Args { 25 #[arg(short, long, default_value_t = 6667)] 26 port: u16, 27 /// The server to proxy. 28 host: String, 29} 30 31type BufReadHalf = BufReader<OwnedReadHalf>; 32 33struct Connection { 34 connected: Mutex<bool>, 35 read: Mutex<BufReadHalf>, 36 write: Mutex<OwnedWriteHalf>, 37} 38 39struct InnerBouncer { 40 client: Connection, 41 server: Connection, 42 // TODO: persist 43 message_queue: std::sync::Mutex<VecDeque<Vec<u8>>>, 44} 45 46struct Bouncer(Arc<InnerBouncer>); 47 48#[tokio::main] 49async fn main() -> Result<()> { 50 tracing_subscriber::registry() 51 .with(EnvFilter::from_default_env()) 52 .with(tracing_subscriber::fmt::layer()) 53 .init(); 54 55 let args = Args::parse(); 56 57 let client_listener = TcpListener::bind("127.0.0.1:6667").await?; 58 let client_conn = client_listener.accept().await?.0; 59 let server_conn = TcpStream::connect(args.host).await?; 60 61 let bouncer = Bouncer::new(client_conn, server_conn); 62 let its_2_am_idk_what_to_name_this_bazinga = bouncer.clone(); 63 tokio::spawn(async move { 64 its_2_am_idk_what_to_name_this_bazinga 65 .clientbound_task() 66 .await 67 .unwrap(); 68 }); 69 70 let its_2_am_idk_what_to_name_this_bazinga = bouncer.clone(); 71 tokio::spawn(async move { 72 its_2_am_idk_what_to_name_this_bazinga 73 .serverbound_task() 74 .await 75 .unwrap(); 76 }); 77 78 tokio::signal::ctrl_c().await?; 79 Ok(()) 80} 81 82impl Clone for Bouncer { 83 fn clone(&self) -> Self { 84 Self(self.0.clone()) 85 } 86} 87 88impl Deref for Bouncer { 89 type Target = InnerBouncer; 90 fn deref(&self) -> &Self::Target { 91 &self.0 92 } 93} 94 95impl Bouncer { 96 fn new(client: TcpStream, server: TcpStream) -> Self { 97 let (client_read, client_write) = client.into_split(); 98 let (server_read, server_write) = server.into_split(); 99 100 Self( 101 InnerBouncer { 102 client: Connection { 103 connected: true.into(), 104 read: BufReader::new(client_read).into(), 105 write: client_write.into(), 106 }, 107 server: Connection { 108 connected: true.into(), 109 read: BufReader::new(server_read).into(), 110 write: server_write.into(), 111 }, 112 message_queue: Default::default(), 113 } 114 .into(), 115 ) 116 } 117 118 pub async fn clientbound_task(&self) -> Result<()> { 119 let mut buf = Vec::new(); 120 121 loop { 122 // TODO: don't re-lock read on each loop? 123 let bytes_read = { 124 let mut read = self.server.read.lock().await; 125 read.read_until(b'\n', &mut buf).await? 126 }; 127 128 if bytes_read == 0 { 129 bail!("EOF from server, what?"); 130 } 131 132 let line = String::from_utf8_lossy(&buf); 133 134 debug!("SERVER->CLIENT: {}", line.trim_end_matches("\r\n")); 135 self.write_to_client(&buf); 136 buf.clear(); 137 } 138 } 139 140 pub async fn serverbound_task(&self) -> Result<()> { 141 let mut buf = Vec::new(); 142 143 let mut read = self.connections.client_read.lock().await; 144 let mut write = self.connections.server_write.lock().await; 145 146 loop { 147 let read = read.read_until(b'\n', &mut buf).await?; 148 if read == 0 { 149 warn!("EOF from client, shutting down serverbound task"); 150 let mut connected = self.connections.client_connected.write().await; 151 *connected = false; 152 return Ok(()); 153 } 154 155 let line = String::from_utf8_lossy(&buf); 156 let line = line.trim_end_matches("\r\n"); 157 158 debug!("CLIENT->SERVER: {}", line); 159 160 write.write(&buf).await?; 161 buf.clear(); 162 } 163 } 164 165 // writes to the client. if disconnected, enqueues instead and marks the client as disconnected. 166 async fn write_to_client(&self, buf: &[u8]) -> std::io::Result<()> { 167 let mut connected = self.client.connected.lock().await; 168 if *connected { 169 let mut write = self.client.write.lock().await; 170 // TODO: handle backlog here? 171 match write.write(buf).await { 172 Err(ref e) if e.kind() == std::io::ErrorKind::BrokenPipe => { 173 warn!("Client disconnected"); 174 *connected = false; 175 } 176 Err(why) => return Err(why), 177 Ok(_) => return Ok(()), 178 } 179 } 180 181 // always run if not initially connected or disconnected above^^^, so OK to return afterwards 182 if !*connected { 183 debug!("writing to queue instead..."); 184 self.message_queue.lock().unwrap().push_back(buf.to_vec()); 185 } 186 187 // it's 5am. i'll clean this up later 188 Ok(()) 189 } 190}