1import sqlite3
2from logging import Logger
3from sqlite3 import Connection
4from typing import override
5
6from flask import Flask, g
7
8from src.atproto.kv import KV as BaseKV
9
10
11class KV(BaseKV):
12 db: Connection
13 logger: Logger
14 prefix: str
15
16 def __init__(self, app: Connection | Flask, logger: Logger, prefix: str):
17 self.db = app if isinstance(app, Connection) else get_db(app)
18 self.logger = logger
19 self.prefix = prefix
20
21 @override
22 def get(self, key: str) -> str | None:
23 cursor = self.db.cursor()
24 row: dict[str, str] | None = cursor.execute(
25 "select value from keyval where prefix = ? and key = ?",
26 (self.prefix, key),
27 ).fetchone()
28 if row is not None:
29 self.logger.debug(f"returning cached {self.prefix}({key})")
30 return row["value"]
31 return None
32
33 @override
34 def set(self, key: str, value: str):
35 self.logger.debug(f"caching {self.prefix}({key}): {value}")
36 cursor = self.db.cursor()
37 _ = cursor.execute(
38 "insert or replace into keyval (prefix, key, value) values (?, ?, ?)",
39 (self.prefix, key, value),
40 )
41 self.db.commit()
42
43
44def get_db(app: Flask) -> sqlite3.Connection:
45 db: sqlite3.Connection | None = g.get("db", None)
46 if db is None:
47 db_path: str = app.config.get("DATABASE_URL", "ligoat.db")
48 db = g.db = sqlite3.connect(db_path, check_same_thread=False)
49 # return rows as dict-like objects
50 db.row_factory = sqlite3.Row
51 return db
52
53
54def close_db_connection(_exception: BaseException | None):
55 db: sqlite3.Connection | None = g.get("db", None)
56 if db is not None:
57 db.close()
58
59
60def init_db(app: Flask):
61 with app.app_context():
62 db = get_db(app)
63 with app.open_resource("schema.sql", mode="r") as schema:
64 _ = db.cursor().executescript(schema.read())
65 db.commit()