OSDN Git Service

add fast sync func (#204)
[bytom/vapor.git] / netsync / chainmgr / fast_sync.go
1 package chainmgr
2
3 import (
4         log "github.com/sirupsen/logrus"
5
6         "github.com/vapor/errors"
7         "github.com/vapor/netsync/peers"
8         "github.com/vapor/p2p/security"
9         "github.com/vapor/protocol/bc"
10         "github.com/vapor/protocol/bc/types"
11 )
12
13 var (
14         maxBlocksPerMsg      = uint64(1000)
15         maxHeadersPerMsg     = uint64(1000)
16         fastSyncPivotGap     = uint64(64)
17         minGapStartFastSync  = uint64(128)
18         maxFastSyncBlocksNum = uint64(10000)
19
20         errOrphanBlock = errors.New("fast sync block is orphan")
21 )
22
23 type MsgFetcher interface {
24         requireBlock(peerID string, height uint64) (*types.Block, error)
25         requireBlocks(peerID string, locator []*bc.Hash, stopHash *bc.Hash) ([]*types.Block, error)
26 }
27
28 type fastSync struct {
29         chain      Chain
30         msgFetcher MsgFetcher
31         peers      *peers.PeerSet
32         syncPeer   *peers.Peer
33         stopHeader *types.BlockHeader
34         length     uint64
35
36         quite chan struct{}
37 }
38
39 func newFastSync(chain Chain, msgFether MsgFetcher, peers *peers.PeerSet) *fastSync {
40         return &fastSync{
41                 chain:      chain,
42                 msgFetcher: msgFether,
43                 peers:      peers,
44                 quite:      make(chan struct{}),
45         }
46 }
47
48 func (fs *fastSync) blockLocator() []*bc.Hash {
49         header := fs.chain.BestBlockHeader()
50         locator := []*bc.Hash{}
51
52         step := uint64(1)
53         for header != nil {
54                 headerHash := header.Hash()
55                 locator = append(locator, &headerHash)
56                 if header.Height == 0 {
57                         break
58                 }
59
60                 var err error
61                 if header.Height < step {
62                         header, err = fs.chain.GetHeaderByHeight(0)
63                 } else {
64                         header, err = fs.chain.GetHeaderByHeight(header.Height - step)
65                 }
66                 if err != nil {
67                         log.WithFields(log.Fields{"module": logModule, "err": err}).Error("blockKeeper fail on get blockLocator")
68                         break
69                 }
70
71                 if len(locator) >= 9 {
72                         step *= 2
73                 }
74         }
75         return locator
76 }
77
78 func (fs *fastSync) process() error {
79         if err := fs.findFastSyncRange(); err != nil {
80                 return err
81         }
82
83         stopHash := fs.stopHeader.Hash()
84         for fs.chain.BestBlockHeight() < fs.stopHeader.Height {
85                 blocks, err := fs.msgFetcher.requireBlocks(fs.syncPeer.ID(), fs.blockLocator(), &stopHash)
86                 if err != nil {
87                         fs.peers.ErrorHandler(fs.syncPeer.ID(), security.LevelConnException, err)
88                         return err
89                 }
90
91                 if err := fs.verifyBlocks(blocks); err != nil {
92                         fs.peers.ErrorHandler(fs.syncPeer.ID(), security.LevelMsgIllegal, err)
93                         return err
94                 }
95         }
96
97         log.WithFields(log.Fields{"module": logModule, "height": fs.chain.BestBlockHeight()}).Info("fast sync success")
98         return nil
99 }
100
101 func (fs *fastSync) findFastSyncRange() error {
102         bestHeight := fs.chain.BestBlockHeight()
103         fs.length = fs.syncPeer.IrreversibleHeight() - fastSyncPivotGap - bestHeight
104         if fs.length > maxFastSyncBlocksNum {
105                 fs.length = maxFastSyncBlocksNum
106         }
107
108         stopBlock, err := fs.msgFetcher.requireBlock(fs.syncPeer.ID(), bestHeight+fs.length)
109         if err != nil {
110                 return err
111         }
112
113         fs.stopHeader = &stopBlock.BlockHeader
114         return nil
115 }
116
117 func (fs *fastSync) locateBlocks(locator []*bc.Hash, stopHash *bc.Hash) ([]*types.Block, error) {
118         headers, err := fs.locateHeaders(locator, stopHash, 0, maxBlocksPerMsg)
119         if err != nil {
120                 return nil, err
121         }
122
123         blocks := []*types.Block{}
124         for _, header := range headers {
125                 headerHash := header.Hash()
126                 block, err := fs.chain.GetBlockByHash(&headerHash)
127                 if err != nil {
128                         return nil, err
129                 }
130
131                 blocks = append(blocks, block)
132         }
133         return blocks, nil
134 }
135
136 func (fs *fastSync) locateHeaders(locator []*bc.Hash, stopHash *bc.Hash, skip uint64, maxNum uint64) ([]*types.BlockHeader, error) {
137         startHeader, err := fs.chain.GetHeaderByHeight(0)
138         if err != nil {
139                 return nil, err
140         }
141
142         for _, hash := range locator {
143                 header, err := fs.chain.GetHeaderByHash(hash)
144                 if err == nil && fs.chain.InMainChain(header.Hash()) {
145                         startHeader = header
146                         break
147                 }
148         }
149
150         headers := make([]*types.BlockHeader, 0)
151         stopHeader, err := fs.chain.GetHeaderByHash(stopHash)
152         if err != nil {
153                 return headers, nil
154         }
155
156         if !fs.chain.InMainChain(*stopHash) {
157                 return headers, nil
158         }
159
160         num := uint64(0)
161         for i := startHeader.Height; i <= stopHeader.Height && num < maxNum; i += skip + 1 {
162                 header, err := fs.chain.GetHeaderByHeight(i)
163                 if err != nil {
164                         return nil, err
165                 }
166
167                 headers = append(headers, header)
168                 num++
169         }
170
171         return headers, nil
172 }
173
174 func (fs *fastSync) setSyncPeer(peer *peers.Peer) {
175         fs.syncPeer = peer
176 }
177
178 func (fs *fastSync) verifyBlocks(blocks []*types.Block) error {
179         for _, block := range blocks {
180                 isOrphan, err := fs.chain.ProcessBlock(block)
181                 if err != nil {
182                         return err
183                 }
184
185                 if isOrphan {
186                         log.WithFields(log.Fields{"module": logModule, "height": block.Height, "hash": block.Hash()}).Error("fast sync block is orphan")
187                         return errOrphanBlock
188                 }
189         }
190
191         return nil
192 }