package norm import ( "context" "database/sql" "fmt" "strconv" "strings" ) type select_ struct { resultColumns []string // TODO: strongly type result columns from string // table-or-subquery expr where *Expr orderBy []orderBy groupBy []groupBy limit *limit } type orderBy struct { field string direction Direction } type groupBy struct { field string } type limit struct { limit int } func Select(cols ...string) select_ { return select_{ resultColumns: cols, } } type SelectOpt func(s *select_) func (s select_) From(table string) select_ { s.from = table return s } func (s select_) Where(expr Expr) select_ { s.where = &expr return s } func (s select_) OrderBy(field string, direction Direction) select_ { s.orderBy = append(s.orderBy, orderBy{ field: field, direction: direction, }) return s } func (s select_) GroupBy(field string) select_ { s.groupBy = append(s.groupBy, groupBy{ field: field, }) return s } func (s select_) Limit(i int) select_ { s.limit = &limit{ limit: i, } return s } func (s select_) Compile() (string, []any, error) { var sql strings.Builder var args []any sql.WriteString("SELECT ") if len(s.resultColumns) == 0 { return "", nil, fmt.Errorf("result columns empty") } else { for i, col := range s.resultColumns { if i > 0 { sql.WriteString(", ") } sql.WriteString(col) } } if s.from == "" { return "", nil, fmt.Errorf("FROM clause is required") } sql.WriteString(" FROM ") sql.WriteString(s.from) if s.where != nil { sql.WriteString(" WHERE ") sql.WriteString(s.where.String()) args = s.where.Binds() } // GROUP BY clause if len(s.groupBy) > 0 { sql.WriteString(" GROUP BY ") for i, gb := range s.groupBy { if i > 0 { sql.WriteString(", ") } sql.WriteString(gb.field) } } // ORDER BY clause if len(s.orderBy) > 0 { sql.WriteString(" ORDER BY ") for i, ob := range s.orderBy { if i > 0 { sql.WriteString(", ") } sql.WriteString(ob.field) sql.WriteString(" ") sql.WriteString(string(ob.direction)) } } // LIMIT clause if s.limit != nil { if s.limit.limit <= 0 { return "", nil, fmt.Errorf("LIMIT must be positive, got %d", s.limit.limit) } sql.WriteString(" LIMIT ") sql.WriteString(strconv.Itoa(s.limit.limit)) } return sql.String(), args, nil } func (s select_) MustCompile() (string, []any) { sql, args, err := s.Compile() if err != nil { panic(err) } return sql, args } func (s select_) Build(p Database) (*sql.Stmt, []any, error) { return Build(s, p) } func (s select_) MustBuild(p Database) (*sql.Stmt, []any) { return MustBuild(s, p) } func (s select_) Exec(p Database) (sql.Result, error) { return Exec(s, p) } func (s select_) ExecContext(ctx context.Context, p Database) (sql.Result, error) { return ExecContext(ctx, s, p) } func (s select_) MustExec(p Database) sql.Result { return MustExec(s, p) } func (s select_) MustExecContext(ctx context.Context, p Database) sql.Result { return MustExecContext(ctx, s, p) } func (s select_) Query(p Database) (*sql.Rows, error) { return Query(s, p) } func (s select_) QueryContext(ctx context.Context, p Database) (*sql.Rows, error) { return QueryContext(ctx, s, p) } func (s select_) QueryRow(p Database) (*sql.Row, error) { return QueryRow(s, p) } func (s select_) QueryRowContext(ctx context.Context, p Database) (*sql.Row, error) { return QueryRowContext(ctx, s, p) } func (s select_) MustQuery(p Database) *sql.Rows { return MustQuery(s, p) } func (s select_) MustQueryContext(ctx context.Context, p Database) *sql.Rows { return MustQueryContext(ctx, s, p) } func (s select_) MustQueryRow(p Database) *sql.Row { return MustQueryRow(s, p) } func (s select_) MustQueryRowContext(ctx context.Context, p Database) *sql.Row { return MustQueryRowContext(ctx, s, p) }