Serenity Operating System
1/*
2 * Copyright (c) 2018-2020, Andreas Kling <kling@serenityos.org>
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 *
8 * 1. Redistributions of source code must retain the above copyright notice, this
9 * list of conditions and the following disclaimer.
10 *
11 * 2. Redistributions in binary form must reproduce the above copyright notice,
12 * this list of conditions and the following disclaimer in the documentation
13 * and/or other materials provided with the distribution.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
19 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
21 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
22 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
23 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
24 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 */
26
27#pragma once
28
29#include <AK/ByteBuffer.h>
30#include <LibCore/Event.h>
31#include <LibCore/EventLoop.h>
32#include <LibCore/IODevice.h>
33#include <LibCore/LocalSocket.h>
34#include <LibCore/Notifier.h>
35#include <LibCore/Object.h>
36#include <LibIPC/Endpoint.h>
37#include <LibIPC/Message.h>
38#include <errno.h>
39#include <stdio.h>
40#include <sys/socket.h>
41#include <sys/types.h>
42#include <unistd.h>
43
44namespace IPC {
45
46class Event : public Core::Event {
47public:
48 enum Type {
49 Invalid = 2000,
50 Disconnected,
51 };
52 Event() {}
53 explicit Event(Type type)
54 : Core::Event(type)
55 {
56 }
57};
58
59class DisconnectedEvent : public Event {
60public:
61 explicit DisconnectedEvent(int client_id)
62 : Event(Disconnected)
63 , m_client_id(client_id)
64 {
65 }
66
67 int client_id() const { return m_client_id; }
68
69private:
70 int m_client_id { 0 };
71};
72
73template<typename T, class... Args>
74NonnullRefPtr<T> new_client_connection(Args&&... args)
75{
76 return T::construct(forward<Args>(args)...) /* arghs */;
77}
78
79template<typename Endpoint>
80class ClientConnection : public Core::Object {
81public:
82 ClientConnection(Endpoint& endpoint, Core::LocalSocket& socket, int client_id)
83 : m_endpoint(endpoint)
84 , m_socket(socket)
85 , m_client_id(client_id)
86 {
87 ASSERT(socket.is_connected());
88 ucred creds;
89 socklen_t creds_size = sizeof(creds);
90 if (getsockopt(m_socket->fd(), SOL_SOCKET, SO_PEERCRED, &creds, &creds_size) < 0) {
91 ASSERT_NOT_REACHED();
92 }
93 m_client_pid = creds.pid;
94 add_child(socket);
95 m_socket->on_ready_to_read = [this] { drain_messages_from_client(); };
96 }
97
98 virtual ~ClientConnection() override
99 {
100 }
101
102 void post_message(const Message& message)
103 {
104 // NOTE: If this connection is being shut down, but has not yet been destroyed,
105 // the socket will be closed. Don't try to send more messages.
106 if (!m_socket->is_open())
107 return;
108
109 auto buffer = message.encode();
110
111 int nwritten = write(m_socket->fd(), buffer.data(), (size_t)buffer.size());
112 if (nwritten < 0) {
113 switch (errno) {
114 case EPIPE:
115 dbg() << *this << "::post_message: Disconnected from peer";
116 shutdown();
117 return;
118 case EAGAIN:
119 dbg() << *this << "::post_message: Client buffer overflowed.";
120 did_misbehave();
121 return;
122 default:
123 perror("Connection::post_message write");
124 ASSERT_NOT_REACHED();
125 }
126 }
127
128 ASSERT(static_cast<size_t>(nwritten) == buffer.size());
129 }
130
131 void drain_messages_from_client()
132 {
133 Vector<u8> bytes;
134 for (;;) {
135 u8 buffer[4096];
136 ssize_t nread = recv(m_socket->fd(), buffer, sizeof(buffer), MSG_DONTWAIT);
137 if (nread == 0 || (nread == -1 && errno == EAGAIN)) {
138 if (bytes.is_empty()) {
139 Core::EventLoop::current().post_event(*this, make<DisconnectedEvent>(client_id()));
140 return;
141 }
142 break;
143 }
144 if (nread < 0) {
145 perror("recv");
146 ASSERT_NOT_REACHED();
147 }
148 bytes.append(buffer, nread);
149 }
150
151 size_t decoded_bytes = 0;
152 for (size_t index = 0; index < (size_t)bytes.size(); index += decoded_bytes) {
153 auto remaining_bytes = ByteBuffer::wrap(bytes.data() + index, bytes.size() - index);
154 auto message = Endpoint::decode_message(remaining_bytes, decoded_bytes);
155 if (!message) {
156 dbg() << "drain_messages_from_client: Endpoint didn't recognize message";
157 did_misbehave();
158 return;
159 }
160 if (auto response = m_endpoint.handle(*message))
161 post_message(*response);
162 ASSERT(decoded_bytes);
163 }
164 }
165
166 void did_misbehave()
167 {
168 dbg() << *this << " (id=" << m_client_id << ", pid=" << m_client_pid << ") misbehaved, disconnecting.";
169 shutdown();
170 }
171
172 void did_misbehave(const char* message)
173 {
174 dbg() << *this << " (id=" << m_client_id << ", pid=" << m_client_pid << ") misbehaved (" << message << "), disconnecting.";
175 shutdown();
176 }
177
178 void shutdown()
179 {
180 m_socket->close();
181 die();
182 }
183
184 int client_id() const { return m_client_id; }
185 pid_t client_pid() const { return m_client_pid; }
186
187 virtual void die() = 0;
188
189protected:
190 void event(Core::Event& event) override
191 {
192 if (event.type() == Event::Disconnected) {
193 int client_id = static_cast<const DisconnectedEvent&>(event).client_id();
194 dbg() << *this << ": Client disconnected: " << client_id;
195 die();
196 return;
197 }
198
199 Core::Object::event(event);
200 }
201
202private:
203 Endpoint& m_endpoint;
204 RefPtr<Core::LocalSocket> m_socket;
205 int m_client_id { -1 };
206 int m_client_pid { -1 };
207};
208
209}