OSDN Git Service

fix casper rollback (#1951)
[bytom/bytom.git] / netsync / chainmgr / block_keeper_test.go
1 package chainmgr
2
3 import (
4         "io/ioutil"
5         "os"
6         "testing"
7         "time"
8
9         "github.com/bytom/bytom/consensus"
10         dbm "github.com/bytom/bytom/database/leveldb"
11         "github.com/bytom/bytom/errors"
12         msgs "github.com/bytom/bytom/netsync/messages"
13         "github.com/bytom/bytom/netsync/peers"
14         "github.com/bytom/bytom/protocol"
15         "github.com/bytom/bytom/protocol/bc"
16         "github.com/bytom/bytom/protocol/bc/types"
17         "github.com/bytom/bytom/test/mock"
18         "github.com/bytom/bytom/testcontrol"
19         "github.com/bytom/bytom/testutil"
20 )
21
22 func TestCheckSyncType(t *testing.T) {
23         tmp, err := ioutil.TempDir(".", "")
24         if err != nil {
25                 t.Fatalf("failed to create temporary data folder: %v", err)
26         }
27         fastSyncDB := dbm.NewDB("testdb", "leveldb", tmp)
28         defer func() {
29                 fastSyncDB.Close()
30                 os.RemoveAll(tmp)
31         }()
32
33         blocks := mockBlocks(nil, 50)
34         chain := mock.NewChain()
35         chain.SetBestBlockHeader(&blocks[len(blocks)-1].BlockHeader)
36         for _, block := range blocks {
37                 chain.SetBlockByHeight(block.Height, block)
38         }
39
40         type syncPeer struct {
41                 peer               *P2PPeer
42                 bestHeight         uint64
43                 irreversibleHeight uint64
44         }
45
46         cases := []struct {
47                 peers    []*syncPeer
48                 syncType int
49         }{
50                 {
51                         peers:    []*syncPeer{},
52                         syncType: noNeedSync,
53                 },
54                 {
55                         peers: []*syncPeer{
56                                 {peer: &P2PPeer{id: "peer1", flag: consensus.SFFullNode | consensus.SFFastSync}, bestHeight: 1000, irreversibleHeight: 500},
57                                 {peer: &P2PPeer{id: "peer2", flag: consensus.SFFullNode | consensus.SFFastSync}, bestHeight: 50, irreversibleHeight: 50},
58                         },
59                         syncType: fastSyncType,
60                 },
61                 {
62                         peers: []*syncPeer{
63                                 {peer: &P2PPeer{id: "peer1", flag: consensus.SFFullNode | consensus.SFFastSync}, bestHeight: 1000, irreversibleHeight: 100},
64                                 {peer: &P2PPeer{id: "peer2", flag: consensus.SFFullNode | consensus.SFFastSync}, bestHeight: 500, irreversibleHeight: 50},
65                         },
66                         syncType: regularSyncType,
67                 },
68                 {
69                         peers: []*syncPeer{
70                                 {peer: &P2PPeer{id: "peer1", flag: consensus.SFFullNode | consensus.SFFastSync}, bestHeight: 51, irreversibleHeight: 50},
71                         },
72                         syncType: regularSyncType,
73                 },
74                 {
75                         peers: []*syncPeer{
76                                 {peer: &P2PPeer{id: "peer1", flag: consensus.SFFullNode | consensus.SFFastSync}, bestHeight: 30, irreversibleHeight: 30},
77                         },
78                         syncType: noNeedSync,
79                 },
80                 {
81                         peers: []*syncPeer{
82                                 {peer: &P2PPeer{id: "peer1", flag: consensus.SFFullNode}, bestHeight: 1000, irreversibleHeight: 1000},
83                         },
84                         syncType: regularSyncType,
85                 },
86                 {
87                         peers: []*syncPeer{
88                                 {peer: &P2PPeer{id: "peer1", flag: consensus.SFFullNode | consensus.SFFastSync}, bestHeight: 1000, irreversibleHeight: 50},
89                                 {peer: &P2PPeer{id: "peer2", flag: consensus.SFFullNode | consensus.SFFastSync}, bestHeight: 800, irreversibleHeight: 800},
90                         },
91                         syncType: fastSyncType,
92                 },
93         }
94
95         for i, c := range cases {
96                 peers := peers.NewPeerSet(NewPeerSet())
97                 blockKeeper := newBlockKeeper(chain, peers, fastSyncDB)
98                 for _, syncPeer := range c.peers {
99                         blockKeeper.peers.AddPeer(syncPeer.peer)
100                         blockKeeper.peers.SetStatus(syncPeer.peer.id, syncPeer.bestHeight, nil)
101                         blockKeeper.peers.SetJustifiedStatus(syncPeer.peer.id, syncPeer.irreversibleHeight, nil)
102                 }
103                 gotType := blockKeeper.checkSyncType()
104                 if c.syncType != gotType {
105                         t.Errorf("case %d: got %d want %d", i, gotType, c.syncType)
106                 }
107         }
108 }
109
110 func TestRegularBlockSync(t *testing.T) {
111         if testcontrol.IgnoreTestTemporary {
112                 return
113         }
114
115         baseChain := mockBlocks(nil, 50)
116         chainX := append(baseChain, mockBlocks(baseChain[50], 60)...)
117         chainY := append(baseChain, mockBlocks(baseChain[50], 70)...)
118         chainZ := append(baseChain, mockBlocks(baseChain[50], 200)...)
119         chainE := append(baseChain, mockErrorBlocks(baseChain[50], 200, 60)...)
120
121         cases := []struct {
122                 syncTimeout time.Duration
123                 aBlocks     []*types.Block
124                 bBlocks     []*types.Block
125                 want        []*types.Block
126                 err         error
127         }{
128                 {
129                         syncTimeout: 30 * time.Second,
130                         aBlocks:     baseChain[:20],
131                         bBlocks:     baseChain[:50],
132                         want:        baseChain[:50],
133                         err:         nil,
134                 },
135                 {
136                         syncTimeout: 30 * time.Second,
137                         aBlocks:     chainX,
138                         bBlocks:     chainY,
139                         want:        chainY,
140                         err:         nil,
141                 },
142                 {
143                         syncTimeout: 30 * time.Second,
144                         aBlocks:     chainX[:52],
145                         bBlocks:     chainY[:53],
146                         want:        chainY[:53],
147                         err:         nil,
148                 },
149                 {
150                         syncTimeout: 30 * time.Second,
151                         aBlocks:     chainX[:52],
152                         bBlocks:     chainZ,
153                         want:        chainZ[:180],
154                         err:         nil,
155                 },
156                 {
157                         syncTimeout: 0 * time.Second,
158                         aBlocks:     chainX[:52],
159                         bBlocks:     chainZ,
160                         want:        chainX[:52],
161                         err:         errRequestTimeout,
162                 },
163                 {
164                         syncTimeout: 30 * time.Second,
165                         aBlocks:     chainX[:52],
166                         bBlocks:     chainE,
167                         want:        chainE[:60],
168                         err:         protocol.ErrBadStateRoot,
169                 },
170         }
171         tmp, err := ioutil.TempDir(".", "")
172         if err != nil {
173                 t.Fatalf("failed to create temporary data folder: %v", err)
174         }
175         testDBA := dbm.NewDB("testdba", "leveldb", tmp)
176         testDBB := dbm.NewDB("testdbb", "leveldb", tmp)
177         defer func() {
178                 testDBA.Close()
179                 testDBB.Close()
180                 os.RemoveAll(tmp)
181         }()
182
183         for i, c := range cases {
184                 a := mockSync(c.aBlocks, nil, testDBA)
185                 b := mockSync(c.bBlocks, nil, testDBB)
186                 netWork := NewNetWork()
187                 netWork.Register(a, "192.168.0.1", "test node A", consensus.SFFullNode)
188                 netWork.Register(b, "192.168.0.2", "test node B", consensus.SFFullNode)
189                 if B2A, A2B, err := netWork.HandsShake(a, b); err != nil {
190                         t.Errorf("fail on peer hands shake %v", err)
191                 } else {
192                         go B2A.postMan()
193                         go A2B.postMan()
194                 }
195
196                 requireBlockTimeout = c.syncTimeout
197                 a.blockKeeper.syncPeer = a.peers.GetPeer("test node B")
198                 if err := a.blockKeeper.regularBlockSync(); errors.Root(err) != c.err {
199                         t.Errorf("case %d: got %v want %v", i, err, c.err)
200                 }
201
202                 got := []*types.Block{}
203                 for i := uint64(0); i <= a.chain.BestBlockHeight(); i++ {
204                         block, err := a.chain.GetBlockByHeight(i)
205                         if err != nil {
206                                 t.Errorf("case %d got err %v", i, err)
207                         }
208                         got = append(got, block)
209                 }
210
211                 if !testutil.DeepEqual(got, c.want) {
212                         t.Errorf("case %d: got %v want %v", i, got, c.want)
213                 }
214         }
215 }
216
217 func TestRequireBlock(t *testing.T) {
218         if testcontrol.IgnoreTestTemporary {
219                 return
220         }
221
222         tmp, err := ioutil.TempDir(".", "")
223         if err != nil {
224                 t.Fatalf("failed to create temporary data folder: %v", err)
225         }
226         testDBA := dbm.NewDB("testdba", "leveldb", tmp)
227         testDBB := dbm.NewDB("testdbb", "leveldb", tmp)
228         defer func() {
229                 testDBB.Close()
230                 testDBA.Close()
231                 os.RemoveAll(tmp)
232         }()
233
234         blocks := mockBlocks(nil, 5)
235         a := mockSync(blocks[:1], nil, testDBA)
236         b := mockSync(blocks[:5], nil, testDBB)
237         netWork := NewNetWork()
238         netWork.Register(a, "192.168.0.1", "test node A", consensus.SFFullNode)
239         netWork.Register(b, "192.168.0.2", "test node B", consensus.SFFullNode)
240         if B2A, A2B, err := netWork.HandsShake(a, b); err != nil {
241                 t.Errorf("fail on peer hands shake %v", err)
242         } else {
243                 go B2A.postMan()
244                 go A2B.postMan()
245         }
246
247         a.blockKeeper.syncPeer = a.peers.GetPeer("test node B")
248         b.blockKeeper.syncPeer = b.peers.GetPeer("test node A")
249         cases := []struct {
250                 syncTimeout   time.Duration
251                 testNode      *Manager
252                 requireHeight uint64
253                 want          *types.Block
254                 err           error
255         }{
256                 {
257                         syncTimeout:   30 * time.Second,
258                         testNode:      a,
259                         requireHeight: 4,
260                         want:          blocks[4],
261                         err:           nil,
262                 },
263                 {
264                         syncTimeout:   1 * time.Millisecond,
265                         testNode:      b,
266                         requireHeight: 4,
267                         want:          nil,
268                         err:           errRequestTimeout,
269                 },
270         }
271
272         defer func() {
273                 requireBlockTimeout = 20 * time.Second
274         }()
275
276         for i, c := range cases {
277                 requireBlockTimeout = c.syncTimeout
278                 got, err := c.testNode.blockKeeper.msgFetcher.requireBlock(c.testNode.blockKeeper.syncPeer.ID(), c.requireHeight)
279                 if !testutil.DeepEqual(got, c.want) {
280                         t.Errorf("case %d: got %v want %v", i, got, c.want)
281                 }
282                 if errors.Root(err) != c.err {
283                         t.Errorf("case %d: got %v want %v", i, err, c.err)
284                 }
285         }
286 }
287
288 func TestSendMerkleBlock(t *testing.T) {
289         if testcontrol.IgnoreTestTemporary {
290                 return
291         }
292
293         tmp, err := ioutil.TempDir(".", "")
294         if err != nil {
295                 t.Fatalf("failed to create temporary data folder: %v", err)
296         }
297
298         testDBA := dbm.NewDB("testdba", "leveldb", tmp)
299         testDBB := dbm.NewDB("testdbb", "leveldb", tmp)
300         defer func() {
301                 testDBA.Close()
302                 testDBB.Close()
303                 os.RemoveAll(tmp)
304         }()
305
306         cases := []struct {
307                 txCount        int
308                 relatedTxIndex []int
309         }{
310                 {
311                         txCount:        10,
312                         relatedTxIndex: []int{0, 2, 5},
313                 },
314                 {
315                         txCount:        0,
316                         relatedTxIndex: []int{},
317                 },
318                 {
319                         txCount:        10,
320                         relatedTxIndex: []int{},
321                 },
322                 {
323                         txCount:        5,
324                         relatedTxIndex: []int{0, 1, 2, 3, 4},
325                 },
326                 {
327                         txCount:        20,
328                         relatedTxIndex: []int{1, 6, 3, 9, 10, 19},
329                 },
330         }
331
332         for _, c := range cases {
333                 blocks := mockBlocks(nil, 2)
334                 targetBlock := blocks[1]
335                 txs, bcTxs := mockTxs(c.txCount)
336                 var err error
337
338                 targetBlock.Transactions = txs
339                 if targetBlock.TransactionsMerkleRoot, err = types.TxMerkleRoot(bcTxs); err != nil {
340                         t.Fatal(err)
341                 }
342
343                 spvNode := mockSync(blocks, nil, testDBA)
344                 fullNode := mockSync(blocks, nil, testDBB)
345                 netWork := NewNetWork()
346                 netWork.Register(spvNode, "192.168.0.1", "spv_node", consensus.SFFastSync)
347                 netWork.Register(fullNode, "192.168.0.2", "full_node", consensus.DefaultServices)
348
349                 var F2S *P2PPeer
350                 if F2S, _, err = netWork.HandsShake(spvNode, fullNode); err != nil {
351                         t.Errorf("fail on peer hands shake %v", err)
352                 }
353
354                 completed := make(chan error)
355                 go func() {
356                         msgBytes := <-F2S.msgCh
357                         _, msg, _ := decodeMessage(msgBytes)
358                         switch m := msg.(type) {
359                         case *msgs.MerkleBlockMessage:
360                                 var relatedTxIDs []*bc.Hash
361                                 for _, rawTx := range m.RawTxDatas {
362                                         tx := &types.Tx{}
363                                         if err := tx.UnmarshalText(rawTx); err != nil {
364                                                 completed <- err
365                                         }
366
367                                         relatedTxIDs = append(relatedTxIDs, &tx.ID)
368                                 }
369                                 var txHashes []*bc.Hash
370                                 for _, hashByte := range m.TxHashes {
371                                         hash := bc.NewHash(hashByte)
372                                         txHashes = append(txHashes, &hash)
373                                 }
374                                 if ok := types.ValidateTxMerkleTreeProof(txHashes, m.Flags, relatedTxIDs, targetBlock.TransactionsMerkleRoot); !ok {
375                                         completed <- errors.New("validate tx fail")
376                                 }
377                                 completed <- nil
378                         }
379                 }()
380
381                 spvPeer := fullNode.peers.GetPeer("spv_node")
382                 for i := 0; i < len(c.relatedTxIndex); i++ {
383                         spvPeer.AddFilterAddress(txs[c.relatedTxIndex[i]].Outputs[0].ControlProgram)
384                 }
385                 msg := &msgs.GetMerkleBlockMessage{RawHash: targetBlock.Hash().Byte32()}
386                 fullNode.handleGetMerkleBlockMsg(spvPeer, msg)
387                 if err := <-completed; err != nil {
388                         t.Fatal(err)
389                 }
390         }
391 }
392
393 func TestLocateBlocks(t *testing.T) {
394         if testcontrol.IgnoreTestTemporary {
395                 return
396         }
397
398         maxNumOfBlocksPerMsg = 5
399         blocks := mockBlocks(nil, 100)
400         cases := []struct {
401                 locator    []uint64
402                 stopHash   bc.Hash
403                 wantHeight []uint64
404                 wantErr    error
405         }{
406                 {
407                         locator:    []uint64{20},
408                         stopHash:   blocks[100].Hash(),
409                         wantHeight: []uint64{20, 21, 22, 23, 24},
410                         wantErr:    nil,
411                 },
412                 {
413                         locator:    []uint64{20},
414                         stopHash:   bc.NewHash([32]byte{0x01, 0x02}),
415                         wantHeight: []uint64{},
416                         wantErr:    mock.ErrFoundHeaderByHash,
417                 },
418         }
419
420         mockChain := mock.NewChain()
421         bk := &blockKeeper{chain: mockChain}
422         for _, block := range blocks {
423                 mockChain.SetBlockByHeight(block.Height, block)
424         }
425
426         for i, c := range cases {
427                 locator := []*bc.Hash{}
428                 for _, i := range c.locator {
429                         hash := blocks[i].Hash()
430                         locator = append(locator, &hash)
431                 }
432
433                 want := []*types.Block{}
434                 for _, i := range c.wantHeight {
435                         want = append(want, blocks[i])
436                 }
437
438                 mockTimeout := func() bool { return false }
439                 got, err := bk.locateBlocks(locator, &c.stopHash, mockTimeout)
440                 if err != c.wantErr {
441                         t.Errorf("case %d: got %v want err = %v", i, err, c.wantErr)
442                 }
443
444                 if !testutil.DeepEqual(got, want) {
445                         t.Errorf("case %d: got %v want %v", i, got, want)
446                 }
447         }
448 }
449
450 func TestLocateHeaders(t *testing.T) {
451         if testcontrol.IgnoreTestTemporary {
452                 return
453         }
454
455         defer func() {
456                 maxNumOfHeadersPerMsg = 1000
457         }()
458         maxNumOfHeadersPerMsg = 10
459         blocks := mockBlocks(nil, 150)
460         blocksHash := []bc.Hash{}
461         for _, block := range blocks {
462                 blocksHash = append(blocksHash, block.Hash())
463         }
464
465         cases := []struct {
466                 chainHeight uint64
467                 locator     []uint64
468                 stopHash    *bc.Hash
469                 skip        uint64
470                 wantHeight  []uint64
471                 err         error
472         }{
473                 {
474                         chainHeight: 100,
475                         locator:     []uint64{90},
476                         stopHash:    &blocksHash[100],
477                         skip:        0,
478                         wantHeight:  []uint64{90, 91, 92, 93, 94, 95, 96, 97, 98, 99},
479                         err:         nil,
480                 },
481                 {
482                         chainHeight: 100,
483                         locator:     []uint64{20},
484                         stopHash:    &blocksHash[24],
485                         skip:        0,
486                         wantHeight:  []uint64{20, 21, 22, 23, 24},
487                         err:         nil,
488                 },
489                 {
490                         chainHeight: 100,
491                         locator:     []uint64{20},
492                         stopHash:    &blocksHash[20],
493                         wantHeight:  []uint64{20},
494                         err:         nil,
495                 },
496                 {
497                         chainHeight: 100,
498                         locator:     []uint64{20},
499                         stopHash:    &blocksHash[120],
500                         wantHeight:  []uint64{},
501                         err:         mock.ErrFoundHeaderByHash,
502                 },
503                 {
504                         chainHeight: 100,
505                         locator:     []uint64{120, 70},
506                         stopHash:    &blocksHash[78],
507                         wantHeight:  []uint64{70, 71, 72, 73, 74, 75, 76, 77, 78},
508                         err:         nil,
509                 },
510                 {
511                         chainHeight: 100,
512                         locator:     []uint64{15},
513                         stopHash:    &blocksHash[10],
514                         skip:        10,
515                         wantHeight:  []uint64{},
516                         err:         nil,
517                 },
518                 {
519                         chainHeight: 100,
520                         locator:     []uint64{15},
521                         stopHash:    &blocksHash[80],
522                         skip:        10,
523                         wantHeight:  []uint64{15, 26, 37, 48, 59, 70, 80},
524                         err:         nil,
525                 },
526                 {
527                         chainHeight: 100,
528                         locator:     []uint64{0},
529                         stopHash:    &blocksHash[100],
530                         skip:        9,
531                         wantHeight:  []uint64{0, 10, 20, 30, 40, 50, 60, 70, 80, 90},
532                         err:         nil,
533                 },
534         }
535
536         for i, c := range cases {
537                 mockChain := mock.NewChain()
538                 bk := &blockKeeper{chain: mockChain}
539                 for i := uint64(0); i <= c.chainHeight; i++ {
540                         mockChain.SetBlockByHeight(i, blocks[i])
541                 }
542
543                 locator := []*bc.Hash{}
544                 for _, i := range c.locator {
545                         hash := blocks[i].Hash()
546                         locator = append(locator, &hash)
547                 }
548
549                 want := []*types.BlockHeader{}
550                 for _, i := range c.wantHeight {
551                         want = append(want, &blocks[i].BlockHeader)
552                 }
553
554                 got, err := bk.locateHeaders(locator, c.stopHash, c.skip, maxNumOfHeadersPerMsg)
555                 if err != c.err {
556                         t.Errorf("case %d: got %v want err = %v", i, err, c.err)
557                 }
558                 if !testutil.DeepEqual(got, want) {
559                         t.Errorf("case %d: got %v want %v", i, got, want)
560                 }
561         }
562 }