OSDN Git Service

Format netsync module code directory (#88)
[bytom/vapor.git] / netsync / chainmgr / handle.go
1 package chainmgr
2
3 import (
4         "errors"
5         "reflect"
6
7         log "github.com/sirupsen/logrus"
8
9         cfg "github.com/vapor/config"
10         "github.com/vapor/consensus"
11         "github.com/vapor/event"
12         msgs "github.com/vapor/netsync/messages"
13         "github.com/vapor/netsync/peers"
14         "github.com/vapor/p2p"
15         core "github.com/vapor/protocol"
16         "github.com/vapor/protocol/bc"
17         "github.com/vapor/protocol/bc/types"
18 )
19
20 const (
21         logModule = "netsync"
22 )
23
24 // Chain is the interface for Bytom core
25 type Chain interface {
26         BestBlockHeader() *types.BlockHeader
27         BestBlockHeight() uint64
28         GetBlockByHash(*bc.Hash) (*types.Block, error)
29         GetBlockByHeight(uint64) (*types.Block, error)
30         GetHeaderByHash(*bc.Hash) (*types.BlockHeader, error)
31         GetHeaderByHeight(uint64) (*types.BlockHeader, error)
32         GetTransactionStatus(*bc.Hash) (*bc.TransactionStatus, error)
33         InMainChain(bc.Hash) bool
34         ProcessBlock(*types.Block) (bool, error)
35         ValidateTx(*types.Tx) (bool, error)
36 }
37
38 type Switch interface {
39         AddReactor(name string, reactor p2p.Reactor) p2p.Reactor
40         AddBannedPeer(string) error
41         Start() (bool, error)
42         Stop() bool
43         IsListening() bool
44         DialPeerWithAddress(addr *p2p.NetAddress) error
45         Peers() *p2p.PeerSet
46 }
47
48 //ChainManager is responsible for the business layer information synchronization
49 type ChainManager struct {
50         sw          Switch
51         chain       Chain
52         txPool      *core.TxPool
53         blockKeeper *blockKeeper
54         peers       *peers.PeerSet
55
56         txSyncCh chan *txSyncMsg
57         quitSync chan struct{}
58         config   *cfg.Config
59
60         eventDispatcher *event.Dispatcher
61         txMsgSub        *event.Subscription
62 }
63
64 //NewChainManager create a chain sync manager.
65 func NewChainManager(config *cfg.Config, sw Switch, chain Chain, txPool *core.TxPool, dispatcher *event.Dispatcher, peers *peers.PeerSet) (*ChainManager, error) {
66         manager := &ChainManager{
67                 sw:              sw,
68                 txPool:          txPool,
69                 chain:           chain,
70                 blockKeeper:     newBlockKeeper(chain, peers),
71                 peers:           peers,
72                 txSyncCh:        make(chan *txSyncMsg),
73                 quitSync:        make(chan struct{}),
74                 config:          config,
75                 eventDispatcher: dispatcher,
76         }
77
78         if !config.VaultMode {
79                 protocolReactor := NewProtocolReactor(manager)
80                 manager.sw.AddReactor("PROTOCOL", protocolReactor)
81         }
82         return manager, nil
83 }
84
85 func (cm *ChainManager) AddPeer(peer peers.BasePeer) {
86         cm.peers.AddPeer(peer)
87 }
88
89 //IsCaughtUp check wheather the peer finish the sync
90 func (cm *ChainManager) IsCaughtUp() bool {
91         peer := cm.peers.BestPeer(consensus.SFFullNode)
92         return peer == nil || peer.Height() <= cm.chain.BestBlockHeight()
93 }
94
95 func (cm *ChainManager) handleBlockMsg(peer *peers.Peer, msg *msgs.BlockMessage) {
96         block, err := msg.GetBlock()
97         if err != nil {
98                 return
99         }
100         cm.blockKeeper.processBlock(peer.ID(), block)
101 }
102
103 func (cm *ChainManager) handleBlocksMsg(peer *peers.Peer, msg *msgs.BlocksMessage) {
104         blocks, err := msg.GetBlocks()
105         if err != nil {
106                 log.WithFields(log.Fields{"module": logModule, "err": err}).Debug("fail on handleBlocksMsg GetBlocks")
107                 return
108         }
109
110         cm.blockKeeper.processBlocks(peer.ID(), blocks)
111 }
112
113 func (cm *ChainManager) handleFilterAddMsg(peer *peers.Peer, msg *msgs.FilterAddMessage) {
114         peer.AddFilterAddress(msg.Address)
115 }
116
117 func (cm *ChainManager) handleFilterClearMsg(peer *peers.Peer) {
118         peer.FilterClear()
119 }
120
121 func (cm *ChainManager) handleFilterLoadMsg(peer *peers.Peer, msg *msgs.FilterLoadMessage) {
122         peer.AddFilterAddresses(msg.Addresses)
123 }
124
125 func (cm *ChainManager) handleGetBlockMsg(peer *peers.Peer, msg *msgs.GetBlockMessage) {
126         var block *types.Block
127         var err error
128         if msg.Height != 0 {
129                 block, err = cm.chain.GetBlockByHeight(msg.Height)
130         } else {
131                 block, err = cm.chain.GetBlockByHash(msg.GetHash())
132         }
133         if err != nil {
134                 log.WithFields(log.Fields{"module": logModule, "err": err}).Warning("fail on handleGetBlockMsg get block from chain")
135                 return
136         }
137
138         ok, err := peer.SendBlock(block)
139         if !ok {
140                 cm.peers.RemovePeer(peer.ID())
141         }
142         if err != nil {
143                 log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on handleGetBlockMsg sentBlock")
144         }
145 }
146
147 func (cm *ChainManager) handleGetBlocksMsg(peer *peers.Peer, msg *msgs.GetBlocksMessage) {
148         blocks, err := cm.blockKeeper.locateBlocks(msg.GetBlockLocator(), msg.GetStopHash())
149         if err != nil || len(blocks) == 0 {
150                 return
151         }
152
153         totalSize := 0
154         sendBlocks := []*types.Block{}
155         for _, block := range blocks {
156                 rawData, err := block.MarshalText()
157                 if err != nil {
158                         log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on handleGetBlocksMsg marshal block")
159                         continue
160                 }
161
162                 if totalSize+len(rawData) > msgs.MaxBlockchainResponseSize/2 {
163                         break
164                 }
165                 totalSize += len(rawData)
166                 sendBlocks = append(sendBlocks, block)
167         }
168
169         ok, err := peer.SendBlocks(sendBlocks)
170         if !ok {
171                 cm.peers.RemovePeer(peer.ID())
172         }
173         if err != nil {
174                 log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on handleGetBlocksMsg sentBlock")
175         }
176 }
177
178 func (cm *ChainManager) handleGetHeadersMsg(peer *peers.Peer, msg *msgs.GetHeadersMessage) {
179         headers, err := cm.blockKeeper.locateHeaders(msg.GetBlockLocator(), msg.GetStopHash())
180         if err != nil || len(headers) == 0 {
181                 log.WithFields(log.Fields{"module": logModule, "err": err}).Debug("fail on handleGetHeadersMsg locateHeaders")
182                 return
183         }
184
185         ok, err := peer.SendHeaders(headers)
186         if !ok {
187                 cm.peers.RemovePeer(peer.ID())
188         }
189         if err != nil {
190                 log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on handleGetHeadersMsg sentBlock")
191         }
192 }
193
194 func (cm *ChainManager) handleGetMerkleBlockMsg(peer *peers.Peer, msg *msgs.GetMerkleBlockMessage) {
195         var err error
196         var block *types.Block
197         if msg.Height != 0 {
198                 block, err = cm.chain.GetBlockByHeight(msg.Height)
199         } else {
200                 block, err = cm.chain.GetBlockByHash(msg.GetHash())
201         }
202         if err != nil {
203                 log.WithFields(log.Fields{"module": logModule, "err": err}).Warning("fail on handleGetMerkleBlockMsg get block from chain")
204                 return
205         }
206
207         blockHash := block.Hash()
208         txStatus, err := cm.chain.GetTransactionStatus(&blockHash)
209         if err != nil {
210                 log.WithFields(log.Fields{"module": logModule, "err": err}).Warning("fail on handleGetMerkleBlockMsg get transaction status")
211                 return
212         }
213
214         ok, err := peer.SendMerkleBlock(block, txStatus)
215         if err != nil {
216                 log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on handleGetMerkleBlockMsg sentMerkleBlock")
217                 return
218         }
219
220         if !ok {
221                 cm.peers.RemovePeer(peer.ID())
222         }
223 }
224
225 func (cm *ChainManager) handleHeadersMsg(peer *peers.Peer, msg *msgs.HeadersMessage) {
226         headers, err := msg.GetHeaders()
227         if err != nil {
228                 log.WithFields(log.Fields{"module": logModule, "err": err}).Debug("fail on handleHeadersMsg GetHeaders")
229                 return
230         }
231
232         cm.blockKeeper.processHeaders(peer.ID(), headers)
233 }
234
235 func (cm *ChainManager) handleStatusMsg(basePeer peers.BasePeer, msg *msgs.StatusMessage) {
236         if peer := cm.peers.GetPeer(basePeer.ID()); peer != nil {
237                 peer.SetStatus(msg.Height, msg.GetHash())
238                 return
239         }
240 }
241
242 func (cm *ChainManager) handleTransactionMsg(peer *peers.Peer, msg *msgs.TransactionMessage) {
243         tx, err := msg.GetTransaction()
244         if err != nil {
245                 cm.peers.AddBanScore(peer.ID(), 0, 10, "fail on get tx from message")
246                 return
247         }
248
249         if isOrphan, err := cm.chain.ValidateTx(tx); err != nil && err != core.ErrDustTx && !isOrphan {
250                 cm.peers.AddBanScore(peer.ID(), 10, 0, "fail on validate tx transaction")
251         }
252         cm.peers.MarkTx(peer.ID(), tx.ID)
253 }
254
255 func (cm *ChainManager) handleTransactionsMsg(peer *peers.Peer, msg *msgs.TransactionsMessage) {
256         txs, err := msg.GetTransactions()
257         if err != nil {
258                 cm.peers.AddBanScore(peer.ID(), 0, 20, "fail on get txs from message")
259                 return
260         }
261
262         if len(txs) > msgs.TxsMsgMaxTxNum {
263                 cm.peers.AddBanScore(peer.ID(), 20, 0, "exceeded the maximum tx number limit")
264                 return
265         }
266
267         for _, tx := range txs {
268                 if isOrphan, err := cm.chain.ValidateTx(tx); err != nil && !isOrphan {
269                         cm.peers.AddBanScore(peer.ID(), 10, 0, "fail on validate tx transaction")
270                         return
271                 }
272                 cm.peers.MarkTx(peer.ID(), tx.ID)
273         }
274 }
275
276 func (cm *ChainManager) processMsg(basePeer peers.BasePeer, msgType byte, msg msgs.BlockchainMessage) {
277         peer := cm.peers.GetPeer(basePeer.ID())
278         if peer == nil {
279                 return
280         }
281
282         log.WithFields(log.Fields{
283                 "module":  logModule,
284                 "peer":    basePeer.Addr(),
285                 "type":    reflect.TypeOf(msg),
286                 "message": msg.String(),
287         }).Info("receive message from peer")
288
289         switch msg := msg.(type) {
290         case *msgs.GetBlockMessage:
291                 cm.handleGetBlockMsg(peer, msg)
292
293         case *msgs.BlockMessage:
294                 cm.handleBlockMsg(peer, msg)
295
296         case *msgs.StatusMessage:
297                 cm.handleStatusMsg(basePeer, msg)
298
299         case *msgs.TransactionMessage:
300                 cm.handleTransactionMsg(peer, msg)
301
302         case *msgs.TransactionsMessage:
303                 cm.handleTransactionsMsg(peer, msg)
304
305         case *msgs.GetHeadersMessage:
306                 cm.handleGetHeadersMsg(peer, msg)
307
308         case *msgs.HeadersMessage:
309                 cm.handleHeadersMsg(peer, msg)
310
311         case *msgs.GetBlocksMessage:
312                 cm.handleGetBlocksMsg(peer, msg)
313
314         case *msgs.BlocksMessage:
315                 cm.handleBlocksMsg(peer, msg)
316
317         case *msgs.FilterLoadMessage:
318                 cm.handleFilterLoadMsg(peer, msg)
319
320         case *msgs.FilterAddMessage:
321                 cm.handleFilterAddMsg(peer, msg)
322
323         case *msgs.FilterClearMessage:
324                 cm.handleFilterClearMsg(peer)
325
326         case *msgs.GetMerkleBlockMessage:
327                 cm.handleGetMerkleBlockMsg(peer, msg)
328
329         default:
330                 log.WithFields(log.Fields{
331                         "module":       logModule,
332                         "peer":         basePeer.Addr(),
333                         "message_type": reflect.TypeOf(msg),
334                 }).Error("unhandled message type")
335         }
336 }
337
338 func (cm *ChainManager) RemovePeer(peerID string) {
339         cm.peers.RemovePeer(peerID)
340 }
341
342 func (cm *ChainManager) SendStatus(peer peers.BasePeer) error {
343         p := cm.peers.GetPeer(peer.ID())
344         if p == nil {
345                 return errors.New("invalid peer")
346         }
347
348         if err := p.SendStatus(cm.chain.BestBlockHeader()); err != nil {
349                 cm.peers.RemovePeer(p.ID())
350                 return err
351         }
352         return nil
353 }
354
355 func (cm *ChainManager) Start() error {
356         var err error
357         cm.txMsgSub, err = cm.eventDispatcher.Subscribe(core.TxMsgEvent{})
358         if err != nil {
359                 return err
360         }
361
362         // broadcast transactions
363         go cm.txBroadcastLoop()
364         go cm.txSyncLoop()
365
366         return nil
367 }
368
369 //Stop stop sync manager
370 func (cm *ChainManager) Stop() {
371         close(cm.quitSync)
372 }