A Python port of the Invisible Internet Project (I2P)
at main 266 lines 9.1 kB view raw
1"""SSU2 peer session state after handshake establishment. 2 3Ported from net.i2p.router.transport.udp.PeerState2. 4 5Manages the post-handshake data phase: 6- Cipher states for encrypt/decrypt 7- Packet numbering 8- ACK tracking (which packets we've received, which we've sent that are acked) 9- Retransmission of unacked packets 10- Connection migration (path challenge/response) 11""" 12 13from __future__ import annotations 14 15import os 16import struct 17import time 18from collections.abc import Sequence 19from dataclasses import dataclass, field 20 21from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305 22 23from i2p_transport.ssu2_handshake import HandshakeKeys, SHORT_HEADER_SIZE, PKT_DATA 24from i2p_transport.ssu2_bitfield import SSU2Bitfield 25from i2p_transport.ssu2_payload import ( 26 AckBlock, 27 PathChallengeBlock, 28 PathResponseBlock, 29 PaddingBlock, 30 TerminationBlock, 31 SSU2PayloadBlock, 32 build_payload, 33 parse_payload, 34) 35from i2p_transport.ssu2 import SSU2HeaderProtection 36 37MAC_LEN = 16 # ChaCha20-Poly1305 auth tag length 38 39 40@dataclass 41class SentPacket: 42 """Tracks a sent packet for retransmission.""" 43 packet_num: int 44 data: bytes 45 sent_at: float 46 retransmit_count: int = 0 47 acked: bool = False 48 49 50class SSU2Connection: 51 """Post-handshake SSU2 session state. 52 53 After Noise_XK completes, both sides have symmetric cipher states. 54 Data packets use short headers (16 bytes) + encrypted payload. 55 """ 56 57 def __init__(self, keys: HandshakeKeys, src_conn_id: int, dest_conn_id: int, 58 remote_address: tuple[str, int], is_initiator: bool): 59 # Connection IDs 60 self.src_conn_id = src_conn_id 61 self.dest_conn_id = dest_conn_id 62 self.remote_address = remote_address 63 self.is_initiator = is_initiator 64 65 # Cipher states 66 self._send_cipher = ChaCha20Poly1305(keys.send_cipher_key) 67 self._recv_cipher = ChaCha20Poly1305(keys.recv_cipher_key) 68 self._send_header_key = keys.send_header_key 69 self._recv_header_key = keys.recv_header_key 70 71 # Header protection uses send/recv header keys 72 # For encrypting outbound: we protect with our send header key 73 # For decrypting inbound: peer protected with their send header key = our recv header key 74 self._send_header_prot = SSU2HeaderProtection( 75 keys.send_header_key, keys.send_header_key) 76 self._recv_header_prot = SSU2HeaderProtection( 77 keys.recv_header_key, keys.recv_header_key) 78 79 # Packet numbering 80 self._next_send_num: int = 0 81 self._next_expected_recv: int = 0 82 83 # ACK tracking 84 self._recv_bitfield = SSU2Bitfield() 85 self._sent_packets: dict[int, SentPacket] = {} 86 self._acked_through: int = -1 87 88 # Timing 89 self.created_at = time.monotonic() 90 self.last_send_time: float = 0.0 91 self.last_recv_time: float = 0.0 92 93 # State 94 self._closed = False 95 self._valid_frames_received: int = 0 96 97 def _build_short_header(self, pkt_num: int) -> bytes: 98 """Build a 16-byte short header for a DATA packet. 99 100 Layout: dest_conn_id(8) | pkt_num(4) | type(1) | flags(3) 101 """ 102 flags_bytes = b"\x00\x00\x00" 103 return struct.pack("!QIB", self.dest_conn_id, pkt_num, PKT_DATA) + flags_bytes 104 105 def _build_nonce(self, packet_num: int) -> bytes: 106 """Build 12-byte nonce from packet number (big-endian, zero-padded).""" 107 return b"\x00" * 4 + struct.pack("!Q", packet_num) 108 109 def encrypt_data_packet(self, payload_blocks: Sequence[SSU2PayloadBlock]) -> bytes: 110 """Build and encrypt a data packet. 111 112 1. Allocate next packet number 113 2. Build short header (16 bytes) 114 3. Serialize payload blocks 115 4. Encrypt payload with ChaCha20-Poly1305 (nonce = packet number) 116 5. Apply header protection 117 Returns complete packet bytes. 118 """ 119 if self._next_send_num >= 0xFFFFFFFF: 120 raise RuntimeError("SSU2 packet number exhausted — connection must be replaced") 121 pkt_num = self._next_send_num 122 self._next_send_num += 1 123 124 header = self._build_short_header(pkt_num) 125 plaintext = build_payload(payload_blocks) 126 nonce = self._build_nonce(pkt_num) 127 128 # Encrypt with associated data = header 129 ciphertext = self._send_cipher.encrypt(nonce, plaintext, header) 130 131 packet = bytearray(header + ciphertext) 132 133 # Apply header protection (XOR first 8 bytes using keystream from ciphertext) 134 if len(ciphertext) >= 12: 135 self._send_header_prot.encrypt_short_header(packet) 136 137 result = bytes(packet) 138 139 # Track for retransmission 140 now = time.monotonic() 141 self._sent_packets[pkt_num] = SentPacket( 142 packet_num=pkt_num, data=result, sent_at=now) 143 self.last_send_time = now 144 145 return result 146 147 def decrypt_data_packet(self, packet: bytes) -> tuple[int, list[SSU2PayloadBlock]]: 148 """Decrypt a received data packet. 149 150 1. Remove header protection 151 2. Parse short header (extract packet number) 152 3. Decrypt payload 153 4. Parse payload blocks 154 5. Update recv bitfield 155 Returns (packet_number, list_of_blocks). 156 """ 157 buf = bytearray(packet) 158 159 # Remove header protection 160 if len(buf) > SHORT_HEADER_SIZE + 12: 161 self._recv_header_prot.decrypt_short_header(buf) 162 163 # Parse short header 164 dest_conn_id, pkt_num, pkt_type = struct.unpack("!QIB", buf[:13]) 165 # bytes 13-15 are flags, ignored for now 166 167 header = bytes(buf[:SHORT_HEADER_SIZE]) 168 ciphertext = bytes(buf[SHORT_HEADER_SIZE:]) 169 170 # Decrypt 171 nonce = self._build_nonce(pkt_num) 172 plaintext = self._recv_cipher.decrypt(nonce, ciphertext, header) 173 174 # Parse payload blocks 175 blocks = parse_payload(plaintext) 176 177 # Update recv tracking 178 self._recv_bitfield.set(pkt_num) 179 self._valid_frames_received += 1 180 self.last_recv_time = time.monotonic() 181 182 return pkt_num, blocks 183 184 def build_ack_block(self) -> AckBlock: 185 """Build an ACK block reflecting what we've received.""" 186 highest = self._recv_bitfield.get_highest() 187 if highest < 0: 188 return AckBlock(ack_through=0, ack_count=0, ranges=[]) 189 190 ranges = self._recv_bitfield.to_ack_blocks() 191 return AckBlock( 192 ack_through=highest, 193 ack_count=len(ranges), 194 ranges=ranges, 195 ) 196 197 def process_ack(self, ack_block: AckBlock) -> list[int]: 198 """Process received ACK block. Returns list of newly acked packet numbers. 199 200 Reconstructs the bitfield from the ACK block and marks matching 201 sent packets as acked. 202 """ 203 peer_bf = SSU2Bitfield.from_ack_blocks(ack_block.ack_through, ack_block.ranges) 204 newly_acked: list[int] = [] 205 206 for pkt_num, sp in self._sent_packets.items(): 207 if not sp.acked and peer_bf.get(pkt_num): 208 sp.acked = True 209 newly_acked.append(pkt_num) 210 211 if ack_block.ack_through > self._acked_through: 212 self._acked_through = ack_block.ack_through 213 214 return newly_acked 215 216 def get_unacked_packets(self, max_age_seconds: float = 5.0) -> list[SentPacket]: 217 """Get packets that need retransmission. 218 219 Returns sent packets that are not yet acked and are older than 220 max_age_seconds. 221 """ 222 now = time.monotonic() 223 result: list[SentPacket] = [] 224 for sp in self._sent_packets.values(): 225 if not sp.acked and (now - sp.sent_at) >= max_age_seconds: 226 result.append(sp) 227 return result 228 229 def send_path_challenge(self) -> bytes: 230 """Build a data packet containing a PathChallenge block.""" 231 challenge_data = os.urandom(8) 232 blocks: list[SSU2PayloadBlock] = [ 233 PathChallengeBlock(challenge_data=challenge_data), 234 ] 235 return self.encrypt_data_packet(blocks) 236 237 def process_path_challenge(self, challenge_data: bytes) -> bytes: 238 """Build PathResponse packet for received challenge.""" 239 blocks: list[SSU2PayloadBlock] = [ 240 PathResponseBlock(response_data=challenge_data), 241 ] 242 return self.encrypt_data_packet(blocks) 243 244 def close(self, reason: int = 0) -> bytes: 245 """Build termination packet.""" 246 blocks: list[SSU2PayloadBlock] = [ 247 TerminationBlock( 248 reason=reason, 249 valid_frames_received=self._valid_frames_received, 250 ), 251 ] 252 packet = self.encrypt_data_packet(blocks) 253 self._closed = True 254 return packet 255 256 @property 257 def is_established(self) -> bool: 258 return not self._closed 259 260 @property 261 def idle_time(self) -> float: 262 """Seconds since last activity.""" 263 last_activity = max(self.last_send_time, self.last_recv_time) 264 if last_activity == 0.0: 265 return time.monotonic() - self.created_at 266 return time.monotonic() - last_activity