OSDN Git Service

rename (#465)
[bytom/vapor.git] / netsync / consensusmgr / block_fetcher.go
index 6278c61..6e28330 100644 (file)
@@ -1,63 +1,73 @@
 package consensusmgr
 
 import (
-       "github.com/sirupsen/logrus"
+       log "github.com/sirupsen/logrus"
        "gopkg.in/karalabe/cookiejar.v2/collections/prque"
 
-       "github.com/vapor/netsync/peers"
-       "github.com/vapor/protocol/bc"
+       "github.com/bytom/vapor/p2p/security"
+       "github.com/bytom/vapor/protocol/bc"
 )
 
 const (
        maxBlockDistance = 64
-       maxMsgSetSize    = 128
        newBlockChSize   = 64
+       msgLimit         = 128 // peer message number limit
 )
 
 // blockFetcher is responsible for accumulating block announcements from various peers
 // and scheduling them for retrieval.
 type blockFetcher struct {
        chain Chain
-       peers *peers.PeerSet
+       peers Peers
 
        newBlockCh chan *blockMsg
-       queue      *prque.Prque
-       msgSet     map[bc.Hash]*blockMsg
+       queue      *prque.Prque          // block import priority queue
+       msgSet     map[bc.Hash]*blockMsg // already queued blocks
+       msgCounter map[string]int        // per peer msg counter to prevent DOS
 }
 
 //NewBlockFetcher creates a block fetcher to retrieve blocks of the new propose.
-func newBlockFetcher(chain Chain, peers *peers.PeerSet) *blockFetcher {
-       f := &blockFetcher{
+func newBlockFetcher(chain Chain, peers Peers) *blockFetcher {
+       return &blockFetcher{
                chain:      chain,
                peers:      peers,
                newBlockCh: make(chan *blockMsg, newBlockChSize),
                queue:      prque.New(),
                msgSet:     make(map[bc.Hash]*blockMsg),
+               msgCounter: make(map[string]int),
        }
-       go f.blockProcessor()
-       return f
 }
 
-func (f *blockFetcher) blockProcessor() {
+func (f *blockFetcher) blockProcessorLoop() {
        for {
-               height := f.chain.BestBlockHeight()
                for !f.queue.Empty() {
                        msg := f.queue.PopItem().(*blockMsg)
-                       if msg.block.Height > height+1 {
+                       if msg.block.Height > f.chain.BestBlockHeight()+1 {
                                f.queue.Push(msg, -float32(msg.block.Height))
                                break
                        }
 
                        f.insert(msg)
                        delete(f.msgSet, msg.block.Hash())
+                       f.msgCounter[msg.peerID]--
+                       if f.msgCounter[msg.peerID] <= 0 {
+                               delete(f.msgCounter, msg.peerID)
+                       }
                }
-               f.add(<-f.newBlockCh)
+               f.add(<-f.newBlockCh, msgLimit)
        }
 }
 
-func (f *blockFetcher) add(msg *blockMsg) {
+func (f *blockFetcher) add(msg *blockMsg, limit int) {
+       // prevent DOS
+       count := f.msgCounter[msg.peerID] + 1
+       if count > limit {
+               log.WithFields(log.Fields{"module": logModule, "peer": msg.peerID, "limit": limit}).Warn("The number of peer messages exceeds the limit")
+               return
+       }
+
        bestHeight := f.chain.BestBlockHeight()
-       if len(f.msgSet) > maxMsgSetSize || bestHeight > msg.block.Height || msg.block.Height-bestHeight > maxBlockDistance {
+       if bestHeight > msg.block.Height || msg.block.Height-bestHeight > maxBlockDistance {
                return
        }
 
@@ -65,7 +75,8 @@ func (f *blockFetcher) add(msg *blockMsg) {
        if _, ok := f.msgSet[blockHash]; !ok {
                f.msgSet[blockHash] = msg
                f.queue.Push(msg, -float32(msg.block.Height))
-               logrus.WithFields(logrus.Fields{
+               f.msgCounter[msg.peerID] = count
+               log.WithFields(log.Fields{
                        "module":       logModule,
                        "block height": msg.block.Height,
                        "block hash":   blockHash.String(),
@@ -80,8 +91,7 @@ func (f *blockFetcher) insert(msg *blockMsg) {
                if peer == nil {
                        return
                }
-
-               f.peers.AddBanScore(msg.peerID, 20, 0, err.Error())
+               f.peers.ProcessIllegal(msg.peerID, security.LevelMsgIllegal, err.Error())
                return
        }
 
@@ -89,16 +99,14 @@ func (f *blockFetcher) insert(msg *blockMsg) {
                return
        }
 
-       hash := msg.block.Hash()
-       f.peers.SetStatus(msg.peerID, msg.block.Height, &hash)
        proposeMsg, err := NewBlockProposeMsg(msg.block)
        if err != nil {
-               logrus.WithFields(logrus.Fields{"module": logModule, "err": err}).Error("failed on create BlockProposeMsg")
+               log.WithFields(log.Fields{"module": logModule, "err": err}).Error("failed on create BlockProposeMsg")
                return
        }
 
        if err := f.peers.BroadcastMsg(NewBroadcastMsg(proposeMsg, consensusChannel)); err != nil {
-               logrus.WithFields(logrus.Fields{"module": logModule, "err": err}).Error("failed on broadcast proposed block")
+               log.WithFields(log.Fields{"module": logModule, "err": err}).Error("failed on broadcast proposed block")
                return
        }
 }