"""NTCP2 Noise handshake integration. Wraps the Noise_XK HandshakeState for the NTCP2 3-message handshake, and provides frame encryption/decryption and message fragmentation after the handshake completes. """ from i2p_crypto.noise import HandshakeState, CipherState from i2p_crypto.x25519 import X25519DH from i2p_transport.ntcp2 import NTCP2Frame, FrameType # Placeholder payload sent by responder/initiator for RouterInfo blocks _ROUTER_INFO_PLACEHOLDER = b"" class NTCP2Handshake: """Wraps HandshakeState('Noise_XK') for the NTCP2 3-message handshake.""" def __init__( self, our_static: tuple[bytes, bytes], peer_static_pub: bytes | None = None, initiator: bool = True, ): """ Args: our_static: (private_key, public_key) tuple. peer_static_pub: Remote static public key (needed for initiator). initiator: True if we initiate the connection. """ self._initiator = initiator self._hs = HandshakeState( pattern="Noise_XK", initiator=initiator, s=our_static, rs=peer_static_pub, ) self._step = 0 # tracks which handshake step we are on # -- Initiator step 1 -------------------------------------------------- def create_message_1(self, options: bytes = b"") -> bytes: """Initiator sends message 1: write_message(options).""" if not self._initiator: raise RuntimeError("Only initiator can create message 1") if self._step != 0: raise RuntimeError("create_message_1 must be called first") msg = self._hs.write_message(options) self._step = 1 return msg # -- Responder step: receive msg1, produce msg2 ------------------------- def process_message_1(self, msg1: bytes) -> bytes: """Responder receives msg1, returns msg2.""" if self._initiator: raise RuntimeError("Only responder can process message 1") if self._step != 0: raise RuntimeError("process_message_1 must be called first") self._hs.read_message(msg1) msg2 = self._hs.write_message(_ROUTER_INFO_PLACEHOLDER) self._step = 2 return msg2 # -- Initiator step: receive msg2, produce msg3 ------------------------- def process_message_2(self, msg2: bytes) -> bytes: """Initiator receives msg2, returns msg3.""" if not self._initiator: raise RuntimeError("Only initiator can process message 2") if self._step != 1: raise RuntimeError("Must call create_message_1 before process_message_2") self._hs.read_message(msg2) msg3 = self._hs.write_message(_ROUTER_INFO_PLACEHOLDER) self._step = 3 return msg3 # -- Responder step: receive msg3 --------------------------------------- def process_message_3(self, msg3: bytes) -> None: """Responder receives msg3.""" if self._initiator: raise RuntimeError("Only responder can process message 3") if self._step != 2: raise RuntimeError("Must call process_message_1 before process_message_3") self._hs.read_message(msg3) self._step = 3 # -- Post-handshake ----------------------------------------------------- def split(self) -> tuple[CipherState, CipherState]: """After handshake, return (send_cipher, recv_cipher). Noise split() produces (c1, c2). By convention: - Initiator sends with c1, receives with c2. - Responder sends with c2, receives with c1. """ c1, c2 = self._hs.split() if self._initiator: return c1, c2 else: return c2, c1 def is_complete(self) -> bool: return self._hs.complete def remote_static_key(self) -> bytes | None: return self._hs.remote_static class NTCP2FrameCodec: """Encrypts/decrypts NTCP2 frames after handshake completion.""" def encrypt_frame(self, cipher: CipherState, frame: NTCP2Frame) -> bytes: """Serialize frame and encrypt with the given CipherState.""" frame_bytes = frame.to_bytes() return cipher.encrypt_with_ad(b"", frame_bytes) def decrypt_frame(self, cipher: CipherState, encrypted: bytes) -> NTCP2Frame: """Decrypt and parse an NTCP2Frame.""" frame_bytes = cipher.decrypt_with_ad(b"", encrypted) return NTCP2Frame.from_bytes(frame_bytes) class NTCP2MessageFragmenter: """Fragment/reassemble I2NP messages across NTCP2 frames.""" def fragment(self, i2np_bytes: bytes, max_payload: int = 65535) -> list[NTCP2Frame]: """Split i2np_bytes into chunks, each becoming an I2NP frame.""" if len(i2np_bytes) == 0: return [NTCP2Frame(FrameType.I2NP, b"")] frames = [] offset = 0 while offset < len(i2np_bytes): chunk = i2np_bytes[offset : offset + max_payload] frames.append(NTCP2Frame(FrameType.I2NP, chunk)) offset += max_payload return frames def reassemble(self, frames: list[NTCP2Frame]) -> bytes: """Concatenate payloads of all I2NP frames.""" return b"".join(f.payload for f in frames)