Serenity Operating System
1/*
2 * Copyright (c) 2018-2020, Andreas Kling <kling@serenityos.org>
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 *
8 * 1. Redistributions of source code must retain the above copyright notice, this
9 * list of conditions and the following disclaimer.
10 *
11 * 2. Redistributions in binary form must reproduce the above copyright notice,
12 * this list of conditions and the following disclaimer in the documentation
13 * and/or other materials provided with the distribution.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
19 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
21 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
22 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
23 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
24 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 */
26
27#pragma once
28
29#include <AK/ByteBuffer.h>
30#include <AK/NonnullOwnPtrVector.h>
31#include <LibCore/Event.h>
32#include <LibCore/LocalSocket.h>
33#include <LibCore/Notifier.h>
34#include <LibCore/SyscallUtils.h>
35#include <LibIPC/Message.h>
36#include <stdio.h>
37#include <stdlib.h>
38#include <sys/select.h>
39#include <sys/socket.h>
40#include <sys/types.h>
41#include <unistd.h>
42
43namespace IPC {
44
45template<typename LocalEndpoint, typename PeerEndpoint>
46class ServerConnection : public Core::Object {
47public:
48 ServerConnection(LocalEndpoint& local_endpoint, const StringView& address)
49 : m_local_endpoint(local_endpoint)
50 , m_connection(Core::LocalSocket::construct(this))
51 , m_notifier(Core::Notifier::construct(m_connection->fd(), Core::Notifier::Read, this))
52 {
53 // We want to rate-limit our clients
54 m_connection->set_blocking(true);
55 m_notifier->on_ready_to_read = [this] {
56 drain_messages_from_server();
57 handle_messages();
58 };
59
60 int retries = 100000;
61 while (retries) {
62 if (m_connection->connect(Core::SocketAddress::local(address))) {
63 break;
64 }
65
66 dbgprintf("Client::Connection: connect failed: %d, %s\n", errno, strerror(errno));
67 usleep(10000);
68 --retries;
69 }
70
71#ifdef __OpenBSD__
72 sockpeercred creds;
73#else
74 ucred creds;
75#endif
76 socklen_t creds_size = sizeof(creds);
77 if (getsockopt(m_connection->fd(), SOL_SOCKET, SO_PEERCRED, &creds, &creds_size) < 0) {
78 ASSERT_NOT_REACHED();
79 }
80 m_server_pid = creds.pid;
81
82 ASSERT(m_connection->is_connected());
83 }
84
85 virtual void handshake() = 0;
86
87 pid_t server_pid() const { return m_server_pid; }
88 void set_my_client_id(int id) { m_my_client_id = id; }
89 int my_client_id() const { return m_my_client_id; }
90
91 template<typename MessageType>
92 OwnPtr<MessageType> wait_for_specific_message()
93 {
94 // Double check we don't already have the event waiting for us.
95 // Otherwise we might end up blocked for a while for no reason.
96 for (size_t i = 0; i < m_unprocessed_messages.size(); ++i) {
97 if (m_unprocessed_messages[i]->message_id() == MessageType::static_message_id()) {
98 auto message = move(m_unprocessed_messages[i]);
99 m_unprocessed_messages.remove(i);
100 return message.template release_nonnull<MessageType>();
101 }
102 }
103 for (;;) {
104 fd_set rfds;
105 FD_ZERO(&rfds);
106 FD_SET(m_connection->fd(), &rfds);
107 int rc = Core::safe_syscall(select, m_connection->fd() + 1, &rfds, nullptr, nullptr, nullptr);
108 if (rc < 0) {
109 perror("select");
110 }
111 ASSERT(rc > 0);
112 ASSERT(FD_ISSET(m_connection->fd(), &rfds));
113 if (!drain_messages_from_server())
114 return nullptr;
115 for (size_t i = 0; i < m_unprocessed_messages.size(); ++i) {
116 if (m_unprocessed_messages[i]->message_id() == MessageType::static_message_id()) {
117 auto message = move(m_unprocessed_messages[i]);
118 m_unprocessed_messages.remove(i);
119 return message.template release_nonnull<MessageType>();
120 }
121 }
122 }
123 }
124
125 bool post_message(const Message& message)
126 {
127 auto buffer = message.encode();
128 int nwritten = write(m_connection->fd(), buffer.data(), buffer.size());
129 if (nwritten < 0) {
130 perror("write");
131 ASSERT_NOT_REACHED();
132 return false;
133 }
134 ASSERT(static_cast<size_t>(nwritten) == buffer.size());
135 return true;
136 }
137
138 template<typename RequestType, typename... Args>
139 OwnPtr<typename RequestType::ResponseType> send_sync(Args&&... args)
140 {
141 bool success = post_message(RequestType(forward<Args>(args)...));
142 ASSERT(success);
143 auto response = wait_for_specific_message<typename RequestType::ResponseType>();
144 ASSERT(response);
145 return response;
146 }
147
148private:
149 bool drain_messages_from_server()
150 {
151 Vector<u8> bytes;
152 for (;;) {
153 u8 buffer[4096];
154 ssize_t nread = recv(m_connection->fd(), buffer, sizeof(buffer), MSG_DONTWAIT);
155 if (nread < 0) {
156 if (errno == EAGAIN)
157 break;
158 perror("read");
159 exit(1);
160 return false;
161 }
162 if (nread == 0) {
163 dbg() << "EOF on IPC fd";
164 // FIXME: Dying is definitely not always appropriate!
165 exit(1);
166 return false;
167 }
168 bytes.append(buffer, nread);
169 }
170
171 size_t decoded_bytes = 0;
172 for (size_t index = 0; index < bytes.size(); index += decoded_bytes) {
173 auto remaining_bytes = ByteBuffer::wrap(bytes.data() + index, bytes.size() - index);
174 if (auto message = LocalEndpoint::decode_message(remaining_bytes, decoded_bytes)) {
175 m_unprocessed_messages.append(move(message));
176 } else if (auto message = PeerEndpoint::decode_message(remaining_bytes, decoded_bytes)) {
177 m_unprocessed_messages.append(move(message));
178 } else {
179 ASSERT_NOT_REACHED();
180 }
181 ASSERT(decoded_bytes);
182 }
183
184 if (!m_unprocessed_messages.is_empty()) {
185 deferred_invoke([this](auto&) {
186 handle_messages();
187 });
188 }
189 return true;
190 }
191
192 void handle_messages()
193 {
194 auto messages = move(m_unprocessed_messages);
195 for (auto& message : messages) {
196 if (message->endpoint_magic() == LocalEndpoint::static_magic())
197 m_local_endpoint.handle(*message);
198 }
199 }
200
201 LocalEndpoint& m_local_endpoint;
202 RefPtr<Core::LocalSocket> m_connection;
203 RefPtr<Core::Notifier> m_notifier;
204 Vector<OwnPtr<Message>> m_unprocessed_messages;
205 int m_server_pid { -1 };
206 int m_my_client_id { -1 };
207};
208
209}