A Python port of the Invisible Internet Project (I2P)
1"""Tests for SSU2 Noise_XK handshake state machines.
2
3TDD: tests written before implementation.
4"""
5
6import time
7
8import pytest
9
10from i2p_crypto.x25519 import X25519DH
11from i2p_transport.ssu2_handshake import (
12 HandshakeRole,
13 HandshakePhase,
14 SSU2Token,
15 HandshakeKeys,
16 TokenManager,
17 OutboundHandshake,
18 InboundHandshake,
19 derive_header_keys,
20 derive_data_keys,
21)
22
23
24# ---------------------------------------------------------------------------
25# TokenManager tests
26# ---------------------------------------------------------------------------
27
28class TestTokenManager:
29
30 def test_token_manager_create_validate(self):
31 """Create a token for an IP and validate it."""
32 mgr = TokenManager(token_lifetime_seconds=3600)
33 remote_ip = b"\x7f\x00\x00\x01" # 127.0.0.1
34 token = mgr.create_token(remote_ip)
35
36 assert isinstance(token, SSU2Token)
37 assert token.remote_ip == remote_ip
38 assert token.token != 0
39 assert mgr.validate_token(token.token, remote_ip) is True
40
41 def test_token_manager_expired(self):
42 """Expired token fails validation."""
43 mgr = TokenManager(token_lifetime_seconds=0) # immediate expiry
44 remote_ip = b"\x7f\x00\x00\x01"
45 token = mgr.create_token(remote_ip)
46 # Token expires immediately (lifetime=0), so by the time we check it's expired
47 assert mgr.validate_token(token.token, remote_ip) is False
48
49 def test_token_manager_wrong_ip(self):
50 """Token for one IP is not valid for another."""
51 mgr = TokenManager(token_lifetime_seconds=3600)
52 ip1 = b"\x7f\x00\x00\x01"
53 ip2 = b"\x0a\x00\x00\x01"
54 token = mgr.create_token(ip1)
55
56 assert mgr.validate_token(token.token, ip1) is True
57 assert mgr.validate_token(token.token, ip2) is False
58
59 def test_token_manager_cleanup_expired(self):
60 """cleanup_expired removes old tokens."""
61 mgr = TokenManager(token_lifetime_seconds=0)
62 remote_ip = b"\x7f\x00\x00\x01"
63 mgr.create_token(remote_ip)
64 mgr.cleanup_expired()
65 # After cleanup, token should be gone
66 assert len(mgr._tokens) == 0
67
68
69# ---------------------------------------------------------------------------
70# Outbound handshake initial state
71# ---------------------------------------------------------------------------
72
73class TestOutboundHandshakeState:
74
75 def test_outbound_initial_phase(self):
76 """Without a token, starts at TOKEN_REQUEST."""
77 alice_priv, alice_pub = X25519DH.generate_keypair()
78 bob_priv, bob_pub = X25519DH.generate_keypair()
79 intro_key = b"\x00" * 32
80
81 hs = OutboundHandshake(
82 local_static_key=alice_priv,
83 remote_static_key=bob_pub,
84 remote_intro_key=intro_key,
85 token=None,
86 )
87 assert hs.phase == HandshakePhase.TOKEN_REQUEST
88
89 def test_outbound_with_token_skips_token_request(self):
90 """With a cached token, starts at SESSION_REQUEST."""
91 alice_priv, alice_pub = X25519DH.generate_keypair()
92 bob_priv, bob_pub = X25519DH.generate_keypair()
93 intro_key = b"\x00" * 32
94
95 hs = OutboundHandshake(
96 local_static_key=alice_priv,
97 remote_static_key=bob_pub,
98 remote_intro_key=intro_key,
99 token=0xDEADBEEF,
100 )
101 assert hs.phase == HandshakePhase.SESSION_REQUEST
102
103
104# ---------------------------------------------------------------------------
105# HandshakeKeys structure
106# ---------------------------------------------------------------------------
107
108class TestHandshakeKeys:
109
110 def test_handshake_keys_structure(self):
111 """HandshakeKeys has correct field sizes."""
112 keys = HandshakeKeys(
113 send_cipher_key=b"\x01" * 32,
114 recv_cipher_key=b"\x02" * 32,
115 send_header_key=b"\x03" * 32,
116 recv_header_key=b"\x04" * 32,
117 )
118 assert len(keys.send_cipher_key) == 32
119 assert len(keys.recv_cipher_key) == 32
120 assert len(keys.send_header_key) == 32
121 assert len(keys.recv_header_key) == 32
122
123
124# ---------------------------------------------------------------------------
125# Key derivation
126# ---------------------------------------------------------------------------
127
128class TestKeyDerivation:
129
130 def test_derive_header_keys(self):
131 """HKDF produces two 32-byte keys."""
132 ck = b"\xab" * 32
133 k1, k2 = derive_header_keys(ck)
134 assert len(k1) == 32
135 assert len(k2) == 32
136 assert k1 != k2
137
138 def test_derive_data_keys(self):
139 """Produces HandshakeKeys with all fields populated."""
140 ck = b"\xcd" * 32
141 hh = b"\xef" * 32
142 keys = derive_data_keys(ck, hh)
143 assert isinstance(keys, HandshakeKeys)
144 assert len(keys.send_cipher_key) == 32
145 assert len(keys.recv_cipher_key) == 32
146 assert len(keys.send_header_key) == 32
147 assert len(keys.recv_header_key) == 32
148
149 def test_derive_header_keys_deterministic(self):
150 """Same input produces same output."""
151 ck = b"\x11" * 32
152 k1a, k2a = derive_header_keys(ck)
153 k1b, k2b = derive_header_keys(ck)
154 assert k1a == k1b
155 assert k2a == k2b
156
157
158# ---------------------------------------------------------------------------
159# get_keys before ESTABLISHED
160# ---------------------------------------------------------------------------
161
162class TestGetKeysBeforeEstablished:
163
164 def test_get_keys_before_established_raises(self):
165 """get_keys raises ValueError if not ESTABLISHED."""
166 alice_priv, alice_pub = X25519DH.generate_keypair()
167 bob_priv, bob_pub = X25519DH.generate_keypair()
168 intro_key = b"\x00" * 32
169
170 hs = OutboundHandshake(
171 local_static_key=alice_priv,
172 remote_static_key=bob_pub,
173 remote_intro_key=intro_key,
174 token=0x1234,
175 )
176 with pytest.raises(ValueError, match="not.*ESTABLISHED"):
177 hs.get_keys()
178
179
180# ---------------------------------------------------------------------------
181# Full handshake roundtrip tests
182# ---------------------------------------------------------------------------
183
184class TestFullHandshake:
185
186 def _make_keypairs(self):
187 """Generate Alice and Bob keypairs."""
188 alice_priv, alice_pub = X25519DH.generate_keypair()
189 bob_priv, bob_pub = X25519DH.generate_keypair()
190 return alice_priv, alice_pub, bob_priv, bob_pub
191
192 def test_full_handshake_roundtrip(self):
193 """Alice and Bob complete full handshake, both derive same data keys."""
194 alice_priv, alice_pub, bob_priv, bob_pub = self._make_keypairs()
195 bob_intro_key = b"\x42" * 32
196 token_mgr = TokenManager()
197
198 # Alice starts with a token (skip token request for simplicity)
199 alice = OutboundHandshake(
200 local_static_key=alice_priv,
201 remote_static_key=bob_pub,
202 remote_intro_key=bob_intro_key,
203 token=0xCAFEBABE,
204 )
205 bob = InboundHandshake(
206 local_static_key=bob_priv,
207 local_intro_key=bob_intro_key,
208 token_manager=token_mgr,
209 )
210
211 # Pre-register the token so Bob accepts it
212 token_mgr._tokens[b"\x7f\x00\x00\x01"] = SSU2Token(
213 token=0xCAFEBABE,
214 remote_ip=b"\x7f\x00\x00\x01",
215 expires=time.time() + 3600,
216 )
217
218 # 1. Alice builds Session Request
219 session_request = alice.build_session_request()
220 assert len(session_request) > 0
221
222 # 2. Bob processes Session Request
223 bob.process_session_request(session_request, remote_ip=b"\x7f\x00\x00\x01")
224
225 # 3. Bob builds Session Created
226 session_created = bob.build_session_created()
227 assert len(session_created) > 0
228
229 # 4. Alice processes Session Created
230 alice.process_session_created(session_created)
231
232 # 5. Alice builds Session Confirmed
233 session_confirmed = alice.build_session_confirmed()
234 assert len(session_confirmed) > 0
235
236 # 6. Bob processes Session Confirmed
237 bob.process_session_confirmed(session_confirmed)
238
239 # Both should be ESTABLISHED
240 assert alice.phase == HandshakePhase.ESTABLISHED
241 assert bob.phase == HandshakePhase.ESTABLISHED
242
243 # Both derive the same data-phase keys (swapped for send/recv)
244 alice_keys = alice.get_keys()
245 bob_keys = bob.get_keys()
246
247 assert alice_keys.send_cipher_key == bob_keys.recv_cipher_key
248 assert alice_keys.recv_cipher_key == bob_keys.send_cipher_key
249 assert alice_keys.send_header_key == bob_keys.recv_header_key
250 assert alice_keys.recv_header_key == bob_keys.send_header_key
251
252 def test_handshake_with_token_request(self):
253 """Full flow including token request/retry."""
254 alice_priv, alice_pub, bob_priv, bob_pub = self._make_keypairs()
255 bob_intro_key = b"\x42" * 32
256 token_mgr = TokenManager()
257 remote_ip = b"\x7f\x00\x00\x01"
258
259 # Alice has no token
260 alice = OutboundHandshake(
261 local_static_key=alice_priv,
262 remote_static_key=bob_pub,
263 remote_intro_key=bob_intro_key,
264 token=None,
265 )
266 bob = InboundHandshake(
267 local_static_key=bob_priv,
268 local_intro_key=bob_intro_key,
269 token_manager=token_mgr,
270 )
271
272 assert alice.phase == HandshakePhase.TOKEN_REQUEST
273
274 # 1. Alice builds Token Request
275 token_request = alice.build_token_request()
276 assert len(token_request) > 0
277
278 # 2. Bob processes Token Request and returns Retry
279 retry = bob.process_token_request(token_request, remote_ip)
280 assert len(retry) > 0
281
282 # 3. Alice processes Retry (gets token)
283 alice.process_retry(retry)
284 assert alice.phase == HandshakePhase.SESSION_REQUEST
285
286 # 4. Continue with normal handshake
287 session_request = alice.build_session_request()
288 bob.process_session_request(session_request, remote_ip)
289
290 session_created = bob.build_session_created()
291 alice.process_session_created(session_created)
292
293 session_confirmed = alice.build_session_confirmed()
294 bob.process_session_confirmed(session_confirmed)
295
296 assert alice.phase == HandshakePhase.ESTABLISHED
297 assert bob.phase == HandshakePhase.ESTABLISHED
298
299 alice_keys = alice.get_keys()
300 bob_keys = bob.get_keys()
301 assert alice_keys.send_cipher_key == bob_keys.recv_cipher_key
302 assert alice_keys.recv_cipher_key == bob_keys.send_cipher_key
303
304 def test_handshake_without_token(self):
305 """Full flow with pre-cached token (skips token request)."""
306 alice_priv, alice_pub, bob_priv, bob_pub = self._make_keypairs()
307 bob_intro_key = b"\x42" * 32
308 token_mgr = TokenManager()
309 remote_ip = b"\x0a\x00\x00\x01"
310
311 # Pre-register token
312 cached_token = 0xFEEDFACE
313 token_mgr._tokens[remote_ip] = SSU2Token(
314 token=cached_token,
315 remote_ip=remote_ip,
316 expires=time.time() + 3600,
317 )
318
319 alice = OutboundHandshake(
320 local_static_key=alice_priv,
321 remote_static_key=bob_pub,
322 remote_intro_key=bob_intro_key,
323 token=cached_token,
324 )
325 bob = InboundHandshake(
326 local_static_key=bob_priv,
327 local_intro_key=bob_intro_key,
328 token_manager=token_mgr,
329 )
330
331 assert alice.phase == HandshakePhase.SESSION_REQUEST
332
333 session_request = alice.build_session_request()
334 bob.process_session_request(session_request, remote_ip)
335
336 session_created = bob.build_session_created()
337 alice.process_session_created(session_created)
338
339 session_confirmed = alice.build_session_confirmed()
340 bob.process_session_confirmed(session_confirmed)
341
342 assert alice.phase == HandshakePhase.ESTABLISHED
343 assert bob.phase == HandshakePhase.ESTABLISHED
344
345 alice_keys = alice.get_keys()
346 bob_keys = bob.get_keys()
347 assert alice_keys.send_cipher_key == bob_keys.recv_cipher_key
348
349 def test_session_request_contains_ephemeral(self):
350 """Session Request contains a 32-byte ephemeral key."""
351 alice_priv, _ = X25519DH.generate_keypair()
352 _, bob_pub = X25519DH.generate_keypair()
353 intro_key = b"\x00" * 32
354
355 alice = OutboundHandshake(
356 local_static_key=alice_priv,
357 remote_static_key=bob_pub,
358 remote_intro_key=intro_key,
359 token=0x1234,
360 )
361 pkt = alice.build_session_request()
362 # The session request long header is 32 bytes, followed by the
363 # 32-byte ephemeral key (start of Noise message 1)
364 assert len(pkt) >= 64 # at least header + ephemeral key
365 # After building, the ephemeral public should be set
366 assert alice._ephemeral_public is not None
367 assert len(alice._ephemeral_public) == 32
368
369 def test_session_created_contains_ephemeral(self):
370 """Session Created contains Bob's ephemeral key."""
371 alice_priv, alice_pub, bob_priv, bob_pub = (
372 X25519DH.generate_keypair() + X25519DH.generate_keypair()
373 )
374 bob_intro_key = b"\x42" * 32
375 token_mgr = TokenManager()
376 remote_ip = b"\x7f\x00\x00\x01"
377
378 token_mgr._tokens[remote_ip] = SSU2Token(
379 token=0xBEEF, remote_ip=remote_ip, expires=time.time() + 3600
380 )
381
382 alice = OutboundHandshake(
383 local_static_key=alice_priv,
384 remote_static_key=bob_pub,
385 remote_intro_key=bob_intro_key,
386 token=0xBEEF,
387 )
388 bob = InboundHandshake(
389 local_static_key=bob_priv,
390 local_intro_key=bob_intro_key,
391 token_manager=token_mgr,
392 )
393
394 session_request = alice.build_session_request()
395 bob.process_session_request(session_request, remote_ip)
396 session_created = bob.build_session_created()
397
398 # Bob should have generated an ephemeral key
399 assert bob._ephemeral_public is not None
400 assert len(bob._ephemeral_public) == 32
401 # Session created packet should be non-trivial
402 assert len(session_created) >= 64