a digital person for bluesky
at toolchange 11 kB view raw
1import os 2import logging 3from typing import Optional, Dict, Any 4from atproto_client import Client, Session, SessionEvent, models 5 6# Configure logging 7logging.basicConfig( 8 level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" 9) 10logger = logging.getLogger("bluesky_session_handler") 11 12# Load the environment variables 13import dotenv 14dotenv.load_dotenv(override=True) 15 16import yaml 17import json 18 19# Strip fields. A list of fields to remove from a JSON object 20STRIP_FIELDS = [ 21 "cid", 22 "rev", 23 "did", 24 "uri", 25 "langs", 26 "threadgate", 27 "py_type", 28 "labels", 29 "facets", 30 "avatar", 31 "viewer", 32 "indexed_at", 33 "tags", 34 "associated", 35 "thread_context", 36 "aspect_ratio", 37 "thumb", 38 "fullsize", 39 "root", 40 "created_at", 41 "verification", 42 "like_count", 43 "quote_count", 44 "reply_count", 45 "repost_count", 46 "embedding_disabled", 47 "thread_muted", 48 "reply_disabled", 49 "pinned", 50 "like", 51 "repost", 52 "blocked_by", 53 "blocking", 54 "blocking_by_list", 55 "followed_by", 56 "following", 57 "known_followers", 58 "muted", 59 "muted_by_list", 60 "root_author_like", 61 "entities", 62 "ref", 63 "mime_type", 64 "size", 65] 66def convert_to_basic_types(obj): 67 """Convert complex Python objects to basic types for JSON/YAML serialization.""" 68 if hasattr(obj, '__dict__'): 69 # Convert objects with __dict__ to their dictionary representation 70 return convert_to_basic_types(obj.__dict__) 71 elif isinstance(obj, dict): 72 return {key: convert_to_basic_types(value) for key, value in obj.items()} 73 elif isinstance(obj, list): 74 return [convert_to_basic_types(item) for item in obj] 75 elif isinstance(obj, (str, int, float, bool)) or obj is None: 76 return obj 77 else: 78 # For other types, try to convert to string 79 return str(obj) 80 81 82def strip_fields(obj, strip_field_list): 83 """Recursively strip fields from a JSON object.""" 84 if isinstance(obj, dict): 85 keys_flagged_for_removal = [] 86 87 # Remove fields from strip list and pydantic metadata 88 for field in list(obj.keys()): 89 if field in strip_field_list or field.startswith("__"): 90 keys_flagged_for_removal.append(field) 91 92 # Remove flagged keys 93 for key in keys_flagged_for_removal: 94 obj.pop(key, None) 95 96 # Recursively process remaining values 97 for key, value in list(obj.items()): 98 obj[key] = strip_fields(value, strip_field_list) 99 # Remove empty/null values after processing 100 if ( 101 obj[key] is None 102 or (isinstance(obj[key], dict) and len(obj[key]) == 0) 103 or (isinstance(obj[key], list) and len(obj[key]) == 0) 104 or (isinstance(obj[key], str) and obj[key].strip() == "") 105 ): 106 obj.pop(key, None) 107 108 elif isinstance(obj, list): 109 for i, value in enumerate(obj): 110 obj[i] = strip_fields(value, strip_field_list) 111 # Remove None values from list 112 obj[:] = [item for item in obj if item is not None] 113 114 return obj 115 116 117def thread_to_yaml_string(thread, strip_metadata=True): 118 """ 119 Convert thread data to a YAML-formatted string for LLM parsing. 120 121 Args: 122 thread: The thread data from get_post_thread 123 strip_metadata: Whether to strip metadata fields for cleaner output 124 125 Returns: 126 YAML-formatted string representation of the thread 127 """ 128 # First convert complex objects to basic types 129 basic_thread = convert_to_basic_types(thread) 130 131 if strip_metadata: 132 # Create a copy and strip unwanted fields 133 cleaned_thread = strip_fields(basic_thread, STRIP_FIELDS) 134 else: 135 cleaned_thread = basic_thread 136 137 return yaml.dump(cleaned_thread, indent=2, allow_unicode=True, default_flow_style=False) 138 139 140 141 142 143def get_session(username: str) -> Optional[str]: 144 try: 145 with open(f"session_{username}.txt", encoding="UTF-8") as f: 146 return f.read() 147 except FileNotFoundError: 148 logger.debug(f"No existing session found for {username}") 149 return None 150 151def save_session(username: str, session_string: str) -> None: 152 with open(f"session_{username}.txt", "w", encoding="UTF-8") as f: 153 f.write(session_string) 154 logger.debug(f"Session saved for {username}") 155 156def on_session_change(username: str, event: SessionEvent, session: Session) -> None: 157 logger.info(f"Session changed: {event} {repr(session)}") 158 if event in (SessionEvent.CREATE, SessionEvent.REFRESH): 159 logger.info(f"Saving changed session for {username}") 160 save_session(username, session.export()) 161 162def init_client(username: str, password: str) -> Client: 163 pds_uri = os.getenv("PDS_URI") 164 if pds_uri is None: 165 logger.warning( 166 "No PDS URI provided. Falling back to bsky.social. Note! If you are on a non-Bluesky PDS, this can cause logins to fail. Please provide a PDS URI using the PDS_URI environment variable." 167 ) 168 pds_uri = "https://bsky.social" 169 170 # Print the PDS URI 171 logger.info(f"Using PDS URI: {pds_uri}") 172 173 client = Client(pds_uri) 174 client.on_session_change( 175 lambda event, session: on_session_change(username, event, session) 176 ) 177 178 session_string = get_session(username) 179 if session_string: 180 logger.info(f"Reusing existing session for {username}") 181 client.login(session_string=session_string) 182 else: 183 logger.info(f"Creating new session for {username}") 184 client.login(username, password) 185 186 return client 187 188 189def default_login() -> Client: 190 username = os.getenv("BSKY_USERNAME") 191 password = os.getenv("BSKY_PASSWORD") 192 193 if username is None: 194 logger.error( 195 "No username provided. Please provide a username using the BSKY_USERNAME environment variable." 196 ) 197 exit() 198 199 if password is None: 200 logger.error( 201 "No password provided. Please provide a password using the BSKY_PASSWORD environment variable." 202 ) 203 exit() 204 205 return init_client(username, password) 206 207def reply_to_post(client: Client, text: str, reply_to_uri: str, reply_to_cid: str, root_uri: Optional[str] = None, root_cid: Optional[str] = None) -> Dict[str, Any]: 208 """ 209 Reply to a post on Bluesky. 210 211 Args: 212 client: Authenticated Bluesky client 213 text: The reply text 214 reply_to_uri: The URI of the post being replied to (parent) 215 reply_to_cid: The CID of the post being replied to (parent) 216 root_uri: The URI of the root post (if replying to a reply). If None, uses reply_to_uri 217 root_cid: The CID of the root post (if replying to a reply). If None, uses reply_to_cid 218 219 Returns: 220 The response from sending the post 221 """ 222 # If root is not provided, this is a reply to the root post 223 if root_uri is None: 224 root_uri = reply_to_uri 225 root_cid = reply_to_cid 226 227 # Create references for the reply 228 parent_ref = models.create_strong_ref(models.ComAtprotoRepoStrongRef.Main(uri=reply_to_uri, cid=reply_to_cid)) 229 root_ref = models.create_strong_ref(models.ComAtprotoRepoStrongRef.Main(uri=root_uri, cid=root_cid)) 230 231 # Send the reply 232 response = client.send_post( 233 text=text, 234 reply_to=models.AppBskyFeedPost.ReplyRef(parent=parent_ref, root=root_ref) 235 ) 236 237 logger.info(f"Reply sent successfully: {response.uri}") 238 return response 239 240 241def get_post_thread(client: Client, uri: str) -> Optional[Dict[str, Any]]: 242 """ 243 Get the thread containing a post to find root post information. 244 245 Args: 246 client: Authenticated Bluesky client 247 uri: The URI of the post 248 249 Returns: 250 The thread data or None if not found 251 """ 252 try: 253 thread = client.app.bsky.feed.get_post_thread({'uri': uri, 'parent_height': 60, 'depth': 10}) 254 return thread 255 except Exception as e: 256 logger.error(f"Error fetching post thread: {e}") 257 return None 258 259 260def reply_to_notification(client: Client, notification: Any, reply_text: str) -> Optional[Dict[str, Any]]: 261 """ 262 Reply to a notification (mention or reply). 263 264 Args: 265 client: Authenticated Bluesky client 266 notification: The notification object from list_notifications 267 reply_text: The text to reply with 268 269 Returns: 270 The response from sending the reply or None if failed 271 """ 272 try: 273 # Get the post URI and CID from the notification (handle both dict and object) 274 if isinstance(notification, dict): 275 post_uri = notification.get('uri') 276 post_cid = notification.get('cid') 277 elif hasattr(notification, 'uri') and hasattr(notification, 'cid'): 278 post_uri = notification.uri 279 post_cid = notification.cid 280 else: 281 post_uri = None 282 post_cid = None 283 284 if not post_uri or not post_cid: 285 logger.error("Notification doesn't have required uri/cid fields") 286 return None 287 288 # Get the thread to find the root post 289 thread_data = get_post_thread(client, post_uri) 290 291 if thread_data and hasattr(thread_data, 'thread'): 292 thread = thread_data.thread 293 294 # Find root post 295 root_uri = post_uri 296 root_cid = post_cid 297 298 # If this has a parent, find the root 299 if hasattr(thread, 'parent') and thread.parent: 300 # Keep going up until we find the root 301 current = thread 302 while hasattr(current, 'parent') and current.parent: 303 current = current.parent 304 if hasattr(current, 'post') and hasattr(current.post, 'uri') and hasattr(current.post, 'cid'): 305 root_uri = current.post.uri 306 root_cid = current.post.cid 307 308 # Reply to the notification 309 return reply_to_post( 310 client=client, 311 text=reply_text, 312 reply_to_uri=post_uri, 313 reply_to_cid=post_cid, 314 root_uri=root_uri, 315 root_cid=root_cid 316 ) 317 else: 318 # If we can't get thread data, just reply directly 319 return reply_to_post( 320 client=client, 321 text=reply_text, 322 reply_to_uri=post_uri, 323 reply_to_cid=post_cid 324 ) 325 326 except Exception as e: 327 logger.error(f"Error replying to notification: {e}") 328 return None 329 330 331if __name__ == "__main__": 332 client = default_login() 333 # do something with the client 334 logger.info("Client is ready to use!")