cli + tui to publish to leaflet (wip) & manage tasks, notes & watch/read lists 馃崈
charm
leaflet
readability
golang
1package store
2
3import (
4 "database/sql"
5 "fmt"
6 "io/fs"
7 "sort"
8 "strings"
9)
10
11// Migration represents a single database migration
12type Migration struct {
13 Version string
14 Name string
15 UpSQL string
16 DownSQL string
17 Applied bool
18 AppliedAt string
19}
20
21// FileSystem interface for reading migration files
22type FileSystem interface {
23 ReadDir(name string) ([]fs.DirEntry, error)
24 ReadFile(name string) ([]byte, error)
25}
26
27// MigrationRunner handles database migrations
28type MigrationRunner struct {
29 db *sql.DB
30 migrationFiles FileSystem
31 runFn func() error // inject for testing
32}
33
34// CreateMigrationRunner creates a new migration runner
35func CreateMigrationRunner(db *sql.DB, files FileSystem) *MigrationRunner {
36 mr := &MigrationRunner{
37 db: db,
38 migrationFiles: files,
39 }
40 mr.runFn = mr.defaultRunMigrations
41 return mr
42}
43
44// RunMigrations applies all pending migrations (delegates to runFn)
45func (mr *MigrationRunner) RunMigrations() error {
46 if mr.runFn != nil {
47 return mr.runFn()
48 }
49 return nil
50}
51
52func (mr *MigrationRunner) defaultRunMigrations() error {
53 entries, err := mr.migrationFiles.ReadDir("sql/migrations")
54 if err != nil {
55 return fmt.Errorf("failed to read migrations directory: %w", err)
56 }
57
58 var upMigrations []string
59 for _, entry := range entries {
60 if strings.HasSuffix(entry.Name(), "_up.sql") {
61 upMigrations = append(upMigrations, entry.Name())
62 }
63 }
64 sort.Strings(upMigrations)
65
66 for _, migrationFile := range upMigrations {
67 version := extractVersionFromFilename(migrationFile)
68
69 var count int
70 err := mr.db.QueryRow("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='migrations'").Scan(&count)
71 if err != nil {
72 return fmt.Errorf("failed to check migrations table: %w", err)
73 }
74
75 if count == 0 && version != "0000" {
76 continue
77 }
78
79 if count > 0 {
80 var applied int
81 err = mr.db.QueryRow("SELECT COUNT(*) FROM migrations WHERE version = ?", version).Scan(&applied)
82 if err != nil {
83 return fmt.Errorf("failed to check migration %s: %w", version, err)
84 }
85 if applied > 0 {
86 continue
87 }
88 }
89
90 content, err := mr.migrationFiles.ReadFile("sql/migrations/" + migrationFile)
91 if err != nil {
92 return fmt.Errorf("failed to read migration %s: %w", migrationFile, err)
93 }
94
95 if _, err := mr.db.Exec(string(content)); err != nil {
96 return fmt.Errorf("failed to execute migration %s: %w", migrationFile, err)
97 }
98
99 if count > 0 || version == "0000" {
100 if _, err := mr.db.Exec("INSERT INTO migrations (version) VALUES (?)", version); err != nil {
101 return fmt.Errorf("failed to record migration %s: %w", version, err)
102 }
103 }
104 }
105
106 return nil
107}
108
109// GetAppliedMigrations returns a list of all applied migrations
110func (mr *MigrationRunner) GetAppliedMigrations() ([]Migration, error) {
111 var count int
112 err := mr.db.QueryRow("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='migrations'").Scan(&count)
113 if err != nil {
114 return nil, fmt.Errorf("failed to check migrations table: %w", err)
115 }
116
117 if count == 0 {
118 return []Migration{}, nil
119 }
120
121 rows, err := mr.db.Query("SELECT version, applied_at FROM migrations ORDER BY version")
122 if err != nil {
123 return nil, fmt.Errorf("failed to query migrations: %w", err)
124 }
125 defer rows.Close()
126
127 var migrations []Migration
128 for rows.Next() {
129 var m Migration
130 if err := rows.Scan(&m.Version, &m.AppliedAt); err != nil {
131 return nil, fmt.Errorf("failed to scan migration: %w", err)
132 }
133 m.Applied = true
134 migrations = append(migrations, m)
135 }
136
137 return migrations, nil
138}
139
140// GetAvailableMigrations returns all available migrations from embedded files
141func (mr *MigrationRunner) GetAvailableMigrations() ([]Migration, error) {
142 entries, err := mr.migrationFiles.ReadDir("sql/migrations")
143 if err != nil {
144 return nil, fmt.Errorf("failed to read migrations directory: %w", err)
145 }
146
147 migrationMap := make(map[string]*Migration)
148
149 for _, entry := range entries {
150 version := extractVersionFromFilename(entry.Name())
151 if version == "" {
152 continue
153 }
154
155 if migrationMap[version] == nil {
156 migrationMap[version] = &Migration{
157 Version: version,
158 Name: extractNameFromFilename(entry.Name()),
159 }
160 }
161
162 content, err := mr.migrationFiles.ReadFile("sql/migrations/" + entry.Name())
163 if err != nil {
164 return nil, fmt.Errorf("failed to read migration file %s: %w", entry.Name(), err)
165 }
166
167 if strings.HasSuffix(entry.Name(), "_up.sql") {
168 migrationMap[version].UpSQL = string(content)
169 } else if strings.HasSuffix(entry.Name(), "_down.sql") {
170 migrationMap[version].DownSQL = string(content)
171 }
172 }
173
174 var migrations []Migration
175 for _, m := range migrationMap {
176 migrations = append(migrations, *m)
177 }
178 sort.Slice(migrations, func(i, j int) bool {
179 return migrations[i].Version < migrations[j].Version
180 })
181
182 return migrations, nil
183}
184
185// Rollback rolls back the last applied migration
186func (mr *MigrationRunner) Rollback() error {
187 var version string
188 err := mr.db.QueryRow("SELECT version FROM migrations ORDER BY version DESC LIMIT 1").Scan(&version)
189 if err != nil {
190 if err == sql.ErrNoRows {
191 return fmt.Errorf("no migrations to rollback")
192 }
193 return fmt.Errorf("failed to get last migration: %w", err)
194 }
195
196 entries, err := mr.migrationFiles.ReadDir("sql/migrations")
197 if err != nil {
198 return fmt.Errorf("failed to read migrations directory: %w", err)
199 }
200
201 var downContent []byte
202 for _, entry := range entries {
203 if strings.HasPrefix(entry.Name(), version) && strings.HasSuffix(entry.Name(), "_down.sql") {
204 downContent, err = mr.migrationFiles.ReadFile("sql/migrations/" + entry.Name())
205 if err != nil {
206 return fmt.Errorf("failed to read down migration: %w", err)
207 }
208 break
209 }
210 }
211
212 if downContent == nil {
213 return fmt.Errorf("down migration not found for version %s", version)
214 }
215
216 if _, err := mr.db.Exec(string(downContent)); err != nil {
217 return fmt.Errorf("failed to execute down migration: %w", err)
218 }
219
220 if _, err := mr.db.Exec("DELETE FROM migrations WHERE version = ?", version); err != nil {
221 return fmt.Errorf("failed to remove migration record: %w", err)
222 }
223
224 return nil
225}
226
227// extractVersionFromFilename extracts the 4-digit version from a [Migration] filename
228func extractVersionFromFilename(filename string) string {
229 parts := strings.Split(filename, "_")
230 if len(parts) > 0 {
231 return parts[0]
232 }
233 return ""
234}
235
236func extractNameFromFilename(filename string) string {
237 parts := strings.Split(filename, "_")
238 if len(parts) < 3 {
239 return ""
240 }
241
242 name := strings.Join(parts[1:len(parts)-1], "_")
243 return strings.TrimSuffix(name, "_up")
244}