+258
query.go
+258
query.go
···
1
+
package norm
2
+
3
+
import (
4
+
"fmt"
5
+
"strconv"
6
+
"strings"
7
+
)
8
+
9
+
type op string
10
+
11
+
type Expr struct {
12
+
kind ExprKind
13
+
14
+
ident *IdentExpr
15
+
binary *BinExpr
16
+
value *ValueExpr
17
+
}
18
+
19
+
func (e Expr) String() string {
20
+
switch e.kind {
21
+
case exprKindIdent:
22
+
return e.ident.String()
23
+
case exprKindBinary:
24
+
return e.binary.String()
25
+
case exprKindValue:
26
+
return e.value.String()
27
+
}
28
+
29
+
// unreachable
30
+
return ""
31
+
}
32
+
33
+
func (e Expr) Binds() []any {
34
+
switch e.kind {
35
+
case exprKindIdent:
36
+
return e.ident.Binds()
37
+
case exprKindBinary:
38
+
return e.binary.Binds()
39
+
case exprKindValue:
40
+
return e.value.Binds()
41
+
}
42
+
43
+
return nil
44
+
}
45
+
46
+
type ExprKind int
47
+
48
+
const (
49
+
exprKindIdent ExprKind = iota
50
+
exprKindBinary
51
+
exprKindValue
52
+
)
53
+
54
+
type IdentExpr string
55
+
56
+
func (i IdentExpr) String() string {
57
+
return string(i)
58
+
}
59
+
60
+
func (i IdentExpr) Binds() []any {
61
+
return nil
62
+
}
63
+
64
+
func (i IdentExpr) AsExpr() Expr {
65
+
return Expr{
66
+
kind: exprKindIdent,
67
+
ident: &i,
68
+
}
69
+
}
70
+
71
+
type ValueExpr struct {
72
+
inner any
73
+
}
74
+
75
+
func (v ValueExpr) String() string {
76
+
return "?"
77
+
}
78
+
79
+
func (v ValueExpr) Binds() []any {
80
+
return []any{v.inner}
81
+
}
82
+
83
+
func (v ValueExpr) AsExpr() Expr {
84
+
return Expr{
85
+
kind: exprKindValue,
86
+
value: &v,
87
+
}
88
+
}
89
+
90
+
type BinExpr struct {
91
+
left Expr
92
+
op op
93
+
right Expr
94
+
}
95
+
96
+
func (b BinExpr) String() string {
97
+
return fmt.Sprintf("(%s) %s (%s)", b.left.String(), b.op, b.right.String())
98
+
}
99
+
100
+
func (b BinExpr) Binds() []any {
101
+
binds := b.left.Binds()
102
+
binds = append(binds, b.right.Binds()...)
103
+
return binds
104
+
}
105
+
106
+
func (b BinExpr) AsExpr() Expr {
107
+
return Expr{
108
+
kind: exprKindBinary,
109
+
binary: &b,
110
+
}
111
+
}
112
+
113
+
func buildBinExpr(left IdentExpr, op op, right any) Expr {
114
+
return BinExpr{left.AsExpr(), op, ValueExpr{right}.AsExpr()}.AsExpr()
115
+
}
116
+
func Eq(left string, right any) Expr { return buildBinExpr(IdentExpr(left), "=", right) }
117
+
func Neq(left string, right any) Expr { return buildBinExpr(IdentExpr(left), "<>", right) }
118
+
func Gt(left string, right any) Expr { return buildBinExpr(IdentExpr(left), ">", right) }
119
+
func Gte(left string, right any) Expr { return buildBinExpr(IdentExpr(left), ">=", right) }
120
+
func Lt(left string, right any) Expr { return buildBinExpr(IdentExpr(left), "<", right) }
121
+
func Lte(left string, right any) Expr { return buildBinExpr(IdentExpr(left), "<=", right) }
122
+
123
+
func (l Expr) And(r Expr) Expr { return BinExpr{l, "and", r}.AsExpr() }
124
+
func (l Expr) Or(r Expr) Expr { return BinExpr{l, "or", r}.AsExpr() }
125
+
126
+
type select_ struct {
127
+
resultColumns []string // TODO: strongly type result columns
128
+
from string // table-or-subquery expr
129
+
where *Expr
130
+
orderBy []orderBy
131
+
groupBy []groupBy
132
+
limit *limit
133
+
}
134
+
135
+
type orderBy struct {
136
+
field string
137
+
direction Direction
138
+
}
139
+
140
+
type groupBy struct {
141
+
field string
142
+
}
143
+
144
+
type limit struct {
145
+
limit int
146
+
}
147
+
148
+
func Select(cols ...string) select_ {
149
+
return select_{
150
+
resultColumns: cols,
151
+
}
152
+
}
153
+
154
+
type SelectOpt func(s *select_)
155
+
156
+
func (s select_) From(table string) select_ {
157
+
s.from = table
158
+
return s
159
+
}
160
+
161
+
func (s select_) Where(expr Expr) select_ {
162
+
s.where = &expr
163
+
return s
164
+
}
165
+
166
+
type Direction string
167
+
168
+
const (
169
+
Ascending Direction = "asc"
170
+
Descending Direction = "desc"
171
+
)
172
+
173
+
func (s select_) OrderBy(field string, direction Direction) select_ {
174
+
s.orderBy = append(s.orderBy, orderBy{
175
+
field: field,
176
+
direction: direction,
177
+
})
178
+
return s
179
+
}
180
+
181
+
func (s select_) GroupBy(field string) select_ {
182
+
s.groupBy = append(s.groupBy, groupBy{
183
+
field: field,
184
+
})
185
+
return s
186
+
}
187
+
188
+
func (s select_) Limit(i int) select_ {
189
+
s.limit = &limit{
190
+
limit: i,
191
+
}
192
+
return s
193
+
}
194
+
195
+
func (s select_) Build() (string, []any, error) {
196
+
var sql strings.Builder
197
+
var args []any
198
+
199
+
sql.WriteString("SELECT ")
200
+
if len(s.resultColumns) == 0 {
201
+
return "", nil, fmt.Errorf("result columns empty")
202
+
} else {
203
+
for i, col := range s.resultColumns {
204
+
if i > 0 {
205
+
sql.WriteString(", ")
206
+
}
207
+
sql.WriteString(col)
208
+
}
209
+
}
210
+
211
+
if s.from == "" {
212
+
return "", nil, fmt.Errorf("FROM clause is required")
213
+
}
214
+
sql.WriteString(" FROM ")
215
+
sql.WriteString(s.from)
216
+
217
+
if s.where != nil {
218
+
sql.WriteString(" WHERE ")
219
+
sql.WriteString(s.where.String())
220
+
221
+
args = s.where.Binds()
222
+
}
223
+
224
+
// GROUP BY clause
225
+
if len(s.groupBy) > 0 {
226
+
sql.WriteString(" GROUP BY ")
227
+
for i, gb := range s.groupBy {
228
+
if i > 0 {
229
+
sql.WriteString(", ")
230
+
}
231
+
sql.WriteString(gb.field)
232
+
}
233
+
}
234
+
235
+
// ORDER BY clause
236
+
if len(s.orderBy) > 0 {
237
+
sql.WriteString(" ORDER BY ")
238
+
for i, ob := range s.orderBy {
239
+
if i > 0 {
240
+
sql.WriteString(", ")
241
+
}
242
+
sql.WriteString(ob.field)
243
+
sql.WriteString(" ")
244
+
sql.WriteString(string(ob.direction))
245
+
}
246
+
}
247
+
248
+
// LIMIT clause
249
+
if s.limit != nil {
250
+
if s.limit.limit <= 0 {
251
+
return "", nil, fmt.Errorf("LIMIT must be positive, got %d", s.limit.limit)
252
+
}
253
+
sql.WriteString(" LIMIT ")
254
+
sql.WriteString(strconv.Itoa(s.limit.limit))
255
+
}
256
+
257
+
return sql.String(), args, nil
258
+
}
+381
query_test.go
+381
query_test.go
···
1
+
package norm
2
+
3
+
import (
4
+
"testing"
5
+
)
6
+
7
+
func TestIdentExpr(t *testing.T) {
8
+
ident := IdentExpr("username")
9
+
10
+
if ident.String() != "username" {
11
+
t.Errorf("Expected 'username', got '%s'", ident.String())
12
+
}
13
+
14
+
if ident.Binds() != nil {
15
+
t.Errorf("Expected nil binds, got %v", ident.Binds())
16
+
}
17
+
18
+
expr := ident.AsExpr()
19
+
if expr.kind != exprKindIdent {
20
+
t.Errorf("Expected exprKindIdent, got %d", expr.kind)
21
+
}
22
+
23
+
if expr.String() != "username" {
24
+
t.Errorf("Expected 'username', got '%s'", expr.String())
25
+
}
26
+
}
27
+
28
+
func TestValueExpr(t *testing.T) {
29
+
value := ValueExpr{inner: "test"}
30
+
31
+
if value.String() != "?" {
32
+
t.Errorf("Expected '?', got '%s'", value.String())
33
+
}
34
+
35
+
if value.Binds()[0] != "test" {
36
+
t.Errorf("Expected %q, got %v", []any{"test"}, value.Binds())
37
+
}
38
+
39
+
expr := value.AsExpr()
40
+
if expr.kind != exprKindValue {
41
+
t.Errorf("Expected exprKindValue, got %d", expr.kind)
42
+
}
43
+
}
44
+
45
+
func TestBinaryExpressions(t *testing.T) {
46
+
tests := []struct {
47
+
name string
48
+
expr Expr
49
+
expected string
50
+
}{
51
+
{"Eq", Eq("age", 25), "(age) = (?)"},
52
+
{"Neq", Neq("status", "active"), "(status) <> (?)"},
53
+
{"Gt", Gt("score", 100), "(score) > (?)"},
54
+
{"Gte", Gte("rating", 4.5), "(rating) >= (?)"},
55
+
{"Lt", Lt("count", 10), "(count) < (?)"},
56
+
{"Lte", Lte("price", 99.99), "(price) <= (?)"},
57
+
}
58
+
59
+
for _, test := range tests {
60
+
t.Run(test.name, func(t *testing.T) {
61
+
if test.expr.String() != test.expected {
62
+
t.Errorf("Expected '%s', got '%s'", test.expected, test.expr.String())
63
+
}
64
+
65
+
if test.expr.kind != exprKindBinary {
66
+
t.Errorf("Expected exprKindBinary, got %d", test.expr.kind)
67
+
}
68
+
})
69
+
}
70
+
}
71
+
72
+
func TestLogicalOperators(t *testing.T) {
73
+
left := Eq("age", 25)
74
+
right := Eq("status", "active")
75
+
76
+
andExpr := left.And(right)
77
+
expectedAnd := "((age) = (?)) and ((status) = (?))"
78
+
if andExpr.String() != expectedAnd {
79
+
t.Errorf("Expected '%s', got '%s'", expectedAnd, andExpr.String())
80
+
}
81
+
82
+
orExpr := left.Or(right)
83
+
expectedOr := "((age) = (?)) or ((status) = (?))"
84
+
if orExpr.String() != expectedOr {
85
+
t.Errorf("Expected '%s', got '%s'", expectedOr, orExpr.String())
86
+
}
87
+
}
88
+
89
+
func TestComplexExpressions(t *testing.T) {
90
+
// Test chained logical operations
91
+
age := Eq("age", 25)
92
+
status := Eq("status", "active")
93
+
score := Gt("score", 100)
94
+
95
+
complex := age.
96
+
And(status).
97
+
Or(score)
98
+
expected := "(((age) = (?)) and ((status) = (?))) or ((score) > (?))"
99
+
100
+
if complex.String() != expected {
101
+
t.Errorf("Expected '%s', got '%s'", expected, complex.String())
102
+
}
103
+
}
104
+
105
+
func TestSelectBasic(t *testing.T) {
106
+
s := Select("name", "age")
107
+
108
+
if len(s.resultColumns) != 2 {
109
+
t.Errorf("Expected 2 columns, got %d", len(s.resultColumns))
110
+
}
111
+
112
+
if s.resultColumns[0] != "name" || s.resultColumns[1] != "age" {
113
+
t.Errorf("Expected columns [name, age], got %v", s.resultColumns)
114
+
}
115
+
}
116
+
117
+
func TestSelectAPI(t *testing.T) {
118
+
// Test fluent API chaining
119
+
s := Select("*").
120
+
From("users").
121
+
Where(Eq("id", 1)).
122
+
OrderBy("name", Ascending).
123
+
GroupBy("department").
124
+
Limit(10)
125
+
126
+
if s.from != "users" {
127
+
t.Errorf("Expected from to be 'users', got '%s'", s.from)
128
+
}
129
+
130
+
if s.where == nil {
131
+
t.Error("Expected where clause to be set")
132
+
}
133
+
134
+
if len(s.orderBy) != 1 {
135
+
t.Errorf("Expected 1 order by clause, got %d", len(s.orderBy))
136
+
}
137
+
138
+
if s.orderBy[0].field != "name" || s.orderBy[0].direction != Ascending {
139
+
t.Errorf("Expected order by name ASC, got %v", s.orderBy[0])
140
+
}
141
+
142
+
if len(s.groupBy) != 1 {
143
+
t.Errorf("Expected 1 group by clause, got %d", len(s.groupBy))
144
+
}
145
+
146
+
if s.groupBy[0].field != "department" {
147
+
t.Errorf("Expected group by department, got %s", s.groupBy[0].field)
148
+
}
149
+
150
+
if s.limit == nil {
151
+
t.Error("Expected limit to be set")
152
+
}
153
+
154
+
if s.limit.limit != 10 {
155
+
t.Errorf("Expected limit 10, got %d", s.limit.limit)
156
+
}
157
+
}
158
+
159
+
func TestSelectBuild_Success(t *testing.T) {
160
+
tests := []struct {
161
+
name string
162
+
builder func() select_
163
+
expectedSql string
164
+
expectedArgs []any
165
+
}{
166
+
{
167
+
name: "Simple select",
168
+
builder: func() select_ {
169
+
return Select("name", "age").From("users")
170
+
},
171
+
expectedSql: "SELECT name, age FROM users",
172
+
expectedArgs: nil,
173
+
},
174
+
{
175
+
name: "Select with where",
176
+
builder: func() select_ {
177
+
return Select("*").
178
+
From("users").
179
+
Where(Eq("active", true))
180
+
},
181
+
expectedSql: "SELECT * FROM users WHERE (active) = (?)",
182
+
expectedArgs: []any{true},
183
+
},
184
+
{
185
+
name: "Select with order by",
186
+
builder: func() select_ {
187
+
return Select("name").
188
+
From("users").
189
+
OrderBy("name", Ascending)
190
+
},
191
+
expectedSql: "SELECT name FROM users ORDER BY name asc",
192
+
expectedArgs: nil,
193
+
},
194
+
{
195
+
name: "Select with multiple order by",
196
+
builder: func() select_ {
197
+
return Select("name", "age").
198
+
From("users").
199
+
OrderBy("name", Ascending).
200
+
OrderBy("age", Descending)
201
+
},
202
+
expectedSql: "SELECT name, age FROM users ORDER BY name asc, age desc",
203
+
expectedArgs: nil,
204
+
},
205
+
{
206
+
name: "Select with group by",
207
+
builder: func() select_ {
208
+
return Select("department", "COUNT(*)").
209
+
From("users").
210
+
GroupBy("department")
211
+
},
212
+
expectedSql: "SELECT department, COUNT(*) FROM users GROUP BY department",
213
+
expectedArgs: nil,
214
+
},
215
+
{
216
+
name: "Select with limit",
217
+
builder: func() select_ {
218
+
return Select("*").
219
+
From("users").
220
+
Limit(5)
221
+
},
222
+
expectedSql: "SELECT * FROM users LIMIT 5",
223
+
expectedArgs: nil,
224
+
},
225
+
{
226
+
name: "Complex select",
227
+
builder: func() select_ {
228
+
return Select("name", "age", "department").
229
+
From("users").
230
+
Where(Eq("active", true).And(Gt("age", 18))).
231
+
GroupBy("department").
232
+
OrderBy("name", Ascending).
233
+
Limit(10)
234
+
},
235
+
expectedSql: "SELECT name, age, department FROM users WHERE ((active) = (?)) and ((age) > (?)) GROUP BY department ORDER BY name asc LIMIT 10",
236
+
expectedArgs: []any{true, 18},
237
+
},
238
+
}
239
+
240
+
for _, test := range tests {
241
+
t.Run(test.name, func(t *testing.T) {
242
+
s := test.builder()
243
+
sql, args, err := s.Build()
244
+
245
+
if err != nil {
246
+
t.Errorf("Expected no error, got %v", err)
247
+
}
248
+
249
+
if sql != test.expectedSql {
250
+
t.Errorf("Expected '%s', got '%s'", test.expectedSql, sql)
251
+
}
252
+
253
+
if len(args) != len(test.expectedArgs) {
254
+
t.Errorf("Expected '%d' args, got '%d' args", len(test.expectedArgs), len(args))
255
+
}
256
+
257
+
for i := range len(args) {
258
+
if args[i] != test.expectedArgs[i] {
259
+
t.Errorf("Expected '%s', got '%s' at index %d", test.expectedArgs[i], args[i], i)
260
+
}
261
+
}
262
+
})
263
+
}
264
+
}
265
+
266
+
func TestSelectBuild_Errors(t *testing.T) {
267
+
tests := []struct {
268
+
name string
269
+
builder func() select_
270
+
expectedError string
271
+
}{
272
+
{
273
+
name: "No columns",
274
+
builder: func() select_ {
275
+
return Select()
276
+
},
277
+
expectedError: "result columns empty",
278
+
},
279
+
{
280
+
name: "No from clause",
281
+
builder: func() select_ {
282
+
return Select("name")
283
+
},
284
+
expectedError: "FROM clause is required",
285
+
},
286
+
{
287
+
name: "Invalid limit",
288
+
builder: func() select_ {
289
+
return Select("name").
290
+
From("users").
291
+
Limit(0)
292
+
},
293
+
expectedError: "LIMIT must be positive, got 0",
294
+
},
295
+
{
296
+
name: "Negative limit",
297
+
builder: func() select_ {
298
+
return Select("name").
299
+
From("users").
300
+
Limit(-5)
301
+
},
302
+
expectedError: "LIMIT must be positive, got -5",
303
+
},
304
+
}
305
+
306
+
for _, test := range tests {
307
+
t.Run(test.name, func(t *testing.T) {
308
+
s := test.builder()
309
+
sql, args, err := s.Build()
310
+
311
+
if err == nil {
312
+
t.Error("Expected error, got nil")
313
+
}
314
+
315
+
if err.Error() != test.expectedError {
316
+
t.Errorf("Expected error '%s', got '%s'", test.expectedError, err.Error())
317
+
}
318
+
319
+
if sql != "" {
320
+
t.Errorf("Expected empty SQL on error, got '%s'", sql)
321
+
}
322
+
323
+
if args != nil {
324
+
t.Errorf("Expected empty args on error, got '%q'", args)
325
+
}
326
+
})
327
+
}
328
+
}
329
+
330
+
func TestDirections(t *testing.T) {
331
+
if Ascending != "asc" {
332
+
t.Errorf("Expected ASC to be 'asc', got '%s'", Ascending)
333
+
}
334
+
335
+
if Descending != "desc" {
336
+
t.Errorf("Expected DESC to be 'desc', got '%s'", Descending)
337
+
}
338
+
}
339
+
340
+
func TestMultipleOrderBy(t *testing.T) {
341
+
s := Select("name", "age").
342
+
From("users").
343
+
OrderBy("name", Ascending).
344
+
OrderBy("age", Descending).
345
+
OrderBy("created_at", Descending)
346
+
347
+
if len(s.orderBy) != 3 {
348
+
t.Errorf("Expected 3 order by clauses, got %d", len(s.orderBy))
349
+
}
350
+
351
+
expected := []orderBy{
352
+
{"name", Ascending},
353
+
{"age", Descending},
354
+
{"created_at", Descending},
355
+
}
356
+
357
+
for i, ob := range s.orderBy {
358
+
if ob.field != expected[i].field || ob.direction != expected[i].direction {
359
+
t.Errorf("Expected order by %v, got %v", expected[i], ob)
360
+
}
361
+
}
362
+
}
363
+
364
+
func TestMultipleGroupBy(t *testing.T) {
365
+
s := Select("department", "status", "COUNT(*)").
366
+
From("users").
367
+
GroupBy("department").
368
+
GroupBy("status")
369
+
370
+
if len(s.groupBy) != 2 {
371
+
t.Errorf("Expected 2 group by clauses, got %d", len(s.groupBy))
372
+
}
373
+
374
+
expected := []string{"department", "status"}
375
+
376
+
for i, gb := range s.groupBy {
377
+
if gb.field != expected[i] {
378
+
t.Errorf("Expected group by %s, got %s", expected[i], gb.field)
379
+
}
380
+
}
381
+
}