Serenity Operating System
at hosted 209 lines 7.6 kB view raw
1/* 2 * Copyright (c) 2018-2020, Andreas Kling <kling@serenityos.org> 3 * All rights reserved. 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions are met: 7 * 8 * 1. Redistributions of source code must retain the above copyright notice, this 9 * list of conditions and the following disclaimer. 10 * 11 * 2. Redistributions in binary form must reproduce the above copyright notice, 12 * this list of conditions and the following disclaimer in the documentation 13 * and/or other materials provided with the distribution. 14 * 15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 19 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 20 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 21 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 22 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 23 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 24 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 */ 26 27#pragma once 28 29#include <AK/ByteBuffer.h> 30#include <AK/NonnullOwnPtrVector.h> 31#include <LibCore/Event.h> 32#include <LibCore/LocalSocket.h> 33#include <LibCore/Notifier.h> 34#include <LibCore/SyscallUtils.h> 35#include <LibIPC/Message.h> 36#include <stdio.h> 37#include <stdlib.h> 38#include <sys/select.h> 39#include <sys/socket.h> 40#include <sys/types.h> 41#include <unistd.h> 42 43namespace IPC { 44 45template<typename LocalEndpoint, typename PeerEndpoint> 46class ServerConnection : public Core::Object { 47public: 48 ServerConnection(LocalEndpoint& local_endpoint, const StringView& address) 49 : m_local_endpoint(local_endpoint) 50 , m_connection(Core::LocalSocket::construct(this)) 51 , m_notifier(Core::Notifier::construct(m_connection->fd(), Core::Notifier::Read, this)) 52 { 53 // We want to rate-limit our clients 54 m_connection->set_blocking(true); 55 m_notifier->on_ready_to_read = [this] { 56 drain_messages_from_server(); 57 handle_messages(); 58 }; 59 60 int retries = 100000; 61 while (retries) { 62 if (m_connection->connect(Core::SocketAddress::local(address))) { 63 break; 64 } 65 66 dbgprintf("Client::Connection: connect failed: %d, %s\n", errno, strerror(errno)); 67 usleep(10000); 68 --retries; 69 } 70 71#ifdef __OpenBSD__ 72 sockpeercred creds; 73#else 74 ucred creds; 75#endif 76 socklen_t creds_size = sizeof(creds); 77 if (getsockopt(m_connection->fd(), SOL_SOCKET, SO_PEERCRED, &creds, &creds_size) < 0) { 78 ASSERT_NOT_REACHED(); 79 } 80 m_server_pid = creds.pid; 81 82 ASSERT(m_connection->is_connected()); 83 } 84 85 virtual void handshake() = 0; 86 87 pid_t server_pid() const { return m_server_pid; } 88 void set_my_client_id(int id) { m_my_client_id = id; } 89 int my_client_id() const { return m_my_client_id; } 90 91 template<typename MessageType> 92 OwnPtr<MessageType> wait_for_specific_message() 93 { 94 // Double check we don't already have the event waiting for us. 95 // Otherwise we might end up blocked for a while for no reason. 96 for (size_t i = 0; i < m_unprocessed_messages.size(); ++i) { 97 if (m_unprocessed_messages[i]->message_id() == MessageType::static_message_id()) { 98 auto message = move(m_unprocessed_messages[i]); 99 m_unprocessed_messages.remove(i); 100 return message.template release_nonnull<MessageType>(); 101 } 102 } 103 for (;;) { 104 fd_set rfds; 105 FD_ZERO(&rfds); 106 FD_SET(m_connection->fd(), &rfds); 107 int rc = Core::safe_syscall(select, m_connection->fd() + 1, &rfds, nullptr, nullptr, nullptr); 108 if (rc < 0) { 109 perror("select"); 110 } 111 ASSERT(rc > 0); 112 ASSERT(FD_ISSET(m_connection->fd(), &rfds)); 113 if (!drain_messages_from_server()) 114 return nullptr; 115 for (size_t i = 0; i < m_unprocessed_messages.size(); ++i) { 116 if (m_unprocessed_messages[i]->message_id() == MessageType::static_message_id()) { 117 auto message = move(m_unprocessed_messages[i]); 118 m_unprocessed_messages.remove(i); 119 return message.template release_nonnull<MessageType>(); 120 } 121 } 122 } 123 } 124 125 bool post_message(const Message& message) 126 { 127 auto buffer = message.encode(); 128 int nwritten = write(m_connection->fd(), buffer.data(), buffer.size()); 129 if (nwritten < 0) { 130 perror("write"); 131 ASSERT_NOT_REACHED(); 132 return false; 133 } 134 ASSERT(static_cast<size_t>(nwritten) == buffer.size()); 135 return true; 136 } 137 138 template<typename RequestType, typename... Args> 139 OwnPtr<typename RequestType::ResponseType> send_sync(Args&&... args) 140 { 141 bool success = post_message(RequestType(forward<Args>(args)...)); 142 ASSERT(success); 143 auto response = wait_for_specific_message<typename RequestType::ResponseType>(); 144 ASSERT(response); 145 return response; 146 } 147 148private: 149 bool drain_messages_from_server() 150 { 151 Vector<u8> bytes; 152 for (;;) { 153 u8 buffer[4096]; 154 ssize_t nread = recv(m_connection->fd(), buffer, sizeof(buffer), MSG_DONTWAIT); 155 if (nread < 0) { 156 if (errno == EAGAIN) 157 break; 158 perror("read"); 159 exit(1); 160 return false; 161 } 162 if (nread == 0) { 163 dbg() << "EOF on IPC fd"; 164 // FIXME: Dying is definitely not always appropriate! 165 exit(1); 166 return false; 167 } 168 bytes.append(buffer, nread); 169 } 170 171 size_t decoded_bytes = 0; 172 for (size_t index = 0; index < bytes.size(); index += decoded_bytes) { 173 auto remaining_bytes = ByteBuffer::wrap(bytes.data() + index, bytes.size() - index); 174 if (auto message = LocalEndpoint::decode_message(remaining_bytes, decoded_bytes)) { 175 m_unprocessed_messages.append(move(message)); 176 } else if (auto message = PeerEndpoint::decode_message(remaining_bytes, decoded_bytes)) { 177 m_unprocessed_messages.append(move(message)); 178 } else { 179 ASSERT_NOT_REACHED(); 180 } 181 ASSERT(decoded_bytes); 182 } 183 184 if (!m_unprocessed_messages.is_empty()) { 185 deferred_invoke([this](auto&) { 186 handle_messages(); 187 }); 188 } 189 return true; 190 } 191 192 void handle_messages() 193 { 194 auto messages = move(m_unprocessed_messages); 195 for (auto& message : messages) { 196 if (message->endpoint_magic() == LocalEndpoint::static_magic()) 197 m_local_endpoint.handle(*message); 198 } 199 } 200 201 LocalEndpoint& m_local_endpoint; 202 RefPtr<Core::LocalSocket> m_connection; 203 RefPtr<Core::Notifier> m_notifier; 204 Vector<OwnPtr<Message>> m_unprocessed_messages; 205 int m_server_pid { -1 }; 206 int m_my_client_id { -1 }; 207}; 208 209}