Sniff and replay HTTP requests and responses — perfect for mocking APIs during testing.
1use std::{
2 process::exit,
3 sync::Arc,
4 thread,
5 time::{SystemTime, UNIX_EPOCH},
6};
7
8use http_body_util::{BodyExt, Full};
9use hyper::{
10 Request, Response,
11 body::{Buf, Bytes, Incoming},
12 server::conn::http1,
13};
14use hyper_util::rt::TokioIo;
15use owo_colors::OwoColorize;
16use serde::{Deserialize, Serialize};
17use tokio::{net::TcpListener, sync::Mutex};
18
19use crate::{
20 replay::start_replay_server,
21 store::{LogStore, save_logs_to_file},
22};
23
24pub const PROXY_LOG_FILE: &str = "replay_mocks.json";
25
26#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
27pub struct RequestLog {
28 pub timestamp: u64,
29 pub method: String,
30 pub path: String,
31 pub query_params: Option<String>,
32 pub headers: Vec<(String, String)>,
33 pub body: Option<String>,
34}
35
36#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
37pub struct ResponseLog {
38 pub status: u16,
39 pub headers: Vec<(String, String)>,
40 pub body: Option<String>,
41}
42
43#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
44pub struct ProxyLog {
45 pub request: RequestLog,
46 pub response: ResponseLog,
47}
48
49pub async fn start_server(
50 target: &str,
51 listen: &str,
52) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
53 let target_uri = target.parse::<hyper::Uri>()?;
54 let target_authority = target_uri.authority().ok_or("Invalid target URL")?;
55 let target_scheme = target_uri.scheme_str().ok_or("http")?;
56 let target_host = target_authority.host();
57 let target_port = target_authority
58 .port_u16()
59 .unwrap_or(if target_scheme == "https" { 443 } else { 80 });
60
61 let logs = Arc::new(Mutex::new(Vec::<ProxyLog>::new()));
62 let logs_for_saving = logs.clone();
63 tokio::spawn(async move {
64 loop {
65 tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
66 save_logs_to_file(&logs_for_saving, PROXY_LOG_FILE)
67 .await
68 .unwrap_or_else(|e| eprintln!("Error saving logs to file: {}", e));
69 }
70 });
71
72 let logs_for_replay = logs.clone();
73 thread::spawn(move || {
74 let rt = tokio::runtime::Builder::new_multi_thread()
75 .enable_all()
76 .build()
77 .unwrap();
78 rt.block_on(async {
79 match start_replay_server(logs_for_replay, "127.0.0.1:6688").await {
80 Ok(_) => {
81 println!("Replay server stopped");
82 exit(0);
83 }
84 Err(e) => eprintln!("Replay server error: {}", e),
85 }
86 });
87 });
88
89 let listener = TcpListener::bind(listen).await?;
90 println!("Target URL: {}", target.magenta());
91 println!("Proxy server is listening on {}", listen.green());
92 println!("Replay server is running on {}", "127.0.0.1:6688".green());
93
94 loop {
95 let (stream, _) = listener.accept().await?;
96 let io = TokioIo::new(stream);
97
98 let target_host_str = target_host.to_string();
99 let target_scheme = target_scheme.to_string();
100 let logs_clone = logs.clone();
101
102 tokio::task::spawn(async move {
103 let service = hyper::service::service_fn(move |req: Request<Incoming>| {
104 let target_host = target_host_str.clone();
105 let scheme = target_scheme.clone();
106 let logs = logs_clone.clone();
107
108 async move { proxy_handler(req, &target_host, target_port, &scheme, logs).await }
109 });
110
111 if let Err(err) = http1::Builder::new()
112 .keep_alive(false)
113 .max_buf_size(30 * 1024 * 1024)
114 .serve_connection(io, service)
115 .await
116 {
117 eprintln!("> Connection error: {}", err);
118 }
119 });
120 }
121}
122
123pub async fn proxy_handler(
124 req: Request<Incoming>,
125 target_host: &str,
126 target_port: u16,
127 scheme: &str,
128 logs: LogStore,
129) -> Result<Response<Full<Bytes>>, hyper::Error> {
130 let timestamp = SystemTime::now()
131 .duration_since(UNIX_EPOCH)
132 .unwrap()
133 .as_secs();
134
135 let method = req.method().clone();
136 let path = req.uri().path().to_string();
137 let query = req.uri().query().map(|q| q.to_string());
138
139 let headers: Vec<(String, String)> = req
140 .headers()
141 .iter()
142 .map(|(name, value)| (name.to_string(), value.to_str().unwrap_or("").to_string()))
143 .collect();
144
145 let (parts, body) = req.into_parts();
146 let body_bytes = match body.collect().await {
147 Ok(collected) => collected.aggregate(),
148 Err(e) => {
149 eprintln!("Error collecting request body: {}", e);
150 return Ok(Response::builder()
151 .status(500)
152 .body(Full::new(Bytes::from("Internal Server Error")))
153 .unwrap());
154 }
155 };
156
157 let body_vec = body_bytes.chunk().to_vec();
158 let body_str = String::from_utf8(body_vec.clone()).ok();
159
160 let forward_uri = if target_port != 443 && target_port != 80 {
161 format!(
162 "{}://{}:{}{}{}",
163 scheme,
164 target_host,
165 target_port,
166 parts.uri.path(),
167 parts
168 .uri
169 .query()
170 .map_or(String::new(), |q| format!("?{}", q))
171 )
172 } else {
173 format!(
174 "{}://{}{}{}",
175 scheme,
176 target_host,
177 parts.uri.path(),
178 parts
179 .uri
180 .query()
181 .map_or(String::new(), |q| format!("?{}", q))
182 )
183 };
184
185 println!("{} {} {}", method.yellow(), path, forward_uri.magenta());
186
187 let client = reqwest::Client::builder()
188 .timeout(std::time::Duration::from_secs(30))
189 .danger_accept_invalid_certs(true)
190 .build()
191 .unwrap_or_else(|_| reqwest::Client::new());
192
193 let mut req_builder = match method.as_str() {
194 "GET" => client.get(&forward_uri),
195 "POST" => client.post(&forward_uri),
196 "PUT" => client.put(&forward_uri),
197 "DELETE" => client.delete(&forward_uri),
198 "HEAD" => client.head(&forward_uri),
199 "OPTIONS" => client.request(reqwest::Method::OPTIONS, &forward_uri),
200 "PATCH" => client.patch(&forward_uri),
201 _ => {
202 eprintln!("Unsupported method: {}", method);
203 return Ok(Response::builder()
204 .status(400)
205 .body(Full::new(Bytes::from("Bad Request: Unsupported Method")))
206 .unwrap());
207 }
208 };
209
210 for (name, value) in &headers {
211 if name.to_lowercase() != "host" && name.to_lowercase() != "connection" {
212 if let Ok(header_name) = reqwest::header::HeaderName::from_bytes(name.as_bytes()) {
213 if let Ok(header_value) = reqwest::header::HeaderValue::from_str(value) {
214 req_builder = req_builder.header(header_name, header_value);
215 }
216 }
217 }
218 }
219
220 if !body_vec.is_empty() {
221 req_builder = req_builder.body(body_vec.clone());
222 }
223
224 let resp = match req_builder.send().await {
225 Ok(resp) => resp,
226 Err(e) => {
227 eprintln!("Error sending request: {}", e);
228 return Ok(Response::builder()
229 .status(502)
230 .body(Full::new(Bytes::from(format!("Bad Gateway: {}", e))))
231 .unwrap());
232 }
233 };
234
235 let status = resp.status().as_u16();
236
237 let resp_headers: Vec<(String, String)> = resp
238 .headers()
239 .iter()
240 .map(|(name, value)| (name.to_string(), value.to_str().unwrap_or("").to_string()))
241 .collect();
242
243 let resp_bytes = match resp.bytes().await {
244 Ok(bytes) => bytes,
245 Err(e) => {
246 eprintln!("Error reading response body: {}", e);
247 return Ok(Response::builder()
248 .status(500)
249 .body(Full::new(Bytes::from("Internal Server Error")))
250 .unwrap());
251 }
252 };
253
254 let resp_vec = resp_bytes.to_vec();
255 let resp_str = String::from_utf8(resp_vec.clone()).ok();
256
257 let log_entry = ProxyLog {
258 request: RequestLog {
259 timestamp,
260 method: method.to_string(),
261 path,
262 query_params: query,
263 headers,
264 body: body_str,
265 },
266 response: ResponseLog {
267 status,
268 headers: resp_headers.clone(),
269 body: resp_str.clone(),
270 },
271 };
272
273 {
274 let mut logs_guard = logs.lock().await;
275 if !logs_guard.iter().any(|log| {
276 log.request.method == log_entry.request.method
277 && log.request.path == log_entry.request.path
278 && log.request.query_params == log_entry.request.query_params
279 }) {
280 logs_guard.push(log_entry.clone());
281 }
282 }
283
284 println!(
285 "Saved request/response to log store {}",
286 PROXY_LOG_FILE.magenta()
287 );
288
289 let mut builder = Response::builder().status(status);
290
291 for (name, value) in resp_headers {
292 if name.to_lowercase() != "connection" && name.to_lowercase() != "transfer-encoding" {
293 if let Ok(header_name) = hyper::header::HeaderName::from_bytes(name.as_bytes()) {
294 if let Ok(header_value) = hyper::header::HeaderValue::from_str(&value) {
295 builder = builder.header(header_name, header_value);
296 }
297 }
298 }
299 }
300
301 builder = builder.header("content-length", resp_vec.len());
302 builder = builder.header("connection", "close");
303
304 Ok(builder
305 .body(Full::new(Bytes::from(resp_vec)))
306 .unwrap_or_else(|_| {
307 Response::builder()
308 .status(500)
309 .body(Full::new(Bytes::from("Internal Server Error")))
310 .unwrap()
311 }))
312}