+
+func TestLocateBlocks(t *testing.T) {
+ maxNumOfBlocksPerMsg = 5
+ blocks := mockBlocks(nil, 100)
+ cases := []struct {
+ locator []uint64
+ stopHash bc.Hash
+ wantHeight []uint64
+ }{
+ {
+ locator: []uint64{20},
+ stopHash: blocks[100].Hash(),
+ wantHeight: []uint64{20, 21, 22, 23, 24},
+ },
+ }
+
+ mockChain := mock.NewChain(nil)
+ bk := &blockKeeper{chain: mockChain}
+ for _, block := range blocks {
+ mockChain.SetBlockByHeight(block.Height, block)
+ }
+
+ for i, c := range cases {
+ locator := []*bc.Hash{}
+ for _, i := range c.locator {
+ hash := blocks[i].Hash()
+ locator = append(locator, &hash)
+ }
+
+ want := []*types.Block{}
+ for _, i := range c.wantHeight {
+ want = append(want, blocks[i])
+ }
+
+ got, _ := bk.locateBlocks(locator, &c.stopHash)
+ if !testutil.DeepEqual(got, want) {
+ t.Errorf("case %d: got %v want %v", i, got, want)
+ }
+ }
+}
+
+func TestLocateHeaders(t *testing.T) {
+ defer func() {
+ maxNumOfHeadersPerMsg = 1000
+ }()
+ maxNumOfHeadersPerMsg = 10
+ blocks := mockBlocks(nil, 150)
+ blocksHash := []bc.Hash{}
+ for _, block := range blocks {
+ blocksHash = append(blocksHash, block.Hash())
+ }
+
+ cases := []struct {
+ chainHeight uint64
+ locator []uint64
+ stopHash *bc.Hash
+ skip uint64
+ wantHeight []uint64
+ err bool
+ }{
+ {
+ chainHeight: 100,
+ locator: []uint64{90},
+ stopHash: &blocksHash[100],
+ skip: 0,
+ wantHeight: []uint64{90, 91, 92, 93, 94, 95, 96, 97, 98, 99},
+ err: false,
+ },
+ {
+ chainHeight: 100,
+ locator: []uint64{20},
+ stopHash: &blocksHash[24],
+ skip: 0,
+ wantHeight: []uint64{20, 21, 22, 23, 24},
+ err: false,
+ },
+ {
+ chainHeight: 100,
+ locator: []uint64{20},
+ stopHash: &blocksHash[20],
+ wantHeight: []uint64{20},
+ err: false,
+ },
+ {
+ chainHeight: 100,
+ locator: []uint64{20},
+ stopHash: &blocksHash[120],
+ wantHeight: []uint64{},
+ err: false,
+ },
+ {
+ chainHeight: 100,
+ locator: []uint64{120, 70},
+ stopHash: &blocksHash[78],
+ wantHeight: []uint64{70, 71, 72, 73, 74, 75, 76, 77, 78},
+ err: false,
+ },
+ {
+ chainHeight: 100,
+ locator: []uint64{15},
+ stopHash: &blocksHash[10],
+ skip: 10,
+ wantHeight: []uint64{},
+ err: false,
+ },
+ {
+ chainHeight: 100,
+ locator: []uint64{15},
+ stopHash: &blocksHash[80],
+ skip: 10,
+ wantHeight: []uint64{15, 26, 37, 48, 59, 70, 80},
+ err: false,
+ },
+ {
+ chainHeight: 100,
+ locator: []uint64{0},
+ stopHash: &blocksHash[100],
+ skip: 9,
+ wantHeight: []uint64{0, 10, 20, 30, 40, 50, 60, 70, 80, 90},
+ err: false,
+ },
+ }
+
+ for i, c := range cases {
+ mockChain := mock.NewChain(nil)
+ bk := &blockKeeper{chain: mockChain}
+ for i := uint64(0); i <= c.chainHeight; i++ {
+ mockChain.SetBlockByHeight(i, blocks[i])
+ }
+
+ locator := []*bc.Hash{}
+ for _, i := range c.locator {
+ hash := blocks[i].Hash()
+ locator = append(locator, &hash)
+ }
+
+ want := []*types.BlockHeader{}
+ for _, i := range c.wantHeight {
+ want = append(want, &blocks[i].BlockHeader)
+ }
+
+ got, err := bk.locateHeaders(locator, c.stopHash, c.skip, maxNumOfHeadersPerMsg)
+ if err != nil != c.err {
+ t.Errorf("case %d: got %v want err = %v", i, err, c.err)
+ }
+ if !testutil.DeepEqual(got, want) {
+ t.Errorf("case %d: got %v want %v", i, got, want)
+ }
+ }
+}