A Python port of the Invisible Internet Project (I2P)
1"""SSU2 I2NP message fragmentation and reassembly.
2
3Ported from:
4- net.i2p.router.transport.udp.OutboundMessageFragments
5- net.i2p.router.transport.udp.InboundMessageFragments
6- net.i2p.router.transport.udp.InboundMessageState
7
8Large I2NP messages are split into MTU-sized fragments. Each fragment
9is a FirstFragment or FollowOnFragment block inside a DATA packet.
10"""
11
12from __future__ import annotations
13
14import time
15from dataclasses import dataclass, field
16
17from i2p_transport.ssu2_payload import FirstFragmentBlock, FollowOnFragmentBlock
18
19# Default MTU minus headers
20DEFAULT_MAX_FRAGMENT_SIZE = 1280 # Conservative MTU
21
22
23@dataclass
24class OutboundMessage:
25 """Tracks fragmentation state of an outbound I2NP message."""
26 msg_id: int
27 data: bytes
28 fragment_size: int = DEFAULT_MAX_FRAGMENT_SIZE
29 fragments_sent: set[int] = field(default_factory=set)
30 fragments_acked: set[int] = field(default_factory=set)
31 created_at: float = field(default_factory=time.monotonic)
32
33 @property
34 def total_fragments(self) -> int:
35 """Total number of fragments needed."""
36 return max(1, (len(self.data) + self.fragment_size - 1) // self.fragment_size)
37
38 @property
39 def is_complete(self) -> bool:
40 """All fragments have been acked."""
41 return len(self.fragments_acked) >= self.total_fragments
42
43 def get_fragment(self, fragment_num: int) -> bytes:
44 """Get data for fragment N."""
45 start = fragment_num * self.fragment_size
46 end = min(start + self.fragment_size, len(self.data))
47 return self.data[start:end]
48
49 def get_unsent_fragments(self) -> list[int]:
50 """Get fragment numbers that haven't been sent yet."""
51 return [i for i in range(self.total_fragments) if i not in self.fragments_sent]
52
53 def get_unacked_fragments(self) -> list[int]:
54 """Get sent but not yet acked fragment numbers."""
55 return [i for i in sorted(self.fragments_sent) if i not in self.fragments_acked]
56
57
58class OutboundFragmenter:
59 """Fragments outbound I2NP messages for SSU2 DATA packets."""
60
61 def __init__(self, max_fragment_size: int = DEFAULT_MAX_FRAGMENT_SIZE):
62 self._messages: dict[int, OutboundMessage] = {}
63 self._next_msg_id = 0
64 self._max_fragment_size = max_fragment_size
65
66 def add_message(self, data: bytes) -> int:
67 """Queue a message for fragmentation. Returns msg_id."""
68 msg_id = self._next_msg_id
69 self._next_msg_id += 1
70 self._messages[msg_id] = OutboundMessage(
71 msg_id=msg_id, data=data,
72 fragment_size=self._max_fragment_size,
73 )
74 return msg_id
75
76 def get_next_blocks(self, max_blocks: int = 4) -> list[FirstFragmentBlock | FollowOnFragmentBlock]:
77 """Get the next batch of fragment blocks to send.
78
79 Returns list of FirstFragmentBlock or FollowOnFragmentBlock.
80 Prioritizes unsent fragments across all queued messages.
81 """
82 blocks: list[FirstFragmentBlock | FollowOnFragmentBlock] = []
83 for msg in self._messages.values():
84 if msg.is_complete:
85 continue
86 unsent = msg.get_unsent_fragments()
87 for frag_num in unsent:
88 if len(blocks) >= max_blocks:
89 return blocks
90 frag_data = msg.get_fragment(frag_num)
91 msg.fragments_sent.add(frag_num)
92
93 if frag_num == 0:
94 blocks.append(FirstFragmentBlock(
95 msg_id=msg.msg_id,
96 total_fragments=msg.total_fragments,
97 fragment_data=frag_data,
98 ))
99 else:
100 is_last = (frag_num == msg.total_fragments - 1)
101 blocks.append(FollowOnFragmentBlock(
102 msg_id=msg.msg_id,
103 fragment_num=frag_num,
104 is_last=is_last,
105 fragment_data=frag_data,
106 ))
107 return blocks
108
109 def ack_fragment(self, msg_id: int, fragment_num: int) -> None:
110 """Mark a fragment as acked."""
111 msg = self._messages.get(msg_id)
112 if msg is not None:
113 msg.fragments_acked.add(fragment_num)
114
115 def get_completed_messages(self) -> list[int]:
116 """Get msg_ids of fully acked messages."""
117 return [mid for mid, msg in self._messages.items() if msg.is_complete]
118
119 def cleanup_completed(self) -> None:
120 """Remove completed messages."""
121 completed = self.get_completed_messages()
122 for mid in completed:
123 del self._messages[mid]
124
125
126@dataclass
127class InboundMessage:
128 """Tracks reassembly state of an inbound I2NP message."""
129 msg_id: int
130 total_fragments: int = 0
131 fragments: dict[int, bytes] = field(default_factory=dict)
132 received_at: float = field(default_factory=time.monotonic)
133
134 @property
135 def is_complete(self) -> bool:
136 """All fragments received."""
137 return self.total_fragments > 0 and len(self.fragments) >= self.total_fragments
138
139 def add_first_fragment(self, total_fragments: int, data: bytes) -> None:
140 """Add first fragment (contains total fragment count)."""
141 self.total_fragments = total_fragments
142 self.fragments[0] = data
143
144 def add_follow_on_fragment(self, fragment_num: int, data: bytes, is_last: bool) -> None:
145 """Add subsequent fragment."""
146 self.fragments[fragment_num] = data
147 # If is_last and we haven't set total_fragments yet, infer it
148 if is_last and self.total_fragments == 0:
149 self.total_fragments = fragment_num + 1
150
151 def reassemble(self) -> bytes:
152 """Reassemble complete message from fragments. Raises ValueError if incomplete."""
153 if not self.is_complete:
154 raise ValueError(
155 f"Cannot reassemble: have {len(self.fragments)}/{self.total_fragments} fragments"
156 )
157 result = b""
158 for i in range(self.total_fragments):
159 result += self.fragments[i]
160 return result
161
162
163class InboundReassembler:
164 """Reassembles inbound fragmented I2NP messages."""
165
166 def __init__(self, timeout_seconds: float = 60.0):
167 self._messages: dict[int, InboundMessage] = {}
168 self._timeout = timeout_seconds
169
170 def process_first_fragment(self, block: FirstFragmentBlock) -> bytes | None:
171 """Process a FirstFragment block. Returns complete message if single fragment."""
172 msg = self._messages.get(block.msg_id)
173 if msg is None:
174 msg = InboundMessage(msg_id=block.msg_id)
175 self._messages[block.msg_id] = msg
176
177 msg.add_first_fragment(block.total_fragments, block.fragment_data)
178
179 if msg.is_complete:
180 result = msg.reassemble()
181 del self._messages[block.msg_id]
182 return result
183 return None
184
185 def process_follow_on_fragment(self, block: FollowOnFragmentBlock) -> bytes | None:
186 """Process a FollowOnFragment block. Returns complete message if now complete."""
187 msg = self._messages.get(block.msg_id)
188 if msg is None:
189 msg = InboundMessage(msg_id=block.msg_id)
190 self._messages[block.msg_id] = msg
191
192 msg.add_follow_on_fragment(block.fragment_num, block.fragment_data, block.is_last)
193
194 if msg.is_complete:
195 result = msg.reassemble()
196 del self._messages[block.msg_id]
197 return result
198 return None
199
200 def cleanup_stale(self) -> int:
201 """Remove timed-out incomplete messages. Returns count removed."""
202 now = time.monotonic()
203 stale = [
204 mid for mid, msg in self._messages.items()
205 if (now - msg.received_at) >= self._timeout
206 ]
207 for mid in stale:
208 del self._messages[mid]
209 return len(stale)