1/*
2 * This Source Code Form is subject to the terms of the Mozilla Public
3 * License, v. 2.0. If a copy of the MPL was not distributed with this
4 * file, You can obtain one at https://mozilla.org/MPL/2.0/.
5 */
6
7use std::{collections::VecDeque, ops::Deref, sync::Arc};
8
9use anyhow::{Result, bail};
10use clap::Parser;
11use tokio::{
12 io::{AsyncBufRead, AsyncBufReadExt, AsyncWriteExt, BufReader},
13 net::{
14 TcpListener, TcpStream,
15 tcp::{OwnedReadHalf, OwnedWriteHalf},
16 },
17 sync::Mutex,
18};
19use tracing::{debug, error, warn};
20use tracing_subscriber::{EnvFilter, prelude::*};
21
22#[derive(Parser, Debug)]
23#[command(version)]
24struct Args {
25 #[arg(short, long, default_value_t = 6667)]
26 port: u16,
27 /// The server to proxy.
28 host: String,
29}
30
31type BufReadHalf = BufReader<OwnedReadHalf>;
32
33struct Connection {
34 connected: Mutex<bool>,
35 read: Mutex<BufReadHalf>,
36 write: Mutex<OwnedWriteHalf>,
37}
38
39struct InnerBouncer {
40 client: Connection,
41 server: Connection,
42 // TODO: persist
43 message_queue: std::sync::Mutex<VecDeque<Vec<u8>>>,
44}
45
46struct Bouncer(Arc<InnerBouncer>);
47
48#[tokio::main]
49async fn main() -> Result<()> {
50 tracing_subscriber::registry()
51 .with(EnvFilter::from_default_env())
52 .with(tracing_subscriber::fmt::layer())
53 .init();
54
55 let args = Args::parse();
56
57 let client_listener = TcpListener::bind("127.0.0.1:6667").await?;
58 let client_conn = client_listener.accept().await?.0;
59 let server_conn = TcpStream::connect(args.host).await?;
60
61 let bouncer = Bouncer::new(client_conn, server_conn);
62 let its_2_am_idk_what_to_name_this_bazinga = bouncer.clone();
63 tokio::spawn(async move {
64 its_2_am_idk_what_to_name_this_bazinga
65 .clientbound_task()
66 .await
67 .unwrap();
68 });
69
70 let its_2_am_idk_what_to_name_this_bazinga = bouncer.clone();
71 tokio::spawn(async move {
72 its_2_am_idk_what_to_name_this_bazinga
73 .serverbound_task()
74 .await
75 .unwrap();
76 });
77
78 tokio::signal::ctrl_c().await?;
79 Ok(())
80}
81
82impl Clone for Bouncer {
83 fn clone(&self) -> Self {
84 Self(self.0.clone())
85 }
86}
87
88impl Deref for Bouncer {
89 type Target = InnerBouncer;
90 fn deref(&self) -> &Self::Target {
91 &self.0
92 }
93}
94
95impl Bouncer {
96 fn new(client: TcpStream, server: TcpStream) -> Self {
97 let (client_read, client_write) = client.into_split();
98 let (server_read, server_write) = server.into_split();
99
100 Self(
101 InnerBouncer {
102 client: Connection {
103 connected: true.into(),
104 read: BufReader::new(client_read).into(),
105 write: client_write.into(),
106 },
107 server: Connection {
108 connected: true.into(),
109 read: BufReader::new(server_read).into(),
110 write: server_write.into(),
111 },
112 message_queue: Default::default(),
113 }
114 .into(),
115 )
116 }
117
118 pub async fn clientbound_task(&self) -> Result<()> {
119 let mut buf = Vec::new();
120
121 loop {
122 // TODO: don't re-lock read on each loop?
123 let bytes_read = {
124 let mut read = self.server.read.lock().await;
125 read.read_until(b'\n', &mut buf).await?
126 };
127
128 if bytes_read == 0 {
129 bail!("EOF from server, what?");
130 }
131
132 let line = String::from_utf8_lossy(&buf);
133
134 debug!("SERVER->CLIENT: {}", line.trim_end_matches("\r\n"));
135 self.write_to_client(&buf);
136 buf.clear();
137 }
138 }
139
140 pub async fn serverbound_task(&self) -> Result<()> {
141 let mut buf = Vec::new();
142
143 let mut read = self.connections.client_read.lock().await;
144 let mut write = self.connections.server_write.lock().await;
145
146 loop {
147 let read = read.read_until(b'\n', &mut buf).await?;
148 if read == 0 {
149 warn!("EOF from client, shutting down serverbound task");
150 let mut connected = self.connections.client_connected.write().await;
151 *connected = false;
152 return Ok(());
153 }
154
155 let line = String::from_utf8_lossy(&buf);
156 let line = line.trim_end_matches("\r\n");
157
158 debug!("CLIENT->SERVER: {}", line);
159
160 write.write(&buf).await?;
161 buf.clear();
162 }
163 }
164
165 // writes to the client. if disconnected, enqueues instead and marks the client as disconnected.
166 async fn write_to_client(&self, buf: &[u8]) -> std::io::Result<()> {
167 let mut connected = self.client.connected.lock().await;
168 if *connected {
169 let mut write = self.client.write.lock().await;
170 // TODO: handle backlog here?
171 match write.write(buf).await {
172 Err(ref e) if e.kind() == std::io::ErrorKind::BrokenPipe => {
173 warn!("Client disconnected");
174 *connected = false;
175 }
176 Err(why) => return Err(why),
177 Ok(_) => return Ok(()),
178 }
179 }
180
181 // always run if not initially connected or disconnected above^^^, so OK to return afterwards
182 if !*connected {
183 debug!("writing to queue instead...");
184 self.message_queue.lock().unwrap().push_back(buf.to_vec());
185 }
186
187 // it's 5am. i'll clean this up later
188 Ok(())
189 }
190}