Serenity Operating System
at portability 205 lines 7.4 kB view raw
1/* 2 * Copyright (c) 2018-2020, Andreas Kling <kling@serenityos.org> 3 * All rights reserved. 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions are met: 7 * 8 * 1. Redistributions of source code must retain the above copyright notice, this 9 * list of conditions and the following disclaimer. 10 * 11 * 2. Redistributions in binary form must reproduce the above copyright notice, 12 * this list of conditions and the following disclaimer in the documentation 13 * and/or other materials provided with the distribution. 14 * 15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 19 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 20 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 21 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 22 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 23 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 24 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 */ 26 27#pragma once 28 29#include <AK/ByteBuffer.h> 30#include <AK/NonnullOwnPtrVector.h> 31#include <LibCore/Event.h> 32#include <LibCore/LocalSocket.h> 33#include <LibCore/Notifier.h> 34#include <LibCore/SyscallUtils.h> 35#include <LibIPC/Message.h> 36#include <stdio.h> 37#include <stdlib.h> 38#include <sys/select.h> 39#include <sys/socket.h> 40#include <sys/types.h> 41#include <unistd.h> 42 43namespace IPC { 44 45template<typename LocalEndpoint, typename PeerEndpoint> 46class ServerConnection : public Core::Object { 47public: 48 ServerConnection(LocalEndpoint& local_endpoint, const StringView& address) 49 : m_local_endpoint(local_endpoint) 50 , m_connection(Core::LocalSocket::construct(this)) 51 , m_notifier(Core::Notifier::construct(m_connection->fd(), Core::Notifier::Read, this)) 52 { 53 // We want to rate-limit our clients 54 m_connection->set_blocking(true); 55 m_notifier->on_ready_to_read = [this] { 56 drain_messages_from_server(); 57 handle_messages(); 58 }; 59 60 int retries = 100000; 61 while (retries) { 62 if (m_connection->connect(Core::SocketAddress::local(address))) { 63 break; 64 } 65 66 dbgprintf("Client::Connection: connect failed: %d, %s\n", errno, strerror(errno)); 67 usleep(10000); 68 --retries; 69 } 70 71 ucred creds; 72 socklen_t creds_size = sizeof(creds); 73 if (getsockopt(m_connection->fd(), SOL_SOCKET, SO_PEERCRED, &creds, &creds_size) < 0) { 74 ASSERT_NOT_REACHED(); 75 } 76 m_server_pid = creds.pid; 77 78 ASSERT(m_connection->is_connected()); 79 } 80 81 virtual void handshake() = 0; 82 83 pid_t server_pid() const { return m_server_pid; } 84 void set_my_client_id(int id) { m_my_client_id = id; } 85 int my_client_id() const { return m_my_client_id; } 86 87 template<typename MessageType> 88 OwnPtr<MessageType> wait_for_specific_message() 89 { 90 // Double check we don't already have the event waiting for us. 91 // Otherwise we might end up blocked for a while for no reason. 92 for (size_t i = 0; i < m_unprocessed_messages.size(); ++i) { 93 if (m_unprocessed_messages[i]->message_id() == MessageType::static_message_id()) { 94 auto message = move(m_unprocessed_messages[i]); 95 m_unprocessed_messages.remove(i); 96 return message; 97 } 98 } 99 for (;;) { 100 fd_set rfds; 101 FD_ZERO(&rfds); 102 FD_SET(m_connection->fd(), &rfds); 103 int rc = Core::safe_syscall(select, m_connection->fd() + 1, &rfds, nullptr, nullptr, nullptr); 104 if (rc < 0) { 105 perror("select"); 106 } 107 ASSERT(rc > 0); 108 ASSERT(FD_ISSET(m_connection->fd(), &rfds)); 109 if (!drain_messages_from_server()) 110 return nullptr; 111 for (size_t i = 0; i < m_unprocessed_messages.size(); ++i) { 112 if (m_unprocessed_messages[i]->message_id() == MessageType::static_message_id()) { 113 auto message = move(m_unprocessed_messages[i]); 114 m_unprocessed_messages.remove(i); 115 return message; 116 } 117 } 118 } 119 } 120 121 bool post_message(const Message& message) 122 { 123 auto buffer = message.encode(); 124 int nwritten = write(m_connection->fd(), buffer.data(), (size_t)buffer.size()); 125 if (nwritten < 0) { 126 perror("write"); 127 ASSERT_NOT_REACHED(); 128 return false; 129 } 130 ASSERT(static_cast<size_t>(nwritten) == buffer.size()); 131 return true; 132 } 133 134 template<typename RequestType, typename... Args> 135 OwnPtr<typename RequestType::ResponseType> send_sync(Args&&... args) 136 { 137 bool success = post_message(RequestType(forward<Args>(args)...)); 138 ASSERT(success); 139 auto response = wait_for_specific_message<typename RequestType::ResponseType>(); 140 ASSERT(response); 141 return response; 142 } 143 144private: 145 bool drain_messages_from_server() 146 { 147 Vector<u8> bytes; 148 for (;;) { 149 u8 buffer[4096]; 150 ssize_t nread = recv(m_connection->fd(), buffer, sizeof(buffer), MSG_DONTWAIT); 151 if (nread < 0) { 152 if (errno == EAGAIN) 153 break; 154 perror("read"); 155 exit(1); 156 return false; 157 } 158 if (nread == 0) { 159 dbg() << "EOF on IPC fd"; 160 // FIXME: Dying is definitely not always appropriate! 161 exit(1); 162 return false; 163 } 164 bytes.append(buffer, nread); 165 } 166 167 size_t decoded_bytes = 0; 168 for (size_t index = 0; index < (size_t)bytes.size(); index += decoded_bytes) { 169 auto remaining_bytes = ByteBuffer::wrap(bytes.data() + index, bytes.size() - index); 170 if (auto message = LocalEndpoint::decode_message(remaining_bytes, decoded_bytes)) { 171 m_unprocessed_messages.append(move(message)); 172 } else if (auto message = PeerEndpoint::decode_message(remaining_bytes, decoded_bytes)) { 173 m_unprocessed_messages.append(move(message)); 174 } else { 175 ASSERT_NOT_REACHED(); 176 } 177 ASSERT(decoded_bytes); 178 } 179 180 if (!m_unprocessed_messages.is_empty()) { 181 deferred_invoke([this](auto&) { 182 handle_messages(); 183 }); 184 } 185 return true; 186 } 187 188 void handle_messages() 189 { 190 auto messages = move(m_unprocessed_messages); 191 for (auto& message : messages) { 192 if (message->endpoint_magic() == LocalEndpoint::static_magic()) 193 m_local_endpoint.handle(*message); 194 } 195 } 196 197 LocalEndpoint& m_local_endpoint; 198 RefPtr<Core::LocalSocket> m_connection; 199 RefPtr<Core::Notifier> m_notifier; 200 Vector<OwnPtr<Message>> m_unprocessed_messages; 201 int m_server_pid { -1 }; 202 int m_my_client_id { -1 }; 203}; 204 205}