A Python port of the Invisible Internet Project (I2P)
at main 204 lines 6.5 kB view raw
1"""Tests for TransportManager — multi-transport coordinator. 2 3TDD: tests written before implementation. 4""" 5 6import asyncio 7import pytest 8 9from i2p_transport.transport_base import ( 10 Transport, TransportBid, TransportStyle, ReachabilityStatus, 11) 12from i2p_transport.manager import TransportManager 13 14 15class MockTransport(Transport): 16 """Mock transport for testing TransportManager.""" 17 18 def __init__( 19 self, 20 style: TransportStyle, 21 bid_latency: int = 100, 22 reachability: ReachabilityStatus = ReachabilityStatus.OK, 23 address: dict | None = None, 24 ): 25 self._style = style 26 self._bid_latency = bid_latency 27 self._reachability = reachability 28 self._running = False 29 self._sent: list[tuple[bytes, bytes]] = [] 30 self._address = address or {"host": "127.0.0.1", "port": 15000} 31 self._send_result = True 32 33 @property 34 def style(self) -> TransportStyle: 35 return self._style 36 37 async def start(self) -> None: 38 self._running = True 39 40 async def stop(self) -> None: 41 self._running = False 42 43 @property 44 def is_running(self) -> bool: 45 return self._running 46 47 async def bid(self, peer_hash: bytes) -> TransportBid: 48 if self._bid_latency == TransportBid.WILL_NOT_SEND: 49 return TransportBid( 50 latency_ms=TransportBid.WILL_NOT_SEND, 51 transport=self, 52 ) 53 return TransportBid(latency_ms=self._bid_latency, transport=self) 54 55 async def send(self, peer_hash: bytes, data: bytes) -> bool: 56 self._sent.append((peer_hash, data)) 57 return self._send_result 58 59 @property 60 def reachability(self) -> ReachabilityStatus: 61 return self._reachability 62 63 @property 64 def current_address(self) -> dict | None: 65 return self._address 66 67 68# --------------------------------------------------------------------------- 69# Tests 70# --------------------------------------------------------------------------- 71 72class TestTransportManager: 73 74 def test_register_transport(self): 75 mgr = TransportManager() 76 t = MockTransport(TransportStyle.NTCP2) 77 mgr.register(t) 78 assert mgr.get_transport(TransportStyle.NTCP2) is t 79 80 def test_register_multiple(self): 81 mgr = TransportManager() 82 ntcp2 = MockTransport(TransportStyle.NTCP2) 83 ssu2 = MockTransport(TransportStyle.SSU2) 84 mgr.register(ntcp2) 85 mgr.register(ssu2) 86 assert mgr.get_transport(TransportStyle.NTCP2) is ntcp2 87 assert mgr.get_transport(TransportStyle.SSU2) is ssu2 88 89 @pytest.mark.asyncio 90 async def test_start_stop_all(self): 91 mgr = TransportManager() 92 ntcp2 = MockTransport(TransportStyle.NTCP2) 93 ssu2 = MockTransport(TransportStyle.SSU2) 94 mgr.register(ntcp2) 95 mgr.register(ssu2) 96 97 await mgr.start_all() 98 assert ntcp2.is_running 99 assert ssu2.is_running 100 assert mgr.is_running 101 102 await mgr.stop_all() 103 assert not ntcp2.is_running 104 assert not ssu2.is_running 105 assert not mgr.is_running 106 107 @pytest.mark.asyncio 108 async def test_send_selects_best_bid(self): 109 mgr = TransportManager() 110 slow = MockTransport(TransportStyle.NTCP2, bid_latency=200) 111 fast = MockTransport(TransportStyle.SSU2, bid_latency=50) 112 mgr.register(slow) 113 mgr.register(fast) 114 await mgr.start_all() 115 116 peer = b"\x01" * 32 117 result = await mgr.send(peer, b"hello") 118 assert result is True 119 # Fast transport should have been selected 120 assert len(fast._sent) == 1 121 assert len(slow._sent) == 0 122 assert fast._sent[0] == (peer, b"hello") 123 124 @pytest.mark.asyncio 125 async def test_send_no_transport(self): 126 mgr = TransportManager() 127 peer = b"\x01" * 32 128 result = await mgr.send(peer, b"hello") 129 assert result is False 130 131 @pytest.mark.asyncio 132 async def test_send_will_not_send(self): 133 mgr = TransportManager() 134 t = MockTransport(TransportStyle.NTCP2, bid_latency=TransportBid.WILL_NOT_SEND) 135 mgr.register(t) 136 await mgr.start_all() 137 138 peer = b"\x01" * 32 139 result = await mgr.send(peer, b"hello") 140 assert result is False 141 assert len(t._sent) == 0 142 143 def test_get_addresses(self): 144 mgr = TransportManager() 145 ntcp2 = MockTransport( 146 TransportStyle.NTCP2, 147 address={"host": "1.2.3.4", "port": 15000}, 148 ) 149 ssu2 = MockTransport( 150 TransportStyle.SSU2, 151 address={"host": "1.2.3.4", "port": 15001}, 152 ) 153 mgr.register(ntcp2) 154 mgr.register(ssu2) 155 addrs = mgr.get_addresses() 156 assert len(addrs) == 2 157 ports = {a["port"] for a in addrs} 158 assert ports == {15000, 15001} 159 160 def test_reachability_best_of_all(self): 161 mgr = TransportManager() 162 firewalled = MockTransport( 163 TransportStyle.NTCP2, reachability=ReachabilityStatus.FIREWALLED, 164 ) 165 ok = MockTransport( 166 TransportStyle.SSU2, reachability=ReachabilityStatus.OK, 167 ) 168 mgr.register(firewalled) 169 mgr.register(ok) 170 assert mgr.reachability == ReachabilityStatus.OK 171 172 def test_reachability_no_transports(self): 173 mgr = TransportManager() 174 assert mgr.reachability == ReachabilityStatus.UNKNOWN 175 176 def test_transport_count(self): 177 mgr = TransportManager() 178 assert mgr.transport_count == 0 179 mgr.register(MockTransport(TransportStyle.NTCP2)) 180 assert mgr.transport_count == 1 181 mgr.register(MockTransport(TransportStyle.SSU2)) 182 assert mgr.transport_count == 2 183 184 @pytest.mark.asyncio 185 async def test_get_best_transport(self): 186 mgr = TransportManager() 187 slow = MockTransport(TransportStyle.NTCP2, bid_latency=200) 188 fast = MockTransport(TransportStyle.SSU2, bid_latency=50) 189 mgr.register(slow) 190 mgr.register(fast) 191 await mgr.start_all() 192 193 peer = b"\x01" * 32 194 best = await mgr.get_best_transport(peer) 195 assert best is fast 196 197 def test_get_addresses_skips_none(self): 198 """Transports with no address should not appear in address list.""" 199 mgr = TransportManager() 200 t = MockTransport(TransportStyle.NTCP2, address=None) 201 t._address = None 202 mgr.register(t) 203 addrs = mgr.get_addresses() 204 assert len(addrs) == 0