A Python port of the Invisible Internet Project (I2P)
at main 225 lines 7.8 kB view raw
1"""KBucketSet — main Kademlia routing table. 2 3Ported from net.i2p.kademlia.KBucketSet. 4""" 5 6import os 7import threading 8from typing import Collection, List, Optional, Set 9 10from i2p_kademlia.kbucket import KBucket, KBucketTrimmer 11from i2p_kademlia.trimmers import RandomTrimmer 12from i2p_kademlia.xor_comparator import XORComparator, xor_bytes, highest_bit 13 14 15class KBucketSet: 16 """Kademlia routing table using XOR metric. 17 18 Partitions keyspace into buckets by bit position of highest differing 19 bit from our own key. Buckets split automatically when overfull. 20 21 Args: 22 us: Our own key (bytes). 23 max_per_bucket: Kademlia K value. 24 b_value: Kademlia B value (1 for BitTorrent, 5 for paper). 25 trimmer: Optional eviction strategy (default: RandomTrimmer). 26 """ 27 28 def __init__( 29 self, 30 us: bytes, 31 max_per_bucket: int = 20, 32 b_value: int = 1, 33 trimmer: Optional[KBucketTrimmer] = None, 34 ) -> None: 35 self._us = us 36 self._key_bits = len(us) * 8 37 self._max_per_bucket = max_per_bucket 38 self._b_value = b_value 39 self._b_factor = 1 << (b_value - 1) 40 self._num_buckets = self._key_bits * self._b_factor 41 self._trimmer = trimmer or RandomTrimmer(max_per_bucket) 42 self._lock = threading.RLock() 43 44 # Start with one bucket covering the entire range 45 initial = KBucket(0, self._num_buckets - 1, max_per_bucket, self._trimmer) 46 self._buckets: List[KBucket] = [initial] 47 48 def _get_range(self, key: bytes) -> int: 49 """Get the range number for a key (0 to num_buckets-1, or -1 if key is us).""" 50 if key == self._us: 51 return -1 52 distance = xor_bytes(self._us, key) 53 hb = highest_bit(distance) 54 if hb < 0: 55 return -1 56 # Map bit position to range: range = hb * b_factor 57 # But we need to be compatible with how Java does it 58 return min(hb, self._num_buckets - 1) 59 60 def _find_bucket_index(self, range_val: int) -> int: 61 """Find which bucket contains the given range value.""" 62 with self._lock: 63 for i, b in enumerate(self._buckets): 64 if b.range_begin <= range_val <= b.range_end: 65 return i 66 return 0 67 68 def _get_bucket(self, key: bytes) -> Optional[KBucket]: 69 """Get the bucket for a key.""" 70 r = self._get_range(key) 71 if r < 0: 72 return None 73 idx = self._find_bucket_index(r) 74 with self._lock: 75 return self._buckets[idx] 76 77 def add(self, peer: bytes) -> bool: 78 """Add a peer to the routing table. Returns True if new.""" 79 if peer == self._us: 80 return False 81 r = self._get_range(peer) 82 if r < 0: 83 return False 84 85 with self._lock: 86 idx = self._find_bucket_index(r) 87 bucket = self._buckets[idx] 88 added = bucket.add(peer) 89 90 # Try splitting if bucket is overfull and splittable 91 while ( 92 bucket.key_count > self._max_per_bucket 93 and bucket.range_begin != bucket.range_end 94 ): 95 self._split(idx) 96 idx = self._find_bucket_index(r) 97 bucket = self._buckets[idx] 98 99 return added 100 101 def _split(self, idx: int) -> None: 102 """Split bucket at index into two.""" 103 bucket = self._buckets[idx] 104 begin = bucket.range_begin 105 end = bucket.range_end 106 if begin == end: 107 return 108 109 mid = (begin + end) // 2 110 lower = KBucket(begin, mid, self._max_per_bucket, self._trimmer) 111 upper = KBucket(mid + 1, end, self._max_per_bucket, self._trimmer) 112 113 # Redistribute entries 114 for entry in bucket.get_entries(): 115 r = self._get_range(entry) 116 if r <= mid: 117 lower.add(entry) 118 else: 119 upper.add(entry) 120 121 self._buckets[idx:idx + 1] = [lower, upper] 122 123 def remove(self, entry: bytes) -> bool: 124 """Remove an entry. Returns True if existed.""" 125 bucket = self._get_bucket(entry) 126 if bucket is None: 127 return False 128 return bucket.remove(entry) 129 130 def size(self) -> int: 131 """Total entries across all buckets.""" 132 with self._lock: 133 return sum(b.key_count for b in self._buckets) 134 135 def clear(self) -> None: 136 """Remove all entries.""" 137 with self._lock: 138 for b in self._buckets: 139 b.clear() 140 141 def get_all(self, to_ignore: Optional[Set[bytes]] = None) -> Set[bytes]: 142 """Get all entries, optionally excluding a set.""" 143 result: Set[bytes] = set() 144 with self._lock: 145 for b in self._buckets: 146 result.update(b.get_entries()) 147 if to_ignore: 148 result -= to_ignore 149 return result 150 151 def get_closest( 152 self, 153 key: bytes, 154 max_count: int, 155 to_ignore: Optional[Collection[bytes]] = None, 156 ) -> List[bytes]: 157 """Get the closest entries to a key by XOR distance. 158 159 Returns up to max_count entries, sorted closest first. 160 """ 161 ignore_set = set(to_ignore) if to_ignore else set() 162 candidates = self.get_all(ignore_set) 163 164 comp = XORComparator(key) 165 sorted_candidates = sorted(candidates, key=comp.key_func()) 166 return sorted_candidates[:max_count] 167 168 def get_closest_to_us( 169 self, 170 max_count: int, 171 to_ignore: Optional[Collection[bytes]] = None, 172 ) -> List[bytes]: 173 """Get entries closest to our own key.""" 174 return self.get_closest(self._us, max_count, to_ignore) 175 176 def get_buckets(self) -> List[KBucket]: 177 """Return a copy of the bucket list (for testing).""" 178 with self._lock: 179 return list(self._buckets) 180 181 def get_explore_keys(self, max_age_ms: int) -> List[bytes]: 182 """Generate random keys for stale or sparse buckets. 183 184 Returns keys for buckets that are either: 185 - Not updated in max_age_ms 186 - Less than 75% full 187 """ 188 import time as _time_mod 189 result = [] 190 with self._lock: 191 now = int(_time_mod.time() * 1000) 192 193 for bucket in self._buckets: 194 is_stale = (now - bucket.last_changed) > max_age_ms 195 is_sparse = bucket.key_count < (self._max_per_bucket * 3 // 4) 196 if is_stale or is_sparse: 197 # Generate a random key that falls in this bucket's range 198 rkey = self._generate_random_key(bucket) 199 if rkey is not None: 200 result.append(rkey) 201 return result 202 203 def _generate_random_key(self, bucket: KBucket) -> Optional[bytes]: 204 """Generate a random key that falls within a bucket's range.""" 205 # Create a random key at the right XOR distance from us 206 key_len = len(self._us) 207 random_bytes = bytearray(os.urandom(key_len)) 208 209 # Set the appropriate bit for the target range 210 target_bit = (bucket.range_begin + bucket.range_end) // 2 211 byte_idx = key_len - 1 - (target_bit // 8) 212 bit_idx = target_bit % 8 213 214 if 0 <= byte_idx < key_len: 215 # XOR with our key and ensure the target bit is set 216 for i in range(key_len): 217 random_bytes[i] ^= self._us[i] 218 # Clear bits above target 219 for i in range(0, byte_idx): 220 random_bytes[i] = 0 221 random_bytes[byte_idx] |= (1 << bit_idx) 222 # XOR back with our key to get actual key 223 result = bytes(random_bytes[i] ^ self._us[i] for i in range(key_len)) 224 return result 225 return None