"""KBucketSet — main Kademlia routing table. Ported from net.i2p.kademlia.KBucketSet. """ import os import threading from typing import Collection, List, Optional, Set from i2p_kademlia.kbucket import KBucket, KBucketTrimmer from i2p_kademlia.trimmers import RandomTrimmer from i2p_kademlia.xor_comparator import XORComparator, xor_bytes, highest_bit class KBucketSet: """Kademlia routing table using XOR metric. Partitions keyspace into buckets by bit position of highest differing bit from our own key. Buckets split automatically when overfull. Args: us: Our own key (bytes). max_per_bucket: Kademlia K value. b_value: Kademlia B value (1 for BitTorrent, 5 for paper). trimmer: Optional eviction strategy (default: RandomTrimmer). """ def __init__( self, us: bytes, max_per_bucket: int = 20, b_value: int = 1, trimmer: Optional[KBucketTrimmer] = None, ) -> None: self._us = us self._key_bits = len(us) * 8 self._max_per_bucket = max_per_bucket self._b_value = b_value self._b_factor = 1 << (b_value - 1) self._num_buckets = self._key_bits * self._b_factor self._trimmer = trimmer or RandomTrimmer(max_per_bucket) self._lock = threading.RLock() # Start with one bucket covering the entire range initial = KBucket(0, self._num_buckets - 1, max_per_bucket, self._trimmer) self._buckets: List[KBucket] = [initial] def _get_range(self, key: bytes) -> int: """Get the range number for a key (0 to num_buckets-1, or -1 if key is us).""" if key == self._us: return -1 distance = xor_bytes(self._us, key) hb = highest_bit(distance) if hb < 0: return -1 # Map bit position to range: range = hb * b_factor # But we need to be compatible with how Java does it return min(hb, self._num_buckets - 1) def _find_bucket_index(self, range_val: int) -> int: """Find which bucket contains the given range value.""" with self._lock: for i, b in enumerate(self._buckets): if b.range_begin <= range_val <= b.range_end: return i return 0 def _get_bucket(self, key: bytes) -> Optional[KBucket]: """Get the bucket for a key.""" r = self._get_range(key) if r < 0: return None idx = self._find_bucket_index(r) with self._lock: return self._buckets[idx] def add(self, peer: bytes) -> bool: """Add a peer to the routing table. Returns True if new.""" if peer == self._us: return False r = self._get_range(peer) if r < 0: return False with self._lock: idx = self._find_bucket_index(r) bucket = self._buckets[idx] added = bucket.add(peer) # Try splitting if bucket is overfull and splittable while ( bucket.key_count > self._max_per_bucket and bucket.range_begin != bucket.range_end ): self._split(idx) idx = self._find_bucket_index(r) bucket = self._buckets[idx] return added def _split(self, idx: int) -> None: """Split bucket at index into two.""" bucket = self._buckets[idx] begin = bucket.range_begin end = bucket.range_end if begin == end: return mid = (begin + end) // 2 lower = KBucket(begin, mid, self._max_per_bucket, self._trimmer) upper = KBucket(mid + 1, end, self._max_per_bucket, self._trimmer) # Redistribute entries for entry in bucket.get_entries(): r = self._get_range(entry) if r <= mid: lower.add(entry) else: upper.add(entry) self._buckets[idx:idx + 1] = [lower, upper] def remove(self, entry: bytes) -> bool: """Remove an entry. Returns True if existed.""" bucket = self._get_bucket(entry) if bucket is None: return False return bucket.remove(entry) def size(self) -> int: """Total entries across all buckets.""" with self._lock: return sum(b.key_count for b in self._buckets) def clear(self) -> None: """Remove all entries.""" with self._lock: for b in self._buckets: b.clear() def get_all(self, to_ignore: Optional[Set[bytes]] = None) -> Set[bytes]: """Get all entries, optionally excluding a set.""" result: Set[bytes] = set() with self._lock: for b in self._buckets: result.update(b.get_entries()) if to_ignore: result -= to_ignore return result def get_closest( self, key: bytes, max_count: int, to_ignore: Optional[Collection[bytes]] = None, ) -> List[bytes]: """Get the closest entries to a key by XOR distance. Returns up to max_count entries, sorted closest first. """ ignore_set = set(to_ignore) if to_ignore else set() candidates = self.get_all(ignore_set) comp = XORComparator(key) sorted_candidates = sorted(candidates, key=comp.key_func()) return sorted_candidates[:max_count] def get_closest_to_us( self, max_count: int, to_ignore: Optional[Collection[bytes]] = None, ) -> List[bytes]: """Get entries closest to our own key.""" return self.get_closest(self._us, max_count, to_ignore) def get_buckets(self) -> List[KBucket]: """Return a copy of the bucket list (for testing).""" with self._lock: return list(self._buckets) def get_explore_keys(self, max_age_ms: int) -> List[bytes]: """Generate random keys for stale or sparse buckets. Returns keys for buckets that are either: - Not updated in max_age_ms - Less than 75% full """ import time as _time_mod result = [] with self._lock: now = int(_time_mod.time() * 1000) for bucket in self._buckets: is_stale = (now - bucket.last_changed) > max_age_ms is_sparse = bucket.key_count < (self._max_per_bucket * 3 // 4) if is_stale or is_sparse: # Generate a random key that falls in this bucket's range rkey = self._generate_random_key(bucket) if rkey is not None: result.append(rkey) return result def _generate_random_key(self, bucket: KBucket) -> Optional[bytes]: """Generate a random key that falls within a bucket's range.""" # Create a random key at the right XOR distance from us key_len = len(self._us) random_bytes = bytearray(os.urandom(key_len)) # Set the appropriate bit for the target range target_bit = (bucket.range_begin + bucket.range_end) // 2 byte_idx = key_len - 1 - (target_bit // 8) bit_idx = target_bit % 8 if 0 <= byte_idx < key_len: # XOR with our key and ensure the target bit is set for i in range(key_len): random_bytes[i] ^= self._us[i] # Clear bits above target for i in range(0, byte_idx): random_bytes[i] = 0 random_bytes[byte_idx] |= (1 << bit_idx) # XOR back with our key to get actual key result = bytes(random_bytes[i] ^ self._us[i] for i in range(key_len)) return result return None