OSDN Git Service

Hulk did something
[bytom/vapor.git] / p2p / connection / connection.go
diff --git a/p2p/connection/connection.go b/p2p/connection/connection.go
new file mode 100644 (file)
index 0000000..30c900f
--- /dev/null
@@ -0,0 +1,418 @@
+package connection
+
+import (
+       "bufio"
+       "fmt"
+       "math"
+       "net"
+       "runtime/debug"
+       "sync/atomic"
+       "time"
+
+       log "github.com/sirupsen/logrus"
+       wire "github.com/tendermint/go-wire"
+       cmn "github.com/tendermint/tmlibs/common"
+       "github.com/tendermint/tmlibs/flowrate"
+)
+
+const (
+       packetTypePing           = byte(0x01)
+       packetTypePong           = byte(0x02)
+       packetTypeMsg            = byte(0x03)
+       maxMsgPacketPayloadSize  = 1024
+       maxMsgPacketOverheadSize = 10 // It's actually lower but good enough
+       maxMsgPacketTotalSize    = maxMsgPacketPayloadSize + maxMsgPacketOverheadSize
+
+       numBatchMsgPackets = 10
+       minReadBufferSize  = 1024
+       minWriteBufferSize = 65536
+       updateState        = 2 * time.Second
+       pingTimeout        = 40 * time.Second
+       flushThrottle      = 100 * time.Millisecond
+
+       defaultSendQueueCapacity   = 1
+       defaultSendRate            = int64(512000) // 500KB/s
+       defaultRecvBufferCapacity  = 4096
+       defaultRecvMessageCapacity = 22020096      // 21MB
+       defaultRecvRate            = int64(512000) // 500KB/s
+       defaultSendTimeout         = 10 * time.Second
+       logModule                  = "p2p/conn"
+)
+
+type receiveCbFunc func(chID byte, msgBytes []byte)
+type errorCbFunc func(interface{})
+
+// Messages in channels are chopped into smaller msgPackets for multiplexing.
+type msgPacket struct {
+       ChannelID byte
+       EOF       byte // 1 means message ends here.
+       Bytes     []byte
+}
+
+func (p msgPacket) String() string {
+       return fmt.Sprintf("MsgPacket{%X:%X T:%X}", p.ChannelID, p.Bytes, p.EOF)
+}
+
+/*
+MConnection handles message transmission on multiple abstract communication
+`Channel`s.  Each channel has a globally unique byte id.
+The byte id and the relative priorities of each `Channel` are configured upon
+initialization of the connection.
+
+There are two methods for sending messages:
+       func (m MConnection) Send(chID byte, msg interface{}) bool {}
+       func (m MConnection) TrySend(chID byte, msg interface{}) bool {}
+
+`Send(chID, msg)` is a blocking call that waits until `msg` is successfully queued
+for the channel with the given id byte `chID`, or until the request times out.
+The message `msg` is serialized using the `tendermint/wire` submodule's
+`WriteBinary()` reflection routine.
+
+`TrySend(chID, msg)` is a nonblocking call that returns false if the channel's
+queue is full.
+
+Inbound message bytes are handled with an onReceive callback function.
+*/
+type MConnection struct {
+       cmn.BaseService
+
+       conn        net.Conn
+       bufReader   *bufio.Reader
+       bufWriter   *bufio.Writer
+       sendMonitor *flowrate.Monitor
+       recvMonitor *flowrate.Monitor
+       send        chan struct{}
+       pong        chan struct{}
+       channels    []*channel
+       channelsIdx map[byte]*channel
+       onReceive   receiveCbFunc
+       onError     errorCbFunc
+       errored     uint32
+       config      *MConnConfig
+
+       quit         chan struct{}
+       flushTimer   *cmn.ThrottleTimer // flush writes as necessary but throttled.
+       pingTimer    *time.Ticker       // send pings periodically
+       chStatsTimer *time.Ticker       // update channel stats periodically
+}
+
+// MConnConfig is a MConnection configuration.
+type MConnConfig struct {
+       SendRate int64 `mapstructure:"send_rate"`
+       RecvRate int64 `mapstructure:"recv_rate"`
+}
+
+// DefaultMConnConfig returns the default config.
+func DefaultMConnConfig() *MConnConfig {
+       return &MConnConfig{
+               SendRate: defaultSendRate,
+               RecvRate: defaultRecvRate,
+       }
+}
+
+// NewMConnectionWithConfig wraps net.Conn and creates multiplex connection with a config
+func NewMConnectionWithConfig(conn net.Conn, chDescs []*ChannelDescriptor, onReceive receiveCbFunc, onError errorCbFunc, config *MConnConfig) *MConnection {
+       mconn := &MConnection{
+               conn:        conn,
+               bufReader:   bufio.NewReaderSize(conn, minReadBufferSize),
+               bufWriter:   bufio.NewWriterSize(conn, minWriteBufferSize),
+               sendMonitor: flowrate.New(0, 0),
+               recvMonitor: flowrate.New(0, 0),
+               send:        make(chan struct{}, 1),
+               pong:        make(chan struct{}, 1),
+               channelsIdx: map[byte]*channel{},
+               channels:    []*channel{},
+               onReceive:   onReceive,
+               onError:     onError,
+               config:      config,
+
+               pingTimer:    time.NewTicker(pingTimeout),
+               chStatsTimer: time.NewTicker(updateState),
+       }
+
+       for _, desc := range chDescs {
+               descCopy := *desc // copy the desc else unsafe access across connections
+               channel := newChannel(mconn, &descCopy)
+               mconn.channelsIdx[channel.id] = channel
+               mconn.channels = append(mconn.channels, channel)
+       }
+       mconn.BaseService = *cmn.NewBaseService(nil, "MConnection", mconn)
+       return mconn
+}
+
+// OnStart implements BaseService
+func (c *MConnection) OnStart() error {
+       c.BaseService.OnStart()
+       c.quit = make(chan struct{})
+       c.flushTimer = cmn.NewThrottleTimer("flush", flushThrottle)
+       go c.sendRoutine()
+       go c.recvRoutine()
+       return nil
+}
+
+// OnStop implements BaseService
+func (c *MConnection) OnStop() {
+       c.BaseService.OnStop()
+       c.flushTimer.Stop()
+       c.pingTimer.Stop()
+       c.chStatsTimer.Stop()
+       if c.quit != nil {
+               close(c.quit)
+       }
+       c.conn.Close()
+       // We can't close pong safely here because recvRoutine may write to it after we've
+       // stopped. Though it doesn't need to get closed at all, we close it @ recvRoutine.
+}
+
+// CanSend returns true if you can send more data onto the chID, false otherwise
+func (c *MConnection) CanSend(chID byte) bool {
+       if !c.IsRunning() {
+               return false
+       }
+
+       channel, ok := c.channelsIdx[chID]
+       if !ok {
+               return false
+       }
+       return channel.canSend()
+}
+
+// Send will queues a message to be sent to channel(blocking).
+func (c *MConnection) Send(chID byte, msg interface{}) bool {
+       if !c.IsRunning() {
+               return false
+       }
+
+       channel, ok := c.channelsIdx[chID]
+       if !ok {
+               log.WithFields(log.Fields{"module": logModule, "chID": chID}).Error("cannot send bytes due to unknown channel")
+               return false
+       }
+
+       if !channel.sendBytes(wire.BinaryBytes(msg)) {
+               log.WithFields(log.Fields{"module": logModule, "chID": chID, "conn": c, "msg": msg}).Error("MConnection send failed")
+               return false
+       }
+
+       select {
+       case c.send <- struct{}{}:
+       default:
+       }
+       return true
+}
+
+// TrafficStatus return the in and out traffic status
+func (c *MConnection) TrafficStatus() (*flowrate.Status, *flowrate.Status) {
+       sentStatus := c.sendMonitor.Status()
+       receivedStatus := c.recvMonitor.Status()
+       return &sentStatus, &receivedStatus
+}
+
+// TrySend queues a message to be sent to channel(Nonblocking).
+func (c *MConnection) TrySend(chID byte, msg interface{}) bool {
+       if !c.IsRunning() {
+               return false
+       }
+
+       channel, ok := c.channelsIdx[chID]
+       if !ok {
+               log.WithFields(log.Fields{"module": logModule, "chID": chID}).Error("cannot send bytes due to unknown channel")
+               return false
+       }
+
+       ok = channel.trySendBytes(wire.BinaryBytes(msg))
+       if ok {
+               select {
+               case c.send <- struct{}{}:
+               default:
+               }
+       }
+       return ok
+}
+
+func (c *MConnection) String() string {
+       return fmt.Sprintf("MConn{%v}", c.conn.RemoteAddr())
+}
+
+func (c *MConnection) flush() {
+       if err := c.bufWriter.Flush(); err != nil {
+               log.WithFields(log.Fields{"module": logModule, "error": err}).Error("MConnection flush failed")
+       }
+}
+
+// Catch panics, usually caused by remote disconnects.
+func (c *MConnection) _recover() {
+       if r := recover(); r != nil {
+               stack := debug.Stack()
+               err := cmn.StackError{r, stack}
+               c.stopForError(err)
+       }
+}
+
+// recvRoutine reads msgPackets and reconstructs the message using the channels' "recving" buffer.
+// After a whole message has been assembled, it's pushed to onReceive().
+// Blocks depending on how the connection is throttled.
+func (c *MConnection) recvRoutine() {
+       defer c._recover()
+       defer close(c.pong)
+
+       for {
+               // Block until .recvMonitor says we can read.
+               c.recvMonitor.Limit(maxMsgPacketTotalSize, atomic.LoadInt64(&c.config.RecvRate), true)
+
+               // Read packet type
+               var n int
+               var err error
+               pktType := wire.ReadByte(c.bufReader, &n, &err)
+               c.recvMonitor.Update(int(n))
+               if err != nil {
+                       if c.IsRunning() {
+                               log.WithFields(log.Fields{"module": logModule, "conn": c, "error": err}).Error("Connection failed @ recvRoutine (reading byte)")
+                               c.conn.Close()
+                               c.stopForError(err)
+                       }
+                       return
+               }
+
+               // Read more depending on packet type.
+               switch pktType {
+               case packetTypePing:
+                       log.WithFields(log.Fields{"module": logModule, "conn": c}).Debug("receive Ping")
+                       select {
+                       case c.pong <- struct{}{}:
+                       default:
+                       }
+
+               case packetTypePong:
+                       log.WithFields(log.Fields{"module": logModule, "conn": c}).Debug("receive Pong")
+
+               case packetTypeMsg:
+                       pkt, n, err := msgPacket{}, int(0), error(nil)
+                       wire.ReadBinaryPtr(&pkt, c.bufReader, maxMsgPacketTotalSize, &n, &err)
+                       c.recvMonitor.Update(int(n))
+                       if err != nil {
+                               if c.IsRunning() {
+                                       log.WithFields(log.Fields{"module": logModule, "conn": c, "error": err}).Error("failed on recvRoutine")
+                                       c.stopForError(err)
+                               }
+                               return
+                       }
+
+                       channel, ok := c.channelsIdx[pkt.ChannelID]
+                       if !ok || channel == nil {
+                               cmn.PanicQ(cmn.Fmt("Unknown channel %X", pkt.ChannelID))
+                       }
+
+                       msgBytes, err := channel.recvMsgPacket(pkt)
+                       if err != nil {
+                               if c.IsRunning() {
+                                       log.WithFields(log.Fields{"module": logModule, "conn": c, "error": err}).Error("failed on recvRoutine")
+                                       c.stopForError(err)
+                               }
+                               return
+                       }
+
+                       if msgBytes != nil {
+                               c.onReceive(pkt.ChannelID, msgBytes)
+                       }
+
+               default:
+                       cmn.PanicSanity(cmn.Fmt("Unknown message type %X", pktType))
+               }
+       }
+}
+
+// Returns true if messages from channels were exhausted.
+func (c *MConnection) sendMsgPacket() bool {
+       var leastRatio float32 = math.MaxFloat32
+       var leastChannel *channel
+       for _, channel := range c.channels {
+               if !channel.isSendPending() {
+                       continue
+               }
+               if ratio := float32(channel.recentlySent) / float32(channel.priority); ratio < leastRatio {
+                       leastRatio = ratio
+                       leastChannel = channel
+               }
+       }
+       if leastChannel == nil {
+               return true
+       }
+
+       n, err := leastChannel.writeMsgPacketTo(c.bufWriter)
+       if err != nil {
+               log.WithFields(log.Fields{"module": logModule, "error": err}).Error("failed to write msgPacket")
+               c.stopForError(err)
+               return true
+       }
+       c.sendMonitor.Update(int(n))
+       c.flushTimer.Set()
+       return false
+}
+
+// sendRoutine polls for packets to send from channels.
+func (c *MConnection) sendRoutine() {
+       defer c._recover()
+
+       for {
+               var n int
+               var err error
+               select {
+               case <-c.flushTimer.Ch:
+                       c.flush()
+               case <-c.chStatsTimer.C:
+                       for _, channel := range c.channels {
+                               channel.updateStats()
+                       }
+               case <-c.pingTimer.C:
+                       log.WithFields(log.Fields{"module": logModule, "conn": c}).Debug("send Ping")
+                       wire.WriteByte(packetTypePing, c.bufWriter, &n, &err)
+                       c.sendMonitor.Update(int(n))
+                       c.flush()
+               case <-c.pong:
+                       log.WithFields(log.Fields{"module": logModule, "conn": c}).Debug("send Pong")
+                       wire.WriteByte(packetTypePong, c.bufWriter, &n, &err)
+                       c.sendMonitor.Update(int(n))
+                       c.flush()
+               case <-c.quit:
+                       return
+               case <-c.send:
+                       if eof := c.sendSomeMsgPackets(); !eof {
+                               select {
+                               case c.send <- struct{}{}:
+                               default:
+                               }
+                       }
+               }
+
+               if !c.IsRunning() {
+                       return
+               }
+               if err != nil {
+                       log.WithFields(log.Fields{"module": logModule, "conn": c, "error": err}).Error("Connection failed @ sendRoutine")
+                       c.stopForError(err)
+                       return
+               }
+       }
+}
+
+// Returns true if messages from channels were exhausted.
+func (c *MConnection) sendSomeMsgPackets() bool {
+       // Block until .sendMonitor says we can write.
+       // Once we're ready we send more than we asked for,
+       // but amortized it should even out.
+       c.sendMonitor.Limit(maxMsgPacketTotalSize, atomic.LoadInt64(&c.config.SendRate), true)
+       for i := 0; i < numBatchMsgPackets; i++ {
+               if c.sendMsgPacket() {
+                       return true
+               }
+       }
+       return false
+}
+
+func (c *MConnection) stopForError(r interface{}) {
+       c.Stop()
+       if atomic.CompareAndSwapUint32(&c.errored, 0, 1) && c.onError != nil {
+               c.onError(r)
+       }
+}