OSDN Git Service

feat: init cross_tx keepers (#146)
[bytom/vapor.git] / vendor / github.com / go-sql-driver / mysql / connection.go
1 // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
2 //
3 // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
4 //
5 // This Source Code Form is subject to the terms of the Mozilla Public
6 // License, v. 2.0. If a copy of the MPL was not distributed with this file,
7 // You can obtain one at http://mozilla.org/MPL/2.0/.
8
9 package mysql
10
11 import (
12         "context"
13         "database/sql"
14         "database/sql/driver"
15         "io"
16         "net"
17         "strconv"
18         "strings"
19         "time"
20 )
21
22 // a copy of context.Context for Go 1.7 and earlier
23 type mysqlContext interface {
24         Done() <-chan struct{}
25         Err() error
26
27         // defined in context.Context, but not used in this driver:
28         // Deadline() (deadline time.Time, ok bool)
29         // Value(key interface{}) interface{}
30 }
31
32 type mysqlConn struct {
33         buf              buffer
34         netConn          net.Conn
35         affectedRows     uint64
36         insertId         uint64
37         cfg              *Config
38         maxAllowedPacket int
39         maxWriteSize     int
40         writeTimeout     time.Duration
41         flags            clientFlag
42         status           statusFlag
43         sequence         uint8
44         parseTime        bool
45
46         // for context support (Go 1.8+)
47         watching bool
48         watcher  chan<- mysqlContext
49         closech  chan struct{}
50         finished chan<- struct{}
51         canceled atomicError // set non-nil if conn is canceled
52         closed   atomicBool  // set when conn is closed, before closech is closed
53 }
54
55 // Handles parameters set in DSN after the connection is established
56 func (mc *mysqlConn) handleParams() (err error) {
57         for param, val := range mc.cfg.Params {
58                 switch param {
59                 // Charset
60                 case "charset":
61                         charsets := strings.Split(val, ",")
62                         for i := range charsets {
63                                 // ignore errors here - a charset may not exist
64                                 err = mc.exec("SET NAMES " + charsets[i])
65                                 if err == nil {
66                                         break
67                                 }
68                         }
69                         if err != nil {
70                                 return
71                         }
72
73                 // System Vars
74                 default:
75                         err = mc.exec("SET " + param + "=" + val + "")
76                         if err != nil {
77                                 return
78                         }
79                 }
80         }
81
82         return
83 }
84
85 func (mc *mysqlConn) markBadConn(err error) error {
86         if mc == nil {
87                 return err
88         }
89         if err != errBadConnNoWrite {
90                 return err
91         }
92         return driver.ErrBadConn
93 }
94
95 func (mc *mysqlConn) Begin() (driver.Tx, error) {
96         return mc.begin(false)
97 }
98
99 func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
100         if mc.closed.IsSet() {
101                 errLog.Print(ErrInvalidConn)
102                 return nil, driver.ErrBadConn
103         }
104         var q string
105         if readOnly {
106                 q = "START TRANSACTION READ ONLY"
107         } else {
108                 q = "START TRANSACTION"
109         }
110         err := mc.exec(q)
111         if err == nil {
112                 return &mysqlTx{mc}, err
113         }
114         return nil, mc.markBadConn(err)
115 }
116
117 func (mc *mysqlConn) Close() (err error) {
118         // Makes Close idempotent
119         if !mc.closed.IsSet() {
120                 err = mc.writeCommandPacket(comQuit)
121         }
122
123         mc.cleanup()
124
125         return
126 }
127
128 // Closes the network connection and unsets internal variables. Do not call this
129 // function after successfully authentication, call Close instead. This function
130 // is called before auth or on auth failure because MySQL will have already
131 // closed the network connection.
132 func (mc *mysqlConn) cleanup() {
133         if !mc.closed.TrySet(true) {
134                 return
135         }
136
137         // Makes cleanup idempotent
138         close(mc.closech)
139         if mc.netConn == nil {
140                 return
141         }
142         if err := mc.netConn.Close(); err != nil {
143                 errLog.Print(err)
144         }
145 }
146
147 func (mc *mysqlConn) error() error {
148         if mc.closed.IsSet() {
149                 if err := mc.canceled.Value(); err != nil {
150                         return err
151                 }
152                 return ErrInvalidConn
153         }
154         return nil
155 }
156
157 func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
158         if mc.closed.IsSet() {
159                 errLog.Print(ErrInvalidConn)
160                 return nil, driver.ErrBadConn
161         }
162         // Send command
163         err := mc.writeCommandPacketStr(comStmtPrepare, query)
164         if err != nil {
165                 return nil, mc.markBadConn(err)
166         }
167
168         stmt := &mysqlStmt{
169                 mc: mc,
170         }
171
172         // Read Result
173         columnCount, err := stmt.readPrepareResultPacket()
174         if err == nil {
175                 if stmt.paramCount > 0 {
176                         if err = mc.readUntilEOF(); err != nil {
177                                 return nil, err
178                         }
179                 }
180
181                 if columnCount > 0 {
182                         err = mc.readUntilEOF()
183                 }
184         }
185
186         return stmt, err
187 }
188
189 func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
190         // Number of ? should be same to len(args)
191         if strings.Count(query, "?") != len(args) {
192                 return "", driver.ErrSkip
193         }
194
195         buf := mc.buf.takeCompleteBuffer()
196         if buf == nil {
197                 // can not take the buffer. Something must be wrong with the connection
198                 errLog.Print(ErrBusyBuffer)
199                 return "", ErrInvalidConn
200         }
201         buf = buf[:0]
202         argPos := 0
203
204         for i := 0; i < len(query); i++ {
205                 q := strings.IndexByte(query[i:], '?')
206                 if q == -1 {
207                         buf = append(buf, query[i:]...)
208                         break
209                 }
210                 buf = append(buf, query[i:i+q]...)
211                 i += q
212
213                 arg := args[argPos]
214                 argPos++
215
216                 if arg == nil {
217                         buf = append(buf, "NULL"...)
218                         continue
219                 }
220
221                 switch v := arg.(type) {
222                 case int64:
223                         buf = strconv.AppendInt(buf, v, 10)
224                 case float64:
225                         buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
226                 case bool:
227                         if v {
228                                 buf = append(buf, '1')
229                         } else {
230                                 buf = append(buf, '0')
231                         }
232                 case time.Time:
233                         if v.IsZero() {
234                                 buf = append(buf, "'0000-00-00'"...)
235                         } else {
236                                 v := v.In(mc.cfg.Loc)
237                                 v = v.Add(time.Nanosecond * 500) // To round under microsecond
238                                 year := v.Year()
239                                 year100 := year / 100
240                                 year1 := year % 100
241                                 month := v.Month()
242                                 day := v.Day()
243                                 hour := v.Hour()
244                                 minute := v.Minute()
245                                 second := v.Second()
246                                 micro := v.Nanosecond() / 1000
247
248                                 buf = append(buf, []byte{
249                                         '\'',
250                                         digits10[year100], digits01[year100],
251                                         digits10[year1], digits01[year1],
252                                         '-',
253                                         digits10[month], digits01[month],
254                                         '-',
255                                         digits10[day], digits01[day],
256                                         ' ',
257                                         digits10[hour], digits01[hour],
258                                         ':',
259                                         digits10[minute], digits01[minute],
260                                         ':',
261                                         digits10[second], digits01[second],
262                                 }...)
263
264                                 if micro != 0 {
265                                         micro10000 := micro / 10000
266                                         micro100 := micro / 100 % 100
267                                         micro1 := micro % 100
268                                         buf = append(buf, []byte{
269                                                 '.',
270                                                 digits10[micro10000], digits01[micro10000],
271                                                 digits10[micro100], digits01[micro100],
272                                                 digits10[micro1], digits01[micro1],
273                                         }...)
274                                 }
275                                 buf = append(buf, '\'')
276                         }
277                 case []byte:
278                         if v == nil {
279                                 buf = append(buf, "NULL"...)
280                         } else {
281                                 buf = append(buf, "_binary'"...)
282                                 if mc.status&statusNoBackslashEscapes == 0 {
283                                         buf = escapeBytesBackslash(buf, v)
284                                 } else {
285                                         buf = escapeBytesQuotes(buf, v)
286                                 }
287                                 buf = append(buf, '\'')
288                         }
289                 case string:
290                         buf = append(buf, '\'')
291                         if mc.status&statusNoBackslashEscapes == 0 {
292                                 buf = escapeStringBackslash(buf, v)
293                         } else {
294                                 buf = escapeStringQuotes(buf, v)
295                         }
296                         buf = append(buf, '\'')
297                 default:
298                         return "", driver.ErrSkip
299                 }
300
301                 if len(buf)+4 > mc.maxAllowedPacket {
302                         return "", driver.ErrSkip
303                 }
304         }
305         if argPos != len(args) {
306                 return "", driver.ErrSkip
307         }
308         return string(buf), nil
309 }
310
311 func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
312         if mc.closed.IsSet() {
313                 errLog.Print(ErrInvalidConn)
314                 return nil, driver.ErrBadConn
315         }
316         if len(args) != 0 {
317                 if !mc.cfg.InterpolateParams {
318                         return nil, driver.ErrSkip
319                 }
320                 // try to interpolate the parameters to save extra roundtrips for preparing and closing a statement
321                 prepared, err := mc.interpolateParams(query, args)
322                 if err != nil {
323                         return nil, err
324                 }
325                 query = prepared
326         }
327         mc.affectedRows = 0
328         mc.insertId = 0
329
330         err := mc.exec(query)
331         if err == nil {
332                 return &mysqlResult{
333                         affectedRows: int64(mc.affectedRows),
334                         insertId:     int64(mc.insertId),
335                 }, err
336         }
337         return nil, mc.markBadConn(err)
338 }
339
340 // Internal function to execute commands
341 func (mc *mysqlConn) exec(query string) error {
342         // Send command
343         if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
344                 return mc.markBadConn(err)
345         }
346
347         // Read Result
348         resLen, err := mc.readResultSetHeaderPacket()
349         if err != nil {
350                 return err
351         }
352
353         if resLen > 0 {
354                 // columns
355                 if err := mc.readUntilEOF(); err != nil {
356                         return err
357                 }
358
359                 // rows
360                 if err := mc.readUntilEOF(); err != nil {
361                         return err
362                 }
363         }
364
365         return mc.discardResults()
366 }
367
368 func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
369         return mc.query(query, args)
370 }
371
372 func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
373         if mc.closed.IsSet() {
374                 errLog.Print(ErrInvalidConn)
375                 return nil, driver.ErrBadConn
376         }
377         if len(args) != 0 {
378                 if !mc.cfg.InterpolateParams {
379                         return nil, driver.ErrSkip
380                 }
381                 // try client-side prepare to reduce roundtrip
382                 prepared, err := mc.interpolateParams(query, args)
383                 if err != nil {
384                         return nil, err
385                 }
386                 query = prepared
387         }
388         // Send command
389         err := mc.writeCommandPacketStr(comQuery, query)
390         if err == nil {
391                 // Read Result
392                 var resLen int
393                 resLen, err = mc.readResultSetHeaderPacket()
394                 if err == nil {
395                         rows := new(textRows)
396                         rows.mc = mc
397
398                         if resLen == 0 {
399                                 rows.rs.done = true
400
401                                 switch err := rows.NextResultSet(); err {
402                                 case nil, io.EOF:
403                                         return rows, nil
404                                 default:
405                                         return nil, err
406                                 }
407                         }
408
409                         // Columns
410                         rows.rs.columns, err = mc.readColumns(resLen)
411                         return rows, err
412                 }
413         }
414         return nil, mc.markBadConn(err)
415 }
416
417 // Gets the value of the given MySQL System Variable
418 // The returned byte slice is only valid until the next read
419 func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
420         // Send command
421         if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
422                 return nil, err
423         }
424
425         // Read Result
426         resLen, err := mc.readResultSetHeaderPacket()
427         if err == nil {
428                 rows := new(textRows)
429                 rows.mc = mc
430                 rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
431
432                 if resLen > 0 {
433                         // Columns
434                         if err := mc.readUntilEOF(); err != nil {
435                                 return nil, err
436                         }
437                 }
438
439                 dest := make([]driver.Value, resLen)
440                 if err = rows.readRow(dest); err == nil {
441                         return dest[0].([]byte), mc.readUntilEOF()
442                 }
443         }
444         return nil, err
445 }
446
447 // finish is called when the query has canceled.
448 func (mc *mysqlConn) cancel(err error) {
449         mc.canceled.Set(err)
450         mc.cleanup()
451 }
452
453 // finish is called when the query has succeeded.
454 func (mc *mysqlConn) finish() {
455         if !mc.watching || mc.finished == nil {
456                 return
457         }
458         select {
459         case mc.finished <- struct{}{}:
460                 mc.watching = false
461         case <-mc.closech:
462         }
463 }
464
465 // Ping implements driver.Pinger interface
466 func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
467         if mc.closed.IsSet() {
468                 errLog.Print(ErrInvalidConn)
469                 return driver.ErrBadConn
470         }
471
472         if err = mc.watchCancel(ctx); err != nil {
473                 return
474         }
475         defer mc.finish()
476
477         if err = mc.writeCommandPacket(comPing); err != nil {
478                 return
479         }
480
481         return mc.readResultOK()
482 }
483
484 // BeginTx implements driver.ConnBeginTx interface
485 func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
486         if err := mc.watchCancel(ctx); err != nil {
487                 return nil, err
488         }
489         defer mc.finish()
490
491         if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault {
492                 level, err := mapIsolationLevel(opts.Isolation)
493                 if err != nil {
494                         return nil, err
495                 }
496                 err = mc.exec("SET TRANSACTION ISOLATION LEVEL " + level)
497                 if err != nil {
498                         return nil, err
499                 }
500         }
501
502         return mc.begin(opts.ReadOnly)
503 }
504
505 func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
506         dargs, err := namedValueToValue(args)
507         if err != nil {
508                 return nil, err
509         }
510
511         if err := mc.watchCancel(ctx); err != nil {
512                 return nil, err
513         }
514
515         rows, err := mc.query(query, dargs)
516         if err != nil {
517                 mc.finish()
518                 return nil, err
519         }
520         rows.finish = mc.finish
521         return rows, err
522 }
523
524 func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
525         dargs, err := namedValueToValue(args)
526         if err != nil {
527                 return nil, err
528         }
529
530         if err := mc.watchCancel(ctx); err != nil {
531                 return nil, err
532         }
533         defer mc.finish()
534
535         return mc.Exec(query, dargs)
536 }
537
538 func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
539         if err := mc.watchCancel(ctx); err != nil {
540                 return nil, err
541         }
542
543         stmt, err := mc.Prepare(query)
544         mc.finish()
545         if err != nil {
546                 return nil, err
547         }
548
549         select {
550         default:
551         case <-ctx.Done():
552                 stmt.Close()
553                 return nil, ctx.Err()
554         }
555         return stmt, nil
556 }
557
558 func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
559         dargs, err := namedValueToValue(args)
560         if err != nil {
561                 return nil, err
562         }
563
564         if err := stmt.mc.watchCancel(ctx); err != nil {
565                 return nil, err
566         }
567
568         rows, err := stmt.query(dargs)
569         if err != nil {
570                 stmt.mc.finish()
571                 return nil, err
572         }
573         rows.finish = stmt.mc.finish
574         return rows, err
575 }
576
577 func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
578         dargs, err := namedValueToValue(args)
579         if err != nil {
580                 return nil, err
581         }
582
583         if err := stmt.mc.watchCancel(ctx); err != nil {
584                 return nil, err
585         }
586         defer stmt.mc.finish()
587
588         return stmt.Exec(dargs)
589 }
590
591 func (mc *mysqlConn) watchCancel(ctx context.Context) error {
592         if mc.watching {
593                 // Reach here if canceled,
594                 // so the connection is already invalid
595                 mc.cleanup()
596                 return nil
597         }
598         if ctx.Done() == nil {
599                 return nil
600         }
601
602         mc.watching = true
603         select {
604         default:
605         case <-ctx.Done():
606                 return ctx.Err()
607         }
608         if mc.watcher == nil {
609                 return nil
610         }
611
612         mc.watcher <- ctx
613
614         return nil
615 }
616
617 func (mc *mysqlConn) startWatcher() {
618         watcher := make(chan mysqlContext, 1)
619         mc.watcher = watcher
620         finished := make(chan struct{})
621         mc.finished = finished
622         go func() {
623                 for {
624                         var ctx mysqlContext
625                         select {
626                         case ctx = <-watcher:
627                         case <-mc.closech:
628                                 return
629                         }
630
631                         select {
632                         case <-ctx.Done():
633                                 mc.cancel(ctx.Err())
634                         case <-finished:
635                         case <-mc.closech:
636                                 return
637                         }
638                 }
639         }()
640 }
641
642 func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) {
643         nv.Value, err = converter{}.ConvertValue(nv.Value)
644         return
645 }
646
647 // ResetSession implements driver.SessionResetter.
648 // (From Go 1.10)
649 func (mc *mysqlConn) ResetSession(ctx context.Context) error {
650         if mc.closed.IsSet() {
651                 return driver.ErrBadConn
652         }
653         return nil
654 }