"""Banlist — temporary and permanent peer banning. Ported from net.i2p.router.Banlist. Tracks peers that should be avoided for routing, either globally or for specific transports. Supports time-based expiration. """ from __future__ import annotations import time from dataclasses import dataclass @dataclass class BanlistEntry: """A single ban entry.""" peer_hash: bytes expiration_ms: int # 0 = permanent transport: str | None # None = all transports reason: str cause: str | None = None class Banlist: """Peer ban tracking with expiration and transport-specific bans.""" DEFAULT_DURATION_MS = 3_600_000 # 1 hour def __init__(self) -> None: # Key: (peer_hash, transport_or_None) → BanlistEntry self._entries: dict[tuple[bytes, str | None], BanlistEntry] = {} def banlist_router( self, peer_hash: bytes, reason: str, duration_ms: int = DEFAULT_DURATION_MS, ) -> None: """Ban a peer globally (all transports).""" now = int(time.time() * 1000) self._entries[(peer_hash, None)] = BanlistEntry( peer_hash=peer_hash, expiration_ms=now + duration_ms, transport=None, reason=reason, ) def banlist_router_forever(self, peer_hash: bytes, reason: str) -> None: """Permanently ban a peer.""" self._entries[(peer_hash, None)] = BanlistEntry( peer_hash=peer_hash, expiration_ms=0, transport=None, reason=reason, ) def banlist_router_transport( self, peer_hash: bytes, transport: str, reason: str, duration_ms: int = DEFAULT_DURATION_MS, ) -> None: """Ban a peer for a specific transport only.""" now = int(time.time() * 1000) self._entries[(peer_hash, transport)] = BanlistEntry( peer_hash=peer_hash, expiration_ms=now + duration_ms, transport=transport, reason=reason, ) def unbanlist_router(self, peer_hash: bytes) -> None: """Remove all bans for a peer.""" keys_to_remove = [k for k in self._entries if k[0] == peer_hash] for k in keys_to_remove: del self._entries[k] def _is_active(self, entry: BanlistEntry) -> bool: """Check if a ban entry is still active (not expired).""" if entry.expiration_ms == 0: return True # permanent return int(time.time() * 1000) < entry.expiration_ms def is_banlisted(self, peer_hash: bytes) -> bool: """Check if peer is globally banned (transport=None).""" entry = self._entries.get((peer_hash, None)) if entry is None: return False if not self._is_active(entry): del self._entries[(peer_hash, None)] return False return True def is_banlisted_transport(self, peer_hash: bytes, transport: str) -> bool: """Check if peer is banned for a specific transport. Returns True if globally banned OR transport-specifically banned. """ # Check global ban first if self.is_banlisted(peer_hash): return True # Check transport-specific entry = self._entries.get((peer_hash, transport)) if entry is None: return False if not self._is_active(entry): del self._entries[(peer_hash, transport)] return False return True def get_entry(self, peer_hash: bytes) -> BanlistEntry | None: """Get the global ban entry for a peer, or None.""" entry = self._entries.get((peer_hash, None)) if entry and self._is_active(entry): return entry return None def get_all_entries(self) -> list[BanlistEntry]: """Get all active ban entries.""" return [e for e in self._entries.values() if self._is_active(e)] def cleanup_expired(self) -> int: """Remove expired entries. Returns count removed.""" expired = [k for k, e in self._entries.items() if not self._is_active(e)] for k in expired: del self._entries[k] return len(expired) @property def count(self) -> int: """Number of active ban entries.""" return sum(1 for e in self._entries.values() if self._is_active(e))