Serenity Operating System
1/*
2 * Copyright (c) 2018-2020, Andreas Kling <kling@serenityos.org>
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 *
8 * 1. Redistributions of source code must retain the above copyright notice, this
9 * list of conditions and the following disclaimer.
10 *
11 * 2. Redistributions in binary form must reproduce the above copyright notice,
12 * this list of conditions and the following disclaimer in the documentation
13 * and/or other materials provided with the distribution.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
19 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
21 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
22 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
23 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
24 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 */
26
27#pragma once
28
29#include <AK/Function.h>
30#include <AK/HashMap.h>
31#include <AK/SinglyLinkedList.h>
32#include <AK/WeakPtr.h>
33#include <Kernel/Net/IPv4Socket.h>
34
35namespace Kernel {
36
37class TCPSocket final : public IPv4Socket
38 , public Weakable<TCPSocket> {
39public:
40 static void for_each(Function<void(TCPSocket&)>);
41 static NonnullRefPtr<TCPSocket> create(int protocol);
42 virtual ~TCPSocket() override;
43
44 enum class Direction {
45 Unspecified,
46 Outgoing,
47 Incoming,
48 Passive,
49 };
50
51 static const char* to_string(Direction direction)
52 {
53 switch (direction) {
54 case Direction::Unspecified:
55 return "Unspecified";
56 case Direction::Outgoing:
57 return "Outgoing";
58 case Direction::Incoming:
59 return "Incoming";
60 case Direction::Passive:
61 return "Passive";
62 default:
63 return "None";
64 }
65 }
66
67 enum class State {
68 Closed,
69 Listen,
70 SynSent,
71 SynReceived,
72 Established,
73 CloseWait,
74 LastAck,
75 FinWait1,
76 FinWait2,
77 Closing,
78 TimeWait,
79 };
80
81 static const char* to_string(State state)
82 {
83 switch (state) {
84 case State::Closed:
85 return "Closed";
86 case State::Listen:
87 return "Listen";
88 case State::SynSent:
89 return "SynSent";
90 case State::SynReceived:
91 return "SynReceived";
92 case State::Established:
93 return "Established";
94 case State::CloseWait:
95 return "CloseWait";
96 case State::LastAck:
97 return "LastAck";
98 case State::FinWait1:
99 return "FinWait1";
100 case State::FinWait2:
101 return "FinWait2";
102 case State::Closing:
103 return "Closing";
104 case State::TimeWait:
105 return "TimeWait";
106 default:
107 return "None";
108 }
109 }
110
111 enum class Error {
112 None,
113 FINDuringConnect,
114 RSTDuringConnect,
115 UnexpectedFlagsDuringConnect,
116 };
117
118 static const char* to_string(Error error)
119 {
120 switch (error) {
121 case Error::None:
122 return "None";
123 case Error::FINDuringConnect:
124 return "FINDuringConnect";
125 case Error::RSTDuringConnect:
126 return "RSTDuringConnect";
127 case Error::UnexpectedFlagsDuringConnect:
128 return "UnexpectedFlagsDuringConnect";
129 default:
130 return "Invalid";
131 }
132 }
133
134 State state() const { return m_state; }
135 void set_state(State state);
136
137 Direction direction() const { return m_direction; }
138
139 bool has_error() const { return m_error != Error::None; }
140 Error error() const { return m_error; }
141 void set_error(Error error) { m_error = error; }
142
143 void set_ack_number(u32 n) { m_ack_number = n; }
144 void set_sequence_number(u32 n) { m_sequence_number = n; }
145 u32 ack_number() const { return m_ack_number; }
146 u32 sequence_number() const { return m_sequence_number; }
147 u32 packets_in() const { return m_packets_in; }
148 u32 bytes_in() const { return m_bytes_in; }
149 u32 packets_out() const { return m_packets_out; }
150 u32 bytes_out() const { return m_bytes_out; }
151
152 void send_tcp_packet(u16 flags, const void* = nullptr, size_t = 0);
153 void send_outgoing_packets();
154 void receive_tcp_packet(const TCPPacket&, u16 size);
155
156 static Lockable<HashMap<IPv4SocketTuple, TCPSocket*>>& sockets_by_tuple();
157 static RefPtr<TCPSocket> from_tuple(const IPv4SocketTuple& tuple);
158 static RefPtr<TCPSocket> from_endpoints(const IPv4Address& local_address, u16 local_port, const IPv4Address& peer_address, u16 peer_port);
159
160 static Lockable<HashMap<IPv4SocketTuple, RefPtr<TCPSocket>>>& closing_sockets();
161
162 RefPtr<TCPSocket> create_client(const IPv4Address& local_address, u16 local_port, const IPv4Address& peer_address, u16 peer_port);
163 void set_originator(TCPSocket& originator) { m_originator = originator.make_weak_ptr(); }
164 bool has_originator() { return !!m_originator; }
165 void release_to_originator();
166 void release_for_accept(RefPtr<TCPSocket>);
167
168 virtual void close() override;
169
170protected:
171 void set_direction(Direction direction) { m_direction = direction; }
172
173private:
174 explicit TCPSocket(int protocol);
175 virtual const char* class_name() const override { return "TCPSocket"; }
176
177 static NetworkOrdered<u16> compute_tcp_checksum(const IPv4Address& source, const IPv4Address& destination, const TCPPacket&, u16 payload_size);
178
179 virtual void shut_down_for_writing() override;
180
181 virtual int protocol_receive(const KBuffer&, void* buffer, size_t buffer_size, int flags) override;
182 virtual int protocol_send(const void*, size_t) override;
183 virtual KResult protocol_connect(FileDescription&, ShouldBlock) override;
184 virtual int protocol_allocate_local_port() override;
185 virtual bool protocol_is_disconnected() const override;
186 virtual KResult protocol_bind() override;
187 virtual KResult protocol_listen() override;
188
189 WeakPtr<TCPSocket> m_originator;
190 HashMap<IPv4SocketTuple, NonnullRefPtr<TCPSocket>> m_pending_release_for_accept;
191 Direction m_direction { Direction::Unspecified };
192 Error m_error { Error::None };
193 RefPtr<NetworkAdapter> m_adapter;
194 u32 m_sequence_number { 0 };
195 u32 m_ack_number { 0 };
196 State m_state { State::Closed };
197 u32 m_packets_in { 0 };
198 u32 m_bytes_in { 0 };
199 u32 m_packets_out { 0 };
200 u32 m_bytes_out { 0 };
201
202 struct OutgoingPacket {
203 u32 ack_number { 0 };
204 ByteBuffer buffer;
205 int tx_counter { 0 };
206 timeval tx_time { 0, 0 };
207 };
208
209 Lock m_not_acked_lock { "TCPSocket unacked packets" };
210 SinglyLinkedList<OutgoingPacket> m_not_acked;
211};
212
213}