+21
-35
scanner.go
+21
-35
scanner.go
···
11
11
onError func(err error)
12
12
}
13
13
14
-
type ScannerOpt[T any] func(*Scanner[T])
15
-
16
-
func OnScannerError[T any](errFunc func(error)) ScannerOpt[T] {
17
-
return func(s *Scanner[T]) {
18
-
s.onError = errFunc
19
-
}
20
-
}
21
-
22
-
func NewScanner[T any](rows *sql.Rows, opts ...ScannerOpt[T]) Scanner[T] {
23
-
scanner := Scanner[T]{
14
+
func NewScanner[T any](rows *sql.Rows) Scanner[T] {
15
+
return Scanner[T]{
24
16
rows: rows,
25
17
}
26
-
27
-
for _, o := range opts {
28
-
o(&scanner)
29
-
}
30
-
31
-
return scanner
32
18
}
33
19
34
-
func (s *Scanner[T]) Scan() iter.Seq[T] {
35
-
return func(yield func(T) bool) {
20
+
func (s *Scanner[T]) Scan() iter.Seq2[T, error] {
21
+
return func(yield func(T, error) bool) {
36
22
for s.rows.Next() {
37
23
var data T
38
24
elem := reflect.ValueOf(&data).Elem()
···
44
30
columns[i] = field.Addr().Interface()
45
31
}
46
32
47
-
if err := s.rows.Scan(columns...); err != nil {
48
-
s.onError(err)
49
-
return
50
-
}
33
+
err := s.rows.Scan(columns...)
51
34
52
-
// Yield the row - if yield returns false, stop iteration
53
-
if !yield(data) {
35
+
if !yield(data, err) {
54
36
return
55
37
}
56
38
}
···
61
43
return s.rows.Close()
62
44
}
63
45
64
-
func ScanAll[T any](rows *sql.Rows) []T {
46
+
func ScanAll[T any](rows *sql.Rows, dest *[]T) error {
65
47
scanner := NewScanner[T](rows)
66
48
defer scanner.Close()
67
49
68
-
var elems []T
69
-
for elem := range scanner.Scan() {
70
-
elems = append(elems, elem)
50
+
for elem, err := range scanner.Scan() {
51
+
if err != nil {
52
+
return err
53
+
}
54
+
*dest = append(*dest, elem)
71
55
}
72
56
73
-
return elems
57
+
return nil
74
58
}
75
59
76
-
func ScanOne[T any](rows *sql.Rows) *T {
77
-
scanner := NewScanner[T](rows)
78
-
defer scanner.Close()
60
+
func Scan[T any](row *sql.Row, dest *T) error {
61
+
elem := reflect.ValueOf(dest).Elem()
62
+
numCols := elem.NumField()
63
+
columns := make([]any, numCols)
79
64
80
-
for elem := range scanner.Scan() {
81
-
return &elem
65
+
for i := range numCols {
66
+
field := elem.Field(i)
67
+
columns[i] = field.Addr().Interface()
82
68
}
83
69
84
-
return nil
70
+
return row.Scan(columns...)
85
71
}
+28
-62
scanner_test.go
+28
-62
scanner_test.go
···
44
44
t.Fatalf("Failed to query departments: %v", err)
45
45
}
46
46
47
-
departments := ScanAll[Department](rows)
47
+
var departments []Department
48
+
err = ScanAll(rows, &departments)
49
+
if err != nil {
50
+
t.Fatalf("Failed to scan departments: %v", err)
51
+
}
48
52
49
53
if len(departments) != 3 {
50
54
t.Errorf("Expected 3 departments, got %d", len(departments))
···
70
74
t.Fatalf("Failed to query users: %v", err)
71
75
}
72
76
73
-
scanner := NewScanner[UserBasic](rows)
74
-
defer scanner.Close()
75
77
var users []UserBasic
76
-
77
-
for user := range scanner.Scan() {
78
-
users = append(users, user)
78
+
err = ScanAll(rows, &users)
79
+
if err != nil {
80
+
t.Fatalf("Failed to scan users: %v", err)
79
81
}
80
82
81
83
if len(users) != 6 {
···
100
102
t.Fatalf("Failed to query users: %v", err)
101
103
}
102
104
103
-
scanner := NewScanner[User](rows)
104
-
defer scanner.Close()
105
105
var users []User
106
-
107
-
for user := range scanner.Scan() {
108
-
users = append(users, user)
106
+
err = ScanAll(rows, &users)
107
+
if err != nil {
108
+
t.Fatalf("Failed to scan users: %v", err)
109
109
}
110
110
111
111
if len(users) != 6 {
···
133
133
}
134
134
}
135
135
136
-
func TestScannerWithErrorHandler(t *testing.T) {
137
-
db := setupTestDB(t)
138
-
defer db.Close()
139
-
140
-
rows, err := db.Query("SELECT id, name, email FROM users LIMIT 1")
141
-
if err != nil {
142
-
t.Fatalf("Failed to query users: %v", err)
143
-
}
144
-
145
-
var capturedError error
146
-
scanner := NewScanner(rows, OnScannerError[SimpleStruct](func(err error) {
147
-
capturedError = err
148
-
}))
149
-
defer scanner.Close()
150
-
151
-
var results []SimpleStruct
152
-
for result := range scanner.Scan() {
153
-
results = append(results, result)
154
-
}
155
-
156
-
if capturedError == nil {
157
-
t.Error("Expected an error to be captured, but none was")
158
-
}
159
-
160
-
if len(results) != 0 {
161
-
t.Errorf("Expected 0 results due to scan error, got %d", len(results))
162
-
}
163
-
}
164
-
165
136
func TestScannerEarlyTermination(t *testing.T) {
166
137
db := setupTestDB(t)
167
138
defer db.Close()
···
247
218
}
248
219
}
249
220
250
-
func TestScannerWithNilErrorHandler(t *testing.T) {
251
-
db := setupTestDB(t)
252
-
defer db.Close()
253
-
254
-
rows, err := db.Query("SELECT id, name, email FROM users LIMIT 1")
255
-
if err != nil {
256
-
t.Fatalf("Failed to query users: %v", err)
257
-
}
258
-
259
-
scanner := NewScanner[UserBasic](rows)
260
-
defer scanner.Close()
261
-
var users []UserBasic
262
-
263
-
for user := range scanner.Scan() {
264
-
users = append(users, user)
265
-
}
266
-
267
-
if len(users) != 1 {
268
-
t.Errorf("Expected 1 user, got %d", len(users))
269
-
}
270
-
}
271
-
272
221
func TestScannerResourceCleanup(t *testing.T) {
273
222
db := setupTestDB(t)
274
223
defer db.Close()
···
325
274
}
326
275
}
327
276
}
277
+
278
+
func TestScanOne(t *testing.T) {
279
+
db := setupTestDB(t)
280
+
defer db.Close()
281
+
282
+
row := db.QueryRow("SELECT * FROM users WHERE id = 1")
283
+
284
+
var u User
285
+
err := Scan(row, &u)
286
+
if err != nil {
287
+
t.Fatalf("Failed to scan user: %v", err)
288
+
}
289
+
290
+
if u.ID != 1 || u.Name != "John Doe" {
291
+
t.Errorf("user data incorrect: %+v", u)
292
+
}
293
+
}