1package norm
2
3import (
4 "context"
5 "database/sql"
6 "fmt"
7 "strings"
8)
9
10type update struct {
11 table string
12 or UpdateOr
13 sets []struct {
14 col string
15 val any
16 }
17 where *Expr
18}
19
20type UpdateOr int
21
22const (
23 UpdateNone UpdateOr = iota
24 UpdateAbort
25 UpdateFail
26 UpdateIgnore
27 UpdateReplace
28 UpdateRollback
29)
30
31func (u UpdateOr) String() string {
32 switch u {
33 case UpdateAbort:
34 return "ABORT"
35 case UpdateFail:
36 return "FAIL"
37 case UpdateIgnore:
38 return "IGNORE"
39 case UpdateReplace:
40 return "REPLACE"
41 case UpdateRollback:
42 return "ROLLBACK"
43 default:
44 return ""
45 }
46}
47
48func Update(table string) update {
49 return update{table: table}
50}
51
52type UpdateOpt func(s *update)
53
54func (u update) Or(option UpdateOr) update {
55 u.or = option
56 return u
57}
58
59func (u update) Set(column string, value any) update {
60 u.sets = append(u.sets, struct {
61 col string
62 val any
63 }{
64 column, value,
65 })
66 return u
67}
68
69func (u update) Sets(values map[string]any) update {
70 for column, value := range values {
71 u = u.Set(column, value)
72 }
73 return u
74}
75
76func (u update) Where(expr Expr) update {
77 u.where = &expr
78 return u
79}
80
81func (u update) Compile() (string, []any, error) {
82 var sql strings.Builder
83 var args []any
84
85 sql.WriteString("UPDATE ")
86
87 orKw := u.or.String()
88 if orKw != "" {
89 sql.WriteString("OR ")
90 sql.WriteString(u.or.String())
91 sql.WriteString(" ")
92 }
93
94 if u.table == "" {
95 return "", nil, fmt.Errorf("table name is required")
96 }
97 sql.WriteString(u.table)
98
99 if len(u.sets) == 0 {
100 return "", nil, fmt.Errorf("no SET clauses supplied")
101 }
102
103 sql.WriteString(" SET ")
104
105 for i, set := range u.sets {
106 if i != 0 {
107 sql.WriteString(", ")
108 }
109 sql.WriteString(set.col)
110 sql.WriteString(" = ?")
111 args = append(args, set.val)
112 }
113
114 if u.where != nil {
115 sql.WriteString(" WHERE ")
116 sql.WriteString(u.where.String())
117
118 args = append(args, u.where.Binds()...)
119 }
120
121 return sql.String(), args, nil
122}
123
124func (u update) MustCompile() (string, []any) {
125 sql, args, err := u.Compile()
126 if err != nil {
127 panic(err)
128 }
129
130 return sql, args
131}
132
133func (u update) Build(p Database) (*sql.Stmt, []any, error) { return Build(u, p) }
134func (u update) MustBuild(p Database) (*sql.Stmt, []any) { return MustBuild(u, p) }
135
136func (u update) Exec(p Database) (sql.Result, error) { return Exec(u, p) }
137func (u update) ExecContext(ctx context.Context, p Database) (sql.Result, error) {
138 return ExecContext(ctx, u, p)
139}
140func (u update) MustExec(p Database) sql.Result { return MustExec(u, p) }
141func (u update) MustExecContext(ctx context.Context, p Database) sql.Result {
142 return MustExecContext(ctx, u, p)
143}