--- /dev/null
+package connection
+
+import (
+ "io"
+ "sync/atomic"
+ "time"
+
+ wire "github.com/tendermint/go-wire"
+ cmn "github.com/tendermint/tmlibs/common"
+)
+
+// ChannelDescriptor is the setting of channel
+type ChannelDescriptor struct {
+ ID byte
+ Priority int
+ SendQueueCapacity int
+ RecvBufferCapacity int
+ RecvMessageCapacity int
+}
+
+// FillDefaults set the channel config if empty
+func (chDesc *ChannelDescriptor) FillDefaults() {
+ if chDesc.SendQueueCapacity == 0 {
+ chDesc.SendQueueCapacity = defaultSendQueueCapacity
+ }
+ if chDesc.RecvBufferCapacity == 0 {
+ chDesc.RecvBufferCapacity = defaultRecvBufferCapacity
+ }
+ if chDesc.RecvMessageCapacity == 0 {
+ chDesc.RecvMessageCapacity = defaultRecvMessageCapacity
+ }
+}
+
+type channel struct {
+ conn *MConnection
+ desc *ChannelDescriptor
+ id byte
+ sendQueue chan []byte
+ sendQueueSize int32 // atomic.
+ recving []byte
+ sending []byte
+ priority int
+ recentlySent int64 // exponential moving average
+}
+
+func newChannel(conn *MConnection, desc *ChannelDescriptor) *channel {
+ desc.FillDefaults()
+ if desc.Priority <= 0 {
+ cmn.PanicSanity("Channel default priority must be a postive integer")
+ }
+ return &channel{
+ conn: conn,
+ desc: desc,
+ id: desc.ID,
+ sendQueue: make(chan []byte, desc.SendQueueCapacity),
+ recving: make([]byte, 0, desc.RecvBufferCapacity),
+ priority: desc.Priority,
+ }
+}
+
+// Goroutine-safe
+// Use only as a heuristic.
+func (ch *channel) canSend() bool {
+ return ch.loadSendQueueSize() < defaultSendQueueCapacity
+}
+
+// Returns true if any msgPackets are pending to be sent.
+// Call before calling nextMsgPacket()
+// Goroutine-safe
+func (ch *channel) isSendPending() bool {
+ if len(ch.sending) == 0 {
+ if len(ch.sendQueue) == 0 {
+ return false
+ }
+ ch.sending = <-ch.sendQueue
+ }
+ return true
+}
+
+// Goroutine-safe
+func (ch *channel) loadSendQueueSize() (size int) {
+ return int(atomic.LoadInt32(&ch.sendQueueSize))
+}
+
+// Creates a new msgPacket to send.
+// Not goroutine-safe
+func (ch *channel) nextMsgPacket() msgPacket {
+ packet := msgPacket{
+ ChannelID: byte(ch.id),
+ Bytes: ch.sending[:cmn.MinInt(maxMsgPacketPayloadSize, len(ch.sending))],
+ }
+ if len(ch.sending) <= maxMsgPacketPayloadSize {
+ packet.EOF = byte(0x01)
+ ch.sending = nil
+ atomic.AddInt32(&ch.sendQueueSize, -1) // decrement sendQueueSize
+ } else {
+ packet.EOF = byte(0x00)
+ ch.sending = ch.sending[cmn.MinInt(maxMsgPacketPayloadSize, len(ch.sending)):]
+ }
+ return packet
+}
+
+// Handles incoming msgPackets. Returns a msg bytes if msg is complete.
+// Not goroutine-safe
+func (ch *channel) recvMsgPacket(packet msgPacket) ([]byte, error) {
+ if ch.desc.RecvMessageCapacity < len(ch.recving)+len(packet.Bytes) {
+ return nil, wire.ErrBinaryReadOverflow
+ }
+
+ ch.recving = append(ch.recving, packet.Bytes...)
+ if packet.EOF == byte(0x01) {
+ msgBytes := ch.recving
+ ch.recving = ch.recving[:0] // make([]byte, 0, ch.desc.RecvBufferCapacity)
+ return msgBytes, nil
+ }
+ return nil, nil
+}
+
+// Queues message to send to this channel.
+// Goroutine-safe
+// Times out (and returns false) after defaultSendTimeout
+func (ch *channel) sendBytes(bytes []byte) bool {
+ select {
+ case ch.sendQueue <- bytes:
+ atomic.AddInt32(&ch.sendQueueSize, 1)
+ return true
+ case <-time.After(defaultSendTimeout):
+ return false
+ }
+}
+
+// Queues message to send to this channel.
+// Nonblocking, returns true if successful.
+// Goroutine-safe
+func (ch *channel) trySendBytes(bytes []byte) bool {
+ select {
+ case ch.sendQueue <- bytes:
+ atomic.AddInt32(&ch.sendQueueSize, 1)
+ return true
+ default:
+ return false
+ }
+}
+
+// Writes next msgPacket to w.
+// Not goroutine-safe
+func (ch *channel) writeMsgPacketTo(w io.Writer) (n int, err error) {
+ packet := ch.nextMsgPacket()
+ wire.WriteByte(packetTypeMsg, w, &n, &err)
+ wire.WriteBinary(packet, w, &n, &err)
+ if err == nil {
+ ch.recentlySent += int64(n)
+ }
+ return
+}
+
+// Call this periodically to update stats for throttling purposes.
+// Not goroutine-safe
+func (ch *channel) updateStats() {
+ ch.recentlySent = int64(float64(ch.recentlySent) * 0.8)
+}
import (
"bufio"
"fmt"
- "io"
"math"
"net"
"runtime/debug"
)
const (
+ packetTypePing = byte(0x01)
+ packetTypePong = byte(0x02)
+ packetTypeMsg = byte(0x03)
+ maxMsgPacketPayloadSize = 1024
+ maxMsgPacketOverheadSize = 10 // It's actually lower but good enough
+ maxMsgPacketTotalSize = maxMsgPacketPayloadSize + maxMsgPacketOverheadSize
+
numBatchMsgPackets = 10
minReadBufferSize = 1024
minWriteBufferSize = 65536
type receiveCbFunc func(chID byte, msgBytes []byte)
type errorCbFunc func(interface{})
-/*
-Each peer has one `MConnection` (multiplex connection) instance.
+// Messages in channels are chopped into smaller msgPackets for multiplexing.
+type msgPacket struct {
+ ChannelID byte
+ EOF byte // 1 means message ends here.
+ Bytes []byte
+}
-__multiplex__ *noun* a system or signal involving simultaneous transmission of
-several messages along a single channel of communication.
+func (p msgPacket) String() string {
+ return fmt.Sprintf("MsgPacket{%X:%X T:%X}", p.ChannelID, p.Bytes, p.EOF)
+}
-Each `MConnection` handles message transmission on multiple abstract communication
+/*
+MConnection handles message transmission on multiple abstract communication
`Channel`s. Each channel has a globally unique byte id.
The byte id and the relative priorities of each `Channel` are configured upon
initialization of the connection.
recvMonitor *flow.Monitor
send chan struct{}
pong chan struct{}
- channels []*Channel
- channelsIdx map[byte]*Channel
+ channels []*channel
+ channelsIdx map[byte]*channel
onReceive receiveCbFunc
onError errorCbFunc
errored uint32
}
}
-// NewMConnection wraps net.Conn and creates multiplex connection
-func NewMConnection(conn net.Conn, chDescs []*ChannelDescriptor, onReceive receiveCbFunc, onError errorCbFunc) *MConnection {
- return NewMConnectionWithConfig(
- conn,
- chDescs,
- onReceive,
- onError,
- DefaultMConnConfig())
-}
-
// NewMConnectionWithConfig wraps net.Conn and creates multiplex connection with a config
func NewMConnectionWithConfig(conn net.Conn, chDescs []*ChannelDescriptor, onReceive receiveCbFunc, onError errorCbFunc, config *MConnConfig) *MConnection {
mconn := &MConnection{
recvMonitor: flow.New(0, 0),
send: make(chan struct{}, 1),
pong: make(chan struct{}),
+ channelsIdx: map[byte]*channel{},
+ channels: []*channel{},
onReceive: onReceive,
onError: onError,
config: config,
chStatsTimer: time.NewTicker(updateState),
}
- // Create channels
- var channelsIdx = map[byte]*Channel{}
- var channels = []*Channel{}
-
for _, desc := range chDescs {
descCopy := *desc // copy the desc else unsafe access across connections
channel := newChannel(mconn, &descCopy)
- channelsIdx[channel.id] = channel
- channels = append(channels, channel)
+ mconn.channelsIdx[channel.id] = channel
+ mconn.channels = append(mconn.channels, channel)
}
- mconn.channels = channels
- mconn.channelsIdx = channelsIdx
-
mconn.BaseService = *cmn.NewBaseService(nil, "MConnection", mconn)
-
return mconn
}
+// OnStart implements BaseService
func (c *MConnection) OnStart() error {
c.BaseService.OnStart()
c.quit = make(chan struct{})
return nil
}
+// OnStop implements BaseService
func (c *MConnection) OnStop() {
c.BaseService.OnStop()
c.flushTimer.Stop()
close(c.quit)
}
c.conn.Close()
- // We can't close pong safely here because
- // recvRoutine may write to it after we've stopped.
- // Though it doesn't need to get closed at all,
- // we close it @ recvRoutine.
- // close(c.pong)
-}
-
-func (c *MConnection) String() string {
- return fmt.Sprintf("MConn{%v}", c.conn.RemoteAddr())
+ // We can't close pong safely here because recvRoutine may write to it after we've
+ // stopped. Though it doesn't need to get closed at all, we close it @ recvRoutine.
}
-func (c *MConnection) flush() {
- log.WithField("conn", c).Debug("Flush")
- err := c.bufWriter.Flush()
- if err != nil {
- log.WithField("error", err).Error("MConnection flush failed")
- }
-}
-
-// Catch panics, usually caused by remote disconnects.
-func (c *MConnection) _recover() {
- if r := recover(); r != nil {
- stack := debug.Stack()
- err := cmn.StackError{r, stack}
- c.stopForError(err)
+// CanSend returns true if you can send more data onto the chID, false otherwise
+func (c *MConnection) CanSend(chID byte) bool {
+ if !c.IsRunning() {
+ return false
}
-}
-func (c *MConnection) stopForError(r interface{}) {
- c.Stop()
- if atomic.CompareAndSwapUint32(&c.errored, 0, 1) {
- if c.onError != nil {
- c.onError(r)
- }
+ channel, ok := c.channelsIdx[chID]
+ if !ok {
+ return false
}
+ return channel.canSend()
}
-// Queues a message to be sent to channel.
+// Send will queues a message to be sent to channel(blocking).
func (c *MConnection) Send(chID byte, msg interface{}) bool {
if !c.IsRunning() {
return false
}
- log.WithFields(log.Fields{
- "chID": chID,
- "conn": c,
- "msg": msg,
- }).Debug("Send")
-
- // Send message to channel.
channel, ok := c.channelsIdx[chID]
if !ok {
- log.WithField("chID", chID).Error(cmn.Fmt("Cannot send bytes, unknown channel"))
+ log.WithField("chID", chID).Error("cannot send bytes due to unknown channel")
return false
}
- success := channel.sendBytes(wire.BinaryBytes(msg))
- if success {
- // Wake up sendRoutine if necessary
- select {
- case c.send <- struct{}{}:
- default:
- }
- } else {
- log.WithFields(log.Fields{
- "chID": chID,
- "conn": c,
- "msg": msg,
- }).Error("Send failed")
+ if !channel.sendBytes(wire.BinaryBytes(msg)) {
+ log.WithFields(log.Fields{"chID": chID, "conn": c, "msg": msg}).Error("MConnection send failed")
+ return false
}
- return success
+
+ select {
+ case c.send <- struct{}{}:
+ default:
+ }
+ return true
}
-// Queues a message to be sent to channel.
-// Nonblocking, returns true if successful.
+// TrySend queues a message to be sent to channel(Nonblocking).
func (c *MConnection) TrySend(chID byte, msg interface{}) bool {
if !c.IsRunning() {
return false
}
- log.WithFields(log.Fields{
- "chID": chID,
- "conn": c,
- "msg": msg,
- }).Debug("TrySend")
-
- // Send message to channel.
channel, ok := c.channelsIdx[chID]
if !ok {
- log.WithField("chID", chID).Error(cmn.Fmt("cannot send bytes, unknown channel"))
+ log.WithField("chID", chID).Error("cannot send bytes due to unknown channel")
return false
}
ok = channel.trySendBytes(wire.BinaryBytes(msg))
if ok {
- // Wake up sendRoutine if necessary
select {
case c.send <- struct{}{}:
default:
}
}
-
return ok
}
-// CanSend returns true if you can send more data onto the chID, false
-// otherwise. Use only as a heuristic.
-func (c *MConnection) CanSend(chID byte) bool {
- if !c.IsRunning() {
- return false
- }
-
- channel, ok := c.channelsIdx[chID]
- if !ok {
- log.WithField("chID", chID).Error(cmn.Fmt("Unknown channel"))
- return false
- }
- return channel.canSend()
-}
-
-// sendRoutine polls for packets to send from channels.
-func (c *MConnection) sendRoutine() {
- defer c._recover()
-
-FOR_LOOP:
- for {
- var n int
- var err error
- select {
- case <-c.flushTimer.Ch:
- // NOTE: flushTimer.Set() must be called every time
- // something is written to .bufWriter.
- c.flush()
- case <-c.chStatsTimer.C:
- for _, channel := range c.channels {
- channel.updateStats()
- }
- case <-c.pingTimer.C:
- log.Debug("Send Ping")
- wire.WriteByte(packetTypePing, c.bufWriter, &n, &err)
- c.sendMonitor.Update(int(n))
- c.flush()
- case <-c.pong:
- log.Debug("Send Pong")
- wire.WriteByte(packetTypePong, c.bufWriter, &n, &err)
- c.sendMonitor.Update(int(n))
- c.flush()
- case <-c.quit:
- break FOR_LOOP
- case <-c.send:
- // Send some msgPackets
- eof := c.sendSomeMsgPackets()
- if !eof {
- // Keep sendRoutine awake.
- select {
- case c.send <- struct{}{}:
- default:
- }
- }
- }
-
- if !c.IsRunning() {
- break FOR_LOOP
- }
- if err != nil {
- log.WithFields(log.Fields{
- "conn": c,
- "error": err,
- }).Error("Connection failed @ sendRoutine")
- c.stopForError(err)
- break FOR_LOOP
- }
- }
-
- // Cleanup
+func (c *MConnection) String() string {
+ return fmt.Sprintf("MConn{%v}", c.conn.RemoteAddr())
}
-// Returns true if messages from channels were exhausted.
-// Blocks in accordance to .sendMonitor throttling.
-func (c *MConnection) sendSomeMsgPackets() bool {
- // Block until .sendMonitor says we can write.
- // Once we're ready we send more than we asked for,
- // but amortized it should even out.
- c.sendMonitor.Limit(maxMsgPacketTotalSize, atomic.LoadInt64(&c.config.SendRate), true)
-
- // Now send some msgPackets.
- for i := 0; i < numBatchMsgPackets; i++ {
- if c.sendMsgPacket() {
- return true
- }
+func (c *MConnection) flush() {
+ if err := c.bufWriter.Flush(); err != nil {
+ log.WithField("error", err).Error("MConnection flush failed")
}
- return false
}
-// Returns true if messages from channels were exhausted.
-func (c *MConnection) sendMsgPacket() bool {
- // Choose a channel to create a msgPacket from.
- // The chosen channel will be the one whose recentlySent/priority is the least.
- var leastRatio float32 = math.MaxFloat32
- var leastChannel *Channel
- for _, channel := range c.channels {
- // If nothing to send, skip this channel
- if !channel.isSendPending() {
- continue
- }
- // Get ratio, and keep track of lowest ratio.
- ratio := float32(channel.recentlySent) / float32(channel.priority)
- if ratio < leastRatio {
- leastRatio = ratio
- leastChannel = channel
- }
- }
-
- // Nothing to send?
- if leastChannel == nil {
- return true
- } else {
- // c.Logger.Info("Found a msgPacket to send")
- }
-
- // Make & send a msgPacket from this channel
- n, err := leastChannel.writeMsgPacketTo(c.bufWriter)
- if err != nil {
- log.WithField("error", err).Error("Failed to write msgPacket")
+// Catch panics, usually caused by remote disconnects.
+func (c *MConnection) _recover() {
+ if r := recover(); r != nil {
+ stack := debug.Stack()
+ err := cmn.StackError{r, stack}
c.stopForError(err)
- return true
}
- c.sendMonitor.Update(int(n))
- c.flushTimer.Set()
- return false
}
// recvRoutine reads msgPackets and reconstructs the message using the channels' "recving" buffer.
// Blocks depending on how the connection is throttled.
func (c *MConnection) recvRoutine() {
defer c._recover()
+ defer close(c.pong)
-FOR_LOOP:
for {
// Block until .recvMonitor says we can read.
c.recvMonitor.Limit(maxMsgPacketTotalSize, atomic.LoadInt64(&c.config.RecvRate), true)
- /*
- // Peek into bufReader for debugging
- if numBytes := c.bufReader.Buffered(); numBytes > 0 {
- log.Infof("Peek connection buffer numBytes:", numBytes)
- bytes, err := c.bufReader.Peek(cmn.MinInt(numBytes, 100))
- if err == nil {
- log.Infof("bytes:", bytes)
- } else {
- log.Warning("Error peeking connection buffer err:", err)
- }
- } else {
- log.Warning("Received bytes number is:", numBytes)
- }
- */
-
// Read packet type
var n int
var err error
c.recvMonitor.Update(int(n))
if err != nil {
if c.IsRunning() {
- log.WithFields(log.Fields{
- "conn": c,
- "error": err,
- }).Error("Connection failed @ recvRoutine (reading byte)")
+ log.WithFields(log.Fields{"conn": c, "error": err}).Error("Connection failed @ recvRoutine (reading byte)")
c.conn.Close()
c.stopForError(err)
}
- break FOR_LOOP
+ return
}
// Read more depending on packet type.
switch pktType {
case packetTypePing:
- // TODO: prevent abuse, as they cause flush()'s.
- log.Debug("Receive Ping")
+ log.Debug("receive Ping")
c.pong <- struct{}{}
+
case packetTypePong:
- // do nothing
- log.Debug("Receive Pong")
+ log.Debug("receive Pong")
+
case packetTypeMsg:
pkt, n, err := msgPacket{}, int(0), error(nil)
wire.ReadBinaryPtr(&pkt, c.bufReader, maxMsgPacketTotalSize, &n, &err)
c.recvMonitor.Update(int(n))
if err != nil {
if c.IsRunning() {
- log.WithFields(log.Fields{
- "conn": c,
- "error": err,
- }).Error("Connection failed @ recvRoutine")
+ log.WithFields(log.Fields{"conn": c, "error": err}).Error("failed on recvRoutine")
c.stopForError(err)
}
- break FOR_LOOP
+ return
}
+
channel, ok := c.channelsIdx[pkt.ChannelID]
if !ok || channel == nil {
cmn.PanicQ(cmn.Fmt("Unknown channel %X", pkt.ChannelID))
}
+
msgBytes, err := channel.recvMsgPacket(pkt)
if err != nil {
if c.IsRunning() {
- log.WithFields(log.Fields{
- "conn": c,
- "error": err,
- }).Error("Connection failed @ recvRoutine")
+ log.WithFields(log.Fields{"conn": c, "error": err}).Error("failed on recvRoutine")
c.stopForError(err)
}
- break FOR_LOOP
+ return
}
+
if msgBytes != nil {
- log.WithFields(log.Fields{
- "channelID": pkt.ChannelID,
- "msgBytes": msgBytes,
- }).Debug("Received bytes")
c.onReceive(pkt.ChannelID, msgBytes)
}
+
default:
cmn.PanicSanity(cmn.Fmt("Unknown message type %X", pktType))
}
}
-
- // Cleanup
- close(c.pong)
- for _ = range c.pong {
- // Drain
- }
-}
-
-type ConnectionStatus struct {
- SendMonitor flow.Status
- RecvMonitor flow.Status
- Channels []ChannelStatus
-}
-
-type ChannelStatus struct {
- ID byte
- SendQueueCapacity int
- SendQueueSize int
- Priority int
- RecentlySent int64
}
-func (c *MConnection) Status() ConnectionStatus {
- var status ConnectionStatus
- status.SendMonitor = c.sendMonitor.Status()
- status.RecvMonitor = c.recvMonitor.Status()
- status.Channels = make([]ChannelStatus, len(c.channels))
- for i, channel := range c.channels {
- status.Channels[i] = ChannelStatus{
- ID: channel.id,
- SendQueueCapacity: cap(channel.sendQueue),
- SendQueueSize: int(channel.sendQueueSize), // TODO use atomic
- Priority: channel.priority,
- RecentlySent: channel.recentlySent,
+// Returns true if messages from channels were exhausted.
+func (c *MConnection) sendMsgPacket() bool {
+ var leastRatio float32 = math.MaxFloat32
+ var leastChannel *channel
+ for _, channel := range c.channels {
+ if !channel.isSendPending() {
+ continue
+ }
+ if ratio := float32(channel.recentlySent) / float32(channel.priority); ratio < leastRatio {
+ leastRatio = ratio
+ leastChannel = channel
}
}
- return status
-}
-
-//-----------------------------------------------------------------------------
-
-type ChannelDescriptor struct {
- ID byte
- Priority int
- SendQueueCapacity int
- RecvBufferCapacity int
- RecvMessageCapacity int
-}
-
-func (chDesc *ChannelDescriptor) FillDefaults() {
- if chDesc.SendQueueCapacity == 0 {
- chDesc.SendQueueCapacity = defaultSendQueueCapacity
- }
- if chDesc.RecvBufferCapacity == 0 {
- chDesc.RecvBufferCapacity = defaultRecvBufferCapacity
- }
- if chDesc.RecvMessageCapacity == 0 {
- chDesc.RecvMessageCapacity = defaultRecvMessageCapacity
- }
-}
-
-// TODO: lowercase.
-// NOTE: not goroutine-safe.
-type Channel struct {
- conn *MConnection
- desc *ChannelDescriptor
- id byte
- sendQueue chan []byte
- sendQueueSize int32 // atomic.
- recving []byte
- sending []byte
- priority int
- recentlySent int64 // exponential moving average
-}
-
-func newChannel(conn *MConnection, desc *ChannelDescriptor) *Channel {
- desc.FillDefaults()
- if desc.Priority <= 0 {
- cmn.PanicSanity("Channel default priority must be a postive integer")
- }
- return &Channel{
- conn: conn,
- desc: desc,
- id: desc.ID,
- sendQueue: make(chan []byte, desc.SendQueueCapacity),
- recving: make([]byte, 0, desc.RecvBufferCapacity),
- priority: desc.Priority,
- }
-}
-
-// Queues message to send to this channel.
-// Goroutine-safe
-// Times out (and returns false) after defaultSendTimeout
-func (ch *Channel) sendBytes(bytes []byte) bool {
- select {
- case ch.sendQueue <- bytes:
- atomic.AddInt32(&ch.sendQueueSize, 1)
+ if leastChannel == nil {
return true
- case <-time.After(defaultSendTimeout):
- return false
}
-}
-// Queues message to send to this channel.
-// Nonblocking, returns true if successful.
-// Goroutine-safe
-func (ch *Channel) trySendBytes(bytes []byte) bool {
- select {
- case ch.sendQueue <- bytes:
- atomic.AddInt32(&ch.sendQueueSize, 1)
+ n, err := leastChannel.writeMsgPacketTo(c.bufWriter)
+ if err != nil {
+ log.WithField("error", err).Error("failed to write msgPacket")
+ c.stopForError(err)
return true
- default:
- return false
}
+ c.sendMonitor.Update(int(n))
+ c.flushTimer.Set()
+ return false
}
-// Goroutine-safe
-func (ch *Channel) loadSendQueueSize() (size int) {
- return int(atomic.LoadInt32(&ch.sendQueueSize))
-}
-
-// Goroutine-safe
-// Use only as a heuristic.
-func (ch *Channel) canSend() bool {
- return ch.loadSendQueueSize() < defaultSendQueueCapacity
-}
+// sendRoutine polls for packets to send from channels.
+func (c *MConnection) sendRoutine() {
+ defer c._recover()
-// Returns true if any msgPackets are pending to be sent.
-// Call before calling nextMsgPacket()
-// Goroutine-safe
-func (ch *Channel) isSendPending() bool {
- if len(ch.sending) == 0 {
- if len(ch.sendQueue) == 0 {
- return false
+ for {
+ var n int
+ var err error
+ select {
+ case <-c.flushTimer.Ch:
+ c.flush()
+ case <-c.chStatsTimer.C:
+ for _, channel := range c.channels {
+ channel.updateStats()
+ }
+ case <-c.pingTimer.C:
+ log.Debug("send Ping")
+ wire.WriteByte(packetTypePing, c.bufWriter, &n, &err)
+ c.sendMonitor.Update(int(n))
+ c.flush()
+ case <-c.pong:
+ log.Debug("send Pong")
+ wire.WriteByte(packetTypePong, c.bufWriter, &n, &err)
+ c.sendMonitor.Update(int(n))
+ c.flush()
+ case <-c.quit:
+ return
+ case <-c.send:
+ if eof := c.sendSomeMsgPackets(); !eof {
+ select {
+ case c.send <- struct{}{}:
+ default:
+ }
+ }
}
- ch.sending = <-ch.sendQueue
- }
- return true
-}
-// Creates a new msgPacket to send.
-// Not goroutine-safe
-func (ch *Channel) nextMsgPacket() msgPacket {
- packet := msgPacket{}
- packet.ChannelID = byte(ch.id)
- packet.Bytes = ch.sending[:cmn.MinInt(maxMsgPacketPayloadSize, len(ch.sending))]
- if len(ch.sending) <= maxMsgPacketPayloadSize {
- packet.EOF = byte(0x01)
- ch.sending = nil
- atomic.AddInt32(&ch.sendQueueSize, -1) // decrement sendQueueSize
- } else {
- packet.EOF = byte(0x00)
- ch.sending = ch.sending[cmn.MinInt(maxMsgPacketPayloadSize, len(ch.sending)):]
+ if !c.IsRunning() {
+ return
+ }
+ if err != nil {
+ log.WithFields(log.Fields{"conn": c, "error": err}).Error("Connection failed @ sendRoutine")
+ c.stopForError(err)
+ return
+ }
}
- return packet
}
-// Writes next msgPacket to w.
-// Not goroutine-safe
-func (ch *Channel) writeMsgPacketTo(w io.Writer) (n int, err error) {
- packet := ch.nextMsgPacket()
- wire.WriteByte(packetTypeMsg, w, &n, &err)
- wire.WriteBinary(packet, w, &n, &err)
- if err == nil {
- ch.recentlySent += int64(n)
+// Returns true if messages from channels were exhausted.
+func (c *MConnection) sendSomeMsgPackets() bool {
+ // Block until .sendMonitor says we can write.
+ // Once we're ready we send more than we asked for,
+ // but amortized it should even out.
+ c.sendMonitor.Limit(maxMsgPacketTotalSize, atomic.LoadInt64(&c.config.SendRate), true)
+ for i := 0; i < numBatchMsgPackets; i++ {
+ if c.sendMsgPacket() {
+ return true
+ }
}
- return
+ return false
}
-// Handles incoming msgPackets. Returns a msg bytes if msg is complete.
-// Not goroutine-safe
-func (ch *Channel) recvMsgPacket(packet msgPacket) ([]byte, error) {
- if ch.desc.RecvMessageCapacity < len(ch.recving)+len(packet.Bytes) {
- return nil, wire.ErrBinaryReadOverflow
- }
- ch.recving = append(ch.recving, packet.Bytes...)
- if packet.EOF == byte(0x01) {
- msgBytes := ch.recving
- // clear the slice without re-allocating.
- // http://stackoverflow.com/questions/16971741/how-do-you-clear-a-slice-in-go
- // suggests this could be a memory leak, but we might as well keep the memory for the channel until it closes,
- // at which point the recving slice stops being used and should be garbage collected
- ch.recving = ch.recving[:0] // make([]byte, 0, ch.desc.RecvBufferCapacity)
- return msgBytes, nil
+func (c *MConnection) stopForError(r interface{}) {
+ c.Stop()
+ if atomic.CompareAndSwapUint32(&c.errored, 0, 1) && c.onError != nil {
+ c.onError(r)
}
- return nil, nil
-}
-
-// Call this periodically to update stats for throttling purposes.
-// Not goroutine-safe
-func (ch *Channel) updateStats() {
- // Exponential decay of stats.
- // TODO: optimize.
- ch.recentlySent = int64(float64(ch.recentlySent) * 0.8)
-}
-
-//-----------------------------------------------------------------------------
-
-const (
- maxMsgPacketPayloadSize = 1024
- maxMsgPacketOverheadSize = 10 // It's actually lower but good enough
- maxMsgPacketTotalSize = maxMsgPacketPayloadSize + maxMsgPacketOverheadSize
- packetTypePing = byte(0x01)
- packetTypePong = byte(0x02)
- packetTypeMsg = byte(0x03)
-)
-
-// Messages in channels are chopped into smaller msgPackets for multiplexing.
-type msgPacket struct {
- ChannelID byte
- EOF byte // 1 means message ends here.
- Bytes []byte
-}
-
-func (p msgPacket) String() string {
- return fmt.Sprintf("MsgPacket{%X:%X T:%X}", p.ChannelID, p.Bytes, p.EOF)
}
-// Uses nacl's secret_box to encrypt a net.Conn.
-// It is (meant to be) an implementation of the STS protocol.
-// Note we do not (yet) assume that a remote peer's pubkey
-// is known ahead of time, and thus we are technically
-// still vulnerable to MITM. (TODO!)
-// See docs/sts-final.pdf for more info
package connection
import (
"golang.org/x/crypto/ripemd160"
"github.com/tendermint/go-crypto"
- "github.com/tendermint/go-wire"
+ wire "github.com/tendermint/go-wire"
cmn "github.com/tendermint/tmlibs/common"
)
-// 2 + 1024 == 1026 total frame size
-const dataLenSize = 2 // uint16 to describe the length, is <= dataMaxSize
-const dataMaxSize = 1024
-const totalFrameSize = dataMaxSize + dataLenSize
-const sealedFrameSize = totalFrameSize + secretbox.Overhead
-const authSigMsgSize = (32 + 1) + (64 + 1) // fixed size (length prefixed) byte arrays
+const (
+ dataLenSize = 2 // uint16 to describe the length, is <= dataMaxSize
+ dataMaxSize = 1024
+ totalFrameSize = dataMaxSize + dataLenSize
+ sealedFrameSize = totalFrameSize + secretbox.Overhead
+ authSigMsgSize = (32 + 1) + (64 + 1) // fixed size (length prefixed) byte arrays
+)
+
+type authSigMessage struct {
+ Key crypto.PubKey
+ Sig crypto.Signature
+}
-// Implements net.Conn
+// SecretConnection implements net.Conn
type SecretConnection struct {
conn io.ReadWriteCloser
recvBuffer []byte
shrSecret *[32]byte // shared secret
}
-// Performs handshake and returns a new authenticated SecretConnection.
-// Returns nil if error in handshake.
-// Caller should call conn.Close()
-// See docs/sts-final.pdf for more information.
+// MakeSecretConnection performs handshake and returns a new authenticated SecretConnection.
func MakeSecretConnection(conn io.ReadWriteCloser, locPrivKey crypto.PrivKeyEd25519) (*SecretConnection, error) {
-
locPubKey := locPrivKey.PubKey().Unwrap().(crypto.PubKeyEd25519)
// Generate ephemeral keys for perfect forward secrecy.
return nil, errors.New("Challenge verification failed")
}
- // We've authorized.
sc.remPubKey = remPubKey.Unwrap().(crypto.PubKeyEd25519)
return sc, nil
}
-// Returns authenticated remote pubkey
+// CONTRACT: data smaller than dataMaxSize is read atomically.
+func (sc *SecretConnection) Read(data []byte) (n int, err error) {
+ if 0 < len(sc.recvBuffer) {
+ n_ := copy(data, sc.recvBuffer)
+ sc.recvBuffer = sc.recvBuffer[n_:]
+ return
+ }
+
+ sealedFrame := make([]byte, sealedFrameSize)
+ if _, err = io.ReadFull(sc.conn, sealedFrame); err != nil {
+ return
+ }
+
+ // decrypt the frame
+ frame := make([]byte, totalFrameSize)
+ if _, ok := secretbox.Open(frame[:0], sealedFrame, sc.recvNonce, sc.shrSecret); !ok {
+ return n, errors.New("Failed to decrypt SecretConnection")
+ }
+
+ incr2Nonce(sc.recvNonce)
+ chunkLength := binary.BigEndian.Uint16(frame) // read the first two bytes
+ if chunkLength > dataMaxSize {
+ return 0, errors.New("chunkLength is greater than dataMaxSize")
+ }
+
+ chunk := frame[dataLenSize : dataLenSize+chunkLength]
+ n = copy(data, chunk)
+ sc.recvBuffer = chunk[n:]
+ return
+}
+
+// RemotePubKey returns authenticated remote pubkey
func (sc *SecretConnection) RemotePubKey() crypto.PubKeyEd25519 {
return sc.remPubKey
}
// CONTRACT: data smaller than dataMaxSize is read atomically.
func (sc *SecretConnection) Write(data []byte) (n int, err error) {
for 0 < len(data) {
- var frame []byte = make([]byte, totalFrameSize)
var chunk []byte
+ frame := make([]byte, totalFrameSize)
if dataMaxSize < len(data) {
chunk = data[:dataMaxSize]
data = data[dataMaxSize:]
chunk = data
data = nil
}
- chunkLength := len(chunk)
- binary.BigEndian.PutUint16(frame, uint16(chunkLength))
+ binary.BigEndian.PutUint16(frame, uint16(len(chunk)))
copy(frame[dataLenSize:], chunk)
// encrypt the frame
- var sealedFrame = make([]byte, sealedFrameSize)
+ sealedFrame := make([]byte, sealedFrameSize)
secretbox.Seal(sealedFrame[:0], frame, sc.sendNonce, sc.shrSecret)
- // fmt.Printf("secretbox.Seal(sealed:%X,sendNonce:%X,shrSecret:%X\n", sealedFrame, sc.sendNonce, sc.shrSecret)
incr2Nonce(sc.sendNonce)
- // end encryption
- _, err := sc.conn.Write(sealedFrame)
- if err != nil {
+ if _, err := sc.conn.Write(sealedFrame); err != nil {
return n, err
- } else {
- n += len(chunk)
}
+
+ n += len(chunk)
}
return
}
-// CONTRACT: data smaller than dataMaxSize is read atomically.
-func (sc *SecretConnection) Read(data []byte) (n int, err error) {
- if 0 < len(sc.recvBuffer) {
- n_ := copy(data, sc.recvBuffer)
- sc.recvBuffer = sc.recvBuffer[n_:]
- return
- }
-
- sealedFrame := make([]byte, sealedFrameSize)
- _, err = io.ReadFull(sc.conn, sealedFrame)
- if err != nil {
- return
- }
+// Close implements net.Conn
+func (sc *SecretConnection) Close() error { return sc.conn.Close() }
- // decrypt the frame
- var frame = make([]byte, totalFrameSize)
- // fmt.Printf("secretbox.Open(sealed:%X,recvNonce:%X,shrSecret:%X\n", sealedFrame, sc.recvNonce, sc.shrSecret)
- _, ok := secretbox.Open(frame[:0], sealedFrame, sc.recvNonce, sc.shrSecret)
- if !ok {
- return n, errors.New("Failed to decrypt SecretConnection")
- }
- incr2Nonce(sc.recvNonce)
- // end decryption
-
- var chunkLength = binary.BigEndian.Uint16(frame) // read the first two bytes
- if chunkLength > dataMaxSize {
- return 0, errors.New("chunkLength is greater than dataMaxSize")
- }
- var chunk = frame[dataLenSize : dataLenSize+chunkLength]
+// LocalAddr implements net.Conn
+func (sc *SecretConnection) LocalAddr() net.Addr { return sc.conn.(net.Conn).LocalAddr() }
- n = copy(data, chunk)
- sc.recvBuffer = chunk[n:]
- return
-}
+// RemoteAddr implements net.Conn
+func (sc *SecretConnection) RemoteAddr() net.Addr { return sc.conn.(net.Conn).RemoteAddr() }
-// Implements net.Conn
-func (sc *SecretConnection) Close() error { return sc.conn.Close() }
-func (sc *SecretConnection) LocalAddr() net.Addr { return sc.conn.(net.Conn).LocalAddr() }
-func (sc *SecretConnection) RemoteAddr() net.Addr { return sc.conn.(net.Conn).RemoteAddr() }
+// SetDeadline implements net.Conn
func (sc *SecretConnection) SetDeadline(t time.Time) error { return sc.conn.(net.Conn).SetDeadline(t) }
+
+// SetReadDeadline implements net.Conn
func (sc *SecretConnection) SetReadDeadline(t time.Time) error {
return sc.conn.(net.Conn).SetReadDeadline(t)
}
+
+// SetWriteDeadline implements net.Conn
func (sc *SecretConnection) SetWriteDeadline(t time.Time) error {
return sc.conn.(net.Conn).SetWriteDeadline(t)
}
-func genEphKeys() (ephPub, ephPriv *[32]byte) {
- var err error
- ephPub, ephPriv, err = box.GenerateKey(crand.Reader)
- if err != nil {
- cmn.PanicCrisis("Could not generate ephemeral keypairs")
- }
+func computeSharedSecret(remPubKey, locPrivKey *[32]byte) (shrSecret *[32]byte) {
+ shrSecret = new([32]byte)
+ box.Precompute(shrSecret, remPubKey, locPrivKey)
return
}
-func shareEphPubKey(conn io.ReadWriteCloser, locEphPub *[32]byte) (remEphPub *[32]byte, err error) {
- var err1, err2 error
-
- cmn.Parallel(
- func() {
- _, err1 = conn.Write(locEphPub[:])
- },
- func() {
- remEphPub = new([32]byte)
- _, err2 = io.ReadFull(conn, remEphPub[:])
- },
- )
-
- if err1 != nil {
- return nil, err1
- }
- if err2 != nil {
- return nil, err2
- }
+func genChallenge(loPubKey, hiPubKey *[32]byte) (challenge *[32]byte) {
+ return hash32(append(loPubKey[:], hiPubKey[:]...))
+}
- return remEphPub, nil
+// increment nonce big-endian by 2 with wraparound.
+func incr2Nonce(nonce *[24]byte) {
+ incrNonce(nonce)
+ incrNonce(nonce)
}
-func computeSharedSecret(remPubKey, locPrivKey *[32]byte) (shrSecret *[32]byte) {
- shrSecret = new([32]byte)
- box.Precompute(shrSecret, remPubKey, locPrivKey)
- return
+// increment nonce big-endian by 1 with wraparound.
+func incrNonce(nonce *[24]byte) {
+ for i := 23; 0 <= i; i-- {
+ nonce[i]++
+ if nonce[i] != 0 {
+ return
+ }
+ }
}
-func sort32(foo, bar *[32]byte) (lo, hi *[32]byte) {
- if bytes.Compare(foo[:], bar[:]) < 0 {
- lo = foo
- hi = bar
- } else {
- lo = bar
- hi = foo
+func genEphKeys() (ephPub, ephPriv *[32]byte) {
+ var err error
+ ephPub, ephPriv, err = box.GenerateKey(crand.Reader)
+ if err != nil {
+ cmn.PanicCrisis("Could not generate ephemeral keypairs")
}
return
}
-func genNonces(loPubKey, hiPubKey *[32]byte, locIsLo bool) (recvNonce, sendNonce *[24]byte) {
+func genNonces(loPubKey, hiPubKey *[32]byte, locIsLo bool) (*[24]byte, *[24]byte) {
nonce1 := hash24(append(loPubKey[:], hiPubKey[:]...))
nonce2 := new([24]byte)
copy(nonce2[:], nonce1[:])
nonce2[len(nonce2)-1] ^= 0x01
if locIsLo {
- recvNonce = nonce1
- sendNonce = nonce2
- } else {
- recvNonce = nonce2
- sendNonce = nonce1
+ return nonce1, nonce2
}
- return
-}
-
-func genChallenge(loPubKey, hiPubKey *[32]byte) (challenge *[32]byte) {
- return hash32(append(loPubKey[:], hiPubKey[:]...))
+ return nonce2, nonce1
}
func signChallenge(challenge *[32]byte, locPrivKey crypto.PrivKeyEd25519) (signature crypto.SignatureEd25519) {
return
}
-type authSigMessage struct {
- Key crypto.PubKey
- Sig crypto.Signature
-}
-
func shareAuthSignature(sc *SecretConnection, pubKey crypto.PubKeyEd25519, signature crypto.SignatureEd25519) (*authSigMessage, error) {
var recvMsg authSigMessage
var err1, err2 error
}
n := int(0) // not used.
recvMsg = wire.ReadBinary(authSigMessage{}, bytes.NewBuffer(readBuffer), authSigMsgSize, &n, &err2).(authSigMessage)
- })
+ },
+ )
if err1 != nil {
return nil, err1
if err2 != nil {
return nil, err2
}
-
return &recvMsg, nil
}
-func verifyChallengeSignature(challenge *[32]byte, remPubKey crypto.PubKeyEd25519, remSignature crypto.SignatureEd25519) bool {
- return remPubKey.VerifyBytes(challenge[:], remSignature.Wrap())
+func shareEphPubKey(conn io.ReadWriteCloser, locEphPub *[32]byte) (remEphPub *[32]byte, err error) {
+ var err1, err2 error
+
+ cmn.Parallel(
+ func() {
+ _, err1 = conn.Write(locEphPub[:])
+ },
+ func() {
+ remEphPub = new([32]byte)
+ _, err2 = io.ReadFull(conn, remEphPub[:])
+ },
+ )
+
+ if err1 != nil {
+ return nil, err1
+ }
+ if err2 != nil {
+ return nil, err2
+ }
+ return remEphPub, nil
}
-//--------------------------------------------------------------------------------
+func sort32(foo, bar *[32]byte) (*[32]byte, *[32]byte) {
+ if bytes.Compare(foo[:], bar[:]) < 0 {
+ return foo, bar
+ }
+ return bar, foo
+}
// sha256
func hash32(input []byte) (res *[32]byte) {
copy(res[:], resSlice)
return
}
-
-// ripemd160
-func hash20(input []byte) (res *[20]byte) {
- hasher := ripemd160.New()
- hasher.Write(input) // does not error
- resSlice := hasher.Sum(nil)
- res = new([20]byte)
- copy(res[:], resSlice)
- return
-}
-
-// increment nonce big-endian by 2 with wraparound.
-func incr2Nonce(nonce *[24]byte) {
- incrNonce(nonce)
- incrNonce(nonce)
-}
-
-// increment nonce big-endian by 1 with wraparound.
-func incrNonce(nonce *[24]byte) {
- for i := 23; 0 <= i; i-- {
- nonce[i] += 1
- if nonce[i] != 0 {
- return
- }
- }
-}