A Python port of the Invisible Internet Project (I2P)
1"""Session key and tag management for AES+SessionTag protocol.
2
3Ported from net.i2p.crypto.SessionKeyManager and net.i2p.crypto.TagSetHandle.
4
5SessionTags are single-use 32-byte tokens that identify which session key
6to use for decrypting an incoming message. The SessionKeyManager maintains
7O(1) tag-to-key lookup via a flat dictionary, and groups tags into TagSets
8that can be expired collectively.
9"""
10
11from __future__ import annotations
12
13import os
14import time
15
16
17class TagSet:
18 """A group of SessionTags bound to a single session key.
19
20 Tags within a TagSet share the same expiration time and session key.
21 Each tag is single-use: once consumed, it is removed from the set.
22 """
23
24 DEFAULT_LIFETIME_MS = 720000 # 12 minutes
25
26 __slots__ = ("_session_key", "_tags", "_creation_ms", "_expiration_ms")
27
28 def __init__(
29 self,
30 session_key: bytes,
31 tags: list[bytes],
32 creation_ms: int,
33 expiration_ms: int,
34 ) -> None:
35 self._session_key = session_key
36 self._tags: set[bytes] = set(tags)
37 self._creation_ms = creation_ms
38 self._expiration_ms = expiration_ms
39
40 @property
41 def session_key(self) -> bytes:
42 return self._session_key
43
44 @property
45 def expiration_ms(self) -> int:
46 return self._expiration_ms
47
48 def remaining(self) -> int:
49 """Return the number of unconsumed tags."""
50 return len(self._tags)
51
52 def consume(self, tag: bytes) -> bool:
53 """Remove *tag* if present. Return True if it was found."""
54 try:
55 self._tags.remove(tag)
56 return True
57 except KeyError:
58 return False
59
60 def is_expired(self, now_ms: int) -> bool:
61 """Return True if this TagSet has expired."""
62 return now_ms >= self._expiration_ms
63
64 def tags(self) -> frozenset[bytes]:
65 """Return an immutable snapshot of remaining tags."""
66 return frozenset(self._tags)
67
68
69class SessionKeyManager:
70 """Manages session keys and their associated tags.
71
72 Provides O(1) tag lookup via ``_tag_to_key`` and tracks per-destination
73 sessions via ``_dest_to_tagsets``.
74 """
75
76 NUM_TAGS_PER_SESSION = 20
77
78 def __init__(self) -> None:
79 # O(1) tag -> session_key mapping
80 self._tag_to_key: dict[bytes, bytes] = {}
81 # destination_hash -> list of TagSets
82 self._dest_to_tagsets: dict[bytes, list[TagSet]] = {}
83 # session_key -> list of TagSets (for add_tags without a destination)
84 self._key_to_tagsets: dict[bytes, list[TagSet]] = {}
85 # destination_hash -> session_key (most recent)
86 self._dest_to_key: dict[bytes, bytes] = {}
87 # token -> (dest_hash, session_key, tags) for pending delivery ACKs
88 self._pending_delivery: dict[int, tuple[bytes, bytes, list[bytes]]] = {}
89
90 # ------------------------------------------------------------------
91 # Public API
92 # ------------------------------------------------------------------
93
94 def create_session(
95 self, destination_hash: bytes
96 ) -> tuple[bytes, list[bytes]]:
97 """Create a new session for *destination_hash*.
98
99 Returns a tuple of (session_key, list_of_tags) where the session
100 key is 32 random bytes and the tag list contains
101 ``NUM_TAGS_PER_SESSION`` random 32-byte tags.
102 """
103 session_key = os.urandom(32)
104 tags = [os.urandom(32) for _ in range(self.NUM_TAGS_PER_SESSION)]
105
106 now_ms = int(time.time() * 1000)
107 tagset = TagSet(
108 session_key=session_key,
109 tags=list(tags),
110 creation_ms=now_ms,
111 expiration_ms=now_ms + TagSet.DEFAULT_LIFETIME_MS,
112 )
113
114 # Register in destination map
115 self._dest_to_tagsets.setdefault(destination_hash, []).append(tagset)
116 self._dest_to_key[destination_hash] = session_key
117
118 # Register in key map
119 self._key_to_tagsets.setdefault(session_key, []).append(tagset)
120
121 # Register each tag for O(1) lookup
122 for tag in tags:
123 self._tag_to_key[tag] = session_key
124
125 return session_key, tags
126
127 def consume_tag(self, tag: bytes) -> bytes | None:
128 """Look up the session key for *tag* in O(1).
129
130 If the tag exists and the owning TagSet is not expired, the tag is
131 consumed (removed) and the session key is returned. Otherwise
132 returns ``None``.
133 """
134 session_key = self._tag_to_key.pop(tag, None)
135 if session_key is None:
136 return None
137
138 # Also remove from the TagSet so remaining() stays accurate
139 for tagset_list in (
140 self._key_to_tagsets.get(session_key, []),
141 ):
142 for ts in tagset_list:
143 if ts.consume(tag):
144 break
145
146 return session_key
147
148 def add_tags(
149 self,
150 session_key: bytes,
151 tags: list[bytes],
152 expiration_ms: int,
153 ) -> None:
154 """Store tags received from a remote peer for future encryption."""
155 now_ms = int(time.time() * 1000)
156 tagset = TagSet(
157 session_key=session_key,
158 tags=list(tags),
159 creation_ms=now_ms,
160 expiration_ms=expiration_ms,
161 )
162 self._key_to_tagsets.setdefault(session_key, []).append(tagset)
163
164 for tag in tags:
165 self._tag_to_key[tag] = session_key
166
167 def expire_old(self, now_ms: int) -> None:
168 """Remove all expired TagSets and their tags from the index."""
169 # Expire in dest_to_tagsets
170 for dest in list(self._dest_to_tagsets):
171 surviving = []
172 for ts in self._dest_to_tagsets[dest]:
173 if ts.is_expired(now_ms):
174 self._remove_tagset_entries(ts)
175 else:
176 surviving.append(ts)
177 if surviving:
178 self._dest_to_tagsets[dest] = surviving
179 else:
180 del self._dest_to_tagsets[dest]
181
182 # Expire in key_to_tagsets
183 for key in list(self._key_to_tagsets):
184 surviving = []
185 for ts in self._key_to_tagsets[key]:
186 if ts.is_expired(now_ms):
187 self._remove_tagset_entries(ts)
188 else:
189 surviving.append(ts)
190 if surviving:
191 self._key_to_tagsets[key] = surviving
192 else:
193 del self._key_to_tagsets[key]
194
195 def has_session(self, destination_hash: bytes) -> bool:
196 """Return True if there is at least one session for *destination_hash*."""
197 return destination_hash in self._dest_to_tagsets
198
199 def get_session_key(self, destination_hash: bytes) -> bytes | None:
200 """Return the most recent session key for *destination_hash*, or None."""
201 return self._dest_to_key.get(destination_hash)
202
203 # ------------------------------------------------------------------
204 # Extended API (Tier 0 — message routing support)
205 # ------------------------------------------------------------------
206
207 def get_current_or_new_key(
208 self, destination_hash: bytes
209 ) -> tuple[bytes, bool]:
210 """Get existing session key or create a new session.
211
212 Returns (session_key, is_new) where is_new is True if a
213 fresh session was created.
214 """
215 existing = self._dest_to_key.get(destination_hash)
216 if existing is not None:
217 return existing, False
218 key, _ = self.create_session(destination_hash)
219 return key, True
220
221 def consume_next_available_tag(
222 self, destination_hash: bytes
223 ) -> bytes | None:
224 """Consume and return the next available tag for a destination.
225
226 Returns None if no tags are available.
227 """
228 tagsets = self._dest_to_tagsets.get(destination_hash, [])
229 now_ms = int(time.time() * 1000)
230 for ts in tagsets:
231 if ts.is_expired(now_ms):
232 continue
233 remaining = ts.tags()
234 if remaining:
235 tag = next(iter(remaining))
236 ts.consume(tag)
237 self._tag_to_key.pop(tag, None)
238 return tag
239 return None
240
241 def should_send_tags(
242 self, destination_hash: bytes, low_threshold: int = 20
243 ) -> bool:
244 """Return True if available tags are below the threshold."""
245 tagsets = self._dest_to_tagsets.get(destination_hash, [])
246 now_ms = int(time.time() * 1000)
247 total = sum(
248 ts.remaining() for ts in tagsets if not ts.is_expired(now_ms)
249 )
250 return total < low_threshold
251
252 def tags_delivered(
253 self,
254 destination_hash: bytes,
255 session_key: bytes,
256 tags: list[bytes],
257 token: int,
258 ) -> None:
259 """Register tags as pending delivery (awaiting ACK)."""
260 self._pending_delivery[token] = (
261 destination_hash,
262 session_key,
263 list(tags),
264 )
265
266 def tags_acked(self, token: int) -> None:
267 """Confirm tag delivery — make tags available for consumption."""
268 pending = self._pending_delivery.pop(token, None)
269 if pending is None:
270 return
271 dest_hash, session_key, tags = pending
272 now_ms = int(time.time() * 1000)
273 self.add_tags(session_key, tags, now_ms + TagSet.DEFAULT_LIFETIME_MS)
274 # Also track under destination
275 tagsets = self._key_to_tagsets.get(session_key, [])
276 if tagsets:
277 dest_list = self._dest_to_tagsets.setdefault(dest_hash, [])
278 latest = tagsets[-1]
279 if latest not in dest_list:
280 dest_list.append(latest)
281
282 def fail_tags(self, token: int) -> None:
283 """Remove failed tags from pending — they are not usable."""
284 self._pending_delivery.pop(token, None)
285
286 # ------------------------------------------------------------------
287 # Internal helpers
288 # ------------------------------------------------------------------
289
290 def _remove_tagset_entries(self, tagset: TagSet) -> None:
291 """Remove all remaining tags of *tagset* from the O(1) index."""
292 for tag in tagset.tags():
293 self._tag_to_key.pop(tag, None)