A Python port of the Invisible Internet Project (I2P)
at main 364 lines 13 kB view raw
1"""Tests for NTCP2 wire-format frame encryption. 2 3TDD: these tests are written before the implementation. 4 5NTCP2 wire format encrypts each frame in two operations: 61. Encrypt the 2-byte frame length -> 18 bytes (2 + 16 AEAD tag) 72. Encrypt the frame payload (type + length + data) -> N + 16 bytes 8Each encrypt/decrypt consumes one nonce from the CipherState. 9""" 10 11import os 12import struct 13 14import pytest 15 16from i2p_crypto.x25519 import X25519DH 17from i2p_crypto.noise import CipherState 18from i2p_transport.ntcp2 import NTCP2Frame, FrameType 19from i2p_transport.ntcp2_handshake import NTCP2Handshake 20 21 22# --------------------------------------------------------------------------- 23# Helpers 24# --------------------------------------------------------------------------- 25 26def _make_keypair(): 27 return X25519DH.generate_keypair() 28 29 30def _make_cipher_pair(): 31 """Complete a Noise_XK handshake and return (send_cipher, recv_cipher). 32 33 send_cipher encrypts; recv_cipher decrypts with matching keys/nonces. 34 """ 35 init_static = _make_keypair() 36 resp_static = _make_keypair() 37 38 init_hs = NTCP2Handshake(init_static, resp_static[1], True) 39 resp_hs = NTCP2Handshake(resp_static, initiator=False) 40 41 msg1 = init_hs.create_message_1() 42 msg2 = resp_hs.process_message_1(msg1) 43 msg3 = init_hs.process_message_2(msg2) 44 resp_hs.process_message_3(msg3) 45 46 send, _ = init_hs.split() 47 _, resp_recv = resp_hs.split() 48 return send, resp_recv 49 50 51# --------------------------------------------------------------------------- 52# Tests 53# --------------------------------------------------------------------------- 54 55class TestEncryptFrameByteCount: 56 """encrypt_frame produces correct byte count: 18 + len(frame_bytes) + 16.""" 57 58 def test_data_frame_size(self): 59 from i2p_transport.ntcp2_wire import NTCP2WireCodec 60 61 send, _ = _make_cipher_pair() 62 codec = NTCP2WireCodec() 63 payload = b"hello world" 64 frame = NTCP2Frame(FrameType.DATA, payload) 65 frame_bytes = frame.to_bytes() 66 67 encrypted = codec.encrypt_frame(send, frame) 68 69 # 18 bytes for encrypted length + (len(frame_bytes) + 16) for encrypted payload 70 expected_len = 18 + len(frame_bytes) + 16 71 assert len(encrypted) == expected_len 72 73 def test_empty_payload_frame_size(self): 74 from i2p_transport.ntcp2_wire import NTCP2WireCodec 75 76 send, _ = _make_cipher_pair() 77 codec = NTCP2WireCodec() 78 frame = NTCP2Frame(FrameType.PADDING, b"") 79 frame_bytes = frame.to_bytes() 80 81 encrypted = codec.encrypt_frame(send, frame) 82 83 expected_len = 18 + len(frame_bytes) + 16 84 assert len(encrypted) == expected_len 85 86 def test_large_payload_frame_size(self): 87 from i2p_transport.ntcp2_wire import NTCP2WireCodec 88 89 send, _ = _make_cipher_pair() 90 codec = NTCP2WireCodec() 91 payload = os.urandom(1024) 92 frame = NTCP2Frame(FrameType.I2NP, payload) 93 frame_bytes = frame.to_bytes() 94 95 encrypted = codec.encrypt_frame(send, frame) 96 97 expected_len = 18 + len(frame_bytes) + 16 98 assert len(encrypted) == expected_len 99 100 101class TestDecryptFrameLength: 102 """decrypt_frame_length recovers original length.""" 103 104 def test_recover_length(self): 105 from i2p_transport.ntcp2_wire import NTCP2WireCodec 106 107 send, recv = _make_cipher_pair() 108 codec = NTCP2WireCodec() 109 payload = b"test payload data" 110 frame = NTCP2Frame(FrameType.DATA, payload) 111 frame_bytes = frame.to_bytes() 112 113 encrypted = codec.encrypt_frame(send, frame) 114 # First 18 bytes are the encrypted length 115 encrypted_length = encrypted[:18] 116 117 recovered_length = codec.decrypt_frame_length(recv, encrypted_length) 118 assert recovered_length == len(frame_bytes) 119 120 def test_recover_zero_payload_length(self): 121 from i2p_transport.ntcp2_wire import NTCP2WireCodec 122 123 send, recv = _make_cipher_pair() 124 codec = NTCP2WireCodec() 125 frame = NTCP2Frame(FrameType.PADDING, b"") 126 frame_bytes = frame.to_bytes() 127 128 encrypted = codec.encrypt_frame(send, frame) 129 encrypted_length = encrypted[:18] 130 131 recovered_length = codec.decrypt_frame_length(recv, encrypted_length) 132 assert recovered_length == len(frame_bytes) 133 134 135class TestDecryptFramePayload: 136 """decrypt_frame_payload recovers original frame.""" 137 138 def test_recover_frame(self): 139 from i2p_transport.ntcp2_wire import NTCP2WireCodec 140 141 send, recv = _make_cipher_pair() 142 codec = NTCP2WireCodec() 143 payload = b"i2np message content" 144 frame = NTCP2Frame(FrameType.I2NP, payload) 145 146 encrypted = codec.encrypt_frame(send, frame) 147 # Skip encrypted length (18 bytes), rest is encrypted payload 148 encrypted_payload = encrypted[18:] 149 150 # Consume the length nonce first 151 codec.decrypt_frame_length(recv, encrypted[:18]) 152 recovered = codec.decrypt_frame_payload(recv, encrypted_payload) 153 154 assert recovered.frame_type == FrameType.I2NP 155 assert recovered.payload == payload 156 157 158class TestFullRoundtrip: 159 """Full roundtrip: encrypt -> decrypt_length -> decrypt_payload -> original frame.""" 160 161 def test_roundtrip_data_frame(self): 162 from i2p_transport.ntcp2_wire import NTCP2WireCodec 163 164 send, recv = _make_cipher_pair() 165 codec = NTCP2WireCodec() 166 original = NTCP2Frame(FrameType.DATA, b"roundtrip test data") 167 168 wire_bytes = codec.encrypt_frame(send, original) 169 170 # Decrypt: first the length, then the payload 171 frame_len = codec.decrypt_frame_length(recv, wire_bytes[:18]) 172 recovered = codec.decrypt_frame_payload(recv, wire_bytes[18:]) 173 174 assert recovered.frame_type == original.frame_type 175 assert recovered.payload == original.payload 176 assert frame_len == len(original.to_bytes()) 177 178 def test_roundtrip_via_convenience(self): 179 from i2p_transport.ntcp2_wire import NTCP2WireCodec 180 181 send, recv = _make_cipher_pair() 182 codec = NTCP2WireCodec() 183 original = NTCP2Frame(FrameType.ROUTER_INFO, os.urandom(200)) 184 185 wire_bytes = codec.encrypt_and_get_wire_bytes(send, original) 186 187 frame_len = codec.decrypt_frame_length(recv, wire_bytes[:18]) 188 recovered = codec.decrypt_frame_payload(recv, wire_bytes[18:]) 189 190 assert recovered.frame_type == original.frame_type 191 assert recovered.payload == original.payload 192 193 194class TestMultipleFrames: 195 """Multiple frames in sequence (nonces advance correctly).""" 196 197 def test_three_sequential_frames(self): 198 from i2p_transport.ntcp2_wire import NTCP2WireCodec 199 200 send, recv = _make_cipher_pair() 201 codec = NTCP2WireCodec() 202 203 frames = [ 204 NTCP2Frame(FrameType.DATETIME, struct.pack("!I", 1710000000)), 205 NTCP2Frame(FrameType.I2NP, os.urandom(128)), 206 NTCP2Frame(FrameType.PADDING, os.urandom(32)), 207 ] 208 209 encrypted_list = [] 210 for f in frames: 211 encrypted_list.append(codec.encrypt_frame(send, f)) 212 213 for i, enc in enumerate(encrypted_list): 214 frame_len = codec.decrypt_frame_length(recv, enc[:18]) 215 recovered = codec.decrypt_frame_payload(recv, enc[18:]) 216 assert recovered.frame_type == frames[i].frame_type 217 assert recovered.payload == frames[i].payload 218 219 def test_nonces_advance_two_per_frame(self): 220 """Each frame uses 2 nonces (one for length, one for payload). 221 Encrypting 3 frames should use nonces 0-5 on the send side.""" 222 from i2p_transport.ntcp2_wire import NTCP2WireCodec 223 224 send, recv = _make_cipher_pair() 225 codec = NTCP2WireCodec() 226 227 for _ in range(3): 228 frame = NTCP2Frame(FrameType.DATA, b"x") 229 codec.encrypt_frame(send, frame) 230 231 # After 3 frames, send cipher should have used 6 nonces (0..5) 232 # Internal nonce counter should be 6 233 assert send._n == 6 234 235 def test_out_of_order_decrypt_fails(self): 236 """Decrypting frames out of order should fail (nonce mismatch).""" 237 from i2p_transport.ntcp2_wire import NTCP2WireCodec 238 239 send, recv = _make_cipher_pair() 240 codec = NTCP2WireCodec() 241 242 f1 = NTCP2Frame(FrameType.DATA, b"first") 243 f2 = NTCP2Frame(FrameType.DATA, b"second") 244 245 enc1 = codec.encrypt_frame(send, f1) 246 enc2 = codec.encrypt_frame(send, f2) 247 248 # Try to decrypt enc2 first — should fail because recv expects nonce 0 249 with pytest.raises(Exception): 250 codec.decrypt_frame_length(recv, enc2[:18]) 251 252 253class TestDifferentFrameTypes: 254 """Different frame types (I2NP, PADDING, TERMINATION).""" 255 256 def test_i2np_frame(self): 257 from i2p_transport.ntcp2_wire import NTCP2WireCodec 258 259 send, recv = _make_cipher_pair() 260 codec = NTCP2WireCodec() 261 frame = NTCP2Frame(FrameType.I2NP, os.urandom(256)) 262 263 wire = codec.encrypt_frame(send, frame) 264 codec.decrypt_frame_length(recv, wire[:18]) 265 recovered = codec.decrypt_frame_payload(recv, wire[18:]) 266 267 assert recovered.frame_type == FrameType.I2NP 268 assert recovered.payload == frame.payload 269 270 def test_padding_frame(self): 271 from i2p_transport.ntcp2_wire import NTCP2WireCodec 272 273 send, recv = _make_cipher_pair() 274 codec = NTCP2WireCodec() 275 frame = NTCP2Frame(FrameType.PADDING, os.urandom(64)) 276 277 wire = codec.encrypt_frame(send, frame) 278 codec.decrypt_frame_length(recv, wire[:18]) 279 recovered = codec.decrypt_frame_payload(recv, wire[18:]) 280 281 assert recovered.frame_type == FrameType.PADDING 282 assert recovered.payload == frame.payload 283 284 def test_termination_frame(self): 285 from i2p_transport.ntcp2_wire import NTCP2WireCodec 286 287 send, recv = _make_cipher_pair() 288 codec = NTCP2WireCodec() 289 # reason(1) + valid_received(8) 290 payload = struct.pack("!BQ", 0, 42) 291 frame = NTCP2Frame(FrameType.TERMINATION, payload) 292 293 wire = codec.encrypt_frame(send, frame) 294 codec.decrypt_frame_length(recv, wire[:18]) 295 recovered = codec.decrypt_frame_payload(recv, wire[18:]) 296 297 assert recovered.frame_type == FrameType.TERMINATION 298 assert recovered.payload == payload 299 300 def test_datetime_frame(self): 301 from i2p_transport.ntcp2_wire import NTCP2WireCodec 302 303 send, recv = _make_cipher_pair() 304 codec = NTCP2WireCodec() 305 ts = struct.pack("!I", 1710000000) 306 frame = NTCP2Frame(FrameType.DATETIME, ts) 307 308 wire = codec.encrypt_frame(send, frame) 309 codec.decrypt_frame_length(recv, wire[:18]) 310 recovered = codec.decrypt_frame_payload(recv, wire[18:]) 311 312 assert recovered.frame_type == FrameType.DATETIME 313 assert struct.unpack("!I", recovered.payload)[0] == 1710000000 314 315 def test_options_frame(self): 316 from i2p_transport.ntcp2_wire import NTCP2WireCodec 317 318 send, recv = _make_cipher_pair() 319 codec = NTCP2WireCodec() 320 frame = NTCP2Frame(FrameType.OPTIONS, b"\x01\x02\x03\x04") 321 322 wire = codec.encrypt_frame(send, frame) 323 codec.decrypt_frame_length(recv, wire[:18]) 324 recovered = codec.decrypt_frame_payload(recv, wire[18:]) 325 326 assert recovered.frame_type == FrameType.OPTIONS 327 assert recovered.payload == b"\x01\x02\x03\x04" 328 329 330class TestEmptyPayloadFrame: 331 """Empty payload frame.""" 332 333 def test_empty_data_frame_roundtrip(self): 334 from i2p_transport.ntcp2_wire import NTCP2WireCodec 335 336 send, recv = _make_cipher_pair() 337 codec = NTCP2WireCodec() 338 frame = NTCP2Frame(FrameType.DATA, b"") 339 340 wire = codec.encrypt_frame(send, frame) 341 342 # frame_bytes = type(1) + length(2) + payload(0) = 3 bytes 343 assert len(wire) == 18 + 3 + 16 344 345 frame_len = codec.decrypt_frame_length(recv, wire[:18]) 346 assert frame_len == 3 # type(1) + length(2) 347 348 recovered = codec.decrypt_frame_payload(recv, wire[18:]) 349 assert recovered.frame_type == FrameType.DATA 350 assert recovered.payload == b"" 351 352 def test_empty_padding_frame_roundtrip(self): 353 from i2p_transport.ntcp2_wire import NTCP2WireCodec 354 355 send, recv = _make_cipher_pair() 356 codec = NTCP2WireCodec() 357 frame = NTCP2Frame(FrameType.PADDING, b"") 358 359 wire = codec.encrypt_frame(send, frame) 360 frame_len = codec.decrypt_frame_length(recv, wire[:18]) 361 recovered = codec.decrypt_frame_payload(recv, wire[18:]) 362 363 assert recovered.frame_type == FrameType.PADDING 364 assert recovered.payload == b""