OSDN Git Service

Peer add announces new block message num limit (#378)
[bytom/vapor.git] / netsync / consensusmgr / block_fetcher.go
index 8c28ff9..96f7fed 100644 (file)
@@ -1,7 +1,7 @@
 package consensusmgr
 
 import (
-       "github.com/sirupsen/logrus"
+       log "github.com/sirupsen/logrus"
        "gopkg.in/karalabe/cookiejar.v2/collections/prque"
 
        "github.com/vapor/p2p/security"
@@ -10,8 +10,8 @@ import (
 
 const (
        maxBlockDistance = 64
-       maxMsgSetSize    = 128
        newBlockChSize   = 64
+       msgLimit         = 128 // peer message number limit
 )
 
 // blockFetcher is responsible for accumulating block announcements from various peers
@@ -21,24 +21,24 @@ type blockFetcher struct {
        peers Peers
 
        newBlockCh chan *blockMsg
-       queue      *prque.Prque
-       msgSet     map[bc.Hash]*blockMsg
+       queue      *prque.Prque          // block import priority queue
+       msgSet     map[bc.Hash]*blockMsg // already queued blocks
+       msgCounter map[string]int        // per peer msg counter to prevent DOS
 }
 
 //NewBlockFetcher creates a block fetcher to retrieve blocks of the new propose.
 func newBlockFetcher(chain Chain, peers Peers) *blockFetcher {
-       f := &blockFetcher{
+       return &blockFetcher{
                chain:      chain,
                peers:      peers,
                newBlockCh: make(chan *blockMsg, newBlockChSize),
                queue:      prque.New(),
                msgSet:     make(map[bc.Hash]*blockMsg),
+               msgCounter: make(map[string]int),
        }
-       go f.blockProcessor()
-       return f
 }
 
-func (f *blockFetcher) blockProcessor() {
+func (f *blockFetcher) blockProcessorLoop() {
        for {
                for !f.queue.Empty() {
                        msg := f.queue.PopItem().(*blockMsg)
@@ -49,14 +49,25 @@ func (f *blockFetcher) blockProcessor() {
 
                        f.insert(msg)
                        delete(f.msgSet, msg.block.Hash())
+                       f.msgCounter[msg.peerID]--
+                       if f.msgCounter[msg.peerID] <= 0 {
+                               delete(f.msgCounter, msg.peerID)
+                       }
                }
-               f.add(<-f.newBlockCh)
+               f.add(<-f.newBlockCh, msgLimit)
        }
 }
 
-func (f *blockFetcher) add(msg *blockMsg) {
+func (f *blockFetcher) add(msg *blockMsg, limit int) {
+       // prevent DOS
+       count := f.msgCounter[msg.peerID] + 1
+       if count > limit {
+               log.WithFields(log.Fields{"module": logModule, "peer": msg.peerID, "limit": limit}).Warn("The number of peer messages exceeds the limit")
+               return
+       }
+
        bestHeight := f.chain.BestBlockHeight()
-       if len(f.msgSet) > maxMsgSetSize || bestHeight > msg.block.Height || msg.block.Height-bestHeight > maxBlockDistance {
+       if bestHeight > msg.block.Height || msg.block.Height-bestHeight > maxBlockDistance {
                return
        }
 
@@ -64,7 +75,8 @@ func (f *blockFetcher) add(msg *blockMsg) {
        if _, ok := f.msgSet[blockHash]; !ok {
                f.msgSet[blockHash] = msg
                f.queue.Push(msg, -float32(msg.block.Height))
-               logrus.WithFields(logrus.Fields{
+               f.msgCounter[msg.peerID] = count
+               log.WithFields(log.Fields{
                        "module":       logModule,
                        "block height": msg.block.Height,
                        "block hash":   blockHash.String(),
@@ -79,7 +91,6 @@ func (f *blockFetcher) insert(msg *blockMsg) {
                if peer == nil {
                        return
                }
-
                f.peers.ProcessIllegal(msg.peerID, security.LevelMsgIllegal, err.Error())
                return
        }
@@ -90,12 +101,12 @@ func (f *blockFetcher) insert(msg *blockMsg) {
 
        proposeMsg, err := NewBlockProposeMsg(msg.block)
        if err != nil {
-               logrus.WithFields(logrus.Fields{"module": logModule, "err": err}).Error("failed on create BlockProposeMsg")
+               log.WithFields(log.Fields{"module": logModule, "err": err}).Error("failed on create BlockProposeMsg")
                return
        }
 
        if err := f.peers.BroadcastMsg(NewBroadcastMsg(proposeMsg, consensusChannel)); err != nil {
-               logrus.WithFields(logrus.Fields{"module": logModule, "err": err}).Error("failed on broadcast proposed block")
+               log.WithFields(log.Fields{"module": logModule, "err": err}).Error("failed on broadcast proposed block")
                return
        }
 }