A Python port of the Invisible Internet Project (I2P)
1"""KBucketSet — main Kademlia routing table.
2
3Ported from net.i2p.kademlia.KBucketSet.
4"""
5
6import os
7import threading
8from typing import Collection, List, Optional, Set
9
10from i2p_kademlia.kbucket import KBucket, KBucketTrimmer
11from i2p_kademlia.trimmers import RandomTrimmer
12from i2p_kademlia.xor_comparator import XORComparator, xor_bytes, highest_bit
13
14
15class KBucketSet:
16 """Kademlia routing table using XOR metric.
17
18 Partitions keyspace into buckets by bit position of highest differing
19 bit from our own key. Buckets split automatically when overfull.
20
21 Args:
22 us: Our own key (bytes).
23 max_per_bucket: Kademlia K value.
24 b_value: Kademlia B value (1 for BitTorrent, 5 for paper).
25 trimmer: Optional eviction strategy (default: RandomTrimmer).
26 """
27
28 def __init__(
29 self,
30 us: bytes,
31 max_per_bucket: int = 20,
32 b_value: int = 1,
33 trimmer: Optional[KBucketTrimmer] = None,
34 ) -> None:
35 self._us = us
36 self._key_bits = len(us) * 8
37 self._max_per_bucket = max_per_bucket
38 self._b_value = b_value
39 self._b_factor = 1 << (b_value - 1)
40 self._num_buckets = self._key_bits * self._b_factor
41 self._trimmer = trimmer or RandomTrimmer(max_per_bucket)
42 self._lock = threading.RLock()
43
44 # Start with one bucket covering the entire range
45 initial = KBucket(0, self._num_buckets - 1, max_per_bucket, self._trimmer)
46 self._buckets: List[KBucket] = [initial]
47
48 def _get_range(self, key: bytes) -> int:
49 """Get the range number for a key (0 to num_buckets-1, or -1 if key is us)."""
50 if key == self._us:
51 return -1
52 distance = xor_bytes(self._us, key)
53 hb = highest_bit(distance)
54 if hb < 0:
55 return -1
56 # Map bit position to range: range = hb * b_factor
57 # But we need to be compatible with how Java does it
58 return min(hb, self._num_buckets - 1)
59
60 def _find_bucket_index(self, range_val: int) -> int:
61 """Find which bucket contains the given range value."""
62 with self._lock:
63 for i, b in enumerate(self._buckets):
64 if b.range_begin <= range_val <= b.range_end:
65 return i
66 return 0
67
68 def _get_bucket(self, key: bytes) -> Optional[KBucket]:
69 """Get the bucket for a key."""
70 r = self._get_range(key)
71 if r < 0:
72 return None
73 idx = self._find_bucket_index(r)
74 with self._lock:
75 return self._buckets[idx]
76
77 def add(self, peer: bytes) -> bool:
78 """Add a peer to the routing table. Returns True if new."""
79 if peer == self._us:
80 return False
81 r = self._get_range(peer)
82 if r < 0:
83 return False
84
85 with self._lock:
86 idx = self._find_bucket_index(r)
87 bucket = self._buckets[idx]
88 added = bucket.add(peer)
89
90 # Try splitting if bucket is overfull and splittable
91 while (
92 bucket.key_count > self._max_per_bucket
93 and bucket.range_begin != bucket.range_end
94 ):
95 self._split(idx)
96 idx = self._find_bucket_index(r)
97 bucket = self._buckets[idx]
98
99 return added
100
101 def _split(self, idx: int) -> None:
102 """Split bucket at index into two."""
103 bucket = self._buckets[idx]
104 begin = bucket.range_begin
105 end = bucket.range_end
106 if begin == end:
107 return
108
109 mid = (begin + end) // 2
110 lower = KBucket(begin, mid, self._max_per_bucket, self._trimmer)
111 upper = KBucket(mid + 1, end, self._max_per_bucket, self._trimmer)
112
113 # Redistribute entries
114 for entry in bucket.get_entries():
115 r = self._get_range(entry)
116 if r <= mid:
117 lower.add(entry)
118 else:
119 upper.add(entry)
120
121 self._buckets[idx:idx + 1] = [lower, upper]
122
123 def remove(self, entry: bytes) -> bool:
124 """Remove an entry. Returns True if existed."""
125 bucket = self._get_bucket(entry)
126 if bucket is None:
127 return False
128 return bucket.remove(entry)
129
130 def size(self) -> int:
131 """Total entries across all buckets."""
132 with self._lock:
133 return sum(b.key_count for b in self._buckets)
134
135 def clear(self) -> None:
136 """Remove all entries."""
137 with self._lock:
138 for b in self._buckets:
139 b.clear()
140
141 def get_all(self, to_ignore: Optional[Set[bytes]] = None) -> Set[bytes]:
142 """Get all entries, optionally excluding a set."""
143 result: Set[bytes] = set()
144 with self._lock:
145 for b in self._buckets:
146 result.update(b.get_entries())
147 if to_ignore:
148 result -= to_ignore
149 return result
150
151 def get_closest(
152 self,
153 key: bytes,
154 max_count: int,
155 to_ignore: Optional[Collection[bytes]] = None,
156 ) -> List[bytes]:
157 """Get the closest entries to a key by XOR distance.
158
159 Returns up to max_count entries, sorted closest first.
160 """
161 ignore_set = set(to_ignore) if to_ignore else set()
162 candidates = self.get_all(ignore_set)
163
164 comp = XORComparator(key)
165 sorted_candidates = sorted(candidates, key=comp.key_func())
166 return sorted_candidates[:max_count]
167
168 def get_closest_to_us(
169 self,
170 max_count: int,
171 to_ignore: Optional[Collection[bytes]] = None,
172 ) -> List[bytes]:
173 """Get entries closest to our own key."""
174 return self.get_closest(self._us, max_count, to_ignore)
175
176 def get_buckets(self) -> List[KBucket]:
177 """Return a copy of the bucket list (for testing)."""
178 with self._lock:
179 return list(self._buckets)
180
181 def get_explore_keys(self, max_age_ms: int) -> List[bytes]:
182 """Generate random keys for stale or sparse buckets.
183
184 Returns keys for buckets that are either:
185 - Not updated in max_age_ms
186 - Less than 75% full
187 """
188 import time as _time_mod
189 result = []
190 with self._lock:
191 now = int(_time_mod.time() * 1000)
192
193 for bucket in self._buckets:
194 is_stale = (now - bucket.last_changed) > max_age_ms
195 is_sparse = bucket.key_count < (self._max_per_bucket * 3 // 4)
196 if is_stale or is_sparse:
197 # Generate a random key that falls in this bucket's range
198 rkey = self._generate_random_key(bucket)
199 if rkey is not None:
200 result.append(rkey)
201 return result
202
203 def _generate_random_key(self, bucket: KBucket) -> Optional[bytes]:
204 """Generate a random key that falls within a bucket's range."""
205 # Create a random key at the right XOR distance from us
206 key_len = len(self._us)
207 random_bytes = bytearray(os.urandom(key_len))
208
209 # Set the appropriate bit for the target range
210 target_bit = (bucket.range_begin + bucket.range_end) // 2
211 byte_idx = key_len - 1 - (target_bit // 8)
212 bit_idx = target_bit % 8
213
214 if 0 <= byte_idx < key_len:
215 # XOR with our key and ensure the target bit is set
216 for i in range(key_len):
217 random_bytes[i] ^= self._us[i]
218 # Clear bits above target
219 for i in range(0, byte_idx):
220 random_bytes[i] = 0
221 random_bytes[byte_idx] |= (1 << bit_idx)
222 # XOR back with our key to get actual key
223 result = bytes(random_bytes[i] ^ self._us[i] for i in range(key_len))
224 return result
225 return None