OSDN Git Service

Merge pull request #201 from Bytom/v0.1
[bytom/vapor.git] / vendor / github.com / go-sql-driver / mysql / connection.go
diff --git a/vendor/github.com/go-sql-driver/mysql/connection.go b/vendor/github.com/go-sql-driver/mysql/connection.go
new file mode 100644 (file)
index 0000000..911be20
--- /dev/null
@@ -0,0 +1,654 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mysql
+
+import (
+       "context"
+       "database/sql"
+       "database/sql/driver"
+       "io"
+       "net"
+       "strconv"
+       "strings"
+       "time"
+)
+
+// a copy of context.Context for Go 1.7 and earlier
+type mysqlContext interface {
+       Done() <-chan struct{}
+       Err() error
+
+       // defined in context.Context, but not used in this driver:
+       // Deadline() (deadline time.Time, ok bool)
+       // Value(key interface{}) interface{}
+}
+
+type mysqlConn struct {
+       buf              buffer
+       netConn          net.Conn
+       affectedRows     uint64
+       insertId         uint64
+       cfg              *Config
+       maxAllowedPacket int
+       maxWriteSize     int
+       writeTimeout     time.Duration
+       flags            clientFlag
+       status           statusFlag
+       sequence         uint8
+       parseTime        bool
+
+       // for context support (Go 1.8+)
+       watching bool
+       watcher  chan<- mysqlContext
+       closech  chan struct{}
+       finished chan<- struct{}
+       canceled atomicError // set non-nil if conn is canceled
+       closed   atomicBool  // set when conn is closed, before closech is closed
+}
+
+// Handles parameters set in DSN after the connection is established
+func (mc *mysqlConn) handleParams() (err error) {
+       for param, val := range mc.cfg.Params {
+               switch param {
+               // Charset
+               case "charset":
+                       charsets := strings.Split(val, ",")
+                       for i := range charsets {
+                               // ignore errors here - a charset may not exist
+                               err = mc.exec("SET NAMES " + charsets[i])
+                               if err == nil {
+                                       break
+                               }
+                       }
+                       if err != nil {
+                               return
+                       }
+
+               // System Vars
+               default:
+                       err = mc.exec("SET " + param + "=" + val + "")
+                       if err != nil {
+                               return
+                       }
+               }
+       }
+
+       return
+}
+
+func (mc *mysqlConn) markBadConn(err error) error {
+       if mc == nil {
+               return err
+       }
+       if err != errBadConnNoWrite {
+               return err
+       }
+       return driver.ErrBadConn
+}
+
+func (mc *mysqlConn) Begin() (driver.Tx, error) {
+       return mc.begin(false)
+}
+
+func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
+       if mc.closed.IsSet() {
+               errLog.Print(ErrInvalidConn)
+               return nil, driver.ErrBadConn
+       }
+       var q string
+       if readOnly {
+               q = "START TRANSACTION READ ONLY"
+       } else {
+               q = "START TRANSACTION"
+       }
+       err := mc.exec(q)
+       if err == nil {
+               return &mysqlTx{mc}, err
+       }
+       return nil, mc.markBadConn(err)
+}
+
+func (mc *mysqlConn) Close() (err error) {
+       // Makes Close idempotent
+       if !mc.closed.IsSet() {
+               err = mc.writeCommandPacket(comQuit)
+       }
+
+       mc.cleanup()
+
+       return
+}
+
+// Closes the network connection and unsets internal variables. Do not call this
+// function after successfully authentication, call Close instead. This function
+// is called before auth or on auth failure because MySQL will have already
+// closed the network connection.
+func (mc *mysqlConn) cleanup() {
+       if !mc.closed.TrySet(true) {
+               return
+       }
+
+       // Makes cleanup idempotent
+       close(mc.closech)
+       if mc.netConn == nil {
+               return
+       }
+       if err := mc.netConn.Close(); err != nil {
+               errLog.Print(err)
+       }
+}
+
+func (mc *mysqlConn) error() error {
+       if mc.closed.IsSet() {
+               if err := mc.canceled.Value(); err != nil {
+                       return err
+               }
+               return ErrInvalidConn
+       }
+       return nil
+}
+
+func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
+       if mc.closed.IsSet() {
+               errLog.Print(ErrInvalidConn)
+               return nil, driver.ErrBadConn
+       }
+       // Send command
+       err := mc.writeCommandPacketStr(comStmtPrepare, query)
+       if err != nil {
+               return nil, mc.markBadConn(err)
+       }
+
+       stmt := &mysqlStmt{
+               mc: mc,
+       }
+
+       // Read Result
+       columnCount, err := stmt.readPrepareResultPacket()
+       if err == nil {
+               if stmt.paramCount > 0 {
+                       if err = mc.readUntilEOF(); err != nil {
+                               return nil, err
+                       }
+               }
+
+               if columnCount > 0 {
+                       err = mc.readUntilEOF()
+               }
+       }
+
+       return stmt, err
+}
+
+func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
+       // Number of ? should be same to len(args)
+       if strings.Count(query, "?") != len(args) {
+               return "", driver.ErrSkip
+       }
+
+       buf := mc.buf.takeCompleteBuffer()
+       if buf == nil {
+               // can not take the buffer. Something must be wrong with the connection
+               errLog.Print(ErrBusyBuffer)
+               return "", ErrInvalidConn
+       }
+       buf = buf[:0]
+       argPos := 0
+
+       for i := 0; i < len(query); i++ {
+               q := strings.IndexByte(query[i:], '?')
+               if q == -1 {
+                       buf = append(buf, query[i:]...)
+                       break
+               }
+               buf = append(buf, query[i:i+q]...)
+               i += q
+
+               arg := args[argPos]
+               argPos++
+
+               if arg == nil {
+                       buf = append(buf, "NULL"...)
+                       continue
+               }
+
+               switch v := arg.(type) {
+               case int64:
+                       buf = strconv.AppendInt(buf, v, 10)
+               case float64:
+                       buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
+               case bool:
+                       if v {
+                               buf = append(buf, '1')
+                       } else {
+                               buf = append(buf, '0')
+                       }
+               case time.Time:
+                       if v.IsZero() {
+                               buf = append(buf, "'0000-00-00'"...)
+                       } else {
+                               v := v.In(mc.cfg.Loc)
+                               v = v.Add(time.Nanosecond * 500) // To round under microsecond
+                               year := v.Year()
+                               year100 := year / 100
+                               year1 := year % 100
+                               month := v.Month()
+                               day := v.Day()
+                               hour := v.Hour()
+                               minute := v.Minute()
+                               second := v.Second()
+                               micro := v.Nanosecond() / 1000
+
+                               buf = append(buf, []byte{
+                                       '\'',
+                                       digits10[year100], digits01[year100],
+                                       digits10[year1], digits01[year1],
+                                       '-',
+                                       digits10[month], digits01[month],
+                                       '-',
+                                       digits10[day], digits01[day],
+                                       ' ',
+                                       digits10[hour], digits01[hour],
+                                       ':',
+                                       digits10[minute], digits01[minute],
+                                       ':',
+                                       digits10[second], digits01[second],
+                               }...)
+
+                               if micro != 0 {
+                                       micro10000 := micro / 10000
+                                       micro100 := micro / 100 % 100
+                                       micro1 := micro % 100
+                                       buf = append(buf, []byte{
+                                               '.',
+                                               digits10[micro10000], digits01[micro10000],
+                                               digits10[micro100], digits01[micro100],
+                                               digits10[micro1], digits01[micro1],
+                                       }...)
+                               }
+                               buf = append(buf, '\'')
+                       }
+               case []byte:
+                       if v == nil {
+                               buf = append(buf, "NULL"...)
+                       } else {
+                               buf = append(buf, "_binary'"...)
+                               if mc.status&statusNoBackslashEscapes == 0 {
+                                       buf = escapeBytesBackslash(buf, v)
+                               } else {
+                                       buf = escapeBytesQuotes(buf, v)
+                               }
+                               buf = append(buf, '\'')
+                       }
+               case string:
+                       buf = append(buf, '\'')
+                       if mc.status&statusNoBackslashEscapes == 0 {
+                               buf = escapeStringBackslash(buf, v)
+                       } else {
+                               buf = escapeStringQuotes(buf, v)
+                       }
+                       buf = append(buf, '\'')
+               default:
+                       return "", driver.ErrSkip
+               }
+
+               if len(buf)+4 > mc.maxAllowedPacket {
+                       return "", driver.ErrSkip
+               }
+       }
+       if argPos != len(args) {
+               return "", driver.ErrSkip
+       }
+       return string(buf), nil
+}
+
+func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
+       if mc.closed.IsSet() {
+               errLog.Print(ErrInvalidConn)
+               return nil, driver.ErrBadConn
+       }
+       if len(args) != 0 {
+               if !mc.cfg.InterpolateParams {
+                       return nil, driver.ErrSkip
+               }
+               // try to interpolate the parameters to save extra roundtrips for preparing and closing a statement
+               prepared, err := mc.interpolateParams(query, args)
+               if err != nil {
+                       return nil, err
+               }
+               query = prepared
+       }
+       mc.affectedRows = 0
+       mc.insertId = 0
+
+       err := mc.exec(query)
+       if err == nil {
+               return &mysqlResult{
+                       affectedRows: int64(mc.affectedRows),
+                       insertId:     int64(mc.insertId),
+               }, err
+       }
+       return nil, mc.markBadConn(err)
+}
+
+// Internal function to execute commands
+func (mc *mysqlConn) exec(query string) error {
+       // Send command
+       if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
+               return mc.markBadConn(err)
+       }
+
+       // Read Result
+       resLen, err := mc.readResultSetHeaderPacket()
+       if err != nil {
+               return err
+       }
+
+       if resLen > 0 {
+               // columns
+               if err := mc.readUntilEOF(); err != nil {
+                       return err
+               }
+
+               // rows
+               if err := mc.readUntilEOF(); err != nil {
+                       return err
+               }
+       }
+
+       return mc.discardResults()
+}
+
+func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
+       return mc.query(query, args)
+}
+
+func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
+       if mc.closed.IsSet() {
+               errLog.Print(ErrInvalidConn)
+               return nil, driver.ErrBadConn
+       }
+       if len(args) != 0 {
+               if !mc.cfg.InterpolateParams {
+                       return nil, driver.ErrSkip
+               }
+               // try client-side prepare to reduce roundtrip
+               prepared, err := mc.interpolateParams(query, args)
+               if err != nil {
+                       return nil, err
+               }
+               query = prepared
+       }
+       // Send command
+       err := mc.writeCommandPacketStr(comQuery, query)
+       if err == nil {
+               // Read Result
+               var resLen int
+               resLen, err = mc.readResultSetHeaderPacket()
+               if err == nil {
+                       rows := new(textRows)
+                       rows.mc = mc
+
+                       if resLen == 0 {
+                               rows.rs.done = true
+
+                               switch err := rows.NextResultSet(); err {
+                               case nil, io.EOF:
+                                       return rows, nil
+                               default:
+                                       return nil, err
+                               }
+                       }
+
+                       // Columns
+                       rows.rs.columns, err = mc.readColumns(resLen)
+                       return rows, err
+               }
+       }
+       return nil, mc.markBadConn(err)
+}
+
+// Gets the value of the given MySQL System Variable
+// The returned byte slice is only valid until the next read
+func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
+       // Send command
+       if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
+               return nil, err
+       }
+
+       // Read Result
+       resLen, err := mc.readResultSetHeaderPacket()
+       if err == nil {
+               rows := new(textRows)
+               rows.mc = mc
+               rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
+
+               if resLen > 0 {
+                       // Columns
+                       if err := mc.readUntilEOF(); err != nil {
+                               return nil, err
+                       }
+               }
+
+               dest := make([]driver.Value, resLen)
+               if err = rows.readRow(dest); err == nil {
+                       return dest[0].([]byte), mc.readUntilEOF()
+               }
+       }
+       return nil, err
+}
+
+// finish is called when the query has canceled.
+func (mc *mysqlConn) cancel(err error) {
+       mc.canceled.Set(err)
+       mc.cleanup()
+}
+
+// finish is called when the query has succeeded.
+func (mc *mysqlConn) finish() {
+       if !mc.watching || mc.finished == nil {
+               return
+       }
+       select {
+       case mc.finished <- struct{}{}:
+               mc.watching = false
+       case <-mc.closech:
+       }
+}
+
+// Ping implements driver.Pinger interface
+func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
+       if mc.closed.IsSet() {
+               errLog.Print(ErrInvalidConn)
+               return driver.ErrBadConn
+       }
+
+       if err = mc.watchCancel(ctx); err != nil {
+               return
+       }
+       defer mc.finish()
+
+       if err = mc.writeCommandPacket(comPing); err != nil {
+               return
+       }
+
+       return mc.readResultOK()
+}
+
+// BeginTx implements driver.ConnBeginTx interface
+func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
+       if err := mc.watchCancel(ctx); err != nil {
+               return nil, err
+       }
+       defer mc.finish()
+
+       if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault {
+               level, err := mapIsolationLevel(opts.Isolation)
+               if err != nil {
+                       return nil, err
+               }
+               err = mc.exec("SET TRANSACTION ISOLATION LEVEL " + level)
+               if err != nil {
+                       return nil, err
+               }
+       }
+
+       return mc.begin(opts.ReadOnly)
+}
+
+func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
+       dargs, err := namedValueToValue(args)
+       if err != nil {
+               return nil, err
+       }
+
+       if err := mc.watchCancel(ctx); err != nil {
+               return nil, err
+       }
+
+       rows, err := mc.query(query, dargs)
+       if err != nil {
+               mc.finish()
+               return nil, err
+       }
+       rows.finish = mc.finish
+       return rows, err
+}
+
+func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
+       dargs, err := namedValueToValue(args)
+       if err != nil {
+               return nil, err
+       }
+
+       if err := mc.watchCancel(ctx); err != nil {
+               return nil, err
+       }
+       defer mc.finish()
+
+       return mc.Exec(query, dargs)
+}
+
+func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
+       if err := mc.watchCancel(ctx); err != nil {
+               return nil, err
+       }
+
+       stmt, err := mc.Prepare(query)
+       mc.finish()
+       if err != nil {
+               return nil, err
+       }
+
+       select {
+       default:
+       case <-ctx.Done():
+               stmt.Close()
+               return nil, ctx.Err()
+       }
+       return stmt, nil
+}
+
+func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
+       dargs, err := namedValueToValue(args)
+       if err != nil {
+               return nil, err
+       }
+
+       if err := stmt.mc.watchCancel(ctx); err != nil {
+               return nil, err
+       }
+
+       rows, err := stmt.query(dargs)
+       if err != nil {
+               stmt.mc.finish()
+               return nil, err
+       }
+       rows.finish = stmt.mc.finish
+       return rows, err
+}
+
+func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
+       dargs, err := namedValueToValue(args)
+       if err != nil {
+               return nil, err
+       }
+
+       if err := stmt.mc.watchCancel(ctx); err != nil {
+               return nil, err
+       }
+       defer stmt.mc.finish()
+
+       return stmt.Exec(dargs)
+}
+
+func (mc *mysqlConn) watchCancel(ctx context.Context) error {
+       if mc.watching {
+               // Reach here if canceled,
+               // so the connection is already invalid
+               mc.cleanup()
+               return nil
+       }
+       if ctx.Done() == nil {
+               return nil
+       }
+
+       mc.watching = true
+       select {
+       default:
+       case <-ctx.Done():
+               return ctx.Err()
+       }
+       if mc.watcher == nil {
+               return nil
+       }
+
+       mc.watcher <- ctx
+
+       return nil
+}
+
+func (mc *mysqlConn) startWatcher() {
+       watcher := make(chan mysqlContext, 1)
+       mc.watcher = watcher
+       finished := make(chan struct{})
+       mc.finished = finished
+       go func() {
+               for {
+                       var ctx mysqlContext
+                       select {
+                       case ctx = <-watcher:
+                       case <-mc.closech:
+                               return
+                       }
+
+                       select {
+                       case <-ctx.Done():
+                               mc.cancel(ctx.Err())
+                       case <-finished:
+                       case <-mc.closech:
+                               return
+                       }
+               }
+       }()
+}
+
+func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) {
+       nv.Value, err = converter{}.ConvertValue(nv.Value)
+       return
+}
+
+// ResetSession implements driver.SessionResetter.
+// (From Go 1.10)
+func (mc *mysqlConn) ResetSession(ctx context.Context) error {
+       if mc.closed.IsSet() {
+               return driver.ErrBadConn
+       }
+       return nil
+}