1package norm
2
3import (
4 "database/sql"
5 "iter"
6 "reflect"
7)
8
9type Scanner[T any] struct {
10 rows *sql.Rows
11 onError func(err error)
12}
13
14func NewScanner[T any](rows *sql.Rows) Scanner[T] {
15 return Scanner[T]{
16 rows: rows,
17 }
18}
19
20func (s *Scanner[T]) Scan() iter.Seq2[T, error] {
21 return func(yield func(T, error) bool) {
22 for s.rows.Next() {
23 var data T
24 elem := reflect.ValueOf(&data).Elem()
25 numCols := elem.NumField()
26 columns := make([]any, numCols)
27
28 for i := range numCols {
29 field := elem.Field(i)
30 columns[i] = field.Addr().Interface()
31 }
32
33 err := s.rows.Scan(columns...)
34
35 if !yield(data, err) {
36 return
37 }
38 }
39 }
40}
41
42func (s *Scanner[T]) Close() error {
43 return s.rows.Close()
44}
45
46func ScanAll[T any](rows *sql.Rows, dest *[]T) error {
47 scanner := NewScanner[T](rows)
48 defer scanner.Close()
49
50 for elem, err := range scanner.Scan() {
51 if err != nil {
52 return err
53 }
54 *dest = append(*dest, elem)
55 }
56
57 return nil
58}
59
60func Scan[T any](row *sql.Row, dest *T) error {
61 elem := reflect.ValueOf(dest).Elem()
62 numCols := elem.NumField()
63 columns := make([]any, numCols)
64
65 for i := range numCols {
66 field := elem.Field(i)
67 columns[i] = field.Addr().Interface()
68 }
69
70 return row.Scan(columns...)
71}