OSDN Git Service

43e6ec7a398e6e886291ee8c47344897e8ce546c
[bytom/vapor.git] / netsync / chainmgr / block_keeper_test.go
1 package chainmgr
2
3 import (
4         "container/list"
5         "encoding/json"
6         "testing"
7         "time"
8
9         "github.com/vapor/consensus"
10         "github.com/vapor/errors"
11         msgs "github.com/vapor/netsync/messages"
12         "github.com/vapor/protocol/bc"
13         "github.com/vapor/protocol/bc/types"
14         "github.com/vapor/test/mock"
15         "github.com/vapor/testutil"
16 )
17
18 func TestAppendHeaderList(t *testing.T) {
19         blocks := mockBlocks(nil, 7)
20         cases := []struct {
21                 originalHeaders []*types.BlockHeader
22                 inputHeaders    []*types.BlockHeader
23                 wantHeaders     []*types.BlockHeader
24                 err             error
25         }{
26                 {
27                         originalHeaders: []*types.BlockHeader{&blocks[0].BlockHeader},
28                         inputHeaders:    []*types.BlockHeader{&blocks[1].BlockHeader, &blocks[2].BlockHeader},
29                         wantHeaders:     []*types.BlockHeader{&blocks[0].BlockHeader, &blocks[1].BlockHeader, &blocks[2].BlockHeader},
30                         err:             nil,
31                 },
32                 {
33                         originalHeaders: []*types.BlockHeader{&blocks[5].BlockHeader},
34                         inputHeaders:    []*types.BlockHeader{&blocks[6].BlockHeader},
35                         wantHeaders:     []*types.BlockHeader{&blocks[5].BlockHeader, &blocks[6].BlockHeader},
36                         err:             nil,
37                 },
38                 {
39                         originalHeaders: []*types.BlockHeader{&blocks[5].BlockHeader},
40                         inputHeaders:    []*types.BlockHeader{&blocks[7].BlockHeader},
41                         wantHeaders:     []*types.BlockHeader{&blocks[5].BlockHeader},
42                         err:             errAppendHeaders,
43                 },
44                 {
45                         originalHeaders: []*types.BlockHeader{&blocks[5].BlockHeader},
46                         inputHeaders:    []*types.BlockHeader{&blocks[7].BlockHeader, &blocks[6].BlockHeader},
47                         wantHeaders:     []*types.BlockHeader{&blocks[5].BlockHeader},
48                         err:             errAppendHeaders,
49                 },
50                 {
51                         originalHeaders: []*types.BlockHeader{&blocks[2].BlockHeader},
52                         inputHeaders:    []*types.BlockHeader{&blocks[3].BlockHeader, &blocks[4].BlockHeader, &blocks[6].BlockHeader},
53                         wantHeaders:     []*types.BlockHeader{&blocks[2].BlockHeader, &blocks[3].BlockHeader, &blocks[4].BlockHeader},
54                         err:             errAppendHeaders,
55                 },
56         }
57
58         for i, c := range cases {
59                 bk := &blockKeeper{headerList: list.New()}
60                 for _, header := range c.originalHeaders {
61                         bk.headerList.PushBack(header)
62                 }
63
64                 if err := bk.appendHeaderList(c.inputHeaders); err != c.err {
65                         t.Errorf("case %d: got error %v want error %v", i, err, c.err)
66                 }
67
68                 gotHeaders := []*types.BlockHeader{}
69                 for e := bk.headerList.Front(); e != nil; e = e.Next() {
70                         gotHeaders = append(gotHeaders, e.Value.(*types.BlockHeader))
71                 }
72
73                 if !testutil.DeepEqual(gotHeaders, c.wantHeaders) {
74                         t.Errorf("case %d: got %v want %v", i, gotHeaders, c.wantHeaders)
75                 }
76         }
77 }
78
79 func TestBlockLocator(t *testing.T) {
80         blocks := mockBlocks(nil, 500)
81         cases := []struct {
82                 bestHeight uint64
83                 wantHeight []uint64
84         }{
85                 {
86                         bestHeight: 0,
87                         wantHeight: []uint64{0},
88                 },
89                 {
90                         bestHeight: 1,
91                         wantHeight: []uint64{1, 0},
92                 },
93                 {
94                         bestHeight: 7,
95                         wantHeight: []uint64{7, 6, 5, 4, 3, 2, 1, 0},
96                 },
97                 {
98                         bestHeight: 10,
99                         wantHeight: []uint64{10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0},
100                 },
101                 {
102                         bestHeight: 100,
103                         wantHeight: []uint64{100, 99, 98, 97, 96, 95, 94, 93, 92, 91, 89, 85, 77, 61, 29, 0},
104                 },
105                 {
106                         bestHeight: 500,
107                         wantHeight: []uint64{500, 499, 498, 497, 496, 495, 494, 493, 492, 491, 489, 485, 477, 461, 429, 365, 237, 0},
108                 },
109         }
110
111         for i, c := range cases {
112                 mockChain := mock.NewChain(nil)
113                 bk := &blockKeeper{chain: mockChain}
114                 mockChain.SetBestBlockHeader(&blocks[c.bestHeight].BlockHeader)
115                 for i := uint64(0); i <= c.bestHeight; i++ {
116                         mockChain.SetBlockByHeight(i, blocks[i])
117                 }
118
119                 want := []*bc.Hash{}
120                 for _, i := range c.wantHeight {
121                         hash := blocks[i].Hash()
122                         want = append(want, &hash)
123                 }
124
125                 if got := bk.blockLocator(); !testutil.DeepEqual(got, want) {
126                         t.Errorf("case %d: got %v want %v", i, got, want)
127                 }
128         }
129 }
130
131 func TestFastBlockSync(t *testing.T) {
132         maxBlockPerMsg = 5
133         maxBlockHeadersPerMsg = 10
134         baseChain := mockBlocks(nil, 300)
135
136         cases := []struct {
137                 syncTimeout time.Duration
138                 aBlocks     []*types.Block
139                 bBlocks     []*types.Block
140                 checkPoint  *consensus.Checkpoint
141                 want        []*types.Block
142                 err         error
143         }{
144                 {
145                         syncTimeout: 30 * time.Second,
146                         aBlocks:     baseChain[:100],
147                         bBlocks:     baseChain[:301],
148                         checkPoint: &consensus.Checkpoint{
149                                 Height: baseChain[250].Height,
150                                 Hash:   baseChain[250].Hash(),
151                         },
152                         want: baseChain[:251],
153                         err:  nil,
154                 },
155                 {
156                         syncTimeout: 30 * time.Second,
157                         aBlocks:     baseChain[:100],
158                         bBlocks:     baseChain[:301],
159                         checkPoint: &consensus.Checkpoint{
160                                 Height: baseChain[100].Height,
161                                 Hash:   baseChain[100].Hash(),
162                         },
163                         want: baseChain[:101],
164                         err:  nil,
165                 },
166                 {
167                         syncTimeout: 1 * time.Millisecond,
168                         aBlocks:     baseChain[:100],
169                         bBlocks:     baseChain[:100],
170                         checkPoint: &consensus.Checkpoint{
171                                 Height: baseChain[200].Height,
172                                 Hash:   baseChain[200].Hash(),
173                         },
174                         want: baseChain[:100],
175                         err:  errRequestTimeout,
176                 },
177         }
178
179         for i, c := range cases {
180                 syncTimeout = c.syncTimeout
181                 a := mockSync(c.aBlocks, nil)
182                 b := mockSync(c.bBlocks, nil)
183                 netWork := NewNetWork()
184                 netWork.Register(a, "192.168.0.1", "test node A", consensus.SFFullNode)
185                 netWork.Register(b, "192.168.0.2", "test node B", consensus.SFFullNode)
186                 if B2A, A2B, err := netWork.HandsShake(a, b); err != nil {
187                         t.Errorf("fail on peer hands shake %v", err)
188                 } else {
189                         go B2A.postMan()
190                         go A2B.postMan()
191                 }
192
193                 a.blockKeeper.syncPeer = a.peers.GetPeer("test node B")
194                 if err := a.blockKeeper.fastBlockSync(c.checkPoint); errors.Root(err) != c.err {
195                         t.Errorf("case %d: got %v want %v", i, err, c.err)
196                 }
197
198                 got := []*types.Block{}
199                 for i := uint64(0); i <= a.chain.BestBlockHeight(); i++ {
200                         block, err := a.chain.GetBlockByHeight(i)
201                         if err != nil {
202                                 t.Errorf("case %d got err %v", i, err)
203                         }
204                         got = append(got, block)
205                 }
206
207                 if !testutil.DeepEqual(got, c.want) {
208                         t.Errorf("case %d: got %v want %v", i, got, c.want)
209                 }
210         }
211 }
212
213 func TestLocateBlocks(t *testing.T) {
214         maxBlockPerMsg = 5
215         blocks := mockBlocks(nil, 100)
216         cases := []struct {
217                 locator    []uint64
218                 stopHash   bc.Hash
219                 wantHeight []uint64
220         }{
221                 {
222                         locator:    []uint64{20},
223                         stopHash:   blocks[100].Hash(),
224                         wantHeight: []uint64{21, 22, 23, 24, 25},
225                 },
226         }
227
228         mockChain := mock.NewChain(nil)
229         bk := &blockKeeper{chain: mockChain}
230         for _, block := range blocks {
231                 mockChain.SetBlockByHeight(block.Height, block)
232         }
233
234         for i, c := range cases {
235                 locator := []*bc.Hash{}
236                 for _, i := range c.locator {
237                         hash := blocks[i].Hash()
238                         locator = append(locator, &hash)
239                 }
240
241                 want := []*types.Block{}
242                 for _, i := range c.wantHeight {
243                         want = append(want, blocks[i])
244                 }
245
246                 got, _ := bk.locateBlocks(locator, &c.stopHash)
247                 if !testutil.DeepEqual(got, want) {
248                         t.Errorf("case %d: got %v want %v", i, got, want)
249                 }
250         }
251 }
252
253 func TestLocateHeaders(t *testing.T) {
254         maxBlockHeadersPerMsg = 10
255         blocks := mockBlocks(nil, 150)
256         cases := []struct {
257                 chainHeight uint64
258                 locator     []uint64
259                 stopHash    bc.Hash
260                 wantHeight  []uint64
261                 err         bool
262         }{
263                 {
264                         chainHeight: 100,
265                         locator:     []uint64{},
266                         stopHash:    blocks[100].Hash(),
267                         wantHeight:  []uint64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
268                         err:         false,
269                 },
270                 {
271                         chainHeight: 100,
272                         locator:     []uint64{20},
273                         stopHash:    blocks[100].Hash(),
274                         wantHeight:  []uint64{21, 22, 23, 24, 25, 26, 27, 28, 29, 30},
275                         err:         false,
276                 },
277                 {
278                         chainHeight: 100,
279                         locator:     []uint64{20},
280                         stopHash:    blocks[24].Hash(),
281                         wantHeight:  []uint64{21, 22, 23, 24},
282                         err:         false,
283                 },
284                 {
285                         chainHeight: 100,
286                         locator:     []uint64{20},
287                         stopHash:    blocks[20].Hash(),
288                         wantHeight:  []uint64{},
289                         err:         false,
290                 },
291                 {
292                         chainHeight: 100,
293                         locator:     []uint64{20},
294                         stopHash:    bc.Hash{},
295                         wantHeight:  []uint64{},
296                         err:         true,
297                 },
298                 {
299                         chainHeight: 100,
300                         locator:     []uint64{120, 70},
301                         stopHash:    blocks[78].Hash(),
302                         wantHeight:  []uint64{71, 72, 73, 74, 75, 76, 77, 78},
303                         err:         false,
304                 },
305         }
306
307         for i, c := range cases {
308                 mockChain := mock.NewChain(nil)
309                 bk := &blockKeeper{chain: mockChain}
310                 for i := uint64(0); i <= c.chainHeight; i++ {
311                         mockChain.SetBlockByHeight(i, blocks[i])
312                 }
313
314                 locator := []*bc.Hash{}
315                 for _, i := range c.locator {
316                         hash := blocks[i].Hash()
317                         locator = append(locator, &hash)
318                 }
319
320                 want := []*types.BlockHeader{}
321                 for _, i := range c.wantHeight {
322                         want = append(want, &blocks[i].BlockHeader)
323                 }
324
325                 got, err := bk.locateHeaders(locator, &c.stopHash)
326                 if err != nil != c.err {
327                         t.Errorf("case %d: got %v want err = %v", i, err, c.err)
328                 }
329                 if !testutil.DeepEqual(got, want) {
330                         t.Errorf("case %d: got %v want %v", i, got, want)
331                 }
332         }
333 }
334
335 func TestNextCheckpoint(t *testing.T) {
336         cases := []struct {
337                 checkPoints []consensus.Checkpoint
338                 bestHeight  uint64
339                 want        *consensus.Checkpoint
340         }{
341                 {
342                         checkPoints: []consensus.Checkpoint{},
343                         bestHeight:  5000,
344                         want:        nil,
345                 },
346                 {
347                         checkPoints: []consensus.Checkpoint{
348                                 {Height: 10000, Hash: bc.Hash{V0: 1}},
349                         },
350                         bestHeight: 5000,
351                         want:       &consensus.Checkpoint{Height: 10000, Hash: bc.Hash{V0: 1}},
352                 },
353                 {
354                         checkPoints: []consensus.Checkpoint{
355                                 {Height: 10000, Hash: bc.Hash{V0: 1}},
356                                 {Height: 20000, Hash: bc.Hash{V0: 2}},
357                                 {Height: 30000, Hash: bc.Hash{V0: 3}},
358                         },
359                         bestHeight: 15000,
360                         want:       &consensus.Checkpoint{Height: 20000, Hash: bc.Hash{V0: 2}},
361                 },
362                 {
363                         checkPoints: []consensus.Checkpoint{
364                                 {Height: 10000, Hash: bc.Hash{V0: 1}},
365                                 {Height: 20000, Hash: bc.Hash{V0: 2}},
366                                 {Height: 30000, Hash: bc.Hash{V0: 3}},
367                         },
368                         bestHeight: 10000,
369                         want:       &consensus.Checkpoint{Height: 20000, Hash: bc.Hash{V0: 2}},
370                 },
371                 {
372                         checkPoints: []consensus.Checkpoint{
373                                 {Height: 10000, Hash: bc.Hash{V0: 1}},
374                                 {Height: 20000, Hash: bc.Hash{V0: 2}},
375                                 {Height: 30000, Hash: bc.Hash{V0: 3}},
376                         },
377                         bestHeight: 35000,
378                         want:       nil,
379                 },
380         }
381
382         mockChain := mock.NewChain(nil)
383         for i, c := range cases {
384                 consensus.ActiveNetParams.Checkpoints = c.checkPoints
385                 mockChain.SetBestBlockHeader(&types.BlockHeader{Height: c.bestHeight})
386                 bk := &blockKeeper{chain: mockChain}
387
388                 if got := bk.nextCheckpoint(); !testutil.DeepEqual(got, c.want) {
389                         t.Errorf("case %d: got %v want %v", i, got, c.want)
390                 }
391         }
392 }
393
394 func TestRegularBlockSync(t *testing.T) {
395         baseChain := mockBlocks(nil, 50)
396         chainX := append(baseChain, mockBlocks(baseChain[50], 60)...)
397         chainY := append(baseChain, mockBlocks(baseChain[50], 70)...)
398         cases := []struct {
399                 syncTimeout time.Duration
400                 aBlocks     []*types.Block
401                 bBlocks     []*types.Block
402                 syncHeight  uint64
403                 want        []*types.Block
404                 err         error
405         }{
406                 {
407                         syncTimeout: 30 * time.Second,
408                         aBlocks:     baseChain[:20],
409                         bBlocks:     baseChain[:50],
410                         syncHeight:  45,
411                         want:        baseChain[:46],
412                         err:         nil,
413                 },
414                 {
415                         syncTimeout: 30 * time.Second,
416                         aBlocks:     chainX,
417                         bBlocks:     chainY,
418                         syncHeight:  70,
419                         want:        chainY,
420                         err:         nil,
421                 },
422                 {
423                         syncTimeout: 30 * time.Second,
424                         aBlocks:     chainX[:52],
425                         bBlocks:     chainY[:53],
426                         syncHeight:  52,
427                         want:        chainY[:53],
428                         err:         nil,
429                 },
430                 {
431                         syncTimeout: 1 * time.Millisecond,
432                         aBlocks:     baseChain,
433                         bBlocks:     baseChain,
434                         syncHeight:  52,
435                         want:        baseChain,
436                         err:         errRequestTimeout,
437                 },
438         }
439
440         for i, c := range cases {
441                 syncTimeout = c.syncTimeout
442                 a := mockSync(c.aBlocks, nil)
443                 b := mockSync(c.bBlocks, nil)
444                 netWork := NewNetWork()
445                 netWork.Register(a, "192.168.0.1", "test node A", consensus.SFFullNode)
446                 netWork.Register(b, "192.168.0.2", "test node B", consensus.SFFullNode)
447                 if B2A, A2B, err := netWork.HandsShake(a, b); err != nil {
448                         t.Errorf("fail on peer hands shake %v", err)
449                 } else {
450                         go B2A.postMan()
451                         go A2B.postMan()
452                 }
453
454                 a.blockKeeper.syncPeer = a.peers.GetPeer("test node B")
455                 if err := a.blockKeeper.regularBlockSync(c.syncHeight); errors.Root(err) != c.err {
456                         t.Errorf("case %d: got %v want %v", i, err, c.err)
457                 }
458
459                 got := []*types.Block{}
460                 for i := uint64(0); i <= a.chain.BestBlockHeight(); i++ {
461                         block, err := a.chain.GetBlockByHeight(i)
462                         if err != nil {
463                                 t.Errorf("case %d got err %v", i, err)
464                         }
465                         got = append(got, block)
466                 }
467
468                 if !testutil.DeepEqual(got, c.want) {
469                         t.Errorf("case %d: got %v want %v", i, got, c.want)
470                 }
471         }
472 }
473
474 func TestRequireBlock(t *testing.T) {
475         blocks := mockBlocks(nil, 5)
476         a := mockSync(blocks[:1], nil)
477         b := mockSync(blocks[:5], nil)
478         netWork := NewNetWork()
479         netWork.Register(a, "192.168.0.1", "test node A", consensus.SFFullNode)
480         netWork.Register(b, "192.168.0.2", "test node B", consensus.SFFullNode)
481         if B2A, A2B, err := netWork.HandsShake(a, b); err != nil {
482                 t.Errorf("fail on peer hands shake %v", err)
483         } else {
484                 go B2A.postMan()
485                 go A2B.postMan()
486         }
487
488         a.blockKeeper.syncPeer = a.peers.GetPeer("test node B")
489         b.blockKeeper.syncPeer = b.peers.GetPeer("test node A")
490         cases := []struct {
491                 syncTimeout   time.Duration
492                 testNode      *Manager
493                 requireHeight uint64
494                 want          *types.Block
495                 err           error
496         }{
497                 {
498                         syncTimeout:   30 * time.Second,
499                         testNode:      a,
500                         requireHeight: 4,
501                         want:          blocks[4],
502                         err:           nil,
503                 },
504                 {
505                         syncTimeout:   1 * time.Millisecond,
506                         testNode:      b,
507                         requireHeight: 4,
508                         want:          nil,
509                         err:           errRequestTimeout,
510                 },
511         }
512
513         for i, c := range cases {
514                 syncTimeout = c.syncTimeout
515                 got, err := c.testNode.blockKeeper.requireBlock(c.requireHeight)
516                 if !testutil.DeepEqual(got, c.want) {
517                         t.Errorf("case %d: got %v want %v", i, got, c.want)
518                 }
519                 if errors.Root(err) != c.err {
520                         t.Errorf("case %d: got %v want %v", i, err, c.err)
521                 }
522         }
523 }
524
525 func TestSendMerkleBlock(t *testing.T) {
526         cases := []struct {
527                 txCount        int
528                 relatedTxIndex []int
529         }{
530                 {
531                         txCount:        10,
532                         relatedTxIndex: []int{0, 2, 5},
533                 },
534                 {
535                         txCount:        0,
536                         relatedTxIndex: []int{},
537                 },
538                 {
539                         txCount:        10,
540                         relatedTxIndex: []int{},
541                 },
542                 {
543                         txCount:        5,
544                         relatedTxIndex: []int{0, 1, 2, 3, 4},
545                 },
546                 {
547                         txCount:        20,
548                         relatedTxIndex: []int{1, 6, 3, 9, 10, 19},
549                 },
550         }
551
552         for _, c := range cases {
553                 blocks := mockBlocks(nil, 2)
554                 targetBlock := blocks[1]
555                 txs, bcTxs := mockTxs(c.txCount)
556                 var err error
557
558                 targetBlock.Transactions = txs
559                 if targetBlock.TransactionsMerkleRoot, err = types.TxMerkleRoot(bcTxs); err != nil {
560                         t.Fatal(err)
561                 }
562
563                 spvNode := mockSync(blocks, nil)
564                 blockHash := targetBlock.Hash()
565                 var statusResult *bc.TransactionStatus
566                 if statusResult, err = spvNode.chain.GetTransactionStatus(&blockHash); err != nil {
567                         t.Fatal(err)
568                 }
569
570                 if targetBlock.TransactionStatusHash, err = types.TxStatusMerkleRoot(statusResult.VerifyStatus); err != nil {
571                         t.Fatal(err)
572                 }
573
574                 fullNode := mockSync(blocks, nil)
575                 netWork := NewNetWork()
576                 netWork.Register(spvNode, "192.168.0.1", "spv_node", consensus.SFFastSync)
577                 netWork.Register(fullNode, "192.168.0.2", "full_node", consensus.DefaultServices)
578
579                 var F2S *P2PPeer
580                 if F2S, _, err = netWork.HandsShake(spvNode, fullNode); err != nil {
581                         t.Errorf("fail on peer hands shake %v", err)
582                 }
583
584                 completed := make(chan error)
585                 go func() {
586                         msgBytes := <-F2S.msgCh
587                         _, msg, _ := decodeMessage(msgBytes)
588                         switch m := msg.(type) {
589                         case *msgs.MerkleBlockMessage:
590                                 var relatedTxIDs []*bc.Hash
591                                 for _, rawTx := range m.RawTxDatas {
592                                         tx := &types.Tx{}
593                                         if err := tx.UnmarshalText(rawTx); err != nil {
594                                                 completed <- err
595                                         }
596
597                                         relatedTxIDs = append(relatedTxIDs, &tx.ID)
598                                 }
599                                 var txHashes []*bc.Hash
600                                 for _, hashByte := range m.TxHashes {
601                                         hash := bc.NewHash(hashByte)
602                                         txHashes = append(txHashes, &hash)
603                                 }
604                                 if ok := types.ValidateTxMerkleTreeProof(txHashes, m.Flags, relatedTxIDs, targetBlock.TransactionsMerkleRoot); !ok {
605                                         completed <- errors.New("validate tx fail")
606                                 }
607
608                                 var statusHashes []*bc.Hash
609                                 for _, statusByte := range m.StatusHashes {
610                                         hash := bc.NewHash(statusByte)
611                                         statusHashes = append(statusHashes, &hash)
612                                 }
613                                 var relatedStatuses []*bc.TxVerifyResult
614                                 for _, statusByte := range m.RawTxStatuses {
615                                         status := &bc.TxVerifyResult{}
616                                         err := json.Unmarshal(statusByte, status)
617                                         if err != nil {
618                                                 completed <- err
619                                         }
620                                         relatedStatuses = append(relatedStatuses, status)
621                                 }
622                                 if ok := types.ValidateStatusMerkleTreeProof(statusHashes, m.Flags, relatedStatuses, targetBlock.TransactionStatusHash); !ok {
623                                         completed <- errors.New("validate status fail")
624                                 }
625
626                                 completed <- nil
627                         }
628                 }()
629
630                 spvPeer := fullNode.peers.GetPeer("spv_node")
631                 for i := 0; i < len(c.relatedTxIndex); i++ {
632                         spvPeer.AddFilterAddress(txs[c.relatedTxIndex[i]].Outputs[0].ControlProgram())
633                 }
634                 msg := &msgs.GetMerkleBlockMessage{RawHash: targetBlock.Hash().Byte32()}
635                 fullNode.handleGetMerkleBlockMsg(spvPeer, msg)
636                 if err := <-completed; err != nil {
637                         t.Fatal(err)
638                 }
639         }
640 }