OSDN Git Service

Merge branch 'dev' into bvm
[bytom/bytom.git] / p2p / connection.go
1 package p2p
2
3 import (
4         "bufio"
5         "fmt"
6         "io"
7         "math"
8         "net"
9         "runtime/debug"
10         "sync/atomic"
11         "time"
12
13         log "github.com/sirupsen/logrus"
14         wire "github.com/tendermint/go-wire"
15         cmn "github.com/tendermint/tmlibs/common"
16         flow "github.com/tendermint/tmlibs/flowrate"
17 )
18
19 const (
20         numBatchMsgPackets = 10
21         minReadBufferSize  = 1024
22         minWriteBufferSize = 65536
23         updateState        = 2 * time.Second
24         pingTimeout        = 40 * time.Second
25         flushThrottle      = 100 * time.Millisecond
26
27         defaultSendQueueCapacity   = 1
28         defaultSendRate            = int64(512000) // 500KB/s
29         defaultRecvBufferCapacity  = 4096
30         defaultRecvMessageCapacity = 22020096      // 21MB
31         defaultRecvRate            = int64(512000) // 500KB/s
32         defaultSendTimeout         = 10 * time.Second
33 )
34
35 type receiveCbFunc func(chID byte, msgBytes []byte)
36 type errorCbFunc func(interface{})
37
38 /*
39 Each peer has one `MConnection` (multiplex connection) instance.
40
41 __multiplex__ *noun* a system or signal involving simultaneous transmission of
42 several messages along a single channel of communication.
43
44 Each `MConnection` handles message transmission on multiple abstract communication
45 `Channel`s.  Each channel has a globally unique byte id.
46 The byte id and the relative priorities of each `Channel` are configured upon
47 initialization of the connection.
48
49 There are two methods for sending messages:
50         func (m MConnection) Send(chID byte, msg interface{}) bool {}
51         func (m MConnection) TrySend(chID byte, msg interface{}) bool {}
52
53 `Send(chID, msg)` is a blocking call that waits until `msg` is successfully queued
54 for the channel with the given id byte `chID`, or until the request times out.
55 The message `msg` is serialized using the `tendermint/wire` submodule's
56 `WriteBinary()` reflection routine.
57
58 `TrySend(chID, msg)` is a nonblocking call that returns false if the channel's
59 queue is full.
60
61 Inbound message bytes are handled with an onReceive callback function.
62 */
63 type MConnection struct {
64         cmn.BaseService
65
66         conn        net.Conn
67         bufReader   *bufio.Reader
68         bufWriter   *bufio.Writer
69         sendMonitor *flow.Monitor
70         recvMonitor *flow.Monitor
71         send        chan struct{}
72         pong        chan struct{}
73         channels    []*Channel
74         channelsIdx map[byte]*Channel
75         onReceive   receiveCbFunc
76         onError     errorCbFunc
77         errored     uint32
78         config      *MConnConfig
79
80         quit         chan struct{}
81         flushTimer   *cmn.ThrottleTimer // flush writes as necessary but throttled.
82         pingTimer    *cmn.RepeatTimer   // send pings periodically
83         chStatsTimer *cmn.RepeatTimer   // update channel stats periodically
84
85         LocalAddress  *NetAddress
86         RemoteAddress *NetAddress
87 }
88
89 // MConnConfig is a MConnection configuration.
90 type MConnConfig struct {
91         SendRate int64 `mapstructure:"send_rate"`
92         RecvRate int64 `mapstructure:"recv_rate"`
93 }
94
95 // DefaultMConnConfig returns the default config.
96 func DefaultMConnConfig() *MConnConfig {
97         return &MConnConfig{
98                 SendRate: defaultSendRate,
99                 RecvRate: defaultRecvRate,
100         }
101 }
102
103 // NewMConnection wraps net.Conn and creates multiplex connection
104 func NewMConnection(conn net.Conn, chDescs []*ChannelDescriptor, onReceive receiveCbFunc, onError errorCbFunc) *MConnection {
105         return NewMConnectionWithConfig(
106                 conn,
107                 chDescs,
108                 onReceive,
109                 onError,
110                 DefaultMConnConfig())
111 }
112
113 // NewMConnectionWithConfig wraps net.Conn and creates multiplex connection with a config
114 func NewMConnectionWithConfig(conn net.Conn, chDescs []*ChannelDescriptor, onReceive receiveCbFunc, onError errorCbFunc, config *MConnConfig) *MConnection {
115         mconn := &MConnection{
116                 conn:        conn,
117                 bufReader:   bufio.NewReaderSize(conn, minReadBufferSize),
118                 bufWriter:   bufio.NewWriterSize(conn, minWriteBufferSize),
119                 sendMonitor: flow.New(0, 0),
120                 recvMonitor: flow.New(0, 0),
121                 send:        make(chan struct{}, 1),
122                 pong:        make(chan struct{}),
123                 onReceive:   onReceive,
124                 onError:     onError,
125                 config:      config,
126
127                 LocalAddress:  NewNetAddress(conn.LocalAddr()),
128                 RemoteAddress: NewNetAddress(conn.RemoteAddr()),
129         }
130
131         // Create channels
132         var channelsIdx = map[byte]*Channel{}
133         var channels = []*Channel{}
134
135         for _, desc := range chDescs {
136                 descCopy := *desc // copy the desc else unsafe access across connections
137                 channel := newChannel(mconn, &descCopy)
138                 channelsIdx[channel.id] = channel
139                 channels = append(channels, channel)
140         }
141         mconn.channels = channels
142         mconn.channelsIdx = channelsIdx
143
144         mconn.BaseService = *cmn.NewBaseService(nil, "MConnection", mconn)
145
146         return mconn
147 }
148
149 func (c *MConnection) OnStart() error {
150         c.BaseService.OnStart()
151         c.quit = make(chan struct{})
152         c.flushTimer = cmn.NewThrottleTimer("flush", flushThrottle)
153         c.pingTimer = cmn.NewRepeatTimer("ping", pingTimeout)
154         c.chStatsTimer = cmn.NewRepeatTimer("chStats", updateState)
155         go c.sendRoutine()
156         go c.recvRoutine()
157         return nil
158 }
159
160 func (c *MConnection) OnStop() {
161         c.BaseService.OnStop()
162         c.flushTimer.Stop()
163         c.pingTimer.Stop()
164         c.chStatsTimer.Stop()
165         if c.quit != nil {
166                 close(c.quit)
167         }
168         c.conn.Close()
169         // We can't close pong safely here because
170         // recvRoutine may write to it after we've stopped.
171         // Though it doesn't need to get closed at all,
172         // we close it @ recvRoutine.
173         // close(c.pong)
174 }
175
176 func (c *MConnection) String() string {
177         return fmt.Sprintf("MConn{%v}", c.conn.RemoteAddr())
178 }
179
180 func (c *MConnection) flush() {
181         log.WithField("conn", c).Debug("Flush")
182         err := c.bufWriter.Flush()
183         if err != nil {
184                 log.WithField("error", err).Error("MConnection flush failed")
185         }
186 }
187
188 // Catch panics, usually caused by remote disconnects.
189 func (c *MConnection) _recover() {
190         if r := recover(); r != nil {
191                 stack := debug.Stack()
192                 err := cmn.StackError{r, stack}
193                 c.stopForError(err)
194         }
195 }
196
197 func (c *MConnection) stopForError(r interface{}) {
198         c.Stop()
199         if atomic.CompareAndSwapUint32(&c.errored, 0, 1) {
200                 if c.onError != nil {
201                         c.onError(r)
202                 }
203         }
204 }
205
206 // Queues a message to be sent to channel.
207 func (c *MConnection) Send(chID byte, msg interface{}) bool {
208         if !c.IsRunning() {
209                 return false
210         }
211
212         log.WithFields(log.Fields{
213                 "chID": chID,
214                 "conn": c,
215                 "msg":  msg,
216         }).Debug("Send")
217
218         // Send message to channel.
219         channel, ok := c.channelsIdx[chID]
220         if !ok {
221                 log.WithField("chID", chID).Error(cmn.Fmt("Cannot send bytes, unknown channel"))
222                 return false
223         }
224
225         success := channel.sendBytes(wire.BinaryBytes(msg))
226         if success {
227                 // Wake up sendRoutine if necessary
228                 select {
229                 case c.send <- struct{}{}:
230                 default:
231                 }
232         } else {
233                 log.WithFields(log.Fields{
234                         "chID": chID,
235                         "conn": c,
236                         "msg":  msg,
237                 }).Error("Send failed")
238         }
239         return success
240 }
241
242 // Queues a message to be sent to channel.
243 // Nonblocking, returns true if successful.
244 func (c *MConnection) TrySend(chID byte, msg interface{}) bool {
245         if !c.IsRunning() {
246                 return false
247         }
248
249         log.WithFields(log.Fields{
250                 "chID": chID,
251                 "conn": c,
252                 "msg":  msg,
253         }).Debug("TrySend")
254
255         // Send message to channel.
256         channel, ok := c.channelsIdx[chID]
257         if !ok {
258                 log.WithField("chID", chID).Error(cmn.Fmt("cannot send bytes, unknown channel"))
259                 return false
260         }
261
262         ok = channel.trySendBytes(wire.BinaryBytes(msg))
263         if ok {
264                 // Wake up sendRoutine if necessary
265                 select {
266                 case c.send <- struct{}{}:
267                 default:
268                 }
269         }
270
271         return ok
272 }
273
274 // CanSend returns true if you can send more data onto the chID, false
275 // otherwise. Use only as a heuristic.
276 func (c *MConnection) CanSend(chID byte) bool {
277         if !c.IsRunning() {
278                 return false
279         }
280
281         channel, ok := c.channelsIdx[chID]
282         if !ok {
283                 log.WithField("chID", chID).Error(cmn.Fmt("Unknown channel"))
284                 return false
285         }
286         return channel.canSend()
287 }
288
289 // sendRoutine polls for packets to send from channels.
290 func (c *MConnection) sendRoutine() {
291         defer c._recover()
292
293 FOR_LOOP:
294         for {
295                 var n int
296                 var err error
297                 select {
298                 case <-c.flushTimer.Ch:
299                         // NOTE: flushTimer.Set() must be called every time
300                         // something is written to .bufWriter.
301                         c.flush()
302                 case <-c.chStatsTimer.Ch:
303                         for _, channel := range c.channels {
304                                 channel.updateStats()
305                         }
306                 case <-c.pingTimer.Ch:
307                         log.Debug("Send Ping")
308                         wire.WriteByte(packetTypePing, c.bufWriter, &n, &err)
309                         c.sendMonitor.Update(int(n))
310                         c.flush()
311                 case <-c.pong:
312                         log.Debug("Send Pong")
313                         wire.WriteByte(packetTypePong, c.bufWriter, &n, &err)
314                         c.sendMonitor.Update(int(n))
315                         c.flush()
316                 case <-c.quit:
317                         break FOR_LOOP
318                 case <-c.send:
319                         // Send some msgPackets
320                         eof := c.sendSomeMsgPackets()
321                         if !eof {
322                                 // Keep sendRoutine awake.
323                                 select {
324                                 case c.send <- struct{}{}:
325                                 default:
326                                 }
327                         }
328                 }
329
330                 if !c.IsRunning() {
331                         break FOR_LOOP
332                 }
333                 if err != nil {
334                         log.WithFields(log.Fields{
335                                 "conn":  c,
336                                 "error": err,
337                         }).Error("Connection failed @ sendRoutine")
338                         c.stopForError(err)
339                         break FOR_LOOP
340                 }
341         }
342
343         // Cleanup
344 }
345
346 // Returns true if messages from channels were exhausted.
347 // Blocks in accordance to .sendMonitor throttling.
348 func (c *MConnection) sendSomeMsgPackets() bool {
349         // Block until .sendMonitor says we can write.
350         // Once we're ready we send more than we asked for,
351         // but amortized it should even out.
352         c.sendMonitor.Limit(maxMsgPacketTotalSize, atomic.LoadInt64(&c.config.SendRate), true)
353
354         // Now send some msgPackets.
355         for i := 0; i < numBatchMsgPackets; i++ {
356                 if c.sendMsgPacket() {
357                         return true
358                 }
359         }
360         return false
361 }
362
363 // Returns true if messages from channels were exhausted.
364 func (c *MConnection) sendMsgPacket() bool {
365         // Choose a channel to create a msgPacket from.
366         // The chosen channel will be the one whose recentlySent/priority is the least.
367         var leastRatio float32 = math.MaxFloat32
368         var leastChannel *Channel
369         for _, channel := range c.channels {
370                 // If nothing to send, skip this channel
371                 if !channel.isSendPending() {
372                         continue
373                 }
374                 // Get ratio, and keep track of lowest ratio.
375                 ratio := float32(channel.recentlySent) / float32(channel.priority)
376                 if ratio < leastRatio {
377                         leastRatio = ratio
378                         leastChannel = channel
379                 }
380         }
381
382         // Nothing to send?
383         if leastChannel == nil {
384                 return true
385         } else {
386                 // c.Logger.Info("Found a msgPacket to send")
387         }
388
389         // Make & send a msgPacket from this channel
390         n, err := leastChannel.writeMsgPacketTo(c.bufWriter)
391         if err != nil {
392                 log.WithField("error", err).Error("Failed to write msgPacket")
393                 c.stopForError(err)
394                 return true
395         }
396         c.sendMonitor.Update(int(n))
397         c.flushTimer.Set()
398         return false
399 }
400
401 // recvRoutine reads msgPackets and reconstructs the message using the channels' "recving" buffer.
402 // After a whole message has been assembled, it's pushed to onReceive().
403 // Blocks depending on how the connection is throttled.
404 func (c *MConnection) recvRoutine() {
405         defer c._recover()
406
407 FOR_LOOP:
408         for {
409                 // Block until .recvMonitor says we can read.
410                 c.recvMonitor.Limit(maxMsgPacketTotalSize, atomic.LoadInt64(&c.config.RecvRate), true)
411
412                 /*
413                         // Peek into bufReader for debugging
414                         if numBytes := c.bufReader.Buffered(); numBytes > 0 {
415                                 log.Info("Peek connection buffer", "numBytes", numBytes, "bytes", log15.Lazy{func() []byte {
416                                         bytes, err := c.bufReader.Peek(MinInt(numBytes, 100))
417                                         if err == nil {
418                                                 return bytes
419                                         } else {
420                                                 log.Warn("Error peeking connection buffer", "error", err)
421                                                 return nil
422                                         }
423                                 }})
424                         }
425                 */
426
427                 // Read packet type
428                 var n int
429                 var err error
430                 pktType := wire.ReadByte(c.bufReader, &n, &err)
431                 c.recvMonitor.Update(int(n))
432                 if err != nil {
433                         if c.IsRunning() {
434                                 log.WithFields(log.Fields{
435                                         "conn":  c,
436                                         "error": err,
437                                 }).Error("Connection failed @ recvRoutine (reading byte)")
438                                 c.stopForError(err)
439                         }
440                         break FOR_LOOP
441                 }
442
443                 // Read more depending on packet type.
444                 switch pktType {
445                 case packetTypePing:
446                         // TODO: prevent abuse, as they cause flush()'s.
447                         log.Debug("Receive Ping")
448                         c.pong <- struct{}{}
449                 case packetTypePong:
450                         // do nothing
451                         log.Debug("Receive Pong")
452                 case packetTypeMsg:
453                         pkt, n, err := msgPacket{}, int(0), error(nil)
454                         wire.ReadBinaryPtr(&pkt, c.bufReader, maxMsgPacketTotalSize, &n, &err)
455                         c.recvMonitor.Update(int(n))
456                         if err != nil {
457                                 if c.IsRunning() {
458                                         log.WithFields(log.Fields{
459                                                 "conn":  c,
460                                                 "error": err,
461                                         }).Error("Connection failed @ recvRoutine")
462                                         c.stopForError(err)
463                                 }
464                                 break FOR_LOOP
465                         }
466                         channel, ok := c.channelsIdx[pkt.ChannelID]
467                         if !ok || channel == nil {
468                                 cmn.PanicQ(cmn.Fmt("Unknown channel %X", pkt.ChannelID))
469                         }
470                         msgBytes, err := channel.recvMsgPacket(pkt)
471                         if err != nil {
472                                 if c.IsRunning() {
473                                         log.WithFields(log.Fields{
474                                                 "conn":  c,
475                                                 "error": err,
476                                         }).Error("Connection failed @ recvRoutine")
477                                         c.stopForError(err)
478                                 }
479                                 break FOR_LOOP
480                         }
481                         if msgBytes != nil {
482                                 log.WithFields(log.Fields{
483                                         "channelID": pkt.ChannelID,
484                                         "msgBytes":  msgBytes,
485                                 }).Debug("Received bytes")
486                                 c.onReceive(pkt.ChannelID, msgBytes)
487                         }
488                 default:
489                         cmn.PanicSanity(cmn.Fmt("Unknown message type %X", pktType))
490                 }
491
492                 // TODO: shouldn't this go in the sendRoutine?
493                 // Better to send a ping packet when *we* haven't sent anything for a while.
494                 c.pingTimer.Reset()
495         }
496
497         // Cleanup
498         close(c.pong)
499         for _ = range c.pong {
500                 // Drain
501         }
502 }
503
504 type ConnectionStatus struct {
505         SendMonitor flow.Status
506         RecvMonitor flow.Status
507         Channels    []ChannelStatus
508 }
509
510 type ChannelStatus struct {
511         ID                byte
512         SendQueueCapacity int
513         SendQueueSize     int
514         Priority          int
515         RecentlySent      int64
516 }
517
518 func (c *MConnection) Status() ConnectionStatus {
519         var status ConnectionStatus
520         status.SendMonitor = c.sendMonitor.Status()
521         status.RecvMonitor = c.recvMonitor.Status()
522         status.Channels = make([]ChannelStatus, len(c.channels))
523         for i, channel := range c.channels {
524                 status.Channels[i] = ChannelStatus{
525                         ID:                channel.id,
526                         SendQueueCapacity: cap(channel.sendQueue),
527                         SendQueueSize:     int(channel.sendQueueSize), // TODO use atomic
528                         Priority:          channel.priority,
529                         RecentlySent:      channel.recentlySent,
530                 }
531         }
532         return status
533 }
534
535 //-----------------------------------------------------------------------------
536
537 type ChannelDescriptor struct {
538         ID                  byte
539         Priority            int
540         SendQueueCapacity   int
541         RecvBufferCapacity  int
542         RecvMessageCapacity int
543 }
544
545 func (chDesc *ChannelDescriptor) FillDefaults() {
546         if chDesc.SendQueueCapacity == 0 {
547                 chDesc.SendQueueCapacity = defaultSendQueueCapacity
548         }
549         if chDesc.RecvBufferCapacity == 0 {
550                 chDesc.RecvBufferCapacity = defaultRecvBufferCapacity
551         }
552         if chDesc.RecvMessageCapacity == 0 {
553                 chDesc.RecvMessageCapacity = defaultRecvMessageCapacity
554         }
555 }
556
557 // TODO: lowercase.
558 // NOTE: not goroutine-safe.
559 type Channel struct {
560         conn          *MConnection
561         desc          *ChannelDescriptor
562         id            byte
563         sendQueue     chan []byte
564         sendQueueSize int32 // atomic.
565         recving       []byte
566         sending       []byte
567         priority      int
568         recentlySent  int64 // exponential moving average
569 }
570
571 func newChannel(conn *MConnection, desc *ChannelDescriptor) *Channel {
572         desc.FillDefaults()
573         if desc.Priority <= 0 {
574                 cmn.PanicSanity("Channel default priority must be a postive integer")
575         }
576         return &Channel{
577                 conn:      conn,
578                 desc:      desc,
579                 id:        desc.ID,
580                 sendQueue: make(chan []byte, desc.SendQueueCapacity),
581                 recving:   make([]byte, 0, desc.RecvBufferCapacity),
582                 priority:  desc.Priority,
583         }
584 }
585
586 // Queues message to send to this channel.
587 // Goroutine-safe
588 // Times out (and returns false) after defaultSendTimeout
589 func (ch *Channel) sendBytes(bytes []byte) bool {
590         select {
591         case ch.sendQueue <- bytes:
592                 atomic.AddInt32(&ch.sendQueueSize, 1)
593                 return true
594         case <-time.After(defaultSendTimeout):
595                 return false
596         }
597 }
598
599 // Queues message to send to this channel.
600 // Nonblocking, returns true if successful.
601 // Goroutine-safe
602 func (ch *Channel) trySendBytes(bytes []byte) bool {
603         select {
604         case ch.sendQueue <- bytes:
605                 atomic.AddInt32(&ch.sendQueueSize, 1)
606                 return true
607         default:
608                 return false
609         }
610 }
611
612 // Goroutine-safe
613 func (ch *Channel) loadSendQueueSize() (size int) {
614         return int(atomic.LoadInt32(&ch.sendQueueSize))
615 }
616
617 // Goroutine-safe
618 // Use only as a heuristic.
619 func (ch *Channel) canSend() bool {
620         return ch.loadSendQueueSize() < defaultSendQueueCapacity
621 }
622
623 // Returns true if any msgPackets are pending to be sent.
624 // Call before calling nextMsgPacket()
625 // Goroutine-safe
626 func (ch *Channel) isSendPending() bool {
627         if len(ch.sending) == 0 {
628                 if len(ch.sendQueue) == 0 {
629                         return false
630                 }
631                 ch.sending = <-ch.sendQueue
632         }
633         return true
634 }
635
636 // Creates a new msgPacket to send.
637 // Not goroutine-safe
638 func (ch *Channel) nextMsgPacket() msgPacket {
639         packet := msgPacket{}
640         packet.ChannelID = byte(ch.id)
641         packet.Bytes = ch.sending[:cmn.MinInt(maxMsgPacketPayloadSize, len(ch.sending))]
642         if len(ch.sending) <= maxMsgPacketPayloadSize {
643                 packet.EOF = byte(0x01)
644                 ch.sending = nil
645                 atomic.AddInt32(&ch.sendQueueSize, -1) // decrement sendQueueSize
646         } else {
647                 packet.EOF = byte(0x00)
648                 ch.sending = ch.sending[cmn.MinInt(maxMsgPacketPayloadSize, len(ch.sending)):]
649         }
650         return packet
651 }
652
653 // Writes next msgPacket to w.
654 // Not goroutine-safe
655 func (ch *Channel) writeMsgPacketTo(w io.Writer) (n int, err error) {
656         packet := ch.nextMsgPacket()
657         wire.WriteByte(packetTypeMsg, w, &n, &err)
658         wire.WriteBinary(packet, w, &n, &err)
659         if err == nil {
660                 ch.recentlySent += int64(n)
661         }
662         return
663 }
664
665 // Handles incoming msgPackets. Returns a msg bytes if msg is complete.
666 // Not goroutine-safe
667 func (ch *Channel) recvMsgPacket(packet msgPacket) ([]byte, error) {
668         if ch.desc.RecvMessageCapacity < len(ch.recving)+len(packet.Bytes) {
669                 return nil, wire.ErrBinaryReadOverflow
670         }
671         ch.recving = append(ch.recving, packet.Bytes...)
672         if packet.EOF == byte(0x01) {
673                 msgBytes := ch.recving
674                 // clear the slice without re-allocating.
675                 // http://stackoverflow.com/questions/16971741/how-do-you-clear-a-slice-in-go
676                 //   suggests this could be a memory leak, but we might as well keep the memory for the channel until it closes,
677                 //      at which point the recving slice stops being used and should be garbage collected
678                 ch.recving = ch.recving[:0] // make([]byte, 0, ch.desc.RecvBufferCapacity)
679                 return msgBytes, nil
680         }
681         return nil, nil
682 }
683
684 // Call this periodically to update stats for throttling purposes.
685 // Not goroutine-safe
686 func (ch *Channel) updateStats() {
687         // Exponential decay of stats.
688         // TODO: optimize.
689         ch.recentlySent = int64(float64(ch.recentlySent) * 0.8)
690 }
691
692 //-----------------------------------------------------------------------------
693
694 const (
695         maxMsgPacketPayloadSize  = 1024
696         maxMsgPacketOverheadSize = 10 // It's actually lower but good enough
697         maxMsgPacketTotalSize    = maxMsgPacketPayloadSize + maxMsgPacketOverheadSize
698         packetTypePing           = byte(0x01)
699         packetTypePong           = byte(0x02)
700         packetTypeMsg            = byte(0x03)
701 )
702
703 // Messages in channels are chopped into smaller msgPackets for multiplexing.
704 type msgPacket struct {
705         ChannelID byte
706         EOF       byte // 1 means message ends here.
707         Bytes     []byte
708 }
709
710 func (p msgPacket) String() string {
711         return fmt.Sprintf("MsgPacket{%X:%X T:%X}", p.ChannelID, p.Bytes, p.EOF)
712 }