A Python port of the Invisible Internet Project (I2P)
at main 293 lines 10 kB view raw
1"""Session key and tag management for AES+SessionTag protocol. 2 3Ported from net.i2p.crypto.SessionKeyManager and net.i2p.crypto.TagSetHandle. 4 5SessionTags are single-use 32-byte tokens that identify which session key 6to use for decrypting an incoming message. The SessionKeyManager maintains 7O(1) tag-to-key lookup via a flat dictionary, and groups tags into TagSets 8that can be expired collectively. 9""" 10 11from __future__ import annotations 12 13import os 14import time 15 16 17class TagSet: 18 """A group of SessionTags bound to a single session key. 19 20 Tags within a TagSet share the same expiration time and session key. 21 Each tag is single-use: once consumed, it is removed from the set. 22 """ 23 24 DEFAULT_LIFETIME_MS = 720000 # 12 minutes 25 26 __slots__ = ("_session_key", "_tags", "_creation_ms", "_expiration_ms") 27 28 def __init__( 29 self, 30 session_key: bytes, 31 tags: list[bytes], 32 creation_ms: int, 33 expiration_ms: int, 34 ) -> None: 35 self._session_key = session_key 36 self._tags: set[bytes] = set(tags) 37 self._creation_ms = creation_ms 38 self._expiration_ms = expiration_ms 39 40 @property 41 def session_key(self) -> bytes: 42 return self._session_key 43 44 @property 45 def expiration_ms(self) -> int: 46 return self._expiration_ms 47 48 def remaining(self) -> int: 49 """Return the number of unconsumed tags.""" 50 return len(self._tags) 51 52 def consume(self, tag: bytes) -> bool: 53 """Remove *tag* if present. Return True if it was found.""" 54 try: 55 self._tags.remove(tag) 56 return True 57 except KeyError: 58 return False 59 60 def is_expired(self, now_ms: int) -> bool: 61 """Return True if this TagSet has expired.""" 62 return now_ms >= self._expiration_ms 63 64 def tags(self) -> frozenset[bytes]: 65 """Return an immutable snapshot of remaining tags.""" 66 return frozenset(self._tags) 67 68 69class SessionKeyManager: 70 """Manages session keys and their associated tags. 71 72 Provides O(1) tag lookup via ``_tag_to_key`` and tracks per-destination 73 sessions via ``_dest_to_tagsets``. 74 """ 75 76 NUM_TAGS_PER_SESSION = 20 77 78 def __init__(self) -> None: 79 # O(1) tag -> session_key mapping 80 self._tag_to_key: dict[bytes, bytes] = {} 81 # destination_hash -> list of TagSets 82 self._dest_to_tagsets: dict[bytes, list[TagSet]] = {} 83 # session_key -> list of TagSets (for add_tags without a destination) 84 self._key_to_tagsets: dict[bytes, list[TagSet]] = {} 85 # destination_hash -> session_key (most recent) 86 self._dest_to_key: dict[bytes, bytes] = {} 87 # token -> (dest_hash, session_key, tags) for pending delivery ACKs 88 self._pending_delivery: dict[int, tuple[bytes, bytes, list[bytes]]] = {} 89 90 # ------------------------------------------------------------------ 91 # Public API 92 # ------------------------------------------------------------------ 93 94 def create_session( 95 self, destination_hash: bytes 96 ) -> tuple[bytes, list[bytes]]: 97 """Create a new session for *destination_hash*. 98 99 Returns a tuple of (session_key, list_of_tags) where the session 100 key is 32 random bytes and the tag list contains 101 ``NUM_TAGS_PER_SESSION`` random 32-byte tags. 102 """ 103 session_key = os.urandom(32) 104 tags = [os.urandom(32) for _ in range(self.NUM_TAGS_PER_SESSION)] 105 106 now_ms = int(time.time() * 1000) 107 tagset = TagSet( 108 session_key=session_key, 109 tags=list(tags), 110 creation_ms=now_ms, 111 expiration_ms=now_ms + TagSet.DEFAULT_LIFETIME_MS, 112 ) 113 114 # Register in destination map 115 self._dest_to_tagsets.setdefault(destination_hash, []).append(tagset) 116 self._dest_to_key[destination_hash] = session_key 117 118 # Register in key map 119 self._key_to_tagsets.setdefault(session_key, []).append(tagset) 120 121 # Register each tag for O(1) lookup 122 for tag in tags: 123 self._tag_to_key[tag] = session_key 124 125 return session_key, tags 126 127 def consume_tag(self, tag: bytes) -> bytes | None: 128 """Look up the session key for *tag* in O(1). 129 130 If the tag exists and the owning TagSet is not expired, the tag is 131 consumed (removed) and the session key is returned. Otherwise 132 returns ``None``. 133 """ 134 session_key = self._tag_to_key.pop(tag, None) 135 if session_key is None: 136 return None 137 138 # Also remove from the TagSet so remaining() stays accurate 139 for tagset_list in ( 140 self._key_to_tagsets.get(session_key, []), 141 ): 142 for ts in tagset_list: 143 if ts.consume(tag): 144 break 145 146 return session_key 147 148 def add_tags( 149 self, 150 session_key: bytes, 151 tags: list[bytes], 152 expiration_ms: int, 153 ) -> None: 154 """Store tags received from a remote peer for future encryption.""" 155 now_ms = int(time.time() * 1000) 156 tagset = TagSet( 157 session_key=session_key, 158 tags=list(tags), 159 creation_ms=now_ms, 160 expiration_ms=expiration_ms, 161 ) 162 self._key_to_tagsets.setdefault(session_key, []).append(tagset) 163 164 for tag in tags: 165 self._tag_to_key[tag] = session_key 166 167 def expire_old(self, now_ms: int) -> None: 168 """Remove all expired TagSets and their tags from the index.""" 169 # Expire in dest_to_tagsets 170 for dest in list(self._dest_to_tagsets): 171 surviving = [] 172 for ts in self._dest_to_tagsets[dest]: 173 if ts.is_expired(now_ms): 174 self._remove_tagset_entries(ts) 175 else: 176 surviving.append(ts) 177 if surviving: 178 self._dest_to_tagsets[dest] = surviving 179 else: 180 del self._dest_to_tagsets[dest] 181 182 # Expire in key_to_tagsets 183 for key in list(self._key_to_tagsets): 184 surviving = [] 185 for ts in self._key_to_tagsets[key]: 186 if ts.is_expired(now_ms): 187 self._remove_tagset_entries(ts) 188 else: 189 surviving.append(ts) 190 if surviving: 191 self._key_to_tagsets[key] = surviving 192 else: 193 del self._key_to_tagsets[key] 194 195 def has_session(self, destination_hash: bytes) -> bool: 196 """Return True if there is at least one session for *destination_hash*.""" 197 return destination_hash in self._dest_to_tagsets 198 199 def get_session_key(self, destination_hash: bytes) -> bytes | None: 200 """Return the most recent session key for *destination_hash*, or None.""" 201 return self._dest_to_key.get(destination_hash) 202 203 # ------------------------------------------------------------------ 204 # Extended API (Tier 0 — message routing support) 205 # ------------------------------------------------------------------ 206 207 def get_current_or_new_key( 208 self, destination_hash: bytes 209 ) -> tuple[bytes, bool]: 210 """Get existing session key or create a new session. 211 212 Returns (session_key, is_new) where is_new is True if a 213 fresh session was created. 214 """ 215 existing = self._dest_to_key.get(destination_hash) 216 if existing is not None: 217 return existing, False 218 key, _ = self.create_session(destination_hash) 219 return key, True 220 221 def consume_next_available_tag( 222 self, destination_hash: bytes 223 ) -> bytes | None: 224 """Consume and return the next available tag for a destination. 225 226 Returns None if no tags are available. 227 """ 228 tagsets = self._dest_to_tagsets.get(destination_hash, []) 229 now_ms = int(time.time() * 1000) 230 for ts in tagsets: 231 if ts.is_expired(now_ms): 232 continue 233 remaining = ts.tags() 234 if remaining: 235 tag = next(iter(remaining)) 236 ts.consume(tag) 237 self._tag_to_key.pop(tag, None) 238 return tag 239 return None 240 241 def should_send_tags( 242 self, destination_hash: bytes, low_threshold: int = 20 243 ) -> bool: 244 """Return True if available tags are below the threshold.""" 245 tagsets = self._dest_to_tagsets.get(destination_hash, []) 246 now_ms = int(time.time() * 1000) 247 total = sum( 248 ts.remaining() for ts in tagsets if not ts.is_expired(now_ms) 249 ) 250 return total < low_threshold 251 252 def tags_delivered( 253 self, 254 destination_hash: bytes, 255 session_key: bytes, 256 tags: list[bytes], 257 token: int, 258 ) -> None: 259 """Register tags as pending delivery (awaiting ACK).""" 260 self._pending_delivery[token] = ( 261 destination_hash, 262 session_key, 263 list(tags), 264 ) 265 266 def tags_acked(self, token: int) -> None: 267 """Confirm tag delivery — make tags available for consumption.""" 268 pending = self._pending_delivery.pop(token, None) 269 if pending is None: 270 return 271 dest_hash, session_key, tags = pending 272 now_ms = int(time.time() * 1000) 273 self.add_tags(session_key, tags, now_ms + TagSet.DEFAULT_LIFETIME_MS) 274 # Also track under destination 275 tagsets = self._key_to_tagsets.get(session_key, []) 276 if tagsets: 277 dest_list = self._dest_to_tagsets.setdefault(dest_hash, []) 278 latest = tagsets[-1] 279 if latest not in dest_list: 280 dest_list.append(latest) 281 282 def fail_tags(self, token: int) -> None: 283 """Remove failed tags from pending — they are not usable.""" 284 self._pending_delivery.pop(token, None) 285 286 # ------------------------------------------------------------------ 287 # Internal helpers 288 # ------------------------------------------------------------------ 289 290 def _remove_tagset_entries(self, tagset: TagSet) -> None: 291 """Remove all remaining tags of *tagset* from the O(1) index.""" 292 for tag in tagset.tags(): 293 self._tag_to_key.pop(tag, None)