OSDN Git Service

Added blockchain struct.
[bytom/bytom.git] / database / pg / query.go
1 package pg
2
3 import (
4         "context"
5         "reflect"
6
7         "chain/errors"
8 )
9
10 var ErrBadRequest = errors.New("bad request")
11
12 // The type of "error"
13 var errorInterface = reflect.TypeOf((*error)(nil)).Elem()
14
15 // ForQueryRows encapsulates a lot of boilerplate when making db queries.
16 // Call it like this:
17 //
18 //   err = ForQueryRows(ctx, db, query, queryArg1, queryArg2, ..., func(scanVar1 type1, scanVar2 type2, ...) {
19 //     ...process a row from the result...
20 //   })
21 //
22 // This is equivalent to:
23 //
24 //   rows, err = db.Query(ctx, query, queryArg1, queryArg2, ...)
25 //   if err != nil {
26 //     return err
27 //   }
28 //   defer rows.Close()
29 //   for rows.Next() {
30 //     var (
31 //       scanVar1 type1
32 //       scanVar2 type2
33 //     )
34 //     err = rows.Scan(&scanVar1, &scanVar2, ...)
35 //     if err != nil {
36 //       return err
37 //     }
38 //     ...process a row from the result...
39 //   }
40 //   if err = rows.Err(); err != nil {
41 //     return err
42 //   }
43 //
44 // The callback is invoked once for each row in the result.  The
45 // number and types of parameters to the callback must match the
46 // values to be scanned with rows.Scan.  The space for the callback's
47 // arguments is not reused between calls.  The callback may return a
48 // single error-type value.  If any invocation yields a non-nil
49 // result, ForQueryRows will abort and return it.
50 func ForQueryRows(ctx context.Context, db DB, query string, args ...interface{}) error {
51         if len(args) == 0 {
52                 return errors.Wrap(ErrBadRequest, "too few arguments")
53         }
54
55         fnArg := args[len(args)-1]
56         queryArgs := args[:len(args)-1]
57
58         fnType := reflect.TypeOf(fnArg)
59         if fnType.Kind() != reflect.Func {
60                 return errors.Wrap(ErrBadRequest, "fn arg not a function")
61         }
62         if fnType.NumOut() > 1 {
63                 return errors.Wrap(ErrBadRequest, "fn arg must return 0 values or 1")
64         }
65         if fnType.NumOut() == 1 && !fnType.Out(0).Implements(errorInterface) {
66                 return errors.Wrap(ErrBadRequest, "fn arg return type must be error")
67         }
68
69         rows, err := db.QueryContext(ctx, query, queryArgs...)
70         if err != nil {
71                 return errors.Wrap(err, "query")
72         }
73         defer rows.Close()
74
75         fnVal := reflect.ValueOf(fnArg)
76
77         argPtrVals := make([]reflect.Value, 0, fnType.NumIn())
78         scanArgs := make([]interface{}, 0, fnType.NumIn())
79         fnArgs := make([]reflect.Value, 0, fnType.NumIn())
80
81         for rows.Next() {
82                 argPtrVals = argPtrVals[:0]
83                 scanArgs = scanArgs[:0]
84                 fnArgs = fnArgs[:0]
85                 for i := 0; i < fnType.NumIn(); i++ {
86                         argType := fnType.In(i)
87                         argPtrVal := reflect.New(argType)
88                         argPtrVals = append(argPtrVals, argPtrVal)
89                         scanArgs = append(scanArgs, argPtrVal.Interface())
90                 }
91                 err = rows.Scan(scanArgs...)
92                 if err != nil {
93                         return errors.Wrap(err, "scan")
94                 }
95                 for _, argPtrVal := range argPtrVals {
96                         fnArgs = append(fnArgs, argPtrVal.Elem())
97                 }
98                 res := fnVal.Call(fnArgs)
99                 if fnType.NumOut() == 1 && !res[0].IsNil() {
100                         return errors.Wrap(res[0].Interface().(error), "callback")
101                 }
102         }
103
104         return errors.Wrap(rows.Err(), "end scan")
105 }