1package norm
2
3import (
4 "context"
5 "database/sql"
6 "fmt"
7 "strconv"
8 "strings"
9)
10
11type select_ struct {
12 resultColumns []string // TODO: strongly type result columns
13 from string // table-or-subquery expr
14 where *Expr
15 orderBy []orderBy
16 groupBy []groupBy
17 limit *limit
18}
19
20type orderBy struct {
21 field string
22 direction Direction
23}
24
25type groupBy struct {
26 field string
27}
28
29type limit struct {
30 limit int
31}
32
33func Select(cols ...string) select_ {
34 return select_{
35 resultColumns: cols,
36 }
37}
38
39type SelectOpt func(s *select_)
40
41func (s select_) From(table string) select_ {
42 s.from = table
43 return s
44}
45
46func (s select_) Where(expr Expr) select_ {
47 s.where = &expr
48 return s
49}
50
51func (s select_) OrderBy(field string, direction Direction) select_ {
52 s.orderBy = append(s.orderBy, orderBy{
53 field: field,
54 direction: direction,
55 })
56 return s
57}
58
59func (s select_) GroupBy(field string) select_ {
60 s.groupBy = append(s.groupBy, groupBy{
61 field: field,
62 })
63 return s
64}
65
66func (s select_) Limit(i int) select_ {
67 s.limit = &limit{
68 limit: i,
69 }
70 return s
71}
72
73func (s select_) Compile() (string, []any, error) {
74 var sql strings.Builder
75 var args []any
76
77 sql.WriteString("SELECT ")
78 if len(s.resultColumns) == 0 {
79 return "", nil, fmt.Errorf("result columns empty")
80 } else {
81 for i, col := range s.resultColumns {
82 if i > 0 {
83 sql.WriteString(", ")
84 }
85 sql.WriteString(col)
86 }
87 }
88
89 if s.from == "" {
90 return "", nil, fmt.Errorf("FROM clause is required")
91 }
92 sql.WriteString(" FROM ")
93 sql.WriteString(s.from)
94
95 if s.where != nil {
96 sql.WriteString(" WHERE ")
97 sql.WriteString(s.where.String())
98
99 args = s.where.Binds()
100 }
101
102 // GROUP BY clause
103 if len(s.groupBy) > 0 {
104 sql.WriteString(" GROUP BY ")
105 for i, gb := range s.groupBy {
106 if i > 0 {
107 sql.WriteString(", ")
108 }
109 sql.WriteString(gb.field)
110 }
111 }
112
113 // ORDER BY clause
114 if len(s.orderBy) > 0 {
115 sql.WriteString(" ORDER BY ")
116 for i, ob := range s.orderBy {
117 if i > 0 {
118 sql.WriteString(", ")
119 }
120 sql.WriteString(ob.field)
121 sql.WriteString(" ")
122 sql.WriteString(string(ob.direction))
123 }
124 }
125
126 // LIMIT clause
127 if s.limit != nil {
128 if s.limit.limit <= 0 {
129 return "", nil, fmt.Errorf("LIMIT must be positive, got %d", s.limit.limit)
130 }
131 sql.WriteString(" LIMIT ")
132 sql.WriteString(strconv.Itoa(s.limit.limit))
133 }
134
135 return sql.String(), args, nil
136}
137
138func (s select_) MustCompile() (string, []any) {
139 sql, args, err := s.Compile()
140 if err != nil {
141 panic(err)
142 }
143
144 return sql, args
145}
146
147func (s select_) Build(p Database) (*sql.Stmt, []any, error) { return Build(s, p) }
148func (s select_) MustBuild(p Database) (*sql.Stmt, []any) { return MustBuild(s, p) }
149
150func (s select_) Exec(p Database) (sql.Result, error) { return Exec(s, p) }
151func (s select_) ExecContext(ctx context.Context, p Database) (sql.Result, error) {
152 return ExecContext(ctx, s, p)
153}
154func (s select_) MustExec(p Database) sql.Result { return MustExec(s, p) }
155func (s select_) MustExecContext(ctx context.Context, p Database) sql.Result {
156 return MustExecContext(ctx, s, p)
157}
158
159func (s select_) Query(p Database) (*sql.Rows, error) { return Query(s, p) }
160func (s select_) QueryContext(ctx context.Context, p Database) (*sql.Rows, error) {
161 return QueryContext(ctx, s, p)
162}
163func (s select_) QueryRow(p Database) (*sql.Row, error) { return QueryRow(s, p) }
164func (s select_) QueryRowContext(ctx context.Context, p Database) (*sql.Row, error) {
165 return QueryRowContext(ctx, s, p)
166}
167
168func (s select_) MustQuery(p Database) *sql.Rows { return MustQuery(s, p) }
169func (s select_) MustQueryContext(ctx context.Context, p Database) *sql.Rows {
170 return MustQueryContext(ctx, s, p)
171}
172func (s select_) MustQueryRow(p Database) *sql.Row { return MustQueryRow(s, p) }
173func (s select_) MustQueryRowContext(ctx context.Context, p Database) *sql.Row {
174 return MustQueryRowContext(ctx, s, p)
175}