social media crossposting tool. 3rd time's the charm
mastodon misskey crossposting bluesky
at next 235 lines 7.3 kB view raw
1import base64 2import json 3import sqlite3 4import time 5from dataclasses import dataclass 6from functools import cached_property 7from typing import Any 8 9from database.connection import DatabasePool 10 11 12def _decode_jwt_payload(token: str) -> dict[str, Any]: 13 try: 14 _, claims, _ = token.split(".") 15 claims = claims + "=" * (4 - len(claims) % 4) if len(claims) % 4 else claims 16 return json.loads(base64.urlsafe_b64decode(claims)) # type: ignore[no-any-return] 17 except Exception: 18 return {} 19 20 21@dataclass 22class Session: 23 access_jwt: str 24 refresh_jwt: str 25 handle: str 26 did: str 27 pds: str 28 email: str | None = None 29 email_confirmed: bool = False 30 email_auth_factor: bool = False 31 active: bool = True 32 status: str | None = None 33 34 @cached_property 35 def access_payload(self) -> dict[str, Any]: 36 return _decode_jwt_payload(self.access_jwt) 37 38 @cached_property 39 def refresh_payload(self) -> dict[str, Any]: 40 return _decode_jwt_payload(self.refresh_jwt) 41 42 def is_access_token_expired(self, buffer_seconds: int = 60) -> bool: 43 exp = self.access_payload.get("exp", 0) 44 return bool(time.time() >= (exp - buffer_seconds)) 45 46 def is_refresh_token_expired(self, buffer_seconds: int = 60) -> bool: 47 exp = self.refresh_payload.get("exp", 0) 48 return bool(time.time() >= (exp - buffer_seconds)) 49 50 @classmethod 51 def from_row(cls, row: sqlite3.Row) -> "Session": 52 return cls( 53 access_jwt=row["access_jwt"], 54 refresh_jwt=row["refresh_jwt"], 55 handle=row["handle"], 56 did=row["did"], 57 pds=row["pds"], 58 email=row["email"], 59 email_confirmed=bool(row["email_confirmed"]), 60 email_auth_factor=bool(row["email_auth_factor"]), 61 active=bool(row["active"]), 62 status=row["status"], 63 ) 64 65 @classmethod 66 def from_dict(cls, data: dict[str, Any], pds: str) -> "Session": 67 return cls( 68 access_jwt=data["accessJwt"], 69 refresh_jwt=data["refreshJwt"], 70 handle=data["handle"], 71 did=data["did"], 72 pds=pds, 73 email=data.get("email"), 74 email_confirmed=data.get("emailConfirmed", False), 75 email_auth_factor=data.get("emailAuthFactor", False), 76 active=data.get("active", True), 77 status=data.get("status"), 78 ) 79 80 81@dataclass 82class IdentityInfo: 83 did: str 84 handle: str 85 pds: str 86 signing_key: str 87 88 @classmethod 89 def from_row(cls, row: sqlite3.Row) -> "IdentityInfo": 90 return cls( 91 did=row["did"], 92 handle=row["handle"], 93 pds=row["pds"], 94 signing_key=row["signing_key"], 95 ) 96 97 @classmethod 98 def from_dict(cls, data: dict[str, Any]) -> "IdentityInfo": 99 return cls( 100 did=data["did"], 101 handle=data["handle"], 102 pds=data["pds"], 103 signing_key=data["signing_key"], 104 ) 105 106 107class AtprotoStore: 108 def __init__( 109 self, 110 db: sqlite3.Connection, 111 identity_ttl: int = 12 * 60 * 60, 112 ) -> None: 113 self.db = db 114 self.db.row_factory = sqlite3.Row 115 self.identity_ttl = identity_ttl 116 117 def get_session(self, did: str) -> Session | None: 118 row = self.db.execute( 119 "SELECT * FROM atproto_sessions WHERE did = ?", (did,) 120 ).fetchone() 121 return Session.from_row(row) if row else None 122 123 def set_session(self, session: Session) -> None: 124 now = time.time() 125 self.db.execute( 126 """ 127 INSERT OR REPLACE INTO atproto_sessions 128 (did, pds, handle, access_jwt, refresh_jwt, email, email_confirmed, 129 email_auth_factor, active, status, created_at) 130 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) 131 """, 132 ( 133 session.did, 134 session.pds, 135 session.handle, 136 session.access_jwt, 137 session.refresh_jwt, 138 session.email, 139 session.email_confirmed, 140 session.email_auth_factor, 141 session.active, 142 session.status, 143 now, 144 ), 145 ) 146 self.db.commit() 147 148 def get_session_by_pds(self, pds: str, identifier: str) -> Session | None: 149 row = self.db.execute( 150 """ 151 SELECT * FROM atproto_sessions 152 WHERE pds = ? AND (did = ? OR handle = ?) 153 """, 154 (pds, identifier, identifier), 155 ).fetchone() 156 return Session.from_row(row) if row else None 157 158 def list_sessions_by_pds(self, pds: str) -> list[Session]: 159 rows = self.db.execute( 160 "SELECT * FROM atproto_sessions WHERE pds = ?", (pds,) 161 ).fetchall() 162 return [Session.from_row(row) for row in rows] 163 164 def remove_session(self, did: str) -> None: 165 self.db.execute("DELETE FROM atproto_sessions WHERE did = ?", (did,)) 166 self.db.commit() 167 168 def get_identity(self, identifier: str) -> IdentityInfo | None: 169 row = self.db.execute( 170 "SELECT * FROM atproto_identities WHERE identifier = ? AND created_at + ? > ?", 171 (identifier, self.identity_ttl, time.time()), 172 ).fetchone() 173 return IdentityInfo.from_row(row) if row else None 174 175 def set_identity(self, identifier: str, identity: IdentityInfo) -> None: 176 now = time.time() 177 for key in (identifier, identity.did, identity.handle): 178 self.db.execute( 179 """ 180 INSERT OR REPLACE INTO atproto_identities 181 (identifier, did, handle, pds, signing_key, created_at) 182 VALUES (?, ?, ?, ?, ?, ?) 183 """, 184 ( 185 key, 186 identity.did, 187 identity.handle, 188 identity.pds, 189 identity.signing_key, 190 now, 191 ), 192 ) 193 self.db.commit() 194 195 def remove_identity(self, identifier: str) -> None: 196 self.db.execute( 197 "DELETE FROM atproto_identities WHERE identifier = ?", (identifier,) 198 ) 199 self.db.commit() 200 201 def cleanup_expired(self) -> None: 202 cutoff = time.time() - self.identity_ttl 203 self.db.execute( 204 "DELETE FROM atproto_identities WHERE created_at + ? < ?", 205 (self.identity_ttl, cutoff), 206 ) 207 self.db.commit() 208 209 def flush_all(self) -> tuple[int, int]: 210 sessions = self.db.execute("SELECT COUNT(*) FROM atproto_sessions").fetchone()[ 211 0 212 ] 213 identities = self.db.execute( 214 "SELECT COUNT(*) FROM atproto_identities" 215 ).fetchone()[0] 216 self.db.execute("DELETE FROM atproto_sessions") 217 self.db.execute("DELETE FROM atproto_identities") 218 self.db.commit() 219 return sessions, identities 220 221 222_store: AtprotoStore | None = None 223 224 225def get_store(db: DatabasePool) -> AtprotoStore: 226 global _store 227 if _store is None: 228 _store = AtprotoStore(db.get_conn()) 229 return _store 230 231 232def flush_caches() -> tuple[int, int]: 233 if _store is not None: 234 return _store.flush_all() 235 return 0, 0