A Python port of the Invisible Internet Project (I2P)
at main 155 lines 5.5 kB view raw
1"""SSU2 ACK bitfield -- efficient tracking of received packet numbers. 2 3Ported from net.i2p.router.transport.udp.SSU2Bitfield. 4""" 5 6from __future__ import annotations 7 8 9class SSU2Bitfield: 10 """Tracks which packet numbers have been received/acked. 11 12 Uses a compact bitfield with an offset that shifts forward 13 as older packets are acknowledged. Converts to/from SSU2 14 ACK block format (ranges of acked/not-acked). 15 """ 16 17 def __init__(self, initial_capacity: int = 256): 18 self._bits = bytearray((initial_capacity + 7) // 8) 19 self._capacity = initial_capacity 20 self._offset: int = 0 # Lowest tracked packet number 21 self._highest: int = -1 22 self._count: int = 0 23 24 # ------------------------------------------------------------------ 25 # Internal helpers 26 # ------------------------------------------------------------------ 27 28 def _ensure_capacity(self, packet_num: int) -> None: 29 """Grow the bitfield if *packet_num* exceeds current capacity.""" 30 idx = packet_num - self._offset 31 while idx >= self._capacity: 32 self._capacity *= 2 33 needed_bytes = (self._capacity + 7) // 8 34 if needed_bytes > len(self._bits): 35 self._bits.extend(b"\x00" * (needed_bytes - len(self._bits))) 36 37 def _bit_index(self, packet_num: int) -> tuple[int, int]: 38 """Return (byte_index, bit_mask) for *packet_num*.""" 39 idx = packet_num - self._offset 40 return idx >> 3, 1 << (idx & 7) 41 42 # ------------------------------------------------------------------ 43 # Public API 44 # ------------------------------------------------------------------ 45 46 def set(self, packet_num: int) -> None: 47 """Mark *packet_num* as received.""" 48 if packet_num < self._offset: 49 return # Already shifted out 50 self._ensure_capacity(packet_num) 51 byte_idx, mask = self._bit_index(packet_num) 52 if not (self._bits[byte_idx] & mask): 53 self._bits[byte_idx] |= mask 54 self._count += 1 55 if packet_num > self._highest: 56 self._highest = packet_num 57 58 def get(self, packet_num: int) -> bool: 59 """Return ``True`` if *packet_num* has been received.""" 60 if packet_num < self._offset: 61 return False 62 idx = packet_num - self._offset 63 if idx >= self._capacity: 64 return False 65 byte_idx, mask = self._bit_index(packet_num) 66 if byte_idx >= len(self._bits): 67 return False 68 return bool(self._bits[byte_idx] & mask) 69 70 def get_highest(self) -> int: 71 """Return the highest received packet number, or -1 if none.""" 72 return self._highest 73 74 def shift_offset(self, new_offset: int) -> None: 75 """Advance offset, discarding entries below *new_offset*.""" 76 if new_offset <= self._offset: 77 return 78 shift = new_offset - self._offset 79 80 # Count bits being discarded 81 for pn in range(self._offset, new_offset): 82 if self.get(pn): 83 self._count -= 1 84 85 # Shift the bytearray 86 shift_bytes = shift >> 3 87 shift_bits = shift & 7 88 89 if shift_bytes >= len(self._bits): 90 self._bits = bytearray((self._capacity + 7) // 8) 91 else: 92 if shift_bytes > 0: 93 self._bits = self._bits[shift_bytes:] + bytearray(shift_bytes) 94 if shift_bits > 0: 95 carry = 0 96 for i in range(len(self._bits) - 1, -1, -1): 97 new_carry = (self._bits[i] & ((1 << shift_bits) - 1)) << (8 - shift_bits) 98 self._bits[i] = (self._bits[i] >> shift_bits) | carry 99 carry = new_carry 100 101 self._offset = new_offset 102 103 def to_ack_blocks(self) -> list[tuple[int, int]]: 104 """Convert to ACK block ranges: list of (ack_count, nack_count). 105 106 Starting from highest, alternates between acked and not-acked ranges. 107 The last tuple may have nack_count == 0. 108 """ 109 if self._highest < self._offset: 110 return [] 111 112 blocks: list[tuple[int, int]] = [] 113 pos = self._highest 114 while pos >= self._offset: 115 # Count acked 116 ack_count = 0 117 while pos >= self._offset and self.get(pos): 118 ack_count += 1 119 pos -= 1 120 121 # Count nacked 122 nack_count = 0 123 while pos >= self._offset and not self.get(pos): 124 nack_count += 1 125 pos -= 1 126 127 blocks.append((ack_count, nack_count)) 128 129 return blocks 130 131 @classmethod 132 def from_ack_blocks(cls, through: int, blocks: list[tuple[int, int]]) -> SSU2Bitfield: 133 """Reconstruct bitfield from ACK blocks received in a packet. 134 135 Args: 136 through: Highest acknowledged packet number. 137 blocks: List of (ack_count, nack_count) tuples, starting 138 from *through* and counting downward. 139 """ 140 bf = cls(initial_capacity=through + 1 if through >= 0 else 256) 141 pos = through 142 for ack_count, nack_count in blocks: 143 for _ in range(ack_count): 144 if pos >= 0: 145 bf.set(pos) 146 pos -= 1 147 pos -= nack_count # Skip nacked 148 return bf 149 150 def __len__(self) -> int: 151 """Number of received packets tracked.""" 152 return self._count 153 154 def __contains__(self, packet_num: int) -> bool: 155 return self.get(packet_num)