OSDN Git Service

versoin1.1.9 (#594)
[bytom/vapor.git] / vendor / github.com / go-sql-driver / mysql / utils.go
1 // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
2 //
3 // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
4 //
5 // This Source Code Form is subject to the terms of the Mozilla Public
6 // License, v. 2.0. If a copy of the MPL was not distributed with this file,
7 // You can obtain one at http://mozilla.org/MPL/2.0/.
8
9 package mysql
10
11 import (
12         "crypto/tls"
13         "database/sql"
14         "database/sql/driver"
15         "encoding/binary"
16         "errors"
17         "fmt"
18         "io"
19         "strconv"
20         "strings"
21         "sync"
22         "sync/atomic"
23         "time"
24 )
25
26 // Registry for custom tls.Configs
27 var (
28         tlsConfigLock     sync.RWMutex
29         tlsConfigRegistry map[string]*tls.Config
30 )
31
32 // RegisterTLSConfig registers a custom tls.Config to be used with sql.Open.
33 // Use the key as a value in the DSN where tls=value.
34 //
35 // Note: The provided tls.Config is exclusively owned by the driver after
36 // registering it.
37 //
38 //  rootCertPool := x509.NewCertPool()
39 //  pem, err := ioutil.ReadFile("/path/ca-cert.pem")
40 //  if err != nil {
41 //      log.Fatal(err)
42 //  }
43 //  if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
44 //      log.Fatal("Failed to append PEM.")
45 //  }
46 //  clientCert := make([]tls.Certificate, 0, 1)
47 //  certs, err := tls.LoadX509KeyPair("/path/client-cert.pem", "/path/client-key.pem")
48 //  if err != nil {
49 //      log.Fatal(err)
50 //  }
51 //  clientCert = append(clientCert, certs)
52 //  mysql.RegisterTLSConfig("custom", &tls.Config{
53 //      RootCAs: rootCertPool,
54 //      Certificates: clientCert,
55 //  })
56 //  db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom")
57 //
58 func RegisterTLSConfig(key string, config *tls.Config) error {
59         if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" {
60                 return fmt.Errorf("key '%s' is reserved", key)
61         }
62
63         tlsConfigLock.Lock()
64         if tlsConfigRegistry == nil {
65                 tlsConfigRegistry = make(map[string]*tls.Config)
66         }
67
68         tlsConfigRegistry[key] = config
69         tlsConfigLock.Unlock()
70         return nil
71 }
72
73 // DeregisterTLSConfig removes the tls.Config associated with key.
74 func DeregisterTLSConfig(key string) {
75         tlsConfigLock.Lock()
76         if tlsConfigRegistry != nil {
77                 delete(tlsConfigRegistry, key)
78         }
79         tlsConfigLock.Unlock()
80 }
81
82 func getTLSConfigClone(key string) (config *tls.Config) {
83         tlsConfigLock.RLock()
84         if v, ok := tlsConfigRegistry[key]; ok {
85                 config = v.Clone()
86         }
87         tlsConfigLock.RUnlock()
88         return
89 }
90
91 // Returns the bool value of the input.
92 // The 2nd return value indicates if the input was a valid bool value
93 func readBool(input string) (value bool, valid bool) {
94         switch input {
95         case "1", "true", "TRUE", "True":
96                 return true, true
97         case "0", "false", "FALSE", "False":
98                 return false, true
99         }
100
101         // Not a valid bool value
102         return
103 }
104
105 /******************************************************************************
106 *                           Time related utils                                *
107 ******************************************************************************/
108
109 // NullTime represents a time.Time that may be NULL.
110 // NullTime implements the Scanner interface so
111 // it can be used as a scan destination:
112 //
113 //  var nt NullTime
114 //  err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt)
115 //  ...
116 //  if nt.Valid {
117 //     // use nt.Time
118 //  } else {
119 //     // NULL value
120 //  }
121 //
122 // This NullTime implementation is not driver-specific
123 type NullTime struct {
124         Time  time.Time
125         Valid bool // Valid is true if Time is not NULL
126 }
127
128 // Scan implements the Scanner interface.
129 // The value type must be time.Time or string / []byte (formatted time-string),
130 // otherwise Scan fails.
131 func (nt *NullTime) Scan(value interface{}) (err error) {
132         if value == nil {
133                 nt.Time, nt.Valid = time.Time{}, false
134                 return
135         }
136
137         switch v := value.(type) {
138         case time.Time:
139                 nt.Time, nt.Valid = v, true
140                 return
141         case []byte:
142                 nt.Time, err = parseDateTime(string(v), time.UTC)
143                 nt.Valid = (err == nil)
144                 return
145         case string:
146                 nt.Time, err = parseDateTime(v, time.UTC)
147                 nt.Valid = (err == nil)
148                 return
149         }
150
151         nt.Valid = false
152         return fmt.Errorf("Can't convert %T to time.Time", value)
153 }
154
155 // Value implements the driver Valuer interface.
156 func (nt NullTime) Value() (driver.Value, error) {
157         if !nt.Valid {
158                 return nil, nil
159         }
160         return nt.Time, nil
161 }
162
163 func parseDateTime(str string, loc *time.Location) (t time.Time, err error) {
164         base := "0000-00-00 00:00:00.0000000"
165         switch len(str) {
166         case 10, 19, 21, 22, 23, 24, 25, 26: // up to "YYYY-MM-DD HH:MM:SS.MMMMMM"
167                 if str == base[:len(str)] {
168                         return
169                 }
170                 t, err = time.Parse(timeFormat[:len(str)], str)
171         default:
172                 err = fmt.Errorf("invalid time string: %s", str)
173                 return
174         }
175
176         // Adjust location
177         if err == nil && loc != time.UTC {
178                 y, mo, d := t.Date()
179                 h, mi, s := t.Clock()
180                 t, err = time.Date(y, mo, d, h, mi, s, t.Nanosecond(), loc), nil
181         }
182
183         return
184 }
185
186 func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Value, error) {
187         switch num {
188         case 0:
189                 return time.Time{}, nil
190         case 4:
191                 return time.Date(
192                         int(binary.LittleEndian.Uint16(data[:2])), // year
193                         time.Month(data[2]),                       // month
194                         int(data[3]),                              // day
195                         0, 0, 0, 0,
196                         loc,
197                 ), nil
198         case 7:
199                 return time.Date(
200                         int(binary.LittleEndian.Uint16(data[:2])), // year
201                         time.Month(data[2]),                       // month
202                         int(data[3]),                              // day
203                         int(data[4]),                              // hour
204                         int(data[5]),                              // minutes
205                         int(data[6]),                              // seconds
206                         0,
207                         loc,
208                 ), nil
209         case 11:
210                 return time.Date(
211                         int(binary.LittleEndian.Uint16(data[:2])), // year
212                         time.Month(data[2]),                       // month
213                         int(data[3]),                              // day
214                         int(data[4]),                              // hour
215                         int(data[5]),                              // minutes
216                         int(data[6]),                              // seconds
217                         int(binary.LittleEndian.Uint32(data[7:11]))*1000, // nanoseconds
218                         loc,
219                 ), nil
220         }
221         return nil, fmt.Errorf("invalid DATETIME packet length %d", num)
222 }
223
224 // zeroDateTime is used in formatBinaryDateTime to avoid an allocation
225 // if the DATE or DATETIME has the zero value.
226 // It must never be changed.
227 // The current behavior depends on database/sql copying the result.
228 var zeroDateTime = []byte("0000-00-00 00:00:00.000000")
229
230 const digits01 = "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789"
231 const digits10 = "0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999"
232
233 func appendMicrosecs(dst, src []byte, decimals int) []byte {
234         if decimals <= 0 {
235                 return dst
236         }
237         if len(src) == 0 {
238                 return append(dst, ".000000"[:decimals+1]...)
239         }
240
241         microsecs := binary.LittleEndian.Uint32(src[:4])
242         p1 := byte(microsecs / 10000)
243         microsecs -= 10000 * uint32(p1)
244         p2 := byte(microsecs / 100)
245         microsecs -= 100 * uint32(p2)
246         p3 := byte(microsecs)
247
248         switch decimals {
249         default:
250                 return append(dst, '.',
251                         digits10[p1], digits01[p1],
252                         digits10[p2], digits01[p2],
253                         digits10[p3], digits01[p3],
254                 )
255         case 1:
256                 return append(dst, '.',
257                         digits10[p1],
258                 )
259         case 2:
260                 return append(dst, '.',
261                         digits10[p1], digits01[p1],
262                 )
263         case 3:
264                 return append(dst, '.',
265                         digits10[p1], digits01[p1],
266                         digits10[p2],
267                 )
268         case 4:
269                 return append(dst, '.',
270                         digits10[p1], digits01[p1],
271                         digits10[p2], digits01[p2],
272                 )
273         case 5:
274                 return append(dst, '.',
275                         digits10[p1], digits01[p1],
276                         digits10[p2], digits01[p2],
277                         digits10[p3],
278                 )
279         }
280 }
281
282 func formatBinaryDateTime(src []byte, length uint8) (driver.Value, error) {
283         // length expects the deterministic length of the zero value,
284         // negative time and 100+ hours are automatically added if needed
285         if len(src) == 0 {
286                 return zeroDateTime[:length], nil
287         }
288         var dst []byte      // return value
289         var p1, p2, p3 byte // current digit pair
290
291         switch length {
292         case 10, 19, 21, 22, 23, 24, 25, 26:
293         default:
294                 t := "DATE"
295                 if length > 10 {
296                         t += "TIME"
297                 }
298                 return nil, fmt.Errorf("illegal %s length %d", t, length)
299         }
300         switch len(src) {
301         case 4, 7, 11:
302         default:
303                 t := "DATE"
304                 if length > 10 {
305                         t += "TIME"
306                 }
307                 return nil, fmt.Errorf("illegal %s packet length %d", t, len(src))
308         }
309         dst = make([]byte, 0, length)
310         // start with the date
311         year := binary.LittleEndian.Uint16(src[:2])
312         pt := year / 100
313         p1 = byte(year - 100*uint16(pt))
314         p2, p3 = src[2], src[3]
315         dst = append(dst,
316                 digits10[pt], digits01[pt],
317                 digits10[p1], digits01[p1], '-',
318                 digits10[p2], digits01[p2], '-',
319                 digits10[p3], digits01[p3],
320         )
321         if length == 10 {
322                 return dst, nil
323         }
324         if len(src) == 4 {
325                 return append(dst, zeroDateTime[10:length]...), nil
326         }
327         dst = append(dst, ' ')
328         p1 = src[4] // hour
329         src = src[5:]
330
331         // p1 is 2-digit hour, src is after hour
332         p2, p3 = src[0], src[1]
333         dst = append(dst,
334                 digits10[p1], digits01[p1], ':',
335                 digits10[p2], digits01[p2], ':',
336                 digits10[p3], digits01[p3],
337         )
338         return appendMicrosecs(dst, src[2:], int(length)-20), nil
339 }
340
341 func formatBinaryTime(src []byte, length uint8) (driver.Value, error) {
342         // length expects the deterministic length of the zero value,
343         // negative time and 100+ hours are automatically added if needed
344         if len(src) == 0 {
345                 return zeroDateTime[11 : 11+length], nil
346         }
347         var dst []byte // return value
348
349         switch length {
350         case
351                 8,                      // time (can be up to 10 when negative and 100+ hours)
352                 10, 11, 12, 13, 14, 15: // time with fractional seconds
353         default:
354                 return nil, fmt.Errorf("illegal TIME length %d", length)
355         }
356         switch len(src) {
357         case 8, 12:
358         default:
359                 return nil, fmt.Errorf("invalid TIME packet length %d", len(src))
360         }
361         // +2 to enable negative time and 100+ hours
362         dst = make([]byte, 0, length+2)
363         if src[0] == 1 {
364                 dst = append(dst, '-')
365         }
366         days := binary.LittleEndian.Uint32(src[1:5])
367         hours := int64(days)*24 + int64(src[5])
368
369         if hours >= 100 {
370                 dst = strconv.AppendInt(dst, hours, 10)
371         } else {
372                 dst = append(dst, digits10[hours], digits01[hours])
373         }
374
375         min, sec := src[6], src[7]
376         dst = append(dst, ':',
377                 digits10[min], digits01[min], ':',
378                 digits10[sec], digits01[sec],
379         )
380         return appendMicrosecs(dst, src[8:], int(length)-9), nil
381 }
382
383 /******************************************************************************
384 *                       Convert from and to bytes                             *
385 ******************************************************************************/
386
387 func uint64ToBytes(n uint64) []byte {
388         return []byte{
389                 byte(n),
390                 byte(n >> 8),
391                 byte(n >> 16),
392                 byte(n >> 24),
393                 byte(n >> 32),
394                 byte(n >> 40),
395                 byte(n >> 48),
396                 byte(n >> 56),
397         }
398 }
399
400 func uint64ToString(n uint64) []byte {
401         var a [20]byte
402         i := 20
403
404         // U+0030 = 0
405         // ...
406         // U+0039 = 9
407
408         var q uint64
409         for n >= 10 {
410                 i--
411                 q = n / 10
412                 a[i] = uint8(n-q*10) + 0x30
413                 n = q
414         }
415
416         i--
417         a[i] = uint8(n) + 0x30
418
419         return a[i:]
420 }
421
422 // treats string value as unsigned integer representation
423 func stringToInt(b []byte) int {
424         val := 0
425         for i := range b {
426                 val *= 10
427                 val += int(b[i] - 0x30)
428         }
429         return val
430 }
431
432 // returns the string read as a bytes slice, wheter the value is NULL,
433 // the number of bytes read and an error, in case the string is longer than
434 // the input slice
435 func readLengthEncodedString(b []byte) ([]byte, bool, int, error) {
436         // Get length
437         num, isNull, n := readLengthEncodedInteger(b)
438         if num < 1 {
439                 return b[n:n], isNull, n, nil
440         }
441
442         n += int(num)
443
444         // Check data length
445         if len(b) >= n {
446                 return b[n-int(num) : n : n], false, n, nil
447         }
448         return nil, false, n, io.EOF
449 }
450
451 // returns the number of bytes skipped and an error, in case the string is
452 // longer than the input slice
453 func skipLengthEncodedString(b []byte) (int, error) {
454         // Get length
455         num, _, n := readLengthEncodedInteger(b)
456         if num < 1 {
457                 return n, nil
458         }
459
460         n += int(num)
461
462         // Check data length
463         if len(b) >= n {
464                 return n, nil
465         }
466         return n, io.EOF
467 }
468
469 // returns the number read, whether the value is NULL and the number of bytes read
470 func readLengthEncodedInteger(b []byte) (uint64, bool, int) {
471         // See issue #349
472         if len(b) == 0 {
473                 return 0, true, 1
474         }
475
476         switch b[0] {
477         // 251: NULL
478         case 0xfb:
479                 return 0, true, 1
480
481         // 252: value of following 2
482         case 0xfc:
483                 return uint64(b[1]) | uint64(b[2])<<8, false, 3
484
485         // 253: value of following 3
486         case 0xfd:
487                 return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16, false, 4
488
489         // 254: value of following 8
490         case 0xfe:
491                 return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 |
492                                 uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 |
493                                 uint64(b[7])<<48 | uint64(b[8])<<56,
494                         false, 9
495         }
496
497         // 0-250: value of first byte
498         return uint64(b[0]), false, 1
499 }
500
501 // encodes a uint64 value and appends it to the given bytes slice
502 func appendLengthEncodedInteger(b []byte, n uint64) []byte {
503         switch {
504         case n <= 250:
505                 return append(b, byte(n))
506
507         case n <= 0xffff:
508                 return append(b, 0xfc, byte(n), byte(n>>8))
509
510         case n <= 0xffffff:
511                 return append(b, 0xfd, byte(n), byte(n>>8), byte(n>>16))
512         }
513         return append(b, 0xfe, byte(n), byte(n>>8), byte(n>>16), byte(n>>24),
514                 byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56))
515 }
516
517 // reserveBuffer checks cap(buf) and expand buffer to len(buf) + appendSize.
518 // If cap(buf) is not enough, reallocate new buffer.
519 func reserveBuffer(buf []byte, appendSize int) []byte {
520         newSize := len(buf) + appendSize
521         if cap(buf) < newSize {
522                 // Grow buffer exponentially
523                 newBuf := make([]byte, len(buf)*2+appendSize)
524                 copy(newBuf, buf)
525                 buf = newBuf
526         }
527         return buf[:newSize]
528 }
529
530 // escapeBytesBackslash escapes []byte with backslashes (\)
531 // This escapes the contents of a string (provided as []byte) by adding backslashes before special
532 // characters, and turning others into specific escape sequences, such as
533 // turning newlines into \n and null bytes into \0.
534 // https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L823-L932
535 func escapeBytesBackslash(buf, v []byte) []byte {
536         pos := len(buf)
537         buf = reserveBuffer(buf, len(v)*2)
538
539         for _, c := range v {
540                 switch c {
541                 case '\x00':
542                         buf[pos] = '\\'
543                         buf[pos+1] = '0'
544                         pos += 2
545                 case '\n':
546                         buf[pos] = '\\'
547                         buf[pos+1] = 'n'
548                         pos += 2
549                 case '\r':
550                         buf[pos] = '\\'
551                         buf[pos+1] = 'r'
552                         pos += 2
553                 case '\x1a':
554                         buf[pos] = '\\'
555                         buf[pos+1] = 'Z'
556                         pos += 2
557                 case '\'':
558                         buf[pos] = '\\'
559                         buf[pos+1] = '\''
560                         pos += 2
561                 case '"':
562                         buf[pos] = '\\'
563                         buf[pos+1] = '"'
564                         pos += 2
565                 case '\\':
566                         buf[pos] = '\\'
567                         buf[pos+1] = '\\'
568                         pos += 2
569                 default:
570                         buf[pos] = c
571                         pos++
572                 }
573         }
574
575         return buf[:pos]
576 }
577
578 // escapeStringBackslash is similar to escapeBytesBackslash but for string.
579 func escapeStringBackslash(buf []byte, v string) []byte {
580         pos := len(buf)
581         buf = reserveBuffer(buf, len(v)*2)
582
583         for i := 0; i < len(v); i++ {
584                 c := v[i]
585                 switch c {
586                 case '\x00':
587                         buf[pos] = '\\'
588                         buf[pos+1] = '0'
589                         pos += 2
590                 case '\n':
591                         buf[pos] = '\\'
592                         buf[pos+1] = 'n'
593                         pos += 2
594                 case '\r':
595                         buf[pos] = '\\'
596                         buf[pos+1] = 'r'
597                         pos += 2
598                 case '\x1a':
599                         buf[pos] = '\\'
600                         buf[pos+1] = 'Z'
601                         pos += 2
602                 case '\'':
603                         buf[pos] = '\\'
604                         buf[pos+1] = '\''
605                         pos += 2
606                 case '"':
607                         buf[pos] = '\\'
608                         buf[pos+1] = '"'
609                         pos += 2
610                 case '\\':
611                         buf[pos] = '\\'
612                         buf[pos+1] = '\\'
613                         pos += 2
614                 default:
615                         buf[pos] = c
616                         pos++
617                 }
618         }
619
620         return buf[:pos]
621 }
622
623 // escapeBytesQuotes escapes apostrophes in []byte by doubling them up.
624 // This escapes the contents of a string by doubling up any apostrophes that
625 // it contains. This is used when the NO_BACKSLASH_ESCAPES SQL_MODE is in
626 // effect on the server.
627 // https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L963-L1038
628 func escapeBytesQuotes(buf, v []byte) []byte {
629         pos := len(buf)
630         buf = reserveBuffer(buf, len(v)*2)
631
632         for _, c := range v {
633                 if c == '\'' {
634                         buf[pos] = '\''
635                         buf[pos+1] = '\''
636                         pos += 2
637                 } else {
638                         buf[pos] = c
639                         pos++
640                 }
641         }
642
643         return buf[:pos]
644 }
645
646 // escapeStringQuotes is similar to escapeBytesQuotes but for string.
647 func escapeStringQuotes(buf []byte, v string) []byte {
648         pos := len(buf)
649         buf = reserveBuffer(buf, len(v)*2)
650
651         for i := 0; i < len(v); i++ {
652                 c := v[i]
653                 if c == '\'' {
654                         buf[pos] = '\''
655                         buf[pos+1] = '\''
656                         pos += 2
657                 } else {
658                         buf[pos] = c
659                         pos++
660                 }
661         }
662
663         return buf[:pos]
664 }
665
666 /******************************************************************************
667 *                               Sync utils                                    *
668 ******************************************************************************/
669
670 // noCopy may be embedded into structs which must not be copied
671 // after the first use.
672 //
673 // See https://github.com/golang/go/issues/8005#issuecomment-190753527
674 // for details.
675 type noCopy struct{}
676
677 // Lock is a no-op used by -copylocks checker from `go vet`.
678 func (*noCopy) Lock() {}
679
680 // atomicBool is a wrapper around uint32 for usage as a boolean value with
681 // atomic access.
682 type atomicBool struct {
683         _noCopy noCopy
684         value   uint32
685 }
686
687 // IsSet returns wether the current boolean value is true
688 func (ab *atomicBool) IsSet() bool {
689         return atomic.LoadUint32(&ab.value) > 0
690 }
691
692 // Set sets the value of the bool regardless of the previous value
693 func (ab *atomicBool) Set(value bool) {
694         if value {
695                 atomic.StoreUint32(&ab.value, 1)
696         } else {
697                 atomic.StoreUint32(&ab.value, 0)
698         }
699 }
700
701 // TrySet sets the value of the bool and returns wether the value changed
702 func (ab *atomicBool) TrySet(value bool) bool {
703         if value {
704                 return atomic.SwapUint32(&ab.value, 1) == 0
705         }
706         return atomic.SwapUint32(&ab.value, 0) > 0
707 }
708
709 // atomicError is a wrapper for atomically accessed error values
710 type atomicError struct {
711         _noCopy noCopy
712         value   atomic.Value
713 }
714
715 // Set sets the error value regardless of the previous value.
716 // The value must not be nil
717 func (ae *atomicError) Set(value error) {
718         ae.value.Store(value)
719 }
720
721 // Value returns the current error value
722 func (ae *atomicError) Value() error {
723         if v := ae.value.Load(); v != nil {
724                 // this will panic if the value doesn't implement the error interface
725                 return v.(error)
726         }
727         return nil
728 }
729
730 func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
731         dargs := make([]driver.Value, len(named))
732         for n, param := range named {
733                 if len(param.Name) > 0 {
734                         // TODO: support the use of Named Parameters #561
735                         return nil, errors.New("mysql: driver does not support the use of Named Parameters")
736                 }
737                 dargs[n] = param.Value
738         }
739         return dargs, nil
740 }
741
742 func mapIsolationLevel(level driver.IsolationLevel) (string, error) {
743         switch sql.IsolationLevel(level) {
744         case sql.LevelRepeatableRead:
745                 return "REPEATABLE READ", nil
746         case sql.LevelReadCommitted:
747                 return "READ COMMITTED", nil
748         case sql.LevelReadUncommitted:
749                 return "READ UNCOMMITTED", nil
750         case sql.LevelSerializable:
751                 return "SERIALIZABLE", nil
752         default:
753                 return "", fmt.Errorf("mysql: unsupported isolation level: %v", level)
754         }
755 }