package norm import ( "context" "database/sql" "fmt" "strings" ) type SQLiteType string const ( Integer SQLiteType = "INTEGER" Text SQLiteType = "TEXT" Real SQLiteType = "REAL" Blob SQLiteType = "BLOB" Numeric SQLiteType = "NUMERIC" ) type ColumnConstraint interface { applyConstraint(*columnDef) } type constraintFunc func(*columnDef) func (f constraintFunc) applyConstraint(col *columnDef) { f(col) } var ( PrimaryKey = constraintFunc(func(col *columnDef) { col.constraints = append(col.constraints, "PRIMARY KEY") }) AutoIncrement = constraintFunc(func(col *columnDef) { col.constraints = append(col.constraints, "AUTOINCREMENT") }) NotNull = constraintFunc(func(col *columnDef) { col.constraints = append(col.constraints, "NOT NULL") }) Unique = constraintFunc(func(col *columnDef) { col.constraints = append(col.constraints, "UNIQUE") }) ) func Default(val any) ColumnConstraint { return constraintFunc(func(col *columnDef) { col.constraints = append(col.constraints, fmt.Sprintf("DEFAULT %v", val)) }) } func Check(expr string) ColumnConstraint { return constraintFunc(func(col *columnDef) { col.constraints = append(col.constraints, fmt.Sprintf("CHECK (%s)", expr)) }) } func Collate(collation string) ColumnConstraint { return constraintFunc(func(col *columnDef) { col.constraints = append(col.constraints, fmt.Sprintf("COLLATE %s", collation)) }) } type columnDef struct { name string dataType SQLiteType constraints []string } type createTable struct { table string ifNotExists bool columns []columnDef tableConstraints []string withoutRowid bool strict bool } func CreateTable(name string) createTable { return createTable{table: name} } func (c createTable) IfNotExists() createTable { c.ifNotExists = true return c } func (c createTable) Column(name string, dataType SQLiteType, constraints ...ColumnConstraint) createTable { col := columnDef{ name: name, dataType: dataType, } for _, constraint := range constraints { constraint.applyConstraint(&col) } c.columns = append(c.columns, col) return c } func (c createTable) PrimaryKey(cols ...string) createTable { c.tableConstraints = append(c.tableConstraints, fmt.Sprintf("PRIMARY KEY (%s)", strings.Join(cols, ", "))) return c } func (c createTable) UniqueConstraint(cols ...string) createTable { c.tableConstraints = append(c.tableConstraints, fmt.Sprintf("UNIQUE (%s)", strings.Join(cols, ", "))) return c } func (c createTable) CheckConstraint(expr string) createTable { c.tableConstraints = append(c.tableConstraints, fmt.Sprintf("CHECK (%s)", expr)) return c } func (c createTable) ForeignKey(col, refTable, refCol string) createTable { c.tableConstraints = append(c.tableConstraints, fmt.Sprintf("FOREIGN KEY (%s) REFERENCES %s(%s)", col, refTable, refCol)) return c } func (c createTable) WithoutRowid() createTable { c.withoutRowid = true return c } func (c createTable) Strict() createTable { c.strict = true return c } func (c createTable) Compile() (string, []any, error) { var sql strings.Builder sql.WriteString("CREATE TABLE ") if c.ifNotExists { sql.WriteString("IF NOT EXISTS ") } if c.table == "" { return "", nil, fmt.Errorf("table name is required") } sql.WriteString(c.table) if len(c.columns) == 0 { return "", nil, fmt.Errorf("at least one column is required") } sql.WriteString(" (") // Column definitions for i, col := range c.columns { if i > 0 { sql.WriteString(", ") } sql.WriteString(col.name) sql.WriteString(" ") sql.WriteString(string(col.dataType)) for _, constraint := range col.constraints { sql.WriteString(" ") sql.WriteString(constraint) } } // Table-level constraints for _, constraint := range c.tableConstraints { sql.WriteString(", ") sql.WriteString(constraint) } sql.WriteString(")") if c.strict { sql.WriteString(" STRICT") } if c.withoutRowid { sql.WriteString(" WITHOUT ROWID") } return sql.String(), nil, nil } func (c createTable) MustCompile() (string, []any) { sql, args, err := c.Compile() if err != nil { panic(err) } return sql, args } func (c createTable) Build(p Database) (*sql.Stmt, []any, error) { return Build(c, p) } func (c createTable) MustBuild(p Database) (*sql.Stmt, []any) { return MustBuild(c, p) } func (c createTable) Exec(p Database) (sql.Result, error) { return Exec(c, p) } func (c createTable) ExecContext(ctx context.Context, p Database) (sql.Result, error) { return ExecContext(ctx, c, p) } func (c createTable) MustExec(p Database) sql.Result { return MustExec(c, p) } func (c createTable) MustExecContext(ctx context.Context, p Database) sql.Result { return MustExecContext(ctx, c, p) }