package chainmgr import ( "container/list" "encoding/json" "testing" "time" "github.com/vapor/consensus" "github.com/vapor/errors" msgs "github.com/vapor/netsync/messages" "github.com/vapor/protocol/bc" "github.com/vapor/protocol/bc/types" "github.com/vapor/test/mock" "github.com/vapor/testutil" ) func TestAppendHeaderList(t *testing.T) { blocks := mockBlocks(nil, 7) cases := []struct { originalHeaders []*types.BlockHeader inputHeaders []*types.BlockHeader wantHeaders []*types.BlockHeader err error }{ { originalHeaders: []*types.BlockHeader{&blocks[0].BlockHeader}, inputHeaders: []*types.BlockHeader{&blocks[1].BlockHeader, &blocks[2].BlockHeader}, wantHeaders: []*types.BlockHeader{&blocks[0].BlockHeader, &blocks[1].BlockHeader, &blocks[2].BlockHeader}, err: nil, }, { originalHeaders: []*types.BlockHeader{&blocks[5].BlockHeader}, inputHeaders: []*types.BlockHeader{&blocks[6].BlockHeader}, wantHeaders: []*types.BlockHeader{&blocks[5].BlockHeader, &blocks[6].BlockHeader}, err: nil, }, { originalHeaders: []*types.BlockHeader{&blocks[5].BlockHeader}, inputHeaders: []*types.BlockHeader{&blocks[7].BlockHeader}, wantHeaders: []*types.BlockHeader{&blocks[5].BlockHeader}, err: errAppendHeaders, }, { originalHeaders: []*types.BlockHeader{&blocks[5].BlockHeader}, inputHeaders: []*types.BlockHeader{&blocks[7].BlockHeader, &blocks[6].BlockHeader}, wantHeaders: []*types.BlockHeader{&blocks[5].BlockHeader}, err: errAppendHeaders, }, { originalHeaders: []*types.BlockHeader{&blocks[2].BlockHeader}, inputHeaders: []*types.BlockHeader{&blocks[3].BlockHeader, &blocks[4].BlockHeader, &blocks[6].BlockHeader}, wantHeaders: []*types.BlockHeader{&blocks[2].BlockHeader, &blocks[3].BlockHeader, &blocks[4].BlockHeader}, err: errAppendHeaders, }, } for i, c := range cases { bk := &blockKeeper{headerList: list.New()} for _, header := range c.originalHeaders { bk.headerList.PushBack(header) } if err := bk.appendHeaderList(c.inputHeaders); err != c.err { t.Errorf("case %d: got error %v want error %v", i, err, c.err) } gotHeaders := []*types.BlockHeader{} for e := bk.headerList.Front(); e != nil; e = e.Next() { gotHeaders = append(gotHeaders, e.Value.(*types.BlockHeader)) } if !testutil.DeepEqual(gotHeaders, c.wantHeaders) { t.Errorf("case %d: got %v want %v", i, gotHeaders, c.wantHeaders) } } } func TestBlockLocator(t *testing.T) { blocks := mockBlocks(nil, 500) cases := []struct { bestHeight uint64 wantHeight []uint64 }{ { bestHeight: 0, wantHeight: []uint64{0}, }, { bestHeight: 1, wantHeight: []uint64{1, 0}, }, { bestHeight: 7, wantHeight: []uint64{7, 6, 5, 4, 3, 2, 1, 0}, }, { bestHeight: 10, wantHeight: []uint64{10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}, }, { bestHeight: 100, wantHeight: []uint64{100, 99, 98, 97, 96, 95, 94, 93, 92, 91, 89, 85, 77, 61, 29, 0}, }, { bestHeight: 500, wantHeight: []uint64{500, 499, 498, 497, 496, 495, 494, 493, 492, 491, 489, 485, 477, 461, 429, 365, 237, 0}, }, } for i, c := range cases { mockChain := mock.NewChain() bk := &blockKeeper{chain: mockChain} mockChain.SetBestBlockHeader(&blocks[c.bestHeight].BlockHeader) for i := uint64(0); i <= c.bestHeight; i++ { mockChain.SetBlockByHeight(i, blocks[i]) } want := []*bc.Hash{} for _, i := range c.wantHeight { hash := blocks[i].Hash() want = append(want, &hash) } if got := bk.blockLocator(); !testutil.DeepEqual(got, want) { t.Errorf("case %d: got %v want %v", i, got, want) } } } func TestFastBlockSync(t *testing.T) { maxBlockPerMsg = 5 maxBlockHeadersPerMsg = 10 baseChain := mockBlocks(nil, 300) cases := []struct { syncTimeout time.Duration aBlocks []*types.Block bBlocks []*types.Block checkPoint *consensus.Checkpoint want []*types.Block err error }{ { syncTimeout: 30 * time.Second, aBlocks: baseChain[:100], bBlocks: baseChain[:301], checkPoint: &consensus.Checkpoint{ Height: baseChain[250].Height, Hash: baseChain[250].Hash(), }, want: baseChain[:251], err: nil, }, { syncTimeout: 30 * time.Second, aBlocks: baseChain[:100], bBlocks: baseChain[:301], checkPoint: &consensus.Checkpoint{ Height: baseChain[100].Height, Hash: baseChain[100].Hash(), }, want: baseChain[:101], err: nil, }, { syncTimeout: 1 * time.Millisecond, aBlocks: baseChain[:100], bBlocks: baseChain[:100], checkPoint: &consensus.Checkpoint{ Height: baseChain[200].Height, Hash: baseChain[200].Hash(), }, want: baseChain[:100], err: errRequestTimeout, }, } for i, c := range cases { syncTimeout = c.syncTimeout a := mockSync(c.aBlocks) b := mockSync(c.bBlocks) netWork := NewNetWork() netWork.Register(a, "192.168.0.1", "test node A", consensus.SFFullNode) netWork.Register(b, "192.168.0.2", "test node B", consensus.SFFullNode) if B2A, A2B, err := netWork.HandsShake(a, b); err != nil { t.Errorf("fail on peer hands shake %v", err) } else { go B2A.postMan() go A2B.postMan() } a.blockKeeper.syncPeer = a.peers.GetPeer("test node B") if err := a.blockKeeper.fastBlockSync(c.checkPoint); errors.Root(err) != c.err { t.Errorf("case %d: got %v want %v", i, err, c.err) } got := []*types.Block{} for i := uint64(0); i <= a.chain.BestBlockHeight(); i++ { block, err := a.chain.GetBlockByHeight(i) if err != nil { t.Errorf("case %d got err %v", i, err) } got = append(got, block) } if !testutil.DeepEqual(got, c.want) { t.Errorf("case %d: got %v want %v", i, got, c.want) } } } func TestLocateBlocks(t *testing.T) { maxBlockPerMsg = 5 blocks := mockBlocks(nil, 100) cases := []struct { locator []uint64 stopHash bc.Hash wantHeight []uint64 }{ { locator: []uint64{20}, stopHash: blocks[100].Hash(), wantHeight: []uint64{21, 22, 23, 24, 25}, }, } mockChain := mock.NewChain() bk := &blockKeeper{chain: mockChain} for _, block := range blocks { mockChain.SetBlockByHeight(block.Height, block) } for i, c := range cases { locator := []*bc.Hash{} for _, i := range c.locator { hash := blocks[i].Hash() locator = append(locator, &hash) } want := []*types.Block{} for _, i := range c.wantHeight { want = append(want, blocks[i]) } got, _ := bk.locateBlocks(locator, &c.stopHash) if !testutil.DeepEqual(got, want) { t.Errorf("case %d: got %v want %v", i, got, want) } } } func TestLocateHeaders(t *testing.T) { maxBlockHeadersPerMsg = 10 blocks := mockBlocks(nil, 150) cases := []struct { chainHeight uint64 locator []uint64 stopHash bc.Hash wantHeight []uint64 err bool }{ { chainHeight: 100, locator: []uint64{}, stopHash: blocks[100].Hash(), wantHeight: []uint64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, err: false, }, { chainHeight: 100, locator: []uint64{20}, stopHash: blocks[100].Hash(), wantHeight: []uint64{21, 22, 23, 24, 25, 26, 27, 28, 29, 30}, err: false, }, { chainHeight: 100, locator: []uint64{20}, stopHash: blocks[24].Hash(), wantHeight: []uint64{21, 22, 23, 24}, err: false, }, { chainHeight: 100, locator: []uint64{20}, stopHash: blocks[20].Hash(), wantHeight: []uint64{}, err: false, }, { chainHeight: 100, locator: []uint64{20}, stopHash: bc.Hash{}, wantHeight: []uint64{}, err: true, }, { chainHeight: 100, locator: []uint64{120, 70}, stopHash: blocks[78].Hash(), wantHeight: []uint64{71, 72, 73, 74, 75, 76, 77, 78}, err: false, }, } for i, c := range cases { mockChain := mock.NewChain() bk := &blockKeeper{chain: mockChain} for i := uint64(0); i <= c.chainHeight; i++ { mockChain.SetBlockByHeight(i, blocks[i]) } locator := []*bc.Hash{} for _, i := range c.locator { hash := blocks[i].Hash() locator = append(locator, &hash) } want := []*types.BlockHeader{} for _, i := range c.wantHeight { want = append(want, &blocks[i].BlockHeader) } got, err := bk.locateHeaders(locator, &c.stopHash) if err != nil != c.err { t.Errorf("case %d: got %v want err = %v", i, err, c.err) } if !testutil.DeepEqual(got, want) { t.Errorf("case %d: got %v want %v", i, got, want) } } } func TestNextCheckpoint(t *testing.T) { cases := []struct { checkPoints []consensus.Checkpoint bestHeight uint64 want *consensus.Checkpoint }{ { checkPoints: []consensus.Checkpoint{}, bestHeight: 5000, want: nil, }, { checkPoints: []consensus.Checkpoint{ {Height: 10000, Hash: bc.Hash{V0: 1}}, }, bestHeight: 5000, want: &consensus.Checkpoint{Height: 10000, Hash: bc.Hash{V0: 1}}, }, { checkPoints: []consensus.Checkpoint{ {Height: 10000, Hash: bc.Hash{V0: 1}}, {Height: 20000, Hash: bc.Hash{V0: 2}}, {Height: 30000, Hash: bc.Hash{V0: 3}}, }, bestHeight: 15000, want: &consensus.Checkpoint{Height: 20000, Hash: bc.Hash{V0: 2}}, }, { checkPoints: []consensus.Checkpoint{ {Height: 10000, Hash: bc.Hash{V0: 1}}, {Height: 20000, Hash: bc.Hash{V0: 2}}, {Height: 30000, Hash: bc.Hash{V0: 3}}, }, bestHeight: 10000, want: &consensus.Checkpoint{Height: 20000, Hash: bc.Hash{V0: 2}}, }, { checkPoints: []consensus.Checkpoint{ {Height: 10000, Hash: bc.Hash{V0: 1}}, {Height: 20000, Hash: bc.Hash{V0: 2}}, {Height: 30000, Hash: bc.Hash{V0: 3}}, }, bestHeight: 35000, want: nil, }, } mockChain := mock.NewChain() for i, c := range cases { consensus.ActiveNetParams.Checkpoints = c.checkPoints mockChain.SetBestBlockHeader(&types.BlockHeader{Height: c.bestHeight}) bk := &blockKeeper{chain: mockChain} if got := bk.nextCheckpoint(); !testutil.DeepEqual(got, c.want) { t.Errorf("case %d: got %v want %v", i, got, c.want) } } } func TestRegularBlockSync(t *testing.T) { baseChain := mockBlocks(nil, 50) chainX := append(baseChain, mockBlocks(baseChain[50], 60)...) chainY := append(baseChain, mockBlocks(baseChain[50], 70)...) cases := []struct { syncTimeout time.Duration aBlocks []*types.Block bBlocks []*types.Block syncHeight uint64 want []*types.Block err error }{ { syncTimeout: 30 * time.Second, aBlocks: baseChain[:20], bBlocks: baseChain[:50], syncHeight: 45, want: baseChain[:46], err: nil, }, { syncTimeout: 30 * time.Second, aBlocks: chainX, bBlocks: chainY, syncHeight: 70, want: chainY, err: nil, }, { syncTimeout: 30 * time.Second, aBlocks: chainX[:52], bBlocks: chainY[:53], syncHeight: 52, want: chainY[:53], err: nil, }, { syncTimeout: 1 * time.Millisecond, aBlocks: baseChain, bBlocks: baseChain, syncHeight: 52, want: baseChain, err: errRequestTimeout, }, } for i, c := range cases { syncTimeout = c.syncTimeout a := mockSync(c.aBlocks) b := mockSync(c.bBlocks) netWork := NewNetWork() netWork.Register(a, "192.168.0.1", "test node A", consensus.SFFullNode) netWork.Register(b, "192.168.0.2", "test node B", consensus.SFFullNode) if B2A, A2B, err := netWork.HandsShake(a, b); err != nil { t.Errorf("fail on peer hands shake %v", err) } else { go B2A.postMan() go A2B.postMan() } a.blockKeeper.syncPeer = a.peers.GetPeer("test node B") if err := a.blockKeeper.regularBlockSync(c.syncHeight); errors.Root(err) != c.err { t.Errorf("case %d: got %v want %v", i, err, c.err) } got := []*types.Block{} for i := uint64(0); i <= a.chain.BestBlockHeight(); i++ { block, err := a.chain.GetBlockByHeight(i) if err != nil { t.Errorf("case %d got err %v", i, err) } got = append(got, block) } if !testutil.DeepEqual(got, c.want) { t.Errorf("case %d: got %v want %v", i, got, c.want) } } } func TestRequireBlock(t *testing.T) { blocks := mockBlocks(nil, 5) a := mockSync(blocks[:1]) b := mockSync(blocks[:5]) netWork := NewNetWork() netWork.Register(a, "192.168.0.1", "test node A", consensus.SFFullNode) netWork.Register(b, "192.168.0.2", "test node B", consensus.SFFullNode) if B2A, A2B, err := netWork.HandsShake(a, b); err != nil { t.Errorf("fail on peer hands shake %v", err) } else { go B2A.postMan() go A2B.postMan() } a.blockKeeper.syncPeer = a.peers.GetPeer("test node B") b.blockKeeper.syncPeer = b.peers.GetPeer("test node A") cases := []struct { syncTimeout time.Duration testNode *Manager requireHeight uint64 want *types.Block err error }{ { syncTimeout: 30 * time.Second, testNode: a, requireHeight: 4, want: blocks[4], err: nil, }, { syncTimeout: 1 * time.Millisecond, testNode: b, requireHeight: 4, want: nil, err: errRequestTimeout, }, } for i, c := range cases { syncTimeout = c.syncTimeout got, err := c.testNode.blockKeeper.requireBlock(c.requireHeight) if !testutil.DeepEqual(got, c.want) { t.Errorf("case %d: got %v want %v", i, got, c.want) } if errors.Root(err) != c.err { t.Errorf("case %d: got %v want %v", i, err, c.err) } } } func TestSendMerkleBlock(t *testing.T) { cases := []struct { txCount int relatedTxIndex []int }{ { txCount: 10, relatedTxIndex: []int{0, 2, 5}, }, { txCount: 0, relatedTxIndex: []int{}, }, { txCount: 10, relatedTxIndex: []int{}, }, { txCount: 5, relatedTxIndex: []int{0, 1, 2, 3, 4}, }, { txCount: 20, relatedTxIndex: []int{1, 6, 3, 9, 10, 19}, }, } for _, c := range cases { blocks := mockBlocks(nil, 2) targetBlock := blocks[1] txs, bcTxs := mockTxs(c.txCount) var err error targetBlock.Transactions = txs if targetBlock.TransactionsMerkleRoot, err = types.TxMerkleRoot(bcTxs); err != nil { t.Fatal(err) } spvNode := mockSync(blocks) blockHash := targetBlock.Hash() var statusResult *bc.TransactionStatus if statusResult, err = spvNode.chain.GetTransactionStatus(&blockHash); err != nil { t.Fatal(err) } if targetBlock.TransactionStatusHash, err = types.TxStatusMerkleRoot(statusResult.VerifyStatus); err != nil { t.Fatal(err) } fullNode := mockSync(blocks) netWork := NewNetWork() netWork.Register(spvNode, "192.168.0.1", "spv_node", consensus.SFFastSync) netWork.Register(fullNode, "192.168.0.2", "full_node", consensus.DefaultServices) var F2S *P2PPeer if F2S, _, err = netWork.HandsShake(spvNode, fullNode); err != nil { t.Errorf("fail on peer hands shake %v", err) } completed := make(chan error) go func() { msgBytes := <-F2S.msgCh _, msg, _ := decodeMessage(msgBytes) switch m := msg.(type) { case *msgs.MerkleBlockMessage: var relatedTxIDs []*bc.Hash for _, rawTx := range m.RawTxDatas { tx := &types.Tx{} if err := tx.UnmarshalText(rawTx); err != nil { completed <- err } relatedTxIDs = append(relatedTxIDs, &tx.ID) } var txHashes []*bc.Hash for _, hashByte := range m.TxHashes { hash := bc.NewHash(hashByte) txHashes = append(txHashes, &hash) } if ok := types.ValidateTxMerkleTreeProof(txHashes, m.Flags, relatedTxIDs, targetBlock.TransactionsMerkleRoot); !ok { completed <- errors.New("validate tx fail") } var statusHashes []*bc.Hash for _, statusByte := range m.StatusHashes { hash := bc.NewHash(statusByte) statusHashes = append(statusHashes, &hash) } var relatedStatuses []*bc.TxVerifyResult for _, statusByte := range m.RawTxStatuses { status := &bc.TxVerifyResult{} err := json.Unmarshal(statusByte, status) if err != nil { completed <- err } relatedStatuses = append(relatedStatuses, status) } if ok := types.ValidateStatusMerkleTreeProof(statusHashes, m.Flags, relatedStatuses, targetBlock.TransactionStatusHash); !ok { completed <- errors.New("validate status fail") } completed <- nil } }() spvPeer := fullNode.peers.GetPeer("spv_node") for i := 0; i < len(c.relatedTxIndex); i++ { spvPeer.AddFilterAddress(txs[c.relatedTxIndex[i]].Outputs[0].ControlProgram()) } msg := &msgs.GetMerkleBlockMessage{RawHash: targetBlock.Hash().Byte32()} fullNode.handleGetMerkleBlockMsg(spvPeer, msg) if err := <-completed; err != nil { t.Fatal(err) } } }