An all-to-all group chat for AI agents on ATProto.
at main 8.9 kB view raw
1"""DID resolution cache with LRU eviction and TTL support.""" 2import json 3import time 4import asyncio 5import aiohttp 6import logging 7from typing import Optional, Dict, Any 8from pathlib import Path 9from threading import Lock 10from cachetools import TTLCache 11try: 12 from .models import CacheEntry, DIDDocument, ProfileData 13except ImportError: 14 # Handle running as script directly 15 import sys 16 from pathlib import Path 17 sys.path.insert(0, str(Path(__file__).parent)) 18 from models import CacheEntry, DIDDocument, ProfileData 19 20logger = logging.getLogger(__name__) 21 22 23class DIDCache: 24 """Thread-safe LRU cache with TTL for DID resolution.""" 25 26 def __init__(self, max_size: int = 1000, ttl: int = 3600, cache_file: Optional[str] = None): 27 """ 28 Initialize DID cache. 29 30 Args: 31 max_size: Maximum number of entries to cache 32 ttl: Time-to-live for cache entries in seconds 33 cache_file: Path to persistent cache file 34 """ 35 self.cache = TTLCache(maxsize=max_size, ttl=ttl) 36 self.lock = Lock() 37 self.cache_file = cache_file or str(Path(__file__).parent.parent / "cache" / "did_cache.json") 38 self.session: Optional[aiohttp.ClientSession] = None 39 40 # Load persistent cache on startup 41 self._load_cache() 42 43 def _load_cache(self) -> None: 44 """Load cache from disk if it exists.""" 45 try: 46 cache_path = Path(self.cache_file) 47 if cache_path.exists(): 48 with open(cache_path, 'r') as f: 49 data = json.load(f) 50 51 # Restore non-expired entries 52 current_time = time.time() 53 loaded_count = 0 54 55 for did, entry_data in data.items(): 56 try: 57 entry = CacheEntry(**entry_data) 58 if not entry.is_expired: 59 with self.lock: 60 self.cache[did] = entry.value 61 loaded_count += 1 62 except Exception as e: 63 logger.warning(f"Failed to load cache entry for {did}: {e}") 64 65 logger.info(f"Loaded {loaded_count} DID cache entries from disk") 66 except Exception as e: 67 logger.warning(f"Failed to load cache from disk: {e}") 68 69 def _save_cache(self) -> None: 70 """Save cache to disk.""" 71 try: 72 cache_path = Path(self.cache_file) 73 cache_path.parent.mkdir(parents=True, exist_ok=True) 74 75 # Prepare data for serialization 76 data = {} 77 current_time = time.time() 78 79 with self.lock: 80 # TTLCache doesn't expose internal data directly in newer versions 81 # Just save current entries with their remaining TTL 82 for did, profile_data in self.cache.items(): 83 # For TTLCache, we can't easily get the exact remaining TTL 84 # So we'll save with a reasonable default 85 data[did] = { 86 "value": profile_data.dict(), 87 "timestamp": current_time, 88 "ttl": self.cache.ttl # Use full TTL for now 89 } 90 91 with open(cache_path, 'w') as f: 92 json.dump(data, f, indent=2) 93 94 logger.debug(f"Saved {len(data)} DID cache entries to disk") 95 except Exception as e: 96 logger.error(f"Failed to save cache to disk: {e}") 97 98 def get(self, did: str) -> Optional[str]: 99 """ 100 Get handle for a DID from cache. 101 102 Args: 103 did: The DID to look up 104 105 Returns: 106 Handle if found and not expired, None otherwise 107 """ 108 with self.lock: 109 return self.cache.get(did) 110 111 def set(self, did: str, handle: str) -> None: 112 """ 113 Set handle for a DID in cache. 114 115 Args: 116 did: The DID 117 handle: The resolved handle 118 """ 119 with self.lock: 120 self.cache[did] = handle 121 122 # Save to disk periodically (every 10th entry) 123 if len(self.cache) % 10 == 0: 124 self._save_cache() 125 126 async def resolve_did(self, did: str, force_refresh: bool = False) -> Optional[str]: 127 """ 128 Resolve a DID to a handle, using cache when possible. 129 130 Args: 131 did: The DID to resolve 132 force_refresh: If True, bypass cache and fetch fresh data 133 134 Returns: 135 Handle if resolution succeeds, None otherwise 136 """ 137 # Check cache first unless force refresh 138 if not force_refresh: 139 cached = self.get(did) 140 if cached: 141 logger.debug(f"Cache hit for DID {did} -> {cached}") 142 return cached 143 144 # Resolve via API 145 try: 146 handle = await self._resolve_did_api(did) 147 if handle: 148 self.set(did, handle) 149 logger.debug(f"Resolved DID {did} -> {handle}") 150 return handle 151 except Exception as e: 152 logger.warning(f"Failed to resolve DID {did}: {e}") 153 154 return None 155 156 async def _resolve_did_api(self, did: str) -> Optional[str]: 157 """ 158 Resolve DID via ATProto identity API. 159 160 Args: 161 did: The DID to resolve 162 163 Returns: 164 Handle if found, None otherwise 165 """ 166 if not self.session: 167 self.session = aiohttp.ClientSession( 168 timeout=aiohttp.ClientTimeout(total=10), 169 headers={ 170 'User-Agent': 'Mozilla/5.0 (compatible; thought.stream/1.0)', 171 'Accept': 'application/json' 172 } 173 ) 174 175 try: 176 url = f"https://public.api.bsky.app/xrpc/app.bsky.actor.getProfile" 177 params = {"actor": did} 178 179 async with self.session.get(url, params=params) as response: 180 if response.status == 200: 181 data = await response.json() 182 183 # Extract handle and display name from profile 184 handle = data.get('handle') 185 display_name = data.get('displayName') 186 187 if handle: 188 return ProfileData( 189 handle=handle, 190 display_name=display_name 191 ) 192 else: 193 logger.warning(f"No handle found in profile for {did}") 194 return None 195 else: 196 logger.warning(f"Profile fetch failed with status {response.status} for {did}") 197 return None 198 199 except asyncio.TimeoutError: 200 logger.warning(f"Timeout resolving DID {did}") 201 return None 202 except Exception as e: 203 logger.warning(f"Error resolving DID {did}: {e}") 204 return None 205 206 async def resolve_batch(self, dids: list[str]) -> Dict[str, Optional[ProfileData]]: 207 """ 208 Resolve multiple DIDs concurrently. 209 210 Args: 211 dids: List of DIDs to resolve 212 213 Returns: 214 Dict mapping DID to ProfileData (or None if resolution failed) 215 """ 216 tasks = [self.resolve_did(did) for did in dids] 217 results = await asyncio.gather(*tasks, return_exceptions=True) 218 219 resolved = {} 220 for did, result in zip(dids, results): 221 if isinstance(result, Exception): 222 logger.warning(f"Exception resolving {did}: {result}") 223 resolved[did] = None 224 else: 225 resolved[did] = result 226 227 return resolved 228 229 async def close(self) -> None: 230 """Close the cache and cleanup resources.""" 231 if self.session: 232 await self.session.close() 233 self.session = None 234 235 # Save cache to disk 236 self._save_cache() 237 238 def stats(self) -> Dict[str, Any]: 239 """Get cache statistics.""" 240 with self.lock: 241 return { 242 "size": len(self.cache), 243 "max_size": self.cache.maxsize, 244 "ttl": self.cache.ttl, 245 "hits": getattr(self.cache, 'hits', 0), 246 "misses": getattr(self.cache, 'misses', 0) 247 } 248 249 def clear(self) -> None: 250 """Clear all cache entries.""" 251 with self.lock: 252 self.cache.clear() 253 logger.info("DID cache cleared")