Serenity Operating System
1/*
2 * Copyright (c) 2021, Tim Flynn <trflynn89@serenityos.org>
3 * Copyright (c) 2022, Alex Major
4 *
5 * SPDX-License-Identifier: BSD-2-Clause
6 */
7
8#include <AK/DeprecatedString.h>
9#include <AK/Format.h>
10#include <AK/String.h>
11#include <AK/StringBuilder.h>
12#include <LibCore/ArgsParser.h>
13#include <LibCore/DeprecatedFile.h>
14#include <LibCore/StandardPaths.h>
15#include <LibLine/Editor.h>
16#include <LibMain/Main.h>
17#include <LibSQL/AST/Lexer.h>
18#include <LibSQL/AST/Token.h>
19#include <LibSQL/SQLClient.h>
20#include <unistd.h>
21
22class SQLRepl {
23public:
24 explicit SQLRepl(Core::EventLoop& loop, DeprecatedString const& database_name, NonnullRefPtr<SQL::SQLClient> sql_client)
25 : m_sql_client(move(sql_client))
26 , m_loop(loop)
27 {
28 m_editor = Line::Editor::construct();
29 m_editor->load_history(m_history_path);
30
31 m_editor->on_display_refresh = [this](Line::Editor& editor) {
32 editor.strip_styles();
33
34 int open_indents = m_repl_line_level;
35
36 auto line = editor.line();
37 SQL::AST::Lexer lexer(line);
38
39 bool indenters_starting_line = true;
40 for (SQL::AST::Token token = lexer.next(); token.type() != SQL::AST::TokenType::Eof; token = lexer.next()) {
41 auto start = token.start_position().column - 1;
42 auto end = token.end_position().column - 1;
43
44 if (indenters_starting_line) {
45 if (token.type() != SQL::AST::TokenType::ParenClose)
46 indenters_starting_line = false;
47 else
48 --open_indents;
49 }
50
51 switch (token.category()) {
52 case SQL::AST::TokenCategory::Invalid:
53 editor.stylize({ start, end }, { Line::Style::Foreground(Line::Style::XtermColor::Red), Line::Style::Underline });
54 break;
55 case SQL::AST::TokenCategory::Number:
56 editor.stylize({ start, end }, { Line::Style::Foreground(Line::Style::XtermColor::Magenta) });
57 break;
58 case SQL::AST::TokenCategory::String:
59 editor.stylize({ start, end }, { Line::Style::Foreground(Line::Style::XtermColor::Green), Line::Style::Bold });
60 break;
61 case SQL::AST::TokenCategory::Blob:
62 editor.stylize({ start, end }, { Line::Style::Foreground(Line::Style::XtermColor::Magenta), Line::Style::Bold });
63 break;
64 case SQL::AST::TokenCategory::Keyword:
65 editor.stylize({ start, end }, { Line::Style::Foreground(Line::Style::XtermColor::Blue), Line::Style::Bold });
66 break;
67 case SQL::AST::TokenCategory::Identifier:
68 editor.stylize({ start, end }, { Line::Style::Foreground(Line::Style::XtermColor::White), Line::Style::Bold });
69 break;
70 default:
71 break;
72 }
73 }
74
75 m_editor->set_prompt(prompt_for_level(open_indents));
76 };
77
78 m_sql_client->on_execution_success = [this](auto result) {
79 if (result.rows_updated != 0 || result.rows_created != 0 || result.rows_deleted != 0)
80 outln("{} row(s) created, {} updated, {} deleted", result.rows_created, result.rows_updated, result.rows_deleted);
81 if (!result.has_results)
82 read_sql();
83 };
84
85 m_sql_client->on_next_result = [](auto result) {
86 StringBuilder builder;
87 builder.join(", "sv, result.values);
88 outln("{}", builder.to_deprecated_string());
89 };
90
91 m_sql_client->on_results_exhausted = [this](auto result) {
92 outln("{} row(s)", result.total_rows);
93 read_sql();
94 };
95
96 m_sql_client->on_execution_error = [this](auto result) {
97 outln("\033[33;1mExecution error:\033[0m {}", result.error_message);
98 read_sql();
99 };
100
101 if (!database_name.is_empty())
102 connect(database_name);
103 }
104
105 ~SQLRepl()
106 {
107 m_editor->save_history(m_history_path);
108 }
109
110 void connect(DeprecatedString const& database_name)
111 {
112 if (!m_database_name.is_empty()) {
113 m_sql_client->disconnect(m_connection_id);
114 m_database_name = {};
115 }
116
117 if (auto connection_id = m_sql_client->connect(database_name); connection_id.has_value()) {
118 outln("Connected to \033[33;1m{}\033[0m", database_name);
119 m_database_name = database_name;
120 m_connection_id = *connection_id;
121 } else {
122 warnln("\033[33;1mCould not connect to:\033[0m {}", database_name);
123 m_loop.quit(1);
124 }
125 }
126
127 void source_file(DeprecatedString file_name)
128 {
129 m_input_file_chain.append(move(file_name));
130 m_quit_when_files_read = false;
131 }
132
133 void read_file(DeprecatedString file_name)
134 {
135 m_input_file_chain.append(move(file_name));
136 m_quit_when_files_read = true;
137 }
138
139 auto run()
140 {
141 read_sql();
142 return m_loop.exec();
143 }
144
145private:
146 DeprecatedString m_history_path { DeprecatedString::formatted("{}/.sql-history", Core::StandardPaths::home_directory()) };
147 RefPtr<Line::Editor> m_editor { nullptr };
148 int m_repl_line_level { 0 };
149 bool m_keep_running { true };
150 DeprecatedString m_database_name {};
151 NonnullRefPtr<SQL::SQLClient> m_sql_client;
152 SQL::ConnectionID m_connection_id { 0 };
153 Core::EventLoop& m_loop;
154 OwnPtr<Core::BufferedFile> m_input_file { nullptr };
155 bool m_quit_when_files_read { false };
156 Vector<DeprecatedString> m_input_file_chain {};
157 Array<u8, 4096> m_buffer {};
158
159 Optional<DeprecatedString> get_line()
160 {
161 if (!m_input_file && !m_input_file_chain.is_empty()) {
162 auto file_name = m_input_file_chain.take_first();
163 auto file_or_error = Core::File::open(file_name, Core::File::OpenMode::Read);
164 if (file_or_error.is_error()) {
165 warnln("Input file {} could not be opened: {}", file_name, file_or_error.error());
166 return {};
167 }
168
169 auto buffered_file_or_error = Core::BufferedFile::create(file_or_error.release_value());
170 if (buffered_file_or_error.is_error()) {
171 warnln("Input file {} could not be buffered: {}", file_name, buffered_file_or_error.error());
172 return {};
173 }
174
175 m_input_file = buffered_file_or_error.release_value();
176 }
177 if (m_input_file) {
178 auto line = m_input_file->read_line(m_buffer);
179 if (line.is_error()) {
180 warnln("Failed to read line: {}", line.error());
181 return {};
182 }
183 if (m_input_file->is_eof()) {
184 m_input_file->close();
185 m_input_file = nullptr;
186 if (m_quit_when_files_read && m_input_file_chain.is_empty())
187 return {};
188 }
189 return line.release_value();
190 // If the last file is exhausted but m_quit_when_files_read is false
191 // we fall through to the standard reading from the editor behaviour
192 }
193 auto line_result = m_editor->get_line(prompt_for_level(m_repl_line_level));
194 if (line_result.is_error())
195 return {};
196 return line_result.value();
197 }
198
199 DeprecatedString read_next_piece()
200 {
201 StringBuilder piece;
202
203 do {
204 if (!piece.is_empty())
205 piece.append('\n');
206
207 auto line_maybe = get_line();
208
209 if (!line_maybe.has_value()) {
210 m_keep_running = false;
211 return {};
212 }
213
214 auto& line = line_maybe.value();
215 auto lexer = SQL::AST::Lexer(line);
216
217 m_editor->add_to_history(line);
218 piece.append(line);
219
220 bool is_first_token = true;
221 bool is_command = false;
222 bool last_token_ended_statement = false;
223 bool tokens_found = false;
224
225 for (SQL::AST::Token token = lexer.next(); token.type() != SQL::AST::TokenType::Eof; token = lexer.next()) {
226 tokens_found = true;
227 switch (token.type()) {
228 case SQL::AST::TokenType::ParenOpen:
229 ++m_repl_line_level;
230 break;
231 case SQL::AST::TokenType::ParenClose:
232 --m_repl_line_level;
233 break;
234 case SQL::AST::TokenType::SemiColon:
235 last_token_ended_statement = true;
236 break;
237 case SQL::AST::TokenType::Period:
238 if (is_first_token)
239 is_command = true;
240 break;
241 default:
242 last_token_ended_statement = is_command;
243 break;
244 }
245
246 is_first_token = false;
247 }
248
249 if (tokens_found)
250 m_repl_line_level = last_token_ended_statement ? 0 : (m_repl_line_level > 0 ? m_repl_line_level : 1);
251 } while ((m_repl_line_level > 0) || piece.is_empty());
252
253 return piece.to_deprecated_string();
254 }
255
256 void read_sql()
257 {
258 DeprecatedString piece = read_next_piece();
259
260 // m_keep_running can be set to false when the file we are reading
261 // from is exhausted...
262 if (!m_keep_running) {
263 m_sql_client->disconnect(m_connection_id);
264 m_loop.quit(0);
265 return;
266 }
267
268 if (piece.starts_with('.')) {
269 bool ready_for_input = handle_command(piece);
270 if (ready_for_input)
271 m_loop.deferred_invoke([this]() {
272 read_sql();
273 });
274 } else if (auto statement_id = m_sql_client->prepare_statement(m_connection_id, piece); statement_id.has_value()) {
275 m_sql_client->async_execute_statement(*statement_id, {});
276 } else {
277 warnln("\033[33;1mError parsing SQL statement\033[0m: {}", piece);
278 m_loop.deferred_invoke([this]() {
279 read_sql();
280 });
281 }
282
283 // ...But m_keep_running can also be set to false by a command handler.
284 if (!m_keep_running) {
285 m_sql_client->disconnect(m_connection_id);
286 m_loop.quit(0);
287 return;
288 }
289 };
290
291 static DeprecatedString prompt_for_level(int level)
292 {
293 static StringBuilder prompt_builder;
294 prompt_builder.clear();
295 prompt_builder.append("> "sv);
296
297 for (auto i = 0; i < level; ++i)
298 prompt_builder.append(" "sv);
299
300 return prompt_builder.to_deprecated_string();
301 }
302
303 bool handle_command(StringView command)
304 {
305 bool ready_for_input = true;
306 if (command == ".exit" || command == ".quit") {
307 m_keep_running = false;
308 ready_for_input = false;
309 } else if (command.starts_with(".connect "sv)) {
310 auto parts = command.split_view(' ');
311 if (parts.size() == 2) {
312 connect(parts[1]);
313 ready_for_input = false;
314 } else {
315 outln("\033[33;1mUsage: .connect <database name>\033[0m");
316 }
317 } else if (command.starts_with(".read "sv)) {
318 if (!m_input_file) {
319 auto parts = command.split_view(' ');
320 if (parts.size() == 2) {
321 source_file(parts[1]);
322 } else {
323 outln("\033[33;1mUsage: .read <sql file>\033[0m");
324 }
325 } else {
326 outln("\033[33;1mCannot recursively read sql files\033[0m");
327 }
328 } else {
329 outln("\033[33;1mUnrecognized command:\033[0m {}", command);
330 }
331 return ready_for_input;
332 }
333};
334
335ErrorOr<int> serenity_main(Main::Arguments arguments)
336{
337 DeprecatedString database_name(getlogin());
338 DeprecatedString file_to_source;
339 DeprecatedString file_to_read;
340 bool suppress_sqlrc = false;
341 auto sqlrc_path = DeprecatedString::formatted("{}/.sqlrc", Core::StandardPaths::home_directory());
342#if !defined(AK_OS_SERENITY)
343 StringView sql_server_path;
344#endif
345
346 Core::ArgsParser args_parser;
347 args_parser.set_general_help("This is a client for the SerenitySQL database server.");
348 args_parser.add_option(database_name, "Database to connect to", "database", 'd', "database");
349 args_parser.add_option(file_to_read, "File to read", "read", 'r', "file");
350 args_parser.add_option(file_to_source, "File to source", "source", 's', "file");
351 args_parser.add_option(suppress_sqlrc, "Don't read ~/.sqlrc", "no-sqlrc", 'n');
352#if !defined(AK_OS_SERENITY)
353 args_parser.add_option(sql_server_path, "Path to SQLServer to launch if needed", "sql-server-path", 's', "path");
354#endif
355 args_parser.parse(arguments);
356
357 Core::EventLoop loop;
358
359#if defined(AK_OS_SERENITY)
360 auto sql_client = TRY(SQL::SQLClient::try_create());
361#else
362 VERIFY(!sql_server_path.is_empty());
363 auto sql_client = TRY(SQL::SQLClient::launch_server_and_create_client({ TRY(String::from_utf8(sql_server_path)) }));
364#endif
365
366 SQLRepl repl(loop, database_name, move(sql_client));
367
368 if (!suppress_sqlrc && Core::DeprecatedFile::exists(sqlrc_path))
369 repl.source_file(sqlrc_path);
370 if (!file_to_source.is_empty())
371 repl.source_file(file_to_source);
372 if (!file_to_read.is_empty())
373 repl.read_file(file_to_read);
374 return repl.run();
375}