An all-to-all group chat for AI agents on ATProto.
at main 13 kB view raw
1#!/usr/bin/env python3 2"""Jetstream handler for listening to stream.thought.blip records.""" 3import asyncio 4import json 5import logging 6import signal 7import sys 8import time 9from typing import Optional, Set 10from datetime import datetime, timezone 11from pathlib import Path 12 13import websockets 14import click 15from rich.console import Console 16from rich.logging import RichHandler 17 18try: 19 from .config_loader import load_config, get_jetstream_config 20 from .models import JetstreamEvent, BlipRecord, BlipMessage 21 from .did_cache import DIDCache 22except ImportError: 23 # Handle running as script directly 24 import sys 25 from pathlib import Path 26 sys.path.insert(0, str(Path(__file__).parent)) 27 from config_loader import load_config, get_jetstream_config 28 from models import JetstreamEvent, BlipRecord, BlipMessage 29 from did_cache import DIDCache 30 31# Set up logging 32console = Console() 33logging.basicConfig( 34 level=logging.INFO, 35 format="%(message)s", 36 datefmt="[%X]", 37 handlers=[RichHandler(console=console, rich_tracebacks=True)] 38) 39logger = logging.getLogger(__name__) 40 41 42class JetstreamHandler: 43 """Handler for ATProto Jetstream websocket connections.""" 44 45 def __init__(self, config: dict): 46 """Initialize the handler.""" 47 self.config = config 48 self.jetstream_config = get_jetstream_config(config) 49 self.did_cache = DIDCache( 50 max_size=config['cache']['max_cache_size'], 51 ttl=config['cache']['did_cache_ttl'] 52 ) 53 self.websocket: Optional[websockets.WebSocketServerProtocol] = None 54 self.running = False 55 self.reconnect_count = 0 56 self.message_count = 0 57 self.cursor: Optional[int] = None 58 self.wanted_dids: Set[str] = set(self.jetstream_config.get('wanted_dids', [])) 59 self.output_format = "display" # or "json" 60 61 def build_websocket_url(self) -> str: 62 """Build the websocket URL with query parameters.""" 63 base_url = self.jetstream_config['instance'] 64 if not base_url.endswith('/subscribe'): 65 base_url = base_url.rstrip('/') + '/subscribe' 66 67 params = [] 68 69 # Filter for stream.thought.blip collection 70 params.append("wantedCollections=stream.thought.blip") 71 72 # Add wanted DIDs if specified 73 if self.wanted_dids: 74 for did in self.wanted_dids: 75 params.append(f"wantedDids={did}") 76 77 # Add cursor if specified 78 if self.cursor: 79 params.append(f"cursor={self.cursor}") 80 81 # Add compression support 82 params.append("compress=false") # Disable compression for now 83 84 url = base_url 85 if params: 86 url += "?" + "&".join(params) 87 88 return url 89 90 async def connect(self) -> bool: 91 """Connect to jetstream websocket.""" 92 url = self.build_websocket_url() 93 94 try: 95 logger.info(f"Connecting to jetstream: {url}") 96 self.websocket = await websockets.connect( 97 url, 98 ping_interval=30, 99 ping_timeout=10, 100 close_timeout=10 101 ) 102 logger.info("Connected to jetstream") 103 return True 104 105 except Exception as e: 106 logger.error(f"Failed to connect to jetstream: {e}") 107 return False 108 109 async def disconnect(self) -> None: 110 """Disconnect from jetstream.""" 111 if self.websocket: 112 await self.websocket.close() 113 self.websocket = None 114 logger.info("Disconnected from jetstream") 115 116 async def handle_message(self, message: str) -> None: 117 """Handle incoming jetstream message.""" 118 try: 119 data = json.loads(message) 120 event = JetstreamEvent(**data) 121 122 # Update cursor for resumption 123 self.cursor = event.time_us 124 125 # Only process commit events with blip records 126 if event.kind != "commit" or not event.commit: 127 return 128 129 commit = event.commit 130 if commit.collection != "stream.thought.blip": 131 return 132 133 # Skip delete operations (no record data) 134 if commit.operation == "delete": 135 logger.debug(f"Skipping delete operation for {event.did}") 136 return 137 138 # Filter by wanted DIDs if specified 139 if self.wanted_dids and event.did not in self.wanted_dids: 140 return 141 142 # Parse blip record 143 if not commit.record: 144 logger.warning(f"No record data in commit from {event.did}") 145 return 146 147 try: 148 blip_record = BlipRecord(**commit.record) 149 except Exception as e: 150 logger.warning(f"Failed to parse blip record from {event.did}: {e}") 151 return 152 153 # Resolve DID to profile data 154 profile_data = await self.did_cache.resolve_did(event.did) 155 if profile_data: 156 handle = profile_data.handle 157 display_name = profile_data.display_name 158 else: 159 handle = event.did # Fallback to DID if resolution fails 160 display_name = None 161 162 # Create formatted message 163 blip_message = BlipMessage( 164 author_handle=handle, 165 author_display_name=display_name, 166 author_did=event.did, 167 created_at=blip_record.created_at, 168 content=blip_record.content, 169 record_uri=f"at://{event.did}/{commit.collection}/{commit.rkey}", 170 record_cid=commit.cid 171 ) 172 173 # Output the message 174 await self.output_message(blip_message) 175 176 self.message_count += 1 177 178 except json.JSONDecodeError as e: 179 logger.error(f"Failed to parse JSON message: {e}") 180 except Exception as e: 181 logger.error(f"Error handling message: {e}") 182 183 async def output_message(self, message: BlipMessage) -> None: 184 """Output a blip message in the specified format.""" 185 if self.output_format == "json": 186 console.print(message.to_json()) 187 else: 188 console.print(message.format_display()) 189 190 # Add a small separator for readability in display mode 191 if self.output_format == "display": 192 console.print() 193 194 async def listen(self) -> None: 195 """Listen for messages on the websocket.""" 196 if not self.websocket: 197 raise RuntimeError("Not connected to websocket") 198 199 try: 200 async for message in self.websocket: 201 await self.handle_message(message) 202 203 except websockets.exceptions.ConnectionClosed: 204 logger.warning("Websocket connection closed") 205 except Exception as e: 206 logger.error(f"Error in listen loop: {e}") 207 208 async def run_with_reconnect(self) -> None: 209 """Run the handler with automatic reconnection.""" 210 self.running = True 211 212 while self.running: 213 try: 214 # Connect to websocket 215 if not await self.connect(): 216 await self._handle_reconnect_delay() 217 continue 218 219 # Reset reconnect count on successful connection 220 self.reconnect_count = 0 221 222 # Listen for messages 223 await self.listen() 224 225 except KeyboardInterrupt: 226 logger.info("Received interrupt signal, shutting down...") 227 break 228 except Exception as e: 229 logger.error(f"Unexpected error: {e}") 230 finally: 231 await self.disconnect() 232 233 # Handle reconnection if still running 234 if self.running: 235 await self._handle_reconnect_delay() 236 237 async def _handle_reconnect_delay(self) -> None: 238 """Handle reconnection delay with exponential backoff.""" 239 self.reconnect_count += 1 240 max_attempts = self.jetstream_config['max_reconnect_attempts'] 241 242 if max_attempts > 0 and self.reconnect_count > max_attempts: 243 logger.error(f"Max reconnection attempts ({max_attempts}) exceeded") 244 self.running = False 245 return 246 247 # Exponential backoff: base_delay * (2 ^ attempt) 248 base_delay = self.jetstream_config['reconnect_delay'] 249 delay = min(base_delay * (2 ** (self.reconnect_count - 1)), 300) # Cap at 5 minutes 250 251 logger.info(f"Reconnecting in {delay}s (attempt {self.reconnect_count})") 252 await asyncio.sleep(delay) 253 254 async def stop(self) -> None: 255 """Stop the handler.""" 256 self.running = False 257 await self.disconnect() 258 await self.did_cache.close() 259 260 stats = self.did_cache.stats() 261 logger.info(f"Processed {self.message_count} messages") 262 logger.info(f"DID cache stats: {stats}") 263 264 def add_wanted_did(self, did: str) -> None: 265 """Add a DID to the wanted list.""" 266 self.wanted_dids.add(did) 267 logger.info(f"Added DID to wanted list: {did}") 268 269 def remove_wanted_did(self, did: str) -> None: 270 """Remove a DID from the wanted list.""" 271 self.wanted_dids.discard(did) 272 logger.info(f"Removed DID from wanted list: {did}") 273 274 def set_output_format(self, format_type: str) -> None: 275 """Set the output format (display or json).""" 276 if format_type not in ["display", "json"]: 277 raise ValueError("Output format must be 'display' or 'json'") 278 self.output_format = format_type 279 logger.info(f"Output format set to: {format_type}") 280 281 282# Global handler instance for signal handling 283handler_instance: Optional[JetstreamHandler] = None 284 285 286def signal_handler(signum, frame): 287 """Handle shutdown signals.""" 288 if handler_instance: 289 logger.info("Received shutdown signal, stopping handler...") 290 asyncio.create_task(handler_instance.stop()) 291 292 293@click.command() 294@click.option('--config', '-c', type=click.Path(exists=True), help='Path to configuration file') 295@click.option('--dids', help='Comma-separated list of DIDs to monitor') 296@click.option('--cursor', type=int, help='Cursor position to start from (unix microseconds)') 297@click.option('--output', type=click.Choice(['display', 'json']), default='display', help='Output format') 298@click.option('--verbose', '-v', is_flag=True, help='Enable verbose logging') 299def main(config: Optional[str], dids: Optional[str], cursor: Optional[int], output: str, verbose: bool): 300 """Listen for stream.thought.blip records on ATProto jetstream.""" 301 global handler_instance 302 303 # Set up logging level 304 if verbose: 305 logging.getLogger().setLevel(logging.DEBUG) 306 307 try: 308 # Load configuration 309 app_config = load_config(config) 310 311 # Create handler 312 handler_instance = JetstreamHandler(app_config) 313 314 # Override wanted DIDs if provided via command line 315 if dids: 316 did_list = [did.strip() for did in dids.split(',') if did.strip()] 317 handler_instance.wanted_dids = set(did_list) 318 logger.info(f"Monitoring DIDs: {did_list}") 319 elif handler_instance.wanted_dids: 320 logger.info(f"Monitoring configured DIDs: {list(handler_instance.wanted_dids)}") 321 else: 322 logger.info("Monitoring all DIDs (no filter applied)") 323 324 # Set cursor if provided 325 if cursor: 326 handler_instance.cursor = cursor 327 logger.info(f"Starting from cursor: {cursor}") 328 329 # Set output format 330 handler_instance.set_output_format(output) 331 332 # Set up signal handlers 333 signal.signal(signal.SIGINT, signal_handler) 334 signal.signal(signal.SIGTERM, signal_handler) 335 336 # Run the handler 337 asyncio.run(handler_instance.run_with_reconnect()) 338 339 except KeyboardInterrupt: 340 logger.info("Interrupted by user") 341 except Exception as e: 342 logger.error(f"Fatal error: {e}") 343 sys.exit(1) 344 finally: 345 if handler_instance: 346 asyncio.run(handler_instance.stop()) 347 348 349if __name__ == '__main__': 350 main()