this repo has no description
1import json 2import yaml 3import dotenv 4import os 5import logging 6from typing import Optional, Dict, Any, List 7from atproto_client import Client, Session, SessionEvent, models 8 9# Configure logging 10logging.basicConfig( 11 level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" 12) 13logger = logging.getLogger("bluesky_session_handler") 14 15# Load the environment variables 16dotenv.load_dotenv(override=True) 17 18 19# Strip fields. A list of fields to remove from a JSON object 20STRIP_FIELDS = [ 21 "cid", 22 "rev", 23 "did", 24 "uri", 25 "langs", 26 "threadgate", 27 "py_type", 28 "labels", 29 "facets", 30 "avatar", 31 "viewer", 32 "indexed_at", 33 "tags", 34 "associated", 35 "thread_context", 36 "aspect_ratio", 37 "thumb", 38 "fullsize", 39 "root", 40 "created_at", 41 "verification", 42 "like_count", 43 "quote_count", 44 "reply_count", 45 "repost_count", 46 "embedding_disabled", 47 "thread_muted", 48 "reply_disabled", 49 "pinned", 50 "like", 51 "repost", 52 "blocked_by", 53 "blocking", 54 "blocking_by_list", 55 "followed_by", 56 "following", 57 "known_followers", 58 "muted", 59 "muted_by_list", 60 "root_author_like", 61 "entities", 62 "ref", 63 "mime_type", 64 "size", 65] 66 67 68def convert_to_basic_types(obj): 69 """Convert complex Python objects to basic types for JSON/YAML serialization.""" 70 if hasattr(obj, '__dict__'): 71 # Convert objects with __dict__ to their dictionary representation 72 return convert_to_basic_types(obj.__dict__) 73 elif isinstance(obj, dict): 74 return {key: convert_to_basic_types(value) for key, value in obj.items()} 75 elif isinstance(obj, list): 76 return [convert_to_basic_types(item) for item in obj] 77 elif isinstance(obj, (str, int, float, bool)) or obj is None: 78 return obj 79 else: 80 # For other types, try to convert to string 81 return str(obj) 82 83 84def strip_fields(obj, strip_field_list): 85 """Recursively strip fields from a JSON object.""" 86 if isinstance(obj, dict): 87 keys_flagged_for_removal = [] 88 89 # Remove fields from strip list and pydantic metadata 90 for field in list(obj.keys()): 91 if field in strip_field_list or field.startswith("__"): 92 keys_flagged_for_removal.append(field) 93 94 # Remove flagged keys 95 for key in keys_flagged_for_removal: 96 obj.pop(key, None) 97 98 # Recursively process remaining values 99 for key, value in list(obj.items()): 100 obj[key] = strip_fields(value, strip_field_list) 101 # Remove empty/null values after processing 102 if ( 103 obj[key] is None 104 or (isinstance(obj[key], dict) and len(obj[key]) == 0) 105 or (isinstance(obj[key], list) and len(obj[key]) == 0) 106 or (isinstance(obj[key], str) and obj[key].strip() == "") 107 ): 108 obj.pop(key, None) 109 110 elif isinstance(obj, list): 111 for i, value in enumerate(obj): 112 obj[i] = strip_fields(value, strip_field_list) 113 # Remove None values from list 114 obj[:] = [item for item in obj if item is not None] 115 116 return obj 117 118 119def flatten_thread_structure(thread_data): 120 """ 121 Flatten a nested thread structure into a list while preserving all data. 122 123 Args: 124 thread_data: The thread data from get_post_thread 125 126 Returns: 127 Dict with 'posts' key containing a list of posts in chronological order 128 """ 129 posts = [] 130 131 def traverse_thread(node): 132 """Recursively traverse the thread structure to collect posts.""" 133 if not node: 134 return 135 136 # If this node has a parent, traverse it first (to maintain chronological order) 137 if hasattr(node, 'parent') and node.parent: 138 traverse_thread(node.parent) 139 140 # Then add this node's post 141 if hasattr(node, 'post') and node.post: 142 # Convert to dict if needed to ensure we can process it 143 if hasattr(node.post, '__dict__'): 144 post_dict = node.post.__dict__.copy() 145 elif isinstance(node.post, dict): 146 post_dict = node.post.copy() 147 else: 148 post_dict = {} 149 150 posts.append(post_dict) 151 152 # Handle the thread structure 153 if hasattr(thread_data, 'thread'): 154 # Start from the main thread node 155 traverse_thread(thread_data.thread) 156 elif hasattr(thread_data, '__dict__') and 'thread' in thread_data.__dict__: 157 traverse_thread(thread_data.__dict__['thread']) 158 159 # Return a simple structure with posts list 160 return {'posts': posts} 161 162 163def thread_to_yaml_string(thread, strip_metadata=True): 164 """ 165 Convert thread data to a YAML-formatted string for LLM parsing. 166 167 Args: 168 thread: The thread data from get_post_thread 169 strip_metadata: Whether to strip metadata fields for cleaner output 170 171 Returns: 172 YAML-formatted string representation of the thread 173 """ 174 # First flatten the thread structure to avoid deep nesting 175 flattened = flatten_thread_structure(thread) 176 177 # Convert complex objects to basic types 178 basic_thread = convert_to_basic_types(flattened) 179 180 if strip_metadata: 181 # Create a copy and strip unwanted fields 182 cleaned_thread = strip_fields(basic_thread, STRIP_FIELDS) 183 else: 184 cleaned_thread = basic_thread 185 186 return yaml.dump(cleaned_thread, indent=2, allow_unicode=True, default_flow_style=False) 187 188 189def get_session(username: str) -> Optional[str]: 190 try: 191 with open(f"session_{username}.txt", encoding="UTF-8") as f: 192 return f.read() 193 except FileNotFoundError: 194 logger.debug(f"No existing session found for {username}") 195 return None 196 197 198def save_session(username: str, session_string: str) -> None: 199 with open(f"session_{username}.txt", "w", encoding="UTF-8") as f: 200 f.write(session_string) 201 logger.debug(f"Session saved for {username}") 202 203 204def on_session_change(username: str, event: SessionEvent, session: Session) -> None: 205 logger.debug(f"Session changed: {event} {repr(session)}") 206 if event in (SessionEvent.CREATE, SessionEvent.REFRESH): 207 logger.debug(f"Saving changed session for {username}") 208 save_session(username, session.export()) 209 210 211def init_client(username: str, password: str, pds_uri: str = "https://bsky.social") -> Client: 212 if pds_uri is None: 213 logger.warning( 214 "No PDS URI provided. Falling back to bsky.social. Note! If you are on a non-Bluesky PDS, this can cause logins to fail. Please provide a PDS URI using the PDS_URI environment variable." 215 ) 216 pds_uri = "https://bsky.social" 217 218 # Print the PDS URI 219 logger.debug(f"Using PDS URI: {pds_uri}") 220 221 client = Client(pds_uri) 222 client.on_session_change( 223 lambda event, session: on_session_change(username, event, session) 224 ) 225 226 session_string = get_session(username) 227 if session_string: 228 logger.debug(f"Reusing existing session for {username}") 229 client.login(session_string=session_string) 230 else: 231 logger.debug(f"Creating new session for {username}") 232 client.login(username, password) 233 234 return client 235 236 237def default_login() -> Client: 238 # Try to load from config first, fall back to environment variables 239 try: 240 from config_loader import get_bluesky_config 241 config = get_bluesky_config() 242 username = config['username'] 243 password = config['password'] 244 pds_uri = config['pds_uri'] 245 except (ImportError, FileNotFoundError, KeyError) as e: 246 logger.warning( 247 f"Could not load from config file ({e}), falling back to environment variables") 248 username = os.getenv("BSKY_USERNAME") 249 password = os.getenv("BSKY_PASSWORD") 250 pds_uri = os.getenv("PDS_URI", "https://bsky.social") 251 252 if username is None: 253 logger.error( 254 "No username provided. Please provide a username using the BSKY_USERNAME environment variable or config.yaml." 255 ) 256 exit() 257 258 if password is None: 259 logger.error( 260 "No password provided. Please provide a password using the BSKY_PASSWORD environment variable or config.yaml." 261 ) 262 exit() 263 264 return init_client(username, password, pds_uri) 265 266 267def remove_outside_quotes(text: str) -> str: 268 """ 269 Remove outside double quotes from response text. 270 271 Only handles double quotes to avoid interfering with contractions: 272 - Double quotes: "text" → text 273 - Preserves single quotes and internal quotes 274 275 Args: 276 text: The text to process 277 278 Returns: 279 Text with outside double quotes removed 280 """ 281 if not text or len(text) < 2: 282 return text 283 284 text = text.strip() 285 286 # Only remove double quotes from start and end 287 if text.startswith('"') and text.endswith('"'): 288 return text[1:-1] 289 290 return text 291 292 293def reply_to_post(client: Client, text: str, reply_to_uri: str, reply_to_cid: str, root_uri: Optional[str] = None, root_cid: Optional[str] = None, lang: Optional[str] = None) -> Dict[str, Any]: 294 """ 295 Reply to a post on Bluesky with rich text support. 296 297 Args: 298 client: Authenticated Bluesky client 299 text: The reply text 300 reply_to_uri: The URI of the post being replied to (parent) 301 reply_to_cid: The CID of the post being replied to (parent) 302 root_uri: The URI of the root post (if replying to a reply). If None, uses reply_to_uri 303 root_cid: The CID of the root post (if replying to a reply). If None, uses reply_to_cid 304 lang: Language code for the post (e.g., 'en-US', 'es', 'ja') 305 306 Returns: 307 The response from sending the post 308 """ 309 import re 310 311 # If root is not provided, this is a reply to the root post 312 if root_uri is None: 313 root_uri = reply_to_uri 314 root_cid = reply_to_cid 315 316 # Create references for the reply 317 parent_ref = models.create_strong_ref( 318 models.ComAtprotoRepoStrongRef.Main(uri=reply_to_uri, cid=reply_to_cid)) 319 root_ref = models.create_strong_ref( 320 models.ComAtprotoRepoStrongRef.Main(uri=root_uri, cid=root_cid)) 321 322 # Parse rich text facets (mentions and URLs) 323 facets = [] 324 text_bytes = text.encode("UTF-8") 325 326 # Parse mentions - fixed to handle @ at start of text 327 mention_regex = rb"(?:^|[$|\W])(@([a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)" 328 329 for m in re.finditer(mention_regex, text_bytes): 330 handle = m.group(1)[1:].decode("UTF-8") # Remove @ prefix 331 # Adjust byte positions to account for the optional prefix 332 mention_start = m.start(1) 333 mention_end = m.end(1) 334 try: 335 # Resolve handle to DID using the API 336 resolve_resp = client.app.bsky.actor.get_profile({'actor': handle}) 337 if resolve_resp and hasattr(resolve_resp, 'did'): 338 facets.append( 339 models.AppBskyRichtextFacet.Main( 340 index=models.AppBskyRichtextFacet.ByteSlice( 341 byteStart=mention_start, 342 byteEnd=mention_end 343 ), 344 features=[models.AppBskyRichtextFacet.Mention( 345 did=resolve_resp.did)] 346 ) 347 ) 348 except Exception as e: 349 # Handle specific error cases 350 error_str = str(e) 351 if 'Could not find user info' in error_str or 'InvalidRequest' in error_str: 352 logger.warning( 353 f"User @{handle} not found (account may be deleted/suspended), skipping mention facet") 354 elif 'BadRequestError' in error_str: 355 logger.warning( 356 f"Bad request when resolving @{handle}, skipping mention facet: {e}") 357 else: 358 logger.debug(f"Failed to resolve handle @{handle}: {e}") 359 continue 360 361 # Parse URLs - fixed to handle URLs at start of text 362 url_regex = rb"(?:^|[$|\W])(https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*[-a-zA-Z0-9@%_\+~#//=])?)" 363 364 for m in re.finditer(url_regex, text_bytes): 365 url = m.group(1).decode("UTF-8") 366 # Adjust byte positions to account for the optional prefix 367 url_start = m.start(1) 368 url_end = m.end(1) 369 facets.append( 370 models.AppBskyRichtextFacet.Main( 371 index=models.AppBskyRichtextFacet.ByteSlice( 372 byteStart=url_start, 373 byteEnd=url_end 374 ), 375 features=[models.AppBskyRichtextFacet.Link(uri=url)] 376 ) 377 ) 378 379 # Send the reply with facets if any were found 380 if facets: 381 response = client.send_post( 382 text=text, 383 reply_to=models.AppBskyFeedPost.ReplyRef( 384 parent=parent_ref, root=root_ref), 385 facets=facets, 386 langs=[lang] if lang else None 387 ) 388 else: 389 response = client.send_post( 390 text=text, 391 reply_to=models.AppBskyFeedPost.ReplyRef( 392 parent=parent_ref, root=root_ref), 393 langs=[lang] if lang else None 394 ) 395 396 logger.info(f"Reply sent successfully: {response.uri}") 397 return response 398 399 400def get_post_thread(client: Client, uri: str) -> Optional[Dict[str, Any]]: 401 """ 402 Get the thread containing a post to find root post information. 403 404 Args: 405 client: Authenticated Bluesky client 406 uri: The URI of the post 407 408 Returns: 409 The thread data or None if not found 410 """ 411 try: 412 thread = client.app.bsky.feed.get_post_thread( 413 {'uri': uri, 'parent_height': 60, 'depth': 10}) 414 return thread 415 except Exception as e: 416 error_str = str(e) 417 # Handle specific error cases more gracefully 418 if 'Could not find user info' in error_str or 'InvalidRequest' in error_str: 419 logger.warning( 420 f"User account not found for post URI {uri} (account may be deleted/suspended)") 421 elif 'NotFound' in error_str or 'Post not found' in error_str: 422 logger.warning(f"Post not found for URI {uri}") 423 elif 'BadRequestError' in error_str: 424 logger.warning(f"Bad request error for URI {uri}: {e}") 425 else: 426 logger.error(f"Error fetching post thread: {e}") 427 return None 428 429 430def reply_to_notification(client: Client, notification: Any, reply_text: str, lang: str = "en-US") -> Optional[Dict[str, Any]]: 431 """ 432 Reply to a notification (mention or reply). 433 434 Args: 435 client: Authenticated Bluesky client 436 notification: The notification object from list_notifications 437 reply_text: The text to reply with 438 lang: Language code for the post (defaults to "en-US") 439 440 Returns: 441 The response from sending the reply or None if failed 442 """ 443 try: 444 # Get the post URI and CID from the notification (handle both dict and object) 445 if isinstance(notification, dict): 446 post_uri = notification.get('uri') 447 post_cid = notification.get('cid') 448 elif hasattr(notification, 'uri') and hasattr(notification, 'cid'): 449 post_uri = notification.uri 450 post_cid = notification.cid 451 else: 452 post_uri = None 453 post_cid = None 454 455 if not post_uri or not post_cid: 456 logger.error("Notification doesn't have required uri/cid fields") 457 return None 458 459 # Get the thread to find the root post 460 thread_data = get_post_thread(client, post_uri) 461 462 if thread_data and hasattr(thread_data, 'thread'): 463 thread = thread_data.thread 464 465 # Find root post 466 root_uri = post_uri 467 root_cid = post_cid 468 469 # If this has a parent, find the root 470 if hasattr(thread, 'parent') and thread.parent: 471 # Keep going up until we find the root 472 current = thread 473 while hasattr(current, 'parent') and current.parent: 474 current = current.parent 475 if hasattr(current, 'post') and hasattr(current.post, 'uri') and hasattr(current.post, 'cid'): 476 root_uri = current.post.uri 477 root_cid = current.post.cid 478 479 # Reply to the notification 480 return reply_to_post( 481 client=client, 482 text=reply_text, 483 reply_to_uri=post_uri, 484 reply_to_cid=post_cid, 485 root_uri=root_uri, 486 root_cid=root_cid, 487 lang=lang 488 ) 489 else: 490 # If we can't get thread data, just reply directly 491 return reply_to_post( 492 client=client, 493 text=reply_text, 494 reply_to_uri=post_uri, 495 reply_to_cid=post_cid, 496 lang=lang 497 ) 498 499 except Exception as e: 500 logger.error(f"Error replying to notification: {e}") 501 return None 502 503 504def reply_with_thread_to_notification(client: Client, notification: Any, reply_messages: List[str], lang: str = "en-US") -> Optional[List[Dict[str, Any]]]: 505 """ 506 Reply to a notification with a threaded chain of messages (max 15). 507 508 Args: 509 client: Authenticated Bluesky client 510 notification: The notification object from list_notifications 511 reply_messages: List of reply texts (max 15 messages, each max 300 chars) 512 lang: Language code for the posts (defaults to "en-US") 513 514 Returns: 515 List of responses from sending the replies or None if failed 516 """ 517 try: 518 # Validate input 519 if not reply_messages or len(reply_messages) == 0: 520 logger.error("Reply messages list cannot be empty") 521 return None 522 if len(reply_messages) > 15: 523 logger.error( 524 f"Cannot send more than 15 reply messages (got {len(reply_messages)})") 525 return None 526 527 # Get the post URI and CID from the notification (handle both dict and object) 528 if isinstance(notification, dict): 529 post_uri = notification.get('uri') 530 post_cid = notification.get('cid') 531 elif hasattr(notification, 'uri') and hasattr(notification, 'cid'): 532 post_uri = notification.uri 533 post_cid = notification.cid 534 else: 535 post_uri = None 536 post_cid = None 537 538 if not post_uri or not post_cid: 539 logger.error("Notification doesn't have required uri/cid fields") 540 return None 541 542 # Get the thread to find the root post 543 thread_data = get_post_thread(client, post_uri) 544 545 root_uri = post_uri 546 root_cid = post_cid 547 548 if thread_data and hasattr(thread_data, 'thread'): 549 thread = thread_data.thread 550 # If this has a parent, find the root 551 if hasattr(thread, 'parent') and thread.parent: 552 # Keep going up until we find the root 553 current = thread 554 while hasattr(current, 'parent') and current.parent: 555 current = current.parent 556 if hasattr(current, 'post') and hasattr(current.post, 'uri') and hasattr(current.post, 'cid'): 557 root_uri = current.post.uri 558 root_cid = current.post.cid 559 560 # Send replies in sequence, creating a thread 561 responses = [] 562 current_parent_uri = post_uri 563 current_parent_cid = post_cid 564 565 for i, message in enumerate(reply_messages): 566 logger.info( 567 f"Sending reply {i+1}/{len(reply_messages)}: {message[:50]}...") 568 569 # Send this reply 570 response = reply_to_post( 571 client=client, 572 text=message, 573 reply_to_uri=current_parent_uri, 574 reply_to_cid=current_parent_cid, 575 root_uri=root_uri, 576 root_cid=root_cid, 577 lang=lang 578 ) 579 580 if not response: 581 logger.error( 582 f"Failed to send reply {i+1}, posting system failure message") 583 # Try to post a system failure message 584 failure_response = reply_to_post( 585 client=client, 586 text="[SYSTEM FAILURE: COULD NOT POST MESSAGE, PLEASE TRY AGAIN]", 587 reply_to_uri=current_parent_uri, 588 reply_to_cid=current_parent_cid, 589 root_uri=root_uri, 590 root_cid=root_cid, 591 lang=lang 592 ) 593 if failure_response: 594 responses.append(failure_response) 595 current_parent_uri = failure_response.uri 596 current_parent_cid = failure_response.cid 597 else: 598 logger.error( 599 "Could not even send system failure message, stopping thread") 600 return responses if responses else None 601 else: 602 responses.append(response) 603 # Update parent references for next reply (if any) 604 if i < len(reply_messages) - 1: # Not the last message 605 current_parent_uri = response.uri 606 current_parent_cid = response.cid 607 608 logger.info(f"Successfully sent {len(responses)} threaded replies") 609 return responses 610 611 except Exception as e: 612 logger.error(f"Error sending threaded reply to notification: {e}") 613 return None 614 615 616if __name__ == "__main__": 617 client = default_login() 618 # do something with the client 619 logger.info("Client is ready to use!")