OSDN Git Service

Fix fast sync pending when all request blocks timeout (#347)
[bytom/vapor.git] / netsync / chainmgr / msg_fetcher.go
index f635667..1435602 100644 (file)
 package chainmgr
 
 import (
+       "sync"
        "time"
 
+       log "github.com/sirupsen/logrus"
+
        "github.com/vapor/errors"
        "github.com/vapor/netsync/peers"
+       "github.com/vapor/p2p/security"
        "github.com/vapor/protocol/bc"
        "github.com/vapor/protocol/bc/types"
 )
 
 const (
-       blockProcessChSize   = 1024
-       blocksProcessChSize  = 128
-       headersProcessChSize = 1024
+       maxNumOfParallelFetchBlocks = 7
+       blockProcessChSize          = 1024
+       blocksProcessChSize         = 128
+       headersProcessChSize        = 1024
+       maxNumOfFastSyncPeers       = 128
 )
 
-type msgFetcher struct {
-       peers *peers.PeerSet
+var (
+       requireBlockTimeout      = 20 * time.Second
+       requireHeadersTimeout    = 30 * time.Second
+       requireBlocksTimeout     = 50 * time.Second
+       checkSyncPeerNumInterval = 5 * time.Second
+
+       errRequestBlocksTimeout = errors.New("request blocks timeout")
+       errRequestTimeout       = errors.New("request timeout")
+       errPeerDropped          = errors.New("Peer dropped")
+       errSendMsg              = errors.New("send message error")
+)
+
+type MsgFetcher interface {
+       resetParameter()
+       addSyncPeer(peerID string)
+       requireBlock(peerID string, height uint64) (*types.Block, error)
+       parallelFetchBlocks(work []*fetchBlocksWork, downloadNotifyCh chan struct{}, ProcessStopCh chan struct{}, wg *sync.WaitGroup)
+       parallelFetchHeaders(peers []*peers.Peer, locator []*bc.Hash, stopHash *bc.Hash, skip uint64) map[string][]*types.BlockHeader
+}
 
+type fetchBlocksWork struct {
+       startHeader, stopHeader *types.BlockHeader
+}
+
+type fetchBlocksResult struct {
+       startHeight, stopHeight uint64
+       err                     error
+}
+
+type msgFetcher struct {
+       storage          Storage
+       syncPeers        *fastSyncPeers
+       peers            *peers.PeerSet
        blockProcessCh   chan *blockMsg
        blocksProcessCh  chan *blocksMsg
        headersProcessCh chan *headersMsg
+       blocksMsgChanMap map[string]chan []*types.Block
+       mux              sync.RWMutex
 }
 
-func newMsgFetcher(peers *peers.PeerSet) *msgFetcher {
+func newMsgFetcher(storage Storage, peers *peers.PeerSet) *msgFetcher {
        return &msgFetcher{
+               storage:          storage,
+               syncPeers:        newFastSyncPeers(),
                peers:            peers,
                blockProcessCh:   make(chan *blockMsg, blockProcessChSize),
                blocksProcessCh:  make(chan *blocksMsg, blocksProcessChSize),
                headersProcessCh: make(chan *headersMsg, headersProcessChSize),
+               blocksMsgChanMap: make(map[string]chan []*types.Block),
+       }
+}
+
+func (mf *msgFetcher) addSyncPeer(peerID string) {
+       mf.syncPeers.add(peerID)
+}
+
+func (mf *msgFetcher) collectResultLoop(peerCh chan string, quit chan struct{}, resultCh chan *fetchBlocksResult, workerCloseCh chan struct{}, workSize int) {
+       defer close(workerCloseCh)
+       ticker := time.NewTicker(checkSyncPeerNumInterval)
+       defer ticker.Stop()
+
+       //collect fetch results
+       for resultCount := 0; resultCount < workSize && mf.syncPeers.size() > 0; {
+               select {
+               case result := <-resultCh:
+                       resultCount++
+                       if result.err != nil {
+                               log.WithFields(log.Fields{"module": logModule, "startHeight": result.startHeight, "stopHeight": result.stopHeight, "err": result.err}).Error("failed on fetch blocks")
+                               return
+                       }
+
+                       peer, err := mf.syncPeers.selectIdlePeer()
+                       if err != nil {
+                               log.WithFields(log.Fields{"module": logModule, "err": result.err}).Warn("failed on find fast sync peer")
+                               break
+                       }
+                       peerCh <- peer
+               case <-ticker.C:
+                       if mf.syncPeers.size() == 0 {
+                               log.WithFields(log.Fields{"module": logModule}).Warn("num of fast sync peer is 0")
+                               return
+                       }
+               case _, ok := <-quit:
+                       if !ok {
+                               return
+                       }
+               }
+       }
+}
+
+func (mf *msgFetcher) fetchBlocks(work *fetchBlocksWork, peerID string) ([]*types.Block, error) {
+       defer mf.syncPeers.setIdle(peerID)
+       startHash := work.startHeader.Hash()
+       stopHash := work.stopHeader.Hash()
+       blocks, err := mf.requireBlocks(peerID, []*bc.Hash{&startHash}, &stopHash)
+       if err != nil {
+               mf.syncPeers.delete(peerID)
+               mf.peers.ProcessIllegal(peerID, security.LevelConnException, err.Error())
+               return nil, err
+       }
+
+       if err := mf.verifyBlocksMsg(blocks, work.startHeader, work.stopHeader); err != nil {
+               mf.syncPeers.delete(peerID)
+               mf.peers.ProcessIllegal(peerID, security.LevelConnException, err.Error())
+               return nil, err
+       }
+
+       return blocks, nil
+}
+
+func (mf *msgFetcher) fetchBlocksProcess(work *fetchBlocksWork, peerCh chan string, downloadNotifyCh chan struct{}, closeCh chan struct{}) error {
+       for {
+               select {
+               case peerID := <-peerCh:
+                       for {
+                               blocks, err := mf.fetchBlocks(work, peerID)
+                               if err != nil {
+                                       log.WithFields(log.Fields{"module": logModule, "startHeight": work.startHeader.Height, "stopHeight": work.stopHeader.Height, "error": err}).Info("failed on fetch blocks")
+                                       break
+                               }
+
+                               if err := mf.storage.writeBlocks(peerID, blocks); err != nil {
+                                       log.WithFields(log.Fields{"module": logModule, "error": err}).Info("write block error")
+                                       return err
+                               }
+
+                               // send to block process pool
+                               select {
+                               case downloadNotifyCh <- struct{}{}:
+                               default:
+                               }
+
+                               // work completed
+                               if blocks[len(blocks)-1].Height >= work.stopHeader.Height-1 {
+                                       return nil
+                               }
+
+                               //unfinished work, continue
+                               work.startHeader = &blocks[len(blocks)-1].BlockHeader
+                       }
+               case <-closeCh:
+                       return nil
+               }
+       }
+}
+
+func (mf *msgFetcher) fetchBlocksWorker(workCh chan *fetchBlocksWork, peerCh chan string, resultCh chan *fetchBlocksResult, closeCh chan struct{}, downloadNotifyCh chan struct{}, wg *sync.WaitGroup) {
+       for {
+               select {
+               case work := <-workCh:
+                       err := mf.fetchBlocksProcess(work, peerCh, downloadNotifyCh, closeCh)
+                       resultCh <- &fetchBlocksResult{startHeight: work.startHeader.Height, stopHeight: work.stopHeader.Height, err: err}
+               case <-closeCh:
+                       wg.Done()
+                       return
+               }
+       }
+}
+
+func (mf *msgFetcher) parallelFetchBlocks(works []*fetchBlocksWork, downloadNotifyCh chan struct{}, ProcessStopCh chan struct{}, wg *sync.WaitGroup) {
+       workSize := len(works)
+       workCh := make(chan *fetchBlocksWork, workSize)
+       peerCh := make(chan string, maxNumOfFastSyncPeers)
+       resultCh := make(chan *fetchBlocksResult, workSize)
+       closeCh := make(chan struct{})
+
+       for _, work := range works {
+               workCh <- work
+       }
+       syncPeers := mf.syncPeers.selectIdlePeers()
+       for i := 0; i < len(syncPeers) && i < maxNumOfFastSyncPeers; i++ {
+               peerCh <- syncPeers[i]
+       }
+
+       var workWg sync.WaitGroup
+       for i := 0; i <= maxNumOfParallelFetchBlocks && i < workSize; i++ {
+               workWg.Add(1)
+               go mf.fetchBlocksWorker(workCh, peerCh, resultCh, closeCh, downloadNotifyCh, &workWg)
+       }
+
+       go mf.collectResultLoop(peerCh, ProcessStopCh, resultCh, closeCh, workSize)
+
+       workWg.Wait()
+       close(resultCh)
+       close(peerCh)
+       close(workCh)
+       close(downloadNotifyCh)
+       wg.Done()
+}
+
+func (mf *msgFetcher) parallelFetchHeaders(peers []*peers.Peer, locator []*bc.Hash, stopHash *bc.Hash, skip uint64) map[string][]*types.BlockHeader {
+       result := make(map[string][]*types.BlockHeader)
+       response := make(map[string]bool)
+       for _, peer := range peers {
+               if ok := peer.GetHeaders(locator, stopHash, skip); !ok {
+                       continue
+               }
+               result[peer.ID()] = nil
+       }
+
+       timeout := time.NewTimer(requireHeadersTimeout)
+       defer timeout.Stop()
+       for {
+               select {
+               case msg := <-mf.headersProcessCh:
+                       if _, ok := result[msg.peerID]; ok {
+                               result[msg.peerID] = append(result[msg.peerID], msg.headers[:]...)
+                               response[msg.peerID] = true
+                               if len(response) == len(result) {
+                                       return result
+                               }
+                       }
+               case <-timeout.C:
+                       log.WithFields(log.Fields{"module": logModule, "err": errRequestTimeout}).Warn("failed on parallel fetch headers")
+                       return result
+               }
        }
 }
 
@@ -38,6 +246,15 @@ func (mf *msgFetcher) processBlock(peerID string, block *types.Block) {
 
 func (mf *msgFetcher) processBlocks(peerID string, blocks []*types.Block) {
        mf.blocksProcessCh <- &blocksMsg{blocks: blocks, peerID: peerID}
+       mf.mux.RLock()
+       blocksMsgChan, ok := mf.blocksMsgChanMap[peerID]
+       mf.mux.RUnlock()
+       if !ok {
+               mf.peers.ProcessIllegal(peerID, security.LevelMsgIllegal, "msg from unsolicited peer")
+               return
+       }
+
+       blocksMsgChan <- blocks
 }
 
 func (mf *msgFetcher) processHeaders(peerID string, headers []*types.BlockHeader) {
@@ -51,10 +268,10 @@ func (mf *msgFetcher) requireBlock(peerID string, height uint64) (*types.Block,
        }
 
        if ok := peer.GetBlockByHeight(height); !ok {
-               return nil, errPeerDropped
+               return nil, errSendMsg
        }
 
-       timeout := time.NewTimer(syncTimeout)
+       timeout := time.NewTimer(requireBlockTimeout)
        defer timeout.Stop()
 
        for {
@@ -76,53 +293,66 @@ func (mf *msgFetcher) requireBlock(peerID string, height uint64) (*types.Block,
 func (mf *msgFetcher) requireBlocks(peerID string, locator []*bc.Hash, stopHash *bc.Hash) ([]*types.Block, error) {
        peer := mf.peers.GetPeer(peerID)
        if peer == nil {
+               mf.syncPeers.delete(peerID)
                return nil, errPeerDropped
        }
 
+       receiveCh := make(chan []*types.Block, 1)
+       mf.mux.Lock()
+       mf.blocksMsgChanMap[peerID] = receiveCh
+       mf.mux.Unlock()
+
        if ok := peer.GetBlocks(locator, stopHash); !ok {
-               return nil, errPeerDropped
+               return nil, errSendMsg
        }
 
-       timeout := time.NewTimer(syncTimeout)
+       timeout := time.NewTimer(requireBlocksTimeout)
        defer timeout.Stop()
+       select {
+       case blocks := <-receiveCh:
+               return blocks, nil
+       case <-timeout.C:
+               return nil, errRequestBlocksTimeout
+       }
+}
 
+func (mf *msgFetcher) resetParameter() {
+       mf.blocksMsgChanMap = make(map[string]chan []*types.Block)
+       mf.syncPeers = newFastSyncPeers()
+       mf.storage.resetParameter()
+       //empty chan
        for {
                select {
-               case msg := <-mf.blocksProcessCh:
-                       if msg.peerID != peerID {
-                               continue
-                       }
-
-                       return msg.blocks, nil
-               case <-timeout.C:
-                       return nil, errors.Wrap(errRequestTimeout, "requireBlocks")
+               case <-mf.blocksProcessCh:
+               case <-mf.headersProcessCh:
+               default:
+                       return
                }
        }
 }
 
-func (mf *msgFetcher) requireHeaders(peerID string, locator []*bc.Hash, stopHash *bc.Hash, skip uint64) ([]*types.BlockHeader, error) {
-       peer := mf.peers.GetPeer(peerID)
-       if peer == nil {
-               return nil, errPeerDropped
+func (mf *msgFetcher) verifyBlocksMsg(blocks []*types.Block, startHeader, stopHeader *types.BlockHeader) error {
+       // null blocks
+       if len(blocks) == 0 {
+               return errors.New("null blocks msg")
        }
 
-       if ok := peer.GetHeaders(locator, stopHash, skip); !ok {
-               return nil, errPeerDropped
+       // blocks more than request
+       if uint64(len(blocks)) > stopHeader.Height-startHeader.Height+1 {
+               return errors.New("exceed length blocks msg")
        }
 
-       timeout := time.NewTimer(syncTimeout)
-       defer timeout.Stop()
-
-       for {
-               select {
-               case msg := <-mf.headersProcessCh:
-                       if msg.peerID != peerID {
-                               continue
-                       }
+       // verify start block
+       if blocks[0].Hash() != startHeader.Hash() {
+               return errors.New("get mismatch blocks msg")
+       }
 
-                       return msg.headers, nil
-               case <-timeout.C:
-                       return nil, errors.Wrap(errRequestTimeout, "requireHeaders")
+       // verify blocks continuity
+       for i := 0; i < len(blocks)-1; i++ {
+               if blocks[i].Hash() != blocks[i+1].PreviousBlockHash {
+                       return errors.New("get discontinuous blocks msg")
                }
        }
+
+       return nil
 }