this repo has no description
at main 372 lines 12 kB view raw
1// Project Includes 2#include "client.hpp" 3#include "file_descriptor.hpp" 4#include "http.hpp" 5// Standard Library Includes 6#include <algorithm> 7#include <array> 8#include <bit> 9#include <cassert> 10#include <cctype> 11#include <filesystem> 12#include <fstream> 13#include <iostream> 14#include <print> 15#include <span> 16#include <string_view> 17#include <system_error> 18#include <vector> 19 20[[nodiscard]] static std::expected<std::string, std::string> decode_url_path(std::string_view url_path) 21{ 22 std::string decoded; 23 decoded.reserve(url_path.size()); 24 25 auto hex_value = [](char c) -> int { 26 if (c >= '0' && c <= '9') 27 return c - '0'; 28 c = static_cast<char>(std::toupper(static_cast<unsigned char>(c))); 29 if (c >= 'A' && c <= 'F') 30 return c - 'A' + 10; 31 return -1; 32 }; 33 34 for (std::size_t i = 0; i < url_path.size(); ++i) 35 { 36 const char c = url_path[i]; 37 38 if (c == '%') 39 { 40 if (i + 2 >= url_path.size()) 41 { 42 return std::unexpected("Invalid percent-encoding (truncated)"); 43 } 44 const int hi = hex_value(url_path[i + 1]); 45 const int lo = hex_value(url_path[i + 2]); 46 if (hi < 0 || lo < 0) 47 { 48 return std::unexpected("Invalid percent-encoding (non-hex)"); 49 } 50 const char decoded_char = static_cast<char>((hi << 4) | lo); 51 if (decoded_char == '\0') 52 { 53 return std::unexpected("Embedded NUL in URL path"); 54 } 55 decoded.push_back(decoded_char); 56 i += 2; 57 } 58 else if (c == '\\') 59 { 60 // Disallow Windows-style separators in URL paths 61 return std::unexpected("Backslash not allowed in URL path"); 62 } 63 else if (c == '\0') 64 { 65 return std::unexpected("Embedded NUL in URL path"); 66 } 67 else 68 { 69 decoded.push_back(c); 70 } 71 } 72 73 return decoded; 74} 75 76[[nodiscard]] std::expected<std::filesystem::path, std::string> sanitize_path( 77 std::string_view url_path, std::filesystem::path const &root_directory) 78{ 79 // Basic validation 80 if (url_path.empty() || url_path.front() != '/') 81 { 82 return std::unexpected("URL path must start with '/'"); 83 } 84 85 // Strip query and fragment: /foo/bar?x=1#frag → /foo/bar 86 if (const auto pos = url_path.find_first_of("?#"); pos != std::string_view::npos) 87 { 88 url_path = url_path.substr(0, pos); 89 } 90 91 // Drop leading '/' 92 url_path.remove_prefix(1); 93 94 // Decode percent-encodings 95 auto decoded_result = decode_url_path(url_path); 96 if (!decoded_result) 97 { 98 return std::unexpected(decoded_result.error()); 99 } 100 const std::string &decoded_path = decoded_result.value(); 101 102 // Build requested path under the root 103 std::filesystem::path requested_path = root_directory; 104 105 for (const auto &part : std::filesystem::path(decoded_path)) 106 { 107 const auto &native = part.native(); 108 if (native.empty() || native == ".") 109 { 110 continue; 111 } 112 if (native == "..") 113 { 114 return std::unexpected("Path traversal detected"); 115 } 116 requested_path /= part; 117 } 118 119 std::error_code ec; 120 121 // Canonicalize root and requested paths to defend against symlink escapes 122 const auto canonical_root = std::filesystem::weakly_canonical(root_directory, ec); 123 if (ec) 124 { 125 return std::unexpected("Failed to canonicalize root directory: " + ec.message()); 126 } 127 128 if (std::filesystem::is_directory(requested_path, ec) && !ec) 129 { 130 requested_path /= "index.html"; 131 } 132 133 const auto canonical_requested = std::filesystem::weakly_canonical(requested_path, ec); 134 if (ec) 135 { 136 return std::unexpected("Failed to canonicalize requested path: " + ec.message()); 137 } 138 139 // Ensure the requested path remains inside the document root 140 const auto root_str = canonical_root.native(); 141 const auto req_str = canonical_requested.native(); 142 143 if (req_str.size() < root_str.size() || !std::equal(root_str.begin(), root_str.end(), req_str.begin()) || 144 (req_str.size() > root_str.size() && req_str[root_str.size()] != std::filesystem::path::preferred_separator)) 145 { 146 return std::unexpected("Requested path escapes document root"); 147 } 148 149 return canonical_requested; 150} 151 152namespace detail 153{ 154// Centralized MIME table, constexpr and case-insensitive 155[[nodiscard]] inline std::string_view mime_type_for_extension(std::string_view ext) 156{ 157 constexpr std::array kMimeTable{ 158 std::pair{std::string_view{".html"}, "text/html"}, 159 std::pair{std::string_view{".htm"}, "text/html"}, 160 std::pair{std::string_view{".css"}, "text/css"}, 161 std::pair{std::string_view{".js"}, "application/javascript"}, 162 std::pair{std::string_view{".json"}, "application/json"}, 163 std::pair{std::string_view{".png"}, "image/png"}, 164 std::pair{std::string_view{".jpg"}, "image/jpeg"}, 165 std::pair{std::string_view{".jpeg"}, "image/jpeg"}, 166 std::pair{std::string_view{".gif"}, "image/gif"}, 167 std::pair{std::string_view{".svg"}, "image/svg+xml"}, 168 std::pair{std::string_view{".txt"}, "text/plain"}, 169 }; 170 171 // Normalize to lowercase to avoid case-sensitivity issues 172 std::string lower_ext(ext); 173 std::ranges::transform(lower_ext, lower_ext.begin(), 174 [](unsigned char c) { return static_cast<char>(std::tolower(c)); }); 175 176 if (auto it = std::ranges::find_if(kMimeTable, [&lower_ext](auto const &p) { return p.first == lower_ext; }); 177 it != kMimeTable.end()) 178 { 179 return it->second; 180 } 181 182 return "application/octet-stream"; 183} 184 185// Utility for generating error responses 186[[nodiscard]] inline HttpResponse make_error_response(uint16_t code, std::string_view reason, 187 std::string_view body_msg = {}) 188{ 189 HttpResponse r{}; 190 r.http_major = 1; 191 r.http_minor = 1; 192 r.status_code = code; 193 r.reason_phrase = std::string(reason); 194 195 if (!body_msg.empty()) 196 { 197 const auto *body_bytes = std::bit_cast<const std::byte *>(body_msg.data()); 198 const auto body_size = body_msg.size(); 199 200 r.body.assign(body_bytes, body_bytes + body_size); 201 r.headers.emplace_back("Content-Length", std::to_string(r.body.size())); 202 r.headers.emplace_back("Content-Type", "text/plain; charset=utf-8"); 203 } 204 else 205 { 206 r.headers.emplace_back("Content-Length", "0"); 207 } 208 209 // Basic security hardening header 210 r.headers.emplace_back("X-Content-Type-Options", "nosniff"); 211 212 return r; 213} 214 215} // namespace detail 216 217HttpResponse handle_get_request(HttpRequest const &request, std::filesystem::path const &root_directory) 218{ 219 using namespace detail; 220 221 // In production you may want to *respond* with 405 rather than assert. 222 assert(request.method == "GET"); 223 224 auto sanitized_path_result = sanitize_path(request.url, root_directory); 225 if (!sanitized_path_result.has_value()) 226 { 227 std::println(std::cerr, "Error sanitizing path: {}", sanitized_path_result.error()); 228 return make_error_response(400, "Bad Request", "Invalid path.\n"); 229 } 230 231 const std::filesystem::path &file_path = sanitized_path_result.value(); 232 233 std::error_code ec; 234 if (!std::filesystem::exists(file_path, ec) || !std::filesystem::is_regular_file(file_path, ec)) 235 { 236 return make_error_response(404, "Not Found", "File not found.\n"); 237 } 238 239 const auto file_size_u = std::filesystem::file_size(file_path, ec); 240 if (ec) 241 { 242 std::println(std::cerr, "file_size error: {}", ec.message()); 243 return make_error_response(500, "Internal Server Error"); 244 } 245 246 // Guard against overflow when converting to size_t 247 if (file_size_u > static_cast<std::uintmax_t>(std::numeric_limits<std::size_t>::max())) 248 { 249 std::println(std::cerr, "File too large to serve: {}", file_path.string()); 250 return make_error_response(413, "Payload Too Large", "Requested file is too large.\n"); 251 } 252 253 const auto file_size = static_cast<std::size_t>(file_size_u); 254 255 HttpResponse response{}; 256 response.http_major = request.http_major ? request.http_major : 1; 257 response.http_minor = request.http_minor ? request.http_minor : 1; 258 response.status_code = 200; 259 response.reason_phrase = "OK"; 260 261 response.body.resize(file_size); 262 263 const auto ext = file_path.extension().string(); 264 const auto mime = mime_type_for_extension(ext); 265 response.headers.emplace_back("Content-Type", std::string(mime)); 266 response.headers.emplace_back("Content-Length", std::to_string(file_size)); 267 response.headers.emplace_back("X-Content-Type-Options", "nosniff"); 268 269 { 270 std::ifstream file(file_path, std::ios::binary); 271 if (!file) 272 { 273 std::println(std::cerr, "Failed to open {}", file_path.string()); 274 return make_error_response(500, "Internal Server Error"); 275 } 276 277 file.read(std::bit_cast<char *>(response.body.data()), static_cast<std::streamsize>(file_size)); 278 279 if (!file) 280 { 281 std::println(std::cerr, "Error reading {}", file_path.string()); 282 return make_error_response(500, "Internal Server Error"); 283 } 284 } 285 286 return response; 287} 288 289[[nodiscard]] std::vector<HttpResponse> generate_responses(RequestList const &requests, 290 std::filesystem::path const &root_directory) 291{ 292 using namespace detail; 293 294 std::vector<HttpResponse> responses; 295 responses.reserve(requests.size()); 296 297 for (auto const &request : requests) 298 { 299 if (request.method == "GET") 300 { 301 responses.push_back(handle_get_request(request, root_directory)); 302 } 303 else 304 { 305 auto response = make_error_response(405, "Method Not Allowed", "Only GET is supported.\n"); 306 response.headers.emplace_back("Allow", "GET"); 307 responses.push_back(std::move(response)); 308 } 309 } 310 311 return responses; 312} 313 314kev::task<RequestList> read_requests(Client &client) 315{ 316 using namespace stdexec; 317 318 auto &read_buffer = client.m_buffer; 319 while (client.m_parser.get_number_of_completed_requests() == 0 && !client.m_disconnect_requested) 320 { 321 size_t bytes_read = co_await client.m_uring_ctx.async_read(client.m_fd.get(), 322 std::span(read_buffer.data(), read_buffer.size())); 323 if (bytes_read == 0) 324 { 325 client.m_disconnect_requested = true; 326 break; 327 } 328 329 auto const data = std::span(read_buffer.data(), bytes_read); 330 auto result = client.m_parser.parse(data); 331 if (!result.has_value()) 332 { 333 std::println(std::cerr, "Error parsing HTTP request. Message: {}", result.error()); 334 std::println("Resetting parser state."); 335 client.m_parser.reset(); 336 throw std::runtime_error("Error parsing HTTP request"); 337 } 338 } 339 RequestList requests; 340 client.m_parser.get_completed_requests(requests); 341 co_return requests; 342} 343 344kev::task<void> handle_requests(Client const &client, ServerContext &server_context, RequestList requests) 345{ 346 using namespace stdexec; 347 348 std::vector<HttpResponse> responses = generate_responses(requests, server_context.m_root_directory); 349 350 std::vector<std::byte> bytes_to_write; 351 bytes_to_write.reserve(4096); 352 for (auto const &response : responses) 353 { 354 response.serialize_into(bytes_to_write); 355 } 356 357 co_await client.m_uring_ctx.async_write_all( 358 client.m_fd.get(), std::span<const std::byte>(bytes_to_write.data(), bytes_to_write.size())); 359} 360 361kev::task<void> handle_connection_coroutine(FileDescriptor client_fd, ServerContext &ctx, UringContext &uring_ctx) 362{ 363 // Client will get hoisted into the coroutine frame and it's lifetime is managed automatically 364 // Passing by reference to read_requests_coroutine and handle_requests_coroutine is safe because 365 // they are called from within this coroutine and thus cannot outlive it. 366 auto client = Client(std::move(client_fd), uring_ctx); 367 while (!client.m_disconnect_requested) 368 { 369 RequestList requests = co_await read_requests(client); 370 co_await handle_requests(client, ctx, std::move(requests)); 371 } 372}