OSDN Git Service

add parallel fast sync support (#238)
[bytom/vapor.git] / netsync / chainmgr / fast_sync.go
index 6a52223..06c67c8 100644 (file)
@@ -1,55 +1,48 @@
 package chainmgr
 
 import (
+       "sync"
+
        log "github.com/sirupsen/logrus"
 
        "github.com/vapor/errors"
        "github.com/vapor/netsync/peers"
-       "github.com/vapor/p2p/security"
        "github.com/vapor/protocol/bc"
        "github.com/vapor/protocol/bc/types"
 )
 
 var (
-       maxBlocksPerMsg      = uint64(1000)
-       maxHeadersPerMsg     = uint64(1000)
-       fastSyncPivotGap     = uint64(64)
-       minGapStartFastSync  = uint64(128)
-       maxFastSyncBlocksNum = uint64(10000)
+       maxNumOfSkeletonPerSync = uint64(10)
+       numOfBlocksSkeletonGap  = maxNumOfBlocksPerMsg
+       maxNumOfBlocksPerSync   = numOfBlocksSkeletonGap * maxNumOfSkeletonPerSync
+       fastSyncPivotGap        = uint64(64)
+       minGapStartFastSync     = uint64(128)
 
-       errOrphanBlock = errors.New("fast sync block is orphan")
+       errNoSyncPeer = errors.New("can't find sync peer")
 )
 
-type MsgFetcher interface {
-       requireBlock(peerID string, height uint64) (*types.Block, error)
-       requireBlocks(peerID string, locator []*bc.Hash, stopHash *bc.Hash) ([]*types.Block, error)
-}
-
 type fastSync struct {
-       chain      Chain
-       msgFetcher MsgFetcher
-       peers      *peers.PeerSet
-       syncPeer   *peers.Peer
-       stopHeader *types.BlockHeader
-       length     uint64
-
-       quite chan struct{}
+       chain          Chain
+       msgFetcher     MsgFetcher
+       blockProcessor BlockProcessor
+       peers          *peers.PeerSet
+       mainSyncPeer   *peers.Peer
 }
 
-func newFastSync(chain Chain, msgFether MsgFetcher, peers *peers.PeerSet) *fastSync {
+func newFastSync(chain Chain, msgFetcher MsgFetcher, storage Storage, peers *peers.PeerSet) *fastSync {
        return &fastSync{
-               chain:      chain,
-               msgFetcher: msgFether,
-               peers:      peers,
-               quite:      make(chan struct{}),
+               chain:          chain,
+               msgFetcher:     msgFetcher,
+               blockProcessor: newBlockProcessor(chain, storage, peers),
+               peers:          peers,
        }
 }
 
 func (fs *fastSync) blockLocator() []*bc.Hash {
        header := fs.chain.BestBlockHeader()
        locator := []*bc.Hash{}
-
        step := uint64(1)
+
        for header != nil {
                headerHash := header.Hash()
                locator = append(locator, &headerHash)
@@ -75,118 +68,88 @@ func (fs *fastSync) blockLocator() []*bc.Hash {
        return locator
 }
 
-func (fs *fastSync) process() error {
-       if err := fs.findFastSyncRange(); err != nil {
-               return err
-       }
-
-       stopHash := fs.stopHeader.Hash()
-       for fs.chain.BestBlockHeight() < fs.stopHeader.Height {
-               blocks, err := fs.msgFetcher.requireBlocks(fs.syncPeer.ID(), fs.blockLocator(), &stopHash)
-               if err != nil {
-                       fs.peers.ErrorHandler(fs.syncPeer.ID(), security.LevelConnException, err)
-                       return err
-               }
-
-               if err := fs.verifyBlocks(blocks); err != nil {
-                       fs.peers.ErrorHandler(fs.syncPeer.ID(), security.LevelMsgIllegal, err)
-                       return err
-               }
-       }
-
-       log.WithFields(log.Fields{"module": logModule, "height": fs.chain.BestBlockHeight()}).Info("fast sync success")
-       return nil
-}
-
-func (fs *fastSync) findFastSyncRange() error {
-       bestHeight := fs.chain.BestBlockHeight()
-       fs.length = fs.syncPeer.IrreversibleHeight() - fastSyncPivotGap - bestHeight
-       if fs.length > maxFastSyncBlocksNum {
-               fs.length = maxFastSyncBlocksNum
+// createFetchBlocksTasks get the skeleton and assign tasks according to the skeleton.
+func (fs *fastSync) createFetchBlocksTasks(stopBlock *types.Block) ([]*fetchBlocksWork, error) {
+       // Find peers that meet the height requirements.
+       peers := fs.peers.GetPeersByHeight(stopBlock.Height + fastSyncPivotGap)
+       if len(peers) == 0 {
+               return nil, errNoSyncPeer
        }
 
-       stopBlock, err := fs.msgFetcher.requireBlock(fs.syncPeer.ID(), bestHeight+fs.length)
-       if err != nil {
-               return err
+       // parallel fetch the skeleton from peers.
+       stopHash := stopBlock.Hash()
+       skeletonMap := fs.msgFetcher.parallelFetchHeaders(peers, fs.blockLocator(), &stopHash, numOfBlocksSkeletonGap-1)
+       if len(skeletonMap) == 0 {
+               return nil, errors.New("No skeleton found")
        }
 
-       fs.stopHeader = &stopBlock.BlockHeader
-       return nil
-}
-
-func (fs *fastSync) locateBlocks(locator []*bc.Hash, stopHash *bc.Hash) ([]*types.Block, error) {
-       headers, err := fs.locateHeaders(locator, stopHash, 0, maxBlocksPerMsg)
-       if err != nil {
-               return nil, err
+       mainSkeleton, ok := skeletonMap[fs.mainSyncPeer.ID()]
+       if !ok {
+               return nil, errors.New("No main skeleton found")
        }
 
-       blocks := []*types.Block{}
-       for _, header := range headers {
-               headerHash := header.Hash()
-               block, err := fs.chain.GetBlockByHash(&headerHash)
-               if err != nil {
-                       return nil, err
+       // collect peers that match the skeleton of the primary sync peer
+       fs.msgFetcher.addSyncPeer(fs.mainSyncPeer.ID())
+       delete(skeletonMap, fs.mainSyncPeer.ID())
+       for peerID, skeleton := range skeletonMap {
+               if len(skeleton) != len(mainSkeleton) {
+                       log.WithFields(log.Fields{"module": logModule, "main skeleton": len(mainSkeleton), "got skeleton": len(skeleton)}).Warn("different skeleton length")
+                       continue
                }
 
-               blocks = append(blocks, block)
+               for i, header := range skeleton {
+                       if header.Hash() != mainSkeleton[i].Hash() {
+                               log.WithFields(log.Fields{"module": logModule, "header index": i, "main skeleton": mainSkeleton[i].Hash(), "got skeleton": header.Hash()}).Warn("different skeleton hash")
+                               continue
+                       }
+               }
+               fs.msgFetcher.addSyncPeer(peerID)
        }
-       return blocks, nil
-}
 
-func (fs *fastSync) locateHeaders(locator []*bc.Hash, stopHash *bc.Hash, skip uint64, maxNum uint64) ([]*types.BlockHeader, error) {
-       startHeader, err := fs.chain.GetHeaderByHeight(0)
-       if err != nil {
-               return nil, err
+       blockFetchTasks := make([]*fetchBlocksWork, 0)
+       // create download task
+       for i := 0; i < len(mainSkeleton)-1; i++ {
+               blockFetchTasks = append(blockFetchTasks, &fetchBlocksWork{startHeader: mainSkeleton[i], stopHeader: mainSkeleton[i+1]})
        }
 
-       for _, hash := range locator {
-               header, err := fs.chain.GetHeaderByHash(hash)
-               if err == nil && fs.chain.InMainChain(header.Hash()) {
-                       startHeader = header
-                       break
-               }
-       }
+       return blockFetchTasks, nil
+}
 
-       headers := make([]*types.BlockHeader, 0)
-       stopHeader, err := fs.chain.GetHeaderByHash(stopHash)
+func (fs *fastSync) process() error {
+       stopBlock, err := fs.findSyncRange()
        if err != nil {
-               return headers, nil
+               return err
        }
 
-       if !fs.chain.InMainChain(*stopHash) {
-               return headers, nil
+       tasks, err := fs.createFetchBlocksTasks(stopBlock)
+       if err != nil {
+               return err
        }
 
-       num := uint64(0)
-       for i := startHeader.Height; i <= stopHeader.Height && num < maxNum; i += skip + 1 {
-               header, err := fs.chain.GetHeaderByHeight(i)
-               if err != nil {
-                       return nil, err
-               }
+       downloadNotifyCh := make(chan struct{}, 1)
+       processStopCh := make(chan struct{})
+       var wg sync.WaitGroup
+       wg.Add(2)
+       go fs.msgFetcher.parallelFetchBlocks(tasks, downloadNotifyCh, processStopCh, &wg)
+       go fs.blockProcessor.process(downloadNotifyCh, processStopCh, &wg)
+       wg.Wait()
+       fs.msgFetcher.resetParameter()
+       log.WithFields(log.Fields{"module": logModule, "height": fs.chain.BestBlockHeight()}).Info("fast sync complete")
+       return nil
+}
 
-               headers = append(headers, header)
-               num++
+// findSyncRange find the start and end of this sync.
+// sync length cannot be greater than maxFastSyncBlocksNum.
+func (fs *fastSync) findSyncRange() (*types.Block, error) {
+       bestHeight := fs.chain.BestBlockHeight()
+       length := fs.mainSyncPeer.IrreversibleHeight() - fastSyncPivotGap - bestHeight
+       if length > maxNumOfBlocksPerSync {
+               length = maxNumOfBlocksPerSync
        }
 
-       return headers, nil
+       return fs.msgFetcher.requireBlock(fs.mainSyncPeer.ID(), bestHeight+length)
 }
 
 func (fs *fastSync) setSyncPeer(peer *peers.Peer) {
-       fs.syncPeer = peer
-}
-
-func (fs *fastSync) verifyBlocks(blocks []*types.Block) error {
-       for _, block := range blocks {
-               isOrphan, err := fs.chain.ProcessBlock(block)
-               if err != nil {
-                       return err
-               }
-
-               if isOrphan {
-                       log.WithFields(log.Fields{"module": logModule, "height": block.Height, "hash": block.Hash()}).Error("fast sync block is orphan")
-                       return errOrphanBlock
-               }
-       }
-
-       return nil
+       fs.mainSyncPeer = peer
 }