OSDN Git Service

netsync add test case (#365)
[bytom/vapor.git] / netsync / chainmgr / block_keeper_test.go
index 43e6ec7..9472f21 100644 (file)
 package chainmgr
 
 import (
-       "container/list"
        "encoding/json"
+       "io/ioutil"
+       "os"
        "testing"
        "time"
 
        "github.com/vapor/consensus"
+       dbm "github.com/vapor/database/leveldb"
        "github.com/vapor/errors"
        msgs "github.com/vapor/netsync/messages"
+       "github.com/vapor/netsync/peers"
+       "github.com/vapor/protocol"
        "github.com/vapor/protocol/bc"
        "github.com/vapor/protocol/bc/types"
        "github.com/vapor/test/mock"
        "github.com/vapor/testutil"
 )
 
-func TestAppendHeaderList(t *testing.T) {
-       blocks := mockBlocks(nil, 7)
-       cases := []struct {
-               originalHeaders []*types.BlockHeader
-               inputHeaders    []*types.BlockHeader
-               wantHeaders     []*types.BlockHeader
-               err             error
-       }{
-               {
-                       originalHeaders: []*types.BlockHeader{&blocks[0].BlockHeader},
-                       inputHeaders:    []*types.BlockHeader{&blocks[1].BlockHeader, &blocks[2].BlockHeader},
-                       wantHeaders:     []*types.BlockHeader{&blocks[0].BlockHeader, &blocks[1].BlockHeader, &blocks[2].BlockHeader},
-                       err:             nil,
-               },
-               {
-                       originalHeaders: []*types.BlockHeader{&blocks[5].BlockHeader},
-                       inputHeaders:    []*types.BlockHeader{&blocks[6].BlockHeader},
-                       wantHeaders:     []*types.BlockHeader{&blocks[5].BlockHeader, &blocks[6].BlockHeader},
-                       err:             nil,
-               },
-               {
-                       originalHeaders: []*types.BlockHeader{&blocks[5].BlockHeader},
-                       inputHeaders:    []*types.BlockHeader{&blocks[7].BlockHeader},
-                       wantHeaders:     []*types.BlockHeader{&blocks[5].BlockHeader},
-                       err:             errAppendHeaders,
-               },
-               {
-                       originalHeaders: []*types.BlockHeader{&blocks[5].BlockHeader},
-                       inputHeaders:    []*types.BlockHeader{&blocks[7].BlockHeader, &blocks[6].BlockHeader},
-                       wantHeaders:     []*types.BlockHeader{&blocks[5].BlockHeader},
-                       err:             errAppendHeaders,
-               },
-               {
-                       originalHeaders: []*types.BlockHeader{&blocks[2].BlockHeader},
-                       inputHeaders:    []*types.BlockHeader{&blocks[3].BlockHeader, &blocks[4].BlockHeader, &blocks[6].BlockHeader},
-                       wantHeaders:     []*types.BlockHeader{&blocks[2].BlockHeader, &blocks[3].BlockHeader, &blocks[4].BlockHeader},
-                       err:             errAppendHeaders,
-               },
-       }
-
-       for i, c := range cases {
-               bk := &blockKeeper{headerList: list.New()}
-               for _, header := range c.originalHeaders {
-                       bk.headerList.PushBack(header)
-               }
-
-               if err := bk.appendHeaderList(c.inputHeaders); err != c.err {
-                       t.Errorf("case %d: got error %v want error %v", i, err, c.err)
-               }
-
-               gotHeaders := []*types.BlockHeader{}
-               for e := bk.headerList.Front(); e != nil; e = e.Next() {
-                       gotHeaders = append(gotHeaders, e.Value.(*types.BlockHeader))
-               }
-
-               if !testutil.DeepEqual(gotHeaders, c.wantHeaders) {
-                       t.Errorf("case %d: got %v want %v", i, gotHeaders, c.wantHeaders)
-               }
+func TestCheckSyncType(t *testing.T) {
+       tmp, err := ioutil.TempDir(".", "")
+       if err != nil {
+               t.Fatalf("failed to create temporary data folder: %v", err)
        }
-}
-
-func TestBlockLocator(t *testing.T) {
-       blocks := mockBlocks(nil, 500)
-       cases := []struct {
-               bestHeight uint64
-               wantHeight []uint64
-       }{
-               {
-                       bestHeight: 0,
-                       wantHeight: []uint64{0},
-               },
-               {
-                       bestHeight: 1,
-                       wantHeight: []uint64{1, 0},
-               },
-               {
-                       bestHeight: 7,
-                       wantHeight: []uint64{7, 6, 5, 4, 3, 2, 1, 0},
-               },
-               {
-                       bestHeight: 10,
-                       wantHeight: []uint64{10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0},
-               },
-               {
-                       bestHeight: 100,
-                       wantHeight: []uint64{100, 99, 98, 97, 96, 95, 94, 93, 92, 91, 89, 85, 77, 61, 29, 0},
-               },
-               {
-                       bestHeight: 500,
-                       wantHeight: []uint64{500, 499, 498, 497, 496, 495, 494, 493, 492, 491, 489, 485, 477, 461, 429, 365, 237, 0},
-               },
+       fastSyncDB := dbm.NewDB("testdb", "leveldb", tmp)
+       defer func() {
+               fastSyncDB.Close()
+               os.RemoveAll(tmp)
+       }()
+
+       blocks := mockBlocks(nil, 50)
+       chain := mock.NewChain(nil)
+       chain.SetBestBlockHeader(&blocks[len(blocks)-1].BlockHeader)
+       for _, block := range blocks {
+               chain.SetBlockByHeight(block.Height, block)
        }
 
-       for i, c := range cases {
-               mockChain := mock.NewChain(nil)
-               bk := &blockKeeper{chain: mockChain}
-               mockChain.SetBestBlockHeader(&blocks[c.bestHeight].BlockHeader)
-               for i := uint64(0); i <= c.bestHeight; i++ {
-                       mockChain.SetBlockByHeight(i, blocks[i])
-               }
-
-               want := []*bc.Hash{}
-               for _, i := range c.wantHeight {
-                       hash := blocks[i].Hash()
-                       want = append(want, &hash)
-               }
-
-               if got := bk.blockLocator(); !testutil.DeepEqual(got, want) {
-                       t.Errorf("case %d: got %v want %v", i, got, want)
-               }
+       type syncPeer struct {
+               peer               *P2PPeer
+               bestHeight         uint64
+               irreversibleHeight uint64
        }
-}
-
-func TestFastBlockSync(t *testing.T) {
-       maxBlockPerMsg = 5
-       maxBlockHeadersPerMsg = 10
-       baseChain := mockBlocks(nil, 300)
 
        cases := []struct {
-               syncTimeout time.Duration
-               aBlocks     []*types.Block
-               bBlocks     []*types.Block
-               checkPoint  *consensus.Checkpoint
-               want        []*types.Block
-               err         error
+               peers    []*syncPeer
+               syncType int
        }{
                {
-                       syncTimeout: 30 * time.Second,
-                       aBlocks:     baseChain[:100],
-                       bBlocks:     baseChain[:301],
-                       checkPoint: &consensus.Checkpoint{
-                               Height: baseChain[250].Height,
-                               Hash:   baseChain[250].Hash(),
-                       },
-                       want: baseChain[:251],
-                       err:  nil,
+                       peers:    []*syncPeer{},
+                       syncType: noNeedSync,
                },
                {
-                       syncTimeout: 30 * time.Second,
-                       aBlocks:     baseChain[:100],
-                       bBlocks:     baseChain[:301],
-                       checkPoint: &consensus.Checkpoint{
-                               Height: baseChain[100].Height,
-                               Hash:   baseChain[100].Hash(),
+                       peers: []*syncPeer{
+                               {peer: &P2PPeer{id: "peer1", flag: consensus.SFFullNode | consensus.SFFastSync}, bestHeight: 1000, irreversibleHeight: 500},
+                               {peer: &P2PPeer{id: "peer2", flag: consensus.SFFullNode | consensus.SFFastSync}, bestHeight: 50, irreversibleHeight: 50},
                        },
-                       want: baseChain[:101],
-                       err:  nil,
+                       syncType: fastSyncType,
                },
                {
-                       syncTimeout: 1 * time.Millisecond,
-                       aBlocks:     baseChain[:100],
-                       bBlocks:     baseChain[:100],
-                       checkPoint: &consensus.Checkpoint{
-                               Height: baseChain[200].Height,
-                               Hash:   baseChain[200].Hash(),
+                       peers: []*syncPeer{
+                               {peer: &P2PPeer{id: "peer1", flag: consensus.SFFullNode | consensus.SFFastSync}, bestHeight: 1000, irreversibleHeight: 100},
+                               {peer: &P2PPeer{id: "peer2", flag: consensus.SFFullNode | consensus.SFFastSync}, bestHeight: 500, irreversibleHeight: 50},
                        },
-                       want: baseChain[:100],
-                       err:  errRequestTimeout,
-               },
-       }
-
-       for i, c := range cases {
-               syncTimeout = c.syncTimeout
-               a := mockSync(c.aBlocks, nil)
-               b := mockSync(c.bBlocks, nil)
-               netWork := NewNetWork()
-               netWork.Register(a, "192.168.0.1", "test node A", consensus.SFFullNode)
-               netWork.Register(b, "192.168.0.2", "test node B", consensus.SFFullNode)
-               if B2A, A2B, err := netWork.HandsShake(a, b); err != nil {
-                       t.Errorf("fail on peer hands shake %v", err)
-               } else {
-                       go B2A.postMan()
-                       go A2B.postMan()
-               }
-
-               a.blockKeeper.syncPeer = a.peers.GetPeer("test node B")
-               if err := a.blockKeeper.fastBlockSync(c.checkPoint); errors.Root(err) != c.err {
-                       t.Errorf("case %d: got %v want %v", i, err, c.err)
-               }
-
-               got := []*types.Block{}
-               for i := uint64(0); i <= a.chain.BestBlockHeight(); i++ {
-                       block, err := a.chain.GetBlockByHeight(i)
-                       if err != nil {
-                               t.Errorf("case %d got err %v", i, err)
-                       }
-                       got = append(got, block)
-               }
-
-               if !testutil.DeepEqual(got, c.want) {
-                       t.Errorf("case %d: got %v want %v", i, got, c.want)
-               }
-       }
-}
-
-func TestLocateBlocks(t *testing.T) {
-       maxBlockPerMsg = 5
-       blocks := mockBlocks(nil, 100)
-       cases := []struct {
-               locator    []uint64
-               stopHash   bc.Hash
-               wantHeight []uint64
-       }{
-               {
-                       locator:    []uint64{20},
-                       stopHash:   blocks[100].Hash(),
-                       wantHeight: []uint64{21, 22, 23, 24, 25},
-               },
-       }
-
-       mockChain := mock.NewChain(nil)
-       bk := &blockKeeper{chain: mockChain}
-       for _, block := range blocks {
-               mockChain.SetBlockByHeight(block.Height, block)
-       }
-
-       for i, c := range cases {
-               locator := []*bc.Hash{}
-               for _, i := range c.locator {
-                       hash := blocks[i].Hash()
-                       locator = append(locator, &hash)
-               }
-
-               want := []*types.Block{}
-               for _, i := range c.wantHeight {
-                       want = append(want, blocks[i])
-               }
-
-               got, _ := bk.locateBlocks(locator, &c.stopHash)
-               if !testutil.DeepEqual(got, want) {
-                       t.Errorf("case %d: got %v want %v", i, got, want)
-               }
-       }
-}
-
-func TestLocateHeaders(t *testing.T) {
-       maxBlockHeadersPerMsg = 10
-       blocks := mockBlocks(nil, 150)
-       cases := []struct {
-               chainHeight uint64
-               locator     []uint64
-               stopHash    bc.Hash
-               wantHeight  []uint64
-               err         bool
-       }{
-               {
-                       chainHeight: 100,
-                       locator:     []uint64{},
-                       stopHash:    blocks[100].Hash(),
-                       wantHeight:  []uint64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
-                       err:         false,
-               },
-               {
-                       chainHeight: 100,
-                       locator:     []uint64{20},
-                       stopHash:    blocks[100].Hash(),
-                       wantHeight:  []uint64{21, 22, 23, 24, 25, 26, 27, 28, 29, 30},
-                       err:         false,
-               },
-               {
-                       chainHeight: 100,
-                       locator:     []uint64{20},
-                       stopHash:    blocks[24].Hash(),
-                       wantHeight:  []uint64{21, 22, 23, 24},
-                       err:         false,
-               },
-               {
-                       chainHeight: 100,
-                       locator:     []uint64{20},
-                       stopHash:    blocks[20].Hash(),
-                       wantHeight:  []uint64{},
-                       err:         false,
-               },
-               {
-                       chainHeight: 100,
-                       locator:     []uint64{20},
-                       stopHash:    bc.Hash{},
-                       wantHeight:  []uint64{},
-                       err:         true,
-               },
-               {
-                       chainHeight: 100,
-                       locator:     []uint64{120, 70},
-                       stopHash:    blocks[78].Hash(),
-                       wantHeight:  []uint64{71, 72, 73, 74, 75, 76, 77, 78},
-                       err:         false,
-               },
-       }
-
-       for i, c := range cases {
-               mockChain := mock.NewChain(nil)
-               bk := &blockKeeper{chain: mockChain}
-               for i := uint64(0); i <= c.chainHeight; i++ {
-                       mockChain.SetBlockByHeight(i, blocks[i])
-               }
-
-               locator := []*bc.Hash{}
-               for _, i := range c.locator {
-                       hash := blocks[i].Hash()
-                       locator = append(locator, &hash)
-               }
-
-               want := []*types.BlockHeader{}
-               for _, i := range c.wantHeight {
-                       want = append(want, &blocks[i].BlockHeader)
-               }
-
-               got, err := bk.locateHeaders(locator, &c.stopHash)
-               if err != nil != c.err {
-                       t.Errorf("case %d: got %v want err = %v", i, err, c.err)
-               }
-               if !testutil.DeepEqual(got, want) {
-                       t.Errorf("case %d: got %v want %v", i, got, want)
-               }
-       }
-}
-
-func TestNextCheckpoint(t *testing.T) {
-       cases := []struct {
-               checkPoints []consensus.Checkpoint
-               bestHeight  uint64
-               want        *consensus.Checkpoint
-       }{
-               {
-                       checkPoints: []consensus.Checkpoint{},
-                       bestHeight:  5000,
-                       want:        nil,
+                       syncType: regularSyncType,
                },
                {
-                       checkPoints: []consensus.Checkpoint{
-                               {Height: 10000, Hash: bc.Hash{V0: 1}},
+                       peers: []*syncPeer{
+                               {peer: &P2PPeer{id: "peer1", flag: consensus.SFFullNode | consensus.SFFastSync}, bestHeight: 51, irreversibleHeight: 50},
                        },
-                       bestHeight: 5000,
-                       want:       &consensus.Checkpoint{Height: 10000, Hash: bc.Hash{V0: 1}},
+                       syncType: regularSyncType,
                },
                {
-                       checkPoints: []consensus.Checkpoint{
-                               {Height: 10000, Hash: bc.Hash{V0: 1}},
-                               {Height: 20000, Hash: bc.Hash{V0: 2}},
-                               {Height: 30000, Hash: bc.Hash{V0: 3}},
+                       peers: []*syncPeer{
+                               {peer: &P2PPeer{id: "peer1", flag: consensus.SFFullNode | consensus.SFFastSync}, bestHeight: 30, irreversibleHeight: 30},
                        },
-                       bestHeight: 15000,
-                       want:       &consensus.Checkpoint{Height: 20000, Hash: bc.Hash{V0: 2}},
+                       syncType: noNeedSync,
                },
                {
-                       checkPoints: []consensus.Checkpoint{
-                               {Height: 10000, Hash: bc.Hash{V0: 1}},
-                               {Height: 20000, Hash: bc.Hash{V0: 2}},
-                               {Height: 30000, Hash: bc.Hash{V0: 3}},
+                       peers: []*syncPeer{
+                               {peer: &P2PPeer{id: "peer1", flag: consensus.SFFullNode}, bestHeight: 1000, irreversibleHeight: 1000},
                        },
-                       bestHeight: 10000,
-                       want:       &consensus.Checkpoint{Height: 20000, Hash: bc.Hash{V0: 2}},
+                       syncType: regularSyncType,
                },
                {
-                       checkPoints: []consensus.Checkpoint{
-                               {Height: 10000, Hash: bc.Hash{V0: 1}},
-                               {Height: 20000, Hash: bc.Hash{V0: 2}},
-                               {Height: 30000, Hash: bc.Hash{V0: 3}},
+                       peers: []*syncPeer{
+                               {peer: &P2PPeer{id: "peer1", flag: consensus.SFFullNode | consensus.SFFastSync}, bestHeight: 1000, irreversibleHeight: 50},
+                               {peer: &P2PPeer{id: "peer2", flag: consensus.SFFullNode | consensus.SFFastSync}, bestHeight: 800, irreversibleHeight: 800},
                        },
-                       bestHeight: 35000,
-                       want:       nil,
+                       syncType: fastSyncType,
                },
        }
 
-       mockChain := mock.NewChain(nil)
        for i, c := range cases {
-               consensus.ActiveNetParams.Checkpoints = c.checkPoints
-               mockChain.SetBestBlockHeader(&types.BlockHeader{Height: c.bestHeight})
-               bk := &blockKeeper{chain: mockChain}
-
-               if got := bk.nextCheckpoint(); !testutil.DeepEqual(got, c.want) {
-                       t.Errorf("case %d: got %v want %v", i, got, c.want)
+               peers := peers.NewPeerSet(NewPeerSet())
+               blockKeeper := newBlockKeeper(chain, peers, fastSyncDB)
+               for _, syncPeer := range c.peers {
+                       blockKeeper.peers.AddPeer(syncPeer.peer)
+                       blockKeeper.peers.SetStatus(syncPeer.peer.id, syncPeer.bestHeight, nil)
+                       blockKeeper.peers.SetIrreversibleStatus(syncPeer.peer.id, syncPeer.irreversibleHeight, nil)
+               }
+               gotType := blockKeeper.checkSyncType()
+               if c.syncType != gotType {
+                       t.Errorf("case %d: got %d want %d", i, gotType, c.syncType)
                }
        }
 }
@@ -395,11 +111,13 @@ func TestRegularBlockSync(t *testing.T) {
        baseChain := mockBlocks(nil, 50)
        chainX := append(baseChain, mockBlocks(baseChain[50], 60)...)
        chainY := append(baseChain, mockBlocks(baseChain[50], 70)...)
+       chainZ := append(baseChain, mockBlocks(baseChain[50], 200)...)
+       chainE := append(baseChain, mockErrorBlocks(baseChain[50], 200, 60)...)
+
        cases := []struct {
                syncTimeout time.Duration
                aBlocks     []*types.Block
                bBlocks     []*types.Block
-               syncHeight  uint64
                want        []*types.Block
                err         error
        }{
@@ -407,15 +125,13 @@ func TestRegularBlockSync(t *testing.T) {
                        syncTimeout: 30 * time.Second,
                        aBlocks:     baseChain[:20],
                        bBlocks:     baseChain[:50],
-                       syncHeight:  45,
-                       want:        baseChain[:46],
+                       want:        baseChain[:50],
                        err:         nil,
                },
                {
                        syncTimeout: 30 * time.Second,
                        aBlocks:     chainX,
                        bBlocks:     chainY,
-                       syncHeight:  70,
                        want:        chainY,
                        err:         nil,
                },
@@ -423,24 +139,46 @@ func TestRegularBlockSync(t *testing.T) {
                        syncTimeout: 30 * time.Second,
                        aBlocks:     chainX[:52],
                        bBlocks:     chainY[:53],
-                       syncHeight:  52,
                        want:        chainY[:53],
                        err:         nil,
                },
                {
-                       syncTimeout: 1 * time.Millisecond,
-                       aBlocks:     baseChain,
-                       bBlocks:     baseChain,
-                       syncHeight:  52,
-                       want:        baseChain,
+                       syncTimeout: 30 * time.Second,
+                       aBlocks:     chainX[:52],
+                       bBlocks:     chainZ,
+                       want:        chainZ[:180],
+                       err:         nil,
+               },
+               {
+                       syncTimeout: 0 * time.Second,
+                       aBlocks:     chainX[:52],
+                       bBlocks:     chainZ,
+                       want:        chainX[:52],
                        err:         errRequestTimeout,
                },
+               {
+                       syncTimeout: 30 * time.Second,
+                       aBlocks:     chainX[:52],
+                       bBlocks:     chainE,
+                       want:        chainE[:60],
+                       err:         protocol.ErrBadStateRoot,
+               },
+       }
+       tmp, err := ioutil.TempDir(".", "")
+       if err != nil {
+               t.Fatalf("failed to create temporary data folder: %v", err)
        }
+       testDBA := dbm.NewDB("testdba", "leveldb", tmp)
+       testDBB := dbm.NewDB("testdbb", "leveldb", tmp)
+       defer func() {
+               testDBA.Close()
+               testDBB.Close()
+               os.RemoveAll(tmp)
+       }()
 
        for i, c := range cases {
-               syncTimeout = c.syncTimeout
-               a := mockSync(c.aBlocks, nil)
-               b := mockSync(c.bBlocks, nil)
+               a := mockSync(c.aBlocks, nil, testDBA)
+               b := mockSync(c.bBlocks, nil, testDBB)
                netWork := NewNetWork()
                netWork.Register(a, "192.168.0.1", "test node A", consensus.SFFullNode)
                netWork.Register(b, "192.168.0.2", "test node B", consensus.SFFullNode)
@@ -451,8 +189,9 @@ func TestRegularBlockSync(t *testing.T) {
                        go A2B.postMan()
                }
 
+               requireBlockTimeout = c.syncTimeout
                a.blockKeeper.syncPeer = a.peers.GetPeer("test node B")
-               if err := a.blockKeeper.regularBlockSync(c.syncHeight); errors.Root(err) != c.err {
+               if err := a.blockKeeper.regularBlockSync(); errors.Root(err) != c.err {
                        t.Errorf("case %d: got %v want %v", i, err, c.err)
                }
 
@@ -472,9 +211,21 @@ func TestRegularBlockSync(t *testing.T) {
 }
 
 func TestRequireBlock(t *testing.T) {
+       tmp, err := ioutil.TempDir(".", "")
+       if err != nil {
+               t.Fatalf("failed to create temporary data folder: %v", err)
+       }
+       testDBA := dbm.NewDB("testdba", "leveldb", tmp)
+       testDBB := dbm.NewDB("testdbb", "leveldb", tmp)
+       defer func() {
+               testDBB.Close()
+               testDBA.Close()
+               os.RemoveAll(tmp)
+       }()
+
        blocks := mockBlocks(nil, 5)
-       a := mockSync(blocks[:1], nil)
-       b := mockSync(blocks[:5], nil)
+       a := mockSync(blocks[:1], nil, testDBA)
+       b := mockSync(blocks[:5], nil, testDBB)
        netWork := NewNetWork()
        netWork.Register(a, "192.168.0.1", "test node A", consensus.SFFullNode)
        netWork.Register(b, "192.168.0.2", "test node B", consensus.SFFullNode)
@@ -510,9 +261,13 @@ func TestRequireBlock(t *testing.T) {
                },
        }
 
+       defer func() {
+               requireBlockTimeout = 20 * time.Second
+       }()
+
        for i, c := range cases {
-               syncTimeout = c.syncTimeout
-               got, err := c.testNode.blockKeeper.requireBlock(c.requireHeight)
+               requireBlockTimeout = c.syncTimeout
+               got, err := c.testNode.blockKeeper.msgFetcher.requireBlock(c.testNode.blockKeeper.syncPeer.ID(), c.requireHeight)
                if !testutil.DeepEqual(got, c.want) {
                        t.Errorf("case %d: got %v want %v", i, got, c.want)
                }
@@ -523,6 +278,19 @@ func TestRequireBlock(t *testing.T) {
 }
 
 func TestSendMerkleBlock(t *testing.T) {
+       tmp, err := ioutil.TempDir(".", "")
+       if err != nil {
+               t.Fatalf("failed to create temporary data folder: %v", err)
+       }
+
+       testDBA := dbm.NewDB("testdba", "leveldb", tmp)
+       testDBB := dbm.NewDB("testdbb", "leveldb", tmp)
+       defer func() {
+               testDBA.Close()
+               testDBB.Close()
+               os.RemoveAll(tmp)
+       }()
+
        cases := []struct {
                txCount        int
                relatedTxIndex []int
@@ -560,7 +328,7 @@ func TestSendMerkleBlock(t *testing.T) {
                        t.Fatal(err)
                }
 
-               spvNode := mockSync(blocks, nil)
+               spvNode := mockSync(blocks, nil, testDBA)
                blockHash := targetBlock.Hash()
                var statusResult *bc.TransactionStatus
                if statusResult, err = spvNode.chain.GetTransactionStatus(&blockHash); err != nil {
@@ -571,7 +339,7 @@ func TestSendMerkleBlock(t *testing.T) {
                        t.Fatal(err)
                }
 
-               fullNode := mockSync(blocks, nil)
+               fullNode := mockSync(blocks, nil, testDBB)
                netWork := NewNetWork()
                netWork.Register(spvNode, "192.168.0.1", "spv_node", consensus.SFFastSync)
                netWork.Register(fullNode, "192.168.0.2", "full_node", consensus.DefaultServices)
@@ -638,3 +406,165 @@ func TestSendMerkleBlock(t *testing.T) {
                }
        }
 }
+
+func TestLocateBlocks(t *testing.T) {
+       maxNumOfBlocksPerMsg = 5
+       blocks := mockBlocks(nil, 100)
+       cases := []struct {
+               locator    []uint64
+               stopHash   bc.Hash
+               wantHeight []uint64
+               wantErr    error
+       }{
+               {
+                       locator:    []uint64{20},
+                       stopHash:   blocks[100].Hash(),
+                       wantHeight: []uint64{20, 21, 22, 23, 24},
+                       wantErr:    nil,
+               },
+               {
+                       locator:    []uint64{20},
+                       stopHash:   bc.NewHash([32]byte{0x01, 0x02}),
+                       wantHeight: []uint64{},
+                       wantErr:    mock.ErrFoundHeaderByHash,
+               },
+       }
+
+       mockChain := mock.NewChain(nil)
+       bk := &blockKeeper{chain: mockChain}
+       for _, block := range blocks {
+               mockChain.SetBlockByHeight(block.Height, block)
+       }
+
+       for i, c := range cases {
+               locator := []*bc.Hash{}
+               for _, i := range c.locator {
+                       hash := blocks[i].Hash()
+                       locator = append(locator, &hash)
+               }
+
+               want := []*types.Block{}
+               for _, i := range c.wantHeight {
+                       want = append(want, blocks[i])
+               }
+
+               got, err := bk.locateBlocks(locator, &c.stopHash)
+               if err != c.wantErr {
+                       t.Errorf("case %d: got %v want err = %v", i, err, c.wantErr)
+               }
+
+               if !testutil.DeepEqual(got, want) {
+                       t.Errorf("case %d: got %v want %v", i, got, want)
+               }
+       }
+}
+
+func TestLocateHeaders(t *testing.T) {
+       defer func() {
+               maxNumOfHeadersPerMsg = 1000
+       }()
+       maxNumOfHeadersPerMsg = 10
+       blocks := mockBlocks(nil, 150)
+       blocksHash := []bc.Hash{}
+       for _, block := range blocks {
+               blocksHash = append(blocksHash, block.Hash())
+       }
+
+       cases := []struct {
+               chainHeight uint64
+               locator     []uint64
+               stopHash    *bc.Hash
+               skip        uint64
+               wantHeight  []uint64
+               err         error
+       }{
+               {
+                       chainHeight: 100,
+                       locator:     []uint64{90},
+                       stopHash:    &blocksHash[100],
+                       skip:        0,
+                       wantHeight:  []uint64{90, 91, 92, 93, 94, 95, 96, 97, 98, 99},
+                       err:         nil,
+               },
+               {
+                       chainHeight: 100,
+                       locator:     []uint64{20},
+                       stopHash:    &blocksHash[24],
+                       skip:        0,
+                       wantHeight:  []uint64{20, 21, 22, 23, 24},
+                       err:         nil,
+               },
+               {
+                       chainHeight: 100,
+                       locator:     []uint64{20},
+                       stopHash:    &blocksHash[20],
+                       wantHeight:  []uint64{20},
+                       err:         nil,
+               },
+               {
+                       chainHeight: 100,
+                       locator:     []uint64{20},
+                       stopHash:    &blocksHash[120],
+                       wantHeight:  []uint64{},
+                       err:         mock.ErrFoundHeaderByHash,
+               },
+               {
+                       chainHeight: 100,
+                       locator:     []uint64{120, 70},
+                       stopHash:    &blocksHash[78],
+                       wantHeight:  []uint64{70, 71, 72, 73, 74, 75, 76, 77, 78},
+                       err:         nil,
+               },
+               {
+                       chainHeight: 100,
+                       locator:     []uint64{15},
+                       stopHash:    &blocksHash[10],
+                       skip:        10,
+                       wantHeight:  []uint64{},
+                       err:         nil,
+               },
+               {
+                       chainHeight: 100,
+                       locator:     []uint64{15},
+                       stopHash:    &blocksHash[80],
+                       skip:        10,
+                       wantHeight:  []uint64{15, 26, 37, 48, 59, 70, 80},
+                       err:         nil,
+               },
+               {
+                       chainHeight: 100,
+                       locator:     []uint64{0},
+                       stopHash:    &blocksHash[100],
+                       skip:        9,
+                       wantHeight:  []uint64{0, 10, 20, 30, 40, 50, 60, 70, 80, 90},
+                       err:         nil,
+               },
+       }
+
+       for i, c := range cases {
+               mockChain := mock.NewChain(nil)
+               bk := &blockKeeper{chain: mockChain}
+               for i := uint64(0); i <= c.chainHeight; i++ {
+                       mockChain.SetBlockByHeight(i, blocks[i])
+               }
+
+               locator := []*bc.Hash{}
+               for _, i := range c.locator {
+                       hash := blocks[i].Hash()
+                       locator = append(locator, &hash)
+               }
+
+               want := []*types.BlockHeader{}
+               for _, i := range c.wantHeight {
+                       want = append(want, &blocks[i].BlockHeader)
+               }
+
+               got, err := bk.locateHeaders(locator, c.stopHash, c.skip, maxNumOfHeadersPerMsg)
+               if err != c.err {
+                       t.Errorf("case %d: got %v want err = %v", i, err, c.err)
+               }
+               if !testutil.DeepEqual(got, want) {
+                       t.Errorf("case %d: got %v want %v", i, got, want)
+               }
+       }
+}