A Python port of the Invisible Internet Project (I2P)
at main 209 lines 7.9 kB view raw
1"""SSU2 I2NP message fragmentation and reassembly. 2 3Ported from: 4- net.i2p.router.transport.udp.OutboundMessageFragments 5- net.i2p.router.transport.udp.InboundMessageFragments 6- net.i2p.router.transport.udp.InboundMessageState 7 8Large I2NP messages are split into MTU-sized fragments. Each fragment 9is a FirstFragment or FollowOnFragment block inside a DATA packet. 10""" 11 12from __future__ import annotations 13 14import time 15from dataclasses import dataclass, field 16 17from i2p_transport.ssu2_payload import FirstFragmentBlock, FollowOnFragmentBlock 18 19# Default MTU minus headers 20DEFAULT_MAX_FRAGMENT_SIZE = 1280 # Conservative MTU 21 22 23@dataclass 24class OutboundMessage: 25 """Tracks fragmentation state of an outbound I2NP message.""" 26 msg_id: int 27 data: bytes 28 fragment_size: int = DEFAULT_MAX_FRAGMENT_SIZE 29 fragments_sent: set[int] = field(default_factory=set) 30 fragments_acked: set[int] = field(default_factory=set) 31 created_at: float = field(default_factory=time.monotonic) 32 33 @property 34 def total_fragments(self) -> int: 35 """Total number of fragments needed.""" 36 return max(1, (len(self.data) + self.fragment_size - 1) // self.fragment_size) 37 38 @property 39 def is_complete(self) -> bool: 40 """All fragments have been acked.""" 41 return len(self.fragments_acked) >= self.total_fragments 42 43 def get_fragment(self, fragment_num: int) -> bytes: 44 """Get data for fragment N.""" 45 start = fragment_num * self.fragment_size 46 end = min(start + self.fragment_size, len(self.data)) 47 return self.data[start:end] 48 49 def get_unsent_fragments(self) -> list[int]: 50 """Get fragment numbers that haven't been sent yet.""" 51 return [i for i in range(self.total_fragments) if i not in self.fragments_sent] 52 53 def get_unacked_fragments(self) -> list[int]: 54 """Get sent but not yet acked fragment numbers.""" 55 return [i for i in sorted(self.fragments_sent) if i not in self.fragments_acked] 56 57 58class OutboundFragmenter: 59 """Fragments outbound I2NP messages for SSU2 DATA packets.""" 60 61 def __init__(self, max_fragment_size: int = DEFAULT_MAX_FRAGMENT_SIZE): 62 self._messages: dict[int, OutboundMessage] = {} 63 self._next_msg_id = 0 64 self._max_fragment_size = max_fragment_size 65 66 def add_message(self, data: bytes) -> int: 67 """Queue a message for fragmentation. Returns msg_id.""" 68 msg_id = self._next_msg_id 69 self._next_msg_id += 1 70 self._messages[msg_id] = OutboundMessage( 71 msg_id=msg_id, data=data, 72 fragment_size=self._max_fragment_size, 73 ) 74 return msg_id 75 76 def get_next_blocks(self, max_blocks: int = 4) -> list[FirstFragmentBlock | FollowOnFragmentBlock]: 77 """Get the next batch of fragment blocks to send. 78 79 Returns list of FirstFragmentBlock or FollowOnFragmentBlock. 80 Prioritizes unsent fragments across all queued messages. 81 """ 82 blocks: list[FirstFragmentBlock | FollowOnFragmentBlock] = [] 83 for msg in self._messages.values(): 84 if msg.is_complete: 85 continue 86 unsent = msg.get_unsent_fragments() 87 for frag_num in unsent: 88 if len(blocks) >= max_blocks: 89 return blocks 90 frag_data = msg.get_fragment(frag_num) 91 msg.fragments_sent.add(frag_num) 92 93 if frag_num == 0: 94 blocks.append(FirstFragmentBlock( 95 msg_id=msg.msg_id, 96 total_fragments=msg.total_fragments, 97 fragment_data=frag_data, 98 )) 99 else: 100 is_last = (frag_num == msg.total_fragments - 1) 101 blocks.append(FollowOnFragmentBlock( 102 msg_id=msg.msg_id, 103 fragment_num=frag_num, 104 is_last=is_last, 105 fragment_data=frag_data, 106 )) 107 return blocks 108 109 def ack_fragment(self, msg_id: int, fragment_num: int) -> None: 110 """Mark a fragment as acked.""" 111 msg = self._messages.get(msg_id) 112 if msg is not None: 113 msg.fragments_acked.add(fragment_num) 114 115 def get_completed_messages(self) -> list[int]: 116 """Get msg_ids of fully acked messages.""" 117 return [mid for mid, msg in self._messages.items() if msg.is_complete] 118 119 def cleanup_completed(self) -> None: 120 """Remove completed messages.""" 121 completed = self.get_completed_messages() 122 for mid in completed: 123 del self._messages[mid] 124 125 126@dataclass 127class InboundMessage: 128 """Tracks reassembly state of an inbound I2NP message.""" 129 msg_id: int 130 total_fragments: int = 0 131 fragments: dict[int, bytes] = field(default_factory=dict) 132 received_at: float = field(default_factory=time.monotonic) 133 134 @property 135 def is_complete(self) -> bool: 136 """All fragments received.""" 137 return self.total_fragments > 0 and len(self.fragments) >= self.total_fragments 138 139 def add_first_fragment(self, total_fragments: int, data: bytes) -> None: 140 """Add first fragment (contains total fragment count).""" 141 self.total_fragments = total_fragments 142 self.fragments[0] = data 143 144 def add_follow_on_fragment(self, fragment_num: int, data: bytes, is_last: bool) -> None: 145 """Add subsequent fragment.""" 146 self.fragments[fragment_num] = data 147 # If is_last and we haven't set total_fragments yet, infer it 148 if is_last and self.total_fragments == 0: 149 self.total_fragments = fragment_num + 1 150 151 def reassemble(self) -> bytes: 152 """Reassemble complete message from fragments. Raises ValueError if incomplete.""" 153 if not self.is_complete: 154 raise ValueError( 155 f"Cannot reassemble: have {len(self.fragments)}/{self.total_fragments} fragments" 156 ) 157 result = b"" 158 for i in range(self.total_fragments): 159 result += self.fragments[i] 160 return result 161 162 163class InboundReassembler: 164 """Reassembles inbound fragmented I2NP messages.""" 165 166 def __init__(self, timeout_seconds: float = 60.0): 167 self._messages: dict[int, InboundMessage] = {} 168 self._timeout = timeout_seconds 169 170 def process_first_fragment(self, block: FirstFragmentBlock) -> bytes | None: 171 """Process a FirstFragment block. Returns complete message if single fragment.""" 172 msg = self._messages.get(block.msg_id) 173 if msg is None: 174 msg = InboundMessage(msg_id=block.msg_id) 175 self._messages[block.msg_id] = msg 176 177 msg.add_first_fragment(block.total_fragments, block.fragment_data) 178 179 if msg.is_complete: 180 result = msg.reassemble() 181 del self._messages[block.msg_id] 182 return result 183 return None 184 185 def process_follow_on_fragment(self, block: FollowOnFragmentBlock) -> bytes | None: 186 """Process a FollowOnFragment block. Returns complete message if now complete.""" 187 msg = self._messages.get(block.msg_id) 188 if msg is None: 189 msg = InboundMessage(msg_id=block.msg_id) 190 self._messages[block.msg_id] = msg 191 192 msg.add_follow_on_fragment(block.fragment_num, block.fragment_data, block.is_last) 193 194 if msg.is_complete: 195 result = msg.reassemble() 196 del self._messages[block.msg_id] 197 return result 198 return None 199 200 def cleanup_stale(self) -> int: 201 """Remove timed-out incomplete messages. Returns count removed.""" 202 now = time.monotonic() 203 stale = [ 204 mid for mid, msg in self._messages.items() 205 if (now - msg.received_at) >= self._timeout 206 ] 207 for mid in stale: 208 del self._messages[mid] 209 return len(stale)