Serenity Operating System
1/*
2 * Copyright (c) 2018-2020, Andreas Kling <kling@serenityos.org>
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 *
8 * 1. Redistributions of source code must retain the above copyright notice, this
9 * list of conditions and the following disclaimer.
10 *
11 * 2. Redistributions in binary form must reproduce the above copyright notice,
12 * this list of conditions and the following disclaimer in the documentation
13 * and/or other materials provided with the distribution.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
19 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
21 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
22 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
23 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
24 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 */
26
27#include <AK/Time.h>
28#include <Kernel/Devices/RandomDevice.h>
29#include <Kernel/FileSystem/FileDescription.h>
30#include <Kernel/Net/NetworkAdapter.h>
31#include <Kernel/Net/Routing.h>
32#include <Kernel/Net/TCP.h>
33#include <Kernel/Net/TCPSocket.h>
34#include <Kernel/Process.h>
35#include <Kernel/Random.h>
36
37//#define TCP_SOCKET_DEBUG
38
39namespace Kernel {
40
41void TCPSocket::for_each(Function<void(TCPSocket&)> callback)
42{
43 LOCKER(sockets_by_tuple().lock());
44 for (auto& it : sockets_by_tuple().resource())
45 callback(*it.value);
46}
47
48void TCPSocket::set_state(State new_state)
49{
50#ifdef TCP_SOCKET_DEBUG
51 dbg() << "TCPSocket{" << this << "} state moving from " << to_string(m_state) << " to " << to_string(new_state);
52#endif
53
54 m_state = new_state;
55
56 if (new_state == State::Established && m_direction == Direction::Outgoing)
57 m_role = Role::Connected;
58
59 if (new_state == State::Closed) {
60 LOCKER(closing_sockets().lock());
61 closing_sockets().resource().remove(tuple());
62 }
63}
64
65Lockable<HashMap<IPv4SocketTuple, RefPtr<TCPSocket>>>& TCPSocket::closing_sockets()
66{
67 static Lockable<HashMap<IPv4SocketTuple, RefPtr<TCPSocket>>>* s_map;
68 if (!s_map)
69 s_map = new Lockable<HashMap<IPv4SocketTuple, RefPtr<TCPSocket>>>;
70 return *s_map;
71}
72
73Lockable<HashMap<IPv4SocketTuple, TCPSocket*>>& TCPSocket::sockets_by_tuple()
74{
75 static Lockable<HashMap<IPv4SocketTuple, TCPSocket*>>* s_map;
76 if (!s_map)
77 s_map = new Lockable<HashMap<IPv4SocketTuple, TCPSocket*>>;
78 return *s_map;
79}
80
81RefPtr<TCPSocket> TCPSocket::from_tuple(const IPv4SocketTuple& tuple)
82{
83 LOCKER(sockets_by_tuple().lock());
84
85 auto exact_match = sockets_by_tuple().resource().get(tuple);
86 if (exact_match.has_value())
87 return { *exact_match.value() };
88
89 auto address_tuple = IPv4SocketTuple(tuple.local_address(), tuple.local_port(), IPv4Address(), 0);
90 auto address_match = sockets_by_tuple().resource().get(address_tuple);
91 if (address_match.has_value())
92 return { *address_match.value() };
93
94 auto wildcard_tuple = IPv4SocketTuple(IPv4Address(), tuple.local_port(), IPv4Address(), 0);
95 auto wildcard_match = sockets_by_tuple().resource().get(wildcard_tuple);
96 if (wildcard_match.has_value())
97 return { *wildcard_match.value() };
98
99 return {};
100}
101
102RefPtr<TCPSocket> TCPSocket::from_endpoints(const IPv4Address& local_address, u16 local_port, const IPv4Address& peer_address, u16 peer_port)
103{
104 return from_tuple(IPv4SocketTuple(local_address, local_port, peer_address, peer_port));
105}
106
107RefPtr<TCPSocket> TCPSocket::create_client(const IPv4Address& new_local_address, u16 new_local_port, const IPv4Address& new_peer_address, u16 new_peer_port)
108{
109 auto tuple = IPv4SocketTuple(new_local_address, new_local_port, new_peer_address, new_peer_port);
110
111 LOCKER(sockets_by_tuple().lock());
112 if (sockets_by_tuple().resource().contains(tuple))
113 return {};
114
115 auto client = TCPSocket::create(protocol());
116
117 client->set_setup_state(SetupState::InProgress);
118 client->set_local_address(new_local_address);
119 client->set_local_port(new_local_port);
120 client->set_peer_address(new_peer_address);
121 client->set_peer_port(new_peer_port);
122 client->set_direction(Direction::Incoming);
123 client->set_originator(*this);
124
125 m_pending_release_for_accept.set(tuple, client);
126 sockets_by_tuple().resource().set(tuple, client);
127
128 return from_tuple(tuple);
129}
130
131void TCPSocket::release_to_originator()
132{
133 ASSERT(!!m_originator);
134 m_originator->release_for_accept(this);
135}
136
137void TCPSocket::release_for_accept(RefPtr<TCPSocket> socket)
138{
139 ASSERT(m_pending_release_for_accept.contains(socket->tuple()));
140 m_pending_release_for_accept.remove(socket->tuple());
141 queue_connection_from(*socket);
142}
143
144TCPSocket::TCPSocket(int protocol)
145 : IPv4Socket(SOCK_STREAM, protocol)
146{
147}
148
149TCPSocket::~TCPSocket()
150{
151 LOCKER(sockets_by_tuple().lock());
152 sockets_by_tuple().resource().remove(tuple());
153
154#ifdef TCP_SOCKET_DEBUG
155 dbg() << "~TCPSocket in state " << to_string(state());
156#endif
157}
158
159NonnullRefPtr<TCPSocket> TCPSocket::create(int protocol)
160{
161 return adopt(*new TCPSocket(protocol));
162}
163
164int TCPSocket::protocol_receive(const KBuffer& packet_buffer, void* buffer, size_t buffer_size, int flags)
165{
166 (void)flags;
167 auto& ipv4_packet = *(const IPv4Packet*)(packet_buffer.data());
168 auto& tcp_packet = *static_cast<const TCPPacket*>(ipv4_packet.payload());
169 size_t payload_size = packet_buffer.size() - sizeof(IPv4Packet) - tcp_packet.header_size();
170#ifdef TCP_SOCKET_DEBUG
171 klog() << "payload_size " << payload_size << ", will it fit in " << buffer_size << "?";
172#endif
173 ASSERT(buffer_size >= payload_size);
174 memcpy(buffer, tcp_packet.payload(), payload_size);
175 return payload_size;
176}
177
178int TCPSocket::protocol_send(const void* data, size_t data_length)
179{
180 send_tcp_packet(TCPFlags::PUSH | TCPFlags::ACK, data, data_length);
181 return data_length;
182}
183
184void TCPSocket::send_tcp_packet(u16 flags, const void* payload, size_t payload_size)
185{
186 auto buffer = ByteBuffer::create_zeroed(sizeof(TCPPacket) + payload_size);
187 auto& tcp_packet = *(TCPPacket*)(buffer.data());
188 ASSERT(local_port());
189 tcp_packet.set_source_port(local_port());
190 tcp_packet.set_destination_port(peer_port());
191 tcp_packet.set_window_size(1024);
192 tcp_packet.set_sequence_number(m_sequence_number);
193 tcp_packet.set_data_offset(sizeof(TCPPacket) / sizeof(u32));
194 tcp_packet.set_flags(flags);
195
196 if (flags & TCPFlags::ACK)
197 tcp_packet.set_ack_number(m_ack_number);
198
199 if (flags & TCPFlags::SYN) {
200 ++m_sequence_number;
201 } else {
202 m_sequence_number += payload_size;
203 }
204
205 memcpy(tcp_packet.payload(), payload, payload_size);
206 tcp_packet.set_checksum(compute_tcp_checksum(local_address(), peer_address(), tcp_packet, payload_size));
207
208 if (tcp_packet.has_syn() || payload_size > 0) {
209 LOCKER(m_not_acked_lock);
210 m_not_acked.append({ m_sequence_number, move(buffer) });
211 send_outgoing_packets();
212 return;
213 }
214
215 auto routing_decision = route_to(peer_address(), local_address(), bound_interface());
216 ASSERT(!routing_decision.is_zero());
217
218 routing_decision.adapter->send_ipv4(
219 routing_decision.next_hop, peer_address(), IPv4Protocol::TCP,
220 buffer.data(), buffer.size(), ttl());
221
222 m_packets_out++;
223 m_bytes_out += buffer.size();
224}
225
226void TCPSocket::send_outgoing_packets()
227{
228 auto routing_decision = route_to(peer_address(), local_address(), bound_interface());
229 ASSERT(!routing_decision.is_zero());
230
231 auto now = kgettimeofday();
232
233 LOCKER(m_not_acked_lock);
234 for (auto& packet : m_not_acked) {
235 timeval diff;
236 timeval_sub(packet.tx_time, now, diff);
237 if (diff.tv_sec == 0 && diff.tv_usec <= 500000)
238 continue;
239 packet.tx_time = now;
240 packet.tx_counter++;
241
242#ifdef TCP_SOCKET_DEBUG
243 auto& tcp_packet = *(TCPPacket*)(packet.buffer.data());
244 klog() << "sending tcp packet from " << local_address().to_string().characters() << ":" << local_port() << " to " << peer_address().to_string().characters() << ":" << peer_port() << " with (" << (tcp_packet.has_syn() ? "SYN " : "") << (tcp_packet.has_ack() ? "ACK " : "") << (tcp_packet.has_fin() ? "FIN " : "") << (tcp_packet.has_rst() ? "RST " : "") << ") seq_no=" << tcp_packet.sequence_number() << ", ack_no=" << tcp_packet.ack_number() << ", tx_counter=" << packet.tx_counter;
245#endif
246 routing_decision.adapter->send_ipv4(
247 routing_decision.next_hop, peer_address(), IPv4Protocol::TCP,
248 packet.buffer.data(), packet.buffer.size(), ttl());
249
250 m_packets_out++;
251 m_bytes_out += packet.buffer.size();
252 }
253}
254
255void TCPSocket::receive_tcp_packet(const TCPPacket& packet, u16 size)
256{
257 if (packet.has_ack()) {
258 u32 ack_number = packet.ack_number();
259
260#ifdef TCP_SOCKET_DEBUG
261 dbg() << "TCPSocket: receive_tcp_packet: " << ack_number;
262#endif
263
264 int removed = 0;
265 LOCKER(m_not_acked_lock);
266 while (!m_not_acked.is_empty()) {
267 auto& packet = m_not_acked.first();
268
269#ifdef TCP_SOCKET_DEBUG
270 dbg() << "TCPSocket: iterate: " << packet.ack_number;
271#endif
272
273 if (packet.ack_number <= ack_number) {
274 m_not_acked.take_first();
275 removed++;
276 } else {
277 break;
278 }
279 }
280
281#ifdef TCP_SOCKET_DEBUG
282 dbg() << "TCPSocket: receive_tcp_packet acknowledged " << removed << " packets";
283#endif
284 }
285
286 m_packets_in++;
287 m_bytes_in += packet.header_size() + size;
288}
289
290NetworkOrdered<u16> TCPSocket::compute_tcp_checksum(const IPv4Address& source, const IPv4Address& destination, const TCPPacket& packet, u16 payload_size)
291{
292 struct [[gnu::packed]] PseudoHeader
293 {
294 IPv4Address source;
295 IPv4Address destination;
296 u8 zero;
297 u8 protocol;
298 NetworkOrdered<u16> payload_size;
299 };
300
301 PseudoHeader pseudo_header { source, destination, 0, (u8)IPv4Protocol::TCP, sizeof(TCPPacket) + payload_size };
302
303 u32 checksum = 0;
304 auto* w = (const NetworkOrdered<u16>*)&pseudo_header;
305 for (size_t i = 0; i < sizeof(pseudo_header) / sizeof(u16); ++i) {
306 checksum += w[i];
307 if (checksum > 0xffff)
308 checksum = (checksum >> 16) + (checksum & 0xffff);
309 }
310 w = (const NetworkOrdered<u16>*)&packet;
311 for (size_t i = 0; i < sizeof(packet) / sizeof(u16); ++i) {
312 checksum += w[i];
313 if (checksum > 0xffff)
314 checksum = (checksum >> 16) + (checksum & 0xffff);
315 }
316 ASSERT(packet.data_offset() * 4 == sizeof(TCPPacket));
317 w = (const NetworkOrdered<u16>*)packet.payload();
318 for (size_t i = 0; i < payload_size / sizeof(u16); ++i) {
319 checksum += w[i];
320 if (checksum > 0xffff)
321 checksum = (checksum >> 16) + (checksum & 0xffff);
322 }
323 if (payload_size & 1) {
324 u16 expanded_byte = ((const u8*)packet.payload())[payload_size - 1] << 8;
325 checksum += expanded_byte;
326 if (checksum > 0xffff)
327 checksum = (checksum >> 16) + (checksum & 0xffff);
328 }
329 return ~(checksum & 0xffff);
330}
331
332KResult TCPSocket::protocol_bind()
333{
334 if (has_specific_local_address() && !m_adapter) {
335 m_adapter = NetworkAdapter::from_ipv4_address(local_address());
336 if (!m_adapter)
337 return KResult(-EADDRNOTAVAIL);
338 }
339
340 return KSuccess;
341}
342
343KResult TCPSocket::protocol_listen()
344{
345 LOCKER(sockets_by_tuple().lock());
346 if (sockets_by_tuple().resource().contains(tuple()))
347 return KResult(-EADDRINUSE);
348 sockets_by_tuple().resource().set(tuple(), this);
349 set_direction(Direction::Passive);
350 set_state(State::Listen);
351 set_setup_state(SetupState::Completed);
352 return KSuccess;
353}
354
355KResult TCPSocket::protocol_connect(FileDescription& description, ShouldBlock should_block)
356{
357 auto routing_decision = route_to(peer_address(), local_address());
358 if (routing_decision.is_zero())
359 return KResult(-EHOSTUNREACH);
360 if (!has_specific_local_address())
361 set_local_address(routing_decision.adapter->ipv4_address());
362
363 allocate_local_port_if_needed();
364
365 m_sequence_number = get_good_random<u32>();
366 m_ack_number = 0;
367
368 set_setup_state(SetupState::InProgress);
369 send_tcp_packet(TCPFlags::SYN);
370 m_state = State::SynSent;
371 m_role = Role::Connecting;
372 m_direction = Direction::Outgoing;
373
374 if (should_block == ShouldBlock::Yes) {
375 if (Thread::current->block<Thread::ConnectBlocker>(description) != Thread::BlockResult::WokeNormally)
376 return KResult(-EINTR);
377 ASSERT(setup_state() == SetupState::Completed);
378 if (has_error()) {
379 m_role = Role::None;
380 return KResult(-ECONNREFUSED);
381 }
382 return KSuccess;
383 }
384
385 return KResult(-EINPROGRESS);
386}
387
388int TCPSocket::protocol_allocate_local_port()
389{
390 static const u16 first_ephemeral_port = 32768;
391 static const u16 last_ephemeral_port = 60999;
392 static const u16 ephemeral_port_range_size = last_ephemeral_port - first_ephemeral_port;
393 u16 first_scan_port = first_ephemeral_port + get_good_random<u16>() % ephemeral_port_range_size;
394
395 LOCKER(sockets_by_tuple().lock());
396 for (u16 port = first_scan_port;;) {
397 IPv4SocketTuple proposed_tuple(local_address(), port, peer_address(), peer_port());
398
399 auto it = sockets_by_tuple().resource().find(proposed_tuple);
400 if (it == sockets_by_tuple().resource().end()) {
401 set_local_port(port);
402 sockets_by_tuple().resource().set(proposed_tuple, this);
403 return port;
404 }
405 ++port;
406 if (port > last_ephemeral_port)
407 port = first_ephemeral_port;
408 if (port == first_scan_port)
409 break;
410 }
411 return -EADDRINUSE;
412}
413
414bool TCPSocket::protocol_is_disconnected() const
415{
416 switch (m_state) {
417 case State::Closed:
418 case State::CloseWait:
419 case State::LastAck:
420 case State::FinWait1:
421 case State::FinWait2:
422 case State::Closing:
423 case State::TimeWait:
424 return true;
425 default:
426 return false;
427 }
428}
429
430void TCPSocket::shut_down_for_writing()
431{
432 if (state() == State::Established) {
433#ifdef TCP_SOCKET_DEBUG
434 dbg() << " Sending FIN/ACK from Established and moving into FinWait1";
435#endif
436 send_tcp_packet(TCPFlags::FIN | TCPFlags::ACK);
437 set_state(State::FinWait1);
438 } else {
439 dbg() << " Shutting down TCPSocket for writing but not moving to FinWait1 since state is " << to_string(state());
440 }
441}
442
443void TCPSocket::close()
444{
445 IPv4Socket::close();
446 if (state() == State::CloseWait) {
447#ifdef TCP_SOCKET_DEBUG
448 dbg() << " Sending FIN from CloseWait and moving into LastAck";
449#endif
450 send_tcp_packet(TCPFlags::FIN | TCPFlags::ACK);
451 set_state(State::LastAck);
452 }
453
454 LOCKER(closing_sockets().lock());
455 closing_sockets().resource().set(tuple(), *this);
456}
457
458}