OSDN Git Service

add fast sync func (#204)
[bytom/vapor.git] / netsync / chainmgr / block_keeper.go
index 112fd50..d1407e9 100644 (file)
@@ -1,7 +1,6 @@
 package chainmgr
 
 import (
-       "container/list"
        "time"
 
        log "github.com/sirupsen/logrus"
@@ -15,22 +14,34 @@ import (
 )
 
 const (
-       syncCycle            = 5 * time.Second
-       blockProcessChSize   = 1024
-       blocksProcessChSize  = 128
-       headersProcessChSize = 1024
+       syncCycle = 5 * time.Second
+
+       noNeedSync = iota
+       fastSyncType
+       regularSyncType
 )
 
 var (
-       maxBlockPerMsg        = uint64(128)
-       maxBlockHeadersPerMsg = uint64(2048)
-       syncTimeout           = 30 * time.Second
+       syncTimeout = 30 * time.Second
 
-       errAppendHeaders  = errors.New("fail to append list due to order dismatch")
        errRequestTimeout = errors.New("request timeout")
        errPeerDropped    = errors.New("Peer dropped")
 )
 
+type FastSync interface {
+       locateBlocks(locator []*bc.Hash, stopHash *bc.Hash) ([]*types.Block, error)
+       locateHeaders(locator []*bc.Hash, stopHash *bc.Hash, skip uint64, maxNum uint64) ([]*types.BlockHeader, error)
+       process() error
+       setSyncPeer(peer *peers.Peer)
+}
+
+type Fetcher interface {
+       processBlock(peerID string, block *types.Block)
+       processBlocks(peerID string, blocks []*types.Block)
+       processHeaders(peerID string, headers []*types.BlockHeader)
+       requireBlock(peerID string, height uint64) (*types.Block, error)
+}
+
 type blockMsg struct {
        block  *types.Block
        peerID string
@@ -47,218 +58,60 @@ type headersMsg struct {
 }
 
 type blockKeeper struct {
-       chain Chain
-       peers *peers.PeerSet
-
-       syncPeer         *peers.Peer
-       blockProcessCh   chan *blockMsg
-       blocksProcessCh  chan *blocksMsg
-       headersProcessCh chan *headersMsg
+       chain      Chain
+       fastSync   FastSync
+       msgFetcher Fetcher
+       peers      *peers.PeerSet
+       syncPeer   *peers.Peer
 
-       headerList *list.List
+       quit chan struct{}
 }
 
 func newBlockKeeper(chain Chain, peers *peers.PeerSet) *blockKeeper {
-       bk := &blockKeeper{
-               chain:            chain,
-               peers:            peers,
-               blockProcessCh:   make(chan *blockMsg, blockProcessChSize),
-               blocksProcessCh:  make(chan *blocksMsg, blocksProcessChSize),
-               headersProcessCh: make(chan *headersMsg, headersProcessChSize),
-               headerList:       list.New(),
-       }
-       bk.resetHeaderState()
-       go bk.syncWorker()
-       return bk
-}
-
-func (bk *blockKeeper) appendHeaderList(headers []*types.BlockHeader) error {
-       for _, header := range headers {
-               prevHeader := bk.headerList.Back().Value.(*types.BlockHeader)
-               if prevHeader.Hash() != header.PreviousBlockHash {
-                       return errAppendHeaders
-               }
-               bk.headerList.PushBack(header)
-       }
-       return nil
-}
-
-func (bk *blockKeeper) blockLocator() []*bc.Hash {
-       header := bk.chain.BestBlockHeader()
-       locator := []*bc.Hash{}
-
-       step := uint64(1)
-       for header != nil {
-               headerHash := header.Hash()
-               locator = append(locator, &headerHash)
-               if header.Height == 0 {
-                       break
-               }
-
-               var err error
-               if header.Height < step {
-                       header, err = bk.chain.GetHeaderByHeight(0)
-               } else {
-                       header, err = bk.chain.GetHeaderByHeight(header.Height - step)
-               }
-               if err != nil {
-                       log.WithFields(log.Fields{"module": logModule, "err": err}).Error("blockKeeper fail on get blockLocator")
-                       break
-               }
-
-               if len(locator) >= 9 {
-                       step *= 2
-               }
-       }
-       return locator
-}
-
-func (bk *blockKeeper) fastBlockSync(checkPoint *consensus.Checkpoint) error {
-       bk.resetHeaderState()
-       lastHeader := bk.headerList.Back().Value.(*types.BlockHeader)
-       for ; lastHeader.Hash() != checkPoint.Hash; lastHeader = bk.headerList.Back().Value.(*types.BlockHeader) {
-               if lastHeader.Height >= checkPoint.Height {
-                       return errors.Wrap(peers.ErrPeerMisbehave, "peer is not in the checkpoint branch")
-               }
-
-               lastHash := lastHeader.Hash()
-               headers, err := bk.requireHeaders([]*bc.Hash{&lastHash}, &checkPoint.Hash)
-               if err != nil {
-                       return err
-               }
-
-               if len(headers) == 0 {
-                       return errors.Wrap(peers.ErrPeerMisbehave, "requireHeaders return empty list")
-               }
-
-               if err := bk.appendHeaderList(headers); err != nil {
-                       return err
-               }
+       msgFetcher := newMsgFetcher(peers)
+       return &blockKeeper{
+               chain:      chain,
+               fastSync:   newFastSync(chain, msgFetcher, peers),
+               msgFetcher: msgFetcher,
+               peers:      peers,
+               quit:       make(chan struct{}),
        }
-
-       fastHeader := bk.headerList.Front()
-       for bk.chain.BestBlockHeight() < checkPoint.Height {
-               locator := bk.blockLocator()
-               blocks, err := bk.requireBlocks(locator, &checkPoint.Hash)
-               if err != nil {
-                       return err
-               }
-
-               if len(blocks) == 0 {
-                       return errors.Wrap(peers.ErrPeerMisbehave, "requireBlocks return empty list")
-               }
-
-               for _, block := range blocks {
-                       if fastHeader = fastHeader.Next(); fastHeader == nil {
-                               return errors.New("get block than is higher than checkpoint")
-                       }
-
-                       if _, err = bk.chain.ProcessBlock(block); err != nil {
-                               return errors.Wrap(err, "fail on fastBlockSync process block")
-                       }
-               }
-       }
-       return nil
 }
 
 func (bk *blockKeeper) locateBlocks(locator []*bc.Hash, stopHash *bc.Hash) ([]*types.Block, error) {
-       headers, err := bk.locateHeaders(locator, stopHash)
-       if err != nil {
-               return nil, err
-       }
-
-       blocks := []*types.Block{}
-       for i, header := range headers {
-               if uint64(i) >= maxBlockPerMsg {
-                       break
-               }
-
-               headerHash := header.Hash()
-               block, err := bk.chain.GetBlockByHash(&headerHash)
-               if err != nil {
-                       return nil, err
-               }
-
-               blocks = append(blocks, block)
-       }
-       return blocks, nil
+       return bk.fastSync.locateBlocks(locator, stopHash)
 }
 
-func (bk *blockKeeper) locateHeaders(locator []*bc.Hash, stopHash *bc.Hash) ([]*types.BlockHeader, error) {
-       stopHeader, err := bk.chain.GetHeaderByHash(stopHash)
-       if err != nil {
-               return nil, err
-       }
-
-       startHeader, err := bk.chain.GetHeaderByHeight(0)
-       if err != nil {
-               return nil, err
-       }
-
-       for _, hash := range locator {
-               header, err := bk.chain.GetHeaderByHash(hash)
-               if err == nil && bk.chain.InMainChain(header.Hash()) {
-                       startHeader = header
-                       break
-               }
-       }
-
-       totalHeaders := stopHeader.Height - startHeader.Height
-       if totalHeaders > maxBlockHeadersPerMsg {
-               totalHeaders = maxBlockHeadersPerMsg
-       }
-
-       headers := []*types.BlockHeader{}
-       for i := uint64(1); i <= totalHeaders; i++ {
-               header, err := bk.chain.GetHeaderByHeight(startHeader.Height + i)
-               if err != nil {
-                       return nil, err
-               }
-
-               headers = append(headers, header)
-       }
-       return headers, nil
-}
-
-func (bk *blockKeeper) nextCheckpoint() *consensus.Checkpoint {
-       height := bk.chain.BestBlockHeader().Height
-       checkpoints := consensus.ActiveNetParams.Checkpoints
-       if len(checkpoints) == 0 || height >= checkpoints[len(checkpoints)-1].Height {
-               return nil
-       }
-
-       nextCheckpoint := &checkpoints[len(checkpoints)-1]
-       for i := len(checkpoints) - 2; i >= 0; i-- {
-               if height >= checkpoints[i].Height {
-                       break
-               }
-               nextCheckpoint = &checkpoints[i]
-       }
-       return nextCheckpoint
+func (bk *blockKeeper) locateHeaders(locator []*bc.Hash, stopHash *bc.Hash, skip uint64, maxNum uint64) ([]*types.BlockHeader, error) {
+       return bk.fastSync.locateHeaders(locator, stopHash, skip, maxNum)
 }
 
 func (bk *blockKeeper) processBlock(peerID string, block *types.Block) {
-       bk.blockProcessCh <- &blockMsg{block: block, peerID: peerID}
+       bk.msgFetcher.processBlock(peerID, block)
 }
 
 func (bk *blockKeeper) processBlocks(peerID string, blocks []*types.Block) {
-       bk.blocksProcessCh <- &blocksMsg{blocks: blocks, peerID: peerID}
+       bk.msgFetcher.processBlocks(peerID, blocks)
 }
 
 func (bk *blockKeeper) processHeaders(peerID string, headers []*types.BlockHeader) {
-       bk.headersProcessCh <- &headersMsg{headers: headers, peerID: peerID}
+       bk.msgFetcher.processHeaders(peerID, headers)
 }
 
-func (bk *blockKeeper) regularBlockSync(wantHeight uint64) error {
-       i := bk.chain.BestBlockHeight() + 1
-       for i <= wantHeight {
-               block, err := bk.requireBlock(i)
+func (bk *blockKeeper) regularBlockSync() error {
+       peerHeight := bk.syncPeer.Height()
+       bestHeight := bk.chain.BestBlockHeight()
+       i := bestHeight + 1
+       for i <= peerHeight {
+               block, err := bk.msgFetcher.requireBlock(bk.syncPeer.ID(), i)
                if err != nil {
+                       bk.peers.ErrorHandler(bk.syncPeer.ID(), security.LevelConnException, err)
                        return err
                }
 
                isOrphan, err := bk.chain.ProcessBlock(block)
                if err != nil {
+                       bk.peers.ErrorHandler(bk.syncPeer.ID(), security.LevelMsgIllegal, err)
                        return err
                }
 
@@ -268,115 +121,62 @@ func (bk *blockKeeper) regularBlockSync(wantHeight uint64) error {
                }
                i = bk.chain.BestBlockHeight() + 1
        }
+       log.WithFields(log.Fields{"module": logModule, "height": bk.chain.BestBlockHeight()}).Info("regular sync success")
        return nil
 }
 
-func (bk *blockKeeper) requireBlock(height uint64) (*types.Block, error) {
-       if ok := bk.syncPeer.GetBlockByHeight(height); !ok {
-               return nil, errPeerDropped
-       }
-
-       timeout := time.NewTimer(syncTimeout)
-       defer timeout.Stop()
-
-       for {
-               select {
-               case msg := <-bk.blockProcessCh:
-                       if msg.peerID != bk.syncPeer.ID() {
-                               continue
-                       }
-                       if msg.block.Height != height {
-                               continue
-                       }
-                       return msg.block, nil
-               case <-timeout.C:
-                       return nil, errors.Wrap(errRequestTimeout, "requireBlock")
-               }
-       }
+func (bk *blockKeeper) start() {
+       go bk.syncWorker()
 }
 
-func (bk *blockKeeper) requireBlocks(locator []*bc.Hash, stopHash *bc.Hash) ([]*types.Block, error) {
-       if ok := bk.syncPeer.GetBlocks(locator, stopHash); !ok {
-               return nil, errPeerDropped
+func (bk *blockKeeper) checkSyncType() int {
+       peer := bk.peers.BestIrreversiblePeer(consensus.SFFullNode | consensus.SFFastSync)
+       if peer == nil {
+               log.WithFields(log.Fields{"module": logModule}).Debug("can't find fast sync peer")
+               return noNeedSync
        }
 
-       timeout := time.NewTimer(syncTimeout)
-       defer timeout.Stop()
+       bestHeight := bk.chain.BestBlockHeight()
 
-       for {
-               select {
-               case msg := <-bk.blocksProcessCh:
-                       if msg.peerID != bk.syncPeer.ID() {
-                               continue
-                       }
-                       return msg.blocks, nil
-               case <-timeout.C:
-                       return nil, errors.Wrap(errRequestTimeout, "requireBlocks")
-               }
+       if peerIrreversibleHeight := peer.IrreversibleHeight(); peerIrreversibleHeight >= bestHeight+minGapStartFastSync {
+               bk.fastSync.setSyncPeer(peer)
+               return fastSyncType
        }
-}
 
-func (bk *blockKeeper) requireHeaders(locator []*bc.Hash, stopHash *bc.Hash) ([]*types.BlockHeader, error) {
-       if ok := bk.syncPeer.GetHeaders(locator, stopHash); !ok {
-               return nil, errPeerDropped
+       peer = bk.peers.BestPeer(consensus.SFFullNode)
+       if peer == nil {
+               log.WithFields(log.Fields{"module": logModule}).Debug("can't find sync peer")
+               return noNeedSync
        }
 
-       timeout := time.NewTimer(syncTimeout)
-       defer timeout.Stop()
-
-       for {
-               select {
-               case msg := <-bk.headersProcessCh:
-                       if msg.peerID != bk.syncPeer.ID() {
-                               continue
-                       }
-                       return msg.headers, nil
-               case <-timeout.C:
-                       return nil, errors.Wrap(errRequestTimeout, "requireHeaders")
-               }
+       peerHeight := peer.Height()
+       if peerHeight > bestHeight {
+               bk.syncPeer = peer
+               return regularSyncType
        }
-}
 
-// resetHeaderState sets the headers-first mode state to values appropriate for
-// syncing from a new peer.
-func (bk *blockKeeper) resetHeaderState() {
-       header := bk.chain.BestBlockHeader()
-       bk.headerList.Init()
-       if bk.nextCheckpoint() != nil {
-               bk.headerList.PushBack(header)
-       }
+       return noNeedSync
 }
 
 func (bk *blockKeeper) startSync() bool {
-       checkPoint := bk.nextCheckpoint()
-       peer := bk.peers.BestPeer(consensus.SFFastSync | consensus.SFFullNode)
-       if peer != nil && checkPoint != nil && peer.Height() >= checkPoint.Height {
-               bk.syncPeer = peer
-               if err := bk.fastBlockSync(checkPoint); err != nil {
-                       log.WithFields(log.Fields{"module": logModule, "err": err}).Warning("fail on fastBlockSync")
-                       bk.peers.ErrorHandler(peer.ID(), security.LevelMsgIllegal, err)
+       switch bk.checkSyncType() {
+       case fastSyncType:
+               if err := bk.fastSync.process(); err != nil {
+                       log.WithFields(log.Fields{"module": logModule, "err": err}).Warning("failed on fast sync")
                        return false
                }
-               return true
-       }
-
-       blockHeight := bk.chain.BestBlockHeight()
-       peer = bk.peers.BestPeer(consensus.SFFullNode)
-       if peer != nil && peer.Height() > blockHeight {
-               bk.syncPeer = peer
-               targetHeight := blockHeight + maxBlockPerMsg
-               if targetHeight > peer.Height() {
-                       targetHeight = peer.Height()
-               }
-
-               if err := bk.regularBlockSync(targetHeight); err != nil {
+       case regularSyncType:
+               if err := bk.regularBlockSync(); err != nil {
                        log.WithFields(log.Fields{"module": logModule, "err": err}).Warning("fail on regularBlockSync")
-                       bk.peers.ErrorHandler(peer.ID(),security.LevelMsgIllegal, err)
                        return false
                }
-               return true
        }
-       return false
+
+       return true
+}
+
+func (bk *blockKeeper) stop() {
+       close(bk.quit)
 }
 
 func (bk *blockKeeper) syncWorker() {
@@ -384,18 +184,17 @@ func (bk *blockKeeper) syncWorker() {
        defer syncTicker.Stop()
 
        for {
-               <-syncTicker.C
-               if update := bk.startSync(); !update {
-                       continue
-               }
-
-               block, err := bk.chain.GetBlockByHeight(bk.chain.BestBlockHeight())
-               if err != nil {
-                       log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on syncWorker get best block")
-               }
+               select {
+               case <-syncTicker.C:
+                       if update := bk.startSync(); !update {
+                               continue
+                       }
 
-               if err = bk.peers.BroadcastNewStatus(block); err != nil {
-                       log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on syncWorker broadcast new status")
+                       if err := bk.peers.BroadcastNewStatus(bk.chain.BestBlockHeader(), bk.chain.BestIrreversibleHeader()); err != nil {
+                               log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on syncWorker broadcast new status")
+                       }
+               case <-bk.quit:
+                       return
                }
        }
 }