"""Tests for streaming send/receive with retransmits.""" import time import pytest from i2p_streaming.stream_io import ( MessageInputStream, MessageOutputStream, StreamSession, ) # --------------------------------------------------------------------------- # MessageInputStream tests # --------------------------------------------------------------------------- class TestMessageInputStream: """Tests for the receive-side buffer and ACK generation.""" def test_in_order_delivery(self): """Receive packets 0, 1, 2 in order -- read returns all data.""" stream = MessageInputStream() stream.receive_packet(0, b"aaa") stream.receive_packet(1, b"bbb") stream.receive_packet(2, b"ccc") data = stream.read(9) assert data == b"aaabbbccc" def test_out_of_order_gap(self): """Receive 0, then 2 (skip 1) -- read returns only seq 0 data.""" stream = MessageInputStream() stream.receive_packet(0, b"aaa") stream.receive_packet(2, b"ccc") data = stream.read(9) assert data == b"aaa" assert stream.readable_bytes() == 0 def test_gap_filled(self): """Receive 0, 2, then fill gap with 1 -- read returns all.""" stream = MessageInputStream() stream.receive_packet(0, b"aaa") stream.receive_packet(2, b"ccc") stream.receive_packet(1, b"bbb") data = stream.read(9) assert data == b"aaabbbccc" def test_generate_acks_no_gaps(self): """ACKs with contiguous reception -- no NACKs.""" stream = MessageInputStream() stream.receive_packet(0, b"a") stream.receive_packet(1, b"b") stream.receive_packet(2, b"c") ack_through, nacks = stream.generate_acks() assert ack_through == 2 assert nacks == [] def test_generate_acks_with_gaps(self): """ACKs with gaps produce NACKs for missing sequence numbers.""" stream = MessageInputStream() stream.receive_packet(0, b"a") stream.receive_packet(2, b"c") stream.receive_packet(4, b"e") ack_through, nacks = stream.generate_acks() assert ack_through == 0 assert sorted(nacks) == [1, 3] def test_is_complete_after_close(self): """Stream is complete after CLOSE received and all data consumed.""" stream = MessageInputStream() stream.receive_packet(0, b"a") stream.receive_packet(1, b"b", is_close=True) assert not stream.is_complete() stream.read(2) assert stream.is_complete() def test_read_partial(self): """read(n) returns at most n bytes.""" stream = MessageInputStream() stream.receive_packet(0, b"hello") stream.receive_packet(1, b"world") data = stream.read(3) assert data == b"hel" assert stream.readable_bytes() == 7 def test_readable_bytes(self): """readable_bytes counts only in-order contiguous data.""" stream = MessageInputStream() stream.receive_packet(0, b"aaa") stream.receive_packet(2, b"ccc") # gap at 1 assert stream.readable_bytes() == 3 # --------------------------------------------------------------------------- # MessageOutputStream tests # --------------------------------------------------------------------------- class TestMessageOutputStream: """Tests for the send-side chunking and retransmit tracking.""" def test_write_splits_into_chunks(self): """Writing data larger than max_packet_size splits it.""" stream = MessageOutputStream() data = b"x" * 2500 packets = stream.write(data, max_packet_size=1024) assert len(packets) == 3 assert packets[0][0] == 0 # seq 0 assert packets[1][0] == 1 # seq 1 assert packets[2][0] == 2 # seq 2 assert len(packets[0][1]) == 1024 assert len(packets[1][1]) == 1024 assert len(packets[2][1]) == 452 # Reconstructed data matches assert b"".join(p[1] for p in packets) == data def test_on_ack_removes_packets(self): """on_ack removes all packets up through ack_through.""" stream = MessageOutputStream() stream.write(b"x" * 3000, max_packet_size=1024) assert stream.pending_count() == 3 stream.on_ack(ack_through=1) assert stream.pending_count() == 1 # only seq 2 remains def test_on_ack_with_nacks(self): """on_ack with nacks keeps nacked packets pending.""" stream = MessageOutputStream() stream.write(b"x" * 3000, max_packet_size=1024) assert stream.pending_count() == 3 stream.on_ack(ack_through=2, nacks=[1]) assert stream.pending_count() == 1 # seq 1 still pending def test_get_retransmit_packets(self): """Packets older than RTO are returned for retransmit.""" stream = MessageOutputStream() stream.write(b"hello", max_packet_size=1024) assert stream.pending_count() == 1 # Simulate time passing: query with now_ms far in the future now_ms = int(time.time() * 1000) + 2000 retransmits = stream.get_retransmit_packets(now_ms, rto_ms=1000) assert len(retransmits) == 1 assert retransmits[0][1] == b"hello" def test_retransmit_not_triggered_early(self): """Packets within RTO window are not returned.""" stream = MessageOutputStream() stream.write(b"hello", max_packet_size=1024) now_ms = int(time.time() * 1000) + 500 retransmits = stream.get_retransmit_packets(now_ms, rto_ms=1000) assert len(retransmits) == 0 def test_pending_count_tracks_unacked(self): """pending_count reflects outstanding packets.""" stream = MessageOutputStream() assert stream.pending_count() == 0 stream.write(b"a" * 2048, max_packet_size=1024) assert stream.pending_count() == 2 stream.on_ack(ack_through=0) assert stream.pending_count() == 1 stream.on_ack(ack_through=1) assert stream.pending_count() == 0 # --------------------------------------------------------------------------- # StreamSession tests # --------------------------------------------------------------------------- class TestStreamSession: """Tests for the streaming session state machine.""" def test_connect_transitions_to_syn_sent(self): """connect() moves state to SYN_SENT.""" session = StreamSession(local_id=1) seq, syn_data = session.connect() assert session.state == "SYN_SENT" assert isinstance(seq, int) assert isinstance(syn_data, bytes) def test_accept_transitions_to_established(self): """accept() moves state to ESTABLISHED.""" session = StreamSession(local_id=2) seq, syn_ack_data = session.accept(remote_id=1) assert session.state == "ESTABLISHED" assert isinstance(seq, int) assert isinstance(syn_ack_data, bytes) def test_full_syn_handshake(self): """Initiator connect + receive_syn_ack => ESTABLISHED.""" initiator = StreamSession(local_id=10) responder = StreamSession(local_id=20) # Initiator sends SYN seq, syn_data = initiator.connect() assert initiator.state == "SYN_SENT" # Responder accepts seq2, syn_ack_data = responder.accept(remote_id=10) assert responder.state == "ESTABLISHED" # Initiator receives SYN-ACK initiator.receive_syn_ack(remote_id=20) assert initiator.state == "ESTABLISHED" def test_send_receive_data(self): """Send and receive data after ESTABLISHED.""" sender = StreamSession(local_id=10) receiver = StreamSession(local_id=20) # Establish sender.connect() receiver.accept(remote_id=10) sender.receive_syn_ack(remote_id=20) # Send data packets = sender.send(b"hello world") assert len(packets) >= 1 # Receive data for seq, data in packets: receiver.receive(seq, data) result = receiver.read(11) assert result == b"hello world" def test_close_handshake(self): """Close transitions through CLOSE_SENT -> CLOSED.""" session = StreamSession(local_id=10) session.connect() session.receive_syn_ack(remote_id=20) assert session.state == "ESTABLISHED" seq, close_data = session.close() assert session.state == "CLOSE_SENT" assert isinstance(close_data, bytes) session.receive_close() assert session.state == "CLOSED" def test_send_before_established_raises(self): """Cannot send data before connection is established.""" session = StreamSession(local_id=1) with pytest.raises(RuntimeError): session.send(b"data") def test_local_id_auto_generated(self): """local_id is auto-generated if not provided.""" session = StreamSession() assert isinstance(session.local_id, int) assert 0 <= session.local_id < 2**32