OSDN Git Service

add sqlite vendor (#48)
[bytom/vapor.git] / vendor / github.com / mattn / go-sqlite3 / callback.go
1 // Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
2 //
3 // Use of this source code is governed by an MIT-style
4 // license that can be found in the LICENSE file.
5
6 package sqlite3
7
8 // You can't export a Go function to C and have definitions in the C
9 // preamble in the same file, so we have to have callbackTrampoline in
10 // its own file. Because we need a separate file anyway, the support
11 // code for SQLite custom functions is in here.
12
13 /*
14 #ifndef USE_LIBSQLITE3
15 #include <sqlite3-binding.h>
16 #else
17 #include <sqlite3.h>
18 #endif
19 #include <stdlib.h>
20
21 void _sqlite3_result_text(sqlite3_context* ctx, const char* s);
22 void _sqlite3_result_blob(sqlite3_context* ctx, const void* b, int l);
23 */
24 import "C"
25
26 import (
27         "errors"
28         "fmt"
29         "math"
30         "reflect"
31         "sync"
32         "unsafe"
33 )
34
35 //export callbackTrampoline
36 func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value) {
37         args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:argc:argc]
38         fi := lookupHandle(uintptr(C.sqlite3_user_data(ctx))).(*functionInfo)
39         fi.Call(ctx, args)
40 }
41
42 //export stepTrampoline
43 func stepTrampoline(ctx *C.sqlite3_context, argc C.int, argv **C.sqlite3_value) {
44         args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:int(argc):int(argc)]
45         ai := lookupHandle(uintptr(C.sqlite3_user_data(ctx))).(*aggInfo)
46         ai.Step(ctx, args)
47 }
48
49 //export doneTrampoline
50 func doneTrampoline(ctx *C.sqlite3_context) {
51         handle := uintptr(C.sqlite3_user_data(ctx))
52         ai := lookupHandle(handle).(*aggInfo)
53         ai.Done(ctx)
54 }
55
56 //export compareTrampoline
57 func compareTrampoline(handlePtr uintptr, la C.int, a *C.char, lb C.int, b *C.char) C.int {
58         cmp := lookupHandle(handlePtr).(func(string, string) int)
59         return C.int(cmp(C.GoStringN(a, la), C.GoStringN(b, lb)))
60 }
61
62 //export commitHookTrampoline
63 func commitHookTrampoline(handle uintptr) int {
64         callback := lookupHandle(handle).(func() int)
65         return callback()
66 }
67
68 //export rollbackHookTrampoline
69 func rollbackHookTrampoline(handle uintptr) {
70         callback := lookupHandle(handle).(func())
71         callback()
72 }
73
74 //export updateHookTrampoline
75 func updateHookTrampoline(handle uintptr, op int, db *C.char, table *C.char, rowid int64) {
76         callback := lookupHandle(handle).(func(int, string, string, int64))
77         callback(op, C.GoString(db), C.GoString(table), rowid)
78 }
79
80 //export authorizerTrampoline
81 func authorizerTrampoline(handle uintptr, op int, arg1 *C.char, arg2 *C.char, arg3 *C.char) int {
82         callback := lookupHandle(handle).(func(int, string, string, string) int)
83         return callback(op, C.GoString(arg1), C.GoString(arg2), C.GoString(arg3))
84 }
85
86 // Use handles to avoid passing Go pointers to C.
87
88 type handleVal struct {
89         db  *SQLiteConn
90         val interface{}
91 }
92
93 var handleLock sync.Mutex
94 var handleVals = make(map[uintptr]handleVal)
95 var handleIndex uintptr = 100
96
97 func newHandle(db *SQLiteConn, v interface{}) uintptr {
98         handleLock.Lock()
99         defer handleLock.Unlock()
100         i := handleIndex
101         handleIndex++
102         handleVals[i] = handleVal{db, v}
103         return i
104 }
105
106 func lookupHandle(handle uintptr) interface{} {
107         handleLock.Lock()
108         defer handleLock.Unlock()
109         r, ok := handleVals[handle]
110         if !ok {
111                 if handle >= 100 && handle < handleIndex {
112                         panic("deleted handle")
113                 } else {
114                         panic("invalid handle")
115                 }
116         }
117         return r.val
118 }
119
120 func deleteHandles(db *SQLiteConn) {
121         handleLock.Lock()
122         defer handleLock.Unlock()
123         for handle, val := range handleVals {
124                 if val.db == db {
125                         delete(handleVals, handle)
126                 }
127         }
128 }
129
130 // This is only here so that tests can refer to it.
131 type callbackArgRaw C.sqlite3_value
132
133 type callbackArgConverter func(*C.sqlite3_value) (reflect.Value, error)
134
135 type callbackArgCast struct {
136         f   callbackArgConverter
137         typ reflect.Type
138 }
139
140 func (c callbackArgCast) Run(v *C.sqlite3_value) (reflect.Value, error) {
141         val, err := c.f(v)
142         if err != nil {
143                 return reflect.Value{}, err
144         }
145         if !val.Type().ConvertibleTo(c.typ) {
146                 return reflect.Value{}, fmt.Errorf("cannot convert %s to %s", val.Type(), c.typ)
147         }
148         return val.Convert(c.typ), nil
149 }
150
151 func callbackArgInt64(v *C.sqlite3_value) (reflect.Value, error) {
152         if C.sqlite3_value_type(v) != C.SQLITE_INTEGER {
153                 return reflect.Value{}, fmt.Errorf("argument must be an INTEGER")
154         }
155         return reflect.ValueOf(int64(C.sqlite3_value_int64(v))), nil
156 }
157
158 func callbackArgBool(v *C.sqlite3_value) (reflect.Value, error) {
159         if C.sqlite3_value_type(v) != C.SQLITE_INTEGER {
160                 return reflect.Value{}, fmt.Errorf("argument must be an INTEGER")
161         }
162         i := int64(C.sqlite3_value_int64(v))
163         val := false
164         if i != 0 {
165                 val = true
166         }
167         return reflect.ValueOf(val), nil
168 }
169
170 func callbackArgFloat64(v *C.sqlite3_value) (reflect.Value, error) {
171         if C.sqlite3_value_type(v) != C.SQLITE_FLOAT {
172                 return reflect.Value{}, fmt.Errorf("argument must be a FLOAT")
173         }
174         return reflect.ValueOf(float64(C.sqlite3_value_double(v))), nil
175 }
176
177 func callbackArgBytes(v *C.sqlite3_value) (reflect.Value, error) {
178         switch C.sqlite3_value_type(v) {
179         case C.SQLITE_BLOB:
180                 l := C.sqlite3_value_bytes(v)
181                 p := C.sqlite3_value_blob(v)
182                 return reflect.ValueOf(C.GoBytes(p, l)), nil
183         case C.SQLITE_TEXT:
184                 l := C.sqlite3_value_bytes(v)
185                 c := unsafe.Pointer(C.sqlite3_value_text(v))
186                 return reflect.ValueOf(C.GoBytes(c, l)), nil
187         default:
188                 return reflect.Value{}, fmt.Errorf("argument must be BLOB or TEXT")
189         }
190 }
191
192 func callbackArgString(v *C.sqlite3_value) (reflect.Value, error) {
193         switch C.sqlite3_value_type(v) {
194         case C.SQLITE_BLOB:
195                 l := C.sqlite3_value_bytes(v)
196                 p := (*C.char)(C.sqlite3_value_blob(v))
197                 return reflect.ValueOf(C.GoStringN(p, l)), nil
198         case C.SQLITE_TEXT:
199                 c := (*C.char)(unsafe.Pointer(C.sqlite3_value_text(v)))
200                 return reflect.ValueOf(C.GoString(c)), nil
201         default:
202                 return reflect.Value{}, fmt.Errorf("argument must be BLOB or TEXT")
203         }
204 }
205
206 func callbackArgGeneric(v *C.sqlite3_value) (reflect.Value, error) {
207         switch C.sqlite3_value_type(v) {
208         case C.SQLITE_INTEGER:
209                 return callbackArgInt64(v)
210         case C.SQLITE_FLOAT:
211                 return callbackArgFloat64(v)
212         case C.SQLITE_TEXT:
213                 return callbackArgString(v)
214         case C.SQLITE_BLOB:
215                 return callbackArgBytes(v)
216         case C.SQLITE_NULL:
217                 // Interpret NULL as a nil byte slice.
218                 var ret []byte
219                 return reflect.ValueOf(ret), nil
220         default:
221                 panic("unreachable")
222         }
223 }
224
225 func callbackArg(typ reflect.Type) (callbackArgConverter, error) {
226         switch typ.Kind() {
227         case reflect.Interface:
228                 if typ.NumMethod() != 0 {
229                         return nil, errors.New("the only supported interface type is interface{}")
230                 }
231                 return callbackArgGeneric, nil
232         case reflect.Slice:
233                 if typ.Elem().Kind() != reflect.Uint8 {
234                         return nil, errors.New("the only supported slice type is []byte")
235                 }
236                 return callbackArgBytes, nil
237         case reflect.String:
238                 return callbackArgString, nil
239         case reflect.Bool:
240                 return callbackArgBool, nil
241         case reflect.Int64:
242                 return callbackArgInt64, nil
243         case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint:
244                 c := callbackArgCast{callbackArgInt64, typ}
245                 return c.Run, nil
246         case reflect.Float64:
247                 return callbackArgFloat64, nil
248         case reflect.Float32:
249                 c := callbackArgCast{callbackArgFloat64, typ}
250                 return c.Run, nil
251         default:
252                 return nil, fmt.Errorf("don't know how to convert to %s", typ)
253         }
254 }
255
256 func callbackConvertArgs(argv []*C.sqlite3_value, converters []callbackArgConverter, variadic callbackArgConverter) ([]reflect.Value, error) {
257         var args []reflect.Value
258
259         if len(argv) < len(converters) {
260                 return nil, fmt.Errorf("function requires at least %d arguments", len(converters))
261         }
262
263         for i, arg := range argv[:len(converters)] {
264                 v, err := converters[i](arg)
265                 if err != nil {
266                         return nil, err
267                 }
268                 args = append(args, v)
269         }
270
271         if variadic != nil {
272                 for _, arg := range argv[len(converters):] {
273                         v, err := variadic(arg)
274                         if err != nil {
275                                 return nil, err
276                         }
277                         args = append(args, v)
278                 }
279         }
280         return args, nil
281 }
282
283 type callbackRetConverter func(*C.sqlite3_context, reflect.Value) error
284
285 func callbackRetInteger(ctx *C.sqlite3_context, v reflect.Value) error {
286         switch v.Type().Kind() {
287         case reflect.Int64:
288         case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint:
289                 v = v.Convert(reflect.TypeOf(int64(0)))
290         case reflect.Bool:
291                 b := v.Interface().(bool)
292                 if b {
293                         v = reflect.ValueOf(int64(1))
294                 } else {
295                         v = reflect.ValueOf(int64(0))
296                 }
297         default:
298                 return fmt.Errorf("cannot convert %s to INTEGER", v.Type())
299         }
300
301         C.sqlite3_result_int64(ctx, C.sqlite3_int64(v.Interface().(int64)))
302         return nil
303 }
304
305 func callbackRetFloat(ctx *C.sqlite3_context, v reflect.Value) error {
306         switch v.Type().Kind() {
307         case reflect.Float64:
308         case reflect.Float32:
309                 v = v.Convert(reflect.TypeOf(float64(0)))
310         default:
311                 return fmt.Errorf("cannot convert %s to FLOAT", v.Type())
312         }
313
314         C.sqlite3_result_double(ctx, C.double(v.Interface().(float64)))
315         return nil
316 }
317
318 func callbackRetBlob(ctx *C.sqlite3_context, v reflect.Value) error {
319         if v.Type().Kind() != reflect.Slice || v.Type().Elem().Kind() != reflect.Uint8 {
320                 return fmt.Errorf("cannot convert %s to BLOB", v.Type())
321         }
322         i := v.Interface()
323         if i == nil || len(i.([]byte)) == 0 {
324                 C.sqlite3_result_null(ctx)
325         } else {
326                 bs := i.([]byte)
327                 C._sqlite3_result_blob(ctx, unsafe.Pointer(&bs[0]), C.int(len(bs)))
328         }
329         return nil
330 }
331
332 func callbackRetText(ctx *C.sqlite3_context, v reflect.Value) error {
333         if v.Type().Kind() != reflect.String {
334                 return fmt.Errorf("cannot convert %s to TEXT", v.Type())
335         }
336         C._sqlite3_result_text(ctx, C.CString(v.Interface().(string)))
337         return nil
338 }
339
340 func callbackRetNil(ctx *C.sqlite3_context, v reflect.Value) error {
341         return nil
342 }
343
344 func callbackRet(typ reflect.Type) (callbackRetConverter, error) {
345         switch typ.Kind() {
346         case reflect.Interface:
347                 errorInterface := reflect.TypeOf((*error)(nil)).Elem()
348                 if typ.Implements(errorInterface) {
349                         return callbackRetNil, nil
350                 }
351                 fallthrough
352         case reflect.Slice:
353                 if typ.Elem().Kind() != reflect.Uint8 {
354                         return nil, errors.New("the only supported slice type is []byte")
355                 }
356                 return callbackRetBlob, nil
357         case reflect.String:
358                 return callbackRetText, nil
359         case reflect.Bool, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint:
360                 return callbackRetInteger, nil
361         case reflect.Float32, reflect.Float64:
362                 return callbackRetFloat, nil
363         default:
364                 return nil, fmt.Errorf("don't know how to convert to %s", typ)
365         }
366 }
367
368 func callbackError(ctx *C.sqlite3_context, err error) {
369         cstr := C.CString(err.Error())
370         defer C.free(unsafe.Pointer(cstr))
371         C.sqlite3_result_error(ctx, cstr, C.int(-1))
372 }
373
374 // Test support code. Tests are not allowed to import "C", so we can't
375 // declare any functions that use C.sqlite3_value.
376 func callbackSyntheticForTests(v reflect.Value, err error) callbackArgConverter {
377         return func(*C.sqlite3_value) (reflect.Value, error) {
378                 return v, err
379         }
380 }