1package orm
2
3import (
4 "context"
5 "database/sql"
6 "fmt"
7 "log/slog"
8 "reflect"
9 "strings"
10)
11
12type migrationFn = func(*sql.Tx) error
13
14func RunMigration(c *sql.Conn, logger *slog.Logger, name string, migrationFn migrationFn) error {
15 logger = logger.With("migration", name)
16
17 tx, err := c.BeginTx(context.Background(), nil)
18 if err != nil {
19 return err
20 }
21 defer tx.Rollback()
22
23 _, err = tx.Exec(`
24 create table if not exists migrations (
25 id integer primary key autoincrement,
26 name text unique
27 );
28 `)
29 if err != nil {
30 return fmt.Errorf("creating migrations table: %w", err)
31 }
32
33 var exists bool
34 err = tx.QueryRow("select exists (select 1 from migrations where name = ?)", name).Scan(&exists)
35 if err != nil {
36 return err
37 }
38
39 if !exists {
40 // run migration
41 err = migrationFn(tx)
42 if err != nil {
43 logger.Error("failed to run migration", "err", err)
44 return err
45 }
46
47 // mark migration as complete
48 _, err = tx.Exec("insert into migrations (name) values (?)", name)
49 if err != nil {
50 logger.Error("failed to mark migration as complete", "err", err)
51 return err
52 }
53
54 // commit the transaction
55 if err := tx.Commit(); err != nil {
56 return err
57 }
58
59 logger.Info("migration applied successfully")
60 } else {
61 logger.Warn("skipped migration, already applied")
62 }
63
64 return nil
65}
66
67type Filter struct {
68 Key string
69 arg any
70 Cmp string
71}
72
73func newFilter(key, cmp string, arg any) Filter {
74 return Filter{
75 Key: key,
76 arg: arg,
77 Cmp: cmp,
78 }
79}
80
81func FilterEq(key string, arg any) Filter { return newFilter(key, "=", arg) }
82func FilterNotEq(key string, arg any) Filter { return newFilter(key, "<>", arg) }
83func FilterGte(key string, arg any) Filter { return newFilter(key, ">=", arg) }
84func FilterLte(key string, arg any) Filter { return newFilter(key, "<=", arg) }
85func FilterIs(key string, arg any) Filter { return newFilter(key, "is", arg) }
86func FilterIsNot(key string, arg any) Filter { return newFilter(key, "is not", arg) }
87func FilterIn(key string, arg any) Filter { return newFilter(key, "in", arg) }
88func FilterLike(key string, arg any) Filter { return newFilter(key, "like", arg) }
89func FilterNotLike(key string, arg any) Filter { return newFilter(key, "not like", arg) }
90func FilterContains(key string, arg any) Filter {
91 return newFilter(key, "like", fmt.Sprintf("%%%v%%", arg))
92}
93
94func (f Filter) Condition() string {
95 rv := reflect.ValueOf(f.arg)
96 kind := rv.Kind()
97
98 // if we have `FilterIn(k, [1, 2, 3])`, compile it down to `k in (?, ?, ?)`
99 if (kind == reflect.Slice && rv.Type().Elem().Kind() != reflect.Uint8) || kind == reflect.Array {
100 if rv.Len() == 0 {
101 // always false
102 return "1 = 0"
103 }
104
105 placeholders := make([]string, rv.Len())
106 for i := range placeholders {
107 placeholders[i] = "?"
108 }
109
110 return fmt.Sprintf("%s %s (%s)", f.Key, f.Cmp, strings.Join(placeholders, ", "))
111 }
112
113 return fmt.Sprintf("%s %s ?", f.Key, f.Cmp)
114}
115
116func (f Filter) Arg() []any {
117 rv := reflect.ValueOf(f.arg)
118 kind := rv.Kind()
119 if (kind == reflect.Slice && rv.Type().Elem().Kind() != reflect.Uint8) || kind == reflect.Array {
120 if rv.Len() == 0 {
121 return nil
122 }
123
124 out := make([]any, rv.Len())
125 for i := range rv.Len() {
126 out[i] = rv.Index(i).Interface()
127 }
128 return out
129 }
130
131 return []any{f.arg}
132}