social media crossposting tool. 3rd time's the charm
mastodon misskey crossposting bluesky

pool db connections, atp identity resolution, bsky jetstream input

zenfyr.dev d58e914e b74b024b

verified
Changed files
+187 -49
atproto
bluesky
cross
database
mastodon
misskey
+49 -14
atproto/identity.py
··· 1 1 from typing import Any 2 + 2 3 import dns.resolver 3 4 import requests 5 + 6 + import env 4 7 from util.cache import TTLCache 5 - from util.util import LOGGER 8 + from util.util import LOGGER, normalize_service_url 9 + 10 + 11 + class DidDocument: 12 + def __init__(self, raw_doc: dict[str, Any]) -> None: 13 + self.raw: dict[str, Any] = raw_doc 14 + self.atproto_pds: str | None = None 15 + 16 + def get_atproto_pds(self) -> str | None: 17 + if self.atproto_pds: 18 + return self.atproto_pds 19 + 20 + services = self.raw.get("service") 21 + if not services: 22 + return None 23 + 24 + for service in services: 25 + if ( 26 + service.get("id") == "#atproto_pds" 27 + and service.get("type") == "AtprotoPersonalDataServer" 28 + ): 29 + endpoint = service.get("serviceEndpoint") 30 + if endpoint: 31 + url = normalize_service_url(endpoint) 32 + self.atproto_pds = url 33 + return url 34 + self.atproto_pds = "" 35 + return None 36 + 6 37 7 38 class DidResolver: 8 39 def __init__(self, plc_host: str) -> None: 9 40 self.plc_host: str = plc_host 10 - self.__cache: TTLCache[str, dict[str, Any]] = TTLCache(ttl_seconds=12*60*60) 41 + self.__cache: TTLCache[str, DidDocument] = TTLCache(ttl_seconds=12 * 60 * 60) 11 42 12 - def try_resolve_plc(self, did: str) -> dict[str, Any] | None: 43 + def try_resolve_plc(self, did: str) -> DidDocument | None: 13 44 url = f"{self.plc_host}/{did}" 14 45 response = requests.get(url, timeout=10, allow_redirects=True) 15 46 16 47 if response.status_code == 200: 17 - return response.json() 48 + return DidDocument(response.json()) 18 49 elif response.status_code == 404 or response.status_code == 410: 19 - return None # tombstone or not registered 50 + return None # tombstone or not registered 20 51 else: 21 52 response.raise_for_status() 22 53 23 - def try_resolve_web(self, did: str) -> dict[str, Any] | None: 24 - url = f"http://{did[len('did:web:'):]}/.well-known/did.json" 54 + def try_resolve_web(self, did: str) -> DidDocument | None: 55 + url = f"http://{did[len('did:web:') :]}/.well-known/did.json" 25 56 response = requests.get(url, timeout=10, allow_redirects=True) 26 57 27 58 if response.status_code == 200: 28 - return response.json() 59 + return DidDocument(response.json()) 29 60 elif response.status_code == 404 or response.status_code == 410: 30 - return None # tombstone or gone 61 + return None # tombstone or gone 31 62 else: 32 63 response.raise_for_status() 33 64 34 - def resolve_did(self, did: str) -> dict[str, Any]: 65 + def resolve_did(self, did: str) -> DidDocument: 35 66 cached = self.__cache.get(did) 36 67 if cached: 37 68 return cached 38 69 39 - if did.startswith('did:plc:'): 70 + if did.startswith("did:plc:"): 40 71 from_plc = self.try_resolve_plc(did) 41 72 if from_plc: 42 73 self.__cache.set(did, from_plc) 43 74 return from_plc 44 - elif did.startswith('did:web:'): 75 + elif did.startswith("did:web:"): 45 76 from_web = self.try_resolve_web(did) 46 77 if from_web: 47 78 self.__cache.set(did, from_web) 48 79 return from_web 49 80 raise Exception(f"Failed to resolve {did}!") 81 + 50 82 51 83 class HandleResolver: 52 84 def __init__(self) -> None: ··· 59 91 60 92 for rdata in answers: 61 93 for txt_data in rdata.strings: 62 - did = txt_data.decode('utf-8').strip() 94 + did = txt_data.decode("utf-8").strip() 63 95 if did.startswith("did="): 64 96 return did[4:] 65 97 except dns.resolver.NXDOMAIN: ··· 82 114 else: 83 115 response.raise_for_status() 84 116 85 - 86 117 def resolve_handle(self, handle: str) -> str: 87 118 cached = self.__cache.get(handle) 88 119 if cached: ··· 99 130 return from_http 100 131 101 132 raise Exception(f"Failed to resolve handle {handle}!") 133 + 134 + 135 + handle_resolver = HandleResolver() 136 + did_resolver = DidResolver(env.PLC_HOST)
+35 -3
bluesky/info.py
··· 1 - from abc import ABC 1 + from abc import ABC, abstractmethod 2 2 from typing import Any 3 + 4 + from atproto.identity import did_resolver, handle_resolver 3 5 from cross.service import Service 4 - from util.util import normalize_service_url 6 + from util.util import LOGGER, normalize_service_url 7 + 8 + SERVICE = "https://bsky.app" 9 + 5 10 6 11 def validate_and_transform(data: dict[str, Any]): 7 12 if not data["handle"] and not data["did"]: ··· 14 19 15 20 if "pds" in data: 16 21 data["pds"] = normalize_service_url(data["pds"]) 22 + 17 23 18 24 class BlueskyService(ABC, Service): 19 - pass 25 + pds: str 26 + did: str 27 + 28 + def _init_identity(self) -> None: 29 + handle, did, pds = self.get_identity_options() 30 + 31 + if did and pds: 32 + self.did = did 33 + self.pds = pds 34 + return 35 + 36 + if not did: 37 + if not handle: 38 + raise KeyError("No did: or atproto handle provided!") 39 + LOGGER.info("Resolving ATP identity for %s...", handle) 40 + self.did = handle_resolver.resolve_handle(handle) 41 + 42 + if not pds: 43 + LOGGER.info("Resolving PDS from %s DID document...", did) 44 + atp_pds = did_resolver.resolve_did(self.did).get_atproto_pds() 45 + if not atp_pds: 46 + raise Exception("Failed to resolve atproto pds for %s") 47 + self.pds = atp_pds 48 + 49 + @abstractmethod 50 + def get_identity_options(self) -> tuple[str | None, str | None, str | None]: 51 + pass
+56 -7
bluesky/input.py
··· 1 + import asyncio 2 + import re 1 3 from abc import ABC 2 4 from dataclasses import dataclass, field 3 - import re 4 5 from typing import Any, Callable, override 5 6 6 - from bluesky.info import BlueskyService, validate_and_transform 7 + import websockets 8 + 9 + from bluesky.info import SERVICE, BlueskyService, validate_and_transform 7 10 from cross.service import InputService, OutputService 11 + from database.connection import DatabasePool 12 + from util.util import LOGGER, normalize_service_url 8 13 9 14 10 15 @dataclass(kw_only=True) 11 16 class BlueskyInputOptions: 12 - handle: str | None 13 - did: str | None 14 - pds: str | None 17 + handle: str | None = None 18 + did: str | None = None 19 + pds: str | None = None 15 20 filters: list[re.Pattern[str]] = field(default_factory=lambda: []) 16 21 17 22 @classmethod ··· 24 29 return BlueskyInputOptions(**data) 25 30 26 31 32 + @dataclass(kw_only=True) 33 + class BlueskyJetstreamInputOptions(BlueskyInputOptions): 34 + jetstream: str = "wss://jetstream2.us-west.bsky.network/subscribe" 35 + 36 + @classmethod 37 + def from_dict(cls, data: dict[str, Any]) -> "BlueskyJetstreamInputOptions": 38 + jetstream = data.pop("jetstream", None) 39 + 40 + base = BlueskyInputOptions.from_dict(data).__dict__.copy() 41 + if jetstream: 42 + base["jetstream"] = normalize_service_url(jetstream) 43 + 44 + return BlueskyJetstreamInputOptions(**base) 45 + 46 + 27 47 class BlueskyBaseInputService(BlueskyService, InputService, ABC): 28 - pass 48 + def __init__(self, db: DatabasePool) -> None: 49 + super().__init__(SERVICE, db) 29 50 30 51 31 52 class BlueskyJetstreamInputService(BlueskyBaseInputService): 53 + def __init__(self, db: DatabasePool, options: BlueskyJetstreamInputOptions) -> None: 54 + super().__init__(db) 55 + self.options: BlueskyJetstreamInputOptions = options 56 + self._init_identity() 57 + 58 + @override 59 + def get_identity_options(self) -> tuple[str | None, str | None, str | None]: 60 + return (self.options.handle, self.options.did, self.options.pds) 61 + 32 62 @override 33 63 async def listen( 34 64 self, 35 65 outputs: list[OutputService], 36 66 submitter: Callable[[Callable[[], None]], None], 37 67 ): 38 - return await super().listen(outputs, submitter) # TODO 68 + url = self.options.jetstream + "?" 69 + url += "wantedCollections=app.bsky.feed.post" 70 + url += "&wantedCollections=app.bsky.feed.repost" 71 + url += f"&wantedDids={self.did}" 72 + 73 + async for ws in websockets.connect(url): 74 + try: 75 + LOGGER.info("Listening to %s...", self.options.jetstream) 76 + 77 + async def listen_for_messages(): 78 + async for msg in ws: 79 + LOGGER.info(msg) # TODO 80 + 81 + listen = asyncio.create_task(listen_for_messages()) 82 + 83 + _ = await asyncio.gather(listen) 84 + except websockets.ConnectionClosedError as e: 85 + LOGGER.error(e, stack_info=True, exc_info=True) 86 + LOGGER.info("Reconnecting to %s...", self.options.jetstream) 87 + continue
+6 -9
cross/service.py
··· 1 1 import sqlite3 2 2 from abc import ABC, abstractmethod 3 - from pathlib import Path 4 3 from typing import Callable, cast 5 4 6 5 from cross.post import Post 7 - from database.connection import get_conn 6 + from database.connection import DatabasePool 8 7 from util.util import LOGGER 9 8 10 9 11 10 class Service: 12 - def __init__(self, url: str, db: Path) -> None: 11 + def __init__(self, url: str, db: DatabasePool) -> None: 13 12 self.url: str = url 14 - self.conn: sqlite3.Connection = get_conn(db) 13 + self.db: DatabasePool = db 14 + #self._lock: threading.Lock = threading.Lock() 15 15 16 16 def get_post(self, url: str, user: str, identifier: str) -> sqlite3.Row | None: 17 - cursor = self.conn.cursor() 17 + cursor = self.db.get_conn().cursor() 18 18 _ = cursor.execute( 19 19 """ 20 20 SELECT * FROM posts ··· 27 27 return cast(sqlite3.Row, cursor.fetchone()) 28 28 29 29 def get_post_by_id(self, id: int) -> sqlite3.Row | None: 30 - cursor = self.conn.cursor() 30 + cursor = self.db.get_conn().cursor() 31 31 _ = cursor.execute("SELECT * FROM posts WHERE id = ?", (id,)) 32 32 return cast(sqlite3.Row, cursor.fetchone()) 33 - 34 - def close(self): 35 - self.conn.close() 36 33 37 34 38 35 class OutputService(Service):
+17
database/connection.py
··· 1 1 import sqlite3 2 + import threading 2 3 from pathlib import Path 3 4 5 + 6 + class DatabasePool: 7 + def __init__(self, db: Path) -> None: 8 + self.db: Path = db 9 + self._local: threading.local = threading.local() 10 + self._conns: list[sqlite3.Connection] = [] 11 + 12 + def get_conn(self) -> sqlite3.Connection: 13 + if getattr(self._local, 'conn', None) is None: 14 + self._local.conn = get_conn(self.db) 15 + self._conns.append(self._local.conn) 16 + return self._local.conn 17 + 18 + def close(self): 19 + for c in self._conns: 20 + c.close() 4 21 5 22 def get_conn(db: Path) -> sqlite3.Connection: 6 23 conn = sqlite3.connect(db, autocommit=True, check_same_thread=False)
+1 -1
env.py
··· 3 3 DEV = bool(os.environ.get("DEV")) or False 4 4 DATA_DIR = os.environ.get("DATA_DIR") or "./data" 5 5 MIGRATIONS_DIR = os.environ.get("MIGRATIONS_DIR") or "./migrations" 6 - PLC_HOST = os.environ.get("PLC_HOST") or "http://plc.directory" 6 + PLC_HOST = os.environ.get("PLC_HOST") or "https://plc.wtf"
+7 -3
main.py
··· 5 5 from pathlib import Path 6 6 from typing import Callable 7 7 8 + from database.connection import DatabasePool 8 9 import env 9 10 from database.migrations import DatabaseMigrator 10 11 from registry import create_input_service, create_output_service ··· 34 35 finally: 35 36 migrator.close() 36 37 38 + db_pool = DatabasePool(database_path) 39 + 37 40 LOGGER.info("Bootstrapping registries...") 38 41 bootstrap() 39 42 ··· 48 51 if "outputs" not in settings: 49 52 raise KeyError("No `outputs` spicified in settings!") 50 53 51 - input = create_input_service(database_path, settings["input"]) 54 + input = create_input_service(db_pool, settings["input"]) 52 55 outputs = [ 53 - create_output_service(database_path, data) for data in settings["outputs"] 56 + create_output_service(db_pool, data) for data in settings["outputs"] 54 57 ] 55 58 56 59 LOGGER.info("Starting task worker...") ··· 72 75 thread = threading.Thread(target=worker, args=(task_queue,), daemon=True) 73 76 thread.start() 74 77 75 - LOGGER.info("Connecting to %s...", "TODO") # TODO 78 + LOGGER.info("Connecting to %s...", input.url) 76 79 try: 77 80 asyncio.run(input.listen(outputs, lambda c: task_queue.put(c))) 78 81 except KeyboardInterrupt: ··· 81 84 task_queue.join() 82 85 task_queue.put(None) 83 86 thread.join() 87 + db_pool.close() 84 88 85 89 86 90 if __name__ == "__main__":
+2 -2
mastodon/input.py
··· 1 1 import asyncio 2 2 import re 3 3 from dataclasses import dataclass, field 4 - from pathlib import Path 5 4 from typing import Any, Callable, override 6 5 7 6 import websockets 8 7 9 8 from cross.service import InputService, OutputService 9 + from database.connection import DatabasePool 10 10 from mastodon.info import MastodonService, validate_and_transform 11 11 from util.util import LOGGER 12 12 ··· 38 38 39 39 40 40 class MastodonInputService(MastodonService, InputService): 41 - def __init__(self, db: Path, options: MastodonInputOptions) -> None: 41 + def __init__(self, db: DatabasePool, options: MastodonInputOptions) -> None: 42 42 super().__init__(options.instance, db) 43 43 self.options: MastodonInputOptions = options 44 44
+2 -2
mastodon/output.py
··· 1 1 from dataclasses import dataclass 2 - from pathlib import Path 3 2 from typing import Any, override 4 3 5 4 from cross.service import OutputService 5 + from database.connection import DatabasePool 6 6 from mastodon.info import InstanceInfo, MastodonService, validate_and_transform 7 7 from util.util import LOGGER 8 8 ··· 28 28 29 29 # TODO 30 30 class MastodonOutputService(MastodonService, OutputService): 31 - def __init__(self, db: Path, options: MastodonOutputOptions) -> None: 31 + def __init__(self, db: DatabasePool, options: MastodonOutputOptions) -> None: 32 32 super().__init__(options.instance, db) 33 33 self.options: MastodonOutputOptions = options 34 34
+2 -2
misskey/input.py
··· 3 3 import re 4 4 import uuid 5 5 from dataclasses import dataclass, field 6 - from pathlib import Path 7 6 from typing import Any, Callable, override 8 7 9 8 import websockets 10 9 11 10 from cross.service import InputService, OutputService 11 + from database.connection import DatabasePool 12 12 from misskey.info import MisskeyService 13 13 from util.util import LOGGER, normalize_service_url 14 14 ··· 40 40 41 41 42 42 class MisskeyInputService(MisskeyService, InputService): 43 - def __init__(self, db: Path, options: MisskeyInputOptions) -> None: 43 + def __init__(self, db: DatabasePool, options: MisskeyInputOptions) -> None: 44 44 super().__init__(options.instance, db) 45 45 self.options: MisskeyInputOptions = options 46 46
+5 -4
registry.py
··· 2 2 from typing import Any, Callable 3 3 4 4 from cross.service import InputService, OutputService 5 + from database.connection import DatabasePool 5 6 6 - input_factories: dict[str, Callable[[Path, dict[str, Any]], InputService]] = {} 7 - output_factories: dict[str, Callable[[Path, dict[str, Any]], OutputService]] = {} 7 + input_factories: dict[str, Callable[[DatabasePool, dict[str, Any]], InputService]] = {} 8 + output_factories: dict[str, Callable[[DatabasePool, dict[str, Any]], OutputService]] = {} 8 9 9 10 10 - def create_input_service(db: Path, data: dict[str, Any]) -> InputService: 11 + def create_input_service(db: DatabasePool, data: dict[str, Any]) -> InputService: 11 12 if "type" not in data: 12 13 raise ValueError("No `type` field in input data!") 13 14 type: str = str(data["type"]) ··· 19 20 return factory(db, data) 20 21 21 22 22 - def create_output_service(db: Path, data: dict[str, Any]) -> OutputService: 23 + def create_output_service(db: DatabasePool, data: dict[str, Any]) -> OutputService: 23 24 if "type" not in data: 24 25 raise ValueError("No `type` field in input data!") 25 26 type: str = str(data["type"])
+5 -2
registry_bootstrap.py
··· 1 - from pathlib import Path 2 1 from typing import Any 3 2 3 + from database.connection import DatabasePool 4 4 from registry import input_factories, output_factories 5 5 6 6 ··· 10 10 self.class_name: str = class_name 11 11 self.options_class_name: str = options_class_name 12 12 13 - def __call__(self, db: Path, d: dict[str, Any]): 13 + def __call__(self, db: DatabasePool, d: dict[str, Any]): 14 14 module = __import__( 15 15 self.module_path, fromlist=[self.class_name, self.options_class_name] 16 16 ) ··· 26 26 input_factories["misskey-wss"] = LazyFactory( 27 27 "misskey.input", "MisskeyInputService", "MisskeyInputOptions" 28 28 ) 29 + input_factories["bluesky-jetstream"] = LazyFactory( 30 + "bluesky.input", "BlueskyJetstreamInputService", "BlueskyJetstreamInputOptions" 31 + )