X-Git-Url: http://git.osdn.net/view?p=bytom%2Fvapor.git;a=blobdiff_plain;f=netsync%2Fchainmgr%2Fblock_keeper_test.go;fp=netsync%2Fchainmgr%2Fblock_keeper_test.go;h=4a12346f82f6e54e496609e1d8d06a21a628b814;hp=d855ed06e0a244fac4e77bb785311d84f981007f;hb=068fc645e200e34e38a75dc283e3e4f05ab15d7f;hpb=51100c2a5afb320a9b16674f8c66b067fe760eb3;ds=sidebyside diff --git a/netsync/chainmgr/block_keeper_test.go b/netsync/chainmgr/block_keeper_test.go index d855ed06..4a12346f 100644 --- a/netsync/chainmgr/block_keeper_test.go +++ b/netsync/chainmgr/block_keeper_test.go @@ -2,14 +2,18 @@ package chainmgr import ( "encoding/json" + "io/ioutil" + "os" "testing" "time" "github.com/vapor/consensus" + dbm "github.com/vapor/database/leveldb" "github.com/vapor/errors" msgs "github.com/vapor/netsync/messages" "github.com/vapor/protocol/bc" "github.com/vapor/protocol/bc/types" + "github.com/vapor/test/mock" "github.com/vapor/testutil" ) @@ -55,11 +59,21 @@ func TestRegularBlockSync(t *testing.T) { err: nil, }, } + tmp, err := ioutil.TempDir(".", "") + if err != nil { + t.Fatalf("failed to create temporary data folder: %v", err) + } + testDBA := dbm.NewDB("testdba", "leveldb", tmp) + testDBB := dbm.NewDB("testdbb", "leveldb", tmp) + defer func() { + testDBA.Close() + testDBB.Close() + os.RemoveAll(tmp) + }() for i, c := range cases { - syncTimeout = c.syncTimeout - a := mockSync(c.aBlocks, nil) - b := mockSync(c.bBlocks, nil) + a := mockSync(c.aBlocks, nil, testDBA) + b := mockSync(c.bBlocks, nil, testDBB) netWork := NewNetWork() netWork.Register(a, "192.168.0.1", "test node A", consensus.SFFullNode) netWork.Register(b, "192.168.0.2", "test node B", consensus.SFFullNode) @@ -91,9 +105,21 @@ func TestRegularBlockSync(t *testing.T) { } func TestRequireBlock(t *testing.T) { + tmp, err := ioutil.TempDir(".", "") + if err != nil { + t.Fatalf("failed to create temporary data folder: %v", err) + } + testDBA := dbm.NewDB("testdba", "leveldb", tmp) + testDBB := dbm.NewDB("testdbb", "leveldb", tmp) + defer func() { + testDBB.Close() + testDBA.Close() + os.RemoveAll(tmp) + }() + blocks := mockBlocks(nil, 5) - a := mockSync(blocks[:1], nil) - b := mockSync(blocks[:5], nil) + a := mockSync(blocks[:1], nil, testDBA) + b := mockSync(blocks[:5], nil, testDBB) netWork := NewNetWork() netWork.Register(a, "192.168.0.1", "test node A", consensus.SFFullNode) netWork.Register(b, "192.168.0.2", "test node B", consensus.SFFullNode) @@ -129,8 +155,12 @@ func TestRequireBlock(t *testing.T) { }, } + defer func() { + requireBlockTimeout = 20 * time.Second + }() + for i, c := range cases { - syncTimeout = c.syncTimeout + requireBlockTimeout = c.syncTimeout got, err := c.testNode.blockKeeper.msgFetcher.requireBlock(c.testNode.blockKeeper.syncPeer.ID(), c.requireHeight) if !testutil.DeepEqual(got, c.want) { t.Errorf("case %d: got %v want %v", i, got, c.want) @@ -142,6 +172,19 @@ func TestRequireBlock(t *testing.T) { } func TestSendMerkleBlock(t *testing.T) { + tmp, err := ioutil.TempDir(".", "") + if err != nil { + t.Fatalf("failed to create temporary data folder: %v", err) + } + + testDBA := dbm.NewDB("testdba", "leveldb", tmp) + testDBB := dbm.NewDB("testdbb", "leveldb", tmp) + defer func() { + testDBA.Close() + testDBB.Close() + os.RemoveAll(tmp) + }() + cases := []struct { txCount int relatedTxIndex []int @@ -179,7 +222,7 @@ func TestSendMerkleBlock(t *testing.T) { t.Fatal(err) } - spvNode := mockSync(blocks, nil) + spvNode := mockSync(blocks, nil, testDBA) blockHash := targetBlock.Hash() var statusResult *bc.TransactionStatus if statusResult, err = spvNode.chain.GetTransactionStatus(&blockHash); err != nil { @@ -190,7 +233,7 @@ func TestSendMerkleBlock(t *testing.T) { t.Fatal(err) } - fullNode := mockSync(blocks, nil) + fullNode := mockSync(blocks, nil, testDBB) netWork := NewNetWork() netWork.Register(spvNode, "192.168.0.1", "spv_node", consensus.SFFastSync) netWork.Register(fullNode, "192.168.0.2", "full_node", consensus.DefaultServices) @@ -257,3 +300,153 @@ func TestSendMerkleBlock(t *testing.T) { } } } + +func TestLocateBlocks(t *testing.T) { + maxNumOfBlocksPerMsg = 5 + blocks := mockBlocks(nil, 100) + cases := []struct { + locator []uint64 + stopHash bc.Hash + wantHeight []uint64 + }{ + { + locator: []uint64{20}, + stopHash: blocks[100].Hash(), + wantHeight: []uint64{20, 21, 22, 23, 24}, + }, + } + + mockChain := mock.NewChain(nil) + bk := &blockKeeper{chain: mockChain} + for _, block := range blocks { + mockChain.SetBlockByHeight(block.Height, block) + } + + for i, c := range cases { + locator := []*bc.Hash{} + for _, i := range c.locator { + hash := blocks[i].Hash() + locator = append(locator, &hash) + } + + want := []*types.Block{} + for _, i := range c.wantHeight { + want = append(want, blocks[i]) + } + + got, _ := bk.locateBlocks(locator, &c.stopHash) + if !testutil.DeepEqual(got, want) { + t.Errorf("case %d: got %v want %v", i, got, want) + } + } +} + +func TestLocateHeaders(t *testing.T) { + defer func() { + maxNumOfHeadersPerMsg = 1000 + }() + maxNumOfHeadersPerMsg = 10 + blocks := mockBlocks(nil, 150) + blocksHash := []bc.Hash{} + for _, block := range blocks { + blocksHash = append(blocksHash, block.Hash()) + } + + cases := []struct { + chainHeight uint64 + locator []uint64 + stopHash *bc.Hash + skip uint64 + wantHeight []uint64 + err bool + }{ + { + chainHeight: 100, + locator: []uint64{90}, + stopHash: &blocksHash[100], + skip: 0, + wantHeight: []uint64{90, 91, 92, 93, 94, 95, 96, 97, 98, 99}, + err: false, + }, + { + chainHeight: 100, + locator: []uint64{20}, + stopHash: &blocksHash[24], + skip: 0, + wantHeight: []uint64{20, 21, 22, 23, 24}, + err: false, + }, + { + chainHeight: 100, + locator: []uint64{20}, + stopHash: &blocksHash[20], + wantHeight: []uint64{20}, + err: false, + }, + { + chainHeight: 100, + locator: []uint64{20}, + stopHash: &blocksHash[120], + wantHeight: []uint64{}, + err: false, + }, + { + chainHeight: 100, + locator: []uint64{120, 70}, + stopHash: &blocksHash[78], + wantHeight: []uint64{70, 71, 72, 73, 74, 75, 76, 77, 78}, + err: false, + }, + { + chainHeight: 100, + locator: []uint64{15}, + stopHash: &blocksHash[10], + skip: 10, + wantHeight: []uint64{}, + err: false, + }, + { + chainHeight: 100, + locator: []uint64{15}, + stopHash: &blocksHash[80], + skip: 10, + wantHeight: []uint64{15, 26, 37, 48, 59, 70, 80}, + err: false, + }, + { + chainHeight: 100, + locator: []uint64{0}, + stopHash: &blocksHash[100], + skip: 9, + wantHeight: []uint64{0, 10, 20, 30, 40, 50, 60, 70, 80, 90}, + err: false, + }, + } + + for i, c := range cases { + mockChain := mock.NewChain(nil) + bk := &blockKeeper{chain: mockChain} + for i := uint64(0); i <= c.chainHeight; i++ { + mockChain.SetBlockByHeight(i, blocks[i]) + } + + locator := []*bc.Hash{} + for _, i := range c.locator { + hash := blocks[i].Hash() + locator = append(locator, &hash) + } + + want := []*types.BlockHeader{} + for _, i := range c.wantHeight { + want = append(want, &blocks[i].BlockHeader) + } + + got, err := bk.locateHeaders(locator, c.stopHash, c.skip, maxNumOfHeadersPerMsg) + if err != nil != c.err { + t.Errorf("case %d: got %v want err = %v", i, err, c.err) + } + if !testutil.DeepEqual(got, want) { + t.Errorf("case %d: got %v want %v", i, got, want) + } + } +}