From c108d6075cc29d783d2bab575e1bd8dcc8942aa7 Mon Sep 17 00:00:00 2001 From: paladz <453256728@qq.com> Date: Tue, 29 May 2018 18:20:20 +0800 Subject: [PATCH] edit the connection --- p2p/connection/channel.go | 161 +++++++++ p2p/connection/connection.go | 593 ++++++++----------------------- p2p/connection/connection_test.go | 19 +- p2p/connection/secret_connection.go | 269 ++++++-------- p2p/connection/secret_connection_test.go | 2 +- 5 files changed, 429 insertions(+), 615 deletions(-) create mode 100644 p2p/connection/channel.go diff --git a/p2p/connection/channel.go b/p2p/connection/channel.go new file mode 100644 index 00000000..4834e56a --- /dev/null +++ b/p2p/connection/channel.go @@ -0,0 +1,161 @@ +package connection + +import ( + "io" + "sync/atomic" + "time" + + wire "github.com/tendermint/go-wire" + cmn "github.com/tendermint/tmlibs/common" +) + +// ChannelDescriptor is the setting of channel +type ChannelDescriptor struct { + ID byte + Priority int + SendQueueCapacity int + RecvBufferCapacity int + RecvMessageCapacity int +} + +// FillDefaults set the channel config if empty +func (chDesc *ChannelDescriptor) FillDefaults() { + if chDesc.SendQueueCapacity == 0 { + chDesc.SendQueueCapacity = defaultSendQueueCapacity + } + if chDesc.RecvBufferCapacity == 0 { + chDesc.RecvBufferCapacity = defaultRecvBufferCapacity + } + if chDesc.RecvMessageCapacity == 0 { + chDesc.RecvMessageCapacity = defaultRecvMessageCapacity + } +} + +type channel struct { + conn *MConnection + desc *ChannelDescriptor + id byte + sendQueue chan []byte + sendQueueSize int32 // atomic. + recving []byte + sending []byte + priority int + recentlySent int64 // exponential moving average +} + +func newChannel(conn *MConnection, desc *ChannelDescriptor) *channel { + desc.FillDefaults() + if desc.Priority <= 0 { + cmn.PanicSanity("Channel default priority must be a postive integer") + } + return &channel{ + conn: conn, + desc: desc, + id: desc.ID, + sendQueue: make(chan []byte, desc.SendQueueCapacity), + recving: make([]byte, 0, desc.RecvBufferCapacity), + priority: desc.Priority, + } +} + +// Goroutine-safe +// Use only as a heuristic. +func (ch *channel) canSend() bool { + return ch.loadSendQueueSize() < defaultSendQueueCapacity +} + +// Returns true if any msgPackets are pending to be sent. +// Call before calling nextMsgPacket() +// Goroutine-safe +func (ch *channel) isSendPending() bool { + if len(ch.sending) == 0 { + if len(ch.sendQueue) == 0 { + return false + } + ch.sending = <-ch.sendQueue + } + return true +} + +// Goroutine-safe +func (ch *channel) loadSendQueueSize() (size int) { + return int(atomic.LoadInt32(&ch.sendQueueSize)) +} + +// Creates a new msgPacket to send. +// Not goroutine-safe +func (ch *channel) nextMsgPacket() msgPacket { + packet := msgPacket{ + ChannelID: byte(ch.id), + Bytes: ch.sending[:cmn.MinInt(maxMsgPacketPayloadSize, len(ch.sending))], + } + if len(ch.sending) <= maxMsgPacketPayloadSize { + packet.EOF = byte(0x01) + ch.sending = nil + atomic.AddInt32(&ch.sendQueueSize, -1) // decrement sendQueueSize + } else { + packet.EOF = byte(0x00) + ch.sending = ch.sending[cmn.MinInt(maxMsgPacketPayloadSize, len(ch.sending)):] + } + return packet +} + +// Handles incoming msgPackets. Returns a msg bytes if msg is complete. +// Not goroutine-safe +func (ch *channel) recvMsgPacket(packet msgPacket) ([]byte, error) { + if ch.desc.RecvMessageCapacity < len(ch.recving)+len(packet.Bytes) { + return nil, wire.ErrBinaryReadOverflow + } + + ch.recving = append(ch.recving, packet.Bytes...) + if packet.EOF == byte(0x01) { + msgBytes := ch.recving + ch.recving = ch.recving[:0] // make([]byte, 0, ch.desc.RecvBufferCapacity) + return msgBytes, nil + } + return nil, nil +} + +// Queues message to send to this channel. +// Goroutine-safe +// Times out (and returns false) after defaultSendTimeout +func (ch *channel) sendBytes(bytes []byte) bool { + select { + case ch.sendQueue <- bytes: + atomic.AddInt32(&ch.sendQueueSize, 1) + return true + case <-time.After(defaultSendTimeout): + return false + } +} + +// Queues message to send to this channel. +// Nonblocking, returns true if successful. +// Goroutine-safe +func (ch *channel) trySendBytes(bytes []byte) bool { + select { + case ch.sendQueue <- bytes: + atomic.AddInt32(&ch.sendQueueSize, 1) + return true + default: + return false + } +} + +// Writes next msgPacket to w. +// Not goroutine-safe +func (ch *channel) writeMsgPacketTo(w io.Writer) (n int, err error) { + packet := ch.nextMsgPacket() + wire.WriteByte(packetTypeMsg, w, &n, &err) + wire.WriteBinary(packet, w, &n, &err) + if err == nil { + ch.recentlySent += int64(n) + } + return +} + +// Call this periodically to update stats for throttling purposes. +// Not goroutine-safe +func (ch *channel) updateStats() { + ch.recentlySent = int64(float64(ch.recentlySent) * 0.8) +} diff --git a/p2p/connection/connection.go b/p2p/connection/connection.go index 77595d43..09b28105 100644 --- a/p2p/connection/connection.go +++ b/p2p/connection/connection.go @@ -3,7 +3,6 @@ package connection import ( "bufio" "fmt" - "io" "math" "net" "runtime/debug" @@ -17,6 +16,13 @@ import ( ) const ( + packetTypePing = byte(0x01) + packetTypePong = byte(0x02) + packetTypeMsg = byte(0x03) + maxMsgPacketPayloadSize = 1024 + maxMsgPacketOverheadSize = 10 // It's actually lower but good enough + maxMsgPacketTotalSize = maxMsgPacketPayloadSize + maxMsgPacketOverheadSize + numBatchMsgPackets = 10 minReadBufferSize = 1024 minWriteBufferSize = 65536 @@ -35,13 +41,19 @@ const ( type receiveCbFunc func(chID byte, msgBytes []byte) type errorCbFunc func(interface{}) -/* -Each peer has one `MConnection` (multiplex connection) instance. +// Messages in channels are chopped into smaller msgPackets for multiplexing. +type msgPacket struct { + ChannelID byte + EOF byte // 1 means message ends here. + Bytes []byte +} -__multiplex__ *noun* a system or signal involving simultaneous transmission of -several messages along a single channel of communication. +func (p msgPacket) String() string { + return fmt.Sprintf("MsgPacket{%X:%X T:%X}", p.ChannelID, p.Bytes, p.EOF) +} -Each `MConnection` handles message transmission on multiple abstract communication +/* +MConnection handles message transmission on multiple abstract communication `Channel`s. Each channel has a globally unique byte id. The byte id and the relative priorities of each `Channel` are configured upon initialization of the connection. @@ -70,8 +82,8 @@ type MConnection struct { recvMonitor *flow.Monitor send chan struct{} pong chan struct{} - channels []*Channel - channelsIdx map[byte]*Channel + channels []*channel + channelsIdx map[byte]*channel onReceive receiveCbFunc onError errorCbFunc errored uint32 @@ -97,16 +109,6 @@ func DefaultMConnConfig() *MConnConfig { } } -// NewMConnection wraps net.Conn and creates multiplex connection -func NewMConnection(conn net.Conn, chDescs []*ChannelDescriptor, onReceive receiveCbFunc, onError errorCbFunc) *MConnection { - return NewMConnectionWithConfig( - conn, - chDescs, - onReceive, - onError, - DefaultMConnConfig()) -} - // NewMConnectionWithConfig wraps net.Conn and creates multiplex connection with a config func NewMConnectionWithConfig(conn net.Conn, chDescs []*ChannelDescriptor, onReceive receiveCbFunc, onError errorCbFunc, config *MConnConfig) *MConnection { mconn := &MConnection{ @@ -117,6 +119,8 @@ func NewMConnectionWithConfig(conn net.Conn, chDescs []*ChannelDescriptor, onRec recvMonitor: flow.New(0, 0), send: make(chan struct{}, 1), pong: make(chan struct{}), + channelsIdx: map[byte]*channel{}, + channels: []*channel{}, onReceive: onReceive, onError: onError, config: config, @@ -125,24 +129,17 @@ func NewMConnectionWithConfig(conn net.Conn, chDescs []*ChannelDescriptor, onRec chStatsTimer: time.NewTicker(updateState), } - // Create channels - var channelsIdx = map[byte]*Channel{} - var channels = []*Channel{} - for _, desc := range chDescs { descCopy := *desc // copy the desc else unsafe access across connections channel := newChannel(mconn, &descCopy) - channelsIdx[channel.id] = channel - channels = append(channels, channel) + mconn.channelsIdx[channel.id] = channel + mconn.channels = append(mconn.channels, channel) } - mconn.channels = channels - mconn.channelsIdx = channelsIdx - mconn.BaseService = *cmn.NewBaseService(nil, "MConnection", mconn) - return mconn } +// OnStart implements BaseService func (c *MConnection) OnStart() error { c.BaseService.OnStart() c.quit = make(chan struct{}) @@ -152,6 +149,7 @@ func (c *MConnection) OnStart() error { return nil } +// OnStop implements BaseService func (c *MConnection) OnStop() { c.BaseService.OnStop() c.flushTimer.Stop() @@ -159,236 +157,86 @@ func (c *MConnection) OnStop() { close(c.quit) } c.conn.Close() - // We can't close pong safely here because - // recvRoutine may write to it after we've stopped. - // Though it doesn't need to get closed at all, - // we close it @ recvRoutine. - // close(c.pong) -} - -func (c *MConnection) String() string { - return fmt.Sprintf("MConn{%v}", c.conn.RemoteAddr()) + // We can't close pong safely here because recvRoutine may write to it after we've + // stopped. Though it doesn't need to get closed at all, we close it @ recvRoutine. } -func (c *MConnection) flush() { - log.WithField("conn", c).Debug("Flush") - err := c.bufWriter.Flush() - if err != nil { - log.WithField("error", err).Error("MConnection flush failed") - } -} - -// Catch panics, usually caused by remote disconnects. -func (c *MConnection) _recover() { - if r := recover(); r != nil { - stack := debug.Stack() - err := cmn.StackError{r, stack} - c.stopForError(err) +// CanSend returns true if you can send more data onto the chID, false otherwise +func (c *MConnection) CanSend(chID byte) bool { + if !c.IsRunning() { + return false } -} -func (c *MConnection) stopForError(r interface{}) { - c.Stop() - if atomic.CompareAndSwapUint32(&c.errored, 0, 1) { - if c.onError != nil { - c.onError(r) - } + channel, ok := c.channelsIdx[chID] + if !ok { + return false } + return channel.canSend() } -// Queues a message to be sent to channel. +// Send will queues a message to be sent to channel(blocking). func (c *MConnection) Send(chID byte, msg interface{}) bool { if !c.IsRunning() { return false } - log.WithFields(log.Fields{ - "chID": chID, - "conn": c, - "msg": msg, - }).Debug("Send") - - // Send message to channel. channel, ok := c.channelsIdx[chID] if !ok { - log.WithField("chID", chID).Error(cmn.Fmt("Cannot send bytes, unknown channel")) + log.WithField("chID", chID).Error("cannot send bytes due to unknown channel") return false } - success := channel.sendBytes(wire.BinaryBytes(msg)) - if success { - // Wake up sendRoutine if necessary - select { - case c.send <- struct{}{}: - default: - } - } else { - log.WithFields(log.Fields{ - "chID": chID, - "conn": c, - "msg": msg, - }).Error("Send failed") + if !channel.sendBytes(wire.BinaryBytes(msg)) { + log.WithFields(log.Fields{"chID": chID, "conn": c, "msg": msg}).Error("MConnection send failed") + return false } - return success + + select { + case c.send <- struct{}{}: + default: + } + return true } -// Queues a message to be sent to channel. -// Nonblocking, returns true if successful. +// TrySend queues a message to be sent to channel(Nonblocking). func (c *MConnection) TrySend(chID byte, msg interface{}) bool { if !c.IsRunning() { return false } - log.WithFields(log.Fields{ - "chID": chID, - "conn": c, - "msg": msg, - }).Debug("TrySend") - - // Send message to channel. channel, ok := c.channelsIdx[chID] if !ok { - log.WithField("chID", chID).Error(cmn.Fmt("cannot send bytes, unknown channel")) + log.WithField("chID", chID).Error("cannot send bytes due to unknown channel") return false } ok = channel.trySendBytes(wire.BinaryBytes(msg)) if ok { - // Wake up sendRoutine if necessary select { case c.send <- struct{}{}: default: } } - return ok } -// CanSend returns true if you can send more data onto the chID, false -// otherwise. Use only as a heuristic. -func (c *MConnection) CanSend(chID byte) bool { - if !c.IsRunning() { - return false - } - - channel, ok := c.channelsIdx[chID] - if !ok { - log.WithField("chID", chID).Error(cmn.Fmt("Unknown channel")) - return false - } - return channel.canSend() -} - -// sendRoutine polls for packets to send from channels. -func (c *MConnection) sendRoutine() { - defer c._recover() - -FOR_LOOP: - for { - var n int - var err error - select { - case <-c.flushTimer.Ch: - // NOTE: flushTimer.Set() must be called every time - // something is written to .bufWriter. - c.flush() - case <-c.chStatsTimer.C: - for _, channel := range c.channels { - channel.updateStats() - } - case <-c.pingTimer.C: - log.Debug("Send Ping") - wire.WriteByte(packetTypePing, c.bufWriter, &n, &err) - c.sendMonitor.Update(int(n)) - c.flush() - case <-c.pong: - log.Debug("Send Pong") - wire.WriteByte(packetTypePong, c.bufWriter, &n, &err) - c.sendMonitor.Update(int(n)) - c.flush() - case <-c.quit: - break FOR_LOOP - case <-c.send: - // Send some msgPackets - eof := c.sendSomeMsgPackets() - if !eof { - // Keep sendRoutine awake. - select { - case c.send <- struct{}{}: - default: - } - } - } - - if !c.IsRunning() { - break FOR_LOOP - } - if err != nil { - log.WithFields(log.Fields{ - "conn": c, - "error": err, - }).Error("Connection failed @ sendRoutine") - c.stopForError(err) - break FOR_LOOP - } - } - - // Cleanup +func (c *MConnection) String() string { + return fmt.Sprintf("MConn{%v}", c.conn.RemoteAddr()) } -// Returns true if messages from channels were exhausted. -// Blocks in accordance to .sendMonitor throttling. -func (c *MConnection) sendSomeMsgPackets() bool { - // Block until .sendMonitor says we can write. - // Once we're ready we send more than we asked for, - // but amortized it should even out. - c.sendMonitor.Limit(maxMsgPacketTotalSize, atomic.LoadInt64(&c.config.SendRate), true) - - // Now send some msgPackets. - for i := 0; i < numBatchMsgPackets; i++ { - if c.sendMsgPacket() { - return true - } +func (c *MConnection) flush() { + if err := c.bufWriter.Flush(); err != nil { + log.WithField("error", err).Error("MConnection flush failed") } - return false } -// Returns true if messages from channels were exhausted. -func (c *MConnection) sendMsgPacket() bool { - // Choose a channel to create a msgPacket from. - // The chosen channel will be the one whose recentlySent/priority is the least. - var leastRatio float32 = math.MaxFloat32 - var leastChannel *Channel - for _, channel := range c.channels { - // If nothing to send, skip this channel - if !channel.isSendPending() { - continue - } - // Get ratio, and keep track of lowest ratio. - ratio := float32(channel.recentlySent) / float32(channel.priority) - if ratio < leastRatio { - leastRatio = ratio - leastChannel = channel - } - } - - // Nothing to send? - if leastChannel == nil { - return true - } else { - // c.Logger.Info("Found a msgPacket to send") - } - - // Make & send a msgPacket from this channel - n, err := leastChannel.writeMsgPacketTo(c.bufWriter) - if err != nil { - log.WithField("error", err).Error("Failed to write msgPacket") +// Catch panics, usually caused by remote disconnects. +func (c *MConnection) _recover() { + if r := recover(); r != nil { + stack := debug.Stack() + err := cmn.StackError{r, stack} c.stopForError(err) - return true } - c.sendMonitor.Update(int(n)) - c.flushTimer.Set() - return false } // recvRoutine reads msgPackets and reconstructs the message using the channels' "recving" buffer. @@ -396,27 +244,12 @@ func (c *MConnection) sendMsgPacket() bool { // Blocks depending on how the connection is throttled. func (c *MConnection) recvRoutine() { defer c._recover() + defer close(c.pong) -FOR_LOOP: for { // Block until .recvMonitor says we can read. c.recvMonitor.Limit(maxMsgPacketTotalSize, atomic.LoadInt64(&c.config.RecvRate), true) - /* - // Peek into bufReader for debugging - if numBytes := c.bufReader.Buffered(); numBytes > 0 { - log.Infof("Peek connection buffer numBytes:", numBytes) - bytes, err := c.bufReader.Peek(cmn.MinInt(numBytes, 100)) - if err == nil { - log.Infof("bytes:", bytes) - } else { - log.Warning("Error peeking connection buffer err:", err) - } - } else { - log.Warning("Received bytes number is:", numBytes) - } - */ - // Read packet type var n int var err error @@ -424,279 +257,149 @@ FOR_LOOP: c.recvMonitor.Update(int(n)) if err != nil { if c.IsRunning() { - log.WithFields(log.Fields{ - "conn": c, - "error": err, - }).Error("Connection failed @ recvRoutine (reading byte)") + log.WithFields(log.Fields{"conn": c, "error": err}).Error("Connection failed @ recvRoutine (reading byte)") c.conn.Close() c.stopForError(err) } - break FOR_LOOP + return } // Read more depending on packet type. switch pktType { case packetTypePing: - // TODO: prevent abuse, as they cause flush()'s. - log.Debug("Receive Ping") + log.Debug("receive Ping") c.pong <- struct{}{} + case packetTypePong: - // do nothing - log.Debug("Receive Pong") + log.Debug("receive Pong") + case packetTypeMsg: pkt, n, err := msgPacket{}, int(0), error(nil) wire.ReadBinaryPtr(&pkt, c.bufReader, maxMsgPacketTotalSize, &n, &err) c.recvMonitor.Update(int(n)) if err != nil { if c.IsRunning() { - log.WithFields(log.Fields{ - "conn": c, - "error": err, - }).Error("Connection failed @ recvRoutine") + log.WithFields(log.Fields{"conn": c, "error": err}).Error("failed on recvRoutine") c.stopForError(err) } - break FOR_LOOP + return } + channel, ok := c.channelsIdx[pkt.ChannelID] if !ok || channel == nil { cmn.PanicQ(cmn.Fmt("Unknown channel %X", pkt.ChannelID)) } + msgBytes, err := channel.recvMsgPacket(pkt) if err != nil { if c.IsRunning() { - log.WithFields(log.Fields{ - "conn": c, - "error": err, - }).Error("Connection failed @ recvRoutine") + log.WithFields(log.Fields{"conn": c, "error": err}).Error("failed on recvRoutine") c.stopForError(err) } - break FOR_LOOP + return } + if msgBytes != nil { - log.WithFields(log.Fields{ - "channelID": pkt.ChannelID, - "msgBytes": msgBytes, - }).Debug("Received bytes") c.onReceive(pkt.ChannelID, msgBytes) } + default: cmn.PanicSanity(cmn.Fmt("Unknown message type %X", pktType)) } } - - // Cleanup - close(c.pong) - for _ = range c.pong { - // Drain - } -} - -type ConnectionStatus struct { - SendMonitor flow.Status - RecvMonitor flow.Status - Channels []ChannelStatus -} - -type ChannelStatus struct { - ID byte - SendQueueCapacity int - SendQueueSize int - Priority int - RecentlySent int64 } -func (c *MConnection) Status() ConnectionStatus { - var status ConnectionStatus - status.SendMonitor = c.sendMonitor.Status() - status.RecvMonitor = c.recvMonitor.Status() - status.Channels = make([]ChannelStatus, len(c.channels)) - for i, channel := range c.channels { - status.Channels[i] = ChannelStatus{ - ID: channel.id, - SendQueueCapacity: cap(channel.sendQueue), - SendQueueSize: int(channel.sendQueueSize), // TODO use atomic - Priority: channel.priority, - RecentlySent: channel.recentlySent, +// Returns true if messages from channels were exhausted. +func (c *MConnection) sendMsgPacket() bool { + var leastRatio float32 = math.MaxFloat32 + var leastChannel *channel + for _, channel := range c.channels { + if !channel.isSendPending() { + continue + } + if ratio := float32(channel.recentlySent) / float32(channel.priority); ratio < leastRatio { + leastRatio = ratio + leastChannel = channel } } - return status -} - -//----------------------------------------------------------------------------- - -type ChannelDescriptor struct { - ID byte - Priority int - SendQueueCapacity int - RecvBufferCapacity int - RecvMessageCapacity int -} - -func (chDesc *ChannelDescriptor) FillDefaults() { - if chDesc.SendQueueCapacity == 0 { - chDesc.SendQueueCapacity = defaultSendQueueCapacity - } - if chDesc.RecvBufferCapacity == 0 { - chDesc.RecvBufferCapacity = defaultRecvBufferCapacity - } - if chDesc.RecvMessageCapacity == 0 { - chDesc.RecvMessageCapacity = defaultRecvMessageCapacity - } -} - -// TODO: lowercase. -// NOTE: not goroutine-safe. -type Channel struct { - conn *MConnection - desc *ChannelDescriptor - id byte - sendQueue chan []byte - sendQueueSize int32 // atomic. - recving []byte - sending []byte - priority int - recentlySent int64 // exponential moving average -} - -func newChannel(conn *MConnection, desc *ChannelDescriptor) *Channel { - desc.FillDefaults() - if desc.Priority <= 0 { - cmn.PanicSanity("Channel default priority must be a postive integer") - } - return &Channel{ - conn: conn, - desc: desc, - id: desc.ID, - sendQueue: make(chan []byte, desc.SendQueueCapacity), - recving: make([]byte, 0, desc.RecvBufferCapacity), - priority: desc.Priority, - } -} - -// Queues message to send to this channel. -// Goroutine-safe -// Times out (and returns false) after defaultSendTimeout -func (ch *Channel) sendBytes(bytes []byte) bool { - select { - case ch.sendQueue <- bytes: - atomic.AddInt32(&ch.sendQueueSize, 1) + if leastChannel == nil { return true - case <-time.After(defaultSendTimeout): - return false } -} -// Queues message to send to this channel. -// Nonblocking, returns true if successful. -// Goroutine-safe -func (ch *Channel) trySendBytes(bytes []byte) bool { - select { - case ch.sendQueue <- bytes: - atomic.AddInt32(&ch.sendQueueSize, 1) + n, err := leastChannel.writeMsgPacketTo(c.bufWriter) + if err != nil { + log.WithField("error", err).Error("failed to write msgPacket") + c.stopForError(err) return true - default: - return false } + c.sendMonitor.Update(int(n)) + c.flushTimer.Set() + return false } -// Goroutine-safe -func (ch *Channel) loadSendQueueSize() (size int) { - return int(atomic.LoadInt32(&ch.sendQueueSize)) -} - -// Goroutine-safe -// Use only as a heuristic. -func (ch *Channel) canSend() bool { - return ch.loadSendQueueSize() < defaultSendQueueCapacity -} +// sendRoutine polls for packets to send from channels. +func (c *MConnection) sendRoutine() { + defer c._recover() -// Returns true if any msgPackets are pending to be sent. -// Call before calling nextMsgPacket() -// Goroutine-safe -func (ch *Channel) isSendPending() bool { - if len(ch.sending) == 0 { - if len(ch.sendQueue) == 0 { - return false + for { + var n int + var err error + select { + case <-c.flushTimer.Ch: + c.flush() + case <-c.chStatsTimer.C: + for _, channel := range c.channels { + channel.updateStats() + } + case <-c.pingTimer.C: + log.Debug("send Ping") + wire.WriteByte(packetTypePing, c.bufWriter, &n, &err) + c.sendMonitor.Update(int(n)) + c.flush() + case <-c.pong: + log.Debug("send Pong") + wire.WriteByte(packetTypePong, c.bufWriter, &n, &err) + c.sendMonitor.Update(int(n)) + c.flush() + case <-c.quit: + return + case <-c.send: + if eof := c.sendSomeMsgPackets(); !eof { + select { + case c.send <- struct{}{}: + default: + } + } } - ch.sending = <-ch.sendQueue - } - return true -} -// Creates a new msgPacket to send. -// Not goroutine-safe -func (ch *Channel) nextMsgPacket() msgPacket { - packet := msgPacket{} - packet.ChannelID = byte(ch.id) - packet.Bytes = ch.sending[:cmn.MinInt(maxMsgPacketPayloadSize, len(ch.sending))] - if len(ch.sending) <= maxMsgPacketPayloadSize { - packet.EOF = byte(0x01) - ch.sending = nil - atomic.AddInt32(&ch.sendQueueSize, -1) // decrement sendQueueSize - } else { - packet.EOF = byte(0x00) - ch.sending = ch.sending[cmn.MinInt(maxMsgPacketPayloadSize, len(ch.sending)):] + if !c.IsRunning() { + return + } + if err != nil { + log.WithFields(log.Fields{"conn": c, "error": err}).Error("Connection failed @ sendRoutine") + c.stopForError(err) + return + } } - return packet } -// Writes next msgPacket to w. -// Not goroutine-safe -func (ch *Channel) writeMsgPacketTo(w io.Writer) (n int, err error) { - packet := ch.nextMsgPacket() - wire.WriteByte(packetTypeMsg, w, &n, &err) - wire.WriteBinary(packet, w, &n, &err) - if err == nil { - ch.recentlySent += int64(n) +// Returns true if messages from channels were exhausted. +func (c *MConnection) sendSomeMsgPackets() bool { + // Block until .sendMonitor says we can write. + // Once we're ready we send more than we asked for, + // but amortized it should even out. + c.sendMonitor.Limit(maxMsgPacketTotalSize, atomic.LoadInt64(&c.config.SendRate), true) + for i := 0; i < numBatchMsgPackets; i++ { + if c.sendMsgPacket() { + return true + } } - return + return false } -// Handles incoming msgPackets. Returns a msg bytes if msg is complete. -// Not goroutine-safe -func (ch *Channel) recvMsgPacket(packet msgPacket) ([]byte, error) { - if ch.desc.RecvMessageCapacity < len(ch.recving)+len(packet.Bytes) { - return nil, wire.ErrBinaryReadOverflow - } - ch.recving = append(ch.recving, packet.Bytes...) - if packet.EOF == byte(0x01) { - msgBytes := ch.recving - // clear the slice without re-allocating. - // http://stackoverflow.com/questions/16971741/how-do-you-clear-a-slice-in-go - // suggests this could be a memory leak, but we might as well keep the memory for the channel until it closes, - // at which point the recving slice stops being used and should be garbage collected - ch.recving = ch.recving[:0] // make([]byte, 0, ch.desc.RecvBufferCapacity) - return msgBytes, nil +func (c *MConnection) stopForError(r interface{}) { + c.Stop() + if atomic.CompareAndSwapUint32(&c.errored, 0, 1) && c.onError != nil { + c.onError(r) } - return nil, nil -} - -// Call this periodically to update stats for throttling purposes. -// Not goroutine-safe -func (ch *Channel) updateStats() { - // Exponential decay of stats. - // TODO: optimize. - ch.recentlySent = int64(float64(ch.recentlySent) * 0.8) -} - -//----------------------------------------------------------------------------- - -const ( - maxMsgPacketPayloadSize = 1024 - maxMsgPacketOverheadSize = 10 // It's actually lower but good enough - maxMsgPacketTotalSize = maxMsgPacketPayloadSize + maxMsgPacketOverheadSize - packetTypePing = byte(0x01) - packetTypePong = byte(0x02) - packetTypeMsg = byte(0x03) -) - -// Messages in channels are chopped into smaller msgPackets for multiplexing. -type msgPacket struct { - ChannelID byte - EOF byte // 1 means message ends here. - Bytes []byte -} - -func (p msgPacket) String() string { - return fmt.Sprintf("MsgPacket{%X:%X T:%X}", p.ChannelID, p.Bytes, p.EOF) } diff --git a/p2p/connection/connection_test.go b/p2p/connection/connection_test.go index 0123a3e5..1a221260 100644 --- a/p2p/connection/connection_test.go +++ b/p2p/connection/connection_test.go @@ -22,7 +22,7 @@ func createMConnection(conn net.Conn) *MConnection { func createMConnectionWithCallbacks(conn net.Conn, onReceive func(chID byte, msgBytes []byte), onError func(r interface{})) *MConnection { chDescs := []*ChannelDescriptor{&ChannelDescriptor{ID: 0x01, Priority: 1, SendQueueCapacity: 1}} - c := NewMConnection(conn, chDescs, onReceive, onError) + c := NewMConnectionWithConfig(conn, chDescs, onReceive, onError, DefaultMConnConfig()) c.SetLogger(log.TestingLogger()) return c } @@ -92,23 +92,6 @@ func TestMConnectionReceive(t *testing.T) { } } -func TestMConnectionStatus(t *testing.T) { - assert, require := assert.New(t), require.New(t) - - server, client := net.Pipe() - defer server.Close() - defer client.Close() - - mconn := createMConnection(client) - _, err := mconn.Start() - require.Nil(err) - defer mconn.Stop() - - status := mconn.Status() - assert.NotNil(status) - assert.Zero(status.Channels[0].SendQueueSize) -} - func TestMConnectionStopsAndReturnsError(t *testing.T) { assert, require := assert.New(t), require.New(t) diff --git a/p2p/connection/secret_connection.go b/p2p/connection/secret_connection.go index 5bd8f9ca..63d135b6 100644 --- a/p2p/connection/secret_connection.go +++ b/p2p/connection/secret_connection.go @@ -1,9 +1,3 @@ -// Uses nacl's secret_box to encrypt a net.Conn. -// It is (meant to be) an implementation of the STS protocol. -// Note we do not (yet) assume that a remote peer's pubkey -// is known ahead of time, and thus we are technically -// still vulnerable to MITM. (TODO!) -// See docs/sts-final.pdf for more info package connection import ( @@ -21,18 +15,24 @@ import ( "golang.org/x/crypto/ripemd160" "github.com/tendermint/go-crypto" - "github.com/tendermint/go-wire" + wire "github.com/tendermint/go-wire" cmn "github.com/tendermint/tmlibs/common" ) -// 2 + 1024 == 1026 total frame size -const dataLenSize = 2 // uint16 to describe the length, is <= dataMaxSize -const dataMaxSize = 1024 -const totalFrameSize = dataMaxSize + dataLenSize -const sealedFrameSize = totalFrameSize + secretbox.Overhead -const authSigMsgSize = (32 + 1) + (64 + 1) // fixed size (length prefixed) byte arrays +const ( + dataLenSize = 2 // uint16 to describe the length, is <= dataMaxSize + dataMaxSize = 1024 + totalFrameSize = dataMaxSize + dataLenSize + sealedFrameSize = totalFrameSize + secretbox.Overhead + authSigMsgSize = (32 + 1) + (64 + 1) // fixed size (length prefixed) byte arrays +) + +type authSigMessage struct { + Key crypto.PubKey + Sig crypto.Signature +} -// Implements net.Conn +// SecretConnection implements net.Conn type SecretConnection struct { conn io.ReadWriteCloser recvBuffer []byte @@ -42,12 +42,8 @@ type SecretConnection struct { shrSecret *[32]byte // shared secret } -// Performs handshake and returns a new authenticated SecretConnection. -// Returns nil if error in handshake. -// Caller should call conn.Close() -// See docs/sts-final.pdf for more information. +// MakeSecretConnection performs handshake and returns a new authenticated SecretConnection. func MakeSecretConnection(conn io.ReadWriteCloser, locPrivKey crypto.PrivKeyEd25519) (*SecretConnection, error) { - locPubKey := locPrivKey.PubKey().Unwrap().(crypto.PubKeyEd25519) // Generate ephemeral keys for perfect forward secrecy. @@ -95,12 +91,42 @@ func MakeSecretConnection(conn io.ReadWriteCloser, locPrivKey crypto.PrivKeyEd25 return nil, errors.New("Challenge verification failed") } - // We've authorized. sc.remPubKey = remPubKey.Unwrap().(crypto.PubKeyEd25519) return sc, nil } -// Returns authenticated remote pubkey +// CONTRACT: data smaller than dataMaxSize is read atomically. +func (sc *SecretConnection) Read(data []byte) (n int, err error) { + if 0 < len(sc.recvBuffer) { + n_ := copy(data, sc.recvBuffer) + sc.recvBuffer = sc.recvBuffer[n_:] + return + } + + sealedFrame := make([]byte, sealedFrameSize) + if _, err = io.ReadFull(sc.conn, sealedFrame); err != nil { + return + } + + // decrypt the frame + frame := make([]byte, totalFrameSize) + if _, ok := secretbox.Open(frame[:0], sealedFrame, sc.recvNonce, sc.shrSecret); !ok { + return n, errors.New("Failed to decrypt SecretConnection") + } + + incr2Nonce(sc.recvNonce) + chunkLength := binary.BigEndian.Uint16(frame) // read the first two bytes + if chunkLength > dataMaxSize { + return 0, errors.New("chunkLength is greater than dataMaxSize") + } + + chunk := frame[dataLenSize : dataLenSize+chunkLength] + n = copy(data, chunk) + sc.recvBuffer = chunk[n:] + return +} + +// RemotePubKey returns authenticated remote pubkey func (sc *SecretConnection) RemotePubKey() crypto.PubKeyEd25519 { return sc.remPubKey } @@ -109,8 +135,8 @@ func (sc *SecretConnection) RemotePubKey() crypto.PubKeyEd25519 { // CONTRACT: data smaller than dataMaxSize is read atomically. func (sc *SecretConnection) Write(data []byte) (n int, err error) { for 0 < len(data) { - var frame []byte = make([]byte, totalFrameSize) var chunk []byte + frame := make([]byte, totalFrameSize) if dataMaxSize < len(data) { chunk = data[:dataMaxSize] data = data[dataMaxSize:] @@ -118,140 +144,89 @@ func (sc *SecretConnection) Write(data []byte) (n int, err error) { chunk = data data = nil } - chunkLength := len(chunk) - binary.BigEndian.PutUint16(frame, uint16(chunkLength)) + binary.BigEndian.PutUint16(frame, uint16(len(chunk))) copy(frame[dataLenSize:], chunk) // encrypt the frame - var sealedFrame = make([]byte, sealedFrameSize) + sealedFrame := make([]byte, sealedFrameSize) secretbox.Seal(sealedFrame[:0], frame, sc.sendNonce, sc.shrSecret) - // fmt.Printf("secretbox.Seal(sealed:%X,sendNonce:%X,shrSecret:%X\n", sealedFrame, sc.sendNonce, sc.shrSecret) incr2Nonce(sc.sendNonce) - // end encryption - _, err := sc.conn.Write(sealedFrame) - if err != nil { + if _, err := sc.conn.Write(sealedFrame); err != nil { return n, err - } else { - n += len(chunk) } + + n += len(chunk) } return } -// CONTRACT: data smaller than dataMaxSize is read atomically. -func (sc *SecretConnection) Read(data []byte) (n int, err error) { - if 0 < len(sc.recvBuffer) { - n_ := copy(data, sc.recvBuffer) - sc.recvBuffer = sc.recvBuffer[n_:] - return - } - - sealedFrame := make([]byte, sealedFrameSize) - _, err = io.ReadFull(sc.conn, sealedFrame) - if err != nil { - return - } +// Close implements net.Conn +func (sc *SecretConnection) Close() error { return sc.conn.Close() } - // decrypt the frame - var frame = make([]byte, totalFrameSize) - // fmt.Printf("secretbox.Open(sealed:%X,recvNonce:%X,shrSecret:%X\n", sealedFrame, sc.recvNonce, sc.shrSecret) - _, ok := secretbox.Open(frame[:0], sealedFrame, sc.recvNonce, sc.shrSecret) - if !ok { - return n, errors.New("Failed to decrypt SecretConnection") - } - incr2Nonce(sc.recvNonce) - // end decryption - - var chunkLength = binary.BigEndian.Uint16(frame) // read the first two bytes - if chunkLength > dataMaxSize { - return 0, errors.New("chunkLength is greater than dataMaxSize") - } - var chunk = frame[dataLenSize : dataLenSize+chunkLength] +// LocalAddr implements net.Conn +func (sc *SecretConnection) LocalAddr() net.Addr { return sc.conn.(net.Conn).LocalAddr() } - n = copy(data, chunk) - sc.recvBuffer = chunk[n:] - return -} +// RemoteAddr implements net.Conn +func (sc *SecretConnection) RemoteAddr() net.Addr { return sc.conn.(net.Conn).RemoteAddr() } -// Implements net.Conn -func (sc *SecretConnection) Close() error { return sc.conn.Close() } -func (sc *SecretConnection) LocalAddr() net.Addr { return sc.conn.(net.Conn).LocalAddr() } -func (sc *SecretConnection) RemoteAddr() net.Addr { return sc.conn.(net.Conn).RemoteAddr() } +// SetDeadline implements net.Conn func (sc *SecretConnection) SetDeadline(t time.Time) error { return sc.conn.(net.Conn).SetDeadline(t) } + +// SetReadDeadline implements net.Conn func (sc *SecretConnection) SetReadDeadline(t time.Time) error { return sc.conn.(net.Conn).SetReadDeadline(t) } + +// SetWriteDeadline implements net.Conn func (sc *SecretConnection) SetWriteDeadline(t time.Time) error { return sc.conn.(net.Conn).SetWriteDeadline(t) } -func genEphKeys() (ephPub, ephPriv *[32]byte) { - var err error - ephPub, ephPriv, err = box.GenerateKey(crand.Reader) - if err != nil { - cmn.PanicCrisis("Could not generate ephemeral keypairs") - } +func computeSharedSecret(remPubKey, locPrivKey *[32]byte) (shrSecret *[32]byte) { + shrSecret = new([32]byte) + box.Precompute(shrSecret, remPubKey, locPrivKey) return } -func shareEphPubKey(conn io.ReadWriteCloser, locEphPub *[32]byte) (remEphPub *[32]byte, err error) { - var err1, err2 error - - cmn.Parallel( - func() { - _, err1 = conn.Write(locEphPub[:]) - }, - func() { - remEphPub = new([32]byte) - _, err2 = io.ReadFull(conn, remEphPub[:]) - }, - ) - - if err1 != nil { - return nil, err1 - } - if err2 != nil { - return nil, err2 - } +func genChallenge(loPubKey, hiPubKey *[32]byte) (challenge *[32]byte) { + return hash32(append(loPubKey[:], hiPubKey[:]...)) +} - return remEphPub, nil +// increment nonce big-endian by 2 with wraparound. +func incr2Nonce(nonce *[24]byte) { + incrNonce(nonce) + incrNonce(nonce) } -func computeSharedSecret(remPubKey, locPrivKey *[32]byte) (shrSecret *[32]byte) { - shrSecret = new([32]byte) - box.Precompute(shrSecret, remPubKey, locPrivKey) - return +// increment nonce big-endian by 1 with wraparound. +func incrNonce(nonce *[24]byte) { + for i := 23; 0 <= i; i-- { + nonce[i]++ + if nonce[i] != 0 { + return + } + } } -func sort32(foo, bar *[32]byte) (lo, hi *[32]byte) { - if bytes.Compare(foo[:], bar[:]) < 0 { - lo = foo - hi = bar - } else { - lo = bar - hi = foo +func genEphKeys() (ephPub, ephPriv *[32]byte) { + var err error + ephPub, ephPriv, err = box.GenerateKey(crand.Reader) + if err != nil { + cmn.PanicCrisis("Could not generate ephemeral keypairs") } return } -func genNonces(loPubKey, hiPubKey *[32]byte, locIsLo bool) (recvNonce, sendNonce *[24]byte) { +func genNonces(loPubKey, hiPubKey *[32]byte, locIsLo bool) (*[24]byte, *[24]byte) { nonce1 := hash24(append(loPubKey[:], hiPubKey[:]...)) nonce2 := new([24]byte) copy(nonce2[:], nonce1[:]) nonce2[len(nonce2)-1] ^= 0x01 if locIsLo { - recvNonce = nonce1 - sendNonce = nonce2 - } else { - recvNonce = nonce2 - sendNonce = nonce1 + return nonce1, nonce2 } - return -} - -func genChallenge(loPubKey, hiPubKey *[32]byte) (challenge *[32]byte) { - return hash32(append(loPubKey[:], hiPubKey[:]...)) + return nonce2, nonce1 } func signChallenge(challenge *[32]byte, locPrivKey crypto.PrivKeyEd25519) (signature crypto.SignatureEd25519) { @@ -259,11 +234,6 @@ func signChallenge(challenge *[32]byte, locPrivKey crypto.PrivKeyEd25519) (signa return } -type authSigMessage struct { - Key crypto.PubKey - Sig crypto.Signature -} - func shareAuthSignature(sc *SecretConnection, pubKey crypto.PubKeyEd25519, signature crypto.SignatureEd25519) (*authSigMessage, error) { var recvMsg authSigMessage var err1, err2 error @@ -281,7 +251,8 @@ func shareAuthSignature(sc *SecretConnection, pubKey crypto.PubKeyEd25519, signa } n := int(0) // not used. recvMsg = wire.ReadBinary(authSigMessage{}, bytes.NewBuffer(readBuffer), authSigMsgSize, &n, &err2).(authSigMessage) - }) + }, + ) if err1 != nil { return nil, err1 @@ -289,15 +260,37 @@ func shareAuthSignature(sc *SecretConnection, pubKey crypto.PubKeyEd25519, signa if err2 != nil { return nil, err2 } - return &recvMsg, nil } -func verifyChallengeSignature(challenge *[32]byte, remPubKey crypto.PubKeyEd25519, remSignature crypto.SignatureEd25519) bool { - return remPubKey.VerifyBytes(challenge[:], remSignature.Wrap()) +func shareEphPubKey(conn io.ReadWriteCloser, locEphPub *[32]byte) (remEphPub *[32]byte, err error) { + var err1, err2 error + + cmn.Parallel( + func() { + _, err1 = conn.Write(locEphPub[:]) + }, + func() { + remEphPub = new([32]byte) + _, err2 = io.ReadFull(conn, remEphPub[:]) + }, + ) + + if err1 != nil { + return nil, err1 + } + if err2 != nil { + return nil, err2 + } + return remEphPub, nil } -//-------------------------------------------------------------------------------- +func sort32(foo, bar *[32]byte) (*[32]byte, *[32]byte) { + if bytes.Compare(foo[:], bar[:]) < 0 { + return foo, bar + } + return bar, foo +} // sha256 func hash32(input []byte) (res *[32]byte) { @@ -318,29 +311,3 @@ func hash24(input []byte) (res *[24]byte) { copy(res[:], resSlice) return } - -// ripemd160 -func hash20(input []byte) (res *[20]byte) { - hasher := ripemd160.New() - hasher.Write(input) // does not error - resSlice := hasher.Sum(nil) - res = new([20]byte) - copy(res[:], resSlice) - return -} - -// increment nonce big-endian by 2 with wraparound. -func incr2Nonce(nonce *[24]byte) { - incrNonce(nonce) - incrNonce(nonce) -} - -// increment nonce big-endian by 1 with wraparound. -func incrNonce(nonce *[24]byte) { - for i := 23; 0 <= i; i-- { - nonce[i] += 1 - if nonce[i] != 0 { - return - } - } -} diff --git a/p2p/connection/secret_connection_test.go b/p2p/connection/secret_connection_test.go index fbfdf922..4785bbc6 100644 --- a/p2p/connection/secret_connection_test.go +++ b/p2p/connection/secret_connection_test.go @@ -145,7 +145,7 @@ func TestSecretConnectionReadWrite(t *testing.T) { var readCount = 0 for _, readChunk := range reads { read += readChunk - readCount += 1 + readCount++ if len(write) <= len(read) { break } -- 2.11.0