A Python port of the Invisible Internet Project (I2P)
1"""Tests for NTCP2 Noise handshake integration.
2
3TDD: these tests are written before the implementation.
4"""
5
6import pytest
7
8from i2p_crypto.x25519 import X25519DH
9from i2p_crypto.noise import CipherState
10from i2p_transport.ntcp2 import NTCP2Frame, FrameType
11from i2p_transport.ntcp2_handshake import (
12 NTCP2Handshake,
13 NTCP2FrameCodec,
14 NTCP2MessageFragmenter,
15)
16
17
18# ---------------------------------------------------------------------------
19# Helpers
20# ---------------------------------------------------------------------------
21
22def _make_keypair():
23 return X25519DH.generate_keypair()
24
25
26def _do_full_handshake():
27 """Run a complete 3-message Noise_XK handshake, return both sides."""
28 alice_static = _make_keypair()
29 bob_static = _make_keypair()
30
31 # Alice (initiator) knows Bob's static public key
32 alice = NTCP2Handshake(
33 our_static=alice_static,
34 peer_static_pub=bob_static[1],
35 initiator=True,
36 )
37 # Bob (responder) does NOT know Alice's static yet
38 bob = NTCP2Handshake(
39 our_static=bob_static,
40 peer_static_pub=None,
41 initiator=False,
42 )
43
44 # msg1: Alice -> Bob
45 msg1 = alice.create_message_1(options=b"hello")
46
47 # msg2: Bob processes msg1, produces msg2
48 msg2 = bob.process_message_1(msg1)
49
50 # msg3: Alice processes msg2, produces msg3
51 msg3 = alice.process_message_2(msg2)
52
53 # Bob processes msg3
54 bob.process_message_3(msg3)
55
56 return alice, bob, alice_static, bob_static
57
58
59# ---------------------------------------------------------------------------
60# Handshake tests
61# ---------------------------------------------------------------------------
62
63class TestNTCP2Handshake:
64
65 def test_full_handshake_completes(self):
66 alice, bob, _, _ = _do_full_handshake()
67 assert alice.is_complete()
68 assert bob.is_complete()
69
70 def test_split_produces_cipher_pair(self):
71 alice, bob, _, _ = _do_full_handshake()
72 a_send, a_recv = alice.split()
73 b_send, b_recv = bob.split()
74 assert isinstance(a_send, CipherState)
75 assert isinstance(a_recv, CipherState)
76 assert isinstance(b_send, CipherState)
77 assert isinstance(b_recv, CipherState)
78
79 def test_transport_encryption_after_handshake(self):
80 """Initiator encrypts, responder decrypts."""
81 alice, bob, _, _ = _do_full_handshake()
82 a_send, a_recv = alice.split()
83 b_send, b_recv = bob.split()
84
85 plaintext = b"I2P rocks"
86 ct = a_send.encrypt_with_ad(b"", plaintext)
87 pt = b_recv.decrypt_with_ad(b"", ct)
88 assert pt == plaintext
89
90 def test_bidirectional_encryption(self):
91 """Both directions work: alice->bob and bob->alice."""
92 alice, bob, _, _ = _do_full_handshake()
93 a_send, a_recv = alice.split()
94 b_send, b_recv = bob.split()
95
96 # Alice -> Bob
97 ct1 = a_send.encrypt_with_ad(b"", b"from alice")
98 assert b_recv.decrypt_with_ad(b"", ct1) == b"from alice"
99
100 # Bob -> Alice
101 ct2 = b_send.encrypt_with_ad(b"", b"from bob")
102 assert a_recv.decrypt_with_ad(b"", ct2) == b"from bob"
103
104 def test_remote_static_key_recovery(self):
105 """Responder learns initiator's static key through XK pattern."""
106 alice, bob, alice_static, bob_static = _do_full_handshake()
107 # Responder should now know Alice's static public key
108 assert bob.remote_static_key() == alice_static[1]
109 # Initiator already knew Bob's static
110 assert alice.remote_static_key() == bob_static[1]
111
112 def test_handshake_out_of_order_error(self):
113 """Calling handshake methods out of order should raise."""
114 alice_static = _make_keypair()
115 bob_static = _make_keypair()
116
117 alice = NTCP2Handshake(
118 our_static=alice_static,
119 peer_static_pub=bob_static[1],
120 initiator=True,
121 )
122
123 # Cannot process_message_2 before create_message_1
124 with pytest.raises(RuntimeError):
125 alice.process_message_2(b"fake_msg")
126
127 def test_decrypt_with_wrong_cipher_fails(self):
128 """Decrypting with mismatched cipher must fail."""
129 alice, bob, _, _ = _do_full_handshake()
130 a_send, a_recv = alice.split()
131 b_send, b_recv = bob.split()
132
133 ct = a_send.encrypt_with_ad(b"", b"secret")
134 # Try to decrypt with the wrong cipher (b_send instead of b_recv)
135 with pytest.raises(Exception):
136 b_send.decrypt_with_ad(b"", ct)
137
138 def test_is_complete_false_before_handshake(self):
139 alice_static = _make_keypair()
140 bob_static = _make_keypair()
141 alice = NTCP2Handshake(
142 our_static=alice_static,
143 peer_static_pub=bob_static[1],
144 initiator=True,
145 )
146 assert not alice.is_complete()
147
148
149# ---------------------------------------------------------------------------
150# FrameCodec tests
151# ---------------------------------------------------------------------------
152
153class TestNTCP2FrameCodec:
154
155 def _make_cipher_pair(self):
156 alice, bob, _, _ = _do_full_handshake()
157 a_send, a_recv = alice.split()
158 b_send, b_recv = bob.split()
159 return a_send, b_recv
160
161 def test_encrypt_decrypt_data_frame(self):
162 codec = NTCP2FrameCodec()
163 send_c, recv_c = self._make_cipher_pair()
164
165 frame = NTCP2Frame(FrameType.DATA, b"some data payload")
166 encrypted = codec.encrypt_frame(send_c, frame)
167 decrypted = codec.decrypt_frame(recv_c, encrypted)
168
169 assert decrypted.frame_type == FrameType.DATA
170 assert decrypted.payload == b"some data payload"
171
172 def test_encrypt_decrypt_i2np_frame(self):
173 codec = NTCP2FrameCodec()
174 send_c, recv_c = self._make_cipher_pair()
175
176 frame = NTCP2Frame(FrameType.I2NP, b"\x01\x02\x03\x04")
177 encrypted = codec.encrypt_frame(send_c, frame)
178 decrypted = codec.decrypt_frame(recv_c, encrypted)
179
180 assert decrypted.frame_type == FrameType.I2NP
181 assert decrypted.payload == b"\x01\x02\x03\x04"
182
183 def test_encrypt_decrypt_padding_frame(self):
184 codec = NTCP2FrameCodec()
185 send_c, recv_c = self._make_cipher_pair()
186
187 frame = NTCP2Frame(FrameType.PADDING, b"\x00" * 16)
188 encrypted = codec.encrypt_frame(send_c, frame)
189 decrypted = codec.decrypt_frame(recv_c, encrypted)
190
191 assert decrypted.frame_type == FrameType.PADDING
192 assert decrypted.payload == b"\x00" * 16
193
194
195# ---------------------------------------------------------------------------
196# MessageFragmenter tests
197# ---------------------------------------------------------------------------
198
199class TestNTCP2MessageFragmenter:
200
201 def test_single_chunk_fits_one_frame(self):
202 frag = NTCP2MessageFragmenter()
203 data = b"short message"
204 frames = frag.fragment(data, max_payload=65535)
205 assert len(frames) == 1
206 assert frames[0].frame_type == FrameType.I2NP
207 assert frames[0].payload == data
208
209 def test_large_message_split_across_frames(self):
210 frag = NTCP2MessageFragmenter()
211 data = b"A" * 1000
212 frames = frag.fragment(data, max_payload=300)
213 # 1000 / 300 = 4 frames (300+300+300+100)
214 assert len(frames) == 4
215 for f in frames:
216 assert f.frame_type == FrameType.I2NP
217
218 def test_reassemble_single_frame(self):
219 frag = NTCP2MessageFragmenter()
220 data = b"hello world"
221 frames = frag.fragment(data)
222 result = frag.reassemble(frames)
223 assert result == data
224
225 def test_reassemble_multiple_frames(self):
226 frag = NTCP2MessageFragmenter()
227 data = b"B" * 1000
228 frames = frag.fragment(data, max_payload=256)
229 result = frag.reassemble(frames)
230 assert result == data
231
232 def test_fragment_exact_multiple(self):
233 frag = NTCP2MessageFragmenter()
234 data = b"C" * 600
235 frames = frag.fragment(data, max_payload=200)
236 assert len(frames) == 3
237 assert frag.reassemble(frames) == data
238
239 def test_empty_message(self):
240 frag = NTCP2MessageFragmenter()
241 frames = frag.fragment(b"")
242 assert len(frames) == 1
243 assert frames[0].payload == b""
244 assert frag.reassemble(frames) == b""