"""Session key and tag management for AES+SessionTag protocol. Ported from net.i2p.crypto.SessionKeyManager and net.i2p.crypto.TagSetHandle. SessionTags are single-use 32-byte tokens that identify which session key to use for decrypting an incoming message. The SessionKeyManager maintains O(1) tag-to-key lookup via a flat dictionary, and groups tags into TagSets that can be expired collectively. """ from __future__ import annotations import os import time class TagSet: """A group of SessionTags bound to a single session key. Tags within a TagSet share the same expiration time and session key. Each tag is single-use: once consumed, it is removed from the set. """ DEFAULT_LIFETIME_MS = 720000 # 12 minutes __slots__ = ("_session_key", "_tags", "_creation_ms", "_expiration_ms") def __init__( self, session_key: bytes, tags: list[bytes], creation_ms: int, expiration_ms: int, ) -> None: self._session_key = session_key self._tags: set[bytes] = set(tags) self._creation_ms = creation_ms self._expiration_ms = expiration_ms @property def session_key(self) -> bytes: return self._session_key @property def expiration_ms(self) -> int: return self._expiration_ms def remaining(self) -> int: """Return the number of unconsumed tags.""" return len(self._tags) def consume(self, tag: bytes) -> bool: """Remove *tag* if present. Return True if it was found.""" try: self._tags.remove(tag) return True except KeyError: return False def is_expired(self, now_ms: int) -> bool: """Return True if this TagSet has expired.""" return now_ms >= self._expiration_ms def tags(self) -> frozenset[bytes]: """Return an immutable snapshot of remaining tags.""" return frozenset(self._tags) class SessionKeyManager: """Manages session keys and their associated tags. Provides O(1) tag lookup via ``_tag_to_key`` and tracks per-destination sessions via ``_dest_to_tagsets``. """ NUM_TAGS_PER_SESSION = 20 def __init__(self) -> None: # O(1) tag -> session_key mapping self._tag_to_key: dict[bytes, bytes] = {} # destination_hash -> list of TagSets self._dest_to_tagsets: dict[bytes, list[TagSet]] = {} # session_key -> list of TagSets (for add_tags without a destination) self._key_to_tagsets: dict[bytes, list[TagSet]] = {} # destination_hash -> session_key (most recent) self._dest_to_key: dict[bytes, bytes] = {} # token -> (dest_hash, session_key, tags) for pending delivery ACKs self._pending_delivery: dict[int, tuple[bytes, bytes, list[bytes]]] = {} # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ def create_session( self, destination_hash: bytes ) -> tuple[bytes, list[bytes]]: """Create a new session for *destination_hash*. Returns a tuple of (session_key, list_of_tags) where the session key is 32 random bytes and the tag list contains ``NUM_TAGS_PER_SESSION`` random 32-byte tags. """ session_key = os.urandom(32) tags = [os.urandom(32) for _ in range(self.NUM_TAGS_PER_SESSION)] now_ms = int(time.time() * 1000) tagset = TagSet( session_key=session_key, tags=list(tags), creation_ms=now_ms, expiration_ms=now_ms + TagSet.DEFAULT_LIFETIME_MS, ) # Register in destination map self._dest_to_tagsets.setdefault(destination_hash, []).append(tagset) self._dest_to_key[destination_hash] = session_key # Register in key map self._key_to_tagsets.setdefault(session_key, []).append(tagset) # Register each tag for O(1) lookup for tag in tags: self._tag_to_key[tag] = session_key return session_key, tags def consume_tag(self, tag: bytes) -> bytes | None: """Look up the session key for *tag* in O(1). If the tag exists and the owning TagSet is not expired, the tag is consumed (removed) and the session key is returned. Otherwise returns ``None``. """ session_key = self._tag_to_key.pop(tag, None) if session_key is None: return None # Also remove from the TagSet so remaining() stays accurate for tagset_list in ( self._key_to_tagsets.get(session_key, []), ): for ts in tagset_list: if ts.consume(tag): break return session_key def add_tags( self, session_key: bytes, tags: list[bytes], expiration_ms: int, ) -> None: """Store tags received from a remote peer for future encryption.""" now_ms = int(time.time() * 1000) tagset = TagSet( session_key=session_key, tags=list(tags), creation_ms=now_ms, expiration_ms=expiration_ms, ) self._key_to_tagsets.setdefault(session_key, []).append(tagset) for tag in tags: self._tag_to_key[tag] = session_key def expire_old(self, now_ms: int) -> None: """Remove all expired TagSets and their tags from the index.""" # Expire in dest_to_tagsets for dest in list(self._dest_to_tagsets): surviving = [] for ts in self._dest_to_tagsets[dest]: if ts.is_expired(now_ms): self._remove_tagset_entries(ts) else: surviving.append(ts) if surviving: self._dest_to_tagsets[dest] = surviving else: del self._dest_to_tagsets[dest] # Expire in key_to_tagsets for key in list(self._key_to_tagsets): surviving = [] for ts in self._key_to_tagsets[key]: if ts.is_expired(now_ms): self._remove_tagset_entries(ts) else: surviving.append(ts) if surviving: self._key_to_tagsets[key] = surviving else: del self._key_to_tagsets[key] def has_session(self, destination_hash: bytes) -> bool: """Return True if there is at least one session for *destination_hash*.""" return destination_hash in self._dest_to_tagsets def get_session_key(self, destination_hash: bytes) -> bytes | None: """Return the most recent session key for *destination_hash*, or None.""" return self._dest_to_key.get(destination_hash) # ------------------------------------------------------------------ # Extended API (Tier 0 — message routing support) # ------------------------------------------------------------------ def get_current_or_new_key( self, destination_hash: bytes ) -> tuple[bytes, bool]: """Get existing session key or create a new session. Returns (session_key, is_new) where is_new is True if a fresh session was created. """ existing = self._dest_to_key.get(destination_hash) if existing is not None: return existing, False key, _ = self.create_session(destination_hash) return key, True def consume_next_available_tag( self, destination_hash: bytes ) -> bytes | None: """Consume and return the next available tag for a destination. Returns None if no tags are available. """ tagsets = self._dest_to_tagsets.get(destination_hash, []) now_ms = int(time.time() * 1000) for ts in tagsets: if ts.is_expired(now_ms): continue remaining = ts.tags() if remaining: tag = next(iter(remaining)) ts.consume(tag) self._tag_to_key.pop(tag, None) return tag return None def should_send_tags( self, destination_hash: bytes, low_threshold: int = 20 ) -> bool: """Return True if available tags are below the threshold.""" tagsets = self._dest_to_tagsets.get(destination_hash, []) now_ms = int(time.time() * 1000) total = sum( ts.remaining() for ts in tagsets if not ts.is_expired(now_ms) ) return total < low_threshold def tags_delivered( self, destination_hash: bytes, session_key: bytes, tags: list[bytes], token: int, ) -> None: """Register tags as pending delivery (awaiting ACK).""" self._pending_delivery[token] = ( destination_hash, session_key, list(tags), ) def tags_acked(self, token: int) -> None: """Confirm tag delivery — make tags available for consumption.""" pending = self._pending_delivery.pop(token, None) if pending is None: return dest_hash, session_key, tags = pending now_ms = int(time.time() * 1000) self.add_tags(session_key, tags, now_ms + TagSet.DEFAULT_LIFETIME_MS) # Also track under destination tagsets = self._key_to_tagsets.get(session_key, []) if tagsets: dest_list = self._dest_to_tagsets.setdefault(dest_hash, []) latest = tagsets[-1] if latest not in dest_list: dest_list.append(latest) def fail_tags(self, token: int) -> None: """Remove failed tags from pending — they are not usable.""" self._pending_delivery.pop(token, None) # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ def _remove_tagset_entries(self, tagset: TagSet) -> None: """Remove all remaining tags of *tagset* from the O(1) index.""" for tag in tagset.tags(): self._tag_to_key.pop(tag, None)