Serenity Operating System
1/*
2 * Copyright (c) 2018-2020, Andreas Kling <kling@serenityos.org>
3 *
4 * SPDX-License-Identifier: BSD-2-Clause
5 */
6
7#pragma once
8
9#include <AK/Error.h>
10#include <AK/Function.h>
11#include <AK/HashMap.h>
12#include <AK/SinglyLinkedList.h>
13#include <Kernel/Library/LockWeakPtr.h>
14#include <Kernel/Locking/MutexProtected.h>
15#include <Kernel/Net/IPv4Socket.h>
16
17namespace Kernel {
18
19class TCPSocket final : public IPv4Socket {
20public:
21 static void for_each(Function<void(TCPSocket const&)>);
22 static ErrorOr<void> try_for_each(Function<ErrorOr<void>(TCPSocket const&)>);
23 static ErrorOr<NonnullRefPtr<TCPSocket>> try_create(int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer);
24 virtual ~TCPSocket() override;
25
26 virtual bool unref() const override;
27
28 enum class Direction {
29 Unspecified,
30 Outgoing,
31 Incoming,
32 Passive,
33 };
34
35 static StringView to_string(Direction direction)
36 {
37 switch (direction) {
38 case Direction::Unspecified:
39 return "Unspecified"sv;
40 case Direction::Outgoing:
41 return "Outgoing"sv;
42 case Direction::Incoming:
43 return "Incoming"sv;
44 case Direction::Passive:
45 return "Passive"sv;
46 default:
47 return "None"sv;
48 }
49 }
50
51 enum class State {
52 Closed,
53 Listen,
54 SynSent,
55 SynReceived,
56 Established,
57 CloseWait,
58 LastAck,
59 FinWait1,
60 FinWait2,
61 Closing,
62 TimeWait,
63 };
64
65 static StringView to_string(State state)
66 {
67 switch (state) {
68 case State::Closed:
69 return "Closed"sv;
70 case State::Listen:
71 return "Listen"sv;
72 case State::SynSent:
73 return "SynSent"sv;
74 case State::SynReceived:
75 return "SynReceived"sv;
76 case State::Established:
77 return "Established"sv;
78 case State::CloseWait:
79 return "CloseWait"sv;
80 case State::LastAck:
81 return "LastAck"sv;
82 case State::FinWait1:
83 return "FinWait1"sv;
84 case State::FinWait2:
85 return "FinWait2"sv;
86 case State::Closing:
87 return "Closing"sv;
88 case State::TimeWait:
89 return "TimeWait"sv;
90 default:
91 return "None"sv;
92 }
93 }
94
95 enum class Error {
96 None,
97 FINDuringConnect,
98 RSTDuringConnect,
99 UnexpectedFlagsDuringConnect,
100 RetransmitTimeout,
101 };
102
103 static StringView to_string(Error error)
104 {
105 switch (error) {
106 case Error::None:
107 return "None"sv;
108 case Error::FINDuringConnect:
109 return "FINDuringConnect"sv;
110 case Error::RSTDuringConnect:
111 return "RSTDuringConnect"sv;
112 case Error::UnexpectedFlagsDuringConnect:
113 return "UnexpectedFlagsDuringConnect"sv;
114 default:
115 return "Invalid"sv;
116 }
117 }
118
119 State state() const { return m_state; }
120 void set_state(State state);
121
122 Direction direction() const { return m_direction; }
123
124 bool has_error() const { return m_error != Error::None; }
125 Error error() const { return m_error; }
126 void set_error(Error error) { m_error = error; }
127
128 void set_ack_number(u32 n) { m_ack_number = n; }
129 void set_sequence_number(u32 n) { m_sequence_number = n; }
130 u32 ack_number() const { return m_ack_number; }
131 u32 sequence_number() const { return m_sequence_number; }
132 u32 packets_in() const { return m_packets_in; }
133 u32 bytes_in() const { return m_bytes_in; }
134 u32 packets_out() const { return m_packets_out; }
135 u32 bytes_out() const { return m_bytes_out; }
136
137 // FIXME: Make this configurable?
138 static constexpr u32 maximum_duplicate_acks = 5;
139 void set_duplicate_acks(u32 acks) { m_duplicate_acks = acks; }
140 u32 duplicate_acks() const { return m_duplicate_acks; }
141
142 ErrorOr<void> send_ack(bool allow_duplicate = false);
143 ErrorOr<void> send_tcp_packet(u16 flags, UserOrKernelBuffer const* = nullptr, size_t = 0, RoutingDecision* = nullptr);
144 void receive_tcp_packet(TCPPacket const&, u16 size);
145
146 bool should_delay_next_ack() const;
147
148 static MutexProtected<HashMap<IPv4SocketTuple, TCPSocket*>>& sockets_by_tuple();
149 static RefPtr<TCPSocket> from_tuple(IPv4SocketTuple const& tuple);
150
151 static MutexProtected<HashMap<IPv4SocketTuple, RefPtr<TCPSocket>>>& closing_sockets();
152
153 ErrorOr<NonnullRefPtr<TCPSocket>> try_create_client(IPv4Address const& local_address, u16 local_port, IPv4Address const& peer_address, u16 peer_port);
154 void set_originator(TCPSocket& originator) { m_originator = originator; }
155 bool has_originator() { return !!m_originator; }
156 void release_to_originator();
157 void release_for_accept(NonnullRefPtr<TCPSocket>);
158
159 void retransmit_packets();
160
161 virtual ErrorOr<void> close() override;
162
163 virtual bool can_write(OpenFileDescription const&, u64) const override;
164
165 static NetworkOrdered<u16> compute_tcp_checksum(IPv4Address const& source, IPv4Address const& destination, TCPPacket const&, u16 payload_size);
166
167protected:
168 void set_direction(Direction direction) { m_direction = direction; }
169
170private:
171 explicit TCPSocket(int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer, NonnullOwnPtr<KBuffer> scratch_buffer);
172 virtual StringView class_name() const override { return "TCPSocket"sv; }
173
174 virtual void shut_down_for_writing() override;
175
176 virtual ErrorOr<size_t> protocol_receive(ReadonlyBytes raw_ipv4_packet, UserOrKernelBuffer& buffer, size_t buffer_size, int flags) override;
177 virtual ErrorOr<size_t> protocol_send(UserOrKernelBuffer const&, size_t) override;
178 virtual ErrorOr<void> protocol_connect(OpenFileDescription&) override;
179 virtual ErrorOr<u16> protocol_allocate_local_port() override;
180 virtual ErrorOr<size_t> protocol_size(ReadonlyBytes raw_ipv4_packet) override;
181 virtual bool protocol_is_disconnected() const override;
182 virtual ErrorOr<void> protocol_bind() override;
183 virtual ErrorOr<void> protocol_listen(bool did_allocate_port) override;
184
185 void enqueue_for_retransmit();
186 void dequeue_for_retransmit();
187
188 LockWeakPtr<TCPSocket> m_originator;
189 HashMap<IPv4SocketTuple, NonnullRefPtr<TCPSocket>> m_pending_release_for_accept;
190 Direction m_direction { Direction::Unspecified };
191 Error m_error { Error::None };
192 LockRefPtr<NetworkAdapter> m_adapter;
193 u32 m_sequence_number { 0 };
194 u32 m_ack_number { 0 };
195 State m_state { State::Closed };
196 u32 m_packets_in { 0 };
197 u32 m_bytes_in { 0 };
198 u32 m_packets_out { 0 };
199 u32 m_bytes_out { 0 };
200
201 struct OutgoingPacket {
202 u32 ack_number { 0 };
203 LockRefPtr<PacketWithTimestamp> buffer;
204 size_t ipv4_payload_offset;
205 LockWeakPtr<NetworkAdapter> adapter;
206 int tx_counter { 0 };
207 };
208
209 struct UnackedPackets {
210 SinglyLinkedList<OutgoingPacket> packets;
211 size_t size { 0 };
212 };
213
214 MutexProtected<UnackedPackets> m_unacked_packets;
215
216 u32 m_duplicate_acks { 0 };
217
218 u32 m_last_ack_number_sent { 0 };
219 Time m_last_ack_sent_time;
220
221 // FIXME: Make this configurable (sysctl)
222 static constexpr u32 maximum_retransmits = 5;
223 Time m_last_retransmit_time;
224 u32 m_retransmit_attempts { 0 };
225
226 // FIXME: Parse window size TCP option from the peer
227 u32 m_send_window_size { 64 * KiB };
228
229 IntrusiveListNode<TCPSocket> m_retransmit_list_node;
230
231public:
232 using RetransmitList = IntrusiveList<&TCPSocket::m_retransmit_list_node>;
233 static MutexProtected<TCPSocket::RetransmitList>& sockets_for_retransmit();
234};
235
236}