Serenity Operating System
at master 145 lines 5.5 kB view raw
1/* 2 * Copyright (c) 2021, Jan de Visser <jan@de-visser.net> 3 * 4 * SPDX-License-Identifier: BSD-2-Clause 5 */ 6 7#include <LibCore/Object.h> 8#include <LibSQL/AST/Parser.h> 9#include <SQLServer/ConnectionFromClient.h> 10#include <SQLServer/DatabaseConnection.h> 11#include <SQLServer/SQLStatement.h> 12 13namespace SQLServer { 14 15static HashMap<SQL::StatementID, NonnullRefPtr<SQLStatement>> s_statements; 16static SQL::StatementID s_next_statement_id = 0; 17 18RefPtr<SQLStatement> SQLStatement::statement_for(SQL::StatementID statement_id) 19{ 20 if (s_statements.contains(statement_id)) 21 return *s_statements.get(statement_id).value(); 22 dbgln_if(SQLSERVER_DEBUG, "Invalid statement_id {}", statement_id); 23 return nullptr; 24} 25 26SQL::ResultOr<NonnullRefPtr<SQLStatement>> SQLStatement::create(DatabaseConnection& connection, StringView sql) 27{ 28 auto parser = SQL::AST::Parser(SQL::AST::Lexer(sql)); 29 auto statement = parser.next_statement(); 30 31 if (parser.has_errors()) 32 return SQL::Result { SQL::SQLCommand::Unknown, SQL::SQLErrorCode::SyntaxError, parser.errors()[0].to_deprecated_string() }; 33 34 return TRY(adopt_nonnull_ref_or_enomem(new (nothrow) SQLStatement(connection, move(statement)))); 35} 36 37SQLStatement::SQLStatement(DatabaseConnection& connection, NonnullRefPtr<SQL::AST::Statement> statement) 38 : Core::Object(&connection) 39 , m_statement_id(s_next_statement_id++) 40 , m_statement(move(statement)) 41{ 42 dbgln_if(SQLSERVER_DEBUG, "SQLStatement({})", connection.connection_id()); 43 s_statements.set(m_statement_id, *this); 44} 45 46void SQLStatement::report_error(SQL::Result result, SQL::ExecutionID execution_id) 47{ 48 dbgln_if(SQLSERVER_DEBUG, "SQLStatement::report_error(statement_id {}, error {}", statement_id(), result.error_string()); 49 50 auto client_connection = ConnectionFromClient::client_connection_for(connection()->client_id()); 51 52 s_statements.remove(statement_id()); 53 remove_from_parent(); 54 55 if (client_connection) 56 client_connection->async_execution_error(statement_id(), execution_id, result.error(), result.error_string()); 57 else 58 warnln("Cannot return execution error. Client disconnected"); 59} 60 61Optional<SQL::ExecutionID> SQLStatement::execute(Vector<SQL::Value> placeholder_values) 62{ 63 dbgln_if(SQLSERVER_DEBUG, "SQLStatement::execute(statement_id {}", statement_id()); 64 65 auto client_connection = ConnectionFromClient::client_connection_for(connection()->client_id()); 66 if (!client_connection) { 67 warnln("Cannot yield next result. Client disconnected"); 68 return {}; 69 } 70 71 auto execution_id = m_next_execution_id++; 72 m_ongoing_executions.set(execution_id); 73 74 deferred_invoke([this, placeholder_values = move(placeholder_values), execution_id] { 75 auto execution_result = m_statement->execute(connection()->database(), placeholder_values); 76 m_ongoing_executions.remove(execution_id); 77 78 if (execution_result.is_error()) { 79 report_error(execution_result.release_error(), execution_id); 80 return; 81 } 82 83 auto client_connection = ConnectionFromClient::client_connection_for(connection()->client_id()); 84 if (!client_connection) { 85 warnln("Cannot return statement execution results. Client disconnected"); 86 return; 87 } 88 89 auto result = execution_result.release_value(); 90 91 if (should_send_result_rows(result)) { 92 client_connection->async_execution_success(statement_id(), execution_id, result.column_names(), true, 0, 0, 0); 93 94 auto result_size = result.size(); 95 next(execution_id, move(result), result_size); 96 } else { 97 if (result.command() == SQL::SQLCommand::Insert) 98 client_connection->async_execution_success(statement_id(), execution_id, result.column_names(), false, result.size(), 0, 0); 99 else if (result.command() == SQL::SQLCommand::Update) 100 client_connection->async_execution_success(statement_id(), execution_id, result.column_names(), false, 0, result.size(), 0); 101 else if (result.command() == SQL::SQLCommand::Delete) 102 client_connection->async_execution_success(statement_id(), execution_id, result.column_names(), false, 0, 0, result.size()); 103 else 104 client_connection->async_execution_success(statement_id(), execution_id, result.column_names(), false, 0, 0, 0); 105 } 106 }); 107 108 return execution_id; 109} 110 111bool SQLStatement::should_send_result_rows(SQL::ResultSet const& result) const 112{ 113 if (result.is_empty()) 114 return false; 115 116 switch (result.command()) { 117 case SQL::SQLCommand::Describe: 118 case SQL::SQLCommand::Select: 119 return true; 120 default: 121 return false; 122 } 123} 124 125void SQLStatement::next(SQL::ExecutionID execution_id, SQL::ResultSet result, size_t result_size) 126{ 127 auto client_connection = ConnectionFromClient::client_connection_for(connection()->client_id()); 128 if (!client_connection) { 129 warnln("Cannot yield next result. Client disconnected"); 130 return; 131 } 132 133 if (!result.is_empty()) { 134 auto result_row = result.take_first(); 135 client_connection->async_next_result(statement_id(), execution_id, result_row.row.take_data()); 136 137 deferred_invoke([this, execution_id, result = move(result), result_size]() mutable { 138 next(execution_id, move(result), result_size); 139 }); 140 } else { 141 client_connection->async_results_exhausted(statement_id(), execution_id, result_size); 142 } 143} 144 145}