A Python port of the Invisible Internet Project (I2P)
1"""Tests for TransportManager — multi-transport coordinator.
2
3TDD: tests written before implementation.
4"""
5
6import asyncio
7import pytest
8
9from i2p_transport.transport_base import (
10 Transport, TransportBid, TransportStyle, ReachabilityStatus,
11)
12from i2p_transport.manager import TransportManager
13
14
15class MockTransport(Transport):
16 """Mock transport for testing TransportManager."""
17
18 def __init__(
19 self,
20 style: TransportStyle,
21 bid_latency: int = 100,
22 reachability: ReachabilityStatus = ReachabilityStatus.OK,
23 address: dict | None = None,
24 ):
25 self._style = style
26 self._bid_latency = bid_latency
27 self._reachability = reachability
28 self._running = False
29 self._sent: list[tuple[bytes, bytes]] = []
30 self._address = address or {"host": "127.0.0.1", "port": 15000}
31 self._send_result = True
32
33 @property
34 def style(self) -> TransportStyle:
35 return self._style
36
37 async def start(self) -> None:
38 self._running = True
39
40 async def stop(self) -> None:
41 self._running = False
42
43 @property
44 def is_running(self) -> bool:
45 return self._running
46
47 async def bid(self, peer_hash: bytes) -> TransportBid:
48 if self._bid_latency == TransportBid.WILL_NOT_SEND:
49 return TransportBid(
50 latency_ms=TransportBid.WILL_NOT_SEND,
51 transport=self,
52 )
53 return TransportBid(latency_ms=self._bid_latency, transport=self)
54
55 async def send(self, peer_hash: bytes, data: bytes) -> bool:
56 self._sent.append((peer_hash, data))
57 return self._send_result
58
59 @property
60 def reachability(self) -> ReachabilityStatus:
61 return self._reachability
62
63 @property
64 def current_address(self) -> dict | None:
65 return self._address
66
67
68# ---------------------------------------------------------------------------
69# Tests
70# ---------------------------------------------------------------------------
71
72class TestTransportManager:
73
74 def test_register_transport(self):
75 mgr = TransportManager()
76 t = MockTransport(TransportStyle.NTCP2)
77 mgr.register(t)
78 assert mgr.get_transport(TransportStyle.NTCP2) is t
79
80 def test_register_multiple(self):
81 mgr = TransportManager()
82 ntcp2 = MockTransport(TransportStyle.NTCP2)
83 ssu2 = MockTransport(TransportStyle.SSU2)
84 mgr.register(ntcp2)
85 mgr.register(ssu2)
86 assert mgr.get_transport(TransportStyle.NTCP2) is ntcp2
87 assert mgr.get_transport(TransportStyle.SSU2) is ssu2
88
89 @pytest.mark.asyncio
90 async def test_start_stop_all(self):
91 mgr = TransportManager()
92 ntcp2 = MockTransport(TransportStyle.NTCP2)
93 ssu2 = MockTransport(TransportStyle.SSU2)
94 mgr.register(ntcp2)
95 mgr.register(ssu2)
96
97 await mgr.start_all()
98 assert ntcp2.is_running
99 assert ssu2.is_running
100 assert mgr.is_running
101
102 await mgr.stop_all()
103 assert not ntcp2.is_running
104 assert not ssu2.is_running
105 assert not mgr.is_running
106
107 @pytest.mark.asyncio
108 async def test_send_selects_best_bid(self):
109 mgr = TransportManager()
110 slow = MockTransport(TransportStyle.NTCP2, bid_latency=200)
111 fast = MockTransport(TransportStyle.SSU2, bid_latency=50)
112 mgr.register(slow)
113 mgr.register(fast)
114 await mgr.start_all()
115
116 peer = b"\x01" * 32
117 result = await mgr.send(peer, b"hello")
118 assert result is True
119 # Fast transport should have been selected
120 assert len(fast._sent) == 1
121 assert len(slow._sent) == 0
122 assert fast._sent[0] == (peer, b"hello")
123
124 @pytest.mark.asyncio
125 async def test_send_no_transport(self):
126 mgr = TransportManager()
127 peer = b"\x01" * 32
128 result = await mgr.send(peer, b"hello")
129 assert result is False
130
131 @pytest.mark.asyncio
132 async def test_send_will_not_send(self):
133 mgr = TransportManager()
134 t = MockTransport(TransportStyle.NTCP2, bid_latency=TransportBid.WILL_NOT_SEND)
135 mgr.register(t)
136 await mgr.start_all()
137
138 peer = b"\x01" * 32
139 result = await mgr.send(peer, b"hello")
140 assert result is False
141 assert len(t._sent) == 0
142
143 def test_get_addresses(self):
144 mgr = TransportManager()
145 ntcp2 = MockTransport(
146 TransportStyle.NTCP2,
147 address={"host": "1.2.3.4", "port": 15000},
148 )
149 ssu2 = MockTransport(
150 TransportStyle.SSU2,
151 address={"host": "1.2.3.4", "port": 15001},
152 )
153 mgr.register(ntcp2)
154 mgr.register(ssu2)
155 addrs = mgr.get_addresses()
156 assert len(addrs) == 2
157 ports = {a["port"] for a in addrs}
158 assert ports == {15000, 15001}
159
160 def test_reachability_best_of_all(self):
161 mgr = TransportManager()
162 firewalled = MockTransport(
163 TransportStyle.NTCP2, reachability=ReachabilityStatus.FIREWALLED,
164 )
165 ok = MockTransport(
166 TransportStyle.SSU2, reachability=ReachabilityStatus.OK,
167 )
168 mgr.register(firewalled)
169 mgr.register(ok)
170 assert mgr.reachability == ReachabilityStatus.OK
171
172 def test_reachability_no_transports(self):
173 mgr = TransportManager()
174 assert mgr.reachability == ReachabilityStatus.UNKNOWN
175
176 def test_transport_count(self):
177 mgr = TransportManager()
178 assert mgr.transport_count == 0
179 mgr.register(MockTransport(TransportStyle.NTCP2))
180 assert mgr.transport_count == 1
181 mgr.register(MockTransport(TransportStyle.SSU2))
182 assert mgr.transport_count == 2
183
184 @pytest.mark.asyncio
185 async def test_get_best_transport(self):
186 mgr = TransportManager()
187 slow = MockTransport(TransportStyle.NTCP2, bid_latency=200)
188 fast = MockTransport(TransportStyle.SSU2, bid_latency=50)
189 mgr.register(slow)
190 mgr.register(fast)
191 await mgr.start_all()
192
193 peer = b"\x01" * 32
194 best = await mgr.get_best_transport(peer)
195 assert best is fast
196
197 def test_get_addresses_skips_none(self):
198 """Transports with no address should not appear in address list."""
199 mgr = TransportManager()
200 t = MockTransport(TransportStyle.NTCP2, address=None)
201 t._address = None
202 mgr.register(t)
203 addrs = mgr.get_addresses()
204 assert len(addrs) == 0