Serenity Operating System
1/*
2 * Copyright (c) 2018-2020, Andreas Kling <kling@serenityos.org>
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 *
8 * 1. Redistributions of source code must retain the above copyright notice, this
9 * list of conditions and the following disclaimer.
10 *
11 * 2. Redistributions in binary form must reproduce the above copyright notice,
12 * this list of conditions and the following disclaimer in the documentation
13 * and/or other materials provided with the distribution.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
19 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
21 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
22 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
23 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
24 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 */
26
27#pragma once
28
29#include <AK/ByteBuffer.h>
30#include <LibCore/Event.h>
31#include <LibCore/EventLoop.h>
32#include <LibCore/IODevice.h>
33#include <LibCore/LocalSocket.h>
34#include <LibCore/Notifier.h>
35#include <LibCore/Object.h>
36#include <LibIPC/Endpoint.h>
37#include <LibIPC/Message.h>
38#include <errno.h>
39#include <stdio.h>
40#include <sys/socket.h>
41#include <sys/types.h>
42#include <unistd.h>
43
44namespace IPC {
45
46class Event : public Core::Event {
47public:
48 enum Type {
49 Invalid = 2000,
50 Disconnected,
51 };
52 Event() {}
53 explicit Event(Type type)
54 : Core::Event(type)
55 {
56 }
57};
58
59class DisconnectedEvent : public Event {
60public:
61 explicit DisconnectedEvent(int client_id)
62 : Event(Disconnected)
63 , m_client_id(client_id)
64 {
65 }
66
67 int client_id() const { return m_client_id; }
68
69private:
70 int m_client_id { 0 };
71};
72
73template<typename T, class... Args>
74NonnullRefPtr<T> new_client_connection(Args&&... args)
75{
76 return T::construct(forward<Args>(args)...) /* arghs */;
77}
78
79template<typename Endpoint>
80class ClientConnection : public Core::Object {
81public:
82 ClientConnection(Endpoint& endpoint, Core::LocalSocket& socket, int client_id)
83 : m_endpoint(endpoint)
84 , m_socket(socket)
85 , m_client_id(client_id)
86 {
87 ASSERT(socket.is_connected());
88#ifdef __OpenBSD__
89 sockpeercred creds;
90#else
91 ucred creds;
92#endif
93 socklen_t creds_size = sizeof(creds);
94 if (getsockopt(m_socket->fd(), SOL_SOCKET, SO_PEERCRED, &creds, &creds_size) < 0) {
95 ASSERT_NOT_REACHED();
96 }
97 m_client_pid = creds.pid;
98 add_child(socket);
99 m_socket->on_ready_to_read = [this] { drain_messages_from_client(); };
100 }
101
102 virtual ~ClientConnection() override
103 {
104 }
105
106 void post_message(const Message& message)
107 {
108 // NOTE: If this connection is being shut down, but has not yet been destroyed,
109 // the socket will be closed. Don't try to send more messages.
110 if (!m_socket->is_open())
111 return;
112
113 auto buffer = message.encode();
114
115 int nwritten = write(m_socket->fd(), buffer.data(), buffer.size());
116 if (nwritten < 0) {
117 switch (errno) {
118 case EPIPE:
119 dbg() << *this << "::post_message: Disconnected from peer";
120 shutdown();
121 return;
122 case EAGAIN:
123 dbg() << *this << "::post_message: Client buffer overflowed.";
124 did_misbehave();
125 return;
126 default:
127 perror("Connection::post_message write");
128 ASSERT_NOT_REACHED();
129 }
130 }
131
132 ASSERT(static_cast<size_t>(nwritten) == buffer.size());
133 }
134
135 void drain_messages_from_client()
136 {
137 Vector<u8> bytes;
138 for (;;) {
139 u8 buffer[4096];
140 ssize_t nread = recv(m_socket->fd(), buffer, sizeof(buffer), MSG_DONTWAIT);
141 if (nread == 0 || (nread == -1 && errno == EAGAIN)) {
142 if (bytes.is_empty()) {
143 Core::EventLoop::current().post_event(*this, make<DisconnectedEvent>(client_id()));
144 return;
145 }
146 break;
147 }
148 if (nread < 0) {
149 perror("recv");
150 ASSERT_NOT_REACHED();
151 }
152 bytes.append(buffer, nread);
153 }
154
155 size_t decoded_bytes = 0;
156 for (size_t index = 0; index < bytes.size(); index += decoded_bytes) {
157 auto remaining_bytes = ByteBuffer::wrap(bytes.data() + index, bytes.size() - index);
158 auto message = Endpoint::decode_message(remaining_bytes, decoded_bytes);
159 if (!message) {
160 dbg() << "drain_messages_from_client: Endpoint didn't recognize message";
161 did_misbehave();
162 return;
163 }
164 if (auto response = m_endpoint.handle(*message))
165 post_message(*response);
166 ASSERT(decoded_bytes);
167 }
168 }
169
170 void did_misbehave()
171 {
172 dbg() << *this << " (id=" << m_client_id << ", pid=" << m_client_pid << ") misbehaved, disconnecting.";
173 shutdown();
174 }
175
176 void did_misbehave(const char* message)
177 {
178 dbg() << *this << " (id=" << m_client_id << ", pid=" << m_client_pid << ") misbehaved (" << message << "), disconnecting.";
179 shutdown();
180 }
181
182 void shutdown()
183 {
184 m_socket->close();
185 die();
186 }
187
188 int client_id() const { return m_client_id; }
189 pid_t client_pid() const { return m_client_pid; }
190
191 virtual void die() = 0;
192
193protected:
194 void event(Core::Event& event) override
195 {
196 if (event.type() == Event::Disconnected) {
197 int client_id = static_cast<const DisconnectedEvent&>(event).client_id();
198 dbg() << *this << ": Client disconnected: " << client_id;
199 die();
200 return;
201 }
202
203 Core::Object::event(event);
204 }
205
206private:
207 Endpoint& m_endpoint;
208 RefPtr<Core::LocalSocket> m_socket;
209 int m_client_id { -1 };
210 int m_client_pid { -1 };
211};
212
213}