Serenity Operating System
1/*
2 * Copyright (c) 2021, Jan de Visser <jan@de-visser.net>
3 * Copyright (c) 2022, the SerenityOS developers.
4 *
5 * SPDX-License-Identifier: BSD-2-Clause
6 */
7
8#include <AK/DeprecatedString.h>
9#include <AK/String.h>
10#include <LibSQL/SQLClient.h>
11
12#if !defined(AK_OS_SERENITY)
13# include <LibCore/DeprecatedFile.h>
14# include <LibCore/Directory.h>
15# include <LibCore/SocketAddress.h>
16# include <LibCore/StandardPaths.h>
17# include <LibCore/System.h>
18#endif
19
20namespace SQL {
21
22#if !defined(AK_OS_SERENITY)
23
24// This is heavily based on how SystemServer's Service creates its socket.
25static ErrorOr<int> create_database_socket(DeprecatedString const& socket_path)
26{
27 if (Core::DeprecatedFile::exists(socket_path))
28 TRY(Core::System::unlink(socket_path));
29
30# ifdef SOCK_NONBLOCK
31 auto socket_fd = TRY(Core::System::socket(AF_LOCAL, SOCK_STREAM | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
32# else
33 auto socket_fd = TRY(Core::System::socket(AF_LOCAL, SOCK_STREAM, 0));
34
35 int option = 1;
36 TRY(Core::System::ioctl(socket_fd, FIONBIO, &option));
37 TRY(Core::System::fcntl(socket_fd, F_SETFD, FD_CLOEXEC));
38# endif
39
40# if !defined(AK_OS_BSD_GENERIC)
41 TRY(Core::System::fchmod(socket_fd, 0600));
42# endif
43
44 auto socket_address = Core::SocketAddress::local(socket_path);
45 auto socket_address_un = socket_address.to_sockaddr_un().release_value();
46
47 TRY(Core::System::bind(socket_fd, reinterpret_cast<sockaddr*>(&socket_address_un), sizeof(socket_address_un)));
48 TRY(Core::System::listen(socket_fd, 16));
49
50 return socket_fd;
51}
52
53static ErrorOr<void> launch_server(DeprecatedString const& socket_path, DeprecatedString const& pid_path, Vector<String> candidate_server_paths)
54{
55 auto server_fd_or_error = create_database_socket(socket_path);
56 if (server_fd_or_error.is_error()) {
57 warnln("Failed to create a database socket at {}: {}", socket_path, server_fd_or_error.error());
58 return server_fd_or_error.release_error();
59 }
60 auto server_fd = server_fd_or_error.value();
61 auto server_pid = TRY(Core::System::fork());
62
63 if (server_pid == 0) {
64 TRY(Core::System::setsid());
65 TRY(Core::System::signal(SIGCHLD, SIG_IGN));
66 server_pid = TRY(Core::System::fork());
67
68 if (server_pid != 0) {
69 auto server_pid_file = TRY(Core::File::open(pid_path, Core::File::OpenMode::Write));
70 TRY(server_pid_file->write_until_depleted(DeprecatedString::number(server_pid).bytes()));
71
72 TRY(Core::System::kill(getpid(), SIGTERM));
73 }
74
75 server_fd = TRY(Core::System::dup(server_fd));
76
77 auto takeover_string = DeprecatedString::formatted("SQLServer:{}", server_fd);
78 TRY(Core::System::setenv("SOCKET_TAKEOVER"sv, takeover_string, true));
79
80 ErrorOr<void> result;
81 for (auto const& server_path : candidate_server_paths) {
82 auto arguments = Array {
83 server_path.bytes_as_string_view(),
84 "--pid-file"sv,
85 pid_path,
86 };
87 result = Core::System::exec(arguments[0], arguments, Core::System::SearchInPath::Yes);
88 if (!result.is_error())
89 break;
90 }
91 if (result.is_error()) {
92 warnln("Could not launch any of {}: {}", candidate_server_paths, result.error());
93 TRY(Core::System::unlink(pid_path));
94 }
95
96 VERIFY_NOT_REACHED();
97 }
98
99 TRY(Core::System::waitpid(server_pid));
100 return {};
101}
102
103static ErrorOr<bool> should_launch_server(DeprecatedString const& pid_path)
104{
105 if (!Core::DeprecatedFile::exists(pid_path))
106 return true;
107
108 Optional<pid_t> pid;
109 {
110 auto server_pid_file = Core::File::open(pid_path, Core::File::OpenMode::Read);
111 if (server_pid_file.is_error()) {
112 warnln("Could not open SQLServer PID file '{}': {}", pid_path, server_pid_file.error());
113 return server_pid_file.release_error();
114 }
115
116 auto contents = server_pid_file.value()->read_until_eof();
117 if (contents.is_error()) {
118 warnln("Could not read SQLServer PID file '{}': {}", pid_path, contents.error());
119 return contents.release_error();
120 }
121
122 pid = StringView { contents.value() }.to_int<pid_t>();
123 }
124
125 if (!pid.has_value()) {
126 warnln("SQLServer PID file '{}' exists, but with an invalid PID", pid_path);
127 TRY(Core::System::unlink(pid_path));
128 return true;
129 }
130 if (kill(*pid, 0) < 0) {
131 warnln("SQLServer PID file '{}' exists with PID {}, but process cannot be found", pid_path, *pid);
132 TRY(Core::System::unlink(pid_path));
133 return true;
134 }
135
136 return false;
137}
138
139ErrorOr<NonnullRefPtr<SQLClient>> SQLClient::launch_server_and_create_client(Vector<String> candidate_server_paths)
140{
141 auto runtime_directory = TRY(Core::StandardPaths::runtime_directory());
142 auto socket_path = DeprecatedString::formatted("{}/SQLServer.socket", runtime_directory);
143 auto pid_path = DeprecatedString::formatted("{}/SQLServer.pid", runtime_directory);
144
145 if (TRY(should_launch_server(pid_path)))
146 TRY(launch_server(socket_path, pid_path, move(candidate_server_paths)));
147
148 auto socket = TRY(Core::LocalSocket::connect(move(socket_path)));
149 TRY(socket->set_blocking(true));
150
151 return adopt_nonnull_ref_or_enomem(new (nothrow) SQLClient(move(socket)));
152}
153
154#endif
155
156void SQLClient::execution_success(u64 statement_id, u64 execution_id, Vector<DeprecatedString> const& column_names, bool has_results, size_t created, size_t updated, size_t deleted)
157{
158 if (!on_execution_success) {
159 outln("{} row(s) created, {} updated, {} deleted", created, updated, deleted);
160 return;
161 }
162
163 ExecutionSuccess success {
164 .statement_id = statement_id,
165 .execution_id = execution_id,
166 .column_names = move(const_cast<Vector<DeprecatedString>&>(column_names)),
167 .has_results = has_results,
168 .rows_created = created,
169 .rows_updated = updated,
170 .rows_deleted = deleted,
171 };
172
173 on_execution_success(move(success));
174}
175
176void SQLClient::execution_error(u64 statement_id, u64 execution_id, SQLErrorCode const& code, DeprecatedString const& message)
177{
178 if (!on_execution_error) {
179 warnln("Execution error for statement_id {}: {} ({})", statement_id, message, to_underlying(code));
180 return;
181 }
182
183 ExecutionError error {
184 .statement_id = statement_id,
185 .execution_id = execution_id,
186 .error_code = code,
187 .error_message = move(const_cast<DeprecatedString&>(message)),
188 };
189
190 on_execution_error(move(error));
191}
192
193void SQLClient::next_result(u64 statement_id, u64 execution_id, Vector<Value> const& row)
194{
195 if (!on_next_result) {
196 StringBuilder builder;
197 builder.join(", "sv, row, "\"{}\""sv);
198 outln("{}", builder.string_view());
199 return;
200 }
201
202 ExecutionResult result {
203 .statement_id = statement_id,
204 .execution_id = execution_id,
205 .values = move(const_cast<Vector<Value>&>(row)),
206 };
207
208 on_next_result(move(result));
209}
210
211void SQLClient::results_exhausted(u64 statement_id, u64 execution_id, size_t total_rows)
212{
213 if (!on_results_exhausted) {
214 outln("{} total row(s)", total_rows);
215 return;
216 }
217
218 ExecutionComplete success {
219 .statement_id = statement_id,
220 .execution_id = execution_id,
221 .total_rows = total_rows,
222 };
223
224 on_results_exhausted(move(success));
225}
226
227}