Sniff and replay HTTP requests and responses — perfect for mocking APIs during testing.
at main 9.8 kB view raw
1use std::{ 2 process::exit, 3 sync::Arc, 4 thread, 5 time::{SystemTime, UNIX_EPOCH}, 6}; 7 8use http_body_util::{BodyExt, Full}; 9use hyper::{ 10 Request, Response, 11 body::{Buf, Bytes, Incoming}, 12 server::conn::http1, 13}; 14use hyper_util::rt::TokioIo; 15use owo_colors::OwoColorize; 16use serde::{Deserialize, Serialize}; 17use tokio::{net::TcpListener, sync::Mutex}; 18 19use crate::{ 20 replay::start_replay_server, 21 store::{LogStore, save_logs_to_file}, 22}; 23 24pub const PROXY_LOG_FILE: &str = "replay_mocks.json"; 25 26#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] 27pub struct RequestLog { 28 pub timestamp: u64, 29 pub method: String, 30 pub path: String, 31 pub query_params: Option<String>, 32 pub headers: Vec<(String, String)>, 33 pub body: Option<String>, 34} 35 36#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] 37pub struct ResponseLog { 38 pub status: u16, 39 pub headers: Vec<(String, String)>, 40 pub body: Option<String>, 41} 42 43#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] 44pub struct ProxyLog { 45 pub request: RequestLog, 46 pub response: ResponseLog, 47} 48 49pub async fn start_server( 50 target: &str, 51 listen: &str, 52) -> Result<(), Box<dyn std::error::Error + Send + Sync>> { 53 let target_uri = target.parse::<hyper::Uri>()?; 54 let target_authority = target_uri.authority().ok_or("Invalid target URL")?; 55 let target_scheme = target_uri.scheme_str().ok_or("http")?; 56 let target_host = target_authority.host(); 57 let target_port = target_authority 58 .port_u16() 59 .unwrap_or(if target_scheme == "https" { 443 } else { 80 }); 60 61 let logs = Arc::new(Mutex::new(Vec::<ProxyLog>::new())); 62 let logs_for_saving = logs.clone(); 63 tokio::spawn(async move { 64 loop { 65 tokio::time::sleep(tokio::time::Duration::from_secs(5)).await; 66 save_logs_to_file(&logs_for_saving, PROXY_LOG_FILE) 67 .await 68 .unwrap_or_else(|e| eprintln!("Error saving logs to file: {}", e)); 69 } 70 }); 71 72 let logs_for_replay = logs.clone(); 73 thread::spawn(move || { 74 let rt = tokio::runtime::Builder::new_multi_thread() 75 .enable_all() 76 .build() 77 .unwrap(); 78 rt.block_on(async { 79 match start_replay_server(logs_for_replay, "127.0.0.1:6688").await { 80 Ok(_) => { 81 println!("Replay server stopped"); 82 exit(0); 83 } 84 Err(e) => eprintln!("Replay server error: {}", e), 85 } 86 }); 87 }); 88 89 let listener = TcpListener::bind(listen).await?; 90 println!("Target URL: {}", target.magenta()); 91 println!("Proxy server is listening on {}", listen.green()); 92 println!("Replay server is running on {}", "127.0.0.1:6688".green()); 93 94 loop { 95 let (stream, _) = listener.accept().await?; 96 let io = TokioIo::new(stream); 97 98 let target_host_str = target_host.to_string(); 99 let target_scheme = target_scheme.to_string(); 100 let logs_clone = logs.clone(); 101 102 tokio::task::spawn(async move { 103 let service = hyper::service::service_fn(move |req: Request<Incoming>| { 104 let target_host = target_host_str.clone(); 105 let scheme = target_scheme.clone(); 106 let logs = logs_clone.clone(); 107 108 async move { proxy_handler(req, &target_host, target_port, &scheme, logs).await } 109 }); 110 111 if let Err(err) = http1::Builder::new() 112 .keep_alive(false) 113 .max_buf_size(30 * 1024 * 1024) 114 .serve_connection(io, service) 115 .await 116 { 117 eprintln!("> Connection error: {}", err); 118 } 119 }); 120 } 121} 122 123pub async fn proxy_handler( 124 req: Request<Incoming>, 125 target_host: &str, 126 target_port: u16, 127 scheme: &str, 128 logs: LogStore, 129) -> Result<Response<Full<Bytes>>, hyper::Error> { 130 let timestamp = SystemTime::now() 131 .duration_since(UNIX_EPOCH) 132 .unwrap() 133 .as_secs(); 134 135 let method = req.method().clone(); 136 let path = req.uri().path().to_string(); 137 let query = req.uri().query().map(|q| q.to_string()); 138 139 let headers: Vec<(String, String)> = req 140 .headers() 141 .iter() 142 .map(|(name, value)| (name.to_string(), value.to_str().unwrap_or("").to_string())) 143 .collect(); 144 145 let (parts, body) = req.into_parts(); 146 let body_bytes = match body.collect().await { 147 Ok(collected) => collected.aggregate(), 148 Err(e) => { 149 eprintln!("Error collecting request body: {}", e); 150 return Ok(Response::builder() 151 .status(500) 152 .body(Full::new(Bytes::from("Internal Server Error"))) 153 .unwrap()); 154 } 155 }; 156 157 let body_vec = body_bytes.chunk().to_vec(); 158 let body_str = String::from_utf8(body_vec.clone()).ok(); 159 160 let forward_uri = if target_port != 443 && target_port != 80 { 161 format!( 162 "{}://{}:{}{}{}", 163 scheme, 164 target_host, 165 target_port, 166 parts.uri.path(), 167 parts 168 .uri 169 .query() 170 .map_or(String::new(), |q| format!("?{}", q)) 171 ) 172 } else { 173 format!( 174 "{}://{}{}{}", 175 scheme, 176 target_host, 177 parts.uri.path(), 178 parts 179 .uri 180 .query() 181 .map_or(String::new(), |q| format!("?{}", q)) 182 ) 183 }; 184 185 println!("{} {} {}", method.yellow(), path, forward_uri.magenta()); 186 187 let client = reqwest::Client::builder() 188 .timeout(std::time::Duration::from_secs(30)) 189 .danger_accept_invalid_certs(true) 190 .build() 191 .unwrap_or_else(|_| reqwest::Client::new()); 192 193 let mut req_builder = match method.as_str() { 194 "GET" => client.get(&forward_uri), 195 "POST" => client.post(&forward_uri), 196 "PUT" => client.put(&forward_uri), 197 "DELETE" => client.delete(&forward_uri), 198 "HEAD" => client.head(&forward_uri), 199 "OPTIONS" => client.request(reqwest::Method::OPTIONS, &forward_uri), 200 "PATCH" => client.patch(&forward_uri), 201 _ => { 202 eprintln!("Unsupported method: {}", method); 203 return Ok(Response::builder() 204 .status(400) 205 .body(Full::new(Bytes::from("Bad Request: Unsupported Method"))) 206 .unwrap()); 207 } 208 }; 209 210 for (name, value) in &headers { 211 if name.to_lowercase() != "host" && name.to_lowercase() != "connection" { 212 if let Ok(header_name) = reqwest::header::HeaderName::from_bytes(name.as_bytes()) { 213 if let Ok(header_value) = reqwest::header::HeaderValue::from_str(value) { 214 req_builder = req_builder.header(header_name, header_value); 215 } 216 } 217 } 218 } 219 220 if !body_vec.is_empty() { 221 req_builder = req_builder.body(body_vec.clone()); 222 } 223 224 let resp = match req_builder.send().await { 225 Ok(resp) => resp, 226 Err(e) => { 227 eprintln!("Error sending request: {}", e); 228 return Ok(Response::builder() 229 .status(502) 230 .body(Full::new(Bytes::from(format!("Bad Gateway: {}", e)))) 231 .unwrap()); 232 } 233 }; 234 235 let status = resp.status().as_u16(); 236 237 let resp_headers: Vec<(String, String)> = resp 238 .headers() 239 .iter() 240 .map(|(name, value)| (name.to_string(), value.to_str().unwrap_or("").to_string())) 241 .collect(); 242 243 let resp_bytes = match resp.bytes().await { 244 Ok(bytes) => bytes, 245 Err(e) => { 246 eprintln!("Error reading response body: {}", e); 247 return Ok(Response::builder() 248 .status(500) 249 .body(Full::new(Bytes::from("Internal Server Error"))) 250 .unwrap()); 251 } 252 }; 253 254 let resp_vec = resp_bytes.to_vec(); 255 let resp_str = String::from_utf8(resp_vec.clone()).ok(); 256 257 let log_entry = ProxyLog { 258 request: RequestLog { 259 timestamp, 260 method: method.to_string(), 261 path, 262 query_params: query, 263 headers, 264 body: body_str, 265 }, 266 response: ResponseLog { 267 status, 268 headers: resp_headers.clone(), 269 body: resp_str.clone(), 270 }, 271 }; 272 273 { 274 let mut logs_guard = logs.lock().await; 275 if !logs_guard.iter().any(|log| { 276 log.request.method == log_entry.request.method 277 && log.request.path == log_entry.request.path 278 && log.request.query_params == log_entry.request.query_params 279 }) { 280 logs_guard.push(log_entry.clone()); 281 } 282 } 283 284 println!( 285 "Saved request/response to log store {}", 286 PROXY_LOG_FILE.magenta() 287 ); 288 289 let mut builder = Response::builder().status(status); 290 291 for (name, value) in resp_headers { 292 if name.to_lowercase() != "connection" && name.to_lowercase() != "transfer-encoding" { 293 if let Ok(header_name) = hyper::header::HeaderName::from_bytes(name.as_bytes()) { 294 if let Ok(header_value) = hyper::header::HeaderValue::from_str(&value) { 295 builder = builder.header(header_name, header_value); 296 } 297 } 298 } 299 } 300 301 builder = builder.header("content-length", resp_vec.len()); 302 builder = builder.header("connection", "close"); 303 304 Ok(builder 305 .body(Full::new(Bytes::from(resp_vec))) 306 .unwrap_or_else(|_| { 307 Response::builder() 308 .status(500) 309 .body(Full::new(Bytes::from("Internal Server Error"))) 310 .unwrap() 311 })) 312}