cli + tui to publish to leaflet (wip) & manage tasks, notes & watch/read lists 馃崈
charm leaflet readability golang
at main 244 lines 6.6 kB view raw
1package store 2 3import ( 4 "database/sql" 5 "fmt" 6 "io/fs" 7 "sort" 8 "strings" 9) 10 11// Migration represents a single database migration 12type Migration struct { 13 Version string 14 Name string 15 UpSQL string 16 DownSQL string 17 Applied bool 18 AppliedAt string 19} 20 21// FileSystem interface for reading migration files 22type FileSystem interface { 23 ReadDir(name string) ([]fs.DirEntry, error) 24 ReadFile(name string) ([]byte, error) 25} 26 27// MigrationRunner handles database migrations 28type MigrationRunner struct { 29 db *sql.DB 30 migrationFiles FileSystem 31 runFn func() error // inject for testing 32} 33 34// CreateMigrationRunner creates a new migration runner 35func CreateMigrationRunner(db *sql.DB, files FileSystem) *MigrationRunner { 36 mr := &MigrationRunner{ 37 db: db, 38 migrationFiles: files, 39 } 40 mr.runFn = mr.defaultRunMigrations 41 return mr 42} 43 44// RunMigrations applies all pending migrations (delegates to runFn) 45func (mr *MigrationRunner) RunMigrations() error { 46 if mr.runFn != nil { 47 return mr.runFn() 48 } 49 return nil 50} 51 52func (mr *MigrationRunner) defaultRunMigrations() error { 53 entries, err := mr.migrationFiles.ReadDir("sql/migrations") 54 if err != nil { 55 return fmt.Errorf("failed to read migrations directory: %w", err) 56 } 57 58 var upMigrations []string 59 for _, entry := range entries { 60 if strings.HasSuffix(entry.Name(), "_up.sql") { 61 upMigrations = append(upMigrations, entry.Name()) 62 } 63 } 64 sort.Strings(upMigrations) 65 66 for _, migrationFile := range upMigrations { 67 version := extractVersionFromFilename(migrationFile) 68 69 var count int 70 err := mr.db.QueryRow("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='migrations'").Scan(&count) 71 if err != nil { 72 return fmt.Errorf("failed to check migrations table: %w", err) 73 } 74 75 if count == 0 && version != "0000" { 76 continue 77 } 78 79 if count > 0 { 80 var applied int 81 err = mr.db.QueryRow("SELECT COUNT(*) FROM migrations WHERE version = ?", version).Scan(&applied) 82 if err != nil { 83 return fmt.Errorf("failed to check migration %s: %w", version, err) 84 } 85 if applied > 0 { 86 continue 87 } 88 } 89 90 content, err := mr.migrationFiles.ReadFile("sql/migrations/" + migrationFile) 91 if err != nil { 92 return fmt.Errorf("failed to read migration %s: %w", migrationFile, err) 93 } 94 95 if _, err := mr.db.Exec(string(content)); err != nil { 96 return fmt.Errorf("failed to execute migration %s: %w", migrationFile, err) 97 } 98 99 if count > 0 || version == "0000" { 100 if _, err := mr.db.Exec("INSERT INTO migrations (version) VALUES (?)", version); err != nil { 101 return fmt.Errorf("failed to record migration %s: %w", version, err) 102 } 103 } 104 } 105 106 return nil 107} 108 109// GetAppliedMigrations returns a list of all applied migrations 110func (mr *MigrationRunner) GetAppliedMigrations() ([]Migration, error) { 111 var count int 112 err := mr.db.QueryRow("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='migrations'").Scan(&count) 113 if err != nil { 114 return nil, fmt.Errorf("failed to check migrations table: %w", err) 115 } 116 117 if count == 0 { 118 return []Migration{}, nil 119 } 120 121 rows, err := mr.db.Query("SELECT version, applied_at FROM migrations ORDER BY version") 122 if err != nil { 123 return nil, fmt.Errorf("failed to query migrations: %w", err) 124 } 125 defer rows.Close() 126 127 var migrations []Migration 128 for rows.Next() { 129 var m Migration 130 if err := rows.Scan(&m.Version, &m.AppliedAt); err != nil { 131 return nil, fmt.Errorf("failed to scan migration: %w", err) 132 } 133 m.Applied = true 134 migrations = append(migrations, m) 135 } 136 137 return migrations, nil 138} 139 140// GetAvailableMigrations returns all available migrations from embedded files 141func (mr *MigrationRunner) GetAvailableMigrations() ([]Migration, error) { 142 entries, err := mr.migrationFiles.ReadDir("sql/migrations") 143 if err != nil { 144 return nil, fmt.Errorf("failed to read migrations directory: %w", err) 145 } 146 147 migrationMap := make(map[string]*Migration) 148 149 for _, entry := range entries { 150 version := extractVersionFromFilename(entry.Name()) 151 if version == "" { 152 continue 153 } 154 155 if migrationMap[version] == nil { 156 migrationMap[version] = &Migration{ 157 Version: version, 158 Name: extractNameFromFilename(entry.Name()), 159 } 160 } 161 162 content, err := mr.migrationFiles.ReadFile("sql/migrations/" + entry.Name()) 163 if err != nil { 164 return nil, fmt.Errorf("failed to read migration file %s: %w", entry.Name(), err) 165 } 166 167 if strings.HasSuffix(entry.Name(), "_up.sql") { 168 migrationMap[version].UpSQL = string(content) 169 } else if strings.HasSuffix(entry.Name(), "_down.sql") { 170 migrationMap[version].DownSQL = string(content) 171 } 172 } 173 174 var migrations []Migration 175 for _, m := range migrationMap { 176 migrations = append(migrations, *m) 177 } 178 sort.Slice(migrations, func(i, j int) bool { 179 return migrations[i].Version < migrations[j].Version 180 }) 181 182 return migrations, nil 183} 184 185// Rollback rolls back the last applied migration 186func (mr *MigrationRunner) Rollback() error { 187 var version string 188 err := mr.db.QueryRow("SELECT version FROM migrations ORDER BY version DESC LIMIT 1").Scan(&version) 189 if err != nil { 190 if err == sql.ErrNoRows { 191 return fmt.Errorf("no migrations to rollback") 192 } 193 return fmt.Errorf("failed to get last migration: %w", err) 194 } 195 196 entries, err := mr.migrationFiles.ReadDir("sql/migrations") 197 if err != nil { 198 return fmt.Errorf("failed to read migrations directory: %w", err) 199 } 200 201 var downContent []byte 202 for _, entry := range entries { 203 if strings.HasPrefix(entry.Name(), version) && strings.HasSuffix(entry.Name(), "_down.sql") { 204 downContent, err = mr.migrationFiles.ReadFile("sql/migrations/" + entry.Name()) 205 if err != nil { 206 return fmt.Errorf("failed to read down migration: %w", err) 207 } 208 break 209 } 210 } 211 212 if downContent == nil { 213 return fmt.Errorf("down migration not found for version %s", version) 214 } 215 216 if _, err := mr.db.Exec(string(downContent)); err != nil { 217 return fmt.Errorf("failed to execute down migration: %w", err) 218 } 219 220 if _, err := mr.db.Exec("DELETE FROM migrations WHERE version = ?", version); err != nil { 221 return fmt.Errorf("failed to remove migration record: %w", err) 222 } 223 224 return nil 225} 226 227// extractVersionFromFilename extracts the 4-digit version from a [Migration] filename 228func extractVersionFromFilename(filename string) string { 229 parts := strings.Split(filename, "_") 230 if len(parts) > 0 { 231 return parts[0] 232 } 233 return "" 234} 235 236func extractNameFromFilename(filename string) string { 237 parts := strings.Split(filename, "_") 238 if len(parts) < 3 { 239 return "" 240 } 241 242 name := strings.Join(parts[1:len(parts)-1], "_") 243 return strings.TrimSuffix(name, "_up") 244}