"""SSU2 peer session state after handshake establishment. Ported from net.i2p.router.transport.udp.PeerState2. Manages the post-handshake data phase: - Cipher states for encrypt/decrypt - Packet numbering - ACK tracking (which packets we've received, which we've sent that are acked) - Retransmission of unacked packets - Connection migration (path challenge/response) """ from __future__ import annotations import os import struct import time from collections.abc import Sequence from dataclasses import dataclass, field from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305 from i2p_transport.ssu2_handshake import HandshakeKeys, SHORT_HEADER_SIZE, PKT_DATA from i2p_transport.ssu2_bitfield import SSU2Bitfield from i2p_transport.ssu2_payload import ( AckBlock, PathChallengeBlock, PathResponseBlock, PaddingBlock, TerminationBlock, SSU2PayloadBlock, build_payload, parse_payload, ) from i2p_transport.ssu2 import SSU2HeaderProtection MAC_LEN = 16 # ChaCha20-Poly1305 auth tag length @dataclass class SentPacket: """Tracks a sent packet for retransmission.""" packet_num: int data: bytes sent_at: float retransmit_count: int = 0 acked: bool = False class SSU2Connection: """Post-handshake SSU2 session state. After Noise_XK completes, both sides have symmetric cipher states. Data packets use short headers (16 bytes) + encrypted payload. """ def __init__(self, keys: HandshakeKeys, src_conn_id: int, dest_conn_id: int, remote_address: tuple[str, int], is_initiator: bool): # Connection IDs self.src_conn_id = src_conn_id self.dest_conn_id = dest_conn_id self.remote_address = remote_address self.is_initiator = is_initiator # Cipher states self._send_cipher = ChaCha20Poly1305(keys.send_cipher_key) self._recv_cipher = ChaCha20Poly1305(keys.recv_cipher_key) self._send_header_key = keys.send_header_key self._recv_header_key = keys.recv_header_key # Header protection uses send/recv header keys # For encrypting outbound: we protect with our send header key # For decrypting inbound: peer protected with their send header key = our recv header key self._send_header_prot = SSU2HeaderProtection( keys.send_header_key, keys.send_header_key) self._recv_header_prot = SSU2HeaderProtection( keys.recv_header_key, keys.recv_header_key) # Packet numbering self._next_send_num: int = 0 self._next_expected_recv: int = 0 # ACK tracking self._recv_bitfield = SSU2Bitfield() self._sent_packets: dict[int, SentPacket] = {} self._acked_through: int = -1 # Timing self.created_at = time.monotonic() self.last_send_time: float = 0.0 self.last_recv_time: float = 0.0 # State self._closed = False self._valid_frames_received: int = 0 def _build_short_header(self, pkt_num: int) -> bytes: """Build a 16-byte short header for a DATA packet. Layout: dest_conn_id(8) | pkt_num(4) | type(1) | flags(3) """ flags_bytes = b"\x00\x00\x00" return struct.pack("!QIB", self.dest_conn_id, pkt_num, PKT_DATA) + flags_bytes def _build_nonce(self, packet_num: int) -> bytes: """Build 12-byte nonce from packet number (big-endian, zero-padded).""" return b"\x00" * 4 + struct.pack("!Q", packet_num) def encrypt_data_packet(self, payload_blocks: Sequence[SSU2PayloadBlock]) -> bytes: """Build and encrypt a data packet. 1. Allocate next packet number 2. Build short header (16 bytes) 3. Serialize payload blocks 4. Encrypt payload with ChaCha20-Poly1305 (nonce = packet number) 5. Apply header protection Returns complete packet bytes. """ if self._next_send_num >= 0xFFFFFFFF: raise RuntimeError("SSU2 packet number exhausted — connection must be replaced") pkt_num = self._next_send_num self._next_send_num += 1 header = self._build_short_header(pkt_num) plaintext = build_payload(payload_blocks) nonce = self._build_nonce(pkt_num) # Encrypt with associated data = header ciphertext = self._send_cipher.encrypt(nonce, plaintext, header) packet = bytearray(header + ciphertext) # Apply header protection (XOR first 8 bytes using keystream from ciphertext) if len(ciphertext) >= 12: self._send_header_prot.encrypt_short_header(packet) result = bytes(packet) # Track for retransmission now = time.monotonic() self._sent_packets[pkt_num] = SentPacket( packet_num=pkt_num, data=result, sent_at=now) self.last_send_time = now return result def decrypt_data_packet(self, packet: bytes) -> tuple[int, list[SSU2PayloadBlock]]: """Decrypt a received data packet. 1. Remove header protection 2. Parse short header (extract packet number) 3. Decrypt payload 4. Parse payload blocks 5. Update recv bitfield Returns (packet_number, list_of_blocks). """ buf = bytearray(packet) # Remove header protection if len(buf) > SHORT_HEADER_SIZE + 12: self._recv_header_prot.decrypt_short_header(buf) # Parse short header dest_conn_id, pkt_num, pkt_type = struct.unpack("!QIB", buf[:13]) # bytes 13-15 are flags, ignored for now header = bytes(buf[:SHORT_HEADER_SIZE]) ciphertext = bytes(buf[SHORT_HEADER_SIZE:]) # Decrypt nonce = self._build_nonce(pkt_num) plaintext = self._recv_cipher.decrypt(nonce, ciphertext, header) # Parse payload blocks blocks = parse_payload(plaintext) # Update recv tracking self._recv_bitfield.set(pkt_num) self._valid_frames_received += 1 self.last_recv_time = time.monotonic() return pkt_num, blocks def build_ack_block(self) -> AckBlock: """Build an ACK block reflecting what we've received.""" highest = self._recv_bitfield.get_highest() if highest < 0: return AckBlock(ack_through=0, ack_count=0, ranges=[]) ranges = self._recv_bitfield.to_ack_blocks() return AckBlock( ack_through=highest, ack_count=len(ranges), ranges=ranges, ) def process_ack(self, ack_block: AckBlock) -> list[int]: """Process received ACK block. Returns list of newly acked packet numbers. Reconstructs the bitfield from the ACK block and marks matching sent packets as acked. """ peer_bf = SSU2Bitfield.from_ack_blocks(ack_block.ack_through, ack_block.ranges) newly_acked: list[int] = [] for pkt_num, sp in self._sent_packets.items(): if not sp.acked and peer_bf.get(pkt_num): sp.acked = True newly_acked.append(pkt_num) if ack_block.ack_through > self._acked_through: self._acked_through = ack_block.ack_through return newly_acked def get_unacked_packets(self, max_age_seconds: float = 5.0) -> list[SentPacket]: """Get packets that need retransmission. Returns sent packets that are not yet acked and are older than max_age_seconds. """ now = time.monotonic() result: list[SentPacket] = [] for sp in self._sent_packets.values(): if not sp.acked and (now - sp.sent_at) >= max_age_seconds: result.append(sp) return result def send_path_challenge(self) -> bytes: """Build a data packet containing a PathChallenge block.""" challenge_data = os.urandom(8) blocks: list[SSU2PayloadBlock] = [ PathChallengeBlock(challenge_data=challenge_data), ] return self.encrypt_data_packet(blocks) def process_path_challenge(self, challenge_data: bytes) -> bytes: """Build PathResponse packet for received challenge.""" blocks: list[SSU2PayloadBlock] = [ PathResponseBlock(response_data=challenge_data), ] return self.encrypt_data_packet(blocks) def close(self, reason: int = 0) -> bytes: """Build termination packet.""" blocks: list[SSU2PayloadBlock] = [ TerminationBlock( reason=reason, valid_frames_received=self._valid_frames_received, ), ] packet = self.encrypt_data_packet(blocks) self._closed = True return packet @property def is_established(self) -> bool: return not self._closed @property def idle_time(self) -> float: """Seconds since last activity.""" last_activity = max(self.last_send_time, self.last_recv_time) if last_activity == 0.0: return time.monotonic() - self.created_at return time.monotonic() - last_activity