OSDN Git Service

add parallel fast sync support (#238)
[bytom/vapor.git] / netsync / chainmgr / block_keeper_test.go
index d855ed0..4a12346 100644 (file)
@@ -2,14 +2,18 @@ package chainmgr
 
 import (
        "encoding/json"
+       "io/ioutil"
+       "os"
        "testing"
        "time"
 
        "github.com/vapor/consensus"
+       dbm "github.com/vapor/database/leveldb"
        "github.com/vapor/errors"
        msgs "github.com/vapor/netsync/messages"
        "github.com/vapor/protocol/bc"
        "github.com/vapor/protocol/bc/types"
+       "github.com/vapor/test/mock"
        "github.com/vapor/testutil"
 )
 
@@ -55,11 +59,21 @@ func TestRegularBlockSync(t *testing.T) {
                        err:         nil,
                },
        }
+       tmp, err := ioutil.TempDir(".", "")
+       if err != nil {
+               t.Fatalf("failed to create temporary data folder: %v", err)
+       }
+       testDBA := dbm.NewDB("testdba", "leveldb", tmp)
+       testDBB := dbm.NewDB("testdbb", "leveldb", tmp)
+       defer func() {
+               testDBA.Close()
+               testDBB.Close()
+               os.RemoveAll(tmp)
+       }()
 
        for i, c := range cases {
-               syncTimeout = c.syncTimeout
-               a := mockSync(c.aBlocks, nil)
-               b := mockSync(c.bBlocks, nil)
+               a := mockSync(c.aBlocks, nil, testDBA)
+               b := mockSync(c.bBlocks, nil, testDBB)
                netWork := NewNetWork()
                netWork.Register(a, "192.168.0.1", "test node A", consensus.SFFullNode)
                netWork.Register(b, "192.168.0.2", "test node B", consensus.SFFullNode)
@@ -91,9 +105,21 @@ func TestRegularBlockSync(t *testing.T) {
 }
 
 func TestRequireBlock(t *testing.T) {
+       tmp, err := ioutil.TempDir(".", "")
+       if err != nil {
+               t.Fatalf("failed to create temporary data folder: %v", err)
+       }
+       testDBA := dbm.NewDB("testdba", "leveldb", tmp)
+       testDBB := dbm.NewDB("testdbb", "leveldb", tmp)
+       defer func() {
+               testDBB.Close()
+               testDBA.Close()
+               os.RemoveAll(tmp)
+       }()
+
        blocks := mockBlocks(nil, 5)
-       a := mockSync(blocks[:1], nil)
-       b := mockSync(blocks[:5], nil)
+       a := mockSync(blocks[:1], nil, testDBA)
+       b := mockSync(blocks[:5], nil, testDBB)
        netWork := NewNetWork()
        netWork.Register(a, "192.168.0.1", "test node A", consensus.SFFullNode)
        netWork.Register(b, "192.168.0.2", "test node B", consensus.SFFullNode)
@@ -129,8 +155,12 @@ func TestRequireBlock(t *testing.T) {
                },
        }
 
+       defer func() {
+               requireBlockTimeout = 20 * time.Second
+       }()
+
        for i, c := range cases {
-               syncTimeout = c.syncTimeout
+               requireBlockTimeout = c.syncTimeout
                got, err := c.testNode.blockKeeper.msgFetcher.requireBlock(c.testNode.blockKeeper.syncPeer.ID(), c.requireHeight)
                if !testutil.DeepEqual(got, c.want) {
                        t.Errorf("case %d: got %v want %v", i, got, c.want)
@@ -142,6 +172,19 @@ func TestRequireBlock(t *testing.T) {
 }
 
 func TestSendMerkleBlock(t *testing.T) {
+       tmp, err := ioutil.TempDir(".", "")
+       if err != nil {
+               t.Fatalf("failed to create temporary data folder: %v", err)
+       }
+
+       testDBA := dbm.NewDB("testdba", "leveldb", tmp)
+       testDBB := dbm.NewDB("testdbb", "leveldb", tmp)
+       defer func() {
+               testDBA.Close()
+               testDBB.Close()
+               os.RemoveAll(tmp)
+       }()
+
        cases := []struct {
                txCount        int
                relatedTxIndex []int
@@ -179,7 +222,7 @@ func TestSendMerkleBlock(t *testing.T) {
                        t.Fatal(err)
                }
 
-               spvNode := mockSync(blocks, nil)
+               spvNode := mockSync(blocks, nil, testDBA)
                blockHash := targetBlock.Hash()
                var statusResult *bc.TransactionStatus
                if statusResult, err = spvNode.chain.GetTransactionStatus(&blockHash); err != nil {
@@ -190,7 +233,7 @@ func TestSendMerkleBlock(t *testing.T) {
                        t.Fatal(err)
                }
 
-               fullNode := mockSync(blocks, nil)
+               fullNode := mockSync(blocks, nil, testDBB)
                netWork := NewNetWork()
                netWork.Register(spvNode, "192.168.0.1", "spv_node", consensus.SFFastSync)
                netWork.Register(fullNode, "192.168.0.2", "full_node", consensus.DefaultServices)
@@ -257,3 +300,153 @@ func TestSendMerkleBlock(t *testing.T) {
                }
        }
 }
+
+func TestLocateBlocks(t *testing.T) {
+       maxNumOfBlocksPerMsg = 5
+       blocks := mockBlocks(nil, 100)
+       cases := []struct {
+               locator    []uint64
+               stopHash   bc.Hash
+               wantHeight []uint64
+       }{
+               {
+                       locator:    []uint64{20},
+                       stopHash:   blocks[100].Hash(),
+                       wantHeight: []uint64{20, 21, 22, 23, 24},
+               },
+       }
+
+       mockChain := mock.NewChain(nil)
+       bk := &blockKeeper{chain: mockChain}
+       for _, block := range blocks {
+               mockChain.SetBlockByHeight(block.Height, block)
+       }
+
+       for i, c := range cases {
+               locator := []*bc.Hash{}
+               for _, i := range c.locator {
+                       hash := blocks[i].Hash()
+                       locator = append(locator, &hash)
+               }
+
+               want := []*types.Block{}
+               for _, i := range c.wantHeight {
+                       want = append(want, blocks[i])
+               }
+
+               got, _ := bk.locateBlocks(locator, &c.stopHash)
+               if !testutil.DeepEqual(got, want) {
+                       t.Errorf("case %d: got %v want %v", i, got, want)
+               }
+       }
+}
+
+func TestLocateHeaders(t *testing.T) {
+       defer func() {
+               maxNumOfHeadersPerMsg = 1000
+       }()
+       maxNumOfHeadersPerMsg = 10
+       blocks := mockBlocks(nil, 150)
+       blocksHash := []bc.Hash{}
+       for _, block := range blocks {
+               blocksHash = append(blocksHash, block.Hash())
+       }
+
+       cases := []struct {
+               chainHeight uint64
+               locator     []uint64
+               stopHash    *bc.Hash
+               skip        uint64
+               wantHeight  []uint64
+               err         bool
+       }{
+               {
+                       chainHeight: 100,
+                       locator:     []uint64{90},
+                       stopHash:    &blocksHash[100],
+                       skip:        0,
+                       wantHeight:  []uint64{90, 91, 92, 93, 94, 95, 96, 97, 98, 99},
+                       err:         false,
+               },
+               {
+                       chainHeight: 100,
+                       locator:     []uint64{20},
+                       stopHash:    &blocksHash[24],
+                       skip:        0,
+                       wantHeight:  []uint64{20, 21, 22, 23, 24},
+                       err:         false,
+               },
+               {
+                       chainHeight: 100,
+                       locator:     []uint64{20},
+                       stopHash:    &blocksHash[20],
+                       wantHeight:  []uint64{20},
+                       err:         false,
+               },
+               {
+                       chainHeight: 100,
+                       locator:     []uint64{20},
+                       stopHash:    &blocksHash[120],
+                       wantHeight:  []uint64{},
+                       err:         false,
+               },
+               {
+                       chainHeight: 100,
+                       locator:     []uint64{120, 70},
+                       stopHash:    &blocksHash[78],
+                       wantHeight:  []uint64{70, 71, 72, 73, 74, 75, 76, 77, 78},
+                       err:         false,
+               },
+               {
+                       chainHeight: 100,
+                       locator:     []uint64{15},
+                       stopHash:    &blocksHash[10],
+                       skip:        10,
+                       wantHeight:  []uint64{},
+                       err:         false,
+               },
+               {
+                       chainHeight: 100,
+                       locator:     []uint64{15},
+                       stopHash:    &blocksHash[80],
+                       skip:        10,
+                       wantHeight:  []uint64{15, 26, 37, 48, 59, 70, 80},
+                       err:         false,
+               },
+               {
+                       chainHeight: 100,
+                       locator:     []uint64{0},
+                       stopHash:    &blocksHash[100],
+                       skip:        9,
+                       wantHeight:  []uint64{0, 10, 20, 30, 40, 50, 60, 70, 80, 90},
+                       err:         false,
+               },
+       }
+
+       for i, c := range cases {
+               mockChain := mock.NewChain(nil)
+               bk := &blockKeeper{chain: mockChain}
+               for i := uint64(0); i <= c.chainHeight; i++ {
+                       mockChain.SetBlockByHeight(i, blocks[i])
+               }
+
+               locator := []*bc.Hash{}
+               for _, i := range c.locator {
+                       hash := blocks[i].Hash()
+                       locator = append(locator, &hash)
+               }
+
+               want := []*types.BlockHeader{}
+               for _, i := range c.wantHeight {
+                       want = append(want, &blocks[i].BlockHeader)
+               }
+
+               got, err := bk.locateHeaders(locator, c.stopHash, c.skip, maxNumOfHeadersPerMsg)
+               if err != nil != c.err {
+                       t.Errorf("case %d: got %v want err = %v", i, err, c.err)
+               }
+               if !testutil.DeepEqual(got, want) {
+                       t.Errorf("case %d: got %v want %v", i, got, want)
+               }
+       }
+}