OSDN Git Service

Add skeleton size validity check (#354)
[bytom/vapor.git] / netsync / chainmgr / fast_sync.go
1 package chainmgr
2
3 import (
4         "sync"
5
6         log "github.com/sirupsen/logrus"
7
8         "github.com/vapor/errors"
9         "github.com/vapor/netsync/peers"
10         "github.com/vapor/p2p/security"
11         "github.com/vapor/protocol/bc"
12         "github.com/vapor/protocol/bc/types"
13 )
14
15 var (
16         minSizeOfSyncSkeleton  = 2
17         maxSizeOfSyncSkeleton  = 11
18         numOfBlocksSkeletonGap = maxNumOfBlocksPerMsg
19         maxNumOfBlocksPerSync  = numOfBlocksSkeletonGap * uint64(maxSizeOfSyncSkeleton-1)
20         fastSyncPivotGap       = uint64(64)
21         minGapStartFastSync    = uint64(128)
22
23         errNoSyncPeer   = errors.New("can't find sync peer")
24         errSkeletonSize = errors.New("fast sync skeleton size wrong")
25 )
26
27 type fastSync struct {
28         chain          Chain
29         msgFetcher     MsgFetcher
30         blockProcessor BlockProcessor
31         peers          *peers.PeerSet
32         mainSyncPeer   *peers.Peer
33 }
34
35 func newFastSync(chain Chain, msgFetcher MsgFetcher, storage Storage, peers *peers.PeerSet) *fastSync {
36         return &fastSync{
37                 chain:          chain,
38                 msgFetcher:     msgFetcher,
39                 blockProcessor: newBlockProcessor(chain, storage, peers),
40                 peers:          peers,
41         }
42 }
43
44 func (fs *fastSync) blockLocator() []*bc.Hash {
45         header := fs.chain.BestBlockHeader()
46         locator := []*bc.Hash{}
47         step := uint64(1)
48
49         for header != nil {
50                 headerHash := header.Hash()
51                 locator = append(locator, &headerHash)
52                 if header.Height == 0 {
53                         break
54                 }
55
56                 var err error
57                 if header.Height < step {
58                         header, err = fs.chain.GetHeaderByHeight(0)
59                 } else {
60                         header, err = fs.chain.GetHeaderByHeight(header.Height - step)
61                 }
62                 if err != nil {
63                         log.WithFields(log.Fields{"module": logModule, "err": err}).Error("blockKeeper fail on get blockLocator")
64                         break
65                 }
66
67                 if len(locator) >= 9 {
68                         step *= 2
69                 }
70         }
71         return locator
72 }
73
74 // createFetchBlocksTasks get the skeleton and assign tasks according to the skeleton.
75 func (fs *fastSync) createFetchBlocksTasks(stopBlock *types.Block) ([]*fetchBlocksWork, error) {
76         // Find peers that meet the height requirements.
77         peers := fs.peers.GetPeersByHeight(stopBlock.Height + fastSyncPivotGap)
78         if len(peers) == 0 {
79                 return nil, errNoSyncPeer
80         }
81
82         // parallel fetch the skeleton from peers.
83         stopHash := stopBlock.Hash()
84         skeletonMap := fs.msgFetcher.parallelFetchHeaders(peers, fs.blockLocator(), &stopHash, numOfBlocksSkeletonGap-1)
85         if len(skeletonMap) == 0 {
86                 return nil, errors.New("No skeleton found")
87         }
88
89         mainSkeleton, ok := skeletonMap[fs.mainSyncPeer.ID()]
90         if !ok {
91                 return nil, errors.New("No main skeleton found")
92         }
93
94         if len(mainSkeleton) < minSizeOfSyncSkeleton || len(mainSkeleton) > maxSizeOfSyncSkeleton {
95                 fs.peers.ProcessIllegal(fs.mainSyncPeer.ID(), security.LevelMsgIllegal, errSkeletonSize.Error())
96                 return nil, errSkeletonSize
97         }
98
99         // collect peers that match the skeleton of the primary sync peer
100         fs.msgFetcher.addSyncPeer(fs.mainSyncPeer.ID())
101         delete(skeletonMap, fs.mainSyncPeer.ID())
102         for peerID, skeleton := range skeletonMap {
103                 if len(skeleton) != len(mainSkeleton) {
104                         log.WithFields(log.Fields{"module": logModule, "main skeleton": len(mainSkeleton), "got skeleton": len(skeleton)}).Warn("different skeleton length")
105                         continue
106                 }
107
108                 for i, header := range skeleton {
109                         if header.Hash() != mainSkeleton[i].Hash() {
110                                 log.WithFields(log.Fields{"module": logModule, "header index": i, "main skeleton": mainSkeleton[i].Hash(), "got skeleton": header.Hash()}).Warn("different skeleton hash")
111                                 continue
112                         }
113                 }
114                 fs.msgFetcher.addSyncPeer(peerID)
115         }
116
117         blockFetchTasks := make([]*fetchBlocksWork, 0)
118         // create download task
119         for i := 0; i < len(mainSkeleton)-1; i++ {
120                 blockFetchTasks = append(blockFetchTasks, &fetchBlocksWork{startHeader: mainSkeleton[i], stopHeader: mainSkeleton[i+1]})
121         }
122
123         return blockFetchTasks, nil
124 }
125
126 func (fs *fastSync) process() error {
127         stopBlock, err := fs.findSyncRange()
128         if err != nil {
129                 return err
130         }
131
132         tasks, err := fs.createFetchBlocksTasks(stopBlock)
133         if err != nil {
134                 return err
135         }
136
137         downloadNotifyCh := make(chan struct{}, 1)
138         processStopCh := make(chan struct{})
139         var wg sync.WaitGroup
140         wg.Add(2)
141         go fs.msgFetcher.parallelFetchBlocks(tasks, downloadNotifyCh, processStopCh, &wg)
142         go fs.blockProcessor.process(downloadNotifyCh, processStopCh, tasks[0].startHeader.Height, &wg)
143         wg.Wait()
144         fs.msgFetcher.resetParameter()
145         log.WithFields(log.Fields{"module": logModule, "height": fs.chain.BestBlockHeight()}).Info("fast sync complete")
146         return nil
147 }
148
149 // findSyncRange find the start and end of this sync.
150 // sync length cannot be greater than maxFastSyncBlocksNum.
151 func (fs *fastSync) findSyncRange() (*types.Block, error) {
152         bestHeight := fs.chain.BestBlockHeight()
153         length := fs.mainSyncPeer.IrreversibleHeight() - fastSyncPivotGap - bestHeight
154         if length > maxNumOfBlocksPerSync {
155                 length = maxNumOfBlocksPerSync
156         }
157
158         return fs.msgFetcher.requireBlock(fs.mainSyncPeer.ID(), bestHeight+length)
159 }
160
161 func (fs *fastSync) setSyncPeer(peer *peers.Peer) {
162         fs.mainSyncPeer = peer
163 }