1import sqlite3
2from logging import Logger
3from sqlite3 import Connection
4from typing import Generic, cast, override
5
6from flask import Flask, g
7
8from src.atproto.kv import KV as BaseKV
9from src.atproto.kv import K, V
10
11
12class KV(BaseKV, Generic[K, V]):
13 db: Connection
14 logger: Logger
15 prefix: str
16
17 def __init__(self, app: Connection | Flask, logger: Logger, prefix: str):
18 self.db = app if isinstance(app, Connection) else get_db(app)
19 self.logger = logger
20 self.prefix = prefix
21
22 @override
23 def get(self, key: K) -> V | None:
24 cursor = self.db.cursor()
25 row: dict[str, str] | None = cursor.execute(
26 "select value from keyval where prefix = ? and key = ?",
27 (self.prefix, key),
28 ).fetchone()
29 if row is not None:
30 self.logger.debug(f"returning cached {self.prefix}({key})")
31 return cast(V, row["value"])
32 return None
33
34 @override
35 def set(self, key: K, value: V):
36 self.logger.debug(f"caching {self.prefix}({key}): {value}")
37 cursor = self.db.cursor()
38 _ = cursor.execute(
39 "insert or replace into keyval (prefix, key, value) values (?, ?, ?)",
40 (self.prefix, key, value),
41 )
42 self.db.commit()
43
44
45def get_db(app: Flask) -> sqlite3.Connection:
46 db: sqlite3.Connection | None = g.get("db", None)
47 if db is None:
48 db_path: str = app.config.get("DATABASE_URL", "ligoat.db")
49 db = g.db = sqlite3.connect(db_path, check_same_thread=False)
50 # return rows as dict-like objects
51 db.row_factory = sqlite3.Row
52 return db
53
54
55def close_db_connection(_exception: BaseException | None):
56 db: sqlite3.Connection | None = g.get("db", None)
57 if db is not None:
58 db.close()
59
60
61def init_db(app: Flask):
62 with app.app_context():
63 db = get_db(app)
64 with app.open_resource("schema.sql", mode="r") as schema:
65 _ = db.cursor().executescript(schema.read())
66 db.commit()