1import json
2import yaml
3import dotenv
4import os
5import logging
6from typing import Optional, Dict, Any, List
7from atproto_client import Client, Session, SessionEvent, models
8
9# Configure logging
10logging.basicConfig(
11 level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
12)
13logger = logging.getLogger("bluesky_session_handler")
14
15# Load the environment variables
16dotenv.load_dotenv(override=True)
17
18
19# Strip fields. A list of fields to remove from a JSON object
20STRIP_FIELDS = [
21 "cid",
22 "rev",
23 "did",
24 "uri",
25 "langs",
26 "threadgate",
27 "py_type",
28 "labels",
29 "facets",
30 "avatar",
31 "viewer",
32 "indexed_at",
33 "tags",
34 "associated",
35 "thread_context",
36 "aspect_ratio",
37 "thumb",
38 "fullsize",
39 "root",
40 "created_at",
41 "verification",
42 "like_count",
43 "quote_count",
44 "reply_count",
45 "repost_count",
46 "embedding_disabled",
47 "thread_muted",
48 "reply_disabled",
49 "pinned",
50 "like",
51 "repost",
52 "blocked_by",
53 "blocking",
54 "blocking_by_list",
55 "followed_by",
56 "following",
57 "known_followers",
58 "muted",
59 "muted_by_list",
60 "root_author_like",
61 "entities",
62 "ref",
63 "mime_type",
64 "size",
65]
66
67
68def convert_to_basic_types(obj):
69 """Convert complex Python objects to basic types for JSON/YAML serialization."""
70 if hasattr(obj, '__dict__'):
71 # Convert objects with __dict__ to their dictionary representation
72 return convert_to_basic_types(obj.__dict__)
73 elif isinstance(obj, dict):
74 return {key: convert_to_basic_types(value) for key, value in obj.items()}
75 elif isinstance(obj, list):
76 return [convert_to_basic_types(item) for item in obj]
77 elif isinstance(obj, (str, int, float, bool)) or obj is None:
78 return obj
79 else:
80 # For other types, try to convert to string
81 return str(obj)
82
83
84def strip_fields(obj, strip_field_list):
85 """Recursively strip fields from a JSON object."""
86 if isinstance(obj, dict):
87 keys_flagged_for_removal = []
88
89 # Remove fields from strip list and pydantic metadata
90 for field in list(obj.keys()):
91 if field in strip_field_list or field.startswith("__"):
92 keys_flagged_for_removal.append(field)
93
94 # Remove flagged keys
95 for key in keys_flagged_for_removal:
96 obj.pop(key, None)
97
98 # Recursively process remaining values
99 for key, value in list(obj.items()):
100 obj[key] = strip_fields(value, strip_field_list)
101 # Remove empty/null values after processing
102 if (
103 obj[key] is None
104 or (isinstance(obj[key], dict) and len(obj[key]) == 0)
105 or (isinstance(obj[key], list) and len(obj[key]) == 0)
106 or (isinstance(obj[key], str) and obj[key].strip() == "")
107 ):
108 obj.pop(key, None)
109
110 elif isinstance(obj, list):
111 for i, value in enumerate(obj):
112 obj[i] = strip_fields(value, strip_field_list)
113 # Remove None values from list
114 obj[:] = [item for item in obj if item is not None]
115
116 return obj
117
118
119def flatten_thread_structure(thread_data):
120 """
121 Flatten a nested thread structure into a list while preserving all data.
122
123 Args:
124 thread_data: The thread data from get_post_thread
125
126 Returns:
127 Dict with 'posts' key containing a list of posts in chronological order
128 """
129 posts = []
130
131 def traverse_thread(node):
132 """Recursively traverse the thread structure to collect posts."""
133 if not node:
134 return
135
136 # If this node has a parent, traverse it first (to maintain chronological order)
137 if hasattr(node, 'parent') and node.parent:
138 traverse_thread(node.parent)
139
140 # Then add this node's post
141 if hasattr(node, 'post') and node.post:
142 # Convert to dict if needed to ensure we can process it
143 if hasattr(node.post, '__dict__'):
144 post_dict = node.post.__dict__.copy()
145 elif isinstance(node.post, dict):
146 post_dict = node.post.copy()
147 else:
148 post_dict = {}
149
150 posts.append(post_dict)
151
152 # Handle the thread structure
153 if hasattr(thread_data, 'thread'):
154 # Start from the main thread node
155 traverse_thread(thread_data.thread)
156 elif hasattr(thread_data, '__dict__') and 'thread' in thread_data.__dict__:
157 traverse_thread(thread_data.__dict__['thread'])
158
159 # Return a simple structure with posts list
160 return {'posts': posts}
161
162
163def thread_to_yaml_string(thread, strip_metadata=True):
164 """
165 Convert thread data to a YAML-formatted string for LLM parsing.
166
167 Args:
168 thread: The thread data from get_post_thread
169 strip_metadata: Whether to strip metadata fields for cleaner output
170
171 Returns:
172 YAML-formatted string representation of the thread
173 """
174 # First flatten the thread structure to avoid deep nesting
175 flattened = flatten_thread_structure(thread)
176
177 # Convert complex objects to basic types
178 basic_thread = convert_to_basic_types(flattened)
179
180 if strip_metadata:
181 # Create a copy and strip unwanted fields
182 cleaned_thread = strip_fields(basic_thread, STRIP_FIELDS)
183 else:
184 cleaned_thread = basic_thread
185
186 return yaml.dump(cleaned_thread, indent=2, allow_unicode=True, default_flow_style=False)
187
188
189def get_session(username: str) -> Optional[str]:
190 try:
191 with open(f"session_{username}.txt", encoding="UTF-8") as f:
192 return f.read()
193 except FileNotFoundError:
194 logger.debug(f"No existing session found for {username}")
195 return None
196
197
198def save_session(username: str, session_string: str) -> None:
199 with open(f"session_{username}.txt", "w", encoding="UTF-8") as f:
200 f.write(session_string)
201 logger.debug(f"Session saved for {username}")
202
203
204def on_session_change(username: str, event: SessionEvent, session: Session) -> None:
205 logger.debug(f"Session changed: {event} {repr(session)}")
206 if event in (SessionEvent.CREATE, SessionEvent.REFRESH):
207 logger.debug(f"Saving changed session for {username}")
208 save_session(username, session.export())
209
210
211def init_client(username: str, password: str, pds_uri: str = "https://bsky.social") -> Client:
212 if pds_uri is None:
213 logger.warning(
214 "No PDS URI provided. Falling back to bsky.social. Note! If you are on a non-Bluesky PDS, this can cause logins to fail. Please provide a PDS URI using the PDS_URI environment variable."
215 )
216 pds_uri = "https://bsky.social"
217
218 # Print the PDS URI
219 logger.debug(f"Using PDS URI: {pds_uri}")
220
221 client = Client(pds_uri)
222 client.on_session_change(
223 lambda event, session: on_session_change(username, event, session)
224 )
225
226 session_string = get_session(username)
227 if session_string:
228 logger.debug(f"Reusing existing session for {username}")
229 client.login(session_string=session_string)
230 else:
231 logger.debug(f"Creating new session for {username}")
232 client.login(username, password)
233
234 return client
235
236
237def default_login() -> Client:
238 # Try to load from config first, fall back to environment variables
239 try:
240 from config_loader import get_bluesky_config
241 config = get_bluesky_config()
242 username = config['username']
243 password = config['password']
244 pds_uri = config['pds_uri']
245 except (ImportError, FileNotFoundError, KeyError) as e:
246 logger.warning(
247 f"Could not load from config file ({e}), falling back to environment variables")
248 username = os.getenv("BSKY_USERNAME")
249 password = os.getenv("BSKY_PASSWORD")
250 pds_uri = os.getenv("PDS_URI", "https://bsky.social")
251
252 if username is None:
253 logger.error(
254 "No username provided. Please provide a username using the BSKY_USERNAME environment variable or config.yaml."
255 )
256 exit()
257
258 if password is None:
259 logger.error(
260 "No password provided. Please provide a password using the BSKY_PASSWORD environment variable or config.yaml."
261 )
262 exit()
263
264 return init_client(username, password, pds_uri)
265
266
267def remove_outside_quotes(text: str) -> str:
268 """
269 Remove outside double quotes from response text.
270
271 Only handles double quotes to avoid interfering with contractions:
272 - Double quotes: "text" → text
273 - Preserves single quotes and internal quotes
274
275 Args:
276 text: The text to process
277
278 Returns:
279 Text with outside double quotes removed
280 """
281 if not text or len(text) < 2:
282 return text
283
284 text = text.strip()
285
286 # Only remove double quotes from start and end
287 if text.startswith('"') and text.endswith('"'):
288 return text[1:-1]
289
290 return text
291
292
293def reply_to_post(client: Client, text: str, reply_to_uri: str, reply_to_cid: str, root_uri: Optional[str] = None, root_cid: Optional[str] = None, lang: Optional[str] = None) -> Dict[str, Any]:
294 """
295 Reply to a post on Bluesky with rich text support.
296
297 Args:
298 client: Authenticated Bluesky client
299 text: The reply text
300 reply_to_uri: The URI of the post being replied to (parent)
301 reply_to_cid: The CID of the post being replied to (parent)
302 root_uri: The URI of the root post (if replying to a reply). If None, uses reply_to_uri
303 root_cid: The CID of the root post (if replying to a reply). If None, uses reply_to_cid
304 lang: Language code for the post (e.g., 'en-US', 'es', 'ja')
305
306 Returns:
307 The response from sending the post
308 """
309 import re
310
311 # If root is not provided, this is a reply to the root post
312 if root_uri is None:
313 root_uri = reply_to_uri
314 root_cid = reply_to_cid
315
316 # Create references for the reply
317 parent_ref = models.create_strong_ref(
318 models.ComAtprotoRepoStrongRef.Main(uri=reply_to_uri, cid=reply_to_cid))
319 root_ref = models.create_strong_ref(
320 models.ComAtprotoRepoStrongRef.Main(uri=root_uri, cid=root_cid))
321
322 # Parse rich text facets (mentions and URLs)
323 facets = []
324 text_bytes = text.encode("UTF-8")
325
326 # Parse mentions - fixed to handle @ at start of text
327 mention_regex = rb"(?:^|[$|\W])(@([a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)"
328
329 for m in re.finditer(mention_regex, text_bytes):
330 handle = m.group(1)[1:].decode("UTF-8") # Remove @ prefix
331 # Adjust byte positions to account for the optional prefix
332 mention_start = m.start(1)
333 mention_end = m.end(1)
334 try:
335 # Resolve handle to DID using the API
336 resolve_resp = client.app.bsky.actor.get_profile({'actor': handle})
337 if resolve_resp and hasattr(resolve_resp, 'did'):
338 facets.append(
339 models.AppBskyRichtextFacet.Main(
340 index=models.AppBskyRichtextFacet.ByteSlice(
341 byteStart=mention_start,
342 byteEnd=mention_end
343 ),
344 features=[models.AppBskyRichtextFacet.Mention(
345 did=resolve_resp.did)]
346 )
347 )
348 except Exception as e:
349 # Handle specific error cases
350 error_str = str(e)
351 if 'Could not find user info' in error_str or 'InvalidRequest' in error_str:
352 logger.warning(
353 f"User @{handle} not found (account may be deleted/suspended), skipping mention facet")
354 elif 'BadRequestError' in error_str:
355 logger.warning(
356 f"Bad request when resolving @{handle}, skipping mention facet: {e}")
357 else:
358 logger.debug(f"Failed to resolve handle @{handle}: {e}")
359 continue
360
361 # Parse URLs - fixed to handle URLs at start of text
362 url_regex = rb"(?:^|[$|\W])(https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*[-a-zA-Z0-9@%_\+~#//=])?)"
363
364 for m in re.finditer(url_regex, text_bytes):
365 url = m.group(1).decode("UTF-8")
366 # Adjust byte positions to account for the optional prefix
367 url_start = m.start(1)
368 url_end = m.end(1)
369 facets.append(
370 models.AppBskyRichtextFacet.Main(
371 index=models.AppBskyRichtextFacet.ByteSlice(
372 byteStart=url_start,
373 byteEnd=url_end
374 ),
375 features=[models.AppBskyRichtextFacet.Link(uri=url)]
376 )
377 )
378
379 # Send the reply with facets if any were found
380 if facets:
381 response = client.send_post(
382 text=text,
383 reply_to=models.AppBskyFeedPost.ReplyRef(
384 parent=parent_ref, root=root_ref),
385 facets=facets,
386 langs=[lang] if lang else None
387 )
388 else:
389 response = client.send_post(
390 text=text,
391 reply_to=models.AppBskyFeedPost.ReplyRef(
392 parent=parent_ref, root=root_ref),
393 langs=[lang] if lang else None
394 )
395
396 logger.info(f"Reply sent successfully: {response.uri}")
397 return response
398
399
400def get_post_thread(client: Client, uri: str) -> Optional[Dict[str, Any]]:
401 """
402 Get the thread containing a post to find root post information.
403
404 Args:
405 client: Authenticated Bluesky client
406 uri: The URI of the post
407
408 Returns:
409 The thread data or None if not found
410 """
411 try:
412 thread = client.app.bsky.feed.get_post_thread(
413 {'uri': uri, 'parent_height': 60, 'depth': 10})
414 return thread
415 except Exception as e:
416 error_str = str(e)
417 # Handle specific error cases more gracefully
418 if 'Could not find user info' in error_str or 'InvalidRequest' in error_str:
419 logger.warning(
420 f"User account not found for post URI {uri} (account may be deleted/suspended)")
421 elif 'NotFound' in error_str or 'Post not found' in error_str:
422 logger.warning(f"Post not found for URI {uri}")
423 elif 'BadRequestError' in error_str:
424 logger.warning(f"Bad request error for URI {uri}: {e}")
425 else:
426 logger.error(f"Error fetching post thread: {e}")
427 return None
428
429
430def reply_to_notification(client: Client, notification: Any, reply_text: str, lang: str = "en-US") -> Optional[Dict[str, Any]]:
431 """
432 Reply to a notification (mention or reply).
433
434 Args:
435 client: Authenticated Bluesky client
436 notification: The notification object from list_notifications
437 reply_text: The text to reply with
438 lang: Language code for the post (defaults to "en-US")
439
440 Returns:
441 The response from sending the reply or None if failed
442 """
443 try:
444 # Get the post URI and CID from the notification (handle both dict and object)
445 if isinstance(notification, dict):
446 post_uri = notification.get('uri')
447 post_cid = notification.get('cid')
448 elif hasattr(notification, 'uri') and hasattr(notification, 'cid'):
449 post_uri = notification.uri
450 post_cid = notification.cid
451 else:
452 post_uri = None
453 post_cid = None
454
455 if not post_uri or not post_cid:
456 logger.error("Notification doesn't have required uri/cid fields")
457 return None
458
459 # Get the thread to find the root post
460 thread_data = get_post_thread(client, post_uri)
461
462 if thread_data and hasattr(thread_data, 'thread'):
463 thread = thread_data.thread
464
465 # Find root post
466 root_uri = post_uri
467 root_cid = post_cid
468
469 # If this has a parent, find the root
470 if hasattr(thread, 'parent') and thread.parent:
471 # Keep going up until we find the root
472 current = thread
473 while hasattr(current, 'parent') and current.parent:
474 current = current.parent
475 if hasattr(current, 'post') and hasattr(current.post, 'uri') and hasattr(current.post, 'cid'):
476 root_uri = current.post.uri
477 root_cid = current.post.cid
478
479 # Reply to the notification
480 return reply_to_post(
481 client=client,
482 text=reply_text,
483 reply_to_uri=post_uri,
484 reply_to_cid=post_cid,
485 root_uri=root_uri,
486 root_cid=root_cid,
487 lang=lang
488 )
489 else:
490 # If we can't get thread data, just reply directly
491 return reply_to_post(
492 client=client,
493 text=reply_text,
494 reply_to_uri=post_uri,
495 reply_to_cid=post_cid,
496 lang=lang
497 )
498
499 except Exception as e:
500 logger.error(f"Error replying to notification: {e}")
501 return None
502
503
504def reply_with_thread_to_notification(client: Client, notification: Any, reply_messages: List[str], lang: str = "en-US") -> Optional[List[Dict[str, Any]]]:
505 """
506 Reply to a notification with a threaded chain of messages (max 15).
507
508 Args:
509 client: Authenticated Bluesky client
510 notification: The notification object from list_notifications
511 reply_messages: List of reply texts (max 15 messages, each max 300 chars)
512 lang: Language code for the posts (defaults to "en-US")
513
514 Returns:
515 List of responses from sending the replies or None if failed
516 """
517 try:
518 # Validate input
519 if not reply_messages or len(reply_messages) == 0:
520 logger.error("Reply messages list cannot be empty")
521 return None
522 if len(reply_messages) > 15:
523 logger.error(
524 f"Cannot send more than 15 reply messages (got {len(reply_messages)})")
525 return None
526
527 # Get the post URI and CID from the notification (handle both dict and object)
528 if isinstance(notification, dict):
529 post_uri = notification.get('uri')
530 post_cid = notification.get('cid')
531 elif hasattr(notification, 'uri') and hasattr(notification, 'cid'):
532 post_uri = notification.uri
533 post_cid = notification.cid
534 else:
535 post_uri = None
536 post_cid = None
537
538 if not post_uri or not post_cid:
539 logger.error("Notification doesn't have required uri/cid fields")
540 return None
541
542 # Get the thread to find the root post
543 thread_data = get_post_thread(client, post_uri)
544
545 root_uri = post_uri
546 root_cid = post_cid
547
548 if thread_data and hasattr(thread_data, 'thread'):
549 thread = thread_data.thread
550 # If this has a parent, find the root
551 if hasattr(thread, 'parent') and thread.parent:
552 # Keep going up until we find the root
553 current = thread
554 while hasattr(current, 'parent') and current.parent:
555 current = current.parent
556 if hasattr(current, 'post') and hasattr(current.post, 'uri') and hasattr(current.post, 'cid'):
557 root_uri = current.post.uri
558 root_cid = current.post.cid
559
560 # Send replies in sequence, creating a thread
561 responses = []
562 current_parent_uri = post_uri
563 current_parent_cid = post_cid
564
565 for i, message in enumerate(reply_messages):
566 logger.info(
567 f"Sending reply {i+1}/{len(reply_messages)}: {message[:50]}...")
568
569 # Send this reply
570 response = reply_to_post(
571 client=client,
572 text=message,
573 reply_to_uri=current_parent_uri,
574 reply_to_cid=current_parent_cid,
575 root_uri=root_uri,
576 root_cid=root_cid,
577 lang=lang
578 )
579
580 if not response:
581 logger.error(
582 f"Failed to send reply {i+1}, posting system failure message")
583 # Try to post a system failure message
584 failure_response = reply_to_post(
585 client=client,
586 text="[SYSTEM FAILURE: COULD NOT POST MESSAGE, PLEASE TRY AGAIN]",
587 reply_to_uri=current_parent_uri,
588 reply_to_cid=current_parent_cid,
589 root_uri=root_uri,
590 root_cid=root_cid,
591 lang=lang
592 )
593 if failure_response:
594 responses.append(failure_response)
595 current_parent_uri = failure_response.uri
596 current_parent_cid = failure_response.cid
597 else:
598 logger.error(
599 "Could not even send system failure message, stopping thread")
600 return responses if responses else None
601 else:
602 responses.append(response)
603 # Update parent references for next reply (if any)
604 if i < len(reply_messages) - 1: # Not the last message
605 current_parent_uri = response.uri
606 current_parent_cid = response.cid
607
608 logger.info(f"Successfully sent {len(responses)} threaded replies")
609 return responses
610
611 except Exception as e:
612 logger.error(f"Error sending threaded reply to notification: {e}")
613 return None
614
615
616if __name__ == "__main__":
617 client = default_login()
618 # do something with the client
619 logger.info("Client is ready to use!")