OSDN Git Service

Fix sync orphan block system panic (#281)
[bytom/vapor.git] / netsync / chainmgr / block_keeper_test.go
1 package chainmgr
2
3 import (
4         "encoding/json"
5         "io/ioutil"
6         "os"
7         "testing"
8         "time"
9
10         "github.com/vapor/consensus"
11         dbm "github.com/vapor/database/leveldb"
12         "github.com/vapor/errors"
13         msgs "github.com/vapor/netsync/messages"
14         "github.com/vapor/protocol/bc"
15         "github.com/vapor/protocol/bc/types"
16         "github.com/vapor/test/mock"
17         "github.com/vapor/testutil"
18 )
19
20 func TestRegularBlockSync(t *testing.T) {
21         baseChain := mockBlocks(nil, 50)
22         chainX := append(baseChain, mockBlocks(baseChain[50], 60)...)
23         chainY := append(baseChain, mockBlocks(baseChain[50], 70)...)
24         chainZ := append(baseChain, mockBlocks(baseChain[50], 200)...)
25
26         cases := []struct {
27                 syncTimeout time.Duration
28                 aBlocks     []*types.Block
29                 bBlocks     []*types.Block
30                 want        []*types.Block
31                 err         error
32         }{
33                 {
34                         syncTimeout: 30 * time.Second,
35                         aBlocks:     baseChain[:20],
36                         bBlocks:     baseChain[:50],
37                         want:        baseChain[:50],
38                         err:         nil,
39                 },
40                 {
41                         syncTimeout: 30 * time.Second,
42                         aBlocks:     chainX,
43                         bBlocks:     chainY,
44                         want:        chainY,
45                         err:         nil,
46                 },
47                 {
48                         syncTimeout: 30 * time.Second,
49                         aBlocks:     chainX[:52],
50                         bBlocks:     chainY[:53],
51                         want:        chainY[:53],
52                         err:         nil,
53                 },
54                 {
55                         syncTimeout: 30 * time.Second,
56                         aBlocks:     chainX[:52],
57                         bBlocks:     chainZ,
58                         want:        chainZ[:201],
59                         err:         nil,
60                 },
61         }
62         tmp, err := ioutil.TempDir(".", "")
63         if err != nil {
64                 t.Fatalf("failed to create temporary data folder: %v", err)
65         }
66         testDBA := dbm.NewDB("testdba", "leveldb", tmp)
67         testDBB := dbm.NewDB("testdbb", "leveldb", tmp)
68         defer func() {
69                 testDBA.Close()
70                 testDBB.Close()
71                 os.RemoveAll(tmp)
72         }()
73
74         for i, c := range cases {
75                 a := mockSync(c.aBlocks, nil, testDBA)
76                 b := mockSync(c.bBlocks, nil, testDBB)
77                 netWork := NewNetWork()
78                 netWork.Register(a, "192.168.0.1", "test node A", consensus.SFFullNode)
79                 netWork.Register(b, "192.168.0.2", "test node B", consensus.SFFullNode)
80                 if B2A, A2B, err := netWork.HandsShake(a, b); err != nil {
81                         t.Errorf("fail on peer hands shake %v", err)
82                 } else {
83                         go B2A.postMan()
84                         go A2B.postMan()
85                 }
86
87                 a.blockKeeper.syncPeer = a.peers.GetPeer("test node B")
88                 if err := a.blockKeeper.regularBlockSync(); errors.Root(err) != c.err {
89                         t.Errorf("case %d: got %v want %v", i, err, c.err)
90                 }
91
92                 got := []*types.Block{}
93                 for i := uint64(0); i <= a.chain.BestBlockHeight(); i++ {
94                         block, err := a.chain.GetBlockByHeight(i)
95                         if err != nil {
96                                 t.Errorf("case %d got err %v", i, err)
97                         }
98                         got = append(got, block)
99                 }
100
101                 if !testutil.DeepEqual(got, c.want) {
102                         t.Errorf("case %d: got %v want %v", i, got, c.want)
103                 }
104         }
105 }
106
107 func TestRequireBlock(t *testing.T) {
108         tmp, err := ioutil.TempDir(".", "")
109         if err != nil {
110                 t.Fatalf("failed to create temporary data folder: %v", err)
111         }
112         testDBA := dbm.NewDB("testdba", "leveldb", tmp)
113         testDBB := dbm.NewDB("testdbb", "leveldb", tmp)
114         defer func() {
115                 testDBB.Close()
116                 testDBA.Close()
117                 os.RemoveAll(tmp)
118         }()
119
120         blocks := mockBlocks(nil, 5)
121         a := mockSync(blocks[:1], nil, testDBA)
122         b := mockSync(blocks[:5], nil, testDBB)
123         netWork := NewNetWork()
124         netWork.Register(a, "192.168.0.1", "test node A", consensus.SFFullNode)
125         netWork.Register(b, "192.168.0.2", "test node B", consensus.SFFullNode)
126         if B2A, A2B, err := netWork.HandsShake(a, b); err != nil {
127                 t.Errorf("fail on peer hands shake %v", err)
128         } else {
129                 go B2A.postMan()
130                 go A2B.postMan()
131         }
132
133         a.blockKeeper.syncPeer = a.peers.GetPeer("test node B")
134         b.blockKeeper.syncPeer = b.peers.GetPeer("test node A")
135         cases := []struct {
136                 syncTimeout   time.Duration
137                 testNode      *Manager
138                 requireHeight uint64
139                 want          *types.Block
140                 err           error
141         }{
142                 {
143                         syncTimeout:   30 * time.Second,
144                         testNode:      a,
145                         requireHeight: 4,
146                         want:          blocks[4],
147                         err:           nil,
148                 },
149                 {
150                         syncTimeout:   1 * time.Millisecond,
151                         testNode:      b,
152                         requireHeight: 4,
153                         want:          nil,
154                         err:           errRequestTimeout,
155                 },
156         }
157
158         defer func() {
159                 requireBlockTimeout = 20 * time.Second
160         }()
161
162         for i, c := range cases {
163                 requireBlockTimeout = c.syncTimeout
164                 got, err := c.testNode.blockKeeper.msgFetcher.requireBlock(c.testNode.blockKeeper.syncPeer.ID(), c.requireHeight)
165                 if !testutil.DeepEqual(got, c.want) {
166                         t.Errorf("case %d: got %v want %v", i, got, c.want)
167                 }
168                 if errors.Root(err) != c.err {
169                         t.Errorf("case %d: got %v want %v", i, err, c.err)
170                 }
171         }
172 }
173
174 func TestSendMerkleBlock(t *testing.T) {
175         tmp, err := ioutil.TempDir(".", "")
176         if err != nil {
177                 t.Fatalf("failed to create temporary data folder: %v", err)
178         }
179
180         testDBA := dbm.NewDB("testdba", "leveldb", tmp)
181         testDBB := dbm.NewDB("testdbb", "leveldb", tmp)
182         defer func() {
183                 testDBA.Close()
184                 testDBB.Close()
185                 os.RemoveAll(tmp)
186         }()
187
188         cases := []struct {
189                 txCount        int
190                 relatedTxIndex []int
191         }{
192                 {
193                         txCount:        10,
194                         relatedTxIndex: []int{0, 2, 5},
195                 },
196                 {
197                         txCount:        0,
198                         relatedTxIndex: []int{},
199                 },
200                 {
201                         txCount:        10,
202                         relatedTxIndex: []int{},
203                 },
204                 {
205                         txCount:        5,
206                         relatedTxIndex: []int{0, 1, 2, 3, 4},
207                 },
208                 {
209                         txCount:        20,
210                         relatedTxIndex: []int{1, 6, 3, 9, 10, 19},
211                 },
212         }
213
214         for _, c := range cases {
215                 blocks := mockBlocks(nil, 2)
216                 targetBlock := blocks[1]
217                 txs, bcTxs := mockTxs(c.txCount)
218                 var err error
219
220                 targetBlock.Transactions = txs
221                 if targetBlock.TransactionsMerkleRoot, err = types.TxMerkleRoot(bcTxs); err != nil {
222                         t.Fatal(err)
223                 }
224
225                 spvNode := mockSync(blocks, nil, testDBA)
226                 blockHash := targetBlock.Hash()
227                 var statusResult *bc.TransactionStatus
228                 if statusResult, err = spvNode.chain.GetTransactionStatus(&blockHash); err != nil {
229                         t.Fatal(err)
230                 }
231
232                 if targetBlock.TransactionStatusHash, err = types.TxStatusMerkleRoot(statusResult.VerifyStatus); err != nil {
233                         t.Fatal(err)
234                 }
235
236                 fullNode := mockSync(blocks, nil, testDBB)
237                 netWork := NewNetWork()
238                 netWork.Register(spvNode, "192.168.0.1", "spv_node", consensus.SFFastSync)
239                 netWork.Register(fullNode, "192.168.0.2", "full_node", consensus.DefaultServices)
240
241                 var F2S *P2PPeer
242                 if F2S, _, err = netWork.HandsShake(spvNode, fullNode); err != nil {
243                         t.Errorf("fail on peer hands shake %v", err)
244                 }
245
246                 completed := make(chan error)
247                 go func() {
248                         msgBytes := <-F2S.msgCh
249                         _, msg, _ := decodeMessage(msgBytes)
250                         switch m := msg.(type) {
251                         case *msgs.MerkleBlockMessage:
252                                 var relatedTxIDs []*bc.Hash
253                                 for _, rawTx := range m.RawTxDatas {
254                                         tx := &types.Tx{}
255                                         if err := tx.UnmarshalText(rawTx); err != nil {
256                                                 completed <- err
257                                         }
258
259                                         relatedTxIDs = append(relatedTxIDs, &tx.ID)
260                                 }
261                                 var txHashes []*bc.Hash
262                                 for _, hashByte := range m.TxHashes {
263                                         hash := bc.NewHash(hashByte)
264                                         txHashes = append(txHashes, &hash)
265                                 }
266                                 if ok := types.ValidateTxMerkleTreeProof(txHashes, m.Flags, relatedTxIDs, targetBlock.TransactionsMerkleRoot); !ok {
267                                         completed <- errors.New("validate tx fail")
268                                 }
269
270                                 var statusHashes []*bc.Hash
271                                 for _, statusByte := range m.StatusHashes {
272                                         hash := bc.NewHash(statusByte)
273                                         statusHashes = append(statusHashes, &hash)
274                                 }
275                                 var relatedStatuses []*bc.TxVerifyResult
276                                 for _, statusByte := range m.RawTxStatuses {
277                                         status := &bc.TxVerifyResult{}
278                                         err := json.Unmarshal(statusByte, status)
279                                         if err != nil {
280                                                 completed <- err
281                                         }
282                                         relatedStatuses = append(relatedStatuses, status)
283                                 }
284                                 if ok := types.ValidateStatusMerkleTreeProof(statusHashes, m.Flags, relatedStatuses, targetBlock.TransactionStatusHash); !ok {
285                                         completed <- errors.New("validate status fail")
286                                 }
287
288                                 completed <- nil
289                         }
290                 }()
291
292                 spvPeer := fullNode.peers.GetPeer("spv_node")
293                 for i := 0; i < len(c.relatedTxIndex); i++ {
294                         spvPeer.AddFilterAddress(txs[c.relatedTxIndex[i]].Outputs[0].ControlProgram())
295                 }
296                 msg := &msgs.GetMerkleBlockMessage{RawHash: targetBlock.Hash().Byte32()}
297                 fullNode.handleGetMerkleBlockMsg(spvPeer, msg)
298                 if err := <-completed; err != nil {
299                         t.Fatal(err)
300                 }
301         }
302 }
303
304 func TestLocateBlocks(t *testing.T) {
305         maxNumOfBlocksPerMsg = 5
306         blocks := mockBlocks(nil, 100)
307         cases := []struct {
308                 locator    []uint64
309                 stopHash   bc.Hash
310                 wantHeight []uint64
311         }{
312                 {
313                         locator:    []uint64{20},
314                         stopHash:   blocks[100].Hash(),
315                         wantHeight: []uint64{20, 21, 22, 23, 24},
316                 },
317         }
318
319         mockChain := mock.NewChain(nil)
320         bk := &blockKeeper{chain: mockChain}
321         for _, block := range blocks {
322                 mockChain.SetBlockByHeight(block.Height, block)
323         }
324
325         for i, c := range cases {
326                 locator := []*bc.Hash{}
327                 for _, i := range c.locator {
328                         hash := blocks[i].Hash()
329                         locator = append(locator, &hash)
330                 }
331
332                 want := []*types.Block{}
333                 for _, i := range c.wantHeight {
334                         want = append(want, blocks[i])
335                 }
336
337                 got, _ := bk.locateBlocks(locator, &c.stopHash)
338                 if !testutil.DeepEqual(got, want) {
339                         t.Errorf("case %d: got %v want %v", i, got, want)
340                 }
341         }
342 }
343
344 func TestLocateHeaders(t *testing.T) {
345         defer func() {
346                 maxNumOfHeadersPerMsg = 1000
347         }()
348         maxNumOfHeadersPerMsg = 10
349         blocks := mockBlocks(nil, 150)
350         blocksHash := []bc.Hash{}
351         for _, block := range blocks {
352                 blocksHash = append(blocksHash, block.Hash())
353         }
354
355         cases := []struct {
356                 chainHeight uint64
357                 locator     []uint64
358                 stopHash    *bc.Hash
359                 skip        uint64
360                 wantHeight  []uint64
361                 err         bool
362         }{
363                 {
364                         chainHeight: 100,
365                         locator:     []uint64{90},
366                         stopHash:    &blocksHash[100],
367                         skip:        0,
368                         wantHeight:  []uint64{90, 91, 92, 93, 94, 95, 96, 97, 98, 99},
369                         err:         false,
370                 },
371                 {
372                         chainHeight: 100,
373                         locator:     []uint64{20},
374                         stopHash:    &blocksHash[24],
375                         skip:        0,
376                         wantHeight:  []uint64{20, 21, 22, 23, 24},
377                         err:         false,
378                 },
379                 {
380                         chainHeight: 100,
381                         locator:     []uint64{20},
382                         stopHash:    &blocksHash[20],
383                         wantHeight:  []uint64{20},
384                         err:         false,
385                 },
386                 {
387                         chainHeight: 100,
388                         locator:     []uint64{20},
389                         stopHash:    &blocksHash[120],
390                         wantHeight:  []uint64{},
391                         err:         false,
392                 },
393                 {
394                         chainHeight: 100,
395                         locator:     []uint64{120, 70},
396                         stopHash:    &blocksHash[78],
397                         wantHeight:  []uint64{70, 71, 72, 73, 74, 75, 76, 77, 78},
398                         err:         false,
399                 },
400                 {
401                         chainHeight: 100,
402                         locator:     []uint64{15},
403                         stopHash:    &blocksHash[10],
404                         skip:        10,
405                         wantHeight:  []uint64{},
406                         err:         false,
407                 },
408                 {
409                         chainHeight: 100,
410                         locator:     []uint64{15},
411                         stopHash:    &blocksHash[80],
412                         skip:        10,
413                         wantHeight:  []uint64{15, 26, 37, 48, 59, 70, 80},
414                         err:         false,
415                 },
416                 {
417                         chainHeight: 100,
418                         locator:     []uint64{0},
419                         stopHash:    &blocksHash[100],
420                         skip:        9,
421                         wantHeight:  []uint64{0, 10, 20, 30, 40, 50, 60, 70, 80, 90},
422                         err:         false,
423                 },
424         }
425
426         for i, c := range cases {
427                 mockChain := mock.NewChain(nil)
428                 bk := &blockKeeper{chain: mockChain}
429                 for i := uint64(0); i <= c.chainHeight; i++ {
430                         mockChain.SetBlockByHeight(i, blocks[i])
431                 }
432
433                 locator := []*bc.Hash{}
434                 for _, i := range c.locator {
435                         hash := blocks[i].Hash()
436                         locator = append(locator, &hash)
437                 }
438
439                 want := []*types.BlockHeader{}
440                 for _, i := range c.wantHeight {
441                         want = append(want, &blocks[i].BlockHeader)
442                 }
443
444                 got, err := bk.locateHeaders(locator, c.stopHash, c.skip, maxNumOfHeadersPerMsg)
445                 if err != nil != c.err {
446                         t.Errorf("case %d: got %v want err = %v", i, err, c.err)
447                 }
448                 if !testutil.DeepEqual(got, want) {
449                         t.Errorf("case %d: got %v want %v", i, got, want)
450                 }
451         }
452 }