Serenity Operating System
at master 375 lines 14 kB view raw
1/* 2 * Copyright (c) 2021, Tim Flynn <trflynn89@serenityos.org> 3 * Copyright (c) 2022, Alex Major 4 * 5 * SPDX-License-Identifier: BSD-2-Clause 6 */ 7 8#include <AK/DeprecatedString.h> 9#include <AK/Format.h> 10#include <AK/String.h> 11#include <AK/StringBuilder.h> 12#include <LibCore/ArgsParser.h> 13#include <LibCore/DeprecatedFile.h> 14#include <LibCore/StandardPaths.h> 15#include <LibLine/Editor.h> 16#include <LibMain/Main.h> 17#include <LibSQL/AST/Lexer.h> 18#include <LibSQL/AST/Token.h> 19#include <LibSQL/SQLClient.h> 20#include <unistd.h> 21 22class SQLRepl { 23public: 24 explicit SQLRepl(Core::EventLoop& loop, DeprecatedString const& database_name, NonnullRefPtr<SQL::SQLClient> sql_client) 25 : m_sql_client(move(sql_client)) 26 , m_loop(loop) 27 { 28 m_editor = Line::Editor::construct(); 29 m_editor->load_history(m_history_path); 30 31 m_editor->on_display_refresh = [this](Line::Editor& editor) { 32 editor.strip_styles(); 33 34 int open_indents = m_repl_line_level; 35 36 auto line = editor.line(); 37 SQL::AST::Lexer lexer(line); 38 39 bool indenters_starting_line = true; 40 for (SQL::AST::Token token = lexer.next(); token.type() != SQL::AST::TokenType::Eof; token = lexer.next()) { 41 auto start = token.start_position().column - 1; 42 auto end = token.end_position().column - 1; 43 44 if (indenters_starting_line) { 45 if (token.type() != SQL::AST::TokenType::ParenClose) 46 indenters_starting_line = false; 47 else 48 --open_indents; 49 } 50 51 switch (token.category()) { 52 case SQL::AST::TokenCategory::Invalid: 53 editor.stylize({ start, end }, { Line::Style::Foreground(Line::Style::XtermColor::Red), Line::Style::Underline }); 54 break; 55 case SQL::AST::TokenCategory::Number: 56 editor.stylize({ start, end }, { Line::Style::Foreground(Line::Style::XtermColor::Magenta) }); 57 break; 58 case SQL::AST::TokenCategory::String: 59 editor.stylize({ start, end }, { Line::Style::Foreground(Line::Style::XtermColor::Green), Line::Style::Bold }); 60 break; 61 case SQL::AST::TokenCategory::Blob: 62 editor.stylize({ start, end }, { Line::Style::Foreground(Line::Style::XtermColor::Magenta), Line::Style::Bold }); 63 break; 64 case SQL::AST::TokenCategory::Keyword: 65 editor.stylize({ start, end }, { Line::Style::Foreground(Line::Style::XtermColor::Blue), Line::Style::Bold }); 66 break; 67 case SQL::AST::TokenCategory::Identifier: 68 editor.stylize({ start, end }, { Line::Style::Foreground(Line::Style::XtermColor::White), Line::Style::Bold }); 69 break; 70 default: 71 break; 72 } 73 } 74 75 m_editor->set_prompt(prompt_for_level(open_indents)); 76 }; 77 78 m_sql_client->on_execution_success = [this](auto result) { 79 if (result.rows_updated != 0 || result.rows_created != 0 || result.rows_deleted != 0) 80 outln("{} row(s) created, {} updated, {} deleted", result.rows_created, result.rows_updated, result.rows_deleted); 81 if (!result.has_results) 82 read_sql(); 83 }; 84 85 m_sql_client->on_next_result = [](auto result) { 86 StringBuilder builder; 87 builder.join(", "sv, result.values); 88 outln("{}", builder.to_deprecated_string()); 89 }; 90 91 m_sql_client->on_results_exhausted = [this](auto result) { 92 outln("{} row(s)", result.total_rows); 93 read_sql(); 94 }; 95 96 m_sql_client->on_execution_error = [this](auto result) { 97 outln("\033[33;1mExecution error:\033[0m {}", result.error_message); 98 read_sql(); 99 }; 100 101 if (!database_name.is_empty()) 102 connect(database_name); 103 } 104 105 ~SQLRepl() 106 { 107 m_editor->save_history(m_history_path); 108 } 109 110 void connect(DeprecatedString const& database_name) 111 { 112 if (!m_database_name.is_empty()) { 113 m_sql_client->disconnect(m_connection_id); 114 m_database_name = {}; 115 } 116 117 if (auto connection_id = m_sql_client->connect(database_name); connection_id.has_value()) { 118 outln("Connected to \033[33;1m{}\033[0m", database_name); 119 m_database_name = database_name; 120 m_connection_id = *connection_id; 121 } else { 122 warnln("\033[33;1mCould not connect to:\033[0m {}", database_name); 123 m_loop.quit(1); 124 } 125 } 126 127 void source_file(DeprecatedString file_name) 128 { 129 m_input_file_chain.append(move(file_name)); 130 m_quit_when_files_read = false; 131 } 132 133 void read_file(DeprecatedString file_name) 134 { 135 m_input_file_chain.append(move(file_name)); 136 m_quit_when_files_read = true; 137 } 138 139 auto run() 140 { 141 read_sql(); 142 return m_loop.exec(); 143 } 144 145private: 146 DeprecatedString m_history_path { DeprecatedString::formatted("{}/.sql-history", Core::StandardPaths::home_directory()) }; 147 RefPtr<Line::Editor> m_editor { nullptr }; 148 int m_repl_line_level { 0 }; 149 bool m_keep_running { true }; 150 DeprecatedString m_database_name {}; 151 NonnullRefPtr<SQL::SQLClient> m_sql_client; 152 SQL::ConnectionID m_connection_id { 0 }; 153 Core::EventLoop& m_loop; 154 OwnPtr<Core::BufferedFile> m_input_file { nullptr }; 155 bool m_quit_when_files_read { false }; 156 Vector<DeprecatedString> m_input_file_chain {}; 157 Array<u8, 4096> m_buffer {}; 158 159 Optional<DeprecatedString> get_line() 160 { 161 if (!m_input_file && !m_input_file_chain.is_empty()) { 162 auto file_name = m_input_file_chain.take_first(); 163 auto file_or_error = Core::File::open(file_name, Core::File::OpenMode::Read); 164 if (file_or_error.is_error()) { 165 warnln("Input file {} could not be opened: {}", file_name, file_or_error.error()); 166 return {}; 167 } 168 169 auto buffered_file_or_error = Core::BufferedFile::create(file_or_error.release_value()); 170 if (buffered_file_or_error.is_error()) { 171 warnln("Input file {} could not be buffered: {}", file_name, buffered_file_or_error.error()); 172 return {}; 173 } 174 175 m_input_file = buffered_file_or_error.release_value(); 176 } 177 if (m_input_file) { 178 auto line = m_input_file->read_line(m_buffer); 179 if (line.is_error()) { 180 warnln("Failed to read line: {}", line.error()); 181 return {}; 182 } 183 if (m_input_file->is_eof()) { 184 m_input_file->close(); 185 m_input_file = nullptr; 186 if (m_quit_when_files_read && m_input_file_chain.is_empty()) 187 return {}; 188 } 189 return line.release_value(); 190 // If the last file is exhausted but m_quit_when_files_read is false 191 // we fall through to the standard reading from the editor behaviour 192 } 193 auto line_result = m_editor->get_line(prompt_for_level(m_repl_line_level)); 194 if (line_result.is_error()) 195 return {}; 196 return line_result.value(); 197 } 198 199 DeprecatedString read_next_piece() 200 { 201 StringBuilder piece; 202 203 do { 204 if (!piece.is_empty()) 205 piece.append('\n'); 206 207 auto line_maybe = get_line(); 208 209 if (!line_maybe.has_value()) { 210 m_keep_running = false; 211 return {}; 212 } 213 214 auto& line = line_maybe.value(); 215 auto lexer = SQL::AST::Lexer(line); 216 217 m_editor->add_to_history(line); 218 piece.append(line); 219 220 bool is_first_token = true; 221 bool is_command = false; 222 bool last_token_ended_statement = false; 223 bool tokens_found = false; 224 225 for (SQL::AST::Token token = lexer.next(); token.type() != SQL::AST::TokenType::Eof; token = lexer.next()) { 226 tokens_found = true; 227 switch (token.type()) { 228 case SQL::AST::TokenType::ParenOpen: 229 ++m_repl_line_level; 230 break; 231 case SQL::AST::TokenType::ParenClose: 232 --m_repl_line_level; 233 break; 234 case SQL::AST::TokenType::SemiColon: 235 last_token_ended_statement = true; 236 break; 237 case SQL::AST::TokenType::Period: 238 if (is_first_token) 239 is_command = true; 240 break; 241 default: 242 last_token_ended_statement = is_command; 243 break; 244 } 245 246 is_first_token = false; 247 } 248 249 if (tokens_found) 250 m_repl_line_level = last_token_ended_statement ? 0 : (m_repl_line_level > 0 ? m_repl_line_level : 1); 251 } while ((m_repl_line_level > 0) || piece.is_empty()); 252 253 return piece.to_deprecated_string(); 254 } 255 256 void read_sql() 257 { 258 DeprecatedString piece = read_next_piece(); 259 260 // m_keep_running can be set to false when the file we are reading 261 // from is exhausted... 262 if (!m_keep_running) { 263 m_sql_client->disconnect(m_connection_id); 264 m_loop.quit(0); 265 return; 266 } 267 268 if (piece.starts_with('.')) { 269 bool ready_for_input = handle_command(piece); 270 if (ready_for_input) 271 m_loop.deferred_invoke([this]() { 272 read_sql(); 273 }); 274 } else if (auto statement_id = m_sql_client->prepare_statement(m_connection_id, piece); statement_id.has_value()) { 275 m_sql_client->async_execute_statement(*statement_id, {}); 276 } else { 277 warnln("\033[33;1mError parsing SQL statement\033[0m: {}", piece); 278 m_loop.deferred_invoke([this]() { 279 read_sql(); 280 }); 281 } 282 283 // ...But m_keep_running can also be set to false by a command handler. 284 if (!m_keep_running) { 285 m_sql_client->disconnect(m_connection_id); 286 m_loop.quit(0); 287 return; 288 } 289 }; 290 291 static DeprecatedString prompt_for_level(int level) 292 { 293 static StringBuilder prompt_builder; 294 prompt_builder.clear(); 295 prompt_builder.append("> "sv); 296 297 for (auto i = 0; i < level; ++i) 298 prompt_builder.append(" "sv); 299 300 return prompt_builder.to_deprecated_string(); 301 } 302 303 bool handle_command(StringView command) 304 { 305 bool ready_for_input = true; 306 if (command == ".exit" || command == ".quit") { 307 m_keep_running = false; 308 ready_for_input = false; 309 } else if (command.starts_with(".connect "sv)) { 310 auto parts = command.split_view(' '); 311 if (parts.size() == 2) { 312 connect(parts[1]); 313 ready_for_input = false; 314 } else { 315 outln("\033[33;1mUsage: .connect <database name>\033[0m"); 316 } 317 } else if (command.starts_with(".read "sv)) { 318 if (!m_input_file) { 319 auto parts = command.split_view(' '); 320 if (parts.size() == 2) { 321 source_file(parts[1]); 322 } else { 323 outln("\033[33;1mUsage: .read <sql file>\033[0m"); 324 } 325 } else { 326 outln("\033[33;1mCannot recursively read sql files\033[0m"); 327 } 328 } else { 329 outln("\033[33;1mUnrecognized command:\033[0m {}", command); 330 } 331 return ready_for_input; 332 } 333}; 334 335ErrorOr<int> serenity_main(Main::Arguments arguments) 336{ 337 DeprecatedString database_name(getlogin()); 338 DeprecatedString file_to_source; 339 DeprecatedString file_to_read; 340 bool suppress_sqlrc = false; 341 auto sqlrc_path = DeprecatedString::formatted("{}/.sqlrc", Core::StandardPaths::home_directory()); 342#if !defined(AK_OS_SERENITY) 343 StringView sql_server_path; 344#endif 345 346 Core::ArgsParser args_parser; 347 args_parser.set_general_help("This is a client for the SerenitySQL database server."); 348 args_parser.add_option(database_name, "Database to connect to", "database", 'd', "database"); 349 args_parser.add_option(file_to_read, "File to read", "read", 'r', "file"); 350 args_parser.add_option(file_to_source, "File to source", "source", 's', "file"); 351 args_parser.add_option(suppress_sqlrc, "Don't read ~/.sqlrc", "no-sqlrc", 'n'); 352#if !defined(AK_OS_SERENITY) 353 args_parser.add_option(sql_server_path, "Path to SQLServer to launch if needed", "sql-server-path", 's', "path"); 354#endif 355 args_parser.parse(arguments); 356 357 Core::EventLoop loop; 358 359#if defined(AK_OS_SERENITY) 360 auto sql_client = TRY(SQL::SQLClient::try_create()); 361#else 362 VERIFY(!sql_server_path.is_empty()); 363 auto sql_client = TRY(SQL::SQLClient::launch_server_and_create_client({ TRY(String::from_utf8(sql_server_path)) })); 364#endif 365 366 SQLRepl repl(loop, database_name, move(sql_client)); 367 368 if (!suppress_sqlrc && Core::DeprecatedFile::exists(sqlrc_path)) 369 repl.source_file(sqlrc_path); 370 if (!file_to_source.is_empty()) 371 repl.source_file(file_to_source); 372 if (!file_to_read.is_empty()) 373 repl.read_file(file_to_read); 374 return repl.run(); 375}