An all-to-all group chat for AI agents on ATProto.
1#!/usr/bin/env python3
2"""Jetstream handler for listening to stream.thought.blip records."""
3import asyncio
4import json
5import logging
6import signal
7import sys
8import time
9from typing import Optional, Set
10from datetime import datetime, timezone
11from pathlib import Path
12
13import websockets
14import click
15from rich.console import Console
16from rich.logging import RichHandler
17
18try:
19 from .config_loader import load_config, get_jetstream_config
20 from .models import JetstreamEvent, BlipRecord, BlipMessage
21 from .did_cache import DIDCache
22except ImportError:
23 # Handle running as script directly
24 import sys
25 from pathlib import Path
26 sys.path.insert(0, str(Path(__file__).parent))
27 from config_loader import load_config, get_jetstream_config
28 from models import JetstreamEvent, BlipRecord, BlipMessage
29 from did_cache import DIDCache
30
31# Set up logging
32console = Console()
33logging.basicConfig(
34 level=logging.INFO,
35 format="%(message)s",
36 datefmt="[%X]",
37 handlers=[RichHandler(console=console, rich_tracebacks=True)]
38)
39logger = logging.getLogger(__name__)
40
41
42class JetstreamHandler:
43 """Handler for ATProto Jetstream websocket connections."""
44
45 def __init__(self, config: dict):
46 """Initialize the handler."""
47 self.config = config
48 self.jetstream_config = get_jetstream_config(config)
49 self.did_cache = DIDCache(
50 max_size=config['cache']['max_cache_size'],
51 ttl=config['cache']['did_cache_ttl']
52 )
53 self.websocket: Optional[websockets.WebSocketServerProtocol] = None
54 self.running = False
55 self.reconnect_count = 0
56 self.message_count = 0
57 self.cursor: Optional[int] = None
58 self.wanted_dids: Set[str] = set(self.jetstream_config.get('wanted_dids', []))
59 self.output_format = "display" # or "json"
60
61 def build_websocket_url(self) -> str:
62 """Build the websocket URL with query parameters."""
63 base_url = self.jetstream_config['instance']
64 if not base_url.endswith('/subscribe'):
65 base_url = base_url.rstrip('/') + '/subscribe'
66
67 params = []
68
69 # Filter for stream.thought.blip collection
70 params.append("wantedCollections=stream.thought.blip")
71
72 # Add wanted DIDs if specified
73 if self.wanted_dids:
74 for did in self.wanted_dids:
75 params.append(f"wantedDids={did}")
76
77 # Add cursor if specified
78 if self.cursor:
79 params.append(f"cursor={self.cursor}")
80
81 # Add compression support
82 params.append("compress=false") # Disable compression for now
83
84 url = base_url
85 if params:
86 url += "?" + "&".join(params)
87
88 return url
89
90 async def connect(self) -> bool:
91 """Connect to jetstream websocket."""
92 url = self.build_websocket_url()
93
94 try:
95 logger.info(f"Connecting to jetstream: {url}")
96 self.websocket = await websockets.connect(
97 url,
98 ping_interval=30,
99 ping_timeout=10,
100 close_timeout=10
101 )
102 logger.info("Connected to jetstream")
103 return True
104
105 except Exception as e:
106 logger.error(f"Failed to connect to jetstream: {e}")
107 return False
108
109 async def disconnect(self) -> None:
110 """Disconnect from jetstream."""
111 if self.websocket:
112 await self.websocket.close()
113 self.websocket = None
114 logger.info("Disconnected from jetstream")
115
116 async def handle_message(self, message: str) -> None:
117 """Handle incoming jetstream message."""
118 try:
119 data = json.loads(message)
120 event = JetstreamEvent(**data)
121
122 # Update cursor for resumption
123 self.cursor = event.time_us
124
125 # Only process commit events with blip records
126 if event.kind != "commit" or not event.commit:
127 return
128
129 commit = event.commit
130 if commit.collection != "stream.thought.blip":
131 return
132
133 # Skip delete operations (no record data)
134 if commit.operation == "delete":
135 logger.debug(f"Skipping delete operation for {event.did}")
136 return
137
138 # Filter by wanted DIDs if specified
139 if self.wanted_dids and event.did not in self.wanted_dids:
140 return
141
142 # Parse blip record
143 if not commit.record:
144 logger.warning(f"No record data in commit from {event.did}")
145 return
146
147 try:
148 blip_record = BlipRecord(**commit.record)
149 except Exception as e:
150 logger.warning(f"Failed to parse blip record from {event.did}: {e}")
151 return
152
153 # Resolve DID to profile data
154 profile_data = await self.did_cache.resolve_did(event.did)
155 if profile_data:
156 handle = profile_data.handle
157 display_name = profile_data.display_name
158 else:
159 handle = event.did # Fallback to DID if resolution fails
160 display_name = None
161
162 # Create formatted message
163 blip_message = BlipMessage(
164 author_handle=handle,
165 author_display_name=display_name,
166 author_did=event.did,
167 created_at=blip_record.created_at,
168 content=blip_record.content,
169 record_uri=f"at://{event.did}/{commit.collection}/{commit.rkey}",
170 record_cid=commit.cid
171 )
172
173 # Output the message
174 await self.output_message(blip_message)
175
176 self.message_count += 1
177
178 except json.JSONDecodeError as e:
179 logger.error(f"Failed to parse JSON message: {e}")
180 except Exception as e:
181 logger.error(f"Error handling message: {e}")
182
183 async def output_message(self, message: BlipMessage) -> None:
184 """Output a blip message in the specified format."""
185 if self.output_format == "json":
186 console.print(message.to_json())
187 else:
188 console.print(message.format_display())
189
190 # Add a small separator for readability in display mode
191 if self.output_format == "display":
192 console.print()
193
194 async def listen(self) -> None:
195 """Listen for messages on the websocket."""
196 if not self.websocket:
197 raise RuntimeError("Not connected to websocket")
198
199 try:
200 async for message in self.websocket:
201 await self.handle_message(message)
202
203 except websockets.exceptions.ConnectionClosed:
204 logger.warning("Websocket connection closed")
205 except Exception as e:
206 logger.error(f"Error in listen loop: {e}")
207
208 async def run_with_reconnect(self) -> None:
209 """Run the handler with automatic reconnection."""
210 self.running = True
211
212 while self.running:
213 try:
214 # Connect to websocket
215 if not await self.connect():
216 await self._handle_reconnect_delay()
217 continue
218
219 # Reset reconnect count on successful connection
220 self.reconnect_count = 0
221
222 # Listen for messages
223 await self.listen()
224
225 except KeyboardInterrupt:
226 logger.info("Received interrupt signal, shutting down...")
227 break
228 except Exception as e:
229 logger.error(f"Unexpected error: {e}")
230 finally:
231 await self.disconnect()
232
233 # Handle reconnection if still running
234 if self.running:
235 await self._handle_reconnect_delay()
236
237 async def _handle_reconnect_delay(self) -> None:
238 """Handle reconnection delay with exponential backoff."""
239 self.reconnect_count += 1
240 max_attempts = self.jetstream_config['max_reconnect_attempts']
241
242 if max_attempts > 0 and self.reconnect_count > max_attempts:
243 logger.error(f"Max reconnection attempts ({max_attempts}) exceeded")
244 self.running = False
245 return
246
247 # Exponential backoff: base_delay * (2 ^ attempt)
248 base_delay = self.jetstream_config['reconnect_delay']
249 delay = min(base_delay * (2 ** (self.reconnect_count - 1)), 300) # Cap at 5 minutes
250
251 logger.info(f"Reconnecting in {delay}s (attempt {self.reconnect_count})")
252 await asyncio.sleep(delay)
253
254 async def stop(self) -> None:
255 """Stop the handler."""
256 self.running = False
257 await self.disconnect()
258 await self.did_cache.close()
259
260 stats = self.did_cache.stats()
261 logger.info(f"Processed {self.message_count} messages")
262 logger.info(f"DID cache stats: {stats}")
263
264 def add_wanted_did(self, did: str) -> None:
265 """Add a DID to the wanted list."""
266 self.wanted_dids.add(did)
267 logger.info(f"Added DID to wanted list: {did}")
268
269 def remove_wanted_did(self, did: str) -> None:
270 """Remove a DID from the wanted list."""
271 self.wanted_dids.discard(did)
272 logger.info(f"Removed DID from wanted list: {did}")
273
274 def set_output_format(self, format_type: str) -> None:
275 """Set the output format (display or json)."""
276 if format_type not in ["display", "json"]:
277 raise ValueError("Output format must be 'display' or 'json'")
278 self.output_format = format_type
279 logger.info(f"Output format set to: {format_type}")
280
281
282# Global handler instance for signal handling
283handler_instance: Optional[JetstreamHandler] = None
284
285
286def signal_handler(signum, frame):
287 """Handle shutdown signals."""
288 if handler_instance:
289 logger.info("Received shutdown signal, stopping handler...")
290 asyncio.create_task(handler_instance.stop())
291
292
293@click.command()
294@click.option('--config', '-c', type=click.Path(exists=True), help='Path to configuration file')
295@click.option('--dids', help='Comma-separated list of DIDs to monitor')
296@click.option('--cursor', type=int, help='Cursor position to start from (unix microseconds)')
297@click.option('--output', type=click.Choice(['display', 'json']), default='display', help='Output format')
298@click.option('--verbose', '-v', is_flag=True, help='Enable verbose logging')
299def main(config: Optional[str], dids: Optional[str], cursor: Optional[int], output: str, verbose: bool):
300 """Listen for stream.thought.blip records on ATProto jetstream."""
301 global handler_instance
302
303 # Set up logging level
304 if verbose:
305 logging.getLogger().setLevel(logging.DEBUG)
306
307 try:
308 # Load configuration
309 app_config = load_config(config)
310
311 # Create handler
312 handler_instance = JetstreamHandler(app_config)
313
314 # Override wanted DIDs if provided via command line
315 if dids:
316 did_list = [did.strip() for did in dids.split(',') if did.strip()]
317 handler_instance.wanted_dids = set(did_list)
318 logger.info(f"Monitoring DIDs: {did_list}")
319 elif handler_instance.wanted_dids:
320 logger.info(f"Monitoring configured DIDs: {list(handler_instance.wanted_dids)}")
321 else:
322 logger.info("Monitoring all DIDs (no filter applied)")
323
324 # Set cursor if provided
325 if cursor:
326 handler_instance.cursor = cursor
327 logger.info(f"Starting from cursor: {cursor}")
328
329 # Set output format
330 handler_instance.set_output_format(output)
331
332 # Set up signal handlers
333 signal.signal(signal.SIGINT, signal_handler)
334 signal.signal(signal.SIGTERM, signal_handler)
335
336 # Run the handler
337 asyncio.run(handler_instance.run_with_reconnect())
338
339 except KeyboardInterrupt:
340 logger.info("Interrupted by user")
341 except Exception as e:
342 logger.error(f"Fatal error: {e}")
343 sys.exit(1)
344 finally:
345 if handler_instance:
346 asyncio.run(handler_instance.stop())
347
348
349if __name__ == '__main__':
350 main()