A Python port of the Invisible Internet Project (I2P)
1"""SSU2 peer session state after handshake establishment.
2
3Ported from net.i2p.router.transport.udp.PeerState2.
4
5Manages the post-handshake data phase:
6- Cipher states for encrypt/decrypt
7- Packet numbering
8- ACK tracking (which packets we've received, which we've sent that are acked)
9- Retransmission of unacked packets
10- Connection migration (path challenge/response)
11"""
12
13from __future__ import annotations
14
15import os
16import struct
17import time
18from collections.abc import Sequence
19from dataclasses import dataclass, field
20
21from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305
22
23from i2p_transport.ssu2_handshake import HandshakeKeys, SHORT_HEADER_SIZE, PKT_DATA
24from i2p_transport.ssu2_bitfield import SSU2Bitfield
25from i2p_transport.ssu2_payload import (
26 AckBlock,
27 PathChallengeBlock,
28 PathResponseBlock,
29 PaddingBlock,
30 TerminationBlock,
31 SSU2PayloadBlock,
32 build_payload,
33 parse_payload,
34)
35from i2p_transport.ssu2 import SSU2HeaderProtection
36
37MAC_LEN = 16 # ChaCha20-Poly1305 auth tag length
38
39
40@dataclass
41class SentPacket:
42 """Tracks a sent packet for retransmission."""
43 packet_num: int
44 data: bytes
45 sent_at: float
46 retransmit_count: int = 0
47 acked: bool = False
48
49
50class SSU2Connection:
51 """Post-handshake SSU2 session state.
52
53 After Noise_XK completes, both sides have symmetric cipher states.
54 Data packets use short headers (16 bytes) + encrypted payload.
55 """
56
57 def __init__(self, keys: HandshakeKeys, src_conn_id: int, dest_conn_id: int,
58 remote_address: tuple[str, int], is_initiator: bool):
59 # Connection IDs
60 self.src_conn_id = src_conn_id
61 self.dest_conn_id = dest_conn_id
62 self.remote_address = remote_address
63 self.is_initiator = is_initiator
64
65 # Cipher states
66 self._send_cipher = ChaCha20Poly1305(keys.send_cipher_key)
67 self._recv_cipher = ChaCha20Poly1305(keys.recv_cipher_key)
68 self._send_header_key = keys.send_header_key
69 self._recv_header_key = keys.recv_header_key
70
71 # Header protection uses send/recv header keys
72 # For encrypting outbound: we protect with our send header key
73 # For decrypting inbound: peer protected with their send header key = our recv header key
74 self._send_header_prot = SSU2HeaderProtection(
75 keys.send_header_key, keys.send_header_key)
76 self._recv_header_prot = SSU2HeaderProtection(
77 keys.recv_header_key, keys.recv_header_key)
78
79 # Packet numbering
80 self._next_send_num: int = 0
81 self._next_expected_recv: int = 0
82
83 # ACK tracking
84 self._recv_bitfield = SSU2Bitfield()
85 self._sent_packets: dict[int, SentPacket] = {}
86 self._acked_through: int = -1
87
88 # Timing
89 self.created_at = time.monotonic()
90 self.last_send_time: float = 0.0
91 self.last_recv_time: float = 0.0
92
93 # State
94 self._closed = False
95 self._valid_frames_received: int = 0
96
97 def _build_short_header(self, pkt_num: int) -> bytes:
98 """Build a 16-byte short header for a DATA packet.
99
100 Layout: dest_conn_id(8) | pkt_num(4) | type(1) | flags(3)
101 """
102 flags_bytes = b"\x00\x00\x00"
103 return struct.pack("!QIB", self.dest_conn_id, pkt_num, PKT_DATA) + flags_bytes
104
105 def _build_nonce(self, packet_num: int) -> bytes:
106 """Build 12-byte nonce from packet number (big-endian, zero-padded)."""
107 return b"\x00" * 4 + struct.pack("!Q", packet_num)
108
109 def encrypt_data_packet(self, payload_blocks: Sequence[SSU2PayloadBlock]) -> bytes:
110 """Build and encrypt a data packet.
111
112 1. Allocate next packet number
113 2. Build short header (16 bytes)
114 3. Serialize payload blocks
115 4. Encrypt payload with ChaCha20-Poly1305 (nonce = packet number)
116 5. Apply header protection
117 Returns complete packet bytes.
118 """
119 if self._next_send_num >= 0xFFFFFFFF:
120 raise RuntimeError("SSU2 packet number exhausted — connection must be replaced")
121 pkt_num = self._next_send_num
122 self._next_send_num += 1
123
124 header = self._build_short_header(pkt_num)
125 plaintext = build_payload(payload_blocks)
126 nonce = self._build_nonce(pkt_num)
127
128 # Encrypt with associated data = header
129 ciphertext = self._send_cipher.encrypt(nonce, plaintext, header)
130
131 packet = bytearray(header + ciphertext)
132
133 # Apply header protection (XOR first 8 bytes using keystream from ciphertext)
134 if len(ciphertext) >= 12:
135 self._send_header_prot.encrypt_short_header(packet)
136
137 result = bytes(packet)
138
139 # Track for retransmission
140 now = time.monotonic()
141 self._sent_packets[pkt_num] = SentPacket(
142 packet_num=pkt_num, data=result, sent_at=now)
143 self.last_send_time = now
144
145 return result
146
147 def decrypt_data_packet(self, packet: bytes) -> tuple[int, list[SSU2PayloadBlock]]:
148 """Decrypt a received data packet.
149
150 1. Remove header protection
151 2. Parse short header (extract packet number)
152 3. Decrypt payload
153 4. Parse payload blocks
154 5. Update recv bitfield
155 Returns (packet_number, list_of_blocks).
156 """
157 buf = bytearray(packet)
158
159 # Remove header protection
160 if len(buf) > SHORT_HEADER_SIZE + 12:
161 self._recv_header_prot.decrypt_short_header(buf)
162
163 # Parse short header
164 dest_conn_id, pkt_num, pkt_type = struct.unpack("!QIB", buf[:13])
165 # bytes 13-15 are flags, ignored for now
166
167 header = bytes(buf[:SHORT_HEADER_SIZE])
168 ciphertext = bytes(buf[SHORT_HEADER_SIZE:])
169
170 # Decrypt
171 nonce = self._build_nonce(pkt_num)
172 plaintext = self._recv_cipher.decrypt(nonce, ciphertext, header)
173
174 # Parse payload blocks
175 blocks = parse_payload(plaintext)
176
177 # Update recv tracking
178 self._recv_bitfield.set(pkt_num)
179 self._valid_frames_received += 1
180 self.last_recv_time = time.monotonic()
181
182 return pkt_num, blocks
183
184 def build_ack_block(self) -> AckBlock:
185 """Build an ACK block reflecting what we've received."""
186 highest = self._recv_bitfield.get_highest()
187 if highest < 0:
188 return AckBlock(ack_through=0, ack_count=0, ranges=[])
189
190 ranges = self._recv_bitfield.to_ack_blocks()
191 return AckBlock(
192 ack_through=highest,
193 ack_count=len(ranges),
194 ranges=ranges,
195 )
196
197 def process_ack(self, ack_block: AckBlock) -> list[int]:
198 """Process received ACK block. Returns list of newly acked packet numbers.
199
200 Reconstructs the bitfield from the ACK block and marks matching
201 sent packets as acked.
202 """
203 peer_bf = SSU2Bitfield.from_ack_blocks(ack_block.ack_through, ack_block.ranges)
204 newly_acked: list[int] = []
205
206 for pkt_num, sp in self._sent_packets.items():
207 if not sp.acked and peer_bf.get(pkt_num):
208 sp.acked = True
209 newly_acked.append(pkt_num)
210
211 if ack_block.ack_through > self._acked_through:
212 self._acked_through = ack_block.ack_through
213
214 return newly_acked
215
216 def get_unacked_packets(self, max_age_seconds: float = 5.0) -> list[SentPacket]:
217 """Get packets that need retransmission.
218
219 Returns sent packets that are not yet acked and are older than
220 max_age_seconds.
221 """
222 now = time.monotonic()
223 result: list[SentPacket] = []
224 for sp in self._sent_packets.values():
225 if not sp.acked and (now - sp.sent_at) >= max_age_seconds:
226 result.append(sp)
227 return result
228
229 def send_path_challenge(self) -> bytes:
230 """Build a data packet containing a PathChallenge block."""
231 challenge_data = os.urandom(8)
232 blocks: list[SSU2PayloadBlock] = [
233 PathChallengeBlock(challenge_data=challenge_data),
234 ]
235 return self.encrypt_data_packet(blocks)
236
237 def process_path_challenge(self, challenge_data: bytes) -> bytes:
238 """Build PathResponse packet for received challenge."""
239 blocks: list[SSU2PayloadBlock] = [
240 PathResponseBlock(response_data=challenge_data),
241 ]
242 return self.encrypt_data_packet(blocks)
243
244 def close(self, reason: int = 0) -> bytes:
245 """Build termination packet."""
246 blocks: list[SSU2PayloadBlock] = [
247 TerminationBlock(
248 reason=reason,
249 valid_frames_received=self._valid_frames_received,
250 ),
251 ]
252 packet = self.encrypt_data_packet(blocks)
253 self._closed = True
254 return packet
255
256 @property
257 def is_established(self) -> bool:
258 return not self._closed
259
260 @property
261 def idle_time(self) -> float:
262 """Seconds since last activity."""
263 last_activity = max(self.last_send_time, self.last_recv_time)
264 if last_activity == 0.0:
265 return time.monotonic() - self.created_at
266 return time.monotonic() - last_activity