OSDN Git Service

Merge pull request #201 from Bytom/v0.1
[bytom/vapor.git] / vendor / github.com / go-sql-driver / mysql / packets.go
diff --git a/vendor/github.com/go-sql-driver/mysql/packets.go b/vendor/github.com/go-sql-driver/mysql/packets.go
new file mode 100644 (file)
index 0000000..170aaa0
--- /dev/null
@@ -0,0 +1,1298 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mysql
+
+import (
+       "bytes"
+       "crypto/tls"
+       "database/sql/driver"
+       "encoding/binary"
+       "errors"
+       "fmt"
+       "io"
+       "math"
+       "time"
+)
+
+// Packets documentation:
+// http://dev.mysql.com/doc/internals/en/client-server-protocol.html
+
+// Read packet to buffer 'data'
+func (mc *mysqlConn) readPacket() ([]byte, error) {
+       var prevData []byte
+       for {
+               // read packet header
+               data, err := mc.buf.readNext(4)
+               if err != nil {
+                       if cerr := mc.canceled.Value(); cerr != nil {
+                               return nil, cerr
+                       }
+                       errLog.Print(err)
+                       mc.Close()
+                       return nil, ErrInvalidConn
+               }
+
+               // packet length [24 bit]
+               pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16)
+
+               // check packet sync [8 bit]
+               if data[3] != mc.sequence {
+                       if data[3] > mc.sequence {
+                               return nil, ErrPktSyncMul
+                       }
+                       return nil, ErrPktSync
+               }
+               mc.sequence++
+
+               // packets with length 0 terminate a previous packet which is a
+               // multiple of (2^24)−1 bytes long
+               if pktLen == 0 {
+                       // there was no previous packet
+                       if prevData == nil {
+                               errLog.Print(ErrMalformPkt)
+                               mc.Close()
+                               return nil, ErrInvalidConn
+                       }
+
+                       return prevData, nil
+               }
+
+               // read packet body [pktLen bytes]
+               data, err = mc.buf.readNext(pktLen)
+               if err != nil {
+                       if cerr := mc.canceled.Value(); cerr != nil {
+                               return nil, cerr
+                       }
+                       errLog.Print(err)
+                       mc.Close()
+                       return nil, ErrInvalidConn
+               }
+
+               // return data if this was the last packet
+               if pktLen < maxPacketSize {
+                       // zero allocations for non-split packets
+                       if prevData == nil {
+                               return data, nil
+                       }
+
+                       return append(prevData, data...), nil
+               }
+
+               prevData = append(prevData, data...)
+       }
+}
+
+// Write packet buffer 'data'
+func (mc *mysqlConn) writePacket(data []byte) error {
+       pktLen := len(data) - 4
+
+       if pktLen > mc.maxAllowedPacket {
+               return ErrPktTooLarge
+       }
+
+       for {
+               var size int
+               if pktLen >= maxPacketSize {
+                       data[0] = 0xff
+                       data[1] = 0xff
+                       data[2] = 0xff
+                       size = maxPacketSize
+               } else {
+                       data[0] = byte(pktLen)
+                       data[1] = byte(pktLen >> 8)
+                       data[2] = byte(pktLen >> 16)
+                       size = pktLen
+               }
+               data[3] = mc.sequence
+
+               // Write packet
+               if mc.writeTimeout > 0 {
+                       if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil {
+                               return err
+                       }
+               }
+
+               n, err := mc.netConn.Write(data[:4+size])
+               if err == nil && n == 4+size {
+                       mc.sequence++
+                       if size != maxPacketSize {
+                               return nil
+                       }
+                       pktLen -= size
+                       data = data[size:]
+                       continue
+               }
+
+               // Handle error
+               if err == nil { // n != len(data)
+                       mc.cleanup()
+                       errLog.Print(ErrMalformPkt)
+               } else {
+                       if cerr := mc.canceled.Value(); cerr != nil {
+                               return cerr
+                       }
+                       if n == 0 && pktLen == len(data)-4 {
+                               // only for the first loop iteration when nothing was written yet
+                               return errBadConnNoWrite
+                       }
+                       mc.cleanup()
+                       errLog.Print(err)
+               }
+               return ErrInvalidConn
+       }
+}
+
+/******************************************************************************
+*                           Initialization Process                            *
+******************************************************************************/
+
+// Handshake Initialization Packet
+// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
+func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err error) {
+       data, err = mc.readPacket()
+       if err != nil {
+               // for init we can rewrite this to ErrBadConn for sql.Driver to retry, since
+               // in connection initialization we don't risk retrying non-idempotent actions.
+               if err == ErrInvalidConn {
+                       return nil, "", driver.ErrBadConn
+               }
+               return
+       }
+
+       if data[0] == iERR {
+               return nil, "", mc.handleErrorPacket(data)
+       }
+
+       // protocol version [1 byte]
+       if data[0] < minProtocolVersion {
+               return nil, "", fmt.Errorf(
+                       "unsupported protocol version %d. Version %d or higher is required",
+                       data[0],
+                       minProtocolVersion,
+               )
+       }
+
+       // server version [null terminated string]
+       // connection id [4 bytes]
+       pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4
+
+       // first part of the password cipher [8 bytes]
+       authData := data[pos : pos+8]
+
+       // (filler) always 0x00 [1 byte]
+       pos += 8 + 1
+
+       // capability flags (lower 2 bytes) [2 bytes]
+       mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
+       if mc.flags&clientProtocol41 == 0 {
+               return nil, "", ErrOldProtocol
+       }
+       if mc.flags&clientSSL == 0 && mc.cfg.tls != nil {
+               return nil, "", ErrNoTLS
+       }
+       pos += 2
+
+       if len(data) > pos {
+               // character set [1 byte]
+               // status flags [2 bytes]
+               // capability flags (upper 2 bytes) [2 bytes]
+               // length of auth-plugin-data [1 byte]
+               // reserved (all [00]) [10 bytes]
+               pos += 1 + 2 + 2 + 1 + 10
+
+               // second part of the password cipher [mininum 13 bytes],
+               // where len=MAX(13, length of auth-plugin-data - 8)
+               //
+               // The web documentation is ambiguous about the length. However,
+               // according to mysql-5.7/sql/auth/sql_authentication.cc line 538,
+               // the 13th byte is "\0 byte, terminating the second part of
+               // a scramble". So the second part of the password cipher is
+               // a NULL terminated string that's at least 13 bytes with the
+               // last byte being NULL.
+               //
+               // The official Python library uses the fixed length 12
+               // which seems to work but technically could have a hidden bug.
+               authData = append(authData, data[pos:pos+12]...)
+               pos += 13
+
+               // EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2)
+               // \NUL otherwise
+               if end := bytes.IndexByte(data[pos:], 0x00); end != -1 {
+                       plugin = string(data[pos : pos+end])
+               } else {
+                       plugin = string(data[pos:])
+               }
+
+               // make a memory safe copy of the cipher slice
+               var b [20]byte
+               copy(b[:], authData)
+               return b[:], plugin, nil
+       }
+
+       // make a memory safe copy of the cipher slice
+       var b [8]byte
+       copy(b[:], authData)
+       return b[:], plugin, nil
+}
+
+// Client Authentication Packet
+// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
+func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, plugin string) error {
+       // Adjust client flags based on server support
+       clientFlags := clientProtocol41 |
+               clientSecureConn |
+               clientLongPassword |
+               clientTransactions |
+               clientLocalFiles |
+               clientPluginAuth |
+               clientMultiResults |
+               mc.flags&clientLongFlag
+
+       if mc.cfg.ClientFoundRows {
+               clientFlags |= clientFoundRows
+       }
+
+       // To enable TLS / SSL
+       if mc.cfg.tls != nil {
+               clientFlags |= clientSSL
+       }
+
+       if mc.cfg.MultiStatements {
+               clientFlags |= clientMultiStatements
+       }
+
+       // encode length of the auth plugin data
+       var authRespLEIBuf [9]byte
+       authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(len(authResp)))
+       if len(authRespLEI) > 1 {
+               // if the length can not be written in 1 byte, it must be written as a
+               // length encoded integer
+               clientFlags |= clientPluginAuthLenEncClientData
+       }
+
+       pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1
+       if addNUL {
+               pktLen++
+       }
+
+       // To specify a db name
+       if n := len(mc.cfg.DBName); n > 0 {
+               clientFlags |= clientConnectWithDB
+               pktLen += n + 1
+       }
+
+       // Calculate packet length and get buffer with that size
+       data := mc.buf.takeSmallBuffer(pktLen + 4)
+       if data == nil {
+               // cannot take the buffer. Something must be wrong with the connection
+               errLog.Print(ErrBusyBuffer)
+               return errBadConnNoWrite
+       }
+
+       // ClientFlags [32 bit]
+       data[4] = byte(clientFlags)
+       data[5] = byte(clientFlags >> 8)
+       data[6] = byte(clientFlags >> 16)
+       data[7] = byte(clientFlags >> 24)
+
+       // MaxPacketSize [32 bit] (none)
+       data[8] = 0x00
+       data[9] = 0x00
+       data[10] = 0x00
+       data[11] = 0x00
+
+       // Charset [1 byte]
+       var found bool
+       data[12], found = collations[mc.cfg.Collation]
+       if !found {
+               // Note possibility for false negatives:
+               // could be triggered  although the collation is valid if the
+               // collations map does not contain entries the server supports.
+               return errors.New("unknown collation")
+       }
+
+       // SSL Connection Request Packet
+       // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
+       if mc.cfg.tls != nil {
+               // Send TLS / SSL request packet
+               if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil {
+                       return err
+               }
+
+               // Switch to TLS
+               tlsConn := tls.Client(mc.netConn, mc.cfg.tls)
+               if err := tlsConn.Handshake(); err != nil {
+                       return err
+               }
+               mc.netConn = tlsConn
+               mc.buf.nc = tlsConn
+       }
+
+       // Filler [23 bytes] (all 0x00)
+       pos := 13
+       for ; pos < 13+23; pos++ {
+               data[pos] = 0
+       }
+
+       // User [null terminated string]
+       if len(mc.cfg.User) > 0 {
+               pos += copy(data[pos:], mc.cfg.User)
+       }
+       data[pos] = 0x00
+       pos++
+
+       // Auth Data [length encoded integer]
+       pos += copy(data[pos:], authRespLEI)
+       pos += copy(data[pos:], authResp)
+       if addNUL {
+               data[pos] = 0x00
+               pos++
+       }
+
+       // Databasename [null terminated string]
+       if len(mc.cfg.DBName) > 0 {
+               pos += copy(data[pos:], mc.cfg.DBName)
+               data[pos] = 0x00
+               pos++
+       }
+
+       pos += copy(data[pos:], plugin)
+       data[pos] = 0x00
+
+       // Send Auth packet
+       return mc.writePacket(data)
+}
+
+// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
+func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte, addNUL bool) error {
+       pktLen := 4 + len(authData)
+       if addNUL {
+               pktLen++
+       }
+       data := mc.buf.takeSmallBuffer(pktLen)
+       if data == nil {
+               // cannot take the buffer. Something must be wrong with the connection
+               errLog.Print(ErrBusyBuffer)
+               return errBadConnNoWrite
+       }
+
+       // Add the auth data [EOF]
+       copy(data[4:], authData)
+       if addNUL {
+               data[pktLen-1] = 0x00
+       }
+
+       return mc.writePacket(data)
+}
+
+/******************************************************************************
+*                             Command Packets                                 *
+******************************************************************************/
+
+func (mc *mysqlConn) writeCommandPacket(command byte) error {
+       // Reset Packet Sequence
+       mc.sequence = 0
+
+       data := mc.buf.takeSmallBuffer(4 + 1)
+       if data == nil {
+               // cannot take the buffer. Something must be wrong with the connection
+               errLog.Print(ErrBusyBuffer)
+               return errBadConnNoWrite
+       }
+
+       // Add command byte
+       data[4] = command
+
+       // Send CMD packet
+       return mc.writePacket(data)
+}
+
+func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
+       // Reset Packet Sequence
+       mc.sequence = 0
+
+       pktLen := 1 + len(arg)
+       data := mc.buf.takeBuffer(pktLen + 4)
+       if data == nil {
+               // cannot take the buffer. Something must be wrong with the connection
+               errLog.Print(ErrBusyBuffer)
+               return errBadConnNoWrite
+       }
+
+       // Add command byte
+       data[4] = command
+
+       // Add arg
+       copy(data[5:], arg)
+
+       // Send CMD packet
+       return mc.writePacket(data)
+}
+
+func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
+       // Reset Packet Sequence
+       mc.sequence = 0
+
+       data := mc.buf.takeSmallBuffer(4 + 1 + 4)
+       if data == nil {
+               // cannot take the buffer. Something must be wrong with the connection
+               errLog.Print(ErrBusyBuffer)
+               return errBadConnNoWrite
+       }
+
+       // Add command byte
+       data[4] = command
+
+       // Add arg [32 bit]
+       data[5] = byte(arg)
+       data[6] = byte(arg >> 8)
+       data[7] = byte(arg >> 16)
+       data[8] = byte(arg >> 24)
+
+       // Send CMD packet
+       return mc.writePacket(data)
+}
+
+/******************************************************************************
+*                              Result Packets                                 *
+******************************************************************************/
+
+func (mc *mysqlConn) readAuthResult() ([]byte, string, error) {
+       data, err := mc.readPacket()
+       if err != nil {
+               return nil, "", err
+       }
+
+       // packet indicator
+       switch data[0] {
+
+       case iOK:
+               return nil, "", mc.handleOkPacket(data)
+
+       case iAuthMoreData:
+               return data[1:], "", err
+
+       case iEOF:
+               if len(data) < 1 {
+                       // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest
+                       return nil, "mysql_old_password", nil
+               }
+               pluginEndIndex := bytes.IndexByte(data, 0x00)
+               if pluginEndIndex < 0 {
+                       return nil, "", ErrMalformPkt
+               }
+               plugin := string(data[1:pluginEndIndex])
+               authData := data[pluginEndIndex+1:]
+               return authData, plugin, nil
+
+       default: // Error otherwise
+               return nil, "", mc.handleErrorPacket(data)
+       }
+}
+
+// Returns error if Packet is not an 'Result OK'-Packet
+func (mc *mysqlConn) readResultOK() error {
+       data, err := mc.readPacket()
+       if err != nil {
+               return err
+       }
+
+       if data[0] == iOK {
+               return mc.handleOkPacket(data)
+       }
+       return mc.handleErrorPacket(data)
+}
+
+// Result Set Header Packet
+// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset
+func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {
+       data, err := mc.readPacket()
+       if err == nil {
+               switch data[0] {
+
+               case iOK:
+                       return 0, mc.handleOkPacket(data)
+
+               case iERR:
+                       return 0, mc.handleErrorPacket(data)
+
+               case iLocalInFile:
+                       return 0, mc.handleInFileRequest(string(data[1:]))
+               }
+
+               // column count
+               num, _, n := readLengthEncodedInteger(data)
+               if n-len(data) == 0 {
+                       return int(num), nil
+               }
+
+               return 0, ErrMalformPkt
+       }
+       return 0, err
+}
+
+// Error Packet
+// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-ERR_Packet
+func (mc *mysqlConn) handleErrorPacket(data []byte) error {
+       if data[0] != iERR {
+               return ErrMalformPkt
+       }
+
+       // 0xff [1 byte]
+
+       // Error Number [16 bit uint]
+       errno := binary.LittleEndian.Uint16(data[1:3])
+
+       // 1792: ER_CANT_EXECUTE_IN_READ_ONLY_TRANSACTION
+       // 1290: ER_OPTION_PREVENTS_STATEMENT (returned by Aurora during failover)
+       if (errno == 1792 || errno == 1290) && mc.cfg.RejectReadOnly {
+               // Oops; we are connected to a read-only connection, and won't be able
+               // to issue any write statements. Since RejectReadOnly is configured,
+               // we throw away this connection hoping this one would have write
+               // permission. This is specifically for a possible race condition
+               // during failover (e.g. on AWS Aurora). See README.md for more.
+               //
+               // We explicitly close the connection before returning
+               // driver.ErrBadConn to ensure that `database/sql` purges this
+               // connection and initiates a new one for next statement next time.
+               mc.Close()
+               return driver.ErrBadConn
+       }
+
+       pos := 3
+
+       // SQL State [optional: # + 5bytes string]
+       if data[3] == 0x23 {
+               //sqlstate := string(data[4 : 4+5])
+               pos = 9
+       }
+
+       // Error Message [string]
+       return &MySQLError{
+               Number:  errno,
+               Message: string(data[pos:]),
+       }
+}
+
+func readStatus(b []byte) statusFlag {
+       return statusFlag(b[0]) | statusFlag(b[1])<<8
+}
+
+// Ok Packet
+// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet
+func (mc *mysqlConn) handleOkPacket(data []byte) error {
+       var n, m int
+
+       // 0x00 [1 byte]
+
+       // Affected rows [Length Coded Binary]
+       mc.affectedRows, _, n = readLengthEncodedInteger(data[1:])
+
+       // Insert id [Length Coded Binary]
+       mc.insertId, _, m = readLengthEncodedInteger(data[1+n:])
+
+       // server_status [2 bytes]
+       mc.status = readStatus(data[1+n+m : 1+n+m+2])
+       if mc.status&statusMoreResultsExists != 0 {
+               return nil
+       }
+
+       // warning count [2 bytes]
+
+       return nil
+}
+
+// Read Packets as Field Packets until EOF-Packet or an Error appears
+// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41
+func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
+       columns := make([]mysqlField, count)
+
+       for i := 0; ; i++ {
+               data, err := mc.readPacket()
+               if err != nil {
+                       return nil, err
+               }
+
+               // EOF Packet
+               if data[0] == iEOF && (len(data) == 5 || len(data) == 1) {
+                       if i == count {
+                               return columns, nil
+                       }
+                       return nil, fmt.Errorf("column count mismatch n:%d len:%d", count, len(columns))
+               }
+
+               // Catalog
+               pos, err := skipLengthEncodedString(data)
+               if err != nil {
+                       return nil, err
+               }
+
+               // Database [len coded string]
+               n, err := skipLengthEncodedString(data[pos:])
+               if err != nil {
+                       return nil, err
+               }
+               pos += n
+
+               // Table [len coded string]
+               if mc.cfg.ColumnsWithAlias {
+                       tableName, _, n, err := readLengthEncodedString(data[pos:])
+                       if err != nil {
+                               return nil, err
+                       }
+                       pos += n
+                       columns[i].tableName = string(tableName)
+               } else {
+                       n, err = skipLengthEncodedString(data[pos:])
+                       if err != nil {
+                               return nil, err
+                       }
+                       pos += n
+               }
+
+               // Original table [len coded string]
+               n, err = skipLengthEncodedString(data[pos:])
+               if err != nil {
+                       return nil, err
+               }
+               pos += n
+
+               // Name [len coded string]
+               name, _, n, err := readLengthEncodedString(data[pos:])
+               if err != nil {
+                       return nil, err
+               }
+               columns[i].name = string(name)
+               pos += n
+
+               // Original name [len coded string]
+               n, err = skipLengthEncodedString(data[pos:])
+               if err != nil {
+                       return nil, err
+               }
+               pos += n
+
+               // Filler [uint8]
+               pos++
+
+               // Charset [charset, collation uint8]
+               columns[i].charSet = data[pos]
+               pos += 2
+
+               // Length [uint32]
+               columns[i].length = binary.LittleEndian.Uint32(data[pos : pos+4])
+               pos += 4
+
+               // Field type [uint8]
+               columns[i].fieldType = fieldType(data[pos])
+               pos++
+
+               // Flags [uint16]
+               columns[i].flags = fieldFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
+               pos += 2
+
+               // Decimals [uint8]
+               columns[i].decimals = data[pos]
+               //pos++
+
+               // Default value [len coded binary]
+               //if pos < len(data) {
+               //      defaultVal, _, err = bytesToLengthCodedBinary(data[pos:])
+               //}
+       }
+}
+
+// Read Packets as Field Packets until EOF-Packet or an Error appears
+// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow
+func (rows *textRows) readRow(dest []driver.Value) error {
+       mc := rows.mc
+
+       if rows.rs.done {
+               return io.EOF
+       }
+
+       data, err := mc.readPacket()
+       if err != nil {
+               return err
+       }
+
+       // EOF Packet
+       if data[0] == iEOF && len(data) == 5 {
+               // server_status [2 bytes]
+               rows.mc.status = readStatus(data[3:])
+               rows.rs.done = true
+               if !rows.HasNextResultSet() {
+                       rows.mc = nil
+               }
+               return io.EOF
+       }
+       if data[0] == iERR {
+               rows.mc = nil
+               return mc.handleErrorPacket(data)
+       }
+
+       // RowSet Packet
+       var n int
+       var isNull bool
+       pos := 0
+
+       for i := range dest {
+               // Read bytes and convert to string
+               dest[i], isNull, n, err = readLengthEncodedString(data[pos:])
+               pos += n
+               if err == nil {
+                       if !isNull {
+                               if !mc.parseTime {
+                                       continue
+                               } else {
+                                       switch rows.rs.columns[i].fieldType {
+                                       case fieldTypeTimestamp, fieldTypeDateTime,
+                                               fieldTypeDate, fieldTypeNewDate:
+                                               dest[i], err = parseDateTime(
+                                                       string(dest[i].([]byte)),
+                                                       mc.cfg.Loc,
+                                               )
+                                               if err == nil {
+                                                       continue
+                                               }
+                                       default:
+                                               continue
+                                       }
+                               }
+
+                       } else {
+                               dest[i] = nil
+                               continue
+                       }
+               }
+               return err // err != nil
+       }
+
+       return nil
+}
+
+// Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read
+func (mc *mysqlConn) readUntilEOF() error {
+       for {
+               data, err := mc.readPacket()
+               if err != nil {
+                       return err
+               }
+
+               switch data[0] {
+               case iERR:
+                       return mc.handleErrorPacket(data)
+               case iEOF:
+                       if len(data) == 5 {
+                               mc.status = readStatus(data[3:])
+                       }
+                       return nil
+               }
+       }
+}
+
+/******************************************************************************
+*                           Prepared Statements                               *
+******************************************************************************/
+
+// Prepare Result Packets
+// http://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html
+func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) {
+       data, err := stmt.mc.readPacket()
+       if err == nil {
+               // packet indicator [1 byte]
+               if data[0] != iOK {
+                       return 0, stmt.mc.handleErrorPacket(data)
+               }
+
+               // statement id [4 bytes]
+               stmt.id = binary.LittleEndian.Uint32(data[1:5])
+
+               // Column count [16 bit uint]
+               columnCount := binary.LittleEndian.Uint16(data[5:7])
+
+               // Param count [16 bit uint]
+               stmt.paramCount = int(binary.LittleEndian.Uint16(data[7:9]))
+
+               // Reserved [8 bit]
+
+               // Warning count [16 bit uint]
+
+               return columnCount, nil
+       }
+       return 0, err
+}
+
+// http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html
+func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
+       maxLen := stmt.mc.maxAllowedPacket - 1
+       pktLen := maxLen
+
+       // After the header (bytes 0-3) follows before the data:
+       // 1 byte command
+       // 4 bytes stmtID
+       // 2 bytes paramID
+       const dataOffset = 1 + 4 + 2
+
+       // Cannot use the write buffer since
+       // a) the buffer is too small
+       // b) it is in use
+       data := make([]byte, 4+1+4+2+len(arg))
+
+       copy(data[4+dataOffset:], arg)
+
+       for argLen := len(arg); argLen > 0; argLen -= pktLen - dataOffset {
+               if dataOffset+argLen < maxLen {
+                       pktLen = dataOffset + argLen
+               }
+
+               stmt.mc.sequence = 0
+               // Add command byte [1 byte]
+               data[4] = comStmtSendLongData
+
+               // Add stmtID [32 bit]
+               data[5] = byte(stmt.id)
+               data[6] = byte(stmt.id >> 8)
+               data[7] = byte(stmt.id >> 16)
+               data[8] = byte(stmt.id >> 24)
+
+               // Add paramID [16 bit]
+               data[9] = byte(paramID)
+               data[10] = byte(paramID >> 8)
+
+               // Send CMD packet
+               err := stmt.mc.writePacket(data[:4+pktLen])
+               if err == nil {
+                       data = data[pktLen-dataOffset:]
+                       continue
+               }
+               return err
+
+       }
+
+       // Reset Packet Sequence
+       stmt.mc.sequence = 0
+       return nil
+}
+
+// Execute Prepared Statement
+// http://dev.mysql.com/doc/internals/en/com-stmt-execute.html
+func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
+       if len(args) != stmt.paramCount {
+               return fmt.Errorf(
+                       "argument count mismatch (got: %d; has: %d)",
+                       len(args),
+                       stmt.paramCount,
+               )
+       }
+
+       const minPktLen = 4 + 1 + 4 + 1 + 4
+       mc := stmt.mc
+
+       // Determine threshould dynamically to avoid packet size shortage.
+       longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1)
+       if longDataSize < 64 {
+               longDataSize = 64
+       }
+
+       // Reset packet-sequence
+       mc.sequence = 0
+
+       var data []byte
+
+       if len(args) == 0 {
+               data = mc.buf.takeBuffer(minPktLen)
+       } else {
+               data = mc.buf.takeCompleteBuffer()
+       }
+       if data == nil {
+               // cannot take the buffer. Something must be wrong with the connection
+               errLog.Print(ErrBusyBuffer)
+               return errBadConnNoWrite
+       }
+
+       // command [1 byte]
+       data[4] = comStmtExecute
+
+       // statement_id [4 bytes]
+       data[5] = byte(stmt.id)
+       data[6] = byte(stmt.id >> 8)
+       data[7] = byte(stmt.id >> 16)
+       data[8] = byte(stmt.id >> 24)
+
+       // flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte]
+       data[9] = 0x00
+
+       // iteration_count (uint32(1)) [4 bytes]
+       data[10] = 0x01
+       data[11] = 0x00
+       data[12] = 0x00
+       data[13] = 0x00
+
+       if len(args) > 0 {
+               pos := minPktLen
+
+               var nullMask []byte
+               if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= len(data) {
+                       // buffer has to be extended but we don't know by how much so
+                       // we depend on append after all data with known sizes fit.
+                       // We stop at that because we deal with a lot of columns here
+                       // which makes the required allocation size hard to guess.
+                       tmp := make([]byte, pos+maskLen+typesLen)
+                       copy(tmp[:pos], data[:pos])
+                       data = tmp
+                       nullMask = data[pos : pos+maskLen]
+                       pos += maskLen
+               } else {
+                       nullMask = data[pos : pos+maskLen]
+                       for i := 0; i < maskLen; i++ {
+                               nullMask[i] = 0
+                       }
+                       pos += maskLen
+               }
+
+               // newParameterBoundFlag 1 [1 byte]
+               data[pos] = 0x01
+               pos++
+
+               // type of each parameter [len(args)*2 bytes]
+               paramTypes := data[pos:]
+               pos += len(args) * 2
+
+               // value of each parameter [n bytes]
+               paramValues := data[pos:pos]
+               valuesCap := cap(paramValues)
+
+               for i, arg := range args {
+                       // build NULL-bitmap
+                       if arg == nil {
+                               nullMask[i/8] |= 1 << (uint(i) & 7)
+                               paramTypes[i+i] = byte(fieldTypeNULL)
+                               paramTypes[i+i+1] = 0x00
+                               continue
+                       }
+
+                       // cache types and values
+                       switch v := arg.(type) {
+                       case int64:
+                               paramTypes[i+i] = byte(fieldTypeLongLong)
+                               paramTypes[i+i+1] = 0x00
+
+                               if cap(paramValues)-len(paramValues)-8 >= 0 {
+                                       paramValues = paramValues[:len(paramValues)+8]
+                                       binary.LittleEndian.PutUint64(
+                                               paramValues[len(paramValues)-8:],
+                                               uint64(v),
+                                       )
+                               } else {
+                                       paramValues = append(paramValues,
+                                               uint64ToBytes(uint64(v))...,
+                                       )
+                               }
+
+                       case float64:
+                               paramTypes[i+i] = byte(fieldTypeDouble)
+                               paramTypes[i+i+1] = 0x00
+
+                               if cap(paramValues)-len(paramValues)-8 >= 0 {
+                                       paramValues = paramValues[:len(paramValues)+8]
+                                       binary.LittleEndian.PutUint64(
+                                               paramValues[len(paramValues)-8:],
+                                               math.Float64bits(v),
+                                       )
+                               } else {
+                                       paramValues = append(paramValues,
+                                               uint64ToBytes(math.Float64bits(v))...,
+                                       )
+                               }
+
+                       case bool:
+                               paramTypes[i+i] = byte(fieldTypeTiny)
+                               paramTypes[i+i+1] = 0x00
+
+                               if v {
+                                       paramValues = append(paramValues, 0x01)
+                               } else {
+                                       paramValues = append(paramValues, 0x00)
+                               }
+
+                       case []byte:
+                               // Common case (non-nil value) first
+                               if v != nil {
+                                       paramTypes[i+i] = byte(fieldTypeString)
+                                       paramTypes[i+i+1] = 0x00
+
+                                       if len(v) < longDataSize {
+                                               paramValues = appendLengthEncodedInteger(paramValues,
+                                                       uint64(len(v)),
+                                               )
+                                               paramValues = append(paramValues, v...)
+                                       } else {
+                                               if err := stmt.writeCommandLongData(i, v); err != nil {
+                                                       return err
+                                               }
+                                       }
+                                       continue
+                               }
+
+                               // Handle []byte(nil) as a NULL value
+                               nullMask[i/8] |= 1 << (uint(i) & 7)
+                               paramTypes[i+i] = byte(fieldTypeNULL)
+                               paramTypes[i+i+1] = 0x00
+
+                       case string:
+                               paramTypes[i+i] = byte(fieldTypeString)
+                               paramTypes[i+i+1] = 0x00
+
+                               if len(v) < longDataSize {
+                                       paramValues = appendLengthEncodedInteger(paramValues,
+                                               uint64(len(v)),
+                                       )
+                                       paramValues = append(paramValues, v...)
+                               } else {
+                                       if err := stmt.writeCommandLongData(i, []byte(v)); err != nil {
+                                               return err
+                                       }
+                               }
+
+                       case time.Time:
+                               paramTypes[i+i] = byte(fieldTypeString)
+                               paramTypes[i+i+1] = 0x00
+
+                               var a [64]byte
+                               var b = a[:0]
+
+                               if v.IsZero() {
+                                       b = append(b, "0000-00-00"...)
+                               } else {
+                                       b = v.In(mc.cfg.Loc).AppendFormat(b, timeFormat)
+                               }
+
+                               paramValues = appendLengthEncodedInteger(paramValues,
+                                       uint64(len(b)),
+                               )
+                               paramValues = append(paramValues, b...)
+
+                       default:
+                               return fmt.Errorf("cannot convert type: %T", arg)
+                       }
+               }
+
+               // Check if param values exceeded the available buffer
+               // In that case we must build the data packet with the new values buffer
+               if valuesCap != cap(paramValues) {
+                       data = append(data[:pos], paramValues...)
+                       mc.buf.buf = data
+               }
+
+               pos += len(paramValues)
+               data = data[:pos]
+       }
+
+       return mc.writePacket(data)
+}
+
+func (mc *mysqlConn) discardResults() error {
+       for mc.status&statusMoreResultsExists != 0 {
+               resLen, err := mc.readResultSetHeaderPacket()
+               if err != nil {
+                       return err
+               }
+               if resLen > 0 {
+                       // columns
+                       if err := mc.readUntilEOF(); err != nil {
+                               return err
+                       }
+                       // rows
+                       if err := mc.readUntilEOF(); err != nil {
+                               return err
+                       }
+               }
+       }
+       return nil
+}
+
+// http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html
+func (rows *binaryRows) readRow(dest []driver.Value) error {
+       data, err := rows.mc.readPacket()
+       if err != nil {
+               return err
+       }
+
+       // packet indicator [1 byte]
+       if data[0] != iOK {
+               // EOF Packet
+               if data[0] == iEOF && len(data) == 5 {
+                       rows.mc.status = readStatus(data[3:])
+                       rows.rs.done = true
+                       if !rows.HasNextResultSet() {
+                               rows.mc = nil
+                       }
+                       return io.EOF
+               }
+               mc := rows.mc
+               rows.mc = nil
+
+               // Error otherwise
+               return mc.handleErrorPacket(data)
+       }
+
+       // NULL-bitmap,  [(column-count + 7 + 2) / 8 bytes]
+       pos := 1 + (len(dest)+7+2)>>3
+       nullMask := data[1:pos]
+
+       for i := range dest {
+               // Field is NULL
+               // (byte >> bit-pos) % 2 == 1
+               if ((nullMask[(i+2)>>3] >> uint((i+2)&7)) & 1) == 1 {
+                       dest[i] = nil
+                       continue
+               }
+
+               // Convert to byte-coded string
+               switch rows.rs.columns[i].fieldType {
+               case fieldTypeNULL:
+                       dest[i] = nil
+                       continue
+
+               // Numeric Types
+               case fieldTypeTiny:
+                       if rows.rs.columns[i].flags&flagUnsigned != 0 {
+                               dest[i] = int64(data[pos])
+                       } else {
+                               dest[i] = int64(int8(data[pos]))
+                       }
+                       pos++
+                       continue
+
+               case fieldTypeShort, fieldTypeYear:
+                       if rows.rs.columns[i].flags&flagUnsigned != 0 {
+                               dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2]))
+                       } else {
+                               dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2])))
+                       }
+                       pos += 2
+                       continue
+
+               case fieldTypeInt24, fieldTypeLong:
+                       if rows.rs.columns[i].flags&flagUnsigned != 0 {
+                               dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4]))
+                       } else {
+                               dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4])))
+                       }
+                       pos += 4
+                       continue
+
+               case fieldTypeLongLong:
+                       if rows.rs.columns[i].flags&flagUnsigned != 0 {
+                               val := binary.LittleEndian.Uint64(data[pos : pos+8])
+                               if val > math.MaxInt64 {
+                                       dest[i] = uint64ToString(val)
+                               } else {
+                                       dest[i] = int64(val)
+                               }
+                       } else {
+                               dest[i] = int64(binary.LittleEndian.Uint64(data[pos : pos+8]))
+                       }
+                       pos += 8
+                       continue
+
+               case fieldTypeFloat:
+                       dest[i] = math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4]))
+                       pos += 4
+                       continue
+
+               case fieldTypeDouble:
+                       dest[i] = math.Float64frombits(binary.LittleEndian.Uint64(data[pos : pos+8]))
+                       pos += 8
+                       continue
+
+               // Length coded Binary Strings
+               case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar,
+                       fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB,
+                       fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB,
+                       fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON:
+                       var isNull bool
+                       var n int
+                       dest[i], isNull, n, err = readLengthEncodedString(data[pos:])
+                       pos += n
+                       if err == nil {
+                               if !isNull {
+                                       continue
+                               } else {
+                                       dest[i] = nil
+                                       continue
+                               }
+                       }
+                       return err
+
+               case
+                       fieldTypeDate, fieldTypeNewDate, // Date YYYY-MM-DD
+                       fieldTypeTime,                         // Time [-][H]HH:MM:SS[.fractal]
+                       fieldTypeTimestamp, fieldTypeDateTime: // Timestamp YYYY-MM-DD HH:MM:SS[.fractal]
+
+                       num, isNull, n := readLengthEncodedInteger(data[pos:])
+                       pos += n
+
+                       switch {
+                       case isNull:
+                               dest[i] = nil
+                               continue
+                       case rows.rs.columns[i].fieldType == fieldTypeTime:
+                               // database/sql does not support an equivalent to TIME, return a string
+                               var dstlen uint8
+                               switch decimals := rows.rs.columns[i].decimals; decimals {
+                               case 0x00, 0x1f:
+                                       dstlen = 8
+                               case 1, 2, 3, 4, 5, 6:
+                                       dstlen = 8 + 1 + decimals
+                               default:
+                                       return fmt.Errorf(
+                                               "protocol error, illegal decimals value %d",
+                                               rows.rs.columns[i].decimals,
+                                       )
+                               }
+                               dest[i], err = formatBinaryTime(data[pos:pos+int(num)], dstlen)
+                       case rows.mc.parseTime:
+                               dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc)
+                       default:
+                               var dstlen uint8
+                               if rows.rs.columns[i].fieldType == fieldTypeDate {
+                                       dstlen = 10
+                               } else {
+                                       switch decimals := rows.rs.columns[i].decimals; decimals {
+                                       case 0x00, 0x1f:
+                                               dstlen = 19
+                                       case 1, 2, 3, 4, 5, 6:
+                                               dstlen = 19 + 1 + decimals
+                                       default:
+                                               return fmt.Errorf(
+                                                       "protocol error, illegal decimals value %d",
+                                                       rows.rs.columns[i].decimals,
+                                               )
+                                       }
+                               }
+                               dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen)
+                       }
+
+                       if err == nil {
+                               pos += int(num)
+                               continue
+                       } else {
+                               return err
+                       }
+
+               // Please report if this happens!
+               default:
+                       return fmt.Errorf("unknown field type %d", rows.rs.columns[i].fieldType)
+               }
+       }
+
+       return nil
+}