social media crossposting tool. 3rd time's the charm
mastodon
misskey
crossposting
bluesky
1import base64
2import json
3import sqlite3
4import time
5from dataclasses import dataclass
6from functools import cached_property
7from typing import Any
8
9from database.connection import DatabasePool
10
11
12def _decode_jwt_payload(token: str) -> dict[str, Any]:
13 try:
14 _, claims, _ = token.split(".")
15 claims = claims + "=" * (4 - len(claims) % 4) if len(claims) % 4 else claims
16 return json.loads(base64.urlsafe_b64decode(claims)) # type: ignore[no-any-return]
17 except Exception:
18 return {}
19
20
21@dataclass
22class Session:
23 access_jwt: str
24 refresh_jwt: str
25 handle: str
26 did: str
27 pds: str
28 email: str | None = None
29 email_confirmed: bool = False
30 email_auth_factor: bool = False
31 active: bool = True
32 status: str | None = None
33
34 @cached_property
35 def access_payload(self) -> dict[str, Any]:
36 return _decode_jwt_payload(self.access_jwt)
37
38 @cached_property
39 def refresh_payload(self) -> dict[str, Any]:
40 return _decode_jwt_payload(self.refresh_jwt)
41
42 def is_access_token_expired(self, buffer_seconds: int = 60) -> bool:
43 exp = self.access_payload.get("exp", 0)
44 return bool(time.time() >= (exp - buffer_seconds))
45
46 def is_refresh_token_expired(self, buffer_seconds: int = 60) -> bool:
47 exp = self.refresh_payload.get("exp", 0)
48 return bool(time.time() >= (exp - buffer_seconds))
49
50 @classmethod
51 def from_row(cls, row: sqlite3.Row) -> "Session":
52 return cls(
53 access_jwt=row["access_jwt"],
54 refresh_jwt=row["refresh_jwt"],
55 handle=row["handle"],
56 did=row["did"],
57 pds=row["pds"],
58 email=row["email"],
59 email_confirmed=bool(row["email_confirmed"]),
60 email_auth_factor=bool(row["email_auth_factor"]),
61 active=bool(row["active"]),
62 status=row["status"],
63 )
64
65 @classmethod
66 def from_dict(cls, data: dict[str, Any], pds: str) -> "Session":
67 return cls(
68 access_jwt=data["accessJwt"],
69 refresh_jwt=data["refreshJwt"],
70 handle=data["handle"],
71 did=data["did"],
72 pds=pds,
73 email=data.get("email"),
74 email_confirmed=data.get("emailConfirmed", False),
75 email_auth_factor=data.get("emailAuthFactor", False),
76 active=data.get("active", True),
77 status=data.get("status"),
78 )
79
80
81@dataclass
82class IdentityInfo:
83 did: str
84 handle: str
85 pds: str
86 signing_key: str
87
88 @classmethod
89 def from_row(cls, row: sqlite3.Row) -> "IdentityInfo":
90 return cls(
91 did=row["did"],
92 handle=row["handle"],
93 pds=row["pds"],
94 signing_key=row["signing_key"],
95 )
96
97 @classmethod
98 def from_dict(cls, data: dict[str, Any]) -> "IdentityInfo":
99 return cls(
100 did=data["did"],
101 handle=data["handle"],
102 pds=data["pds"],
103 signing_key=data["signing_key"],
104 )
105
106
107class AtprotoStore:
108 def __init__(
109 self,
110 db: sqlite3.Connection,
111 identity_ttl: int = 12 * 60 * 60,
112 ) -> None:
113 self.db = db
114 self.db.row_factory = sqlite3.Row
115 self.identity_ttl = identity_ttl
116
117 def get_session(self, did: str) -> Session | None:
118 row = self.db.execute(
119 "SELECT * FROM atproto_sessions WHERE did = ?", (did,)
120 ).fetchone()
121 return Session.from_row(row) if row else None
122
123 def set_session(self, session: Session) -> None:
124 now = time.time()
125 self.db.execute(
126 """
127 INSERT OR REPLACE INTO atproto_sessions
128 (did, pds, handle, access_jwt, refresh_jwt, email, email_confirmed,
129 email_auth_factor, active, status, created_at)
130 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
131 """,
132 (
133 session.did,
134 session.pds,
135 session.handle,
136 session.access_jwt,
137 session.refresh_jwt,
138 session.email,
139 session.email_confirmed,
140 session.email_auth_factor,
141 session.active,
142 session.status,
143 now,
144 ),
145 )
146 self.db.commit()
147
148 def get_session_by_pds(self, pds: str, identifier: str) -> Session | None:
149 row = self.db.execute(
150 """
151 SELECT * FROM atproto_sessions
152 WHERE pds = ? AND (did = ? OR handle = ?)
153 """,
154 (pds, identifier, identifier),
155 ).fetchone()
156 return Session.from_row(row) if row else None
157
158 def list_sessions_by_pds(self, pds: str) -> list[Session]:
159 rows = self.db.execute(
160 "SELECT * FROM atproto_sessions WHERE pds = ?", (pds,)
161 ).fetchall()
162 return [Session.from_row(row) for row in rows]
163
164 def remove_session(self, did: str) -> None:
165 self.db.execute("DELETE FROM atproto_sessions WHERE did = ?", (did,))
166 self.db.commit()
167
168 def get_identity(self, identifier: str) -> IdentityInfo | None:
169 row = self.db.execute(
170 "SELECT * FROM atproto_identities WHERE identifier = ? AND created_at + ? > ?",
171 (identifier, self.identity_ttl, time.time()),
172 ).fetchone()
173 return IdentityInfo.from_row(row) if row else None
174
175 def set_identity(self, identifier: str, identity: IdentityInfo) -> None:
176 now = time.time()
177 for key in (identifier, identity.did, identity.handle):
178 self.db.execute(
179 """
180 INSERT OR REPLACE INTO atproto_identities
181 (identifier, did, handle, pds, signing_key, created_at)
182 VALUES (?, ?, ?, ?, ?, ?)
183 """,
184 (
185 key,
186 identity.did,
187 identity.handle,
188 identity.pds,
189 identity.signing_key,
190 now,
191 ),
192 )
193 self.db.commit()
194
195 def remove_identity(self, identifier: str) -> None:
196 self.db.execute(
197 "DELETE FROM atproto_identities WHERE identifier = ?", (identifier,)
198 )
199 self.db.commit()
200
201 def cleanup_expired(self) -> None:
202 cutoff = time.time() - self.identity_ttl
203 self.db.execute(
204 "DELETE FROM atproto_identities WHERE created_at + ? < ?",
205 (self.identity_ttl, cutoff),
206 )
207 self.db.commit()
208
209 def flush_all(self) -> tuple[int, int]:
210 sessions = self.db.execute("SELECT COUNT(*) FROM atproto_sessions").fetchone()[
211 0
212 ]
213 identities = self.db.execute(
214 "SELECT COUNT(*) FROM atproto_identities"
215 ).fetchone()[0]
216 self.db.execute("DELETE FROM atproto_sessions")
217 self.db.execute("DELETE FROM atproto_identities")
218 self.db.commit()
219 return sessions, identities
220
221
222_store: AtprotoStore | None = None
223
224
225def get_store(db: DatabasePool) -> AtprotoStore:
226 global _store
227 if _store is None:
228 _store = AtprotoStore(db.get_conn())
229 return _store
230
231
232def flush_caches() -> tuple[int, int]:
233 if _store is not None:
234 return _store.flush_all()
235 return 0, 0