an ORM-free SQL experience

rework scanner API to handle errors better

Signed-off-by: oppiliappan <me@oppi.li>

oppi.li 4d7450dc 0d2a557e

verified
Changed files
+49 -97
+21 -35
scanner.go
··· 11 onError func(err error) 12 } 13 14 - type ScannerOpt[T any] func(*Scanner[T]) 15 - 16 - func OnScannerError[T any](errFunc func(error)) ScannerOpt[T] { 17 - return func(s *Scanner[T]) { 18 - s.onError = errFunc 19 - } 20 - } 21 - 22 - func NewScanner[T any](rows *sql.Rows, opts ...ScannerOpt[T]) Scanner[T] { 23 - scanner := Scanner[T]{ 24 rows: rows, 25 } 26 - 27 - for _, o := range opts { 28 - o(&scanner) 29 - } 30 - 31 - return scanner 32 } 33 34 - func (s *Scanner[T]) Scan() iter.Seq[T] { 35 - return func(yield func(T) bool) { 36 for s.rows.Next() { 37 var data T 38 elem := reflect.ValueOf(&data).Elem() ··· 44 columns[i] = field.Addr().Interface() 45 } 46 47 - if err := s.rows.Scan(columns...); err != nil { 48 - s.onError(err) 49 - return 50 - } 51 52 - // Yield the row - if yield returns false, stop iteration 53 - if !yield(data) { 54 return 55 } 56 } ··· 61 return s.rows.Close() 62 } 63 64 - func ScanAll[T any](rows *sql.Rows) []T { 65 scanner := NewScanner[T](rows) 66 defer scanner.Close() 67 68 - var elems []T 69 - for elem := range scanner.Scan() { 70 - elems = append(elems, elem) 71 } 72 73 - return elems 74 } 75 76 - func ScanOne[T any](rows *sql.Rows) *T { 77 - scanner := NewScanner[T](rows) 78 - defer scanner.Close() 79 80 - for elem := range scanner.Scan() { 81 - return &elem 82 } 83 84 - return nil 85 }
··· 11 onError func(err error) 12 } 13 14 + func NewScanner[T any](rows *sql.Rows) Scanner[T] { 15 + return Scanner[T]{ 16 rows: rows, 17 } 18 } 19 20 + func (s *Scanner[T]) Scan() iter.Seq2[T, error] { 21 + return func(yield func(T, error) bool) { 22 for s.rows.Next() { 23 var data T 24 elem := reflect.ValueOf(&data).Elem() ··· 30 columns[i] = field.Addr().Interface() 31 } 32 33 + err := s.rows.Scan(columns...) 34 35 + if !yield(data, err) { 36 return 37 } 38 } ··· 43 return s.rows.Close() 44 } 45 46 + func ScanAll[T any](rows *sql.Rows, dest *[]T) error { 47 scanner := NewScanner[T](rows) 48 defer scanner.Close() 49 50 + for elem, err := range scanner.Scan() { 51 + if err != nil { 52 + return err 53 + } 54 + *dest = append(*dest, elem) 55 } 56 57 + return nil 58 } 59 60 + func Scan[T any](row *sql.Row, dest *T) error { 61 + elem := reflect.ValueOf(dest).Elem() 62 + numCols := elem.NumField() 63 + columns := make([]any, numCols) 64 65 + for i := range numCols { 66 + field := elem.Field(i) 67 + columns[i] = field.Addr().Interface() 68 } 69 70 + return row.Scan(columns...) 71 }
+28 -62
scanner_test.go
··· 44 t.Fatalf("Failed to query departments: %v", err) 45 } 46 47 - departments := ScanAll[Department](rows) 48 49 if len(departments) != 3 { 50 t.Errorf("Expected 3 departments, got %d", len(departments)) ··· 70 t.Fatalf("Failed to query users: %v", err) 71 } 72 73 - scanner := NewScanner[UserBasic](rows) 74 - defer scanner.Close() 75 var users []UserBasic 76 - 77 - for user := range scanner.Scan() { 78 - users = append(users, user) 79 } 80 81 if len(users) != 6 { ··· 100 t.Fatalf("Failed to query users: %v", err) 101 } 102 103 - scanner := NewScanner[User](rows) 104 - defer scanner.Close() 105 var users []User 106 - 107 - for user := range scanner.Scan() { 108 - users = append(users, user) 109 } 110 111 if len(users) != 6 { ··· 133 } 134 } 135 136 - func TestScannerWithErrorHandler(t *testing.T) { 137 - db := setupTestDB(t) 138 - defer db.Close() 139 - 140 - rows, err := db.Query("SELECT id, name, email FROM users LIMIT 1") 141 - if err != nil { 142 - t.Fatalf("Failed to query users: %v", err) 143 - } 144 - 145 - var capturedError error 146 - scanner := NewScanner(rows, OnScannerError[SimpleStruct](func(err error) { 147 - capturedError = err 148 - })) 149 - defer scanner.Close() 150 - 151 - var results []SimpleStruct 152 - for result := range scanner.Scan() { 153 - results = append(results, result) 154 - } 155 - 156 - if capturedError == nil { 157 - t.Error("Expected an error to be captured, but none was") 158 - } 159 - 160 - if len(results) != 0 { 161 - t.Errorf("Expected 0 results due to scan error, got %d", len(results)) 162 - } 163 - } 164 - 165 func TestScannerEarlyTermination(t *testing.T) { 166 db := setupTestDB(t) 167 defer db.Close() ··· 247 } 248 } 249 250 - func TestScannerWithNilErrorHandler(t *testing.T) { 251 - db := setupTestDB(t) 252 - defer db.Close() 253 - 254 - rows, err := db.Query("SELECT id, name, email FROM users LIMIT 1") 255 - if err != nil { 256 - t.Fatalf("Failed to query users: %v", err) 257 - } 258 - 259 - scanner := NewScanner[UserBasic](rows) 260 - defer scanner.Close() 261 - var users []UserBasic 262 - 263 - for user := range scanner.Scan() { 264 - users = append(users, user) 265 - } 266 - 267 - if len(users) != 1 { 268 - t.Errorf("Expected 1 user, got %d", len(users)) 269 - } 270 - } 271 - 272 func TestScannerResourceCleanup(t *testing.T) { 273 db := setupTestDB(t) 274 defer db.Close() ··· 325 } 326 } 327 }
··· 44 t.Fatalf("Failed to query departments: %v", err) 45 } 46 47 + var departments []Department 48 + err = ScanAll(rows, &departments) 49 + if err != nil { 50 + t.Fatalf("Failed to scan departments: %v", err) 51 + } 52 53 if len(departments) != 3 { 54 t.Errorf("Expected 3 departments, got %d", len(departments)) ··· 74 t.Fatalf("Failed to query users: %v", err) 75 } 76 77 var users []UserBasic 78 + err = ScanAll(rows, &users) 79 + if err != nil { 80 + t.Fatalf("Failed to scan users: %v", err) 81 } 82 83 if len(users) != 6 { ··· 102 t.Fatalf("Failed to query users: %v", err) 103 } 104 105 var users []User 106 + err = ScanAll(rows, &users) 107 + if err != nil { 108 + t.Fatalf("Failed to scan users: %v", err) 109 } 110 111 if len(users) != 6 { ··· 133 } 134 } 135 136 func TestScannerEarlyTermination(t *testing.T) { 137 db := setupTestDB(t) 138 defer db.Close() ··· 218 } 219 } 220 221 func TestScannerResourceCleanup(t *testing.T) { 222 db := setupTestDB(t) 223 defer db.Close() ··· 274 } 275 } 276 } 277 + 278 + func TestScanOne(t *testing.T) { 279 + db := setupTestDB(t) 280 + defer db.Close() 281 + 282 + row := db.QueryRow("SELECT * FROM users WHERE id = 1") 283 + 284 + var u User 285 + err := Scan(row, &u) 286 + if err != nil { 287 + t.Fatalf("Failed to scan user: %v", err) 288 + } 289 + 290 + if u.ID != 1 || u.Name != "John Doe" { 291 + t.Errorf("user data incorrect: %+v", u) 292 + } 293 + }