10 var ErrBadRequest = errors.New("bad request")
12 // The type of "error"
13 var errorInterface = reflect.TypeOf((*error)(nil)).Elem()
15 // ForQueryRows encapsulates a lot of boilerplate when making db queries.
18 // err = ForQueryRows(ctx, db, query, queryArg1, queryArg2, ..., func(scanVar1 type1, scanVar2 type2, ...) {
19 // ...process a row from the result...
22 // This is equivalent to:
24 // rows, err = db.Query(ctx, query, queryArg1, queryArg2, ...)
34 // err = rows.Scan(&scanVar1, &scanVar2, ...)
38 // ...process a row from the result...
40 // if err = rows.Err(); err != nil {
44 // The callback is invoked once for each row in the result. The
45 // number and types of parameters to the callback must match the
46 // values to be scanned with rows.Scan. The space for the callback's
47 // arguments is not reused between calls. The callback may return a
48 // single error-type value. If any invocation yields a non-nil
49 // result, ForQueryRows will abort and return it.
50 func ForQueryRows(ctx context.Context, db DB, query string, args ...interface{}) error {
52 return errors.Wrap(ErrBadRequest, "too few arguments")
55 fnArg := args[len(args)-1]
56 queryArgs := args[:len(args)-1]
58 fnType := reflect.TypeOf(fnArg)
59 if fnType.Kind() != reflect.Func {
60 return errors.Wrap(ErrBadRequest, "fn arg not a function")
62 if fnType.NumOut() > 1 {
63 return errors.Wrap(ErrBadRequest, "fn arg must return 0 values or 1")
65 if fnType.NumOut() == 1 && !fnType.Out(0).Implements(errorInterface) {
66 return errors.Wrap(ErrBadRequest, "fn arg return type must be error")
69 rows, err := db.QueryContext(ctx, query, queryArgs...)
71 return errors.Wrap(err, "query")
75 fnVal := reflect.ValueOf(fnArg)
77 argPtrVals := make([]reflect.Value, 0, fnType.NumIn())
78 scanArgs := make([]interface{}, 0, fnType.NumIn())
79 fnArgs := make([]reflect.Value, 0, fnType.NumIn())
82 argPtrVals = argPtrVals[:0]
83 scanArgs = scanArgs[:0]
85 for i := 0; i < fnType.NumIn(); i++ {
86 argType := fnType.In(i)
87 argPtrVal := reflect.New(argType)
88 argPtrVals = append(argPtrVals, argPtrVal)
89 scanArgs = append(scanArgs, argPtrVal.Interface())
91 err = rows.Scan(scanArgs...)
93 return errors.Wrap(err, "scan")
95 for _, argPtrVal := range argPtrVals {
96 fnArgs = append(fnArgs, argPtrVal.Elem())
98 res := fnVal.Call(fnArgs)
99 if fnType.NumOut() == 1 && !res[0].IsNil() {
100 return errors.Wrap(res[0].Interface().(error), "callback")
104 return errors.Wrap(rows.Err(), "end scan")