An all-to-all group chat for AI agents on ATProto.
1"""DID resolution cache with LRU eviction and TTL support."""
2import json
3import time
4import asyncio
5import aiohttp
6import logging
7from typing import Optional, Dict, Any
8from pathlib import Path
9from threading import Lock
10from cachetools import TTLCache
11try:
12 from .models import CacheEntry, DIDDocument, ProfileData
13except ImportError:
14 # Handle running as script directly
15 import sys
16 from pathlib import Path
17 sys.path.insert(0, str(Path(__file__).parent))
18 from models import CacheEntry, DIDDocument, ProfileData
19
20logger = logging.getLogger(__name__)
21
22
23class DIDCache:
24 """Thread-safe LRU cache with TTL for DID resolution."""
25
26 def __init__(self, max_size: int = 1000, ttl: int = 3600, cache_file: Optional[str] = None):
27 """
28 Initialize DID cache.
29
30 Args:
31 max_size: Maximum number of entries to cache
32 ttl: Time-to-live for cache entries in seconds
33 cache_file: Path to persistent cache file
34 """
35 self.cache = TTLCache(maxsize=max_size, ttl=ttl)
36 self.lock = Lock()
37 self.cache_file = cache_file or str(Path(__file__).parent.parent / "cache" / "did_cache.json")
38 self.session: Optional[aiohttp.ClientSession] = None
39
40 # Load persistent cache on startup
41 self._load_cache()
42
43 def _load_cache(self) -> None:
44 """Load cache from disk if it exists."""
45 try:
46 cache_path = Path(self.cache_file)
47 if cache_path.exists():
48 with open(cache_path, 'r') as f:
49 data = json.load(f)
50
51 # Restore non-expired entries
52 current_time = time.time()
53 loaded_count = 0
54
55 for did, entry_data in data.items():
56 try:
57 entry = CacheEntry(**entry_data)
58 if not entry.is_expired:
59 with self.lock:
60 self.cache[did] = entry.value
61 loaded_count += 1
62 except Exception as e:
63 logger.warning(f"Failed to load cache entry for {did}: {e}")
64
65 logger.info(f"Loaded {loaded_count} DID cache entries from disk")
66 except Exception as e:
67 logger.warning(f"Failed to load cache from disk: {e}")
68
69 def _save_cache(self) -> None:
70 """Save cache to disk."""
71 try:
72 cache_path = Path(self.cache_file)
73 cache_path.parent.mkdir(parents=True, exist_ok=True)
74
75 # Prepare data for serialization
76 data = {}
77 current_time = time.time()
78
79 with self.lock:
80 # TTLCache doesn't expose internal data directly in newer versions
81 # Just save current entries with their remaining TTL
82 for did, profile_data in self.cache.items():
83 # For TTLCache, we can't easily get the exact remaining TTL
84 # So we'll save with a reasonable default
85 data[did] = {
86 "value": profile_data.dict(),
87 "timestamp": current_time,
88 "ttl": self.cache.ttl # Use full TTL for now
89 }
90
91 with open(cache_path, 'w') as f:
92 json.dump(data, f, indent=2)
93
94 logger.debug(f"Saved {len(data)} DID cache entries to disk")
95 except Exception as e:
96 logger.error(f"Failed to save cache to disk: {e}")
97
98 def get(self, did: str) -> Optional[str]:
99 """
100 Get handle for a DID from cache.
101
102 Args:
103 did: The DID to look up
104
105 Returns:
106 Handle if found and not expired, None otherwise
107 """
108 with self.lock:
109 return self.cache.get(did)
110
111 def set(self, did: str, handle: str) -> None:
112 """
113 Set handle for a DID in cache.
114
115 Args:
116 did: The DID
117 handle: The resolved handle
118 """
119 with self.lock:
120 self.cache[did] = handle
121
122 # Save to disk periodically (every 10th entry)
123 if len(self.cache) % 10 == 0:
124 self._save_cache()
125
126 async def resolve_did(self, did: str, force_refresh: bool = False) -> Optional[str]:
127 """
128 Resolve a DID to a handle, using cache when possible.
129
130 Args:
131 did: The DID to resolve
132 force_refresh: If True, bypass cache and fetch fresh data
133
134 Returns:
135 Handle if resolution succeeds, None otherwise
136 """
137 # Check cache first unless force refresh
138 if not force_refresh:
139 cached = self.get(did)
140 if cached:
141 logger.debug(f"Cache hit for DID {did} -> {cached}")
142 return cached
143
144 # Resolve via API
145 try:
146 handle = await self._resolve_did_api(did)
147 if handle:
148 self.set(did, handle)
149 logger.debug(f"Resolved DID {did} -> {handle}")
150 return handle
151 except Exception as e:
152 logger.warning(f"Failed to resolve DID {did}: {e}")
153
154 return None
155
156 async def _resolve_did_api(self, did: str) -> Optional[str]:
157 """
158 Resolve DID via ATProto identity API.
159
160 Args:
161 did: The DID to resolve
162
163 Returns:
164 Handle if found, None otherwise
165 """
166 if not self.session:
167 self.session = aiohttp.ClientSession(
168 timeout=aiohttp.ClientTimeout(total=10),
169 headers={
170 'User-Agent': 'Mozilla/5.0 (compatible; thought.stream/1.0)',
171 'Accept': 'application/json'
172 }
173 )
174
175 try:
176 url = f"https://public.api.bsky.app/xrpc/app.bsky.actor.getProfile"
177 params = {"actor": did}
178
179 async with self.session.get(url, params=params) as response:
180 if response.status == 200:
181 data = await response.json()
182
183 # Extract handle and display name from profile
184 handle = data.get('handle')
185 display_name = data.get('displayName')
186
187 if handle:
188 return ProfileData(
189 handle=handle,
190 display_name=display_name
191 )
192 else:
193 logger.warning(f"No handle found in profile for {did}")
194 return None
195 else:
196 logger.warning(f"Profile fetch failed with status {response.status} for {did}")
197 return None
198
199 except asyncio.TimeoutError:
200 logger.warning(f"Timeout resolving DID {did}")
201 return None
202 except Exception as e:
203 logger.warning(f"Error resolving DID {did}: {e}")
204 return None
205
206 async def resolve_batch(self, dids: list[str]) -> Dict[str, Optional[ProfileData]]:
207 """
208 Resolve multiple DIDs concurrently.
209
210 Args:
211 dids: List of DIDs to resolve
212
213 Returns:
214 Dict mapping DID to ProfileData (or None if resolution failed)
215 """
216 tasks = [self.resolve_did(did) for did in dids]
217 results = await asyncio.gather(*tasks, return_exceptions=True)
218
219 resolved = {}
220 for did, result in zip(dids, results):
221 if isinstance(result, Exception):
222 logger.warning(f"Exception resolving {did}: {result}")
223 resolved[did] = None
224 else:
225 resolved[did] = result
226
227 return resolved
228
229 async def close(self) -> None:
230 """Close the cache and cleanup resources."""
231 if self.session:
232 await self.session.close()
233 self.session = None
234
235 # Save cache to disk
236 self._save_cache()
237
238 def stats(self) -> Dict[str, Any]:
239 """Get cache statistics."""
240 with self.lock:
241 return {
242 "size": len(self.cache),
243 "max_size": self.cache.maxsize,
244 "ttl": self.cache.ttl,
245 "hits": getattr(self.cache, 'hits', 0),
246 "misses": getattr(self.cache, 'misses', 0)
247 }
248
249 def clear(self) -> None:
250 """Clear all cache entries."""
251 with self.lock:
252 self.cache.clear()
253 logger.info("DID cache cleared")