OSDN Git Service

576be24b482729f2ca0da328e7dfea9168918370
[bytom/vapor.git] / p2p / connection / connection.go
1 package connection
2
3 import (
4         "bufio"
5         "fmt"
6         "math"
7         "net"
8         "runtime/debug"
9         "sync/atomic"
10         "time"
11
12         log "github.com/sirupsen/logrus"
13         wire "github.com/tendermint/go-wire"
14         cmn "github.com/tendermint/tmlibs/common"
15         "github.com/tendermint/tmlibs/flowrate"
16
17         "github.com/vapor/common/compression"
18 )
19
20 const (
21         packetTypePing           = byte(0x01)
22         packetTypePong           = byte(0x02)
23         packetTypeMsg            = byte(0x03)
24         maxMsgPacketPayloadSize  = 1024
25         maxMsgPacketOverheadSize = 10 // It's actually lower but good enough
26         maxMsgPacketTotalSize    = maxMsgPacketPayloadSize + maxMsgPacketOverheadSize
27
28         numBatchMsgPackets = 10
29         minReadBufferSize  = 1024
30         minWriteBufferSize = 65536
31         updateState        = 2 * time.Second
32         pingTimeout        = 40 * time.Second
33         flushThrottle      = 100 * time.Millisecond
34
35         defaultSendQueueCapacity   = 1
36         defaultSendRate            = int64(104857600) // 100MB/s
37         defaultRecvBufferCapacity  = 4096
38         defaultRecvMessageCapacity = 22020096         // 21MB
39         defaultRecvRate            = int64(104857600) // 100MB/s
40         defaultSendTimeout         = 10 * time.Second
41         logModule                  = "p2p/conn"
42 )
43
44 type receiveCbFunc func(chID byte, msgBytes []byte)
45 type errorCbFunc func(interface{})
46
47 // Messages in channels are chopped into smaller msgPackets for multiplexing.
48 type msgPacket struct {
49         ChannelID byte
50         EOF       byte // 1 means message ends here.
51         Bytes     []byte
52 }
53
54 func (p msgPacket) String() string {
55         return fmt.Sprintf("MsgPacket{%X:%X T:%X}", p.ChannelID, p.Bytes, p.EOF)
56 }
57
58 /*
59 MConnection handles message transmission on multiple abstract communication
60 `Channel`s.  Each channel has a globally unique byte id.
61 The byte id and the relative priorities of each `Channel` are configured upon
62 initialization of the connection.
63
64 There are two methods for sending messages:
65         func (m MConnection) Send(chID byte, msg interface{}) bool {}
66         func (m MConnection) TrySend(chID byte, msg interface{}) bool {}
67
68 `Send(chID, msg)` is a blocking call that waits until `msg` is successfully queued
69 for the channel with the given id byte `chID`, or until the request times out.
70 The message `msg` is serialized using the `tendermint/wire` submodule's
71 `WriteBinary()` reflection routine.
72
73 `TrySend(chID, msg)` is a nonblocking call that returns false if the channel's
74 queue is full.
75
76 Inbound message bytes are handled with an onReceive callback function.
77 */
78 type MConnection struct {
79         cmn.BaseService
80
81         conn        net.Conn
82         bufReader   *bufio.Reader
83         bufWriter   *bufio.Writer
84         sendMonitor *flowrate.Monitor
85         recvMonitor *flowrate.Monitor
86         send        chan struct{}
87         pong        chan struct{}
88         channels    []*channel
89         channelsIdx map[byte]*channel
90         onReceive   receiveCbFunc
91         onError     errorCbFunc
92         errored     uint32
93         config      *MConnConfig
94
95         quit         chan struct{}
96         flushTimer   *cmn.ThrottleTimer // flush writes as necessary but throttled.
97         pingTimer    *time.Ticker       // send pings periodically
98         chStatsTimer *time.Ticker       // update channel stats periodically
99
100         compression compression.Compression
101 }
102
103 // MConnConfig is a MConnection configuration.
104 type MConnConfig struct {
105         SendRate    int64  `mapstructure:"send_rate"`
106         RecvRate    int64  `mapstructure:"recv_rate"`
107         Compression string `mapstructure:"compression_backend"`
108 }
109
110 // DefaultMConnConfig returns the default config.
111 func DefaultMConnConfig(compression string) *MConnConfig {
112         return &MConnConfig{
113                 SendRate:    defaultSendRate,
114                 RecvRate:    defaultRecvRate,
115                 Compression: compression,
116         }
117 }
118
119 // NewMConnectionWithConfig wraps net.Conn and creates multiplex connection with a config
120 func NewMConnectionWithConfig(conn net.Conn, chDescs []*ChannelDescriptor, onReceive receiveCbFunc, onError errorCbFunc, config *MConnConfig) *MConnection {
121         mconn := &MConnection{
122                 conn:        conn,
123                 bufReader:   bufio.NewReaderSize(conn, minReadBufferSize),
124                 bufWriter:   bufio.NewWriterSize(conn, minWriteBufferSize),
125                 sendMonitor: flowrate.New(0, 0),
126                 recvMonitor: flowrate.New(0, 0),
127                 send:        make(chan struct{}, 1),
128                 pong:        make(chan struct{}, 1),
129                 channelsIdx: map[byte]*channel{},
130                 channels:    []*channel{},
131                 onReceive:   onReceive,
132                 onError:     onError,
133                 config:      config,
134
135                 pingTimer:    time.NewTicker(pingTimeout),
136                 chStatsTimer: time.NewTicker(updateState),
137                 compression:  compression.NewCompression(config.Compression),
138         }
139
140         for _, desc := range chDescs {
141                 descCopy := *desc // copy the desc else unsafe access across connections
142                 channel := newChannel(mconn, &descCopy)
143                 mconn.channelsIdx[channel.id] = channel
144                 mconn.channels = append(mconn.channels, channel)
145         }
146         mconn.BaseService = *cmn.NewBaseService(nil, "MConnection", mconn)
147         return mconn
148 }
149
150 // OnStart implements BaseService
151 func (c *MConnection) OnStart() error {
152         c.BaseService.OnStart()
153         c.quit = make(chan struct{})
154         c.flushTimer = cmn.NewThrottleTimer("flush", flushThrottle)
155         go c.sendRoutine()
156         go c.recvRoutine()
157         return nil
158 }
159
160 // OnStop implements BaseService
161 func (c *MConnection) OnStop() {
162         c.BaseService.OnStop()
163         c.flushTimer.Stop()
164         c.pingTimer.Stop()
165         c.chStatsTimer.Stop()
166         if c.quit != nil {
167                 close(c.quit)
168         }
169         c.conn.Close()
170         // We can't close pong safely here because recvRoutine may write to it after we've
171         // stopped. Though it doesn't need to get closed at all, we close it @ recvRoutine.
172 }
173
174 // CanSend returns true if you can send more data onto the chID, false otherwise
175 func (c *MConnection) CanSend(chID byte) bool {
176         if !c.IsRunning() {
177                 return false
178         }
179
180         channel, ok := c.channelsIdx[chID]
181         if !ok {
182                 return false
183         }
184         return channel.canSend()
185 }
186
187 // Send will queues a message to be sent to channel(blocking).
188 func (c *MConnection) Send(chID byte, msg interface{}) bool {
189         if !c.IsRunning() {
190                 return false
191         }
192
193         channel, ok := c.channelsIdx[chID]
194         if !ok {
195                 log.WithFields(log.Fields{"module": logModule, "chID": chID}).Error("cannot send bytes due to unknown channel")
196                 return false
197         }
198
199         compressData := c.compression.CompressBytes(wire.BinaryBytes(msg))
200
201         if !channel.sendBytes(compressData) {
202                 log.WithFields(log.Fields{"module": logModule, "chID": chID, "conn": c, "msg": msg}).Error("MConnection send failed")
203                 return false
204         }
205
206         select {
207         case c.send <- struct{}{}:
208         default:
209         }
210         return true
211 }
212
213 // TrafficStatus return the in and out traffic status
214 func (c *MConnection) TrafficStatus() (*flowrate.Status, *flowrate.Status) {
215         sentStatus := c.sendMonitor.Status()
216         receivedStatus := c.recvMonitor.Status()
217         return &sentStatus, &receivedStatus
218 }
219
220 // TrySend queues a message to be sent to channel(Nonblocking).
221 func (c *MConnection) TrySend(chID byte, msg interface{}) bool {
222         if !c.IsRunning() {
223                 return false
224         }
225
226         channel, ok := c.channelsIdx[chID]
227         if !ok {
228                 log.WithFields(log.Fields{"module": logModule, "chID": chID}).Error("cannot send bytes due to unknown channel")
229                 return false
230         }
231
232         compressData := c.compression.CompressBytes(wire.BinaryBytes(msg))
233
234         ok = channel.trySendBytes(compressData)
235         if ok {
236                 select {
237                 case c.send <- struct{}{}:
238                 default:
239                 }
240         }
241         return ok
242 }
243
244 func (c *MConnection) String() string {
245         return fmt.Sprintf("MConn{%v}", c.conn.RemoteAddr())
246 }
247
248 func (c *MConnection) flush() {
249         if err := c.bufWriter.Flush(); err != nil {
250                 log.WithFields(log.Fields{"module": logModule, "error": err}).Error("MConnection flush failed")
251         }
252 }
253
254 // Catch panics, usually caused by remote disconnects.
255 func (c *MConnection) _recover() {
256         if r := recover(); r != nil {
257                 stack := debug.Stack()
258                 err := cmn.StackError{r, stack}
259                 c.stopForError(err)
260         }
261 }
262
263 // recvRoutine reads msgPackets and reconstructs the message using the channels' "recving" buffer.
264 // After a whole message has been assembled, it's pushed to onReceive().
265 // Blocks depending on how the connection is throttled.
266 func (c *MConnection) recvRoutine() {
267         defer c._recover()
268         defer close(c.pong)
269
270         for {
271                 // Block until .recvMonitor says we can read.
272                 c.recvMonitor.Limit(maxMsgPacketTotalSize, atomic.LoadInt64(&c.config.RecvRate), true)
273
274                 // Read packet type
275                 var n int
276                 var err error
277                 pktType := wire.ReadByte(c.bufReader, &n, &err)
278                 c.recvMonitor.Update(int(n))
279                 if err != nil {
280                         if c.IsRunning() {
281                                 log.WithFields(log.Fields{"module": logModule, "conn": c, "error": err}).Error("Connection failed @ recvRoutine (reading byte)")
282                                 c.conn.Close()
283                                 c.stopForError(err)
284                         }
285                         return
286                 }
287
288                 // Read more depending on packet type.
289                 switch pktType {
290                 case packetTypePing:
291                         log.WithFields(log.Fields{"module": logModule, "conn": c}).Debug("receive Ping")
292                         select {
293                         case c.pong <- struct{}{}:
294                         default:
295                         }
296
297                 case packetTypePong:
298                         log.WithFields(log.Fields{"module": logModule, "conn": c}).Debug("receive Pong")
299
300                 case packetTypeMsg:
301                         pkt, n, err := msgPacket{}, int(0), error(nil)
302                         wire.ReadBinaryPtr(&pkt, c.bufReader, maxMsgPacketTotalSize, &n, &err)
303                         c.recvMonitor.Update(int(n))
304                         if err != nil {
305                                 if c.IsRunning() {
306                                         log.WithFields(log.Fields{"module": logModule, "conn": c, "error": err}).Error("failed on recvRoutine")
307                                         c.stopForError(err)
308                                 }
309                                 return
310                         }
311
312                         channel, ok := c.channelsIdx[pkt.ChannelID]
313                         if !ok || channel == nil {
314                                 cmn.PanicQ(cmn.Fmt("Unknown channel %X", pkt.ChannelID))
315                         }
316
317                         msgBytes, err := channel.recvMsgPacket(pkt)
318                         if err != nil {
319                                 if c.IsRunning() {
320                                         log.WithFields(log.Fields{"module": logModule, "conn": c, "error": err}).Error("failed on recvRoutine")
321                                         c.stopForError(err)
322                                 }
323                                 return
324                         }
325
326                         if msgBytes != nil {
327                                 data, err := c.compression.DecompressBytes(msgBytes)
328                                 if err != nil {
329                                         log.WithFields(log.Fields{"module": logModule, "conn": c, "error": err}).Error("failed decompress bytes")
330                                         return
331                                 }
332                                 c.onReceive(pkt.ChannelID, data)
333                         }
334
335                 default:
336                         cmn.PanicSanity(cmn.Fmt("Unknown message type %X", pktType))
337                 }
338         }
339 }
340
341 // Returns true if messages from channels were exhausted.
342 func (c *MConnection) sendMsgPacket() bool {
343         var leastRatio float32 = math.MaxFloat32
344         var leastChannel *channel
345         for _, channel := range c.channels {
346                 if !channel.isSendPending() {
347                         continue
348                 }
349                 if ratio := float32(channel.recentlySent) / float32(channel.priority); ratio < leastRatio {
350                         leastRatio = ratio
351                         leastChannel = channel
352                 }
353         }
354         if leastChannel == nil {
355                 return true
356         }
357
358         n, err := leastChannel.writeMsgPacketTo(c.bufWriter)
359         if err != nil {
360                 log.WithFields(log.Fields{"module": logModule, "error": err}).Error("failed to write msgPacket")
361                 c.stopForError(err)
362                 return true
363         }
364         c.sendMonitor.Update(int(n))
365         c.flushTimer.Set()
366         return false
367 }
368
369 // sendRoutine polls for packets to send from channels.
370 func (c *MConnection) sendRoutine() {
371         defer c._recover()
372
373         for {
374                 var n int
375                 var err error
376                 select {
377                 case <-c.flushTimer.Ch:
378                         c.flush()
379                 case <-c.chStatsTimer.C:
380                         for _, channel := range c.channels {
381                                 channel.updateStats()
382                         }
383                 case <-c.pingTimer.C:
384                         log.WithFields(log.Fields{"module": logModule, "conn": c}).Debug("send Ping")
385                         wire.WriteByte(packetTypePing, c.bufWriter, &n, &err)
386                         c.sendMonitor.Update(int(n))
387                         c.flush()
388                 case <-c.pong:
389                         log.WithFields(log.Fields{"module": logModule, "conn": c}).Debug("send Pong")
390                         wire.WriteByte(packetTypePong, c.bufWriter, &n, &err)
391                         c.sendMonitor.Update(int(n))
392                         c.flush()
393                 case <-c.quit:
394                         return
395                 case <-c.send:
396                         if eof := c.sendSomeMsgPackets(); !eof {
397                                 select {
398                                 case c.send <- struct{}{}:
399                                 default:
400                                 }
401                         }
402                 }
403
404                 if !c.IsRunning() {
405                         return
406                 }
407                 if err != nil {
408                         log.WithFields(log.Fields{"module": logModule, "conn": c, "error": err}).Error("Connection failed @ sendRoutine")
409                         c.stopForError(err)
410                         return
411                 }
412         }
413 }
414
415 // Returns true if messages from channels were exhausted.
416 func (c *MConnection) sendSomeMsgPackets() bool {
417         // Block until .sendMonitor says we can write.
418         // Once we're ready we send more than we asked for,
419         // but amortized it should even out.
420         c.sendMonitor.Limit(maxMsgPacketTotalSize, atomic.LoadInt64(&c.config.SendRate), true)
421         for i := 0; i < numBatchMsgPackets; i++ {
422                 if c.sendMsgPacket() {
423                         return true
424                 }
425         }
426         return false
427 }
428
429 func (c *MConnection) stopForError(r interface{}) {
430         c.Stop()
431         if atomic.CompareAndSwapUint32(&c.errored, 0, 1) && c.onError != nil {
432                 c.onError(r)
433         }
434 }