A Python port of the Invisible Internet Project (I2P)
1"""Tests for streaming send/receive with retransmits."""
2
3import time
4import pytest
5
6from i2p_streaming.stream_io import (
7 MessageInputStream,
8 MessageOutputStream,
9 StreamSession,
10)
11
12
13# ---------------------------------------------------------------------------
14# MessageInputStream tests
15# ---------------------------------------------------------------------------
16
17class TestMessageInputStream:
18 """Tests for the receive-side buffer and ACK generation."""
19
20 def test_in_order_delivery(self):
21 """Receive packets 0, 1, 2 in order -- read returns all data."""
22 stream = MessageInputStream()
23 stream.receive_packet(0, b"aaa")
24 stream.receive_packet(1, b"bbb")
25 stream.receive_packet(2, b"ccc")
26
27 data = stream.read(9)
28 assert data == b"aaabbbccc"
29
30 def test_out_of_order_gap(self):
31 """Receive 0, then 2 (skip 1) -- read returns only seq 0 data."""
32 stream = MessageInputStream()
33 stream.receive_packet(0, b"aaa")
34 stream.receive_packet(2, b"ccc")
35
36 data = stream.read(9)
37 assert data == b"aaa"
38 assert stream.readable_bytes() == 0
39
40 def test_gap_filled(self):
41 """Receive 0, 2, then fill gap with 1 -- read returns all."""
42 stream = MessageInputStream()
43 stream.receive_packet(0, b"aaa")
44 stream.receive_packet(2, b"ccc")
45 stream.receive_packet(1, b"bbb")
46
47 data = stream.read(9)
48 assert data == b"aaabbbccc"
49
50 def test_generate_acks_no_gaps(self):
51 """ACKs with contiguous reception -- no NACKs."""
52 stream = MessageInputStream()
53 stream.receive_packet(0, b"a")
54 stream.receive_packet(1, b"b")
55 stream.receive_packet(2, b"c")
56
57 ack_through, nacks = stream.generate_acks()
58 assert ack_through == 2
59 assert nacks == []
60
61 def test_generate_acks_with_gaps(self):
62 """ACKs with gaps produce NACKs for missing sequence numbers."""
63 stream = MessageInputStream()
64 stream.receive_packet(0, b"a")
65 stream.receive_packet(2, b"c")
66 stream.receive_packet(4, b"e")
67
68 ack_through, nacks = stream.generate_acks()
69 assert ack_through == 0
70 assert sorted(nacks) == [1, 3]
71
72 def test_is_complete_after_close(self):
73 """Stream is complete after CLOSE received and all data consumed."""
74 stream = MessageInputStream()
75 stream.receive_packet(0, b"a")
76 stream.receive_packet(1, b"b", is_close=True)
77
78 assert not stream.is_complete()
79 stream.read(2)
80 assert stream.is_complete()
81
82 def test_read_partial(self):
83 """read(n) returns at most n bytes."""
84 stream = MessageInputStream()
85 stream.receive_packet(0, b"hello")
86 stream.receive_packet(1, b"world")
87
88 data = stream.read(3)
89 assert data == b"hel"
90 assert stream.readable_bytes() == 7
91
92 def test_readable_bytes(self):
93 """readable_bytes counts only in-order contiguous data."""
94 stream = MessageInputStream()
95 stream.receive_packet(0, b"aaa")
96 stream.receive_packet(2, b"ccc") # gap at 1
97
98 assert stream.readable_bytes() == 3
99
100
101# ---------------------------------------------------------------------------
102# MessageOutputStream tests
103# ---------------------------------------------------------------------------
104
105class TestMessageOutputStream:
106 """Tests for the send-side chunking and retransmit tracking."""
107
108 def test_write_splits_into_chunks(self):
109 """Writing data larger than max_packet_size splits it."""
110 stream = MessageOutputStream()
111 data = b"x" * 2500
112 packets = stream.write(data, max_packet_size=1024)
113
114 assert len(packets) == 3
115 assert packets[0][0] == 0 # seq 0
116 assert packets[1][0] == 1 # seq 1
117 assert packets[2][0] == 2 # seq 2
118 assert len(packets[0][1]) == 1024
119 assert len(packets[1][1]) == 1024
120 assert len(packets[2][1]) == 452
121 # Reconstructed data matches
122 assert b"".join(p[1] for p in packets) == data
123
124 def test_on_ack_removes_packets(self):
125 """on_ack removes all packets up through ack_through."""
126 stream = MessageOutputStream()
127 stream.write(b"x" * 3000, max_packet_size=1024)
128 assert stream.pending_count() == 3
129
130 stream.on_ack(ack_through=1)
131 assert stream.pending_count() == 1 # only seq 2 remains
132
133 def test_on_ack_with_nacks(self):
134 """on_ack with nacks keeps nacked packets pending."""
135 stream = MessageOutputStream()
136 stream.write(b"x" * 3000, max_packet_size=1024)
137 assert stream.pending_count() == 3
138
139 stream.on_ack(ack_through=2, nacks=[1])
140 assert stream.pending_count() == 1 # seq 1 still pending
141
142 def test_get_retransmit_packets(self):
143 """Packets older than RTO are returned for retransmit."""
144 stream = MessageOutputStream()
145 stream.write(b"hello", max_packet_size=1024)
146 assert stream.pending_count() == 1
147
148 # Simulate time passing: query with now_ms far in the future
149 now_ms = int(time.time() * 1000) + 2000
150 retransmits = stream.get_retransmit_packets(now_ms, rto_ms=1000)
151 assert len(retransmits) == 1
152 assert retransmits[0][1] == b"hello"
153
154 def test_retransmit_not_triggered_early(self):
155 """Packets within RTO window are not returned."""
156 stream = MessageOutputStream()
157 stream.write(b"hello", max_packet_size=1024)
158
159 now_ms = int(time.time() * 1000) + 500
160 retransmits = stream.get_retransmit_packets(now_ms, rto_ms=1000)
161 assert len(retransmits) == 0
162
163 def test_pending_count_tracks_unacked(self):
164 """pending_count reflects outstanding packets."""
165 stream = MessageOutputStream()
166 assert stream.pending_count() == 0
167
168 stream.write(b"a" * 2048, max_packet_size=1024)
169 assert stream.pending_count() == 2
170
171 stream.on_ack(ack_through=0)
172 assert stream.pending_count() == 1
173
174 stream.on_ack(ack_through=1)
175 assert stream.pending_count() == 0
176
177
178# ---------------------------------------------------------------------------
179# StreamSession tests
180# ---------------------------------------------------------------------------
181
182class TestStreamSession:
183 """Tests for the streaming session state machine."""
184
185 def test_connect_transitions_to_syn_sent(self):
186 """connect() moves state to SYN_SENT."""
187 session = StreamSession(local_id=1)
188 seq, syn_data = session.connect()
189
190 assert session.state == "SYN_SENT"
191 assert isinstance(seq, int)
192 assert isinstance(syn_data, bytes)
193
194 def test_accept_transitions_to_established(self):
195 """accept() moves state to ESTABLISHED."""
196 session = StreamSession(local_id=2)
197 seq, syn_ack_data = session.accept(remote_id=1)
198
199 assert session.state == "ESTABLISHED"
200 assert isinstance(seq, int)
201 assert isinstance(syn_ack_data, bytes)
202
203 def test_full_syn_handshake(self):
204 """Initiator connect + receive_syn_ack => ESTABLISHED."""
205 initiator = StreamSession(local_id=10)
206 responder = StreamSession(local_id=20)
207
208 # Initiator sends SYN
209 seq, syn_data = initiator.connect()
210 assert initiator.state == "SYN_SENT"
211
212 # Responder accepts
213 seq2, syn_ack_data = responder.accept(remote_id=10)
214 assert responder.state == "ESTABLISHED"
215
216 # Initiator receives SYN-ACK
217 initiator.receive_syn_ack(remote_id=20)
218 assert initiator.state == "ESTABLISHED"
219
220 def test_send_receive_data(self):
221 """Send and receive data after ESTABLISHED."""
222 sender = StreamSession(local_id=10)
223 receiver = StreamSession(local_id=20)
224
225 # Establish
226 sender.connect()
227 receiver.accept(remote_id=10)
228 sender.receive_syn_ack(remote_id=20)
229
230 # Send data
231 packets = sender.send(b"hello world")
232 assert len(packets) >= 1
233
234 # Receive data
235 for seq, data in packets:
236 receiver.receive(seq, data)
237
238 result = receiver.read(11)
239 assert result == b"hello world"
240
241 def test_close_handshake(self):
242 """Close transitions through CLOSE_SENT -> CLOSED."""
243 session = StreamSession(local_id=10)
244 session.connect()
245 session.receive_syn_ack(remote_id=20)
246 assert session.state == "ESTABLISHED"
247
248 seq, close_data = session.close()
249 assert session.state == "CLOSE_SENT"
250 assert isinstance(close_data, bytes)
251
252 session.receive_close()
253 assert session.state == "CLOSED"
254
255 def test_send_before_established_raises(self):
256 """Cannot send data before connection is established."""
257 session = StreamSession(local_id=1)
258 with pytest.raises(RuntimeError):
259 session.send(b"data")
260
261 def test_local_id_auto_generated(self):
262 """local_id is auto-generated if not provided."""
263 session = StreamSession()
264 assert isinstance(session.local_id, int)
265 assert 0 <= session.local_id < 2**32