A Python port of the Invisible Internet Project (I2P)
1"""SSU2 ACK bitfield -- efficient tracking of received packet numbers.
2
3Ported from net.i2p.router.transport.udp.SSU2Bitfield.
4"""
5
6from __future__ import annotations
7
8
9class SSU2Bitfield:
10 """Tracks which packet numbers have been received/acked.
11
12 Uses a compact bitfield with an offset that shifts forward
13 as older packets are acknowledged. Converts to/from SSU2
14 ACK block format (ranges of acked/not-acked).
15 """
16
17 def __init__(self, initial_capacity: int = 256):
18 self._bits = bytearray((initial_capacity + 7) // 8)
19 self._capacity = initial_capacity
20 self._offset: int = 0 # Lowest tracked packet number
21 self._highest: int = -1
22 self._count: int = 0
23
24 # ------------------------------------------------------------------
25 # Internal helpers
26 # ------------------------------------------------------------------
27
28 def _ensure_capacity(self, packet_num: int) -> None:
29 """Grow the bitfield if *packet_num* exceeds current capacity."""
30 idx = packet_num - self._offset
31 while idx >= self._capacity:
32 self._capacity *= 2
33 needed_bytes = (self._capacity + 7) // 8
34 if needed_bytes > len(self._bits):
35 self._bits.extend(b"\x00" * (needed_bytes - len(self._bits)))
36
37 def _bit_index(self, packet_num: int) -> tuple[int, int]:
38 """Return (byte_index, bit_mask) for *packet_num*."""
39 idx = packet_num - self._offset
40 return idx >> 3, 1 << (idx & 7)
41
42 # ------------------------------------------------------------------
43 # Public API
44 # ------------------------------------------------------------------
45
46 def set(self, packet_num: int) -> None:
47 """Mark *packet_num* as received."""
48 if packet_num < self._offset:
49 return # Already shifted out
50 self._ensure_capacity(packet_num)
51 byte_idx, mask = self._bit_index(packet_num)
52 if not (self._bits[byte_idx] & mask):
53 self._bits[byte_idx] |= mask
54 self._count += 1
55 if packet_num > self._highest:
56 self._highest = packet_num
57
58 def get(self, packet_num: int) -> bool:
59 """Return ``True`` if *packet_num* has been received."""
60 if packet_num < self._offset:
61 return False
62 idx = packet_num - self._offset
63 if idx >= self._capacity:
64 return False
65 byte_idx, mask = self._bit_index(packet_num)
66 if byte_idx >= len(self._bits):
67 return False
68 return bool(self._bits[byte_idx] & mask)
69
70 def get_highest(self) -> int:
71 """Return the highest received packet number, or -1 if none."""
72 return self._highest
73
74 def shift_offset(self, new_offset: int) -> None:
75 """Advance offset, discarding entries below *new_offset*."""
76 if new_offset <= self._offset:
77 return
78 shift = new_offset - self._offset
79
80 # Count bits being discarded
81 for pn in range(self._offset, new_offset):
82 if self.get(pn):
83 self._count -= 1
84
85 # Shift the bytearray
86 shift_bytes = shift >> 3
87 shift_bits = shift & 7
88
89 if shift_bytes >= len(self._bits):
90 self._bits = bytearray((self._capacity + 7) // 8)
91 else:
92 if shift_bytes > 0:
93 self._bits = self._bits[shift_bytes:] + bytearray(shift_bytes)
94 if shift_bits > 0:
95 carry = 0
96 for i in range(len(self._bits) - 1, -1, -1):
97 new_carry = (self._bits[i] & ((1 << shift_bits) - 1)) << (8 - shift_bits)
98 self._bits[i] = (self._bits[i] >> shift_bits) | carry
99 carry = new_carry
100
101 self._offset = new_offset
102
103 def to_ack_blocks(self) -> list[tuple[int, int]]:
104 """Convert to ACK block ranges: list of (ack_count, nack_count).
105
106 Starting from highest, alternates between acked and not-acked ranges.
107 The last tuple may have nack_count == 0.
108 """
109 if self._highest < self._offset:
110 return []
111
112 blocks: list[tuple[int, int]] = []
113 pos = self._highest
114 while pos >= self._offset:
115 # Count acked
116 ack_count = 0
117 while pos >= self._offset and self.get(pos):
118 ack_count += 1
119 pos -= 1
120
121 # Count nacked
122 nack_count = 0
123 while pos >= self._offset and not self.get(pos):
124 nack_count += 1
125 pos -= 1
126
127 blocks.append((ack_count, nack_count))
128
129 return blocks
130
131 @classmethod
132 def from_ack_blocks(cls, through: int, blocks: list[tuple[int, int]]) -> SSU2Bitfield:
133 """Reconstruct bitfield from ACK blocks received in a packet.
134
135 Args:
136 through: Highest acknowledged packet number.
137 blocks: List of (ack_count, nack_count) tuples, starting
138 from *through* and counting downward.
139 """
140 bf = cls(initial_capacity=through + 1 if through >= 0 else 256)
141 pos = through
142 for ack_count, nack_count in blocks:
143 for _ in range(ack_count):
144 if pos >= 0:
145 bf.set(pos)
146 pos -= 1
147 pos -= nack_count # Skip nacked
148 return bf
149
150 def __len__(self) -> int:
151 """Number of received packets tracked."""
152 return self._count
153
154 def __contains__(self, packet_num: int) -> bool:
155 return self.get(packet_num)