this repo has no description
1// Project Includes
2#include "client.hpp"
3#include "file_descriptor.hpp"
4#include "http.hpp"
5// Standard Library Includes
6#include <algorithm>
7#include <array>
8#include <bit>
9#include <cassert>
10#include <cctype>
11#include <filesystem>
12#include <fstream>
13#include <iostream>
14#include <print>
15#include <span>
16#include <string_view>
17#include <system_error>
18#include <vector>
19
20[[nodiscard]] static std::expected<std::string, std::string> decode_url_path(std::string_view url_path)
21{
22 std::string decoded;
23 decoded.reserve(url_path.size());
24
25 auto hex_value = [](char c) -> int {
26 if (c >= '0' && c <= '9')
27 return c - '0';
28 c = static_cast<char>(std::toupper(static_cast<unsigned char>(c)));
29 if (c >= 'A' && c <= 'F')
30 return c - 'A' + 10;
31 return -1;
32 };
33
34 for (std::size_t i = 0; i < url_path.size(); ++i)
35 {
36 const char c = url_path[i];
37
38 if (c == '%')
39 {
40 if (i + 2 >= url_path.size())
41 {
42 return std::unexpected("Invalid percent-encoding (truncated)");
43 }
44 const int hi = hex_value(url_path[i + 1]);
45 const int lo = hex_value(url_path[i + 2]);
46 if (hi < 0 || lo < 0)
47 {
48 return std::unexpected("Invalid percent-encoding (non-hex)");
49 }
50 const char decoded_char = static_cast<char>((hi << 4) | lo);
51 if (decoded_char == '\0')
52 {
53 return std::unexpected("Embedded NUL in URL path");
54 }
55 decoded.push_back(decoded_char);
56 i += 2;
57 }
58 else if (c == '\\')
59 {
60 // Disallow Windows-style separators in URL paths
61 return std::unexpected("Backslash not allowed in URL path");
62 }
63 else if (c == '\0')
64 {
65 return std::unexpected("Embedded NUL in URL path");
66 }
67 else
68 {
69 decoded.push_back(c);
70 }
71 }
72
73 return decoded;
74}
75
76[[nodiscard]] std::expected<std::filesystem::path, std::string> sanitize_path(
77 std::string_view url_path, std::filesystem::path const &root_directory)
78{
79 // Basic validation
80 if (url_path.empty() || url_path.front() != '/')
81 {
82 return std::unexpected("URL path must start with '/'");
83 }
84
85 // Strip query and fragment: /foo/bar?x=1#frag → /foo/bar
86 if (const auto pos = url_path.find_first_of("?#"); pos != std::string_view::npos)
87 {
88 url_path = url_path.substr(0, pos);
89 }
90
91 // Drop leading '/'
92 url_path.remove_prefix(1);
93
94 // Decode percent-encodings
95 auto decoded_result = decode_url_path(url_path);
96 if (!decoded_result)
97 {
98 return std::unexpected(decoded_result.error());
99 }
100 const std::string &decoded_path = decoded_result.value();
101
102 // Build requested path under the root
103 std::filesystem::path requested_path = root_directory;
104
105 for (const auto &part : std::filesystem::path(decoded_path))
106 {
107 const auto &native = part.native();
108 if (native.empty() || native == ".")
109 {
110 continue;
111 }
112 if (native == "..")
113 {
114 return std::unexpected("Path traversal detected");
115 }
116 requested_path /= part;
117 }
118
119 std::error_code ec;
120
121 // Canonicalize root and requested paths to defend against symlink escapes
122 const auto canonical_root = std::filesystem::weakly_canonical(root_directory, ec);
123 if (ec)
124 {
125 return std::unexpected("Failed to canonicalize root directory: " + ec.message());
126 }
127
128 if (std::filesystem::is_directory(requested_path, ec) && !ec)
129 {
130 requested_path /= "index.html";
131 }
132
133 const auto canonical_requested = std::filesystem::weakly_canonical(requested_path, ec);
134 if (ec)
135 {
136 return std::unexpected("Failed to canonicalize requested path: " + ec.message());
137 }
138
139 // Ensure the requested path remains inside the document root
140 const auto root_str = canonical_root.native();
141 const auto req_str = canonical_requested.native();
142
143 if (req_str.size() < root_str.size() || !std::equal(root_str.begin(), root_str.end(), req_str.begin()) ||
144 (req_str.size() > root_str.size() && req_str[root_str.size()] != std::filesystem::path::preferred_separator))
145 {
146 return std::unexpected("Requested path escapes document root");
147 }
148
149 return canonical_requested;
150}
151
152namespace detail
153{
154// Centralized MIME table, constexpr and case-insensitive
155[[nodiscard]] inline std::string_view mime_type_for_extension(std::string_view ext)
156{
157 constexpr std::array kMimeTable{
158 std::pair{std::string_view{".html"}, "text/html"},
159 std::pair{std::string_view{".htm"}, "text/html"},
160 std::pair{std::string_view{".css"}, "text/css"},
161 std::pair{std::string_view{".js"}, "application/javascript"},
162 std::pair{std::string_view{".json"}, "application/json"},
163 std::pair{std::string_view{".png"}, "image/png"},
164 std::pair{std::string_view{".jpg"}, "image/jpeg"},
165 std::pair{std::string_view{".jpeg"}, "image/jpeg"},
166 std::pair{std::string_view{".gif"}, "image/gif"},
167 std::pair{std::string_view{".svg"}, "image/svg+xml"},
168 std::pair{std::string_view{".txt"}, "text/plain"},
169 };
170
171 // Normalize to lowercase to avoid case-sensitivity issues
172 std::string lower_ext(ext);
173 std::ranges::transform(lower_ext, lower_ext.begin(),
174 [](unsigned char c) { return static_cast<char>(std::tolower(c)); });
175
176 if (auto it = std::ranges::find_if(kMimeTable, [&lower_ext](auto const &p) { return p.first == lower_ext; });
177 it != kMimeTable.end())
178 {
179 return it->second;
180 }
181
182 return "application/octet-stream";
183}
184
185// Utility for generating error responses
186[[nodiscard]] inline HttpResponse make_error_response(uint16_t code, std::string_view reason,
187 std::string_view body_msg = {})
188{
189 HttpResponse r{};
190 r.http_major = 1;
191 r.http_minor = 1;
192 r.status_code = code;
193 r.reason_phrase = std::string(reason);
194
195 if (!body_msg.empty())
196 {
197 const auto *body_bytes = std::bit_cast<const std::byte *>(body_msg.data());
198 const auto body_size = body_msg.size();
199
200 r.body.assign(body_bytes, body_bytes + body_size);
201 r.headers.emplace_back("Content-Length", std::to_string(r.body.size()));
202 r.headers.emplace_back("Content-Type", "text/plain; charset=utf-8");
203 }
204 else
205 {
206 r.headers.emplace_back("Content-Length", "0");
207 }
208
209 // Basic security hardening header
210 r.headers.emplace_back("X-Content-Type-Options", "nosniff");
211
212 return r;
213}
214
215} // namespace detail
216
217HttpResponse handle_get_request(HttpRequest const &request, std::filesystem::path const &root_directory)
218{
219 using namespace detail;
220
221 // In production you may want to *respond* with 405 rather than assert.
222 assert(request.method == "GET");
223
224 auto sanitized_path_result = sanitize_path(request.url, root_directory);
225 if (!sanitized_path_result.has_value())
226 {
227 std::println(std::cerr, "Error sanitizing path: {}", sanitized_path_result.error());
228 return make_error_response(400, "Bad Request", "Invalid path.\n");
229 }
230
231 const std::filesystem::path &file_path = sanitized_path_result.value();
232
233 std::error_code ec;
234 if (!std::filesystem::exists(file_path, ec) || !std::filesystem::is_regular_file(file_path, ec))
235 {
236 return make_error_response(404, "Not Found", "File not found.\n");
237 }
238
239 const auto file_size_u = std::filesystem::file_size(file_path, ec);
240 if (ec)
241 {
242 std::println(std::cerr, "file_size error: {}", ec.message());
243 return make_error_response(500, "Internal Server Error");
244 }
245
246 // Guard against overflow when converting to size_t
247 if (file_size_u > static_cast<std::uintmax_t>(std::numeric_limits<std::size_t>::max()))
248 {
249 std::println(std::cerr, "File too large to serve: {}", file_path.string());
250 return make_error_response(413, "Payload Too Large", "Requested file is too large.\n");
251 }
252
253 const auto file_size = static_cast<std::size_t>(file_size_u);
254
255 HttpResponse response{};
256 response.http_major = request.http_major ? request.http_major : 1;
257 response.http_minor = request.http_minor ? request.http_minor : 1;
258 response.status_code = 200;
259 response.reason_phrase = "OK";
260
261 response.body.resize(file_size);
262
263 const auto ext = file_path.extension().string();
264 const auto mime = mime_type_for_extension(ext);
265 response.headers.emplace_back("Content-Type", std::string(mime));
266 response.headers.emplace_back("Content-Length", std::to_string(file_size));
267 response.headers.emplace_back("X-Content-Type-Options", "nosniff");
268
269 {
270 std::ifstream file(file_path, std::ios::binary);
271 if (!file)
272 {
273 std::println(std::cerr, "Failed to open {}", file_path.string());
274 return make_error_response(500, "Internal Server Error");
275 }
276
277 file.read(std::bit_cast<char *>(response.body.data()), static_cast<std::streamsize>(file_size));
278
279 if (!file)
280 {
281 std::println(std::cerr, "Error reading {}", file_path.string());
282 return make_error_response(500, "Internal Server Error");
283 }
284 }
285
286 return response;
287}
288
289[[nodiscard]] std::vector<HttpResponse> generate_responses(RequestList const &requests,
290 std::filesystem::path const &root_directory)
291{
292 using namespace detail;
293
294 std::vector<HttpResponse> responses;
295 responses.reserve(requests.size());
296
297 for (auto const &request : requests)
298 {
299 if (request.method == "GET")
300 {
301 responses.push_back(handle_get_request(request, root_directory));
302 }
303 else
304 {
305 auto response = make_error_response(405, "Method Not Allowed", "Only GET is supported.\n");
306 response.headers.emplace_back("Allow", "GET");
307 responses.push_back(std::move(response));
308 }
309 }
310
311 return responses;
312}
313
314kev::task<RequestList> read_requests(Client &client)
315{
316 using namespace stdexec;
317
318 auto &read_buffer = client.m_buffer;
319 while (client.m_parser.get_number_of_completed_requests() == 0 && !client.m_disconnect_requested)
320 {
321 size_t bytes_read = co_await client.m_uring_ctx.async_read(client.m_fd.get(),
322 std::span(read_buffer.data(), read_buffer.size()));
323 if (bytes_read == 0)
324 {
325 client.m_disconnect_requested = true;
326 break;
327 }
328
329 auto const data = std::span(read_buffer.data(), bytes_read);
330 auto result = client.m_parser.parse(data);
331 if (!result.has_value())
332 {
333 std::println(std::cerr, "Error parsing HTTP request. Message: {}", result.error());
334 std::println("Resetting parser state.");
335 client.m_parser.reset();
336 throw std::runtime_error("Error parsing HTTP request");
337 }
338 }
339 RequestList requests;
340 client.m_parser.get_completed_requests(requests);
341 co_return requests;
342}
343
344kev::task<void> handle_requests(Client const &client, ServerContext &server_context, RequestList requests)
345{
346 using namespace stdexec;
347
348 std::vector<HttpResponse> responses = generate_responses(requests, server_context.m_root_directory);
349
350 std::vector<std::byte> bytes_to_write;
351 bytes_to_write.reserve(4096);
352 for (auto const &response : responses)
353 {
354 response.serialize_into(bytes_to_write);
355 }
356
357 co_await client.m_uring_ctx.async_write_all(
358 client.m_fd.get(), std::span<const std::byte>(bytes_to_write.data(), bytes_to_write.size()));
359}
360
361kev::task<void> handle_connection_coroutine(FileDescriptor client_fd, ServerContext &ctx, UringContext &uring_ctx)
362{
363 // Client will get hoisted into the coroutine frame and it's lifetime is managed automatically
364 // Passing by reference to read_requests_coroutine and handle_requests_coroutine is safe because
365 // they are called from within this coroutine and thus cannot outlive it.
366 auto client = Client(std::move(client_fd), uring_ctx);
367 while (!client.m_disconnect_requested)
368 {
369 RequestList requests = co_await read_requests(client);
370 co_await handle_requests(client, ctx, std::move(requests));
371 }
372}