+12
-26
database/migrations.py
+12
-26
database/migrations.py
···
1
import sqlite3
2
from pathlib import Path
3
4
from database.connection import get_conn
5
from util.util import LOGGER
6
-
7
8
class DatabaseMigrator:
9
def __init__(self, db_path: Path, migrations_folder: Path) -> None:
10
self.db_path: Path = db_path
11
self.migrations_folder: Path = migrations_folder
12
self.conn: sqlite3.Connection = get_conn(db_path)
13
14
def close(self):
15
self.conn.close()
···
24
_ = cursor.execute(f"PRAGMA user_version = {version}")
25
self.conn.commit()
26
27
-
def get_migrations(self) -> list[tuple[int, Path]]:
28
-
if not self.migrations_folder.exists():
29
-
return []
30
-
31
-
files: list[tuple[int, Path]] = []
32
-
for f in self.migrations_folder.glob("*.sql"):
33
-
try:
34
-
version = int(f.stem.split("_")[0])
35
-
files.append((version, f))
36
-
except (ValueError, IndexError):
37
-
LOGGER.warning("Warning: Skipping invalid migration file: %", f.name)
38
-
39
-
return sorted(files, key=lambda x: x[0])
40
-
41
-
def apply_migration(self, version: int, path: Path):
42
-
with open(path, "r") as f:
43
-
sql = f.read()
44
-
45
-
cursor = self.conn.cursor()
46
try:
47
-
_ = cursor.executescript(sql)
48
self.set_version(version)
49
-
LOGGER.info("Applied migration: %s", path.name)
50
except sqlite3.Error as e:
51
self.conn.rollback()
52
-
raise Exception(f"Error applying migration {path.name}: {e}")
53
54
def migrate(self):
55
current_version = self.get_version()
56
-
migrations = self.get_migrations()
57
58
if not migrations:
59
LOGGER.warning("No migration files found.")
···
64
LOGGER.info("No pending migrations.")
65
return
66
67
-
for version, filepath in pending:
68
-
self.apply_migration(version, filepath)
···
1
import sqlite3
2
from pathlib import Path
3
+
from typing import Callable
4
5
from database.connection import get_conn
6
from util.util import LOGGER
7
8
class DatabaseMigrator:
9
def __init__(self, db_path: Path, migrations_folder: Path) -> None:
10
self.db_path: Path = db_path
11
self.migrations_folder: Path = migrations_folder
12
self.conn: sqlite3.Connection = get_conn(db_path)
13
+
_ = self.conn.execute("PRAGMA foreign_keys = OFF;")
14
+
self.conn.autocommit = False
15
16
def close(self):
17
self.conn.close()
···
26
_ = cursor.execute(f"PRAGMA user_version = {version}")
27
self.conn.commit()
28
29
+
def apply_migration(self, version: int, filename: str, migration: Callable[[sqlite3.Connection], None]):
30
try:
31
+
_ = migration(self.conn)
32
self.set_version(version)
33
+
self.conn.commit()
34
+
LOGGER.info("Applied migration: %s..", filename)
35
except sqlite3.Error as e:
36
self.conn.rollback()
37
+
raise Exception(f"Error applying migration {filename}: {e}")
38
39
def migrate(self):
40
current_version = self.get_version()
41
+
from migrations._registry import load_migrations
42
+
migrations = load_migrations(self.migrations_folder)
43
44
if not migrations:
45
LOGGER.warning("No migration files found.")
···
50
LOGGER.info("No pending migrations.")
51
return
52
53
+
for version, filename, migration in pending:
54
+
self.apply_migration(version, filename, migration)
-16
migrations/001_initdb.sql
-16
migrations/001_initdb.sql
···
1
-
CREATE TABLE IF NOT EXISTS posts (
2
-
id INTEGER UNIQUE PRIMARY KEY AUTOINCREMENT,
3
-
user TEXT NOT NULL,
4
-
service TEXT NOT NULL,
5
-
identifier TEXT NOT NULL,
6
-
parent INTEGER NULL REFERENCES posts(id),
7
-
root INTEGER NULL REFERENCES posts(id),
8
-
reposted INTEGER NULL REFERENCES posts(id),
9
-
extra_data TEXT NULL
10
-
);
11
-
12
-
CREATE TABLE IF NOT EXISTS mappings (
13
-
original INTEGER NOT NULL REFERENCES posts(id) ON DELETE CASCADE,
14
-
mapped INTEGER NOT NULL REFERENCES posts(id) ON DELETE CASCADE,
15
-
UNIQUE(original, mapped)
16
-
);
···
+21
migrations/001_initdb_v1.py
+21
migrations/001_initdb_v1.py
···
···
1
+
import sqlite3
2
+
3
+
4
+
def migrate(conn: sqlite3.Connection):
5
+
_ = conn.execute("""
6
+
CREATE TABLE IF NOT EXISTS posts (
7
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
8
+
user_id TEXT NOT NULL,
9
+
service TEXT NOT NULL,
10
+
identifier TEXT NOT NULL,
11
+
parent_id INTEGER NULL REFERENCES posts(id) ON DELETE SET NULL,
12
+
root_id INTEGER NULL REFERENCES posts(id) ON DELETE SET NULL
13
+
);
14
+
""")
15
+
_ = conn.execute("""
16
+
CREATE TABLE IF NOT EXISTS mappings (
17
+
original_post_id INTEGER NOT NULL REFERENCES posts(id) ON DELETE CASCADE,
18
+
mapped_post_id INTEGER NOT NULL
19
+
);
20
+
""")
21
+
pass
-5
migrations/002_add_indexes.sql
-5
migrations/002_add_indexes.sql
+11
migrations/002_add_reposted_column_v1.py
+11
migrations/002_add_reposted_column_v1.py
···
···
1
+
import sqlite3
2
+
3
+
4
+
def migrate(conn: sqlite3.Connection):
5
+
columns = conn.execute("PRAGMA table_info(posts)")
6
+
column_names = [col[1] for col in columns]
7
+
if "reposted_id" not in column_names:
8
+
_ = conn.execute("""
9
+
ALTER TABLE posts
10
+
ADD COLUMN reposted_id INTEGER NULL REFERENCES posts(id) ON DELETE SET NULL
11
+
""")
+22
migrations/003_add_extra_data_column_v1.py
+22
migrations/003_add_extra_data_column_v1.py
···
···
1
+
import json
2
+
import sqlite3
3
+
4
+
5
+
def migrate(conn: sqlite3.Connection):
6
+
columns = conn.execute("PRAGMA table_info(posts)")
7
+
column_names = [col[1] for col in columns]
8
+
if "extra_data" not in column_names:
9
+
_ = conn.execute("""
10
+
ALTER TABLE posts
11
+
ADD COLUMN extra_data TEXT NULL
12
+
""")
13
+
14
+
# migrate old bsky identifiers from json to uri as id and cid in extra_data
15
+
data = conn.execute("SELECT id, identifier FROM posts WHERE service = 'https://bsky.app';").fetchall()
16
+
rewrites: list[tuple[str, str, int]] = []
17
+
for row in data:
18
+
if row[1][0] == '{' and row[1][-1] == '}':
19
+
data = json.loads(row[1])
20
+
rewrites.append((data['uri'], json.dumps({'cid': data['cid']}), row[0]))
21
+
if rewrites:
22
+
_ = conn.executemany("UPDATE posts SET identifier = ?, extra_data = ? WHERE id = ?;", rewrites)
+52
migrations/004_initdb_next.py
+52
migrations/004_initdb_next.py
···
···
1
+
import sqlite3
2
+
3
+
4
+
def migrate(conn: sqlite3.Connection):
5
+
cursor = conn.cursor()
6
+
7
+
old_posts = cursor.execute("SELECT * FROM posts;").fetchall()
8
+
old_mappings = cursor.execute("SELECT * FROM mappings;").fetchall()
9
+
10
+
_ = cursor.execute("DROP TABLE posts;")
11
+
_ = cursor.execute("DROP TABLE mappings;")
12
+
13
+
_ = cursor.execute("""
14
+
CREATE TABLE posts (
15
+
id INTEGER UNIQUE PRIMARY KEY AUTOINCREMENT,
16
+
user TEXT NOT NULL,
17
+
service TEXT NOT NULL,
18
+
identifier TEXT NOT NULL,
19
+
parent INTEGER NULL REFERENCES posts(id),
20
+
root INTEGER NULL REFERENCES posts(id),
21
+
reposted INTEGER NULL REFERENCES posts(id),
22
+
extra_data TEXT NULL
23
+
);
24
+
""")
25
+
26
+
_ = cursor.execute("""
27
+
CREATE TABLE mappings (
28
+
original INTEGER NOT NULL REFERENCES posts(id) ON DELETE CASCADE,
29
+
mapped INTEGER NOT NULL REFERENCES posts(id) ON DELETE CASCADE,
30
+
UNIQUE(original, mapped)
31
+
);
32
+
""")
33
+
34
+
for old_post in old_posts:
35
+
_ = cursor.execute(
36
+
"""
37
+
INSERT INTO posts (id, user, service, identifier, parent, root, reposted, extra_data)
38
+
VALUES (:id, :user_id, :service, :identifier, :parent_id, :root_id, :reposted_id, :extra_data)
39
+
""",
40
+
dict(old_post),
41
+
)
42
+
43
+
for mapping in old_mappings:
44
+
original, mapped = mapping["original_post_id"], mapping["mapped_post_id"]
45
+
_ = cursor.execute(
46
+
"INSERT OR IGNORE INTO mappings (original, mapped) VALUES (?, ?)",
47
+
(original, mapped),
48
+
)
49
+
_ = cursor.execute(
50
+
"INSERT OR IGNORE INTO mappings (original, mapped) VALUES (?, ?)",
51
+
(mapped, original),
52
+
)
+12
migrations/005_add_indexes.py
+12
migrations/005_add_indexes.py
···
···
1
+
import sqlite3
2
+
3
+
4
+
def migrate(conn: sqlite3.Connection):
5
+
_ = conn.execute("""
6
+
CREATE INDEX IF NOT EXISTS idx_posts_service_user_identifier
7
+
ON posts (service, user, identifier);
8
+
""")
9
+
_ = conn.execute("""
10
+
CREATE UNIQUE INDEX IF NOT EXISTS ux_mappings_original_mapped
11
+
ON mappings (original, mapped);
12
+
""")
+35
migrations/_registry.py
+35
migrations/_registry.py
···
···
1
+
import importlib.util
2
+
from pathlib import Path
3
+
import sqlite3
4
+
from typing import Callable
5
+
6
+
7
+
def load_migrations(path: Path) -> list[tuple[int, str, Callable[[sqlite3.Connection], None]]]:
8
+
migrations: list[tuple[int, str, Callable[[sqlite3.Connection], None]]] = []
9
+
migration_files = sorted(
10
+
[f for f in path.glob("*.py") if not f.stem.startswith("_")]
11
+
)
12
+
13
+
for filepath in migration_files:
14
+
filename = filepath.stem
15
+
version_str = filename.split("_")[0]
16
+
17
+
try:
18
+
version = int(version_str)
19
+
except ValueError:
20
+
raise ValueError('migrations must start with a number!!')
21
+
22
+
spec = importlib.util.spec_from_file_location(filepath.stem, filepath)
23
+
if not spec or not spec.loader:
24
+
raise Exception(f"Failed to load spec from file: {filepath}")
25
+
26
+
module = importlib.util.module_from_spec(spec)
27
+
spec.loader.exec_module(module)
28
+
29
+
if hasattr(module, "migrate"):
30
+
migrations.append((version, filename, module.migrate))
31
+
else:
32
+
raise ValueError(f"Migration {filepath.name} missing 'migrate' function")
33
+
34
+
migrations.sort(key=lambda x: x[0])
35
+
return migrations