Serenity Operating System
1/*
2 * Copyright (c) 2021, Jan de Visser <jan@de-visser.net>
3 *
4 * SPDX-License-Identifier: BSD-2-Clause
5 */
6
7#include <LibCore/Object.h>
8#include <LibSQL/AST/Parser.h>
9#include <SQLServer/ConnectionFromClient.h>
10#include <SQLServer/DatabaseConnection.h>
11#include <SQLServer/SQLStatement.h>
12
13namespace SQLServer {
14
15static HashMap<SQL::StatementID, NonnullRefPtr<SQLStatement>> s_statements;
16static SQL::StatementID s_next_statement_id = 0;
17
18RefPtr<SQLStatement> SQLStatement::statement_for(SQL::StatementID statement_id)
19{
20 if (s_statements.contains(statement_id))
21 return *s_statements.get(statement_id).value();
22 dbgln_if(SQLSERVER_DEBUG, "Invalid statement_id {}", statement_id);
23 return nullptr;
24}
25
26SQL::ResultOr<NonnullRefPtr<SQLStatement>> SQLStatement::create(DatabaseConnection& connection, StringView sql)
27{
28 auto parser = SQL::AST::Parser(SQL::AST::Lexer(sql));
29 auto statement = parser.next_statement();
30
31 if (parser.has_errors())
32 return SQL::Result { SQL::SQLCommand::Unknown, SQL::SQLErrorCode::SyntaxError, parser.errors()[0].to_deprecated_string() };
33
34 return TRY(adopt_nonnull_ref_or_enomem(new (nothrow) SQLStatement(connection, move(statement))));
35}
36
37SQLStatement::SQLStatement(DatabaseConnection& connection, NonnullRefPtr<SQL::AST::Statement> statement)
38 : Core::Object(&connection)
39 , m_statement_id(s_next_statement_id++)
40 , m_statement(move(statement))
41{
42 dbgln_if(SQLSERVER_DEBUG, "SQLStatement({})", connection.connection_id());
43 s_statements.set(m_statement_id, *this);
44}
45
46void SQLStatement::report_error(SQL::Result result, SQL::ExecutionID execution_id)
47{
48 dbgln_if(SQLSERVER_DEBUG, "SQLStatement::report_error(statement_id {}, error {}", statement_id(), result.error_string());
49
50 auto client_connection = ConnectionFromClient::client_connection_for(connection()->client_id());
51
52 s_statements.remove(statement_id());
53 remove_from_parent();
54
55 if (client_connection)
56 client_connection->async_execution_error(statement_id(), execution_id, result.error(), result.error_string());
57 else
58 warnln("Cannot return execution error. Client disconnected");
59}
60
61Optional<SQL::ExecutionID> SQLStatement::execute(Vector<SQL::Value> placeholder_values)
62{
63 dbgln_if(SQLSERVER_DEBUG, "SQLStatement::execute(statement_id {}", statement_id());
64
65 auto client_connection = ConnectionFromClient::client_connection_for(connection()->client_id());
66 if (!client_connection) {
67 warnln("Cannot yield next result. Client disconnected");
68 return {};
69 }
70
71 auto execution_id = m_next_execution_id++;
72 m_ongoing_executions.set(execution_id);
73
74 deferred_invoke([this, placeholder_values = move(placeholder_values), execution_id] {
75 auto execution_result = m_statement->execute(connection()->database(), placeholder_values);
76 m_ongoing_executions.remove(execution_id);
77
78 if (execution_result.is_error()) {
79 report_error(execution_result.release_error(), execution_id);
80 return;
81 }
82
83 auto client_connection = ConnectionFromClient::client_connection_for(connection()->client_id());
84 if (!client_connection) {
85 warnln("Cannot return statement execution results. Client disconnected");
86 return;
87 }
88
89 auto result = execution_result.release_value();
90
91 if (should_send_result_rows(result)) {
92 client_connection->async_execution_success(statement_id(), execution_id, result.column_names(), true, 0, 0, 0);
93
94 auto result_size = result.size();
95 next(execution_id, move(result), result_size);
96 } else {
97 if (result.command() == SQL::SQLCommand::Insert)
98 client_connection->async_execution_success(statement_id(), execution_id, result.column_names(), false, result.size(), 0, 0);
99 else if (result.command() == SQL::SQLCommand::Update)
100 client_connection->async_execution_success(statement_id(), execution_id, result.column_names(), false, 0, result.size(), 0);
101 else if (result.command() == SQL::SQLCommand::Delete)
102 client_connection->async_execution_success(statement_id(), execution_id, result.column_names(), false, 0, 0, result.size());
103 else
104 client_connection->async_execution_success(statement_id(), execution_id, result.column_names(), false, 0, 0, 0);
105 }
106 });
107
108 return execution_id;
109}
110
111bool SQLStatement::should_send_result_rows(SQL::ResultSet const& result) const
112{
113 if (result.is_empty())
114 return false;
115
116 switch (result.command()) {
117 case SQL::SQLCommand::Describe:
118 case SQL::SQLCommand::Select:
119 return true;
120 default:
121 return false;
122 }
123}
124
125void SQLStatement::next(SQL::ExecutionID execution_id, SQL::ResultSet result, size_t result_size)
126{
127 auto client_connection = ConnectionFromClient::client_connection_for(connection()->client_id());
128 if (!client_connection) {
129 warnln("Cannot yield next result. Client disconnected");
130 return;
131 }
132
133 if (!result.is_empty()) {
134 auto result_row = result.take_first();
135 client_connection->async_next_result(statement_id(), execution_id, result_row.row.take_data());
136
137 deferred_invoke([this, execution_id, result = move(result), result_size]() mutable {
138 next(execution_id, move(result), result_size);
139 });
140 } else {
141 client_connection->async_results_exhausted(statement_id(), execution_id, result_size);
142 }
143}
144
145}