OSDN Git Service

52b206c79eaa557ebdb67bfbd12a17b9d51d38c9
[bytom/vapor.git] / netsync / chainmgr / block_keeper_test.go
1 package chainmgr
2
3 import (
4         "encoding/json"
5         "io/ioutil"
6         "os"
7         "testing"
8         "time"
9
10         "github.com/bytom/vapor/consensus"
11         dbm "github.com/bytom/vapor/database/leveldb"
12         "github.com/bytom/vapor/errors"
13         msgs "github.com/bytom/vapor/netsync/messages"
14         "github.com/bytom/vapor/netsync/peers"
15         "github.com/bytom/vapor/protocol"
16         "github.com/bytom/vapor/protocol/bc"
17         "github.com/bytom/vapor/protocol/bc/types"
18         "github.com/bytom/vapor/test/mock"
19         "github.com/bytom/vapor/testutil"
20 )
21
22 func TestCheckSyncType(t *testing.T) {
23         tmp, err := ioutil.TempDir(".", "")
24         if err != nil {
25                 t.Fatalf("failed to create temporary data folder: %v", err)
26         }
27         fastSyncDB := dbm.NewDB("testdb", "leveldb", tmp)
28         defer func() {
29                 fastSyncDB.Close()
30                 os.RemoveAll(tmp)
31         }()
32
33         blocks := mockBlocks(nil, 50)
34         chain := mock.NewChain(nil)
35         chain.SetBestBlockHeader(&blocks[len(blocks)-1].BlockHeader)
36         for _, block := range blocks {
37                 chain.SetBlockByHeight(block.Height, block)
38         }
39
40         type syncPeer struct {
41                 peer               *P2PPeer
42                 bestHeight         uint64
43                 irreversibleHeight uint64
44         }
45
46         cases := []struct {
47                 peers    []*syncPeer
48                 syncType int
49         }{
50                 {
51                         peers:    []*syncPeer{},
52                         syncType: noNeedSync,
53                 },
54                 {
55                         peers: []*syncPeer{
56                                 {peer: &P2PPeer{id: "peer1", flag: consensus.SFFullNode | consensus.SFFastSync}, bestHeight: 1000, irreversibleHeight: 500},
57                                 {peer: &P2PPeer{id: "peer2", flag: consensus.SFFullNode | consensus.SFFastSync}, bestHeight: 50, irreversibleHeight: 50},
58                         },
59                         syncType: fastSyncType,
60                 },
61                 {
62                         peers: []*syncPeer{
63                                 {peer: &P2PPeer{id: "peer1", flag: consensus.SFFullNode | consensus.SFFastSync}, bestHeight: 1000, irreversibleHeight: 100},
64                                 {peer: &P2PPeer{id: "peer2", flag: consensus.SFFullNode | consensus.SFFastSync}, bestHeight: 500, irreversibleHeight: 50},
65                         },
66                         syncType: regularSyncType,
67                 },
68                 {
69                         peers: []*syncPeer{
70                                 {peer: &P2PPeer{id: "peer1", flag: consensus.SFFullNode | consensus.SFFastSync}, bestHeight: 51, irreversibleHeight: 50},
71                         },
72                         syncType: regularSyncType,
73                 },
74                 {
75                         peers: []*syncPeer{
76                                 {peer: &P2PPeer{id: "peer1", flag: consensus.SFFullNode | consensus.SFFastSync}, bestHeight: 30, irreversibleHeight: 30},
77                         },
78                         syncType: noNeedSync,
79                 },
80                 {
81                         peers: []*syncPeer{
82                                 {peer: &P2PPeer{id: "peer1", flag: consensus.SFFullNode}, bestHeight: 1000, irreversibleHeight: 1000},
83                         },
84                         syncType: regularSyncType,
85                 },
86                 {
87                         peers: []*syncPeer{
88                                 {peer: &P2PPeer{id: "peer1", flag: consensus.SFFullNode | consensus.SFFastSync}, bestHeight: 1000, irreversibleHeight: 50},
89                                 {peer: &P2PPeer{id: "peer2", flag: consensus.SFFullNode | consensus.SFFastSync}, bestHeight: 800, irreversibleHeight: 800},
90                         },
91                         syncType: fastSyncType,
92                 },
93         }
94
95         for i, c := range cases {
96                 peers := peers.NewPeerSet(NewPeerSet())
97                 blockKeeper := newBlockKeeper(chain, peers, fastSyncDB)
98                 for _, syncPeer := range c.peers {
99                         blockKeeper.peers.AddPeer(syncPeer.peer)
100                         blockKeeper.peers.SetStatus(syncPeer.peer.id, syncPeer.bestHeight, nil)
101                         blockKeeper.peers.SetIrreversibleStatus(syncPeer.peer.id, syncPeer.irreversibleHeight, nil)
102                 }
103                 gotType := blockKeeper.checkSyncType()
104                 if c.syncType != gotType {
105                         t.Errorf("case %d: got %d want %d", i, gotType, c.syncType)
106                 }
107         }
108 }
109
110 func TestRegularBlockSync(t *testing.T) {
111         baseChain := mockBlocks(nil, 50)
112         chainX := append(baseChain, mockBlocks(baseChain[50], 60)...)
113         chainY := append(baseChain, mockBlocks(baseChain[50], 70)...)
114         chainZ := append(baseChain, mockBlocks(baseChain[50], 200)...)
115         chainE := append(baseChain, mockErrorBlocks(baseChain[50], 200, 60)...)
116
117         cases := []struct {
118                 syncTimeout time.Duration
119                 aBlocks     []*types.Block
120                 bBlocks     []*types.Block
121                 want        []*types.Block
122                 err         error
123         }{
124                 {
125                         syncTimeout: 30 * time.Second,
126                         aBlocks:     baseChain[:20],
127                         bBlocks:     baseChain[:50],
128                         want:        baseChain[:50],
129                         err:         nil,
130                 },
131                 {
132                         syncTimeout: 30 * time.Second,
133                         aBlocks:     chainX,
134                         bBlocks:     chainY,
135                         want:        chainY,
136                         err:         nil,
137                 },
138                 {
139                         syncTimeout: 30 * time.Second,
140                         aBlocks:     chainX[:52],
141                         bBlocks:     chainY[:53],
142                         want:        chainY[:53],
143                         err:         nil,
144                 },
145                 {
146                         syncTimeout: 30 * time.Second,
147                         aBlocks:     chainX[:52],
148                         bBlocks:     chainZ,
149                         want:        chainZ[:180],
150                         err:         nil,
151                 },
152                 {
153                         syncTimeout: 0 * time.Second,
154                         aBlocks:     chainX[:52],
155                         bBlocks:     chainZ,
156                         want:        chainX[:52],
157                         err:         errRequestTimeout,
158                 },
159                 {
160                         syncTimeout: 30 * time.Second,
161                         aBlocks:     chainX[:52],
162                         bBlocks:     chainE,
163                         want:        chainE[:60],
164                         err:         protocol.ErrBadStateRoot,
165                 },
166         }
167         tmp, err := ioutil.TempDir(".", "")
168         if err != nil {
169                 t.Fatalf("failed to create temporary data folder: %v", err)
170         }
171         testDBA := dbm.NewDB("testdba", "leveldb", tmp)
172         testDBB := dbm.NewDB("testdbb", "leveldb", tmp)
173         defer func() {
174                 testDBA.Close()
175                 testDBB.Close()
176                 os.RemoveAll(tmp)
177         }()
178
179         for i, c := range cases {
180                 a := mockSync(c.aBlocks, nil, testDBA)
181                 b := mockSync(c.bBlocks, nil, testDBB)
182                 netWork := NewNetWork()
183                 netWork.Register(a, "192.168.0.1", "test node A", consensus.SFFullNode)
184                 netWork.Register(b, "192.168.0.2", "test node B", consensus.SFFullNode)
185                 if B2A, A2B, err := netWork.HandsShake(a, b); err != nil {
186                         t.Errorf("fail on peer hands shake %v", err)
187                 } else {
188                         go B2A.postMan()
189                         go A2B.postMan()
190                 }
191
192                 requireBlockTimeout = c.syncTimeout
193                 a.blockKeeper.syncPeer = a.peers.GetPeer("test node B")
194                 if err := a.blockKeeper.regularBlockSync(); errors.Root(err) != c.err {
195                         t.Errorf("case %d: got %v want %v", i, err, c.err)
196                 }
197
198                 got := []*types.Block{}
199                 for i := uint64(0); i <= a.chain.BestBlockHeight(); i++ {
200                         block, err := a.chain.GetBlockByHeight(i)
201                         if err != nil {
202                                 t.Errorf("case %d got err %v", i, err)
203                         }
204                         got = append(got, block)
205                 }
206
207                 if !testutil.DeepEqual(got, c.want) {
208                         t.Errorf("case %d: got %v want %v", i, got, c.want)
209                 }
210         }
211 }
212
213 func TestRequireBlock(t *testing.T) {
214         tmp, err := ioutil.TempDir(".", "")
215         if err != nil {
216                 t.Fatalf("failed to create temporary data folder: %v", err)
217         }
218         testDBA := dbm.NewDB("testdba", "leveldb", tmp)
219         testDBB := dbm.NewDB("testdbb", "leveldb", tmp)
220         defer func() {
221                 testDBB.Close()
222                 testDBA.Close()
223                 os.RemoveAll(tmp)
224         }()
225
226         blocks := mockBlocks(nil, 5)
227         a := mockSync(blocks[:1], nil, testDBA)
228         b := mockSync(blocks[:5], nil, testDBB)
229         netWork := NewNetWork()
230         netWork.Register(a, "192.168.0.1", "test node A", consensus.SFFullNode)
231         netWork.Register(b, "192.168.0.2", "test node B", consensus.SFFullNode)
232         if B2A, A2B, err := netWork.HandsShake(a, b); err != nil {
233                 t.Errorf("fail on peer hands shake %v", err)
234         } else {
235                 go B2A.postMan()
236                 go A2B.postMan()
237         }
238
239         a.blockKeeper.syncPeer = a.peers.GetPeer("test node B")
240         b.blockKeeper.syncPeer = b.peers.GetPeer("test node A")
241         cases := []struct {
242                 syncTimeout   time.Duration
243                 testNode      *Manager
244                 requireHeight uint64
245                 want          *types.Block
246                 err           error
247         }{
248                 {
249                         syncTimeout:   30 * time.Second,
250                         testNode:      a,
251                         requireHeight: 4,
252                         want:          blocks[4],
253                         err:           nil,
254                 },
255                 {
256                         syncTimeout:   1 * time.Millisecond,
257                         testNode:      b,
258                         requireHeight: 4,
259                         want:          nil,
260                         err:           errRequestTimeout,
261                 },
262         }
263
264         defer func() {
265                 requireBlockTimeout = 20 * time.Second
266         }()
267
268         for i, c := range cases {
269                 requireBlockTimeout = c.syncTimeout
270                 got, err := c.testNode.blockKeeper.msgFetcher.requireBlock(c.testNode.blockKeeper.syncPeer.ID(), c.requireHeight)
271                 if !testutil.DeepEqual(got, c.want) {
272                         t.Errorf("case %d: got %v want %v", i, got, c.want)
273                 }
274                 if errors.Root(err) != c.err {
275                         t.Errorf("case %d: got %v want %v", i, err, c.err)
276                 }
277         }
278 }
279
280 func TestSendMerkleBlock(t *testing.T) {
281         tmp, err := ioutil.TempDir(".", "")
282         if err != nil {
283                 t.Fatalf("failed to create temporary data folder: %v", err)
284         }
285
286         testDBA := dbm.NewDB("testdba", "leveldb", tmp)
287         testDBB := dbm.NewDB("testdbb", "leveldb", tmp)
288         defer func() {
289                 testDBA.Close()
290                 testDBB.Close()
291                 os.RemoveAll(tmp)
292         }()
293
294         cases := []struct {
295                 txCount        int
296                 relatedTxIndex []int
297         }{
298                 {
299                         txCount:        10,
300                         relatedTxIndex: []int{0, 2, 5},
301                 },
302                 {
303                         txCount:        0,
304                         relatedTxIndex: []int{},
305                 },
306                 {
307                         txCount:        10,
308                         relatedTxIndex: []int{},
309                 },
310                 {
311                         txCount:        5,
312                         relatedTxIndex: []int{0, 1, 2, 3, 4},
313                 },
314                 {
315                         txCount:        20,
316                         relatedTxIndex: []int{1, 6, 3, 9, 10, 19},
317                 },
318         }
319
320         for _, c := range cases {
321                 blocks := mockBlocks(nil, 2)
322                 targetBlock := blocks[1]
323                 txs, bcTxs := mockTxs(c.txCount)
324                 var err error
325
326                 targetBlock.Transactions = txs
327                 if targetBlock.TransactionsMerkleRoot, err = types.TxMerkleRoot(bcTxs); err != nil {
328                         t.Fatal(err)
329                 }
330
331                 spvNode := mockSync(blocks, nil, testDBA)
332                 blockHash := targetBlock.Hash()
333                 var statusResult *bc.TransactionStatus
334                 if statusResult, err = spvNode.chain.GetTransactionStatus(&blockHash); err != nil {
335                         t.Fatal(err)
336                 }
337
338                 if targetBlock.TransactionStatusHash, err = types.TxStatusMerkleRoot(statusResult.VerifyStatus); err != nil {
339                         t.Fatal(err)
340                 }
341
342                 fullNode := mockSync(blocks, nil, testDBB)
343                 netWork := NewNetWork()
344                 netWork.Register(spvNode, "192.168.0.1", "spv_node", consensus.SFFastSync)
345                 netWork.Register(fullNode, "192.168.0.2", "full_node", consensus.DefaultServices)
346
347                 var F2S *P2PPeer
348                 if F2S, _, err = netWork.HandsShake(spvNode, fullNode); err != nil {
349                         t.Errorf("fail on peer hands shake %v", err)
350                 }
351
352                 completed := make(chan error)
353                 go func() {
354                         msgBytes := <-F2S.msgCh
355                         _, msg, _ := decodeMessage(msgBytes)
356                         switch m := msg.(type) {
357                         case *msgs.MerkleBlockMessage:
358                                 var relatedTxIDs []*bc.Hash
359                                 for _, rawTx := range m.RawTxDatas {
360                                         tx := &types.Tx{}
361                                         if err := tx.UnmarshalText(rawTx); err != nil {
362                                                 completed <- err
363                                         }
364
365                                         relatedTxIDs = append(relatedTxIDs, &tx.ID)
366                                 }
367                                 var txHashes []*bc.Hash
368                                 for _, hashByte := range m.TxHashes {
369                                         hash := bc.NewHash(hashByte)
370                                         txHashes = append(txHashes, &hash)
371                                 }
372                                 if ok := types.ValidateTxMerkleTreeProof(txHashes, m.Flags, relatedTxIDs, targetBlock.TransactionsMerkleRoot); !ok {
373                                         completed <- errors.New("validate tx fail")
374                                 }
375
376                                 var statusHashes []*bc.Hash
377                                 for _, statusByte := range m.StatusHashes {
378                                         hash := bc.NewHash(statusByte)
379                                         statusHashes = append(statusHashes, &hash)
380                                 }
381                                 var relatedStatuses []*bc.TxVerifyResult
382                                 for _, statusByte := range m.RawTxStatuses {
383                                         status := &bc.TxVerifyResult{}
384                                         err := json.Unmarshal(statusByte, status)
385                                         if err != nil {
386                                                 completed <- err
387                                         }
388                                         relatedStatuses = append(relatedStatuses, status)
389                                 }
390                                 if ok := types.ValidateStatusMerkleTreeProof(statusHashes, m.Flags, relatedStatuses, targetBlock.TransactionStatusHash); !ok {
391                                         completed <- errors.New("validate status fail")
392                                 }
393
394                                 completed <- nil
395                         }
396                 }()
397
398                 spvPeer := fullNode.peers.GetPeer("spv_node")
399                 for i := 0; i < len(c.relatedTxIndex); i++ {
400                         spvPeer.AddFilterAddress(txs[c.relatedTxIndex[i]].Outputs[0].ControlProgram())
401                 }
402                 msg := &msgs.GetMerkleBlockMessage{RawHash: targetBlock.Hash().Byte32()}
403                 fullNode.handleGetMerkleBlockMsg(spvPeer, msg)
404                 if err := <-completed; err != nil {
405                         t.Fatal(err)
406                 }
407         }
408 }
409
410 func TestLocateBlocks(t *testing.T) {
411         maxNumOfBlocksPerMsg = 5
412         blocks := mockBlocks(nil, 100)
413         cases := []struct {
414                 locator    []uint64
415                 stopHash   bc.Hash
416                 wantHeight []uint64
417                 wantErr    error
418         }{
419                 {
420                         locator:    []uint64{20},
421                         stopHash:   blocks[100].Hash(),
422                         wantHeight: []uint64{20, 21, 22, 23, 24},
423                         wantErr:    nil,
424                 },
425                 {
426                         locator:    []uint64{20},
427                         stopHash:   bc.NewHash([32]byte{0x01, 0x02}),
428                         wantHeight: []uint64{},
429                         wantErr:    mock.ErrFoundHeaderByHash,
430                 },
431         }
432
433         mockChain := mock.NewChain(nil)
434         bk := &blockKeeper{chain: mockChain}
435         for _, block := range blocks {
436                 mockChain.SetBlockByHeight(block.Height, block)
437         }
438
439         for i, c := range cases {
440                 locator := []*bc.Hash{}
441                 for _, i := range c.locator {
442                         hash := blocks[i].Hash()
443                         locator = append(locator, &hash)
444                 }
445
446                 want := []*types.Block{}
447                 for _, i := range c.wantHeight {
448                         want = append(want, blocks[i])
449                 }
450
451                 mockTimeout := func() bool { return false }
452                 got, err := bk.locateBlocks(locator, &c.stopHash, mockTimeout)
453                 if err != c.wantErr {
454                         t.Errorf("case %d: got %v want err = %v", i, err, c.wantErr)
455                 }
456
457                 if !testutil.DeepEqual(got, want) {
458                         t.Errorf("case %d: got %v want %v", i, got, want)
459                 }
460         }
461 }
462
463 func TestLocateHeaders(t *testing.T) {
464         defer func() {
465                 maxNumOfHeadersPerMsg = 1000
466         }()
467         maxNumOfHeadersPerMsg = 10
468         blocks := mockBlocks(nil, 150)
469         blocksHash := []bc.Hash{}
470         for _, block := range blocks {
471                 blocksHash = append(blocksHash, block.Hash())
472         }
473
474         cases := []struct {
475                 chainHeight uint64
476                 locator     []uint64
477                 stopHash    *bc.Hash
478                 skip        uint64
479                 wantHeight  []uint64
480                 err         error
481         }{
482                 {
483                         chainHeight: 100,
484                         locator:     []uint64{90},
485                         stopHash:    &blocksHash[100],
486                         skip:        0,
487                         wantHeight:  []uint64{90, 91, 92, 93, 94, 95, 96, 97, 98, 99},
488                         err:         nil,
489                 },
490                 {
491                         chainHeight: 100,
492                         locator:     []uint64{20},
493                         stopHash:    &blocksHash[24],
494                         skip:        0,
495                         wantHeight:  []uint64{20, 21, 22, 23, 24},
496                         err:         nil,
497                 },
498                 {
499                         chainHeight: 100,
500                         locator:     []uint64{20},
501                         stopHash:    &blocksHash[20],
502                         wantHeight:  []uint64{20},
503                         err:         nil,
504                 },
505                 {
506                         chainHeight: 100,
507                         locator:     []uint64{20},
508                         stopHash:    &blocksHash[120],
509                         wantHeight:  []uint64{},
510                         err:         mock.ErrFoundHeaderByHash,
511                 },
512                 {
513                         chainHeight: 100,
514                         locator:     []uint64{120, 70},
515                         stopHash:    &blocksHash[78],
516                         wantHeight:  []uint64{70, 71, 72, 73, 74, 75, 76, 77, 78},
517                         err:         nil,
518                 },
519                 {
520                         chainHeight: 100,
521                         locator:     []uint64{15},
522                         stopHash:    &blocksHash[10],
523                         skip:        10,
524                         wantHeight:  []uint64{},
525                         err:         nil,
526                 },
527                 {
528                         chainHeight: 100,
529                         locator:     []uint64{15},
530                         stopHash:    &blocksHash[80],
531                         skip:        10,
532                         wantHeight:  []uint64{15, 26, 37, 48, 59, 70, 80},
533                         err:         nil,
534                 },
535                 {
536                         chainHeight: 100,
537                         locator:     []uint64{0},
538                         stopHash:    &blocksHash[100],
539                         skip:        9,
540                         wantHeight:  []uint64{0, 10, 20, 30, 40, 50, 60, 70, 80, 90},
541                         err:         nil,
542                 },
543         }
544
545         for i, c := range cases {
546                 mockChain := mock.NewChain(nil)
547                 bk := &blockKeeper{chain: mockChain}
548                 for i := uint64(0); i <= c.chainHeight; i++ {
549                         mockChain.SetBlockByHeight(i, blocks[i])
550                 }
551
552                 locator := []*bc.Hash{}
553                 for _, i := range c.locator {
554                         hash := blocks[i].Hash()
555                         locator = append(locator, &hash)
556                 }
557
558                 want := []*types.BlockHeader{}
559                 for _, i := range c.wantHeight {
560                         want = append(want, &blocks[i].BlockHeader)
561                 }
562
563                 got, err := bk.locateHeaders(locator, c.stopHash, c.skip, maxNumOfHeadersPerMsg)
564                 if err != c.err {
565                         t.Errorf("case %d: got %v want err = %v", i, err, c.err)
566                 }
567                 if !testutil.DeepEqual(got, want) {
568                         t.Errorf("case %d: got %v want %v", i, got, want)
569                 }
570         }
571 }