Serenity Operating System
1/*
2 * Copyright (c) 2021, sin-ack <sin-ack@protonmail.com>
3 * Copyright (c) 2022, the SerenityOS developers.
4 *
5 * SPDX-License-Identifier: BSD-2-Clause
6 */
7
8#pragma once
9
10#include <AK/BufferedStream.h>
11#include <AK/Function.h>
12#include <AK/Stream.h>
13#include <AK/Time.h>
14#include <LibCore/Notifier.h>
15#include <LibCore/SocketAddress.h>
16
17namespace Core {
18
19/// The Socket class is the base class for all concrete BSD-style socket
20/// classes. Sockets are non-seekable streams which can be read byte-wise.
21class Socket : public Stream {
22public:
23 Socket(Socket&&) = default;
24 Socket& operator=(Socket&&) = default;
25
26 /// Checks how many bytes of data are currently available to read on the
27 /// socket. For datagram-based socket, this is the size of the first
28 /// datagram that can be read. Returns either the amount of bytes, or an
29 /// errno in the case of failure.
30 virtual ErrorOr<size_t> pending_bytes() const = 0;
31 /// Returns whether there's any data that can be immediately read, or an
32 /// errno on failure.
33 virtual ErrorOr<bool> can_read_without_blocking(int timeout = 0) const = 0;
34 // Sets the blocking mode of the socket. If blocking mode is disabled, reads
35 // will fail with EAGAIN when there's no data available to read, and writes
36 // will fail with EAGAIN when the data cannot be written without blocking
37 // (due to the send buffer being full, for example).
38 virtual ErrorOr<void> set_blocking(bool enabled) = 0;
39 // Sets the close-on-exec mode of the socket. If close-on-exec mode is
40 // enabled, then the socket will be automatically closed by the kernel when
41 // an exec call happens.
42 virtual ErrorOr<void> set_close_on_exec(bool enabled) = 0;
43
44 /// Disables any listening mechanisms that this socket uses.
45 /// Can be called with 'false' when `on_ready_to_read` notifications are no longer needed.
46 /// Conversely, set_notifications_enabled(true) will re-enable notifications.
47 virtual void set_notifications_enabled(bool) { }
48
49 Function<void()> on_ready_to_read;
50
51 enum class PreventSIGPIPE {
52 No,
53 Yes,
54 };
55
56protected:
57 enum class SocketDomain {
58 Local,
59 Inet,
60 };
61
62 enum class SocketType {
63 Stream,
64 Datagram,
65 };
66
67 Socket(PreventSIGPIPE prevent_sigpipe = PreventSIGPIPE::No)
68 : m_prevent_sigpipe(prevent_sigpipe == PreventSIGPIPE::Yes)
69 {
70 }
71
72 static ErrorOr<int> create_fd(SocketDomain, SocketType);
73 // FIXME: This will need to be updated when IPv6 socket arrives. Perhaps a
74 // base class for all address types is appropriate.
75 static ErrorOr<IPv4Address> resolve_host(DeprecatedString const&, SocketType);
76
77 static ErrorOr<void> connect_local(int fd, DeprecatedString const& path);
78 static ErrorOr<void> connect_inet(int fd, SocketAddress const&);
79
80 int default_flags() const
81 {
82 int flags = 0;
83 if (m_prevent_sigpipe)
84 flags |= MSG_NOSIGNAL;
85 return flags;
86 }
87
88private:
89 bool m_prevent_sigpipe { false };
90};
91
92/// A reusable socket maintains state about being connected in addition to
93/// normal Socket capabilities, and can be reconnected once disconnected.
94class ReusableSocket : public Socket {
95public:
96 /// Returns whether the socket is currently connected.
97 virtual bool is_connected() = 0;
98 /// Reconnects the socket to the given host and port. Returns EALREADY if
99 /// is_connected() is true.
100 virtual ErrorOr<void> reconnect(DeprecatedString const& host, u16 port) = 0;
101 /// Connects the socket to the given socket address (IP address + port).
102 /// Returns EALREADY is_connected() is true.
103 virtual ErrorOr<void> reconnect(SocketAddress const&) = 0;
104};
105
106class PosixSocketHelper {
107 AK_MAKE_NONCOPYABLE(PosixSocketHelper);
108
109public:
110 template<typename T>
111 PosixSocketHelper(Badge<T>)
112 requires(IsBaseOf<Socket, T>)
113 {
114 }
115
116 PosixSocketHelper(PosixSocketHelper&& other)
117 {
118 operator=(move(other));
119 }
120
121 PosixSocketHelper& operator=(PosixSocketHelper&& other)
122 {
123 m_fd = exchange(other.m_fd, -1);
124 m_last_read_was_eof = exchange(other.m_last_read_was_eof, false);
125 m_notifier = move(other.m_notifier);
126 return *this;
127 }
128
129 int fd() const { return m_fd; }
130 void set_fd(int fd) { m_fd = fd; }
131
132 ErrorOr<Bytes> read(Bytes, int flags);
133 ErrorOr<size_t> write(ReadonlyBytes, int flags);
134
135 bool is_eof() const { return !is_open() || m_last_read_was_eof; }
136 bool is_open() const { return m_fd != -1; }
137 void close();
138
139 ErrorOr<size_t> pending_bytes() const;
140 ErrorOr<bool> can_read_without_blocking(int timeout) const;
141
142 ErrorOr<void> set_blocking(bool enabled);
143 ErrorOr<void> set_close_on_exec(bool enabled);
144 ErrorOr<void> set_receive_timeout(Time timeout);
145
146 void setup_notifier();
147 RefPtr<Core::Notifier> notifier() { return m_notifier; }
148
149private:
150 int m_fd { -1 };
151 bool m_last_read_was_eof { false };
152 RefPtr<Core::Notifier> m_notifier;
153};
154
155class TCPSocket final : public Socket {
156public:
157 static ErrorOr<NonnullOwnPtr<TCPSocket>> connect(DeprecatedString const& host, u16 port);
158 static ErrorOr<NonnullOwnPtr<TCPSocket>> connect(SocketAddress const& address);
159 static ErrorOr<NonnullOwnPtr<TCPSocket>> adopt_fd(int fd);
160
161 TCPSocket(TCPSocket&& other)
162 : Socket(static_cast<Socket&&>(other))
163 , m_helper(move(other.m_helper))
164 {
165 if (is_open())
166 setup_notifier();
167 }
168
169 TCPSocket& operator=(TCPSocket&& other)
170 {
171 Socket::operator=(static_cast<Socket&&>(other));
172 m_helper = move(other.m_helper);
173 if (is_open())
174 setup_notifier();
175
176 return *this;
177 }
178
179 virtual ErrorOr<Bytes> read_some(Bytes buffer) override { return m_helper.read(buffer, default_flags()); }
180 virtual ErrorOr<size_t> write_some(ReadonlyBytes buffer) override { return m_helper.write(buffer, default_flags()); }
181 virtual bool is_eof() const override { return m_helper.is_eof(); }
182 virtual bool is_open() const override { return m_helper.is_open(); };
183 virtual void close() override { m_helper.close(); };
184 virtual ErrorOr<size_t> pending_bytes() const override { return m_helper.pending_bytes(); }
185 virtual ErrorOr<bool> can_read_without_blocking(int timeout = 0) const override { return m_helper.can_read_without_blocking(timeout); }
186 virtual void set_notifications_enabled(bool enabled) override
187 {
188 if (auto notifier = m_helper.notifier())
189 notifier->set_enabled(enabled);
190 }
191 ErrorOr<void> set_blocking(bool enabled) override { return m_helper.set_blocking(enabled); }
192 ErrorOr<void> set_close_on_exec(bool enabled) override { return m_helper.set_close_on_exec(enabled); }
193
194 virtual ~TCPSocket() override { close(); }
195
196private:
197 TCPSocket(PreventSIGPIPE prevent_sigpipe = PreventSIGPIPE::No)
198 : Socket(prevent_sigpipe)
199 {
200 }
201
202 void setup_notifier()
203 {
204 VERIFY(is_open());
205
206 m_helper.setup_notifier();
207 m_helper.notifier()->on_ready_to_read = [this] {
208 if (on_ready_to_read)
209 on_ready_to_read();
210 };
211 }
212
213 PosixSocketHelper m_helper { Badge<TCPSocket> {} };
214};
215
216class UDPSocket final : public Socket {
217public:
218 static ErrorOr<NonnullOwnPtr<UDPSocket>> connect(DeprecatedString const& host, u16 port, Optional<Time> timeout = {});
219 static ErrorOr<NonnullOwnPtr<UDPSocket>> connect(SocketAddress const& address, Optional<Time> timeout = {});
220
221 UDPSocket(UDPSocket&& other)
222 : Socket(static_cast<Socket&&>(other))
223 , m_helper(move(other.m_helper))
224 {
225 if (is_open())
226 setup_notifier();
227 }
228
229 UDPSocket& operator=(UDPSocket&& other)
230 {
231 Socket::operator=(static_cast<Socket&&>(other));
232 m_helper = move(other.m_helper);
233 if (is_open())
234 setup_notifier();
235
236 return *this;
237 }
238
239 virtual ErrorOr<Bytes> read_some(Bytes buffer) override
240 {
241 auto pending_bytes = TRY(this->pending_bytes());
242 if (pending_bytes > buffer.size()) {
243 // With UDP datagrams, reading a datagram into a buffer that's
244 // smaller than the datagram's size will cause the rest of the
245 // datagram to be discarded. That's not very nice, so let's bail
246 // early, telling the caller that he should allocate a bigger
247 // buffer.
248 return Error::from_errno(EMSGSIZE);
249 }
250
251 return m_helper.read(buffer, default_flags());
252 }
253
254 virtual ErrorOr<size_t> write_some(ReadonlyBytes buffer) override { return m_helper.write(buffer, default_flags()); }
255 virtual bool is_eof() const override { return m_helper.is_eof(); }
256 virtual bool is_open() const override { return m_helper.is_open(); }
257 virtual void close() override { m_helper.close(); }
258 virtual ErrorOr<size_t> pending_bytes() const override { return m_helper.pending_bytes(); }
259 virtual ErrorOr<bool> can_read_without_blocking(int timeout = 0) const override { return m_helper.can_read_without_blocking(timeout); }
260 virtual void set_notifications_enabled(bool enabled) override
261 {
262 if (auto notifier = m_helper.notifier())
263 notifier->set_enabled(enabled);
264 }
265 ErrorOr<void> set_blocking(bool enabled) override { return m_helper.set_blocking(enabled); }
266 ErrorOr<void> set_close_on_exec(bool enabled) override { return m_helper.set_close_on_exec(enabled); }
267
268 virtual ~UDPSocket() override { close(); }
269
270private:
271 UDPSocket(PreventSIGPIPE prevent_sigpipe = PreventSIGPIPE::No)
272 : Socket(prevent_sigpipe)
273 {
274 }
275
276 void setup_notifier()
277 {
278 VERIFY(is_open());
279
280 m_helper.setup_notifier();
281 m_helper.notifier()->on_ready_to_read = [this] {
282 if (on_ready_to_read)
283 on_ready_to_read();
284 };
285 }
286
287 PosixSocketHelper m_helper { Badge<UDPSocket> {} };
288};
289
290class LocalSocket final : public Socket {
291public:
292 static ErrorOr<NonnullOwnPtr<LocalSocket>> connect(DeprecatedString const& path, PreventSIGPIPE = PreventSIGPIPE::No);
293 static ErrorOr<NonnullOwnPtr<LocalSocket>> adopt_fd(int fd, PreventSIGPIPE = PreventSIGPIPE::No);
294
295 LocalSocket(LocalSocket&& other)
296 : Socket(static_cast<Socket&&>(other))
297 , m_helper(move(other.m_helper))
298 {
299 if (is_open())
300 setup_notifier();
301 }
302
303 LocalSocket& operator=(LocalSocket&& other)
304 {
305 Socket::operator=(static_cast<Socket&&>(other));
306 m_helper = move(other.m_helper);
307 if (is_open())
308 setup_notifier();
309
310 return *this;
311 }
312
313 virtual ErrorOr<Bytes> read_some(Bytes buffer) override { return m_helper.read(buffer, default_flags()); }
314 virtual ErrorOr<size_t> write_some(ReadonlyBytes buffer) override { return m_helper.write(buffer, default_flags()); }
315 virtual bool is_eof() const override { return m_helper.is_eof(); }
316 virtual bool is_open() const override { return m_helper.is_open(); }
317 virtual void close() override { m_helper.close(); }
318 virtual ErrorOr<size_t> pending_bytes() const override { return m_helper.pending_bytes(); }
319 virtual ErrorOr<bool> can_read_without_blocking(int timeout = 0) const override { return m_helper.can_read_without_blocking(timeout); }
320 virtual ErrorOr<void> set_blocking(bool enabled) override { return m_helper.set_blocking(enabled); }
321 virtual ErrorOr<void> set_close_on_exec(bool enabled) override { return m_helper.set_close_on_exec(enabled); }
322 virtual void set_notifications_enabled(bool enabled) override
323 {
324 if (auto notifier = m_helper.notifier())
325 notifier->set_enabled(enabled);
326 }
327
328 ErrorOr<int> receive_fd(int flags);
329 ErrorOr<void> send_fd(int fd);
330 ErrorOr<pid_t> peer_pid() const;
331 ErrorOr<Bytes> read_without_waiting(Bytes buffer);
332
333 /// Release the fd associated with this LocalSocket. After the fd is
334 /// released, the socket will be considered "closed" and all operations done
335 /// on it will fail with ENOTCONN. Fails with ENOTCONN if the socket is
336 /// already closed.
337 ErrorOr<int> release_fd();
338
339 Optional<int> fd() const;
340 RefPtr<Core::Notifier> notifier() { return m_helper.notifier(); }
341
342 virtual ~LocalSocket() { close(); }
343
344private:
345 LocalSocket(PreventSIGPIPE prevent_sigpipe = PreventSIGPIPE::No)
346 : Socket(prevent_sigpipe)
347 {
348 }
349
350 void setup_notifier()
351 {
352 VERIFY(is_open());
353
354 m_helper.setup_notifier();
355 m_helper.notifier()->on_ready_to_read = [this] {
356 if (on_ready_to_read)
357 on_ready_to_read();
358 };
359 }
360
361 PosixSocketHelper m_helper { Badge<LocalSocket> {} };
362};
363
364template<typename T>
365concept SocketLike = IsBaseOf<Socket, T>;
366
367class BufferedSocketBase : public Socket {
368public:
369 virtual ErrorOr<StringView> read_line(Bytes buffer) = 0;
370 virtual ErrorOr<Bytes> read_until(Bytes buffer, StringView candidate) = 0;
371 virtual ErrorOr<bool> can_read_line() = 0;
372 virtual size_t buffer_size() const = 0;
373};
374
375template<SocketLike T>
376class BufferedSocket final : public BufferedSocketBase {
377 friend BufferedHelper<T>;
378
379public:
380 static ErrorOr<NonnullOwnPtr<BufferedSocket<T>>> create(NonnullOwnPtr<T> stream, size_t buffer_size = 16384)
381 {
382 return BufferedHelper<T>::template create_buffered<BufferedSocket>(move(stream), buffer_size);
383 }
384
385 BufferedSocket(BufferedSocket&& other)
386 : BufferedSocketBase(static_cast<BufferedSocketBase&&>(other))
387 , m_helper(move(other.m_helper))
388 {
389 setup_notifier();
390 }
391
392 BufferedSocket& operator=(BufferedSocket&& other)
393 {
394 Socket::operator=(static_cast<Socket&&>(other));
395 m_helper = move(other.m_helper);
396
397 setup_notifier();
398 return *this;
399 }
400
401 virtual ErrorOr<Bytes> read_some(Bytes buffer) override { return m_helper.read(move(buffer)); }
402 virtual ErrorOr<size_t> write_some(ReadonlyBytes buffer) override { return m_helper.stream().write_some(buffer); }
403 virtual bool is_eof() const override { return m_helper.is_eof(); }
404 virtual bool is_open() const override { return m_helper.stream().is_open(); }
405 virtual void close() override { m_helper.stream().close(); }
406 virtual ErrorOr<size_t> pending_bytes() const override
407 {
408 return TRY(m_helper.stream().pending_bytes()) + m_helper.buffered_data_size();
409 }
410 virtual ErrorOr<bool> can_read_without_blocking(int timeout = 0) const override { return m_helper.buffered_data_size() > 0 || TRY(m_helper.stream().can_read_without_blocking(timeout)); }
411 virtual ErrorOr<void> set_blocking(bool enabled) override { return m_helper.stream().set_blocking(enabled); }
412 virtual ErrorOr<void> set_close_on_exec(bool enabled) override { return m_helper.stream().set_close_on_exec(enabled); }
413 virtual void set_notifications_enabled(bool enabled) override { m_helper.stream().set_notifications_enabled(enabled); }
414
415 virtual ErrorOr<StringView> read_line(Bytes buffer) override { return m_helper.read_line(move(buffer)); }
416 virtual ErrorOr<Bytes> read_until(Bytes buffer, StringView candidate) override { return m_helper.read_until(move(buffer), move(candidate)); }
417 template<size_t N>
418 ErrorOr<Bytes> read_until_any_of(Bytes buffer, Array<StringView, N> candidates) { return m_helper.read_until_any_of(move(buffer), move(candidates)); }
419 virtual ErrorOr<bool> can_read_line() override { return m_helper.can_read_line(); }
420
421 virtual size_t buffer_size() const override { return m_helper.buffer_size(); }
422
423 virtual ~BufferedSocket() override = default;
424
425private:
426 BufferedSocket(NonnullOwnPtr<T> stream, CircularBuffer buffer)
427 : m_helper(Badge<BufferedSocket<T>> {}, move(stream), move(buffer))
428 {
429 setup_notifier();
430 }
431
432 void setup_notifier()
433 {
434 m_helper.stream().on_ready_to_read = [this] {
435 if (on_ready_to_read)
436 on_ready_to_read();
437 };
438 }
439
440 BufferedHelper<T> m_helper;
441};
442
443using BufferedTCPSocket = BufferedSocket<TCPSocket>;
444using BufferedUDPSocket = BufferedSocket<UDPSocket>;
445using BufferedLocalSocket = BufferedSocket<LocalSocket>;
446
447/// A BasicReusableSocket allows one to use one of the base Core::Stream classes
448/// as a ReusableSocket. It does not preserve any connection state or options,
449/// and instead just recreates the stream when reconnecting.
450template<SocketLike T>
451class BasicReusableSocket final : public ReusableSocket {
452public:
453 static ErrorOr<NonnullOwnPtr<BasicReusableSocket<T>>> connect(DeprecatedString const& host, u16 port)
454 {
455 return make<BasicReusableSocket<T>>(TRY(T::connect(host, port)));
456 }
457
458 static ErrorOr<NonnullOwnPtr<BasicReusableSocket<T>>> connect(SocketAddress const& address)
459 {
460 return make<BasicReusableSocket<T>>(TRY(T::connect(address)));
461 }
462
463 virtual bool is_connected() override
464 {
465 return m_socket.is_open();
466 }
467
468 virtual ErrorOr<void> reconnect(DeprecatedString const& host, u16 port) override
469 {
470 if (is_connected())
471 return Error::from_errno(EALREADY);
472
473 m_socket = TRY(T::connect(host, port));
474 return {};
475 }
476
477 virtual ErrorOr<void> reconnect(SocketAddress const& address) override
478 {
479 if (is_connected())
480 return Error::from_errno(EALREADY);
481
482 m_socket = TRY(T::connect(address));
483 return {};
484 }
485
486 virtual ErrorOr<Bytes> read_some(Bytes buffer) override { return m_socket.read(move(buffer)); }
487 virtual ErrorOr<size_t> write_some(ReadonlyBytes buffer) override { return m_socket.write(buffer); }
488 virtual bool is_eof() const override { return m_socket.is_eof(); }
489 virtual bool is_open() const override { return m_socket.is_open(); }
490 virtual void close() override { m_socket.close(); }
491 virtual ErrorOr<size_t> pending_bytes() const override { return m_socket.pending_bytes(); }
492 virtual ErrorOr<bool> can_read_without_blocking(int timeout = 0) const override { return m_socket.can_read_without_blocking(timeout); }
493 virtual ErrorOr<void> set_blocking(bool enabled) override { return m_socket.set_blocking(enabled); }
494 virtual ErrorOr<void> set_close_on_exec(bool enabled) override { return m_socket.set_close_on_exec(enabled); }
495
496private:
497 BasicReusableSocket(NonnullOwnPtr<T> socket)
498 : m_socket(move(socket))
499 {
500 }
501
502 NonnullOwnPtr<T> m_socket;
503};
504
505using ReusableTCPSocket = BasicReusableSocket<TCPSocket>;
506using ReusableUDPSocket = BasicReusableSocket<UDPSocket>;
507
508}