an ORM-free SQL experience
at main 209 lines 4.8 kB view raw
1package norm 2 3import ( 4 "context" 5 "database/sql" 6 "fmt" 7 "strings" 8) 9 10type SQLiteType string 11 12const ( 13 Integer SQLiteType = "INTEGER" 14 Text SQLiteType = "TEXT" 15 Real SQLiteType = "REAL" 16 Blob SQLiteType = "BLOB" 17 Numeric SQLiteType = "NUMERIC" 18) 19 20type ColumnConstraint interface { 21 applyConstraint(*columnDef) 22} 23 24type constraintFunc func(*columnDef) 25 26func (f constraintFunc) applyConstraint(col *columnDef) { 27 f(col) 28} 29 30var ( 31 PrimaryKey = constraintFunc(func(col *columnDef) { 32 col.constraints = append(col.constraints, "PRIMARY KEY") 33 }) 34 35 AutoIncrement = constraintFunc(func(col *columnDef) { 36 col.constraints = append(col.constraints, "AUTOINCREMENT") 37 }) 38 39 NotNull = constraintFunc(func(col *columnDef) { 40 col.constraints = append(col.constraints, "NOT NULL") 41 }) 42 43 Unique = constraintFunc(func(col *columnDef) { 44 col.constraints = append(col.constraints, "UNIQUE") 45 }) 46) 47 48func Default(val any) ColumnConstraint { 49 return constraintFunc(func(col *columnDef) { 50 col.constraints = append(col.constraints, fmt.Sprintf("DEFAULT %v", val)) 51 }) 52} 53 54func Check(expr string) ColumnConstraint { 55 return constraintFunc(func(col *columnDef) { 56 col.constraints = append(col.constraints, fmt.Sprintf("CHECK (%s)", expr)) 57 }) 58} 59 60func Collate(collation string) ColumnConstraint { 61 return constraintFunc(func(col *columnDef) { 62 col.constraints = append(col.constraints, fmt.Sprintf("COLLATE %s", collation)) 63 }) 64} 65 66type columnDef struct { 67 name string 68 dataType SQLiteType 69 constraints []string 70} 71 72type createTable struct { 73 table string 74 ifNotExists bool 75 columns []columnDef 76 tableConstraints []string 77 withoutRowid bool 78 strict bool 79} 80 81func CreateTable(name string) createTable { 82 return createTable{table: name} 83} 84 85func (c createTable) IfNotExists() createTable { 86 c.ifNotExists = true 87 return c 88} 89 90func (c createTable) Column(name string, dataType SQLiteType, constraints ...ColumnConstraint) createTable { 91 col := columnDef{ 92 name: name, 93 dataType: dataType, 94 } 95 for _, constraint := range constraints { 96 constraint.applyConstraint(&col) 97 } 98 c.columns = append(c.columns, col) 99 return c 100} 101 102func (c createTable) PrimaryKey(cols ...string) createTable { 103 c.tableConstraints = append(c.tableConstraints, 104 fmt.Sprintf("PRIMARY KEY (%s)", strings.Join(cols, ", "))) 105 return c 106} 107 108func (c createTable) UniqueConstraint(cols ...string) createTable { 109 c.tableConstraints = append(c.tableConstraints, 110 fmt.Sprintf("UNIQUE (%s)", strings.Join(cols, ", "))) 111 return c 112} 113 114func (c createTable) CheckConstraint(expr string) createTable { 115 c.tableConstraints = append(c.tableConstraints, 116 fmt.Sprintf("CHECK (%s)", expr)) 117 return c 118} 119 120func (c createTable) ForeignKey(col, refTable, refCol string) createTable { 121 c.tableConstraints = append(c.tableConstraints, 122 fmt.Sprintf("FOREIGN KEY (%s) REFERENCES %s(%s)", col, refTable, refCol)) 123 return c 124} 125 126func (c createTable) WithoutRowid() createTable { 127 c.withoutRowid = true 128 return c 129} 130 131func (c createTable) Strict() createTable { 132 c.strict = true 133 return c 134} 135 136func (c createTable) Compile() (string, []any, error) { 137 var sql strings.Builder 138 139 sql.WriteString("CREATE TABLE ") 140 141 if c.ifNotExists { 142 sql.WriteString("IF NOT EXISTS ") 143 } 144 145 if c.table == "" { 146 return "", nil, fmt.Errorf("table name is required") 147 } 148 sql.WriteString(c.table) 149 150 if len(c.columns) == 0 { 151 return "", nil, fmt.Errorf("at least one column is required") 152 } 153 154 sql.WriteString(" (") 155 156 // Column definitions 157 for i, col := range c.columns { 158 if i > 0 { 159 sql.WriteString(", ") 160 } 161 162 sql.WriteString(col.name) 163 sql.WriteString(" ") 164 sql.WriteString(string(col.dataType)) 165 166 for _, constraint := range col.constraints { 167 sql.WriteString(" ") 168 sql.WriteString(constraint) 169 } 170 } 171 172 // Table-level constraints 173 for _, constraint := range c.tableConstraints { 174 sql.WriteString(", ") 175 sql.WriteString(constraint) 176 } 177 178 sql.WriteString(")") 179 180 if c.strict { 181 sql.WriteString(" STRICT") 182 } 183 184 if c.withoutRowid { 185 sql.WriteString(" WITHOUT ROWID") 186 } 187 188 return sql.String(), nil, nil 189} 190 191func (c createTable) MustCompile() (string, []any) { 192 sql, args, err := c.Compile() 193 if err != nil { 194 panic(err) 195 } 196 return sql, args 197} 198 199func (c createTable) Build(p Database) (*sql.Stmt, []any, error) { return Build(c, p) } 200func (c createTable) MustBuild(p Database) (*sql.Stmt, []any) { return MustBuild(c, p) } 201 202func (c createTable) Exec(p Database) (sql.Result, error) { return Exec(c, p) } 203func (c createTable) ExecContext(ctx context.Context, p Database) (sql.Result, error) { 204 return ExecContext(ctx, c, p) 205} 206func (c createTable) MustExec(p Database) sql.Result { return MustExec(c, p) } 207func (c createTable) MustExecContext(ctx context.Context, p Database) sql.Result { 208 return MustExecContext(ctx, c, p) 209}