OSDN Git Service

add parallel fast sync support (#238)
authoryahtoo <yahtoo.ma@gmail.com>
Thu, 11 Jul 2019 06:04:09 +0000 (14:04 +0800)
committerPaladz <yzhu101@uottawa.ca>
Thu, 11 Jul 2019 06:04:09 +0000 (14:04 +0800)
* Add parallel require blocks

* Add fast sync block storage and process

* add fetch blocks timeout process

* Fix oops bug

* add timeoutQueue func

* Fix review bug

* Opz code format

* modify parallelFetchHeaders function

* modify select sync peer logic

* Del unused code

* add blocksTasks struct

* Modify block parallel download mode

* Opz code format

* Fix test case error

* opz code format

* Add fast sync peer err handle

* Fix test case err

* Fix review bugs

* Fix review bugs

* Fix review bugs

* Add test file

* Fix review bug

* Fix fetch headers stopHeader lost err

* Fix locate headers bug

* Opz code format

* Fix review bug

16 files changed:
netsync/chainmgr/block_keeper.go
netsync/chainmgr/block_keeper_test.go
netsync/chainmgr/block_process.go [new file with mode: 0644]
netsync/chainmgr/block_process_test.go [new file with mode: 0644]
netsync/chainmgr/fast_sync.go
netsync/chainmgr/fast_sync_test.go
netsync/chainmgr/handle.go
netsync/chainmgr/msg_fetcher.go
netsync/chainmgr/peers.go [new file with mode: 0644]
netsync/chainmgr/storage.go [new file with mode: 0644]
netsync/chainmgr/storage_test.go [new file with mode: 0644]
netsync/chainmgr/tool_test.go
netsync/chainmgr/tx_keeper_test.go
netsync/peers/peer.go
netsync/sync_manager.go
node/node.go

index 24b4ef3..43f8521 100644 (file)
@@ -6,7 +6,7 @@ import (
        log "github.com/sirupsen/logrus"
 
        "github.com/vapor/consensus"
-       "github.com/vapor/errors"
+       dbm "github.com/vapor/database/leveldb"
        "github.com/vapor/netsync/peers"
        "github.com/vapor/p2p/security"
        "github.com/vapor/protocol/bc"
@@ -22,15 +22,11 @@ const (
 )
 
 var (
-       syncTimeout = 30 * time.Second
-
-       errRequestTimeout = errors.New("request timeout")
-       errPeerDropped    = errors.New("Peer dropped")
+       maxNumOfBlocksPerMsg  = uint64(1000)
+       maxNumOfHeadersPerMsg = uint64(1000)
 )
 
 type FastSync interface {
-       locateBlocks(locator []*bc.Hash, stopHash *bc.Hash) ([]*types.Block, error)
-       locateHeaders(locator []*bc.Hash, stopHash *bc.Hash, skip uint64, maxNum uint64) ([]*types.BlockHeader, error)
        process() error
        setSyncPeer(peer *peers.Peer)
 }
@@ -67,11 +63,12 @@ type blockKeeper struct {
        quit chan struct{}
 }
 
-func newBlockKeeper(chain Chain, peers *peers.PeerSet) *blockKeeper {
-       msgFetcher := newMsgFetcher(peers)
+func newBlockKeeper(chain Chain, peers *peers.PeerSet, fastSyncDB dbm.DB) *blockKeeper {
+       storage := newStorage(fastSyncDB)
+       msgFetcher := newMsgFetcher(storage, peers)
        return &blockKeeper{
                chain:      chain,
-               fastSync:   newFastSync(chain, msgFetcher, peers),
+               fastSync:   newFastSync(chain, msgFetcher, storage, peers),
                msgFetcher: msgFetcher,
                peers:      peers,
                quit:       make(chan struct{}),
@@ -79,11 +76,69 @@ func newBlockKeeper(chain Chain, peers *peers.PeerSet) *blockKeeper {
 }
 
 func (bk *blockKeeper) locateBlocks(locator []*bc.Hash, stopHash *bc.Hash) ([]*types.Block, error) {
-       return bk.fastSync.locateBlocks(locator, stopHash)
+       headers, err := bk.locateHeaders(locator, stopHash, 0, maxNumOfBlocksPerMsg)
+       if err != nil {
+               return nil, err
+       }
+
+       blocks := []*types.Block{}
+       for _, header := range headers {
+               headerHash := header.Hash()
+               block, err := bk.chain.GetBlockByHash(&headerHash)
+               if err != nil {
+                       return nil, err
+               }
+
+               blocks = append(blocks, block)
+       }
+       return blocks, nil
 }
 
 func (bk *blockKeeper) locateHeaders(locator []*bc.Hash, stopHash *bc.Hash, skip uint64, maxNum uint64) ([]*types.BlockHeader, error) {
-       return bk.fastSync.locateHeaders(locator, stopHash, skip, maxNum)
+       startHeader, err := bk.chain.GetHeaderByHeight(0)
+       if err != nil {
+               return nil, err
+       }
+
+       for _, hash := range locator {
+               header, err := bk.chain.GetHeaderByHash(hash)
+               if err == nil && bk.chain.InMainChain(header.Hash()) {
+                       startHeader = header
+                       break
+               }
+       }
+
+       headers := make([]*types.BlockHeader, 0)
+       stopHeader, err := bk.chain.GetHeaderByHash(stopHash)
+       if err != nil {
+               return headers, nil
+       }
+
+       if !bk.chain.InMainChain(*stopHash) || stopHeader.Height < startHeader.Height {
+               return headers, nil
+       }
+
+       headers = append(headers, startHeader)
+       if stopHeader.Height == startHeader.Height {
+               return headers, nil
+       }
+
+       for num, index := uint64(0), startHeader.Height; num < maxNum-1; num++ {
+               index += skip + 1
+               if index >= stopHeader.Height {
+                       headers = append(headers, stopHeader)
+                       break
+               }
+
+               header, err := bk.chain.GetHeaderByHeight(index)
+               if err != nil {
+                       return nil, err
+               }
+
+               headers = append(headers, header)
+       }
+
+       return headers, nil
 }
 
 func (bk *blockKeeper) processBlock(peerID string, block *types.Block) {
@@ -105,13 +160,13 @@ func (bk *blockKeeper) regularBlockSync() error {
        for i <= peerHeight {
                block, err := bk.msgFetcher.requireBlock(bk.syncPeer.ID(), i)
                if err != nil {
-                       bk.peers.ErrorHandler(bk.syncPeer.ID(), security.LevelConnException, err)
+                       bk.peers.ProcessIllegal(bk.syncPeer.ID(), security.LevelConnException, err.Error())
                        return err
                }
 
                isOrphan, err := bk.chain.ProcessBlock(block)
                if err != nil {
-                       bk.peers.ErrorHandler(bk.syncPeer.ID(), security.LevelMsgIllegal, err)
+                       bk.peers.ProcessIllegal(bk.syncPeer.ID(), security.LevelMsgIllegal, err.Error())
                        return err
                }
 
@@ -137,7 +192,6 @@ func (bk *blockKeeper) checkSyncType() int {
        }
 
        bestHeight := bk.chain.BestBlockHeight()
-
        if peerIrreversibleHeight := peer.IrreversibleHeight(); peerIrreversibleHeight >= bestHeight+minGapStartFastSync {
                bk.fastSync.setSyncPeer(peer)
                return fastSyncType
index d855ed0..4a12346 100644 (file)
@@ -2,14 +2,18 @@ package chainmgr
 
 import (
        "encoding/json"
+       "io/ioutil"
+       "os"
        "testing"
        "time"
 
        "github.com/vapor/consensus"
+       dbm "github.com/vapor/database/leveldb"
        "github.com/vapor/errors"
        msgs "github.com/vapor/netsync/messages"
        "github.com/vapor/protocol/bc"
        "github.com/vapor/protocol/bc/types"
+       "github.com/vapor/test/mock"
        "github.com/vapor/testutil"
 )
 
@@ -55,11 +59,21 @@ func TestRegularBlockSync(t *testing.T) {
                        err:         nil,
                },
        }
+       tmp, err := ioutil.TempDir(".", "")
+       if err != nil {
+               t.Fatalf("failed to create temporary data folder: %v", err)
+       }
+       testDBA := dbm.NewDB("testdba", "leveldb", tmp)
+       testDBB := dbm.NewDB("testdbb", "leveldb", tmp)
+       defer func() {
+               testDBA.Close()
+               testDBB.Close()
+               os.RemoveAll(tmp)
+       }()
 
        for i, c := range cases {
-               syncTimeout = c.syncTimeout
-               a := mockSync(c.aBlocks, nil)
-               b := mockSync(c.bBlocks, nil)
+               a := mockSync(c.aBlocks, nil, testDBA)
+               b := mockSync(c.bBlocks, nil, testDBB)
                netWork := NewNetWork()
                netWork.Register(a, "192.168.0.1", "test node A", consensus.SFFullNode)
                netWork.Register(b, "192.168.0.2", "test node B", consensus.SFFullNode)
@@ -91,9 +105,21 @@ func TestRegularBlockSync(t *testing.T) {
 }
 
 func TestRequireBlock(t *testing.T) {
+       tmp, err := ioutil.TempDir(".", "")
+       if err != nil {
+               t.Fatalf("failed to create temporary data folder: %v", err)
+       }
+       testDBA := dbm.NewDB("testdba", "leveldb", tmp)
+       testDBB := dbm.NewDB("testdbb", "leveldb", tmp)
+       defer func() {
+               testDBB.Close()
+               testDBA.Close()
+               os.RemoveAll(tmp)
+       }()
+
        blocks := mockBlocks(nil, 5)
-       a := mockSync(blocks[:1], nil)
-       b := mockSync(blocks[:5], nil)
+       a := mockSync(blocks[:1], nil, testDBA)
+       b := mockSync(blocks[:5], nil, testDBB)
        netWork := NewNetWork()
        netWork.Register(a, "192.168.0.1", "test node A", consensus.SFFullNode)
        netWork.Register(b, "192.168.0.2", "test node B", consensus.SFFullNode)
@@ -129,8 +155,12 @@ func TestRequireBlock(t *testing.T) {
                },
        }
 
+       defer func() {
+               requireBlockTimeout = 20 * time.Second
+       }()
+
        for i, c := range cases {
-               syncTimeout = c.syncTimeout
+               requireBlockTimeout = c.syncTimeout
                got, err := c.testNode.blockKeeper.msgFetcher.requireBlock(c.testNode.blockKeeper.syncPeer.ID(), c.requireHeight)
                if !testutil.DeepEqual(got, c.want) {
                        t.Errorf("case %d: got %v want %v", i, got, c.want)
@@ -142,6 +172,19 @@ func TestRequireBlock(t *testing.T) {
 }
 
 func TestSendMerkleBlock(t *testing.T) {
+       tmp, err := ioutil.TempDir(".", "")
+       if err != nil {
+               t.Fatalf("failed to create temporary data folder: %v", err)
+       }
+
+       testDBA := dbm.NewDB("testdba", "leveldb", tmp)
+       testDBB := dbm.NewDB("testdbb", "leveldb", tmp)
+       defer func() {
+               testDBA.Close()
+               testDBB.Close()
+               os.RemoveAll(tmp)
+       }()
+
        cases := []struct {
                txCount        int
                relatedTxIndex []int
@@ -179,7 +222,7 @@ func TestSendMerkleBlock(t *testing.T) {
                        t.Fatal(err)
                }
 
-               spvNode := mockSync(blocks, nil)
+               spvNode := mockSync(blocks, nil, testDBA)
                blockHash := targetBlock.Hash()
                var statusResult *bc.TransactionStatus
                if statusResult, err = spvNode.chain.GetTransactionStatus(&blockHash); err != nil {
@@ -190,7 +233,7 @@ func TestSendMerkleBlock(t *testing.T) {
                        t.Fatal(err)
                }
 
-               fullNode := mockSync(blocks, nil)
+               fullNode := mockSync(blocks, nil, testDBB)
                netWork := NewNetWork()
                netWork.Register(spvNode, "192.168.0.1", "spv_node", consensus.SFFastSync)
                netWork.Register(fullNode, "192.168.0.2", "full_node", consensus.DefaultServices)
@@ -257,3 +300,153 @@ func TestSendMerkleBlock(t *testing.T) {
                }
        }
 }
+
+func TestLocateBlocks(t *testing.T) {
+       maxNumOfBlocksPerMsg = 5
+       blocks := mockBlocks(nil, 100)
+       cases := []struct {
+               locator    []uint64
+               stopHash   bc.Hash
+               wantHeight []uint64
+       }{
+               {
+                       locator:    []uint64{20},
+                       stopHash:   blocks[100].Hash(),
+                       wantHeight: []uint64{20, 21, 22, 23, 24},
+               },
+       }
+
+       mockChain := mock.NewChain(nil)
+       bk := &blockKeeper{chain: mockChain}
+       for _, block := range blocks {
+               mockChain.SetBlockByHeight(block.Height, block)
+       }
+
+       for i, c := range cases {
+               locator := []*bc.Hash{}
+               for _, i := range c.locator {
+                       hash := blocks[i].Hash()
+                       locator = append(locator, &hash)
+               }
+
+               want := []*types.Block{}
+               for _, i := range c.wantHeight {
+                       want = append(want, blocks[i])
+               }
+
+               got, _ := bk.locateBlocks(locator, &c.stopHash)
+               if !testutil.DeepEqual(got, want) {
+                       t.Errorf("case %d: got %v want %v", i, got, want)
+               }
+       }
+}
+
+func TestLocateHeaders(t *testing.T) {
+       defer func() {
+               maxNumOfHeadersPerMsg = 1000
+       }()
+       maxNumOfHeadersPerMsg = 10
+       blocks := mockBlocks(nil, 150)
+       blocksHash := []bc.Hash{}
+       for _, block := range blocks {
+               blocksHash = append(blocksHash, block.Hash())
+       }
+
+       cases := []struct {
+               chainHeight uint64
+               locator     []uint64
+               stopHash    *bc.Hash
+               skip        uint64
+               wantHeight  []uint64
+               err         bool
+       }{
+               {
+                       chainHeight: 100,
+                       locator:     []uint64{90},
+                       stopHash:    &blocksHash[100],
+                       skip:        0,
+                       wantHeight:  []uint64{90, 91, 92, 93, 94, 95, 96, 97, 98, 99},
+                       err:         false,
+               },
+               {
+                       chainHeight: 100,
+                       locator:     []uint64{20},
+                       stopHash:    &blocksHash[24],
+                       skip:        0,
+                       wantHeight:  []uint64{20, 21, 22, 23, 24},
+                       err:         false,
+               },
+               {
+                       chainHeight: 100,
+                       locator:     []uint64{20},
+                       stopHash:    &blocksHash[20],
+                       wantHeight:  []uint64{20},
+                       err:         false,
+               },
+               {
+                       chainHeight: 100,
+                       locator:     []uint64{20},
+                       stopHash:    &blocksHash[120],
+                       wantHeight:  []uint64{},
+                       err:         false,
+               },
+               {
+                       chainHeight: 100,
+                       locator:     []uint64{120, 70},
+                       stopHash:    &blocksHash[78],
+                       wantHeight:  []uint64{70, 71, 72, 73, 74, 75, 76, 77, 78},
+                       err:         false,
+               },
+               {
+                       chainHeight: 100,
+                       locator:     []uint64{15},
+                       stopHash:    &blocksHash[10],
+                       skip:        10,
+                       wantHeight:  []uint64{},
+                       err:         false,
+               },
+               {
+                       chainHeight: 100,
+                       locator:     []uint64{15},
+                       stopHash:    &blocksHash[80],
+                       skip:        10,
+                       wantHeight:  []uint64{15, 26, 37, 48, 59, 70, 80},
+                       err:         false,
+               },
+               {
+                       chainHeight: 100,
+                       locator:     []uint64{0},
+                       stopHash:    &blocksHash[100],
+                       skip:        9,
+                       wantHeight:  []uint64{0, 10, 20, 30, 40, 50, 60, 70, 80, 90},
+                       err:         false,
+               },
+       }
+
+       for i, c := range cases {
+               mockChain := mock.NewChain(nil)
+               bk := &blockKeeper{chain: mockChain}
+               for i := uint64(0); i <= c.chainHeight; i++ {
+                       mockChain.SetBlockByHeight(i, blocks[i])
+               }
+
+               locator := []*bc.Hash{}
+               for _, i := range c.locator {
+                       hash := blocks[i].Hash()
+                       locator = append(locator, &hash)
+               }
+
+               want := []*types.BlockHeader{}
+               for _, i := range c.wantHeight {
+                       want = append(want, &blocks[i].BlockHeader)
+               }
+
+               got, err := bk.locateHeaders(locator, c.stopHash, c.skip, maxNumOfHeadersPerMsg)
+               if err != nil != c.err {
+                       t.Errorf("case %d: got %v want err = %v", i, err, c.err)
+               }
+               if !testutil.DeepEqual(got, want) {
+                       t.Errorf("case %d: got %v want %v", i, got, want)
+               }
+       }
+}
diff --git a/netsync/chainmgr/block_process.go b/netsync/chainmgr/block_process.go
new file mode 100644 (file)
index 0000000..4caf9f7
--- /dev/null
@@ -0,0 +1,64 @@
+package chainmgr
+
+import (
+       "sync"
+
+       log "github.com/sirupsen/logrus"
+
+       "github.com/vapor/netsync/peers"
+       "github.com/vapor/p2p/security"
+)
+
+type BlockProcessor interface {
+       process(chan struct{}, chan struct{}, *sync.WaitGroup)
+}
+
+type blockProcessor struct {
+       chain   Chain
+       storage Storage
+       peers   *peers.PeerSet
+}
+
+func newBlockProcessor(chain Chain, storage Storage, peers *peers.PeerSet) *blockProcessor {
+       return &blockProcessor{
+               chain:   chain,
+               peers:   peers,
+               storage: storage,
+       }
+}
+
+func (bp *blockProcessor) insert(blockStorage *blockStorage) error {
+       isOrphan, err := bp.chain.ProcessBlock(blockStorage.block)
+       if err != nil || isOrphan {
+               bp.peers.ProcessIllegal(blockStorage.peerID, security.LevelMsgIllegal, err.Error())
+       }
+       return err
+}
+
+func (bp *blockProcessor) process(downloadNotifyCh chan struct{}, ProcessStop chan struct{}, wg *sync.WaitGroup) {
+       defer func() {
+               close(ProcessStop)
+               wg.Done()
+       }()
+
+       for {
+               for {
+                       nextHeight := bp.chain.BestBlockHeight() + 1
+                       block, err := bp.storage.readBlock(nextHeight)
+                       if err != nil {
+                               break
+                       }
+
+                       if err := bp.insert(block); err != nil {
+                               log.WithFields(log.Fields{"module": logModule, "err": err}).Error("failed on process block")
+                               return
+                       }
+
+                       bp.storage.deleteBlock(nextHeight)
+               }
+
+               if _, ok := <-downloadNotifyCh; !ok {
+                       return
+               }
+       }
+}
diff --git a/netsync/chainmgr/block_process_test.go b/netsync/chainmgr/block_process_test.go
new file mode 100644 (file)
index 0000000..4f1e024
--- /dev/null
@@ -0,0 +1,50 @@
+package chainmgr
+
+import (
+       "io/ioutil"
+       "os"
+       "sync"
+       "testing"
+       "time"
+
+       dbm "github.com/vapor/database/leveldb"
+       "github.com/vapor/test/mock"
+)
+
+func TestBlockProcess(t *testing.T) {
+       tmp, err := ioutil.TempDir(".", "")
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer os.RemoveAll(tmp)
+
+       testDB := dbm.NewDB("testdb", "leveldb", tmp)
+       defer testDB.Close()
+
+       s := newStorage(testDB)
+       mockChain := mock.NewChain(nil)
+       blockNum := 200
+       blocks := mockBlocks(nil, uint64(blockNum))
+       for i := 0; i <= blockNum/2; i++ {
+               mockChain.SetBlockByHeight(uint64(i), blocks[i])
+               mockChain.SetBestBlockHeader(&blocks[i].BlockHeader)
+       }
+
+       if err := s.writeBlocks("testPeer", blocks); err != nil {
+               t.Fatal(err)
+       }
+
+       bp := newBlockProcessor(mockChain, s, nil)
+       downloadNotifyCh := make(chan struct{}, 1)
+       ProcessStopCh := make(chan struct{})
+       var wg sync.WaitGroup
+       go func() {
+               time.Sleep(1 * time.Second)
+               close(downloadNotifyCh)
+       }()
+       wg.Add(1)
+       bp.process(downloadNotifyCh, ProcessStopCh, &wg)
+       if bp.chain.BestBlockHeight() != uint64(blockNum) {
+               t.Fatalf("TestBlockProcess fail: got %d want %d", bp.chain.BestBlockHeight(), blockNum)
+       }
+}
index 6a52223..06c67c8 100644 (file)
@@ -1,55 +1,48 @@
 package chainmgr
 
 import (
+       "sync"
+
        log "github.com/sirupsen/logrus"
 
        "github.com/vapor/errors"
        "github.com/vapor/netsync/peers"
-       "github.com/vapor/p2p/security"
        "github.com/vapor/protocol/bc"
        "github.com/vapor/protocol/bc/types"
 )
 
 var (
-       maxBlocksPerMsg      = uint64(1000)
-       maxHeadersPerMsg     = uint64(1000)
-       fastSyncPivotGap     = uint64(64)
-       minGapStartFastSync  = uint64(128)
-       maxFastSyncBlocksNum = uint64(10000)
+       maxNumOfSkeletonPerSync = uint64(10)
+       numOfBlocksSkeletonGap  = maxNumOfBlocksPerMsg
+       maxNumOfBlocksPerSync   = numOfBlocksSkeletonGap * maxNumOfSkeletonPerSync
+       fastSyncPivotGap        = uint64(64)
+       minGapStartFastSync     = uint64(128)
 
-       errOrphanBlock = errors.New("fast sync block is orphan")
+       errNoSyncPeer = errors.New("can't find sync peer")
 )
 
-type MsgFetcher interface {
-       requireBlock(peerID string, height uint64) (*types.Block, error)
-       requireBlocks(peerID string, locator []*bc.Hash, stopHash *bc.Hash) ([]*types.Block, error)
-}
-
 type fastSync struct {
-       chain      Chain
-       msgFetcher MsgFetcher
-       peers      *peers.PeerSet
-       syncPeer   *peers.Peer
-       stopHeader *types.BlockHeader
-       length     uint64
-
-       quite chan struct{}
+       chain          Chain
+       msgFetcher     MsgFetcher
+       blockProcessor BlockProcessor
+       peers          *peers.PeerSet
+       mainSyncPeer   *peers.Peer
 }
 
-func newFastSync(chain Chain, msgFether MsgFetcher, peers *peers.PeerSet) *fastSync {
+func newFastSync(chain Chain, msgFetcher MsgFetcher, storage Storage, peers *peers.PeerSet) *fastSync {
        return &fastSync{
-               chain:      chain,
-               msgFetcher: msgFether,
-               peers:      peers,
-               quite:      make(chan struct{}),
+               chain:          chain,
+               msgFetcher:     msgFetcher,
+               blockProcessor: newBlockProcessor(chain, storage, peers),
+               peers:          peers,
        }
 }
 
 func (fs *fastSync) blockLocator() []*bc.Hash {
        header := fs.chain.BestBlockHeader()
        locator := []*bc.Hash{}
-
        step := uint64(1)
+
        for header != nil {
                headerHash := header.Hash()
                locator = append(locator, &headerHash)
@@ -75,118 +68,88 @@ func (fs *fastSync) blockLocator() []*bc.Hash {
        return locator
 }
 
-func (fs *fastSync) process() error {
-       if err := fs.findFastSyncRange(); err != nil {
-               return err
-       }
-
-       stopHash := fs.stopHeader.Hash()
-       for fs.chain.BestBlockHeight() < fs.stopHeader.Height {
-               blocks, err := fs.msgFetcher.requireBlocks(fs.syncPeer.ID(), fs.blockLocator(), &stopHash)
-               if err != nil {
-                       fs.peers.ErrorHandler(fs.syncPeer.ID(), security.LevelConnException, err)
-                       return err
-               }
-
-               if err := fs.verifyBlocks(blocks); err != nil {
-                       fs.peers.ErrorHandler(fs.syncPeer.ID(), security.LevelMsgIllegal, err)
-                       return err
-               }
-       }
-
-       log.WithFields(log.Fields{"module": logModule, "height": fs.chain.BestBlockHeight()}).Info("fast sync success")
-       return nil
-}
-
-func (fs *fastSync) findFastSyncRange() error {
-       bestHeight := fs.chain.BestBlockHeight()
-       fs.length = fs.syncPeer.IrreversibleHeight() - fastSyncPivotGap - bestHeight
-       if fs.length > maxFastSyncBlocksNum {
-               fs.length = maxFastSyncBlocksNum
+// createFetchBlocksTasks get the skeleton and assign tasks according to the skeleton.
+func (fs *fastSync) createFetchBlocksTasks(stopBlock *types.Block) ([]*fetchBlocksWork, error) {
+       // Find peers that meet the height requirements.
+       peers := fs.peers.GetPeersByHeight(stopBlock.Height + fastSyncPivotGap)
+       if len(peers) == 0 {
+               return nil, errNoSyncPeer
        }
 
-       stopBlock, err := fs.msgFetcher.requireBlock(fs.syncPeer.ID(), bestHeight+fs.length)
-       if err != nil {
-               return err
+       // parallel fetch the skeleton from peers.
+       stopHash := stopBlock.Hash()
+       skeletonMap := fs.msgFetcher.parallelFetchHeaders(peers, fs.blockLocator(), &stopHash, numOfBlocksSkeletonGap-1)
+       if len(skeletonMap) == 0 {
+               return nil, errors.New("No skeleton found")
        }
 
-       fs.stopHeader = &stopBlock.BlockHeader
-       return nil
-}
-
-func (fs *fastSync) locateBlocks(locator []*bc.Hash, stopHash *bc.Hash) ([]*types.Block, error) {
-       headers, err := fs.locateHeaders(locator, stopHash, 0, maxBlocksPerMsg)
-       if err != nil {
-               return nil, err
+       mainSkeleton, ok := skeletonMap[fs.mainSyncPeer.ID()]
+       if !ok {
+               return nil, errors.New("No main skeleton found")
        }
 
-       blocks := []*types.Block{}
-       for _, header := range headers {
-               headerHash := header.Hash()
-               block, err := fs.chain.GetBlockByHash(&headerHash)
-               if err != nil {
-                       return nil, err
+       // collect peers that match the skeleton of the primary sync peer
+       fs.msgFetcher.addSyncPeer(fs.mainSyncPeer.ID())
+       delete(skeletonMap, fs.mainSyncPeer.ID())
+       for peerID, skeleton := range skeletonMap {
+               if len(skeleton) != len(mainSkeleton) {
+                       log.WithFields(log.Fields{"module": logModule, "main skeleton": len(mainSkeleton), "got skeleton": len(skeleton)}).Warn("different skeleton length")
+                       continue
                }
 
-               blocks = append(blocks, block)
+               for i, header := range skeleton {
+                       if header.Hash() != mainSkeleton[i].Hash() {
+                               log.WithFields(log.Fields{"module": logModule, "header index": i, "main skeleton": mainSkeleton[i].Hash(), "got skeleton": header.Hash()}).Warn("different skeleton hash")
+                               continue
+                       }
+               }
+               fs.msgFetcher.addSyncPeer(peerID)
        }
-       return blocks, nil
-}
 
-func (fs *fastSync) locateHeaders(locator []*bc.Hash, stopHash *bc.Hash, skip uint64, maxNum uint64) ([]*types.BlockHeader, error) {
-       startHeader, err := fs.chain.GetHeaderByHeight(0)
-       if err != nil {
-               return nil, err
+       blockFetchTasks := make([]*fetchBlocksWork, 0)
+       // create download task
+       for i := 0; i < len(mainSkeleton)-1; i++ {
+               blockFetchTasks = append(blockFetchTasks, &fetchBlocksWork{startHeader: mainSkeleton[i], stopHeader: mainSkeleton[i+1]})
        }
 
-       for _, hash := range locator {
-               header, err := fs.chain.GetHeaderByHash(hash)
-               if err == nil && fs.chain.InMainChain(header.Hash()) {
-                       startHeader = header
-                       break
-               }
-       }
+       return blockFetchTasks, nil
+}
 
-       headers := make([]*types.BlockHeader, 0)
-       stopHeader, err := fs.chain.GetHeaderByHash(stopHash)
+func (fs *fastSync) process() error {
+       stopBlock, err := fs.findSyncRange()
        if err != nil {
-               return headers, nil
+               return err
        }
 
-       if !fs.chain.InMainChain(*stopHash) {
-               return headers, nil
+       tasks, err := fs.createFetchBlocksTasks(stopBlock)
+       if err != nil {
+               return err
        }
 
-       num := uint64(0)
-       for i := startHeader.Height; i <= stopHeader.Height && num < maxNum; i += skip + 1 {
-               header, err := fs.chain.GetHeaderByHeight(i)
-               if err != nil {
-                       return nil, err
-               }
+       downloadNotifyCh := make(chan struct{}, 1)
+       processStopCh := make(chan struct{})
+       var wg sync.WaitGroup
+       wg.Add(2)
+       go fs.msgFetcher.parallelFetchBlocks(tasks, downloadNotifyCh, processStopCh, &wg)
+       go fs.blockProcessor.process(downloadNotifyCh, processStopCh, &wg)
+       wg.Wait()
+       fs.msgFetcher.resetParameter()
+       log.WithFields(log.Fields{"module": logModule, "height": fs.chain.BestBlockHeight()}).Info("fast sync complete")
+       return nil
+}
 
-               headers = append(headers, header)
-               num++
+// findSyncRange find the start and end of this sync.
+// sync length cannot be greater than maxFastSyncBlocksNum.
+func (fs *fastSync) findSyncRange() (*types.Block, error) {
+       bestHeight := fs.chain.BestBlockHeight()
+       length := fs.mainSyncPeer.IrreversibleHeight() - fastSyncPivotGap - bestHeight
+       if length > maxNumOfBlocksPerSync {
+               length = maxNumOfBlocksPerSync
        }
 
-       return headers, nil
+       return fs.msgFetcher.requireBlock(fs.mainSyncPeer.ID(), bestHeight+length)
 }
 
 func (fs *fastSync) setSyncPeer(peer *peers.Peer) {
-       fs.syncPeer = peer
-}
-
-func (fs *fastSync) verifyBlocks(blocks []*types.Block) error {
-       for _, block := range blocks {
-               isOrphan, err := fs.chain.ProcessBlock(block)
-               if err != nil {
-                       return err
-               }
-
-               if isOrphan {
-                       log.WithFields(log.Fields{"module": logModule, "height": block.Height, "hash": block.Hash()}).Error("fast sync block is orphan")
-                       return errOrphanBlock
-               }
-       }
-
-       return nil
+       fs.mainSyncPeer = peer
 }
index 0ff3701..efd5f5b 100644 (file)
@@ -1,10 +1,13 @@
 package chainmgr
 
 import (
+       "io/ioutil"
+       "os"
        "testing"
        "time"
 
        "github.com/vapor/consensus"
+       dbm "github.com/vapor/database/leveldb"
        "github.com/vapor/errors"
        "github.com/vapor/protocol/bc"
        "github.com/vapor/protocol/bc/types"
@@ -65,9 +68,33 @@ func TestBlockLocator(t *testing.T) {
 }
 
 func TestFastBlockSync(t *testing.T) {
-       maxBlocksPerMsg = 10
-       maxHeadersPerMsg = 10
-       maxFastSyncBlocksNum = 200
+       tmp, err := ioutil.TempDir(".", "")
+       if err != nil {
+               t.Fatalf("failed to create temporary data folder: %v", err)
+       }
+       testDBA := dbm.NewDB("testdba", "leveldb", tmp)
+       testDBB := dbm.NewDB("testdbb", "leveldb", tmp)
+       defer func() {
+               testDBA.Close()
+               testDBB.Close()
+               os.RemoveAll(tmp)
+       }()
+
+       maxNumOfSkeletonPerSync = 10
+       numOfBlocksSkeletonGap = 10
+       maxNumOfBlocksPerSync = maxNumOfSkeletonPerSync * maxNumOfSkeletonPerSync
+       fastSyncPivotGap = uint64(5)
+       minGapStartFastSync = uint64(6)
+
+       defer func() {
+               maxNumOfSkeletonPerSync = 10
+               numOfBlocksSkeletonGap = maxNumOfBlocksPerMsg
+               maxNumOfBlocksPerSync = maxNumOfSkeletonPerSync * maxNumOfSkeletonPerSync
+               fastSyncPivotGap = uint64(64)
+               minGapStartFastSync = uint64(128)
+
+       }()
+
        baseChain := mockBlocks(nil, 300)
 
        cases := []struct {
@@ -81,22 +108,42 @@ func TestFastBlockSync(t *testing.T) {
                        syncTimeout: 30 * time.Second,
                        aBlocks:     baseChain[:50],
                        bBlocks:     baseChain[:301],
-                       want:        baseChain[:237],
+                       want:        baseChain[:150],
                        err:         nil,
                },
                {
                        syncTimeout: 30 * time.Second,
                        aBlocks:     baseChain[:2],
                        bBlocks:     baseChain[:300],
-                       want:        baseChain[:202],
+                       want:        baseChain[:102],
+                       err:         nil,
+               },
+               {
+                       syncTimeout: 30 * time.Second,
+                       aBlocks:     baseChain[:2],
+                       bBlocks:     baseChain[:53],
+                       want:        baseChain[:48],
+                       err:         nil,
+               },
+               {
+                       syncTimeout: 30 * time.Second,
+                       aBlocks:     baseChain[:2],
+                       bBlocks:     baseChain[:53],
+                       want:        baseChain[:48],
+                       err:         nil,
+               },
+               {
+                       syncTimeout: 30 * time.Second,
+                       aBlocks:     baseChain[:2],
+                       bBlocks:     baseChain[:10],
+                       want:        baseChain[:5],
                        err:         nil,
                },
        }
 
        for i, c := range cases {
-               syncTimeout = c.syncTimeout
-               a := mockSync(c.aBlocks, nil)
-               b := mockSync(c.bBlocks, nil)
+               a := mockSync(c.aBlocks, nil, testDBA)
+               b := mockSync(c.bBlocks, nil, testDBB)
                netWork := NewNetWork()
                netWork.Register(a, "192.168.0.1", "test node A", consensus.SFFullNode|consensus.SFFastSync)
                netWork.Register(b, "192.168.0.2", "test node B", consensus.SFFullNode|consensus.SFFastSync)
@@ -126,150 +173,3 @@ func TestFastBlockSync(t *testing.T) {
                }
        }
 }
-
-func TestLocateBlocks(t *testing.T) {
-       maxBlocksPerMsg = 5
-       blocks := mockBlocks(nil, 100)
-       cases := []struct {
-               locator    []uint64
-               stopHash   bc.Hash
-               wantHeight []uint64
-       }{
-               {
-                       locator:    []uint64{20},
-                       stopHash:   blocks[100].Hash(),
-                       wantHeight: []uint64{20, 21, 22, 23, 24},
-               },
-       }
-
-       mockChain := mock.NewChain(nil)
-       fs := &fastSync{chain: mockChain}
-       for _, block := range blocks {
-               mockChain.SetBlockByHeight(block.Height, block)
-       }
-
-       for i, c := range cases {
-               locator := []*bc.Hash{}
-               for _, i := range c.locator {
-                       hash := blocks[i].Hash()
-                       locator = append(locator, &hash)
-               }
-
-               want := []*types.Block{}
-               for _, i := range c.wantHeight {
-                       want = append(want, blocks[i])
-               }
-
-               got, _ := fs.locateBlocks(locator, &c.stopHash)
-               if !testutil.DeepEqual(got, want) {
-                       t.Errorf("case %d: got %v want %v", i, got, want)
-               }
-       }
-}
-
-func TestLocateHeaders(t *testing.T) {
-       maxHeadersPerMsg = 10
-       blocks := mockBlocks(nil, 150)
-       blocksHash := []bc.Hash{}
-       for _, block := range blocks {
-               blocksHash = append(blocksHash, block.Hash())
-       }
-
-       cases := []struct {
-               chainHeight uint64
-               locator     []uint64
-               stopHash    *bc.Hash
-               skip        uint64
-               wantHeight  []uint64
-               err         bool
-       }{
-               {
-                       chainHeight: 100,
-                       locator:     []uint64{90},
-                       stopHash:    &blocksHash[100],
-                       skip:        0,
-                       wantHeight:  []uint64{90, 91, 92, 93, 94, 95, 96, 97, 98, 99},
-                       err:         false,
-               },
-               {
-                       chainHeight: 100,
-                       locator:     []uint64{20},
-                       stopHash:    &blocksHash[24],
-                       skip:        0,
-                       wantHeight:  []uint64{20, 21, 22, 23, 24},
-                       err:         false,
-               },
-               {
-                       chainHeight: 100,
-                       locator:     []uint64{20},
-                       stopHash:    &blocksHash[20],
-                       wantHeight:  []uint64{20},
-                       err:         false,
-               },
-               {
-                       chainHeight: 100,
-                       locator:     []uint64{20},
-                       stopHash:    &blocksHash[120],
-                       wantHeight:  []uint64{},
-                       err:         false,
-               },
-               {
-                       chainHeight: 100,
-                       locator:     []uint64{120, 70},
-                       stopHash:    &blocksHash[78],
-                       wantHeight:  []uint64{70, 71, 72, 73, 74, 75, 76, 77, 78},
-                       err:         false,
-               },
-               {
-                       chainHeight: 100,
-                       locator:     []uint64{15},
-                       stopHash:    &blocksHash[10],
-                       skip:        10,
-                       wantHeight:  []uint64{},
-                       err:         false,
-               },
-               {
-                       chainHeight: 100,
-                       locator:     []uint64{15},
-                       stopHash:    &blocksHash[80],
-                       skip:        10,
-                       wantHeight:  []uint64{15, 26, 37, 48, 59, 70},
-                       err:         false,
-               },
-               {
-                       chainHeight: 100,
-                       locator:     []uint64{0},
-                       stopHash:    &blocksHash[100],
-                       skip:        9,
-                       wantHeight:  []uint64{0, 10, 20, 30, 40, 50, 60, 70, 80, 90},
-                       err:         false,
-               },
-       }
-
-       for i, c := range cases {
-               mockChain := mock.NewChain(nil)
-               fs := &fastSync{chain: mockChain}
-               for i := uint64(0); i <= c.chainHeight; i++ {
-                       mockChain.SetBlockByHeight(i, blocks[i])
-               }
-
-               locator := []*bc.Hash{}
-               for _, i := range c.locator {
-                       hash := blocks[i].Hash()
-                       locator = append(locator, &hash)
-               }
-
-               want := []*types.BlockHeader{}
-               for _, i := range c.wantHeight {
-                       want = append(want, &blocks[i].BlockHeader)
-               }
-
-               got, err := fs.locateHeaders(locator, c.stopHash, c.skip, maxHeadersPerMsg)
-               if err != nil != c.err {
-                       t.Errorf("case %d: got %v want err = %v", i, err, c.err)
-               }
-               if !testutil.DeepEqual(got, want) {
-                       t.Errorf("case %d: got %v want %v", i, got, want)
-               }
-       }
-}
index 6e37389..81baace 100644 (file)
@@ -8,6 +8,7 @@ import (
 
        cfg "github.com/vapor/config"
        "github.com/vapor/consensus"
+       dbm "github.com/vapor/database/leveldb"
        "github.com/vapor/event"
        msgs "github.com/vapor/netsync/messages"
        "github.com/vapor/netsync/peers"
@@ -68,12 +69,12 @@ type Manager struct {
 }
 
 //NewChainManager create a chain sync manager.
-func NewManager(config *cfg.Config, sw Switch, chain Chain, mempool Mempool, dispatcher *event.Dispatcher, peers *peers.PeerSet) (*Manager, error) {
+func NewManager(config *cfg.Config, sw Switch, chain Chain, mempool Mempool, dispatcher *event.Dispatcher, peers *peers.PeerSet, fastSyncDB dbm.DB) (*Manager, error) {
        manager := &Manager{
                sw:              sw,
                mempool:         mempool,
                chain:           chain,
-               blockKeeper:     newBlockKeeper(chain, peers),
+               blockKeeper:     newBlockKeeper(chain, peers, fastSyncDB),
                peers:           peers,
                txSyncCh:        make(chan *txSyncMsg),
                quit:            make(chan struct{}),
@@ -182,7 +183,7 @@ func (m *Manager) handleGetBlocksMsg(peer *peers.Peer, msg *msgs.GetBlocksMessag
 }
 
 func (m *Manager) handleGetHeadersMsg(peer *peers.Peer, msg *msgs.GetHeadersMessage) {
-       headers, err := m.blockKeeper.locateHeaders(msg.GetBlockLocator(), msg.GetStopHash(), msg.GetSkip(), maxHeadersPerMsg)
+       headers, err := m.blockKeeper.locateHeaders(msg.GetBlockLocator(), msg.GetStopHash(), msg.GetSkip(), maxNumOfHeadersPerMsg)
        if err != nil || len(headers) == 0 {
                log.WithFields(log.Fields{"module": logModule, "err": err}).Debug("fail on handleGetHeadersMsg locateHeaders")
                return
index f635667..2cfda5e 100644 (file)
 package chainmgr
 
 import (
+       "sync"
        "time"
 
+       log "github.com/sirupsen/logrus"
+
        "github.com/vapor/errors"
        "github.com/vapor/netsync/peers"
+       "github.com/vapor/p2p/security"
        "github.com/vapor/protocol/bc"
        "github.com/vapor/protocol/bc/types"
 )
 
 const (
-       blockProcessChSize   = 1024
-       blocksProcessChSize  = 128
-       headersProcessChSize = 1024
+       maxNumOfParallelFetchBlocks = 7
+       blockProcessChSize          = 1024
+       blocksProcessChSize         = 128
+       headersProcessChSize        = 1024
+       maxNumOfFastSyncPeers       = 128
 )
 
-type msgFetcher struct {
-       peers *peers.PeerSet
+var (
+       requireBlockTimeout   = 20 * time.Second
+       requireHeadersTimeout = 30 * time.Second
+       requireBlocksTimeout  = 50 * time.Second
+
+       errRequestBlocksTimeout = errors.New("request blocks timeout")
+       errRequestTimeout       = errors.New("request timeout")
+       errPeerDropped          = errors.New("Peer dropped")
+       errSendMsg              = errors.New("send message error")
+)
+
+type MsgFetcher interface {
+       resetParameter()
+       addSyncPeer(peerID string)
+       requireBlock(peerID string, height uint64) (*types.Block, error)
+       parallelFetchBlocks(work []*fetchBlocksWork, downloadNotifyCh chan struct{}, ProcessStopCh chan struct{}, wg *sync.WaitGroup)
+       parallelFetchHeaders(peers []*peers.Peer, locator []*bc.Hash, stopHash *bc.Hash, skip uint64) map[string][]*types.BlockHeader
+}
 
+type fetchBlocksWork struct {
+       startHeader, stopHeader *types.BlockHeader
+}
+
+type fetchBlocksResult struct {
+       startHeight, stopHeight uint64
+       err                     error
+}
+
+type msgFetcher struct {
+       storage          Storage
+       syncPeers        *fastSyncPeers
+       peers            *peers.PeerSet
        blockProcessCh   chan *blockMsg
        blocksProcessCh  chan *blocksMsg
        headersProcessCh chan *headersMsg
+       blocksMsgChanMap map[string]chan []*types.Block
+       mux              sync.RWMutex
 }
 
-func newMsgFetcher(peers *peers.PeerSet) *msgFetcher {
+func newMsgFetcher(storage Storage, peers *peers.PeerSet) *msgFetcher {
        return &msgFetcher{
+               storage:          storage,
+               syncPeers:        newFastSyncPeers(),
                peers:            peers,
                blockProcessCh:   make(chan *blockMsg, blockProcessChSize),
                blocksProcessCh:  make(chan *blocksMsg, blocksProcessChSize),
                headersProcessCh: make(chan *headersMsg, headersProcessChSize),
+               blocksMsgChanMap: make(map[string]chan []*types.Block),
+       }
+}
+
+func (mf *msgFetcher) addSyncPeer(peerID string) {
+       mf.syncPeers.add(peerID)
+}
+
+func (mf *msgFetcher) collectResultLoop(peerCh chan string, quit chan struct{}, resultCh chan *fetchBlocksResult, workerCloseCh chan struct{}, workSize int) {
+       defer close(workerCloseCh)
+       //collect fetch results
+       for resultCount := 0; resultCount < workSize && mf.syncPeers.size() > 0; resultCount++ {
+               select {
+               case result := <-resultCh:
+                       if result.err != nil {
+                               log.WithFields(log.Fields{"module": logModule, "startHeight": result.startHeight, "stopHeight": result.stopHeight, "err": result.err}).Error("failed on fetch blocks")
+                               return
+                       }
+
+                       peer, err := mf.syncPeers.selectIdlePeer()
+                       if err != nil {
+                               log.WithFields(log.Fields{"module": logModule, "err": result.err}).Warn("failed on find fast sync peer")
+                               break
+                       }
+                       peerCh <- peer
+               case _, ok := <-quit:
+                       if !ok {
+                               return
+                       }
+               }
+       }
+}
+
+func (mf *msgFetcher) fetchBlocks(work *fetchBlocksWork, peerID string) ([]*types.Block, error) {
+       defer mf.syncPeers.setIdle(peerID)
+       startHash := work.startHeader.Hash()
+       stopHash := work.stopHeader.Hash()
+       blocks, err := mf.requireBlocks(peerID, []*bc.Hash{&startHash}, &stopHash)
+       if err != nil {
+               mf.peers.ProcessIllegal(peerID, security.LevelConnException, err.Error())
+               return nil, err
+       }
+
+       if err := mf.verifyBlocksMsg(blocks, work.startHeader, work.stopHeader); err != nil {
+               mf.peers.ProcessIllegal(peerID, security.LevelConnException, err.Error())
+               return nil, err
+       }
+
+       return blocks, nil
+}
+
+func (mf *msgFetcher) fetchBlocksProcess(work *fetchBlocksWork, peerCh chan string, downloadNotifyCh chan struct{}, closeCh chan struct{}) error {
+       for {
+               select {
+               case peerID := <-peerCh:
+                       for {
+                               blocks, err := mf.fetchBlocks(work, peerID)
+                               if err != nil {
+                                       log.WithFields(log.Fields{"module": logModule, "startHeight": work.startHeader.Height, "stopHeight": work.stopHeader.Height, "error": err}).Info("failed on fetch blocks")
+                                       break
+                               }
+
+                               if err := mf.storage.writeBlocks(peerID, blocks); err != nil {
+                                       log.WithFields(log.Fields{"module": logModule, "error": err}).Info("write block error")
+                                       return err
+                               }
+
+                               // send to block process pool
+                               select {
+                               case downloadNotifyCh <- struct{}{}:
+                               default:
+                               }
+
+                               // work completed
+                               if blocks[len(blocks)-1].Height >= work.stopHeader.Height-1 {
+                                       return nil
+                               }
+
+                               //unfinished work, continue
+                               work.startHeader = &blocks[len(blocks)-1].BlockHeader
+                       }
+               case <-closeCh:
+                       return nil
+               }
+       }
+}
+
+func (mf *msgFetcher) fetchBlocksWorker(workCh chan *fetchBlocksWork, peerCh chan string, resultCh chan *fetchBlocksResult, closeCh chan struct{}, downloadNotifyCh chan struct{}, wg *sync.WaitGroup) {
+       for {
+               select {
+               case work := <-workCh:
+                       err := mf.fetchBlocksProcess(work, peerCh, downloadNotifyCh, closeCh)
+                       resultCh <- &fetchBlocksResult{startHeight: work.startHeader.Height, stopHeight: work.stopHeader.Height, err: err}
+               case <-closeCh:
+                       wg.Done()
+                       return
+               }
+       }
+}
+
+func (mf *msgFetcher) parallelFetchBlocks(works []*fetchBlocksWork, downloadNotifyCh chan struct{}, ProcessStopCh chan struct{}, wg *sync.WaitGroup) {
+       workSize := len(works)
+       workCh := make(chan *fetchBlocksWork, workSize)
+       peerCh := make(chan string, maxNumOfFastSyncPeers)
+       resultCh := make(chan *fetchBlocksResult, workSize)
+       closeCh := make(chan struct{})
+
+       for _, work := range works {
+               workCh <- work
+       }
+       syncPeers := mf.syncPeers.selectIdlePeers()
+       for i := 0; i < len(syncPeers) && i < maxNumOfFastSyncPeers; i++ {
+               peerCh <- syncPeers[i]
+       }
+
+       var workWg sync.WaitGroup
+       for i := 0; i <= maxNumOfParallelFetchBlocks && i < workSize; i++ {
+               workWg.Add(1)
+               go mf.fetchBlocksWorker(workCh, peerCh, resultCh, closeCh, downloadNotifyCh, &workWg)
+       }
+
+       go mf.collectResultLoop(peerCh, ProcessStopCh, resultCh, closeCh, workSize)
+
+       workWg.Wait()
+       close(resultCh)
+       close(peerCh)
+       close(workCh)
+       close(downloadNotifyCh)
+       wg.Done()
+}
+
+func (mf *msgFetcher) parallelFetchHeaders(peers []*peers.Peer, locator []*bc.Hash, stopHash *bc.Hash, skip uint64) map[string][]*types.BlockHeader {
+       result := make(map[string][]*types.BlockHeader)
+       response := make(map[string]bool)
+       for _, peer := range peers {
+               if ok := peer.GetHeaders(locator, stopHash, skip); !ok {
+                       continue
+               }
+               result[peer.ID()] = nil
+       }
+
+       timeout := time.NewTimer(requireHeadersTimeout)
+       defer timeout.Stop()
+       for {
+               select {
+               case msg := <-mf.headersProcessCh:
+                       if _, ok := result[msg.peerID]; ok {
+                               result[msg.peerID] = append(result[msg.peerID], msg.headers[:]...)
+                               response[msg.peerID] = true
+                               if len(response) == len(result) {
+                                       return result
+                               }
+                       }
+               case <-timeout.C:
+                       log.WithFields(log.Fields{"module": logModule, "err": errRequestTimeout}).Warn("failed on parallel fetch headers")
+                       return result
+               }
        }
 }
 
@@ -38,6 +234,15 @@ func (mf *msgFetcher) processBlock(peerID string, block *types.Block) {
 
 func (mf *msgFetcher) processBlocks(peerID string, blocks []*types.Block) {
        mf.blocksProcessCh <- &blocksMsg{blocks: blocks, peerID: peerID}
+       mf.mux.RLock()
+       blocksMsgChan, ok := mf.blocksMsgChanMap[peerID]
+       mf.mux.RUnlock()
+       if !ok {
+               mf.peers.ProcessIllegal(peerID, security.LevelMsgIllegal, "msg from unsolicited peer")
+               return
+       }
+
+       blocksMsgChan <- blocks
 }
 
 func (mf *msgFetcher) processHeaders(peerID string, headers []*types.BlockHeader) {
@@ -51,10 +256,10 @@ func (mf *msgFetcher) requireBlock(peerID string, height uint64) (*types.Block,
        }
 
        if ok := peer.GetBlockByHeight(height); !ok {
-               return nil, errPeerDropped
+               return nil, errSendMsg
        }
 
-       timeout := time.NewTimer(syncTimeout)
+       timeout := time.NewTimer(requireBlockTimeout)
        defer timeout.Stop()
 
        for {
@@ -76,53 +281,66 @@ func (mf *msgFetcher) requireBlock(peerID string, height uint64) (*types.Block,
 func (mf *msgFetcher) requireBlocks(peerID string, locator []*bc.Hash, stopHash *bc.Hash) ([]*types.Block, error) {
        peer := mf.peers.GetPeer(peerID)
        if peer == nil {
+               mf.syncPeers.delete(peerID)
                return nil, errPeerDropped
        }
 
+       receiveCh := make(chan []*types.Block, 1)
+       mf.mux.Lock()
+       mf.blocksMsgChanMap[peerID] = receiveCh
+       mf.mux.Unlock()
+
        if ok := peer.GetBlocks(locator, stopHash); !ok {
-               return nil, errPeerDropped
+               return nil, errSendMsg
        }
 
-       timeout := time.NewTimer(syncTimeout)
+       timeout := time.NewTimer(requireBlocksTimeout)
        defer timeout.Stop()
+       select {
+       case blocks := <-receiveCh:
+               return blocks, nil
+       case <-timeout.C:
+               return nil, errRequestBlocksTimeout
+       }
+}
 
+func (mf *msgFetcher) resetParameter() {
+       mf.blocksMsgChanMap = make(map[string]chan []*types.Block)
+       mf.syncPeers = newFastSyncPeers()
+       mf.storage.resetParameter()
+       //empty chan
        for {
                select {
-               case msg := <-mf.blocksProcessCh:
-                       if msg.peerID != peerID {
-                               continue
-                       }
-
-                       return msg.blocks, nil
-               case <-timeout.C:
-                       return nil, errors.Wrap(errRequestTimeout, "requireBlocks")
+               case <-mf.blocksProcessCh:
+               case <-mf.headersProcessCh:
+               default:
+                       return
                }
        }
 }
 
-func (mf *msgFetcher) requireHeaders(peerID string, locator []*bc.Hash, stopHash *bc.Hash, skip uint64) ([]*types.BlockHeader, error) {
-       peer := mf.peers.GetPeer(peerID)
-       if peer == nil {
-               return nil, errPeerDropped
+func (mf *msgFetcher) verifyBlocksMsg(blocks []*types.Block, startHeader, stopHeader *types.BlockHeader) error {
+       // null blocks
+       if len(blocks) == 0 {
+               return errors.New("null blocks msg")
        }
 
-       if ok := peer.GetHeaders(locator, stopHash, skip); !ok {
-               return nil, errPeerDropped
+       // blocks more than request
+       if uint64(len(blocks)) > stopHeader.Height-startHeader.Height+1 {
+               return errors.New("exceed length blocks msg")
        }
 
-       timeout := time.NewTimer(syncTimeout)
-       defer timeout.Stop()
-
-       for {
-               select {
-               case msg := <-mf.headersProcessCh:
-                       if msg.peerID != peerID {
-                               continue
-                       }
+       // verify start block
+       if blocks[0].Hash() != startHeader.Hash() {
+               return errors.New("get mismatch blocks msg")
+       }
 
-                       return msg.headers, nil
-               case <-timeout.C:
-                       return nil, errors.Wrap(errRequestTimeout, "requireHeaders")
+       // verify blocks continuity
+       for i := 0; i < len(blocks)-1; i++ {
+               if blocks[i].Hash() != blocks[i+1].PreviousBlockHash {
+                       return errors.New("get discontinuous blocks msg")
                }
        }
+
+       return nil
 }
diff --git a/netsync/chainmgr/peers.go b/netsync/chainmgr/peers.go
new file mode 100644 (file)
index 0000000..4dea4fe
--- /dev/null
@@ -0,0 +1,88 @@
+package chainmgr
+
+import (
+       "errors"
+       "sync"
+)
+
+var errNoValidFastSyncPeer = errors.New("no valid fast sync peer")
+
+type fastSyncPeers struct {
+       peers map[string]bool
+       mtx   sync.RWMutex
+}
+
+func newFastSyncPeers() *fastSyncPeers {
+       return &fastSyncPeers{
+               peers: make(map[string]bool),
+       }
+}
+
+func (fs *fastSyncPeers) add(peerID string) {
+       fs.mtx.Lock()
+       defer fs.mtx.Unlock()
+
+       if _, ok := fs.peers[peerID]; ok {
+               return
+       }
+
+       fs.peers[peerID] = false
+}
+
+func (fs *fastSyncPeers) delete(peerID string) {
+       fs.mtx.Lock()
+       defer fs.mtx.Unlock()
+
+       delete(fs.peers, peerID)
+}
+
+func (fs *fastSyncPeers) selectIdlePeers() []string {
+       fs.mtx.Lock()
+       defer fs.mtx.Unlock()
+
+       peers := make([]string, 0)
+       for peerID, isBusy := range fs.peers {
+               if isBusy {
+                       continue
+               }
+
+               fs.peers[peerID] = true
+               peers = append(peers, peerID)
+       }
+
+       return peers
+}
+
+func (fs *fastSyncPeers) selectIdlePeer() (string, error) {
+       fs.mtx.Lock()
+       defer fs.mtx.Unlock()
+
+       for peerID, isBusy := range fs.peers {
+               if isBusy {
+                       continue
+               }
+
+               fs.peers[peerID] = true
+               return peerID, nil
+       }
+
+       return "", errNoValidFastSyncPeer
+}
+
+func (fs *fastSyncPeers) setIdle(peerID string) {
+       fs.mtx.Lock()
+       defer fs.mtx.Unlock()
+
+       if _, ok := fs.peers[peerID]; !ok {
+               return
+       }
+
+       fs.peers[peerID] = false
+}
+
+func (fs *fastSyncPeers) size() int {
+       fs.mtx.RLock()
+       defer fs.mtx.RUnlock()
+
+       return len(fs.peers)
+}
diff --git a/netsync/chainmgr/storage.go b/netsync/chainmgr/storage.go
new file mode 100644 (file)
index 0000000..2bab996
--- /dev/null
@@ -0,0 +1,168 @@
+package chainmgr
+
+import (
+       "encoding/binary"
+       "sync"
+
+       dbm "github.com/vapor/database/leveldb"
+       "github.com/vapor/errors"
+       "github.com/vapor/protocol/bc/types"
+)
+
+var (
+       maxByteOfStorageRAM = 800 * 1024 * 1024 //100MB
+       errStorageFindBlock = errors.New("can't find block from storage")
+       errDBFindBlock      = errors.New("can't find block from DB")
+)
+
+type Storage interface {
+       resetParameter()
+       writeBlocks(peerID string, blocks []*types.Block) error
+       readBlock(height uint64) (*blockStorage, error)
+       deleteBlock(height uint64)
+}
+
+type LocalStore interface {
+       writeBlock(block *types.Block) error
+       readBlock(height uint64) (*types.Block, error)
+       clearData()
+}
+
+type blockStorage struct {
+       block  *types.Block
+       peerID string
+       size   int
+       isRAM  bool
+}
+
+type storage struct {
+       actualUsage int
+       blocks      map[uint64]*blockStorage
+       localStore  LocalStore
+       mux         sync.RWMutex
+}
+
+func newStorage(db dbm.DB) *storage {
+       DBStorage := newDBStore(db)
+       DBStorage.clearData()
+       return &storage{
+               blocks:     make(map[uint64]*blockStorage),
+               localStore: DBStorage,
+       }
+}
+
+func (s *storage) writeBlocks(peerID string, blocks []*types.Block) error {
+       s.mux.Lock()
+       defer s.mux.Unlock()
+
+       for _, block := range blocks {
+               binaryBlock, err := block.MarshalText()
+               if err != nil {
+                       return errors.Wrap(err, "Marshal block header")
+               }
+
+               if len(binaryBlock)+s.actualUsage < maxByteOfStorageRAM {
+                       s.blocks[block.Height] = &blockStorage{block: block, peerID: peerID, size: len(binaryBlock), isRAM: true}
+                       s.actualUsage += len(binaryBlock)
+                       continue
+               }
+
+               if err := s.localStore.writeBlock(block); err != nil {
+                       return err
+               }
+
+               s.blocks[block.Height] = &blockStorage{peerID: peerID, isRAM: false}
+       }
+
+       return nil
+}
+
+func (s *storage) readBlock(height uint64) (*blockStorage, error) {
+       s.mux.RLock()
+       defer s.mux.RUnlock()
+
+       blockStore, ok := s.blocks[height]
+       if !ok {
+               return nil, errStorageFindBlock
+       }
+
+       if blockStore.isRAM {
+               return blockStore, nil
+       }
+
+       block, err := s.localStore.readBlock(height)
+       if err != nil {
+               return nil, err
+       }
+
+       blockStore.block = block
+       return blockStore, nil
+}
+
+// deleteBlock delete blocks in memory
+func (s *storage) deleteBlock(height uint64) {
+       s.mux.RLock()
+       defer s.mux.RUnlock()
+
+       blockStore, ok := s.blocks[height]
+       if !ok {
+               return
+       }
+
+       if blockStore.isRAM {
+               s.actualUsage -= blockStore.size
+               delete(s.blocks, height)
+       }
+}
+
+func (s *storage) resetParameter() {
+       s.mux.Lock()
+       defer s.mux.Unlock()
+
+       s.blocks = make(map[uint64]*blockStorage)
+       s.actualUsage = 0
+       s.localStore.clearData()
+}
+
+type levelDBStorage struct {
+       db dbm.DB
+}
+
+func newDBStore(db dbm.DB) *levelDBStorage {
+       return &levelDBStorage{
+               db: db,
+       }
+}
+
+func (ls *levelDBStorage) clearData() {
+       iter := ls.db.Iterator()
+       defer iter.Release()
+
+       for iter.Next() {
+               ls.db.Delete(iter.Key())
+       }
+}
+
+func (ls *levelDBStorage) writeBlock(block *types.Block) error {
+       binaryBlock, err := block.MarshalText()
+       if err != nil {
+               return err
+       }
+
+       key := make([]byte, 8)
+       binary.BigEndian.PutUint64(key, block.Height)
+       ls.db.Set(key, binaryBlock)
+       return nil
+}
+
+func (ls *levelDBStorage) readBlock(height uint64) (*types.Block, error) {
+       key := make([]byte, 8)
+       binary.BigEndian.PutUint64(key, height)
+       binaryBlock := ls.db.Get(key)
+       if binaryBlock == nil {
+               return nil, errDBFindBlock
+       }
+
+       block := &types.Block{}
+       return block, block.UnmarshalText(binaryBlock)
+}
diff --git a/netsync/chainmgr/storage_test.go b/netsync/chainmgr/storage_test.go
new file mode 100644 (file)
index 0000000..c15d14b
--- /dev/null
@@ -0,0 +1,133 @@
+package chainmgr
+
+import (
+       "io/ioutil"
+       "os"
+       "testing"
+
+       "github.com/davecgh/go-spew/spew"
+
+       dbm "github.com/vapor/database/leveldb"
+       "github.com/vapor/protocol/bc/types"
+)
+
+func TestReadWriteBlocks(t *testing.T) {
+       tmp, err := ioutil.TempDir(".", "")
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer os.RemoveAll(tmp)
+
+       testDB := dbm.NewDB("testdb", "leveldb", tmp)
+       defer testDB.Close()
+
+       s := newStorage(testDB)
+
+       cases := []struct {
+               storageRAMLimit int
+               blocks          []*types.Block
+               peerID          string
+               isRAM           bool
+       }{
+               {
+                       storageRAMLimit: 800 * 1024 * 1024,
+                       blocks:          mockBlocks(nil, 500),
+                       peerID:          "testPeer",
+                       isRAM:           true,
+               },
+               {
+                       storageRAMLimit: 1,
+                       blocks:          mockBlocks(nil, 500),
+                       peerID:          "testPeer",
+                       isRAM:           false,
+               },
+       }
+
+       for index, c := range cases {
+               maxByteOfStorageRAM = c.storageRAMLimit
+               s.writeBlocks(c.peerID, c.blocks)
+
+               for i := 0; i < len(c.blocks); i++ {
+                       blockStorage, err := s.readBlock(uint64(i))
+                       if err != nil {
+                               t.Fatal(err)
+                       }
+
+                       if blockStorage.isRAM != c.isRAM {
+                               t.Fatalf("case %d: TestReadWriteBlocks block %d isRAM: got %t want %t", index, i, blockStorage.isRAM, c.isRAM)
+                       }
+
+                       if blockStorage.block.Hash() != c.blocks[i].Hash() {
+                               t.Fatalf("case %d: TestReadWriteBlocks block %d: got %s want %s", index, i, spew.Sdump(blockStorage.block), spew.Sdump(c.blocks[i]))
+                       }
+               }
+       }
+}
+
+func TestDeleteBlock(t *testing.T) {
+       tmp, err := ioutil.TempDir(".", "")
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer os.RemoveAll(tmp)
+
+       testDB := dbm.NewDB("testdb", "leveldb", tmp)
+       defer testDB.Close()
+
+       maxByteOfStorageRAM = 1024
+       blocks := mockBlocks(nil, 500)
+       s := newStorage(testDB)
+       for i, block := range blocks {
+               if err := s.writeBlocks("testPeer", []*types.Block{block}); err != nil {
+                       t.Fatal(err)
+               }
+
+               blockStorage, err := s.readBlock(block.Height)
+               if err != nil {
+                       t.Fatal(err)
+               }
+
+               if !blockStorage.isRAM {
+                       t.Fatalf("TestReadWriteBlocks block %d isRAM: got %t want %t", i, blockStorage.isRAM, true)
+               }
+
+               s.deleteBlock(block.Height)
+       }
+
+}
+
+func TestLevelDBStorageReadWrite(t *testing.T) {
+       tmp, err := ioutil.TempDir(".", "")
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer os.RemoveAll(tmp)
+
+       testDB := dbm.NewDB("testdb", "leveldb", tmp)
+       defer testDB.Close()
+
+       blocks := mockBlocks(nil, 16)
+       s := newDBStore(testDB)
+
+       for i, block := range blocks {
+               err := s.writeBlock(block)
+               if err != nil {
+                       t.Fatal(err)
+               }
+
+               gotBlock, err := s.readBlock(block.Height)
+               if err != nil {
+                       t.Fatal(err)
+               }
+
+               if gotBlock.Hash() != block.Hash() {
+                       t.Fatalf("TestLevelDBStorageReadWrite block %d: got %s want %s", i, spew.Sdump(gotBlock), spew.Sdump(block))
+               }
+
+               s.clearData()
+               _, err = s.readBlock(block.Height)
+               if err == nil {
+                       t.Fatalf("TestLevelDBStorageReadWrite clear data err block %d", i)
+               }
+       }
+}
index 6559675..b9d661f 100644 (file)
@@ -7,6 +7,7 @@ import (
 
        "github.com/tendermint/go-wire"
        "github.com/tendermint/tmlibs/flowrate"
+       dbm "github.com/vapor/database/leveldb"
 
        "github.com/vapor/consensus"
        "github.com/vapor/netsync/peers"
@@ -158,7 +159,7 @@ func mockBlocks(startBlock *types.Block, height uint64) []*types.Block {
        return blocks
 }
 
-func mockSync(blocks []*types.Block, mempool *mock.Mempool) *Manager {
+func mockSync(blocks []*types.Block, mempool *mock.Mempool, fastSyncDB dbm.DB) *Manager {
        chain := mock.NewChain(mempool)
        peers := peers.NewPeerSet(NewPeerSet())
        chain.SetBestBlockHeader(&blocks[len(blocks)-1].BlockHeader)
@@ -168,7 +169,7 @@ func mockSync(blocks []*types.Block, mempool *mock.Mempool) *Manager {
 
        return &Manager{
                chain:       chain,
-               blockKeeper: newBlockKeeper(chain, peers),
+               blockKeeper: newBlockKeeper(chain, peers, fastSyncDB),
                peers:       peers,
                mempool:     mempool,
                txSyncCh:    make(chan *txSyncMsg),
index 7401af2..dd269fd 100644 (file)
@@ -1,6 +1,8 @@
 package chainmgr
 
 import (
+       "io/ioutil"
+       "os"
        "reflect"
        "testing"
        "time"
@@ -8,6 +10,7 @@ import (
        "github.com/davecgh/go-spew/spew"
 
        "github.com/vapor/consensus"
+       dbm "github.com/vapor/database/leveldb"
        "github.com/vapor/protocol"
        "github.com/vapor/protocol/bc"
        "github.com/vapor/protocol/bc/types"
@@ -44,9 +47,17 @@ func getTransactions() []*types.Tx {
 }
 
 func TestSyncMempool(t *testing.T) {
+       tmpDir, err := ioutil.TempDir(".", "")
+       if err != nil {
+               t.Fatalf("failed to create temporary data folder: %v", err)
+       }
+       defer os.RemoveAll(tmpDir)
+       testDBA := dbm.NewDB("testdba", "leveldb", tmpDir)
+       testDBB := dbm.NewDB("testdbb", "leveldb", tmpDir)
+
        blocks := mockBlocks(nil, 5)
-       a := mockSync(blocks, &mock.Mempool{})
-       b := mockSync(blocks, &mock.Mempool{})
+       a := mockSync(blocks, &mock.Mempool{}, testDBA)
+       b := mockSync(blocks, &mock.Mempool{}, testDBB)
 
        netWork := NewNetWork()
        netWork.Register(a, "192.168.0.1", "test node A", consensus.SFFullNode)
index e74a458..431de06 100644 (file)
@@ -30,6 +30,7 @@ const (
 var (
        errSendStatusMsg = errors.New("send status msg fail")
        ErrPeerMisbehave = errors.New("peer is misbehave")
+       ErrNoValidPeer   = errors.New("Can't find valid fast sync peer")
 )
 
 //BasePeer is the interface for connection level peer
@@ -558,6 +559,19 @@ func (ps *PeerSet) GetPeer(id string) *Peer {
        return ps.peers[id]
 }
 
+func (ps *PeerSet) GetPeersByHeight(height uint64) []*Peer {
+       ps.mtx.RLock()
+       defer ps.mtx.RUnlock()
+
+       peers := []*Peer{}
+       for _, peer := range ps.peers {
+               if peer.Height() >= height {
+                       peers = append(peers, peer)
+               }
+       }
+       return peers
+}
+
 func (ps *PeerSet) GetPeerInfos() []*PeerInfo {
        ps.mtx.RLock()
        defer ps.mtx.RUnlock()
@@ -658,3 +672,10 @@ func (ps *PeerSet) SetStatus(peerID string, height uint64, hash *bc.Hash) {
 
        peer.SetBestStatus(height, hash)
 }
+
+func (ps *PeerSet) Size() int {
+       ps.mtx.RLock()
+       defer ps.mtx.RUnlock()
+
+       return len(ps.peers)
+}
index 61645bf..2451749 100644 (file)
@@ -7,6 +7,7 @@ import (
 
        "github.com/vapor/config"
        "github.com/vapor/consensus"
+       dbm "github.com/vapor/database/leveldb"
        "github.com/vapor/event"
        "github.com/vapor/netsync/chainmgr"
        "github.com/vapor/netsync/consensusmgr"
@@ -55,14 +56,14 @@ type SyncManager struct {
 }
 
 // NewSyncManager create sync manager and set switch.
-func NewSyncManager(config *config.Config, chain *protocol.Chain, txPool *protocol.TxPool, dispatcher *event.Dispatcher) (*SyncManager, error) {
+func NewSyncManager(config *config.Config, chain *protocol.Chain, txPool *protocol.TxPool, dispatcher *event.Dispatcher, fastSyncDB dbm.DB) (*SyncManager, error) {
        sw, err := p2p.NewSwitch(config)
        if err != nil {
                return nil, err
        }
        peers := peers.NewPeerSet(sw)
 
-       chainManger, err := chainmgr.NewManager(config, sw, chain, txPool, dispatcher, peers)
+       chainManger, err := chainmgr.NewManager(config, sw, chain, txPool, dispatcher, peers, fastSyncDB)
        if err != nil {
                return nil, err
        }
index 764edf4..d4c7fd5 100644 (file)
@@ -124,8 +124,8 @@ func NewNode(config *cfg.Config) *Node {
                        wallet.RescanBlocks()
                }
        }
-
-       syncManager, err := netsync.NewSyncManager(config, chain, txPool, dispatcher)
+       fastSyncDB := dbm.NewDB("fastsync", config.DBBackend, config.DBDir())
+       syncManager, err := netsync.NewSyncManager(config, chain, txPool, dispatcher,fastSyncDB)
        if err != nil {
                cmn.Exit(cmn.Fmt("Failed to create sync manager: %v", err))
        }