OSDN Git Service

77595d43b9e97dce502ddceebb18a77db52d1bff
[bytom/bytom.git] / p2p / connection / connection.go
1 package connection
2
3 import (
4         "bufio"
5         "fmt"
6         "io"
7         "math"
8         "net"
9         "runtime/debug"
10         "sync/atomic"
11         "time"
12
13         log "github.com/sirupsen/logrus"
14         wire "github.com/tendermint/go-wire"
15         cmn "github.com/tendermint/tmlibs/common"
16         flow "github.com/tendermint/tmlibs/flowrate"
17 )
18
19 const (
20         numBatchMsgPackets = 10
21         minReadBufferSize  = 1024
22         minWriteBufferSize = 65536
23         updateState        = 2 * time.Second
24         pingTimeout        = 40 * time.Second
25         flushThrottle      = 100 * time.Millisecond
26
27         defaultSendQueueCapacity   = 1
28         defaultSendRate            = int64(512000) // 500KB/s
29         defaultRecvBufferCapacity  = 4096
30         defaultRecvMessageCapacity = 22020096      // 21MB
31         defaultRecvRate            = int64(512000) // 500KB/s
32         defaultSendTimeout         = 10 * time.Second
33 )
34
35 type receiveCbFunc func(chID byte, msgBytes []byte)
36 type errorCbFunc func(interface{})
37
38 /*
39 Each peer has one `MConnection` (multiplex connection) instance.
40
41 __multiplex__ *noun* a system or signal involving simultaneous transmission of
42 several messages along a single channel of communication.
43
44 Each `MConnection` handles message transmission on multiple abstract communication
45 `Channel`s.  Each channel has a globally unique byte id.
46 The byte id and the relative priorities of each `Channel` are configured upon
47 initialization of the connection.
48
49 There are two methods for sending messages:
50         func (m MConnection) Send(chID byte, msg interface{}) bool {}
51         func (m MConnection) TrySend(chID byte, msg interface{}) bool {}
52
53 `Send(chID, msg)` is a blocking call that waits until `msg` is successfully queued
54 for the channel with the given id byte `chID`, or until the request times out.
55 The message `msg` is serialized using the `tendermint/wire` submodule's
56 `WriteBinary()` reflection routine.
57
58 `TrySend(chID, msg)` is a nonblocking call that returns false if the channel's
59 queue is full.
60
61 Inbound message bytes are handled with an onReceive callback function.
62 */
63 type MConnection struct {
64         cmn.BaseService
65
66         conn        net.Conn
67         bufReader   *bufio.Reader
68         bufWriter   *bufio.Writer
69         sendMonitor *flow.Monitor
70         recvMonitor *flow.Monitor
71         send        chan struct{}
72         pong        chan struct{}
73         channels    []*Channel
74         channelsIdx map[byte]*Channel
75         onReceive   receiveCbFunc
76         onError     errorCbFunc
77         errored     uint32
78         config      *MConnConfig
79
80         quit         chan struct{}
81         flushTimer   *cmn.ThrottleTimer // flush writes as necessary but throttled.
82         pingTimer    *time.Ticker       // send pings periodically
83         chStatsTimer *time.Ticker       // update channel stats periodically
84 }
85
86 // MConnConfig is a MConnection configuration.
87 type MConnConfig struct {
88         SendRate int64 `mapstructure:"send_rate"`
89         RecvRate int64 `mapstructure:"recv_rate"`
90 }
91
92 // DefaultMConnConfig returns the default config.
93 func DefaultMConnConfig() *MConnConfig {
94         return &MConnConfig{
95                 SendRate: defaultSendRate,
96                 RecvRate: defaultRecvRate,
97         }
98 }
99
100 // NewMConnection wraps net.Conn and creates multiplex connection
101 func NewMConnection(conn net.Conn, chDescs []*ChannelDescriptor, onReceive receiveCbFunc, onError errorCbFunc) *MConnection {
102         return NewMConnectionWithConfig(
103                 conn,
104                 chDescs,
105                 onReceive,
106                 onError,
107                 DefaultMConnConfig())
108 }
109
110 // NewMConnectionWithConfig wraps net.Conn and creates multiplex connection with a config
111 func NewMConnectionWithConfig(conn net.Conn, chDescs []*ChannelDescriptor, onReceive receiveCbFunc, onError errorCbFunc, config *MConnConfig) *MConnection {
112         mconn := &MConnection{
113                 conn:        conn,
114                 bufReader:   bufio.NewReaderSize(conn, minReadBufferSize),
115                 bufWriter:   bufio.NewWriterSize(conn, minWriteBufferSize),
116                 sendMonitor: flow.New(0, 0),
117                 recvMonitor: flow.New(0, 0),
118                 send:        make(chan struct{}, 1),
119                 pong:        make(chan struct{}),
120                 onReceive:   onReceive,
121                 onError:     onError,
122                 config:      config,
123
124                 pingTimer:    time.NewTicker(pingTimeout),
125                 chStatsTimer: time.NewTicker(updateState),
126         }
127
128         // Create channels
129         var channelsIdx = map[byte]*Channel{}
130         var channels = []*Channel{}
131
132         for _, desc := range chDescs {
133                 descCopy := *desc // copy the desc else unsafe access across connections
134                 channel := newChannel(mconn, &descCopy)
135                 channelsIdx[channel.id] = channel
136                 channels = append(channels, channel)
137         }
138         mconn.channels = channels
139         mconn.channelsIdx = channelsIdx
140
141         mconn.BaseService = *cmn.NewBaseService(nil, "MConnection", mconn)
142
143         return mconn
144 }
145
146 func (c *MConnection) OnStart() error {
147         c.BaseService.OnStart()
148         c.quit = make(chan struct{})
149         c.flushTimer = cmn.NewThrottleTimer("flush", flushThrottle)
150         go c.sendRoutine()
151         go c.recvRoutine()
152         return nil
153 }
154
155 func (c *MConnection) OnStop() {
156         c.BaseService.OnStop()
157         c.flushTimer.Stop()
158         if c.quit != nil {
159                 close(c.quit)
160         }
161         c.conn.Close()
162         // We can't close pong safely here because
163         // recvRoutine may write to it after we've stopped.
164         // Though it doesn't need to get closed at all,
165         // we close it @ recvRoutine.
166         // close(c.pong)
167 }
168
169 func (c *MConnection) String() string {
170         return fmt.Sprintf("MConn{%v}", c.conn.RemoteAddr())
171 }
172
173 func (c *MConnection) flush() {
174         log.WithField("conn", c).Debug("Flush")
175         err := c.bufWriter.Flush()
176         if err != nil {
177                 log.WithField("error", err).Error("MConnection flush failed")
178         }
179 }
180
181 // Catch panics, usually caused by remote disconnects.
182 func (c *MConnection) _recover() {
183         if r := recover(); r != nil {
184                 stack := debug.Stack()
185                 err := cmn.StackError{r, stack}
186                 c.stopForError(err)
187         }
188 }
189
190 func (c *MConnection) stopForError(r interface{}) {
191         c.Stop()
192         if atomic.CompareAndSwapUint32(&c.errored, 0, 1) {
193                 if c.onError != nil {
194                         c.onError(r)
195                 }
196         }
197 }
198
199 // Queues a message to be sent to channel.
200 func (c *MConnection) Send(chID byte, msg interface{}) bool {
201         if !c.IsRunning() {
202                 return false
203         }
204
205         log.WithFields(log.Fields{
206                 "chID": chID,
207                 "conn": c,
208                 "msg":  msg,
209         }).Debug("Send")
210
211         // Send message to channel.
212         channel, ok := c.channelsIdx[chID]
213         if !ok {
214                 log.WithField("chID", chID).Error(cmn.Fmt("Cannot send bytes, unknown channel"))
215                 return false
216         }
217
218         success := channel.sendBytes(wire.BinaryBytes(msg))
219         if success {
220                 // Wake up sendRoutine if necessary
221                 select {
222                 case c.send <- struct{}{}:
223                 default:
224                 }
225         } else {
226                 log.WithFields(log.Fields{
227                         "chID": chID,
228                         "conn": c,
229                         "msg":  msg,
230                 }).Error("Send failed")
231         }
232         return success
233 }
234
235 // Queues a message to be sent to channel.
236 // Nonblocking, returns true if successful.
237 func (c *MConnection) TrySend(chID byte, msg interface{}) bool {
238         if !c.IsRunning() {
239                 return false
240         }
241
242         log.WithFields(log.Fields{
243                 "chID": chID,
244                 "conn": c,
245                 "msg":  msg,
246         }).Debug("TrySend")
247
248         // Send message to channel.
249         channel, ok := c.channelsIdx[chID]
250         if !ok {
251                 log.WithField("chID", chID).Error(cmn.Fmt("cannot send bytes, unknown channel"))
252                 return false
253         }
254
255         ok = channel.trySendBytes(wire.BinaryBytes(msg))
256         if ok {
257                 // Wake up sendRoutine if necessary
258                 select {
259                 case c.send <- struct{}{}:
260                 default:
261                 }
262         }
263
264         return ok
265 }
266
267 // CanSend returns true if you can send more data onto the chID, false
268 // otherwise. Use only as a heuristic.
269 func (c *MConnection) CanSend(chID byte) bool {
270         if !c.IsRunning() {
271                 return false
272         }
273
274         channel, ok := c.channelsIdx[chID]
275         if !ok {
276                 log.WithField("chID", chID).Error(cmn.Fmt("Unknown channel"))
277                 return false
278         }
279         return channel.canSend()
280 }
281
282 // sendRoutine polls for packets to send from channels.
283 func (c *MConnection) sendRoutine() {
284         defer c._recover()
285
286 FOR_LOOP:
287         for {
288                 var n int
289                 var err error
290                 select {
291                 case <-c.flushTimer.Ch:
292                         // NOTE: flushTimer.Set() must be called every time
293                         // something is written to .bufWriter.
294                         c.flush()
295                 case <-c.chStatsTimer.C:
296                         for _, channel := range c.channels {
297                                 channel.updateStats()
298                         }
299                 case <-c.pingTimer.C:
300                         log.Debug("Send Ping")
301                         wire.WriteByte(packetTypePing, c.bufWriter, &n, &err)
302                         c.sendMonitor.Update(int(n))
303                         c.flush()
304                 case <-c.pong:
305                         log.Debug("Send Pong")
306                         wire.WriteByte(packetTypePong, c.bufWriter, &n, &err)
307                         c.sendMonitor.Update(int(n))
308                         c.flush()
309                 case <-c.quit:
310                         break FOR_LOOP
311                 case <-c.send:
312                         // Send some msgPackets
313                         eof := c.sendSomeMsgPackets()
314                         if !eof {
315                                 // Keep sendRoutine awake.
316                                 select {
317                                 case c.send <- struct{}{}:
318                                 default:
319                                 }
320                         }
321                 }
322
323                 if !c.IsRunning() {
324                         break FOR_LOOP
325                 }
326                 if err != nil {
327                         log.WithFields(log.Fields{
328                                 "conn":  c,
329                                 "error": err,
330                         }).Error("Connection failed @ sendRoutine")
331                         c.stopForError(err)
332                         break FOR_LOOP
333                 }
334         }
335
336         // Cleanup
337 }
338
339 // Returns true if messages from channels were exhausted.
340 // Blocks in accordance to .sendMonitor throttling.
341 func (c *MConnection) sendSomeMsgPackets() bool {
342         // Block until .sendMonitor says we can write.
343         // Once we're ready we send more than we asked for,
344         // but amortized it should even out.
345         c.sendMonitor.Limit(maxMsgPacketTotalSize, atomic.LoadInt64(&c.config.SendRate), true)
346
347         // Now send some msgPackets.
348         for i := 0; i < numBatchMsgPackets; i++ {
349                 if c.sendMsgPacket() {
350                         return true
351                 }
352         }
353         return false
354 }
355
356 // Returns true if messages from channels were exhausted.
357 func (c *MConnection) sendMsgPacket() bool {
358         // Choose a channel to create a msgPacket from.
359         // The chosen channel will be the one whose recentlySent/priority is the least.
360         var leastRatio float32 = math.MaxFloat32
361         var leastChannel *Channel
362         for _, channel := range c.channels {
363                 // If nothing to send, skip this channel
364                 if !channel.isSendPending() {
365                         continue
366                 }
367                 // Get ratio, and keep track of lowest ratio.
368                 ratio := float32(channel.recentlySent) / float32(channel.priority)
369                 if ratio < leastRatio {
370                         leastRatio = ratio
371                         leastChannel = channel
372                 }
373         }
374
375         // Nothing to send?
376         if leastChannel == nil {
377                 return true
378         } else {
379                 // c.Logger.Info("Found a msgPacket to send")
380         }
381
382         // Make & send a msgPacket from this channel
383         n, err := leastChannel.writeMsgPacketTo(c.bufWriter)
384         if err != nil {
385                 log.WithField("error", err).Error("Failed to write msgPacket")
386                 c.stopForError(err)
387                 return true
388         }
389         c.sendMonitor.Update(int(n))
390         c.flushTimer.Set()
391         return false
392 }
393
394 // recvRoutine reads msgPackets and reconstructs the message using the channels' "recving" buffer.
395 // After a whole message has been assembled, it's pushed to onReceive().
396 // Blocks depending on how the connection is throttled.
397 func (c *MConnection) recvRoutine() {
398         defer c._recover()
399
400 FOR_LOOP:
401         for {
402                 // Block until .recvMonitor says we can read.
403                 c.recvMonitor.Limit(maxMsgPacketTotalSize, atomic.LoadInt64(&c.config.RecvRate), true)
404
405                 /*
406                         // Peek into bufReader for debugging
407                         if numBytes := c.bufReader.Buffered(); numBytes > 0 {
408                                 log.Infof("Peek connection buffer numBytes:", numBytes)
409                                 bytes, err := c.bufReader.Peek(cmn.MinInt(numBytes, 100))
410                                 if err == nil {
411                                         log.Infof("bytes:", bytes)
412                                 } else {
413                                         log.Warning("Error peeking connection buffer err:", err)
414                                 }
415                         } else {
416                                 log.Warning("Received bytes number is:", numBytes)
417                         }
418                 */
419
420                 // Read packet type
421                 var n int
422                 var err error
423                 pktType := wire.ReadByte(c.bufReader, &n, &err)
424                 c.recvMonitor.Update(int(n))
425                 if err != nil {
426                         if c.IsRunning() {
427                                 log.WithFields(log.Fields{
428                                         "conn":  c,
429                                         "error": err,
430                                 }).Error("Connection failed @ recvRoutine (reading byte)")
431                                 c.conn.Close()
432                                 c.stopForError(err)
433                         }
434                         break FOR_LOOP
435                 }
436
437                 // Read more depending on packet type.
438                 switch pktType {
439                 case packetTypePing:
440                         // TODO: prevent abuse, as they cause flush()'s.
441                         log.Debug("Receive Ping")
442                         c.pong <- struct{}{}
443                 case packetTypePong:
444                         // do nothing
445                         log.Debug("Receive Pong")
446                 case packetTypeMsg:
447                         pkt, n, err := msgPacket{}, int(0), error(nil)
448                         wire.ReadBinaryPtr(&pkt, c.bufReader, maxMsgPacketTotalSize, &n, &err)
449                         c.recvMonitor.Update(int(n))
450                         if err != nil {
451                                 if c.IsRunning() {
452                                         log.WithFields(log.Fields{
453                                                 "conn":  c,
454                                                 "error": err,
455                                         }).Error("Connection failed @ recvRoutine")
456                                         c.stopForError(err)
457                                 }
458                                 break FOR_LOOP
459                         }
460                         channel, ok := c.channelsIdx[pkt.ChannelID]
461                         if !ok || channel == nil {
462                                 cmn.PanicQ(cmn.Fmt("Unknown channel %X", pkt.ChannelID))
463                         }
464                         msgBytes, err := channel.recvMsgPacket(pkt)
465                         if err != nil {
466                                 if c.IsRunning() {
467                                         log.WithFields(log.Fields{
468                                                 "conn":  c,
469                                                 "error": err,
470                                         }).Error("Connection failed @ recvRoutine")
471                                         c.stopForError(err)
472                                 }
473                                 break FOR_LOOP
474                         }
475                         if msgBytes != nil {
476                                 log.WithFields(log.Fields{
477                                         "channelID": pkt.ChannelID,
478                                         "msgBytes":  msgBytes,
479                                 }).Debug("Received bytes")
480                                 c.onReceive(pkt.ChannelID, msgBytes)
481                         }
482                 default:
483                         cmn.PanicSanity(cmn.Fmt("Unknown message type %X", pktType))
484                 }
485         }
486
487         // Cleanup
488         close(c.pong)
489         for _ = range c.pong {
490                 // Drain
491         }
492 }
493
494 type ConnectionStatus struct {
495         SendMonitor flow.Status
496         RecvMonitor flow.Status
497         Channels    []ChannelStatus
498 }
499
500 type ChannelStatus struct {
501         ID                byte
502         SendQueueCapacity int
503         SendQueueSize     int
504         Priority          int
505         RecentlySent      int64
506 }
507
508 func (c *MConnection) Status() ConnectionStatus {
509         var status ConnectionStatus
510         status.SendMonitor = c.sendMonitor.Status()
511         status.RecvMonitor = c.recvMonitor.Status()
512         status.Channels = make([]ChannelStatus, len(c.channels))
513         for i, channel := range c.channels {
514                 status.Channels[i] = ChannelStatus{
515                         ID:                channel.id,
516                         SendQueueCapacity: cap(channel.sendQueue),
517                         SendQueueSize:     int(channel.sendQueueSize), // TODO use atomic
518                         Priority:          channel.priority,
519                         RecentlySent:      channel.recentlySent,
520                 }
521         }
522         return status
523 }
524
525 //-----------------------------------------------------------------------------
526
527 type ChannelDescriptor struct {
528         ID                  byte
529         Priority            int
530         SendQueueCapacity   int
531         RecvBufferCapacity  int
532         RecvMessageCapacity int
533 }
534
535 func (chDesc *ChannelDescriptor) FillDefaults() {
536         if chDesc.SendQueueCapacity == 0 {
537                 chDesc.SendQueueCapacity = defaultSendQueueCapacity
538         }
539         if chDesc.RecvBufferCapacity == 0 {
540                 chDesc.RecvBufferCapacity = defaultRecvBufferCapacity
541         }
542         if chDesc.RecvMessageCapacity == 0 {
543                 chDesc.RecvMessageCapacity = defaultRecvMessageCapacity
544         }
545 }
546
547 // TODO: lowercase.
548 // NOTE: not goroutine-safe.
549 type Channel struct {
550         conn          *MConnection
551         desc          *ChannelDescriptor
552         id            byte
553         sendQueue     chan []byte
554         sendQueueSize int32 // atomic.
555         recving       []byte
556         sending       []byte
557         priority      int
558         recentlySent  int64 // exponential moving average
559 }
560
561 func newChannel(conn *MConnection, desc *ChannelDescriptor) *Channel {
562         desc.FillDefaults()
563         if desc.Priority <= 0 {
564                 cmn.PanicSanity("Channel default priority must be a postive integer")
565         }
566         return &Channel{
567                 conn:      conn,
568                 desc:      desc,
569                 id:        desc.ID,
570                 sendQueue: make(chan []byte, desc.SendQueueCapacity),
571                 recving:   make([]byte, 0, desc.RecvBufferCapacity),
572                 priority:  desc.Priority,
573         }
574 }
575
576 // Queues message to send to this channel.
577 // Goroutine-safe
578 // Times out (and returns false) after defaultSendTimeout
579 func (ch *Channel) sendBytes(bytes []byte) bool {
580         select {
581         case ch.sendQueue <- bytes:
582                 atomic.AddInt32(&ch.sendQueueSize, 1)
583                 return true
584         case <-time.After(defaultSendTimeout):
585                 return false
586         }
587 }
588
589 // Queues message to send to this channel.
590 // Nonblocking, returns true if successful.
591 // Goroutine-safe
592 func (ch *Channel) trySendBytes(bytes []byte) bool {
593         select {
594         case ch.sendQueue <- bytes:
595                 atomic.AddInt32(&ch.sendQueueSize, 1)
596                 return true
597         default:
598                 return false
599         }
600 }
601
602 // Goroutine-safe
603 func (ch *Channel) loadSendQueueSize() (size int) {
604         return int(atomic.LoadInt32(&ch.sendQueueSize))
605 }
606
607 // Goroutine-safe
608 // Use only as a heuristic.
609 func (ch *Channel) canSend() bool {
610         return ch.loadSendQueueSize() < defaultSendQueueCapacity
611 }
612
613 // Returns true if any msgPackets are pending to be sent.
614 // Call before calling nextMsgPacket()
615 // Goroutine-safe
616 func (ch *Channel) isSendPending() bool {
617         if len(ch.sending) == 0 {
618                 if len(ch.sendQueue) == 0 {
619                         return false
620                 }
621                 ch.sending = <-ch.sendQueue
622         }
623         return true
624 }
625
626 // Creates a new msgPacket to send.
627 // Not goroutine-safe
628 func (ch *Channel) nextMsgPacket() msgPacket {
629         packet := msgPacket{}
630         packet.ChannelID = byte(ch.id)
631         packet.Bytes = ch.sending[:cmn.MinInt(maxMsgPacketPayloadSize, len(ch.sending))]
632         if len(ch.sending) <= maxMsgPacketPayloadSize {
633                 packet.EOF = byte(0x01)
634                 ch.sending = nil
635                 atomic.AddInt32(&ch.sendQueueSize, -1) // decrement sendQueueSize
636         } else {
637                 packet.EOF = byte(0x00)
638                 ch.sending = ch.sending[cmn.MinInt(maxMsgPacketPayloadSize, len(ch.sending)):]
639         }
640         return packet
641 }
642
643 // Writes next msgPacket to w.
644 // Not goroutine-safe
645 func (ch *Channel) writeMsgPacketTo(w io.Writer) (n int, err error) {
646         packet := ch.nextMsgPacket()
647         wire.WriteByte(packetTypeMsg, w, &n, &err)
648         wire.WriteBinary(packet, w, &n, &err)
649         if err == nil {
650                 ch.recentlySent += int64(n)
651         }
652         return
653 }
654
655 // Handles incoming msgPackets. Returns a msg bytes if msg is complete.
656 // Not goroutine-safe
657 func (ch *Channel) recvMsgPacket(packet msgPacket) ([]byte, error) {
658         if ch.desc.RecvMessageCapacity < len(ch.recving)+len(packet.Bytes) {
659                 return nil, wire.ErrBinaryReadOverflow
660         }
661         ch.recving = append(ch.recving, packet.Bytes...)
662         if packet.EOF == byte(0x01) {
663                 msgBytes := ch.recving
664                 // clear the slice without re-allocating.
665                 // http://stackoverflow.com/questions/16971741/how-do-you-clear-a-slice-in-go
666                 //   suggests this could be a memory leak, but we might as well keep the memory for the channel until it closes,
667                 //      at which point the recving slice stops being used and should be garbage collected
668                 ch.recving = ch.recving[:0] // make([]byte, 0, ch.desc.RecvBufferCapacity)
669                 return msgBytes, nil
670         }
671         return nil, nil
672 }
673
674 // Call this periodically to update stats for throttling purposes.
675 // Not goroutine-safe
676 func (ch *Channel) updateStats() {
677         // Exponential decay of stats.
678         // TODO: optimize.
679         ch.recentlySent = int64(float64(ch.recentlySent) * 0.8)
680 }
681
682 //-----------------------------------------------------------------------------
683
684 const (
685         maxMsgPacketPayloadSize  = 1024
686         maxMsgPacketOverheadSize = 10 // It's actually lower but good enough
687         maxMsgPacketTotalSize    = maxMsgPacketPayloadSize + maxMsgPacketOverheadSize
688         packetTypePing           = byte(0x01)
689         packetTypePong           = byte(0x02)
690         packetTypeMsg            = byte(0x03)
691 )
692
693 // Messages in channels are chopped into smaller msgPackets for multiplexing.
694 type msgPacket struct {
695         ChannelID byte
696         EOF       byte // 1 means message ends here.
697         Bytes     []byte
698 }
699
700 func (p msgPacket) String() string {
701         return fmt.Sprintf("MsgPacket{%X:%X T:%X}", p.ChannelID, p.Bytes, p.EOF)
702 }