this repo has no description
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

at 20bc691d73a47fecbfa97590f9b88b3842d9df07 553 lines 20 kB view raw
1import os 2import logging 3from typing import Optional, Dict, Any, List 4from atproto_client import Client, Session, SessionEvent, models 5 6# Configure logging 7logging.basicConfig( 8 level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" 9) 10logger = logging.getLogger("bluesky_session_handler") 11 12# Load the environment variables 13import dotenv 14dotenv.load_dotenv(override=True) 15 16import yaml 17import json 18 19# Strip fields. A list of fields to remove from a JSON object 20STRIP_FIELDS = [ 21 "cid", 22 "rev", 23 "did", 24 "uri", 25 "langs", 26 "threadgate", 27 "py_type", 28 "labels", 29 "facets", 30 "avatar", 31 "viewer", 32 "indexed_at", 33 "tags", 34 "associated", 35 "thread_context", 36 "aspect_ratio", 37 "thumb", 38 "fullsize", 39 "root", 40 "created_at", 41 "verification", 42 "like_count", 43 "quote_count", 44 "reply_count", 45 "repost_count", 46 "embedding_disabled", 47 "thread_muted", 48 "reply_disabled", 49 "pinned", 50 "like", 51 "repost", 52 "blocked_by", 53 "blocking", 54 "blocking_by_list", 55 "followed_by", 56 "following", 57 "known_followers", 58 "muted", 59 "muted_by_list", 60 "root_author_like", 61 "entities", 62 "ref", 63 "mime_type", 64 "size", 65] 66def convert_to_basic_types(obj): 67 """Convert complex Python objects to basic types for JSON/YAML serialization.""" 68 if hasattr(obj, '__dict__'): 69 # Convert objects with __dict__ to their dictionary representation 70 return convert_to_basic_types(obj.__dict__) 71 elif isinstance(obj, dict): 72 return {key: convert_to_basic_types(value) for key, value in obj.items()} 73 elif isinstance(obj, list): 74 return [convert_to_basic_types(item) for item in obj] 75 elif isinstance(obj, (str, int, float, bool)) or obj is None: 76 return obj 77 else: 78 # For other types, try to convert to string 79 return str(obj) 80 81 82def strip_fields(obj, strip_field_list): 83 """Recursively strip fields from a JSON object.""" 84 if isinstance(obj, dict): 85 keys_flagged_for_removal = [] 86 87 # Remove fields from strip list and pydantic metadata 88 for field in list(obj.keys()): 89 if field in strip_field_list or field.startswith("__"): 90 keys_flagged_for_removal.append(field) 91 92 # Remove flagged keys 93 for key in keys_flagged_for_removal: 94 obj.pop(key, None) 95 96 # Recursively process remaining values 97 for key, value in list(obj.items()): 98 obj[key] = strip_fields(value, strip_field_list) 99 # Remove empty/null values after processing 100 if ( 101 obj[key] is None 102 or (isinstance(obj[key], dict) and len(obj[key]) == 0) 103 or (isinstance(obj[key], list) and len(obj[key]) == 0) 104 or (isinstance(obj[key], str) and obj[key].strip() == "") 105 ): 106 obj.pop(key, None) 107 108 elif isinstance(obj, list): 109 for i, value in enumerate(obj): 110 obj[i] = strip_fields(value, strip_field_list) 111 # Remove None values from list 112 obj[:] = [item for item in obj if item is not None] 113 114 return obj 115 116 117def flatten_thread_structure(thread_data): 118 """ 119 Flatten a nested thread structure into a list while preserving all data. 120 121 Args: 122 thread_data: The thread data from get_post_thread 123 124 Returns: 125 Dict with 'posts' key containing a list of posts in chronological order 126 """ 127 posts = [] 128 129 def traverse_thread(node): 130 """Recursively traverse the thread structure to collect posts.""" 131 if not node: 132 return 133 134 # If this node has a parent, traverse it first (to maintain chronological order) 135 if hasattr(node, 'parent') and node.parent: 136 traverse_thread(node.parent) 137 138 # Then add this node's post 139 if hasattr(node, 'post') and node.post: 140 # Convert to dict if needed to ensure we can process it 141 if hasattr(node.post, '__dict__'): 142 post_dict = node.post.__dict__.copy() 143 elif isinstance(node.post, dict): 144 post_dict = node.post.copy() 145 else: 146 post_dict = {} 147 148 posts.append(post_dict) 149 150 # Handle the thread structure 151 if hasattr(thread_data, 'thread'): 152 # Start from the main thread node 153 traverse_thread(thread_data.thread) 154 elif hasattr(thread_data, '__dict__') and 'thread' in thread_data.__dict__: 155 traverse_thread(thread_data.__dict__['thread']) 156 157 # Return a simple structure with posts list 158 return {'posts': posts} 159 160 161def thread_to_yaml_string(thread, strip_metadata=True): 162 """ 163 Convert thread data to a YAML-formatted string for LLM parsing. 164 165 Args: 166 thread: The thread data from get_post_thread 167 strip_metadata: Whether to strip metadata fields for cleaner output 168 169 Returns: 170 YAML-formatted string representation of the thread 171 """ 172 # First flatten the thread structure to avoid deep nesting 173 flattened = flatten_thread_structure(thread) 174 175 # Convert complex objects to basic types 176 basic_thread = convert_to_basic_types(flattened) 177 178 if strip_metadata: 179 # Create a copy and strip unwanted fields 180 cleaned_thread = strip_fields(basic_thread, STRIP_FIELDS) 181 else: 182 cleaned_thread = basic_thread 183 184 return yaml.dump(cleaned_thread, indent=2, allow_unicode=True, default_flow_style=False) 185 186 187 188 189 190 191 192def get_session(username: str) -> Optional[str]: 193 try: 194 with open(f"session_{username}.txt", encoding="UTF-8") as f: 195 return f.read() 196 except FileNotFoundError: 197 logger.debug(f"No existing session found for {username}") 198 return None 199 200def save_session(username: str, session_string: str) -> None: 201 with open(f"session_{username}.txt", "w", encoding="UTF-8") as f: 202 f.write(session_string) 203 logger.debug(f"Session saved for {username}") 204 205def on_session_change(username: str, event: SessionEvent, session: Session) -> None: 206 logger.debug(f"Session changed: {event} {repr(session)}") 207 if event in (SessionEvent.CREATE, SessionEvent.REFRESH): 208 logger.debug(f"Saving changed session for {username}") 209 save_session(username, session.export()) 210 211def init_client(username: str, password: str) -> Client: 212 pds_uri = os.getenv("PDS_URI") 213 if pds_uri is None: 214 logger.warning( 215 "No PDS URI provided. Falling back to bsky.social. Note! If you are on a non-Bluesky PDS, this can cause logins to fail. Please provide a PDS URI using the PDS_URI environment variable." 216 ) 217 pds_uri = "https://bsky.social" 218 219 # Print the PDS URI 220 logger.debug(f"Using PDS URI: {pds_uri}") 221 222 client = Client(pds_uri) 223 client.on_session_change( 224 lambda event, session: on_session_change(username, event, session) 225 ) 226 227 session_string = get_session(username) 228 if session_string: 229 logger.debug(f"Reusing existing session for {username}") 230 client.login(session_string=session_string) 231 else: 232 logger.debug(f"Creating new session for {username}") 233 client.login(username, password) 234 235 return client 236 237 238def default_login() -> Client: 239 username = os.getenv("BSKY_USERNAME") 240 password = os.getenv("BSKY_PASSWORD") 241 242 if username is None: 243 logger.error( 244 "No username provided. Please provide a username using the BSKY_USERNAME environment variable." 245 ) 246 exit() 247 248 if password is None: 249 logger.error( 250 "No password provided. Please provide a password using the BSKY_PASSWORD environment variable." 251 ) 252 exit() 253 254 return init_client(username, password) 255 256def reply_to_post(client: Client, text: str, reply_to_uri: str, reply_to_cid: str, root_uri: Optional[str] = None, root_cid: Optional[str] = None, lang: Optional[str] = None) -> Dict[str, Any]: 257 """ 258 Reply to a post on Bluesky with rich text support. 259 260 Args: 261 client: Authenticated Bluesky client 262 text: The reply text 263 reply_to_uri: The URI of the post being replied to (parent) 264 reply_to_cid: The CID of the post being replied to (parent) 265 root_uri: The URI of the root post (if replying to a reply). If None, uses reply_to_uri 266 root_cid: The CID of the root post (if replying to a reply). If None, uses reply_to_cid 267 lang: Language code for the post (e.g., 'en-US', 'es', 'ja') 268 269 Returns: 270 The response from sending the post 271 """ 272 import re 273 274 # If root is not provided, this is a reply to the root post 275 if root_uri is None: 276 root_uri = reply_to_uri 277 root_cid = reply_to_cid 278 279 # Create references for the reply 280 parent_ref = models.create_strong_ref(models.ComAtprotoRepoStrongRef.Main(uri=reply_to_uri, cid=reply_to_cid)) 281 root_ref = models.create_strong_ref(models.ComAtprotoRepoStrongRef.Main(uri=root_uri, cid=root_cid)) 282 283 # Parse rich text facets (mentions and URLs) 284 facets = [] 285 text_bytes = text.encode("UTF-8") 286 287 # Parse mentions - fixed to handle @ at start of text 288 mention_regex = rb"(?:^|[$|\W])(@([a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)" 289 290 for m in re.finditer(mention_regex, text_bytes): 291 handle = m.group(1)[1:].decode("UTF-8") # Remove @ prefix 292 # Adjust byte positions to account for the optional prefix 293 mention_start = m.start(1) 294 mention_end = m.end(1) 295 try: 296 # Resolve handle to DID using the API 297 resolve_resp = client.app.bsky.actor.get_profile({'actor': handle}) 298 if resolve_resp and hasattr(resolve_resp, 'did'): 299 facets.append( 300 models.AppBskyRichtextFacet.Main( 301 index=models.AppBskyRichtextFacet.ByteSlice( 302 byteStart=mention_start, 303 byteEnd=mention_end 304 ), 305 features=[models.AppBskyRichtextFacet.Mention(did=resolve_resp.did)] 306 ) 307 ) 308 except Exception as e: 309 logger.debug(f"Failed to resolve handle {handle}: {e}") 310 continue 311 312 # Parse URLs - fixed to handle URLs at start of text 313 url_regex = rb"(?:^|[$|\W])(https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*[-a-zA-Z0-9@%_\+~#//=])?)" 314 315 for m in re.finditer(url_regex, text_bytes): 316 url = m.group(1).decode("UTF-8") 317 # Adjust byte positions to account for the optional prefix 318 url_start = m.start(1) 319 url_end = m.end(1) 320 facets.append( 321 models.AppBskyRichtextFacet.Main( 322 index=models.AppBskyRichtextFacet.ByteSlice( 323 byteStart=url_start, 324 byteEnd=url_end 325 ), 326 features=[models.AppBskyRichtextFacet.Link(uri=url)] 327 ) 328 ) 329 330 # Send the reply with facets if any were found 331 if facets: 332 response = client.send_post( 333 text=text, 334 reply_to=models.AppBskyFeedPost.ReplyRef(parent=parent_ref, root=root_ref), 335 facets=facets, 336 langs=[lang] if lang else None 337 ) 338 else: 339 response = client.send_post( 340 text=text, 341 reply_to=models.AppBskyFeedPost.ReplyRef(parent=parent_ref, root=root_ref), 342 langs=[lang] if lang else None 343 ) 344 345 logger.info(f"Reply sent successfully: {response.uri}") 346 return response 347 348 349def get_post_thread(client: Client, uri: str) -> Optional[Dict[str, Any]]: 350 """ 351 Get the thread containing a post to find root post information. 352 353 Args: 354 client: Authenticated Bluesky client 355 uri: The URI of the post 356 357 Returns: 358 The thread data or None if not found 359 """ 360 try: 361 thread = client.app.bsky.feed.get_post_thread({'uri': uri, 'parent_height': 60, 'depth': 10}) 362 return thread 363 except Exception as e: 364 logger.error(f"Error fetching post thread: {e}") 365 return None 366 367 368def reply_to_notification(client: Client, notification: Any, reply_text: str, lang: str = "en-US") -> Optional[Dict[str, Any]]: 369 """ 370 Reply to a notification (mention or reply). 371 372 Args: 373 client: Authenticated Bluesky client 374 notification: The notification object from list_notifications 375 reply_text: The text to reply with 376 lang: Language code for the post (defaults to "en-US") 377 378 Returns: 379 The response from sending the reply or None if failed 380 """ 381 try: 382 # Get the post URI and CID from the notification (handle both dict and object) 383 if isinstance(notification, dict): 384 post_uri = notification.get('uri') 385 post_cid = notification.get('cid') 386 elif hasattr(notification, 'uri') and hasattr(notification, 'cid'): 387 post_uri = notification.uri 388 post_cid = notification.cid 389 else: 390 post_uri = None 391 post_cid = None 392 393 if not post_uri or not post_cid: 394 logger.error("Notification doesn't have required uri/cid fields") 395 return None 396 397 # Get the thread to find the root post 398 thread_data = get_post_thread(client, post_uri) 399 400 if thread_data and hasattr(thread_data, 'thread'): 401 thread = thread_data.thread 402 403 # Find root post 404 root_uri = post_uri 405 root_cid = post_cid 406 407 # If this has a parent, find the root 408 if hasattr(thread, 'parent') and thread.parent: 409 # Keep going up until we find the root 410 current = thread 411 while hasattr(current, 'parent') and current.parent: 412 current = current.parent 413 if hasattr(current, 'post') and hasattr(current.post, 'uri') and hasattr(current.post, 'cid'): 414 root_uri = current.post.uri 415 root_cid = current.post.cid 416 417 # Reply to the notification 418 return reply_to_post( 419 client=client, 420 text=reply_text, 421 reply_to_uri=post_uri, 422 reply_to_cid=post_cid, 423 root_uri=root_uri, 424 root_cid=root_cid, 425 lang=lang 426 ) 427 else: 428 # If we can't get thread data, just reply directly 429 return reply_to_post( 430 client=client, 431 text=reply_text, 432 reply_to_uri=post_uri, 433 reply_to_cid=post_cid, 434 lang=lang 435 ) 436 437 except Exception as e: 438 logger.error(f"Error replying to notification: {e}") 439 return None 440 441 442def reply_with_thread_to_notification(client: Client, notification: Any, reply_messages: List[str], lang: str = "en-US") -> Optional[List[Dict[str, Any]]]: 443 """ 444 Reply to a notification with a threaded chain of messages (max 4). 445 446 Args: 447 client: Authenticated Bluesky client 448 notification: The notification object from list_notifications 449 reply_messages: List of reply texts (max 4 messages, each max 300 chars) 450 lang: Language code for the posts (defaults to "en-US") 451 452 Returns: 453 List of responses from sending the replies or None if failed 454 """ 455 try: 456 # Validate input 457 if not reply_messages or len(reply_messages) == 0: 458 logger.error("Reply messages list cannot be empty") 459 return None 460 if len(reply_messages) > 4: 461 logger.error(f"Cannot send more than 4 reply messages (got {len(reply_messages)})") 462 return None 463 464 # Get the post URI and CID from the notification (handle both dict and object) 465 if isinstance(notification, dict): 466 post_uri = notification.get('uri') 467 post_cid = notification.get('cid') 468 elif hasattr(notification, 'uri') and hasattr(notification, 'cid'): 469 post_uri = notification.uri 470 post_cid = notification.cid 471 else: 472 post_uri = None 473 post_cid = None 474 475 if not post_uri or not post_cid: 476 logger.error("Notification doesn't have required uri/cid fields") 477 return None 478 479 # Get the thread to find the root post 480 thread_data = get_post_thread(client, post_uri) 481 482 root_uri = post_uri 483 root_cid = post_cid 484 485 if thread_data and hasattr(thread_data, 'thread'): 486 thread = thread_data.thread 487 # If this has a parent, find the root 488 if hasattr(thread, 'parent') and thread.parent: 489 # Keep going up until we find the root 490 current = thread 491 while hasattr(current, 'parent') and current.parent: 492 current = current.parent 493 if hasattr(current, 'post') and hasattr(current.post, 'uri') and hasattr(current.post, 'cid'): 494 root_uri = current.post.uri 495 root_cid = current.post.cid 496 497 # Send replies in sequence, creating a thread 498 responses = [] 499 current_parent_uri = post_uri 500 current_parent_cid = post_cid 501 502 for i, message in enumerate(reply_messages): 503 logger.info(f"Sending reply {i+1}/{len(reply_messages)}: {message[:50]}...") 504 505 # Send this reply 506 response = reply_to_post( 507 client=client, 508 text=message, 509 reply_to_uri=current_parent_uri, 510 reply_to_cid=current_parent_cid, 511 root_uri=root_uri, 512 root_cid=root_cid, 513 lang=lang 514 ) 515 516 if not response: 517 logger.error(f"Failed to send reply {i+1}, posting system failure message") 518 # Try to post a system failure message 519 failure_response = reply_to_post( 520 client=client, 521 text="[SYSTEM FAILURE: COULD NOT POST MESSAGE, PLEASE TRY AGAIN]", 522 reply_to_uri=current_parent_uri, 523 reply_to_cid=current_parent_cid, 524 root_uri=root_uri, 525 root_cid=root_cid, 526 lang=lang 527 ) 528 if failure_response: 529 responses.append(failure_response) 530 current_parent_uri = failure_response.uri 531 current_parent_cid = failure_response.cid 532 else: 533 logger.error("Could not even send system failure message, stopping thread") 534 return responses if responses else None 535 else: 536 responses.append(response) 537 # Update parent references for next reply (if any) 538 if i < len(reply_messages) - 1: # Not the last message 539 current_parent_uri = response.uri 540 current_parent_cid = response.cid 541 542 logger.info(f"Successfully sent {len(responses)} threaded replies") 543 return responses 544 545 except Exception as e: 546 logger.error(f"Error sending threaded reply to notification: {e}") 547 return None 548 549 550if __name__ == "__main__": 551 client = default_login() 552 # do something with the client 553 logger.info("Client is ready to use!")