+49
-14
atproto/identity.py
+49
-14
atproto/identity.py
···
1
1
from typing import Any
2
+
2
3
import dns.resolver
3
4
import requests
5
+
6
+
import env
4
7
from util.cache import TTLCache
5
-
from util.util import LOGGER
8
+
from util.util import LOGGER, normalize_service_url
9
+
10
+
11
+
class DidDocument:
12
+
def __init__(self, raw_doc: dict[str, Any]) -> None:
13
+
self.raw: dict[str, Any] = raw_doc
14
+
self.atproto_pds: str | None = None
15
+
16
+
def get_atproto_pds(self) -> str | None:
17
+
if self.atproto_pds:
18
+
return self.atproto_pds
19
+
20
+
services = self.raw.get("service")
21
+
if not services:
22
+
return None
23
+
24
+
for service in services:
25
+
if (
26
+
service.get("id") == "#atproto_pds"
27
+
and service.get("type") == "AtprotoPersonalDataServer"
28
+
):
29
+
endpoint = service.get("serviceEndpoint")
30
+
if endpoint:
31
+
url = normalize_service_url(endpoint)
32
+
self.atproto_pds = url
33
+
return url
34
+
self.atproto_pds = ""
35
+
return None
36
+
6
37
7
38
class DidResolver:
8
39
def __init__(self, plc_host: str) -> None:
9
40
self.plc_host: str = plc_host
10
-
self.__cache: TTLCache[str, dict[str, Any]] = TTLCache(ttl_seconds=12*60*60)
41
+
self.__cache: TTLCache[str, DidDocument] = TTLCache(ttl_seconds=12 * 60 * 60)
11
42
12
-
def try_resolve_plc(self, did: str) -> dict[str, Any] | None:
43
+
def try_resolve_plc(self, did: str) -> DidDocument | None:
13
44
url = f"{self.plc_host}/{did}"
14
45
response = requests.get(url, timeout=10, allow_redirects=True)
15
46
16
47
if response.status_code == 200:
17
-
return response.json()
48
+
return DidDocument(response.json())
18
49
elif response.status_code == 404 or response.status_code == 410:
19
-
return None # tombstone or not registered
50
+
return None # tombstone or not registered
20
51
else:
21
52
response.raise_for_status()
22
53
23
-
def try_resolve_web(self, did: str) -> dict[str, Any] | None:
24
-
url = f"http://{did[len('did:web:'):]}/.well-known/did.json"
54
+
def try_resolve_web(self, did: str) -> DidDocument | None:
55
+
url = f"http://{did[len('did:web:') :]}/.well-known/did.json"
25
56
response = requests.get(url, timeout=10, allow_redirects=True)
26
57
27
58
if response.status_code == 200:
28
-
return response.json()
59
+
return DidDocument(response.json())
29
60
elif response.status_code == 404 or response.status_code == 410:
30
-
return None # tombstone or gone
61
+
return None # tombstone or gone
31
62
else:
32
63
response.raise_for_status()
33
64
34
-
def resolve_did(self, did: str) -> dict[str, Any]:
65
+
def resolve_did(self, did: str) -> DidDocument:
35
66
cached = self.__cache.get(did)
36
67
if cached:
37
68
return cached
38
69
39
-
if did.startswith('did:plc:'):
70
+
if did.startswith("did:plc:"):
40
71
from_plc = self.try_resolve_plc(did)
41
72
if from_plc:
42
73
self.__cache.set(did, from_plc)
43
74
return from_plc
44
-
elif did.startswith('did:web:'):
75
+
elif did.startswith("did:web:"):
45
76
from_web = self.try_resolve_web(did)
46
77
if from_web:
47
78
self.__cache.set(did, from_web)
48
79
return from_web
49
80
raise Exception(f"Failed to resolve {did}!")
81
+
50
82
51
83
class HandleResolver:
52
84
def __init__(self) -> None:
···
59
91
60
92
for rdata in answers:
61
93
for txt_data in rdata.strings:
62
-
did = txt_data.decode('utf-8').strip()
94
+
did = txt_data.decode("utf-8").strip()
63
95
if did.startswith("did="):
64
96
return did[4:]
65
97
except dns.resolver.NXDOMAIN:
···
82
114
else:
83
115
response.raise_for_status()
84
116
85
-
86
117
def resolve_handle(self, handle: str) -> str:
87
118
cached = self.__cache.get(handle)
88
119
if cached:
···
99
130
return from_http
100
131
101
132
raise Exception(f"Failed to resolve handle {handle}!")
133
+
134
+
135
+
handle_resolver = HandleResolver()
136
+
did_resolver = DidResolver(env.PLC_HOST)
+35
-3
bluesky/info.py
+35
-3
bluesky/info.py
···
1
-
from abc import ABC
1
+
from abc import ABC, abstractmethod
2
2
from typing import Any
3
+
4
+
from atproto.identity import did_resolver, handle_resolver
3
5
from cross.service import Service
4
-
from util.util import normalize_service_url
6
+
from util.util import LOGGER, normalize_service_url
7
+
8
+
SERVICE = "https://bsky.app"
9
+
5
10
6
11
def validate_and_transform(data: dict[str, Any]):
7
12
if not data["handle"] and not data["did"]:
···
14
19
15
20
if "pds" in data:
16
21
data["pds"] = normalize_service_url(data["pds"])
22
+
17
23
18
24
class BlueskyService(ABC, Service):
19
-
pass
25
+
pds: str
26
+
did: str
27
+
28
+
def _init_identity(self) -> None:
29
+
handle, did, pds = self.get_identity_options()
30
+
31
+
if did and pds:
32
+
self.did = did
33
+
self.pds = pds
34
+
return
35
+
36
+
if not did:
37
+
if not handle:
38
+
raise KeyError("No did: or atproto handle provided!")
39
+
LOGGER.info("Resolving ATP identity for %s...", handle)
40
+
self.did = handle_resolver.resolve_handle(handle)
41
+
42
+
if not pds:
43
+
LOGGER.info("Resolving PDS from %s DID document...", did)
44
+
atp_pds = did_resolver.resolve_did(self.did).get_atproto_pds()
45
+
if not atp_pds:
46
+
raise Exception("Failed to resolve atproto pds for %s")
47
+
self.pds = atp_pds
48
+
49
+
@abstractmethod
50
+
def get_identity_options(self) -> tuple[str | None, str | None, str | None]:
51
+
pass
+56
-7
bluesky/input.py
+56
-7
bluesky/input.py
···
1
+
import asyncio
2
+
import re
1
3
from abc import ABC
2
4
from dataclasses import dataclass, field
3
-
import re
4
5
from typing import Any, Callable, override
5
6
6
-
from bluesky.info import BlueskyService, validate_and_transform
7
+
import websockets
8
+
9
+
from bluesky.info import SERVICE, BlueskyService, validate_and_transform
7
10
from cross.service import InputService, OutputService
11
+
from database.connection import DatabasePool
12
+
from util.util import LOGGER, normalize_service_url
8
13
9
14
10
15
@dataclass(kw_only=True)
11
16
class BlueskyInputOptions:
12
-
handle: str | None
13
-
did: str | None
14
-
pds: str | None
17
+
handle: str | None = None
18
+
did: str | None = None
19
+
pds: str | None = None
15
20
filters: list[re.Pattern[str]] = field(default_factory=lambda: [])
16
21
17
22
@classmethod
···
24
29
return BlueskyInputOptions(**data)
25
30
26
31
32
+
@dataclass(kw_only=True)
33
+
class BlueskyJetstreamInputOptions(BlueskyInputOptions):
34
+
jetstream: str = "wss://jetstream2.us-west.bsky.network/subscribe"
35
+
36
+
@classmethod
37
+
def from_dict(cls, data: dict[str, Any]) -> "BlueskyJetstreamInputOptions":
38
+
jetstream = data.pop("jetstream", None)
39
+
40
+
base = BlueskyInputOptions.from_dict(data).__dict__.copy()
41
+
if jetstream:
42
+
base["jetstream"] = normalize_service_url(jetstream)
43
+
44
+
return BlueskyJetstreamInputOptions(**base)
45
+
46
+
27
47
class BlueskyBaseInputService(BlueskyService, InputService, ABC):
28
-
pass
48
+
def __init__(self, db: DatabasePool) -> None:
49
+
super().__init__(SERVICE, db)
29
50
30
51
31
52
class BlueskyJetstreamInputService(BlueskyBaseInputService):
53
+
def __init__(self, db: DatabasePool, options: BlueskyJetstreamInputOptions) -> None:
54
+
super().__init__(db)
55
+
self.options: BlueskyJetstreamInputOptions = options
56
+
self._init_identity()
57
+
58
+
@override
59
+
def get_identity_options(self) -> tuple[str | None, str | None, str | None]:
60
+
return (self.options.handle, self.options.did, self.options.pds)
61
+
32
62
@override
33
63
async def listen(
34
64
self,
35
65
outputs: list[OutputService],
36
66
submitter: Callable[[Callable[[], None]], None],
37
67
):
38
-
return await super().listen(outputs, submitter) # TODO
68
+
url = self.options.jetstream + "?"
69
+
url += "wantedCollections=app.bsky.feed.post"
70
+
url += "&wantedCollections=app.bsky.feed.repost"
71
+
url += f"&wantedDids={self.did}"
72
+
73
+
async for ws in websockets.connect(url):
74
+
try:
75
+
LOGGER.info("Listening to %s...", self.options.jetstream)
76
+
77
+
async def listen_for_messages():
78
+
async for msg in ws:
79
+
LOGGER.info(msg) # TODO
80
+
81
+
listen = asyncio.create_task(listen_for_messages())
82
+
83
+
_ = await asyncio.gather(listen)
84
+
except websockets.ConnectionClosedError as e:
85
+
LOGGER.error(e, stack_info=True, exc_info=True)
86
+
LOGGER.info("Reconnecting to %s...", self.options.jetstream)
87
+
continue
+6
-9
cross/service.py
+6
-9
cross/service.py
···
1
1
import sqlite3
2
2
from abc import ABC, abstractmethod
3
-
from pathlib import Path
4
3
from typing import Callable, cast
5
4
6
5
from cross.post import Post
7
-
from database.connection import get_conn
6
+
from database.connection import DatabasePool
8
7
from util.util import LOGGER
9
8
10
9
11
10
class Service:
12
-
def __init__(self, url: str, db: Path) -> None:
11
+
def __init__(self, url: str, db: DatabasePool) -> None:
13
12
self.url: str = url
14
-
self.conn: sqlite3.Connection = get_conn(db)
13
+
self.db: DatabasePool = db
14
+
#self._lock: threading.Lock = threading.Lock()
15
15
16
16
def get_post(self, url: str, user: str, identifier: str) -> sqlite3.Row | None:
17
-
cursor = self.conn.cursor()
17
+
cursor = self.db.get_conn().cursor()
18
18
_ = cursor.execute(
19
19
"""
20
20
SELECT * FROM posts
···
27
27
return cast(sqlite3.Row, cursor.fetchone())
28
28
29
29
def get_post_by_id(self, id: int) -> sqlite3.Row | None:
30
-
cursor = self.conn.cursor()
30
+
cursor = self.db.get_conn().cursor()
31
31
_ = cursor.execute("SELECT * FROM posts WHERE id = ?", (id,))
32
32
return cast(sqlite3.Row, cursor.fetchone())
33
-
34
-
def close(self):
35
-
self.conn.close()
36
33
37
34
38
35
class OutputService(Service):
+17
database/connection.py
+17
database/connection.py
···
1
1
import sqlite3
2
+
import threading
2
3
from pathlib import Path
3
4
5
+
6
+
class DatabasePool:
7
+
def __init__(self, db: Path) -> None:
8
+
self.db: Path = db
9
+
self._local: threading.local = threading.local()
10
+
self._conns: list[sqlite3.Connection] = []
11
+
12
+
def get_conn(self) -> sqlite3.Connection:
13
+
if getattr(self._local, 'conn', None) is None:
14
+
self._local.conn = get_conn(self.db)
15
+
self._conns.append(self._local.conn)
16
+
return self._local.conn
17
+
18
+
def close(self):
19
+
for c in self._conns:
20
+
c.close()
4
21
5
22
def get_conn(db: Path) -> sqlite3.Connection:
6
23
conn = sqlite3.connect(db, autocommit=True, check_same_thread=False)
+1
-1
env.py
+1
-1
env.py
···
3
3
DEV = bool(os.environ.get("DEV")) or False
4
4
DATA_DIR = os.environ.get("DATA_DIR") or "./data"
5
5
MIGRATIONS_DIR = os.environ.get("MIGRATIONS_DIR") or "./migrations"
6
-
PLC_HOST = os.environ.get("PLC_HOST") or "http://plc.directory"
6
+
PLC_HOST = os.environ.get("PLC_HOST") or "https://plc.wtf"
+7
-3
main.py
+7
-3
main.py
···
5
5
from pathlib import Path
6
6
from typing import Callable
7
7
8
+
from database.connection import DatabasePool
8
9
import env
9
10
from database.migrations import DatabaseMigrator
10
11
from registry import create_input_service, create_output_service
···
34
35
finally:
35
36
migrator.close()
36
37
38
+
db_pool = DatabasePool(database_path)
39
+
37
40
LOGGER.info("Bootstrapping registries...")
38
41
bootstrap()
39
42
···
48
51
if "outputs" not in settings:
49
52
raise KeyError("No `outputs` spicified in settings!")
50
53
51
-
input = create_input_service(database_path, settings["input"])
54
+
input = create_input_service(db_pool, settings["input"])
52
55
outputs = [
53
-
create_output_service(database_path, data) for data in settings["outputs"]
56
+
create_output_service(db_pool, data) for data in settings["outputs"]
54
57
]
55
58
56
59
LOGGER.info("Starting task worker...")
···
72
75
thread = threading.Thread(target=worker, args=(task_queue,), daemon=True)
73
76
thread.start()
74
77
75
-
LOGGER.info("Connecting to %s...", "TODO") # TODO
78
+
LOGGER.info("Connecting to %s...", input.url)
76
79
try:
77
80
asyncio.run(input.listen(outputs, lambda c: task_queue.put(c)))
78
81
except KeyboardInterrupt:
···
81
84
task_queue.join()
82
85
task_queue.put(None)
83
86
thread.join()
87
+
db_pool.close()
84
88
85
89
86
90
if __name__ == "__main__":
+2
-2
mastodon/input.py
+2
-2
mastodon/input.py
···
1
1
import asyncio
2
2
import re
3
3
from dataclasses import dataclass, field
4
-
from pathlib import Path
5
4
from typing import Any, Callable, override
6
5
7
6
import websockets
8
7
9
8
from cross.service import InputService, OutputService
9
+
from database.connection import DatabasePool
10
10
from mastodon.info import MastodonService, validate_and_transform
11
11
from util.util import LOGGER
12
12
···
38
38
39
39
40
40
class MastodonInputService(MastodonService, InputService):
41
-
def __init__(self, db: Path, options: MastodonInputOptions) -> None:
41
+
def __init__(self, db: DatabasePool, options: MastodonInputOptions) -> None:
42
42
super().__init__(options.instance, db)
43
43
self.options: MastodonInputOptions = options
44
44
+2
-2
mastodon/output.py
+2
-2
mastodon/output.py
···
1
1
from dataclasses import dataclass
2
-
from pathlib import Path
3
2
from typing import Any, override
4
3
5
4
from cross.service import OutputService
5
+
from database.connection import DatabasePool
6
6
from mastodon.info import InstanceInfo, MastodonService, validate_and_transform
7
7
from util.util import LOGGER
8
8
···
28
28
29
29
# TODO
30
30
class MastodonOutputService(MastodonService, OutputService):
31
-
def __init__(self, db: Path, options: MastodonOutputOptions) -> None:
31
+
def __init__(self, db: DatabasePool, options: MastodonOutputOptions) -> None:
32
32
super().__init__(options.instance, db)
33
33
self.options: MastodonOutputOptions = options
34
34
+2
-2
misskey/input.py
+2
-2
misskey/input.py
···
3
3
import re
4
4
import uuid
5
5
from dataclasses import dataclass, field
6
-
from pathlib import Path
7
6
from typing import Any, Callable, override
8
7
9
8
import websockets
10
9
11
10
from cross.service import InputService, OutputService
11
+
from database.connection import DatabasePool
12
12
from misskey.info import MisskeyService
13
13
from util.util import LOGGER, normalize_service_url
14
14
···
40
40
41
41
42
42
class MisskeyInputService(MisskeyService, InputService):
43
-
def __init__(self, db: Path, options: MisskeyInputOptions) -> None:
43
+
def __init__(self, db: DatabasePool, options: MisskeyInputOptions) -> None:
44
44
super().__init__(options.instance, db)
45
45
self.options: MisskeyInputOptions = options
46
46
+5
-4
registry.py
+5
-4
registry.py
···
2
2
from typing import Any, Callable
3
3
4
4
from cross.service import InputService, OutputService
5
+
from database.connection import DatabasePool
5
6
6
-
input_factories: dict[str, Callable[[Path, dict[str, Any]], InputService]] = {}
7
-
output_factories: dict[str, Callable[[Path, dict[str, Any]], OutputService]] = {}
7
+
input_factories: dict[str, Callable[[DatabasePool, dict[str, Any]], InputService]] = {}
8
+
output_factories: dict[str, Callable[[DatabasePool, dict[str, Any]], OutputService]] = {}
8
9
9
10
10
-
def create_input_service(db: Path, data: dict[str, Any]) -> InputService:
11
+
def create_input_service(db: DatabasePool, data: dict[str, Any]) -> InputService:
11
12
if "type" not in data:
12
13
raise ValueError("No `type` field in input data!")
13
14
type: str = str(data["type"])
···
19
20
return factory(db, data)
20
21
21
22
22
-
def create_output_service(db: Path, data: dict[str, Any]) -> OutputService:
23
+
def create_output_service(db: DatabasePool, data: dict[str, Any]) -> OutputService:
23
24
if "type" not in data:
24
25
raise ValueError("No `type` field in input data!")
25
26
type: str = str(data["type"])
+5
-2
registry_bootstrap.py
+5
-2
registry_bootstrap.py
···
1
-
from pathlib import Path
2
1
from typing import Any
3
2
3
+
from database.connection import DatabasePool
4
4
from registry import input_factories, output_factories
5
5
6
6
···
10
10
self.class_name: str = class_name
11
11
self.options_class_name: str = options_class_name
12
12
13
-
def __call__(self, db: Path, d: dict[str, Any]):
13
+
def __call__(self, db: DatabasePool, d: dict[str, Any]):
14
14
module = __import__(
15
15
self.module_path, fromlist=[self.class_name, self.options_class_name]
16
16
)
···
26
26
input_factories["misskey-wss"] = LazyFactory(
27
27
"misskey.input", "MisskeyInputService", "MisskeyInputOptions"
28
28
)
29
+
input_factories["bluesky-jetstream"] = LazyFactory(
30
+
"bluesky.input", "BlueskyJetstreamInputService", "BlueskyJetstreamInputOptions"
31
+
)