Serenity Operating System
1/*
2 * Copyright (c) 2018-2020, Andreas Kling <kling@serenityos.org>
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 *
8 * 1. Redistributions of source code must retain the above copyright notice, this
9 * list of conditions and the following disclaimer.
10 *
11 * 2. Redistributions in binary form must reproduce the above copyright notice,
12 * this list of conditions and the following disclaimer in the documentation
13 * and/or other materials provided with the distribution.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
19 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
21 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
22 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
23 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
24 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 */
26
27#pragma once
28
29#include <AK/ByteBuffer.h>
30#include <AK/NonnullOwnPtrVector.h>
31#include <LibCore/Event.h>
32#include <LibCore/LocalSocket.h>
33#include <LibCore/Notifier.h>
34#include <LibCore/SyscallUtils.h>
35#include <LibIPC/Message.h>
36#include <stdio.h>
37#include <stdlib.h>
38#include <sys/select.h>
39#include <sys/socket.h>
40#include <sys/types.h>
41#include <unistd.h>
42
43namespace IPC {
44
45template<typename LocalEndpoint, typename PeerEndpoint>
46class ServerConnection : public Core::Object {
47public:
48 ServerConnection(LocalEndpoint& local_endpoint, const StringView& address)
49 : m_local_endpoint(local_endpoint)
50 , m_connection(Core::LocalSocket::construct(this))
51 , m_notifier(Core::Notifier::construct(m_connection->fd(), Core::Notifier::Read, this))
52 {
53 // We want to rate-limit our clients
54 m_connection->set_blocking(true);
55 m_notifier->on_ready_to_read = [this] {
56 drain_messages_from_server();
57 handle_messages();
58 };
59
60 int retries = 100000;
61 while (retries) {
62 if (m_connection->connect(Core::SocketAddress::local(address))) {
63 break;
64 }
65
66 dbgprintf("Client::Connection: connect failed: %d, %s\n", errno, strerror(errno));
67 usleep(10000);
68 --retries;
69 }
70
71 ucred creds;
72 socklen_t creds_size = sizeof(creds);
73 if (getsockopt(m_connection->fd(), SOL_SOCKET, SO_PEERCRED, &creds, &creds_size) < 0) {
74 ASSERT_NOT_REACHED();
75 }
76 m_server_pid = creds.pid;
77
78 ASSERT(m_connection->is_connected());
79 }
80
81 virtual void handshake() = 0;
82
83 pid_t server_pid() const { return m_server_pid; }
84 void set_my_client_id(int id) { m_my_client_id = id; }
85 int my_client_id() const { return m_my_client_id; }
86
87 template<typename MessageType>
88 OwnPtr<MessageType> wait_for_specific_message()
89 {
90 // Double check we don't already have the event waiting for us.
91 // Otherwise we might end up blocked for a while for no reason.
92 for (size_t i = 0; i < m_unprocessed_messages.size(); ++i) {
93 if (m_unprocessed_messages[i]->message_id() == MessageType::static_message_id()) {
94 auto message = move(m_unprocessed_messages[i]);
95 m_unprocessed_messages.remove(i);
96 return message;
97 }
98 }
99 for (;;) {
100 fd_set rfds;
101 FD_ZERO(&rfds);
102 FD_SET(m_connection->fd(), &rfds);
103 int rc = Core::safe_syscall(select, m_connection->fd() + 1, &rfds, nullptr, nullptr, nullptr);
104 if (rc < 0) {
105 perror("select");
106 }
107 ASSERT(rc > 0);
108 ASSERT(FD_ISSET(m_connection->fd(), &rfds));
109 if (!drain_messages_from_server())
110 return nullptr;
111 for (size_t i = 0; i < m_unprocessed_messages.size(); ++i) {
112 if (m_unprocessed_messages[i]->message_id() == MessageType::static_message_id()) {
113 auto message = move(m_unprocessed_messages[i]);
114 m_unprocessed_messages.remove(i);
115 return message;
116 }
117 }
118 }
119 }
120
121 bool post_message(const Message& message)
122 {
123 auto buffer = message.encode();
124 int nwritten = write(m_connection->fd(), buffer.data(), (size_t)buffer.size());
125 if (nwritten < 0) {
126 perror("write");
127 ASSERT_NOT_REACHED();
128 return false;
129 }
130 ASSERT(static_cast<size_t>(nwritten) == buffer.size());
131 return true;
132 }
133
134 template<typename RequestType, typename... Args>
135 OwnPtr<typename RequestType::ResponseType> send_sync(Args&&... args)
136 {
137 bool success = post_message(RequestType(forward<Args>(args)...));
138 ASSERT(success);
139 auto response = wait_for_specific_message<typename RequestType::ResponseType>();
140 ASSERT(response);
141 return response;
142 }
143
144private:
145 bool drain_messages_from_server()
146 {
147 Vector<u8> bytes;
148 for (;;) {
149 u8 buffer[4096];
150 ssize_t nread = recv(m_connection->fd(), buffer, sizeof(buffer), MSG_DONTWAIT);
151 if (nread < 0) {
152 if (errno == EAGAIN)
153 break;
154 perror("read");
155 exit(1);
156 return false;
157 }
158 if (nread == 0) {
159 dbg() << "EOF on IPC fd";
160 // FIXME: Dying is definitely not always appropriate!
161 exit(1);
162 return false;
163 }
164 bytes.append(buffer, nread);
165 }
166
167 size_t decoded_bytes = 0;
168 for (size_t index = 0; index < (size_t)bytes.size(); index += decoded_bytes) {
169 auto remaining_bytes = ByteBuffer::wrap(bytes.data() + index, bytes.size() - index);
170 if (auto message = LocalEndpoint::decode_message(remaining_bytes, decoded_bytes)) {
171 m_unprocessed_messages.append(move(message));
172 } else if (auto message = PeerEndpoint::decode_message(remaining_bytes, decoded_bytes)) {
173 m_unprocessed_messages.append(move(message));
174 } else {
175 ASSERT_NOT_REACHED();
176 }
177 ASSERT(decoded_bytes);
178 }
179
180 if (!m_unprocessed_messages.is_empty()) {
181 deferred_invoke([this](auto&) {
182 handle_messages();
183 });
184 }
185 return true;
186 }
187
188 void handle_messages()
189 {
190 auto messages = move(m_unprocessed_messages);
191 for (auto& message : messages) {
192 if (message->endpoint_magic() == LocalEndpoint::static_magic())
193 m_local_endpoint.handle(*message);
194 }
195 }
196
197 LocalEndpoint& m_local_endpoint;
198 RefPtr<Core::LocalSocket> m_connection;
199 RefPtr<Core::Notifier> m_notifier;
200 Vector<OwnPtr<Message>> m_unprocessed_messages;
201 int m_server_pid { -1 };
202 int m_my_client_id { -1 };
203};
204
205}