"""Tests for TunnelPool and TunnelPoolManager — tunnel pool policies and build orchestration. TDD: tests written first, implementation follows. """ from __future__ import annotations import time from unittest.mock import MagicMock, patch import pytest from i2p_tunnel.pool import TunnelEntry, TunnelPool, TunnelPoolManager # --------------------------------------------------------------------------- # TunnelEntry tests # --------------------------------------------------------------------------- class TestTunnelEntry: def test_not_expired_when_fresh(self): now = time.monotonic() entry = TunnelEntry(tunnel_id=1, hops=[b"\x01" * 32], created_at=now, lifetime_seconds=600.0, is_inbound=False) assert not entry.is_expired def test_expired_after_lifetime(self): now = time.monotonic() entry = TunnelEntry(tunnel_id=1, hops=[b"\x01" * 32], created_at=now - 700, lifetime_seconds=600.0) assert entry.is_expired def test_remaining_seconds(self): now = time.monotonic() entry = TunnelEntry(tunnel_id=1, hops=[], created_at=now, lifetime_seconds=600.0) remaining = entry.remaining_seconds assert 599.0 <= remaining <= 600.0 def test_remaining_seconds_negative_when_expired(self): entry = TunnelEntry(tunnel_id=1, hops=[], created_at=time.monotonic() - 700, lifetime_seconds=600.0) assert entry.remaining_seconds < 0 def test_lifetime_fraction_fresh(self): entry = TunnelEntry(tunnel_id=1, hops=[], created_at=time.monotonic(), lifetime_seconds=600.0) assert 0.0 <= entry.lifetime_fraction <= 0.01 def test_lifetime_fraction_halfway(self): entry = TunnelEntry(tunnel_id=1, hops=[], created_at=time.monotonic() - 300, lifetime_seconds=600.0) assert 0.49 <= entry.lifetime_fraction <= 0.51 def test_lifetime_fraction_expired(self): entry = TunnelEntry(tunnel_id=1, hops=[], created_at=time.monotonic() - 700, lifetime_seconds=600.0) assert entry.lifetime_fraction >= 1.0 # --------------------------------------------------------------------------- # TunnelPool tests # --------------------------------------------------------------------------- class TestTunnelPool: def _make_entry(self, tunnel_id=1, age=0.0, lifetime=600.0, is_inbound=False): return TunnelEntry( tunnel_id=tunnel_id, hops=[b"\xaa" * 32], created_at=time.monotonic() - age, lifetime_seconds=lifetime, is_inbound=is_inbound, ) def test_initial_state_empty(self): pool = TunnelPool(target_count=3) assert pool.active_count == 0 def test_needs_rebuild_when_empty(self): pool = TunnelPool(target_count=3, min_count=1) assert pool.needs_rebuild() def test_add_tunnel_increases_count(self): pool = TunnelPool(target_count=3) pool.add_tunnel(self._make_entry(tunnel_id=1)) assert pool.active_count == 1 def test_needs_rebuild_true_below_target(self): pool = TunnelPool(target_count=3, min_count=1) pool.add_tunnel(self._make_entry(tunnel_id=1)) pool.add_tunnel(self._make_entry(tunnel_id=2)) assert pool.needs_rebuild() def test_needs_rebuild_false_at_target(self): pool = TunnelPool(target_count=3, min_count=1) for i in range(3): pool.add_tunnel(self._make_entry(tunnel_id=i)) assert not pool.needs_rebuild() def test_needs_rebuild_false_above_target(self): pool = TunnelPool(target_count=2, min_count=1) for i in range(4): pool.add_tunnel(self._make_entry(tunnel_id=i)) assert not pool.needs_rebuild() def test_remove_expired(self): pool = TunnelPool(target_count=3) # One fresh, one expired pool.add_tunnel(self._make_entry(tunnel_id=1, age=0.0)) pool.add_tunnel(self._make_entry(tunnel_id=2, age=700.0, lifetime=600.0)) removed = pool.remove_expired() assert removed == 1 assert pool.active_count == 1 def test_remove_expired_returns_zero_when_none_expired(self): pool = TunnelPool(target_count=3) pool.add_tunnel(self._make_entry(tunnel_id=1)) assert pool.remove_expired() == 0 def test_select_for_routing_returns_tunnel(self): pool = TunnelPool(target_count=3) pool.add_tunnel(self._make_entry(tunnel_id=42)) result = pool.select_for_routing() assert result is not None assert result.tunnel_id == 42 def test_select_for_routing_none_when_empty(self): pool = TunnelPool(target_count=3) assert pool.select_for_routing() is None def test_select_for_routing_avoids_nearly_expired(self): pool = TunnelPool(target_count=3, rebuild_threshold=0.80) # Tunnel at 95% of lifetime — should be skipped pool.add_tunnel(self._make_entry(tunnel_id=1, age=570.0, lifetime=600.0)) # Fresh tunnel — should be selected pool.add_tunnel(self._make_entry(tunnel_id=2, age=0.0)) result = pool.select_for_routing() assert result is not None assert result.tunnel_id == 2 def test_select_for_routing_returns_nearly_expired_if_only_option(self): """If all tunnels are near expiry, still return one rather than None.""" pool = TunnelPool(target_count=3, rebuild_threshold=0.80) pool.add_tunnel(self._make_entry(tunnel_id=1, age=570.0, lifetime=600.0)) result = pool.select_for_routing() assert result is not None def test_get_statistics(self): pool = TunnelPool(target_count=3, min_count=1, is_inbound=True) pool.add_tunnel(self._make_entry(tunnel_id=1, age=0.0)) pool.add_tunnel(self._make_entry(tunnel_id=2, age=700.0, lifetime=600.0)) stats = pool.get_statistics() assert stats["active_count"] == 2 # not yet cleaned assert stats["target_count"] == 3 assert stats["is_inbound"] is True assert "needs_rebuild" in stats def test_needs_preemptive_rebuild_when_tunnel_nearing_expiry(self): pool = TunnelPool(target_count=3, min_count=1, rebuild_threshold=0.80) # 3 tunnels at target, but one is at 85% lifetime — preemptive rebuild pool.add_tunnel(self._make_entry(tunnel_id=1, age=0.0)) pool.add_tunnel(self._make_entry(tunnel_id=2, age=0.0)) pool.add_tunnel(self._make_entry(tunnel_id=3, age=510.0, lifetime=600.0)) assert pool.needs_preemptive_rebuild() def test_no_preemptive_rebuild_when_all_fresh(self): pool = TunnelPool(target_count=3, min_count=1, rebuild_threshold=0.80) for i in range(3): pool.add_tunnel(self._make_entry(tunnel_id=i, age=0.0)) assert not pool.needs_preemptive_rebuild() # --------------------------------------------------------------------------- # TunnelPoolManager tests # --------------------------------------------------------------------------- class TestTunnelPoolManager: def _make_selector(self, peer_count=5): """Create a mock PeerSelector with peer_count peers.""" from i2p_peer.hop_config import TunnelHopConfig import os import struct selector = MagicMock() peers = [os.urandom(32) for _ in range(peer_count)] def fake_select_hops(length, exclude=None): exclude = exclude or set() available = [p for p in peers if p not in exclude] hops = [] for i, ph in enumerate(available[:length]): tid = struct.unpack("!I", os.urandom(4))[0] or 1 hops.append(TunnelHopConfig( peer_hash=ph, receive_tunnel_id=tid, layer_key=os.urandom(32), iv_key=os.urandom(32), reply_key=os.urandom(32), reply_iv=os.urandom(16), is_gateway=(i == 0), is_endpoint=(i == length - 1), )) return hops selector.select_hops = MagicMock(side_effect=fake_select_hops) return selector, peers def test_initial_pools_need_rebuild(self): selector, _ = self._make_selector() mgr = TunnelPoolManager(peer_selector=selector) assert mgr.inbound_pool.needs_rebuild() assert mgr.outbound_pool.needs_rebuild() def test_maintain_pools_returns_hop_configs(self): selector, _ = self._make_selector() mgr = TunnelPoolManager(peer_selector=selector, tunnel_length=2) configs = mgr.maintain_pools() # Should request builds for both inbound and outbound (6 total at target=3 each) assert len(configs) > 0 # Each config is a list of TunnelHopConfig for cfg in configs: assert len(cfg) == 2 # tunnel_length=2 def test_select_inbound_outbound_pair_none_when_empty(self): selector, _ = self._make_selector() mgr = TunnelPoolManager(peer_selector=selector) assert mgr.select_inbound_outbound_pair() is None def test_select_inbound_outbound_pair(self): selector, _ = self._make_selector() mgr = TunnelPoolManager(peer_selector=selector) # Manually add entries inbound = TunnelEntry(tunnel_id=10, hops=[b"\x01" * 32], created_at=time.monotonic(), is_inbound=True) outbound = TunnelEntry(tunnel_id=20, hops=[b"\x02" * 32], created_at=time.monotonic(), is_inbound=False) mgr.inbound_pool.add_tunnel(inbound) mgr.outbound_pool.add_tunnel(outbound) pair = mgr.select_inbound_outbound_pair() assert pair is not None assert pair[0].is_inbound assert not pair[1].is_inbound def test_record_build_failure(self): selector, peers = self._make_selector() mgr = TunnelPoolManager(peer_selector=selector) peer = peers[0] assert mgr.get_build_failure_count(peer) == 0 mgr.record_build_failure(peer) assert mgr.get_build_failure_count(peer) == 1 mgr.record_build_failure(peer) mgr.record_build_failure(peer) assert mgr.get_build_failure_count(peer) == 3 def test_record_build_success_adds_to_pool(self): selector, _ = self._make_selector() mgr = TunnelPoolManager(peer_selector=selector) entry = TunnelEntry(tunnel_id=99, hops=[b"\xbb" * 32], created_at=time.monotonic(), is_inbound=True) mgr.record_build_success(entry) assert mgr.inbound_pool.active_count == 1 def test_record_build_success_outbound(self): selector, _ = self._make_selector() mgr = TunnelPoolManager(peer_selector=selector) entry = TunnelEntry(tunnel_id=99, hops=[b"\xbb" * 32], created_at=time.monotonic(), is_inbound=False) mgr.record_build_success(entry) assert mgr.outbound_pool.active_count == 1 def test_maintain_pools_excludes_failed_peers(self): selector, peers = self._make_selector() mgr = TunnelPoolManager(peer_selector=selector, max_failures=3) # Fail a peer 4 times (> max_failures=3) for _ in range(4): mgr.record_build_failure(peers[0]) configs = mgr.maintain_pools() # The excluded peer should be passed to select_hops for call_args in selector.select_hops.call_args_list: exclude = call_args.kwargs.get("exclude") or ( call_args.args[1] if len(call_args.args) > 1 else set() ) assert peers[0] in exclude def test_maintain_pools_handles_preemptive_rebuild(self): selector, _ = self._make_selector() mgr = TunnelPoolManager(peer_selector=selector, tunnel_length=2, target_count=2, rebuild_threshold=0.80) # Fill both pools to target for i in range(2): mgr.inbound_pool.add_tunnel(TunnelEntry( tunnel_id=i, hops=[b"\x01" * 32], created_at=time.monotonic(), is_inbound=True)) mgr.outbound_pool.add_tunnel(TunnelEntry( tunnel_id=i + 100, hops=[b"\x02" * 32], created_at=time.monotonic(), is_inbound=False)) # No builds needed — pools at target configs = mgr.maintain_pools() assert len(configs) == 0 # Now age one inbound tunnel past threshold mgr.inbound_pool._tunnels[0] = TunnelEntry( tunnel_id=0, hops=[b"\x01" * 32], created_at=time.monotonic() - 500, lifetime_seconds=600.0, is_inbound=True) configs = mgr.maintain_pools() assert len(configs) >= 1 # preemptive rebuild triggered