an ORM-free SQL experience
1package norm
2
3import (
4 "context"
5 "database/sql"
6 "fmt"
7 "strings"
8)
9
10type SQLiteType string
11
12const (
13 Integer SQLiteType = "INTEGER"
14 Text SQLiteType = "TEXT"
15 Real SQLiteType = "REAL"
16 Blob SQLiteType = "BLOB"
17 Numeric SQLiteType = "NUMERIC"
18)
19
20type ColumnConstraint interface {
21 applyConstraint(*columnDef)
22}
23
24type constraintFunc func(*columnDef)
25
26func (f constraintFunc) applyConstraint(col *columnDef) {
27 f(col)
28}
29
30var (
31 PrimaryKey = constraintFunc(func(col *columnDef) {
32 col.constraints = append(col.constraints, "PRIMARY KEY")
33 })
34
35 AutoIncrement = constraintFunc(func(col *columnDef) {
36 col.constraints = append(col.constraints, "AUTOINCREMENT")
37 })
38
39 NotNull = constraintFunc(func(col *columnDef) {
40 col.constraints = append(col.constraints, "NOT NULL")
41 })
42
43 Unique = constraintFunc(func(col *columnDef) {
44 col.constraints = append(col.constraints, "UNIQUE")
45 })
46)
47
48func Default(val any) ColumnConstraint {
49 return constraintFunc(func(col *columnDef) {
50 col.constraints = append(col.constraints, fmt.Sprintf("DEFAULT %v", val))
51 })
52}
53
54func Check(expr string) ColumnConstraint {
55 return constraintFunc(func(col *columnDef) {
56 col.constraints = append(col.constraints, fmt.Sprintf("CHECK (%s)", expr))
57 })
58}
59
60func Collate(collation string) ColumnConstraint {
61 return constraintFunc(func(col *columnDef) {
62 col.constraints = append(col.constraints, fmt.Sprintf("COLLATE %s", collation))
63 })
64}
65
66type columnDef struct {
67 name string
68 dataType SQLiteType
69 constraints []string
70}
71
72type createTable struct {
73 table string
74 ifNotExists bool
75 columns []columnDef
76 tableConstraints []string
77 withoutRowid bool
78 strict bool
79}
80
81func CreateTable(name string) createTable {
82 return createTable{table: name}
83}
84
85func (c createTable) IfNotExists() createTable {
86 c.ifNotExists = true
87 return c
88}
89
90func (c createTable) Column(name string, dataType SQLiteType, constraints ...ColumnConstraint) createTable {
91 col := columnDef{
92 name: name,
93 dataType: dataType,
94 }
95 for _, constraint := range constraints {
96 constraint.applyConstraint(&col)
97 }
98 c.columns = append(c.columns, col)
99 return c
100}
101
102func (c createTable) PrimaryKey(cols ...string) createTable {
103 c.tableConstraints = append(c.tableConstraints,
104 fmt.Sprintf("PRIMARY KEY (%s)", strings.Join(cols, ", ")))
105 return c
106}
107
108func (c createTable) UniqueConstraint(cols ...string) createTable {
109 c.tableConstraints = append(c.tableConstraints,
110 fmt.Sprintf("UNIQUE (%s)", strings.Join(cols, ", ")))
111 return c
112}
113
114func (c createTable) CheckConstraint(expr string) createTable {
115 c.tableConstraints = append(c.tableConstraints,
116 fmt.Sprintf("CHECK (%s)", expr))
117 return c
118}
119
120func (c createTable) ForeignKey(col, refTable, refCol string) createTable {
121 c.tableConstraints = append(c.tableConstraints,
122 fmt.Sprintf("FOREIGN KEY (%s) REFERENCES %s(%s)", col, refTable, refCol))
123 return c
124}
125
126func (c createTable) WithoutRowid() createTable {
127 c.withoutRowid = true
128 return c
129}
130
131func (c createTable) Strict() createTable {
132 c.strict = true
133 return c
134}
135
136func (c createTable) Compile() (string, []any, error) {
137 var sql strings.Builder
138
139 sql.WriteString("CREATE TABLE ")
140
141 if c.ifNotExists {
142 sql.WriteString("IF NOT EXISTS ")
143 }
144
145 if c.table == "" {
146 return "", nil, fmt.Errorf("table name is required")
147 }
148 sql.WriteString(c.table)
149
150 if len(c.columns) == 0 {
151 return "", nil, fmt.Errorf("at least one column is required")
152 }
153
154 sql.WriteString(" (")
155
156 // Column definitions
157 for i, col := range c.columns {
158 if i > 0 {
159 sql.WriteString(", ")
160 }
161
162 sql.WriteString(col.name)
163 sql.WriteString(" ")
164 sql.WriteString(string(col.dataType))
165
166 for _, constraint := range col.constraints {
167 sql.WriteString(" ")
168 sql.WriteString(constraint)
169 }
170 }
171
172 // Table-level constraints
173 for _, constraint := range c.tableConstraints {
174 sql.WriteString(", ")
175 sql.WriteString(constraint)
176 }
177
178 sql.WriteString(")")
179
180 if c.strict {
181 sql.WriteString(" STRICT")
182 }
183
184 if c.withoutRowid {
185 sql.WriteString(" WITHOUT ROWID")
186 }
187
188 return sql.String(), nil, nil
189}
190
191func (c createTable) MustCompile() (string, []any) {
192 sql, args, err := c.Compile()
193 if err != nil {
194 panic(err)
195 }
196 return sql, args
197}
198
199func (c createTable) Build(p Database) (*sql.Stmt, []any, error) { return Build(c, p) }
200func (c createTable) MustBuild(p Database) (*sql.Stmt, []any) { return MustBuild(c, p) }
201
202func (c createTable) Exec(p Database) (sql.Result, error) { return Exec(c, p) }
203func (c createTable) ExecContext(ctx context.Context, p Database) (sql.Result, error) {
204 return ExecContext(ctx, c, p)
205}
206func (c createTable) MustExec(p Database) sql.Result { return MustExec(c, p) }
207func (c createTable) MustExecContext(ctx context.Context, p Database) sql.Result {
208 return MustExecContext(ctx, c, p)
209}