OSDN Git Service

netsync add test case (#365)
[bytom/vapor.git] / netsync / chainmgr / block_keeper_test.go
1 package chainmgr
2
3 import (
4         "encoding/json"
5         "io/ioutil"
6         "os"
7         "testing"
8         "time"
9
10         "github.com/vapor/consensus"
11         dbm "github.com/vapor/database/leveldb"
12         "github.com/vapor/errors"
13         msgs "github.com/vapor/netsync/messages"
14         "github.com/vapor/netsync/peers"
15         "github.com/vapor/protocol"
16         "github.com/vapor/protocol/bc"
17         "github.com/vapor/protocol/bc/types"
18         "github.com/vapor/test/mock"
19         "github.com/vapor/testutil"
20 )
21
22 func TestCheckSyncType(t *testing.T) {
23         tmp, err := ioutil.TempDir(".", "")
24         if err != nil {
25                 t.Fatalf("failed to create temporary data folder: %v", err)
26         }
27         fastSyncDB := dbm.NewDB("testdb", "leveldb", tmp)
28         defer func() {
29                 fastSyncDB.Close()
30                 os.RemoveAll(tmp)
31         }()
32
33         blocks := mockBlocks(nil, 50)
34         chain := mock.NewChain(nil)
35         chain.SetBestBlockHeader(&blocks[len(blocks)-1].BlockHeader)
36         for _, block := range blocks {
37                 chain.SetBlockByHeight(block.Height, block)
38         }
39
40         type syncPeer struct {
41                 peer               *P2PPeer
42                 bestHeight         uint64
43                 irreversibleHeight uint64
44         }
45
46         cases := []struct {
47                 peers    []*syncPeer
48                 syncType int
49         }{
50                 {
51                         peers:    []*syncPeer{},
52                         syncType: noNeedSync,
53                 },
54                 {
55                         peers: []*syncPeer{
56                                 {peer: &P2PPeer{id: "peer1", flag: consensus.SFFullNode | consensus.SFFastSync}, bestHeight: 1000, irreversibleHeight: 500},
57                                 {peer: &P2PPeer{id: "peer2", flag: consensus.SFFullNode | consensus.SFFastSync}, bestHeight: 50, irreversibleHeight: 50},
58                         },
59                         syncType: fastSyncType,
60                 },
61                 {
62                         peers: []*syncPeer{
63                                 {peer: &P2PPeer{id: "peer1", flag: consensus.SFFullNode | consensus.SFFastSync}, bestHeight: 1000, irreversibleHeight: 100},
64                                 {peer: &P2PPeer{id: "peer2", flag: consensus.SFFullNode | consensus.SFFastSync}, bestHeight: 500, irreversibleHeight: 50},
65                         },
66                         syncType: regularSyncType,
67                 },
68                 {
69                         peers: []*syncPeer{
70                                 {peer: &P2PPeer{id: "peer1", flag: consensus.SFFullNode | consensus.SFFastSync}, bestHeight: 51, irreversibleHeight: 50},
71                         },
72                         syncType: regularSyncType,
73                 },
74                 {
75                         peers: []*syncPeer{
76                                 {peer: &P2PPeer{id: "peer1", flag: consensus.SFFullNode | consensus.SFFastSync}, bestHeight: 30, irreversibleHeight: 30},
77                         },
78                         syncType: noNeedSync,
79                 },
80                 {
81                         peers: []*syncPeer{
82                                 {peer: &P2PPeer{id: "peer1", flag: consensus.SFFullNode}, bestHeight: 1000, irreversibleHeight: 1000},
83                         },
84                         syncType: regularSyncType,
85                 },
86                 {
87                         peers: []*syncPeer{
88                                 {peer: &P2PPeer{id: "peer1", flag: consensus.SFFullNode | consensus.SFFastSync}, bestHeight: 1000, irreversibleHeight: 50},
89                                 {peer: &P2PPeer{id: "peer2", flag: consensus.SFFullNode | consensus.SFFastSync}, bestHeight: 800, irreversibleHeight: 800},
90                         },
91                         syncType: fastSyncType,
92                 },
93         }
94
95         for i, c := range cases {
96                 peers := peers.NewPeerSet(NewPeerSet())
97                 blockKeeper := newBlockKeeper(chain, peers, fastSyncDB)
98                 for _, syncPeer := range c.peers {
99                         blockKeeper.peers.AddPeer(syncPeer.peer)
100                         blockKeeper.peers.SetStatus(syncPeer.peer.id, syncPeer.bestHeight, nil)
101                         blockKeeper.peers.SetIrreversibleStatus(syncPeer.peer.id, syncPeer.irreversibleHeight, nil)
102                 }
103                 gotType := blockKeeper.checkSyncType()
104                 if c.syncType != gotType {
105                         t.Errorf("case %d: got %d want %d", i, gotType, c.syncType)
106                 }
107         }
108 }
109
110 func TestRegularBlockSync(t *testing.T) {
111         baseChain := mockBlocks(nil, 50)
112         chainX := append(baseChain, mockBlocks(baseChain[50], 60)...)
113         chainY := append(baseChain, mockBlocks(baseChain[50], 70)...)
114         chainZ := append(baseChain, mockBlocks(baseChain[50], 200)...)
115         chainE := append(baseChain, mockErrorBlocks(baseChain[50], 200, 60)...)
116
117         cases := []struct {
118                 syncTimeout time.Duration
119                 aBlocks     []*types.Block
120                 bBlocks     []*types.Block
121                 want        []*types.Block
122                 err         error
123         }{
124                 {
125                         syncTimeout: 30 * time.Second,
126                         aBlocks:     baseChain[:20],
127                         bBlocks:     baseChain[:50],
128                         want:        baseChain[:50],
129                         err:         nil,
130                 },
131                 {
132                         syncTimeout: 30 * time.Second,
133                         aBlocks:     chainX,
134                         bBlocks:     chainY,
135                         want:        chainY,
136                         err:         nil,
137                 },
138                 {
139                         syncTimeout: 30 * time.Second,
140                         aBlocks:     chainX[:52],
141                         bBlocks:     chainY[:53],
142                         want:        chainY[:53],
143                         err:         nil,
144                 },
145                 {
146                         syncTimeout: 30 * time.Second,
147                         aBlocks:     chainX[:52],
148                         bBlocks:     chainZ,
149                         want:        chainZ[:180],
150                         err:         nil,
151                 },
152                 {
153                         syncTimeout: 0 * time.Second,
154                         aBlocks:     chainX[:52],
155                         bBlocks:     chainZ,
156                         want:        chainX[:52],
157                         err:         errRequestTimeout,
158                 },
159                 {
160                         syncTimeout: 30 * time.Second,
161                         aBlocks:     chainX[:52],
162                         bBlocks:     chainE,
163                         want:        chainE[:60],
164                         err:         protocol.ErrBadStateRoot,
165                 },
166         }
167         tmp, err := ioutil.TempDir(".", "")
168         if err != nil {
169                 t.Fatalf("failed to create temporary data folder: %v", err)
170         }
171         testDBA := dbm.NewDB("testdba", "leveldb", tmp)
172         testDBB := dbm.NewDB("testdbb", "leveldb", tmp)
173         defer func() {
174                 testDBA.Close()
175                 testDBB.Close()
176                 os.RemoveAll(tmp)
177         }()
178
179         for i, c := range cases {
180                 a := mockSync(c.aBlocks, nil, testDBA)
181                 b := mockSync(c.bBlocks, nil, testDBB)
182                 netWork := NewNetWork()
183                 netWork.Register(a, "192.168.0.1", "test node A", consensus.SFFullNode)
184                 netWork.Register(b, "192.168.0.2", "test node B", consensus.SFFullNode)
185                 if B2A, A2B, err := netWork.HandsShake(a, b); err != nil {
186                         t.Errorf("fail on peer hands shake %v", err)
187                 } else {
188                         go B2A.postMan()
189                         go A2B.postMan()
190                 }
191
192                 requireBlockTimeout = c.syncTimeout
193                 a.blockKeeper.syncPeer = a.peers.GetPeer("test node B")
194                 if err := a.blockKeeper.regularBlockSync(); errors.Root(err) != c.err {
195                         t.Errorf("case %d: got %v want %v", i, err, c.err)
196                 }
197
198                 got := []*types.Block{}
199                 for i := uint64(0); i <= a.chain.BestBlockHeight(); i++ {
200                         block, err := a.chain.GetBlockByHeight(i)
201                         if err != nil {
202                                 t.Errorf("case %d got err %v", i, err)
203                         }
204                         got = append(got, block)
205                 }
206
207                 if !testutil.DeepEqual(got, c.want) {
208                         t.Errorf("case %d: got %v want %v", i, got, c.want)
209                 }
210         }
211 }
212
213 func TestRequireBlock(t *testing.T) {
214         tmp, err := ioutil.TempDir(".", "")
215         if err != nil {
216                 t.Fatalf("failed to create temporary data folder: %v", err)
217         }
218         testDBA := dbm.NewDB("testdba", "leveldb", tmp)
219         testDBB := dbm.NewDB("testdbb", "leveldb", tmp)
220         defer func() {
221                 testDBB.Close()
222                 testDBA.Close()
223                 os.RemoveAll(tmp)
224         }()
225
226         blocks := mockBlocks(nil, 5)
227         a := mockSync(blocks[:1], nil, testDBA)
228         b := mockSync(blocks[:5], nil, testDBB)
229         netWork := NewNetWork()
230         netWork.Register(a, "192.168.0.1", "test node A", consensus.SFFullNode)
231         netWork.Register(b, "192.168.0.2", "test node B", consensus.SFFullNode)
232         if B2A, A2B, err := netWork.HandsShake(a, b); err != nil {
233                 t.Errorf("fail on peer hands shake %v", err)
234         } else {
235                 go B2A.postMan()
236                 go A2B.postMan()
237         }
238
239         a.blockKeeper.syncPeer = a.peers.GetPeer("test node B")
240         b.blockKeeper.syncPeer = b.peers.GetPeer("test node A")
241         cases := []struct {
242                 syncTimeout   time.Duration
243                 testNode      *Manager
244                 requireHeight uint64
245                 want          *types.Block
246                 err           error
247         }{
248                 {
249                         syncTimeout:   30 * time.Second,
250                         testNode:      a,
251                         requireHeight: 4,
252                         want:          blocks[4],
253                         err:           nil,
254                 },
255                 {
256                         syncTimeout:   1 * time.Millisecond,
257                         testNode:      b,
258                         requireHeight: 4,
259                         want:          nil,
260                         err:           errRequestTimeout,
261                 },
262         }
263
264         defer func() {
265                 requireBlockTimeout = 20 * time.Second
266         }()
267
268         for i, c := range cases {
269                 requireBlockTimeout = c.syncTimeout
270                 got, err := c.testNode.blockKeeper.msgFetcher.requireBlock(c.testNode.blockKeeper.syncPeer.ID(), c.requireHeight)
271                 if !testutil.DeepEqual(got, c.want) {
272                         t.Errorf("case %d: got %v want %v", i, got, c.want)
273                 }
274                 if errors.Root(err) != c.err {
275                         t.Errorf("case %d: got %v want %v", i, err, c.err)
276                 }
277         }
278 }
279
280 func TestSendMerkleBlock(t *testing.T) {
281         tmp, err := ioutil.TempDir(".", "")
282         if err != nil {
283                 t.Fatalf("failed to create temporary data folder: %v", err)
284         }
285
286         testDBA := dbm.NewDB("testdba", "leveldb", tmp)
287         testDBB := dbm.NewDB("testdbb", "leveldb", tmp)
288         defer func() {
289                 testDBA.Close()
290                 testDBB.Close()
291                 os.RemoveAll(tmp)
292         }()
293
294         cases := []struct {
295                 txCount        int
296                 relatedTxIndex []int
297         }{
298                 {
299                         txCount:        10,
300                         relatedTxIndex: []int{0, 2, 5},
301                 },
302                 {
303                         txCount:        0,
304                         relatedTxIndex: []int{},
305                 },
306                 {
307                         txCount:        10,
308                         relatedTxIndex: []int{},
309                 },
310                 {
311                         txCount:        5,
312                         relatedTxIndex: []int{0, 1, 2, 3, 4},
313                 },
314                 {
315                         txCount:        20,
316                         relatedTxIndex: []int{1, 6, 3, 9, 10, 19},
317                 },
318         }
319
320         for _, c := range cases {
321                 blocks := mockBlocks(nil, 2)
322                 targetBlock := blocks[1]
323                 txs, bcTxs := mockTxs(c.txCount)
324                 var err error
325
326                 targetBlock.Transactions = txs
327                 if targetBlock.TransactionsMerkleRoot, err = types.TxMerkleRoot(bcTxs); err != nil {
328                         t.Fatal(err)
329                 }
330
331                 spvNode := mockSync(blocks, nil, testDBA)
332                 blockHash := targetBlock.Hash()
333                 var statusResult *bc.TransactionStatus
334                 if statusResult, err = spvNode.chain.GetTransactionStatus(&blockHash); err != nil {
335                         t.Fatal(err)
336                 }
337
338                 if targetBlock.TransactionStatusHash, err = types.TxStatusMerkleRoot(statusResult.VerifyStatus); err != nil {
339                         t.Fatal(err)
340                 }
341
342                 fullNode := mockSync(blocks, nil, testDBB)
343                 netWork := NewNetWork()
344                 netWork.Register(spvNode, "192.168.0.1", "spv_node", consensus.SFFastSync)
345                 netWork.Register(fullNode, "192.168.0.2", "full_node", consensus.DefaultServices)
346
347                 var F2S *P2PPeer
348                 if F2S, _, err = netWork.HandsShake(spvNode, fullNode); err != nil {
349                         t.Errorf("fail on peer hands shake %v", err)
350                 }
351
352                 completed := make(chan error)
353                 go func() {
354                         msgBytes := <-F2S.msgCh
355                         _, msg, _ := decodeMessage(msgBytes)
356                         switch m := msg.(type) {
357                         case *msgs.MerkleBlockMessage:
358                                 var relatedTxIDs []*bc.Hash
359                                 for _, rawTx := range m.RawTxDatas {
360                                         tx := &types.Tx{}
361                                         if err := tx.UnmarshalText(rawTx); err != nil {
362                                                 completed <- err
363                                         }
364
365                                         relatedTxIDs = append(relatedTxIDs, &tx.ID)
366                                 }
367                                 var txHashes []*bc.Hash
368                                 for _, hashByte := range m.TxHashes {
369                                         hash := bc.NewHash(hashByte)
370                                         txHashes = append(txHashes, &hash)
371                                 }
372                                 if ok := types.ValidateTxMerkleTreeProof(txHashes, m.Flags, relatedTxIDs, targetBlock.TransactionsMerkleRoot); !ok {
373                                         completed <- errors.New("validate tx fail")
374                                 }
375
376                                 var statusHashes []*bc.Hash
377                                 for _, statusByte := range m.StatusHashes {
378                                         hash := bc.NewHash(statusByte)
379                                         statusHashes = append(statusHashes, &hash)
380                                 }
381                                 var relatedStatuses []*bc.TxVerifyResult
382                                 for _, statusByte := range m.RawTxStatuses {
383                                         status := &bc.TxVerifyResult{}
384                                         err := json.Unmarshal(statusByte, status)
385                                         if err != nil {
386                                                 completed <- err
387                                         }
388                                         relatedStatuses = append(relatedStatuses, status)
389                                 }
390                                 if ok := types.ValidateStatusMerkleTreeProof(statusHashes, m.Flags, relatedStatuses, targetBlock.TransactionStatusHash); !ok {
391                                         completed <- errors.New("validate status fail")
392                                 }
393
394                                 completed <- nil
395                         }
396                 }()
397
398                 spvPeer := fullNode.peers.GetPeer("spv_node")
399                 for i := 0; i < len(c.relatedTxIndex); i++ {
400                         spvPeer.AddFilterAddress(txs[c.relatedTxIndex[i]].Outputs[0].ControlProgram())
401                 }
402                 msg := &msgs.GetMerkleBlockMessage{RawHash: targetBlock.Hash().Byte32()}
403                 fullNode.handleGetMerkleBlockMsg(spvPeer, msg)
404                 if err := <-completed; err != nil {
405                         t.Fatal(err)
406                 }
407         }
408 }
409
410 func TestLocateBlocks(t *testing.T) {
411         maxNumOfBlocksPerMsg = 5
412         blocks := mockBlocks(nil, 100)
413         cases := []struct {
414                 locator    []uint64
415                 stopHash   bc.Hash
416                 wantHeight []uint64
417                 wantErr    error
418         }{
419                 {
420                         locator:    []uint64{20},
421                         stopHash:   blocks[100].Hash(),
422                         wantHeight: []uint64{20, 21, 22, 23, 24},
423                         wantErr:    nil,
424                 },
425                 {
426                         locator:    []uint64{20},
427                         stopHash:   bc.NewHash([32]byte{0x01, 0x02}),
428                         wantHeight: []uint64{},
429                         wantErr:    mock.ErrFoundHeaderByHash,
430                 },
431         }
432
433         mockChain := mock.NewChain(nil)
434         bk := &blockKeeper{chain: mockChain}
435         for _, block := range blocks {
436                 mockChain.SetBlockByHeight(block.Height, block)
437         }
438
439         for i, c := range cases {
440                 locator := []*bc.Hash{}
441                 for _, i := range c.locator {
442                         hash := blocks[i].Hash()
443                         locator = append(locator, &hash)
444                 }
445
446                 want := []*types.Block{}
447                 for _, i := range c.wantHeight {
448                         want = append(want, blocks[i])
449                 }
450
451                 got, err := bk.locateBlocks(locator, &c.stopHash)
452                 if err != c.wantErr {
453                         t.Errorf("case %d: got %v want err = %v", i, err, c.wantErr)
454                 }
455
456                 if !testutil.DeepEqual(got, want) {
457                         t.Errorf("case %d: got %v want %v", i, got, want)
458                 }
459         }
460 }
461
462 func TestLocateHeaders(t *testing.T) {
463         defer func() {
464                 maxNumOfHeadersPerMsg = 1000
465         }()
466         maxNumOfHeadersPerMsg = 10
467         blocks := mockBlocks(nil, 150)
468         blocksHash := []bc.Hash{}
469         for _, block := range blocks {
470                 blocksHash = append(blocksHash, block.Hash())
471         }
472
473         cases := []struct {
474                 chainHeight uint64
475                 locator     []uint64
476                 stopHash    *bc.Hash
477                 skip        uint64
478                 wantHeight  []uint64
479                 err         error
480         }{
481                 {
482                         chainHeight: 100,
483                         locator:     []uint64{90},
484                         stopHash:    &blocksHash[100],
485                         skip:        0,
486                         wantHeight:  []uint64{90, 91, 92, 93, 94, 95, 96, 97, 98, 99},
487                         err:         nil,
488                 },
489                 {
490                         chainHeight: 100,
491                         locator:     []uint64{20},
492                         stopHash:    &blocksHash[24],
493                         skip:        0,
494                         wantHeight:  []uint64{20, 21, 22, 23, 24},
495                         err:         nil,
496                 },
497                 {
498                         chainHeight: 100,
499                         locator:     []uint64{20},
500                         stopHash:    &blocksHash[20],
501                         wantHeight:  []uint64{20},
502                         err:         nil,
503                 },
504                 {
505                         chainHeight: 100,
506                         locator:     []uint64{20},
507                         stopHash:    &blocksHash[120],
508                         wantHeight:  []uint64{},
509                         err:         mock.ErrFoundHeaderByHash,
510                 },
511                 {
512                         chainHeight: 100,
513                         locator:     []uint64{120, 70},
514                         stopHash:    &blocksHash[78],
515                         wantHeight:  []uint64{70, 71, 72, 73, 74, 75, 76, 77, 78},
516                         err:         nil,
517                 },
518                 {
519                         chainHeight: 100,
520                         locator:     []uint64{15},
521                         stopHash:    &blocksHash[10],
522                         skip:        10,
523                         wantHeight:  []uint64{},
524                         err:         nil,
525                 },
526                 {
527                         chainHeight: 100,
528                         locator:     []uint64{15},
529                         stopHash:    &blocksHash[80],
530                         skip:        10,
531                         wantHeight:  []uint64{15, 26, 37, 48, 59, 70, 80},
532                         err:         nil,
533                 },
534                 {
535                         chainHeight: 100,
536                         locator:     []uint64{0},
537                         stopHash:    &blocksHash[100],
538                         skip:        9,
539                         wantHeight:  []uint64{0, 10, 20, 30, 40, 50, 60, 70, 80, 90},
540                         err:         nil,
541                 },
542         }
543
544         for i, c := range cases {
545                 mockChain := mock.NewChain(nil)
546                 bk := &blockKeeper{chain: mockChain}
547                 for i := uint64(0); i <= c.chainHeight; i++ {
548                         mockChain.SetBlockByHeight(i, blocks[i])
549                 }
550
551                 locator := []*bc.Hash{}
552                 for _, i := range c.locator {
553                         hash := blocks[i].Hash()
554                         locator = append(locator, &hash)
555                 }
556
557                 want := []*types.BlockHeader{}
558                 for _, i := range c.wantHeight {
559                         want = append(want, &blocks[i].BlockHeader)
560                 }
561
562                 got, err := bk.locateHeaders(locator, c.stopHash, c.skip, maxNumOfHeadersPerMsg)
563                 if err != c.err {
564                         t.Errorf("case %d: got %v want err = %v", i, err, c.err)
565                 }
566                 if !testutil.DeepEqual(got, want) {
567                         t.Errorf("case %d: got %v want %v", i, got, want)
568                 }
569         }
570 }