OSDN Git Service

Add mempool sync test (#114)
[bytom/vapor.git] / netsync / chainmgr / handle.go
1 package chainmgr
2
3 import (
4         "errors"
5         "reflect"
6
7         log "github.com/sirupsen/logrus"
8
9         cfg "github.com/vapor/config"
10         "github.com/vapor/consensus"
11         "github.com/vapor/event"
12         msgs "github.com/vapor/netsync/messages"
13         "github.com/vapor/netsync/peers"
14         "github.com/vapor/p2p"
15         core "github.com/vapor/protocol"
16         "github.com/vapor/protocol/bc"
17         "github.com/vapor/protocol/bc/types"
18 )
19
20 const (
21         logModule = "netsync"
22 )
23
24 // Chain is the interface for Bytom core
25 type Chain interface {
26         BestBlockHeader() *types.BlockHeader
27         BestBlockHeight() uint64
28         GetBlockByHash(*bc.Hash) (*types.Block, error)
29         GetBlockByHeight(uint64) (*types.Block, error)
30         GetHeaderByHash(*bc.Hash) (*types.BlockHeader, error)
31         GetHeaderByHeight(uint64) (*types.BlockHeader, error)
32         GetTransactionStatus(*bc.Hash) (*bc.TransactionStatus, error)
33         InMainChain(bc.Hash) bool
34         ProcessBlock(*types.Block) (bool, error)
35         ValidateTx(*types.Tx) (bool, error)
36 }
37
38 type Switch interface {
39         AddReactor(name string, reactor p2p.Reactor) p2p.Reactor
40         AddBannedPeer(string) error
41         Start() (bool, error)
42         Stop() bool
43         IsListening() bool
44         DialPeerWithAddress(addr *p2p.NetAddress) error
45         Peers() *p2p.PeerSet
46 }
47
48 // Mempool is the interface for Bytom mempool
49 type Mempool interface {
50         GetTransactions() []*core.TxDesc
51 }
52
53 //Manager is responsible for the business layer information synchronization
54 type Manager struct {
55         sw          Switch
56         chain       Chain
57         mempool     Mempool
58         blockKeeper *blockKeeper
59         peers       *peers.PeerSet
60
61         txSyncCh chan *txSyncMsg
62         quit     chan struct{}
63         config   *cfg.Config
64
65         eventDispatcher *event.Dispatcher
66         txMsgSub        *event.Subscription
67 }
68
69 //NewChainManager create a chain sync manager.
70 func NewManager(config *cfg.Config, sw Switch, chain Chain, mempool Mempool, dispatcher *event.Dispatcher, peers *peers.PeerSet) (*Manager, error) {
71         manager := &Manager{
72                 sw:              sw,
73                 mempool:         mempool,
74                 chain:           chain,
75                 blockKeeper:     newBlockKeeper(chain, peers),
76                 peers:           peers,
77                 txSyncCh:        make(chan *txSyncMsg),
78                 quit:            make(chan struct{}),
79                 config:          config,
80                 eventDispatcher: dispatcher,
81         }
82
83         if !config.VaultMode {
84                 protocolReactor := NewProtocolReactor(manager)
85                 manager.sw.AddReactor("PROTOCOL", protocolReactor)
86         }
87         return manager, nil
88 }
89
90 func (m *Manager) AddPeer(peer peers.BasePeer) {
91         m.peers.AddPeer(peer)
92 }
93
94 //IsCaughtUp check wheather the peer finish the sync
95 func (m *Manager) IsCaughtUp() bool {
96         peer := m.peers.BestPeer(consensus.SFFullNode)
97         return peer == nil || peer.Height() <= m.chain.BestBlockHeight()
98 }
99
100 func (m *Manager) handleBlockMsg(peer *peers.Peer, msg *msgs.BlockMessage) {
101         block, err := msg.GetBlock()
102         if err != nil {
103                 return
104         }
105         m.blockKeeper.processBlock(peer.ID(), block)
106 }
107
108 func (m *Manager) handleBlocksMsg(peer *peers.Peer, msg *msgs.BlocksMessage) {
109         blocks, err := msg.GetBlocks()
110         if err != nil {
111                 log.WithFields(log.Fields{"module": logModule, "err": err}).Debug("fail on handleBlocksMsg GetBlocks")
112                 return
113         }
114
115         m.blockKeeper.processBlocks(peer.ID(), blocks)
116 }
117
118 func (m *Manager) handleFilterAddMsg(peer *peers.Peer, msg *msgs.FilterAddMessage) {
119         peer.AddFilterAddress(msg.Address)
120 }
121
122 func (m *Manager) handleFilterClearMsg(peer *peers.Peer) {
123         peer.FilterClear()
124 }
125
126 func (m *Manager) handleFilterLoadMsg(peer *peers.Peer, msg *msgs.FilterLoadMessage) {
127         peer.AddFilterAddresses(msg.Addresses)
128 }
129
130 func (m *Manager) handleGetBlockMsg(peer *peers.Peer, msg *msgs.GetBlockMessage) {
131         var block *types.Block
132         var err error
133         if msg.Height != 0 {
134                 block, err = m.chain.GetBlockByHeight(msg.Height)
135         } else {
136                 block, err = m.chain.GetBlockByHash(msg.GetHash())
137         }
138         if err != nil {
139                 log.WithFields(log.Fields{"module": logModule, "err": err}).Warning("fail on handleGetBlockMsg get block from chain")
140                 return
141         }
142
143         ok, err := peer.SendBlock(block)
144         if !ok {
145                 m.peers.RemovePeer(peer.ID())
146         }
147         if err != nil {
148                 log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on handleGetBlockMsg sentBlock")
149         }
150 }
151
152 func (m *Manager) handleGetBlocksMsg(peer *peers.Peer, msg *msgs.GetBlocksMessage) {
153         blocks, err := m.blockKeeper.locateBlocks(msg.GetBlockLocator(), msg.GetStopHash())
154         if err != nil || len(blocks) == 0 {
155                 return
156         }
157
158         totalSize := 0
159         sendBlocks := []*types.Block{}
160         for _, block := range blocks {
161                 rawData, err := block.MarshalText()
162                 if err != nil {
163                         log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on handleGetBlocksMsg marshal block")
164                         continue
165                 }
166
167                 if totalSize+len(rawData) > msgs.MaxBlockchainResponseSize/2 {
168                         break
169                 }
170                 totalSize += len(rawData)
171                 sendBlocks = append(sendBlocks, block)
172         }
173
174         ok, err := peer.SendBlocks(sendBlocks)
175         if !ok {
176                 m.peers.RemovePeer(peer.ID())
177         }
178         if err != nil {
179                 log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on handleGetBlocksMsg sentBlock")
180         }
181 }
182
183 func (m *Manager) handleGetHeadersMsg(peer *peers.Peer, msg *msgs.GetHeadersMessage) {
184         headers, err := m.blockKeeper.locateHeaders(msg.GetBlockLocator(), msg.GetStopHash())
185         if err != nil || len(headers) == 0 {
186                 log.WithFields(log.Fields{"module": logModule, "err": err}).Debug("fail on handleGetHeadersMsg locateHeaders")
187                 return
188         }
189
190         ok, err := peer.SendHeaders(headers)
191         if !ok {
192                 m.peers.RemovePeer(peer.ID())
193         }
194         if err != nil {
195                 log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on handleGetHeadersMsg sentBlock")
196         }
197 }
198
199 func (m *Manager) handleGetMerkleBlockMsg(peer *peers.Peer, msg *msgs.GetMerkleBlockMessage) {
200         var err error
201         var block *types.Block
202         if msg.Height != 0 {
203                 block, err = m.chain.GetBlockByHeight(msg.Height)
204         } else {
205                 block, err = m.chain.GetBlockByHash(msg.GetHash())
206         }
207         if err != nil {
208                 log.WithFields(log.Fields{"module": logModule, "err": err}).Warning("fail on handleGetMerkleBlockMsg get block from chain")
209                 return
210         }
211
212         blockHash := block.Hash()
213         txStatus, err := m.chain.GetTransactionStatus(&blockHash)
214         if err != nil {
215                 log.WithFields(log.Fields{"module": logModule, "err": err}).Warning("fail on handleGetMerkleBlockMsg get transaction status")
216                 return
217         }
218
219         ok, err := peer.SendMerkleBlock(block, txStatus)
220         if err != nil {
221                 log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on handleGetMerkleBlockMsg sentMerkleBlock")
222                 return
223         }
224
225         if !ok {
226                 m.peers.RemovePeer(peer.ID())
227         }
228 }
229
230 func (m *Manager) handleHeadersMsg(peer *peers.Peer, msg *msgs.HeadersMessage) {
231         headers, err := msg.GetHeaders()
232         if err != nil {
233                 log.WithFields(log.Fields{"module": logModule, "err": err}).Debug("fail on handleHeadersMsg GetHeaders")
234                 return
235         }
236
237         m.blockKeeper.processHeaders(peer.ID(), headers)
238 }
239
240 func (m *Manager) handleStatusMsg(basePeer peers.BasePeer, msg *msgs.StatusMessage) {
241         if peer := m.peers.GetPeer(basePeer.ID()); peer != nil {
242                 peer.SetStatus(msg.Height, msg.GetHash())
243                 return
244         }
245 }
246
247 func (m *Manager) handleTransactionMsg(peer *peers.Peer, msg *msgs.TransactionMessage) {
248         tx, err := msg.GetTransaction()
249         if err != nil {
250                 m.peers.AddBanScore(peer.ID(), 0, 10, "fail on get tx from message")
251                 return
252         }
253
254         if isOrphan, err := m.chain.ValidateTx(tx); err != nil && err != core.ErrDustTx && !isOrphan {
255                 m.peers.AddBanScore(peer.ID(), 10, 0, "fail on validate tx transaction")
256         }
257         m.peers.MarkTx(peer.ID(), tx.ID)
258 }
259
260 func (m *Manager) handleTransactionsMsg(peer *peers.Peer, msg *msgs.TransactionsMessage) {
261         txs, err := msg.GetTransactions()
262         if err != nil {
263                 m.peers.AddBanScore(peer.ID(), 0, 20, "fail on get txs from message")
264                 return
265         }
266
267         if len(txs) > msgs.TxsMsgMaxTxNum {
268                 m.peers.AddBanScore(peer.ID(), 20, 0, "exceeded the maximum tx number limit")
269                 return
270         }
271
272         for _, tx := range txs {
273                 if isOrphan, err := m.chain.ValidateTx(tx); err != nil && !isOrphan {
274                         m.peers.AddBanScore(peer.ID(), 10, 0, "fail on validate tx transaction")
275                         return
276                 }
277                 m.peers.MarkTx(peer.ID(), tx.ID)
278         }
279 }
280
281 func (m *Manager) processMsg(basePeer peers.BasePeer, msgType byte, msg msgs.BlockchainMessage) {
282         peer := m.peers.GetPeer(basePeer.ID())
283         if peer == nil {
284                 return
285         }
286
287         log.WithFields(log.Fields{
288                 "module":  logModule,
289                 "peer":    basePeer.Addr(),
290                 "type":    reflect.TypeOf(msg),
291                 "message": msg.String(),
292         }).Info("receive message from peer")
293
294         switch msg := msg.(type) {
295         case *msgs.GetBlockMessage:
296                 m.handleGetBlockMsg(peer, msg)
297
298         case *msgs.BlockMessage:
299                 m.handleBlockMsg(peer, msg)
300
301         case *msgs.StatusMessage:
302                 m.handleStatusMsg(basePeer, msg)
303
304         case *msgs.TransactionMessage:
305                 m.handleTransactionMsg(peer, msg)
306
307         case *msgs.TransactionsMessage:
308                 m.handleTransactionsMsg(peer, msg)
309
310         case *msgs.GetHeadersMessage:
311                 m.handleGetHeadersMsg(peer, msg)
312
313         case *msgs.HeadersMessage:
314                 m.handleHeadersMsg(peer, msg)
315
316         case *msgs.GetBlocksMessage:
317                 m.handleGetBlocksMsg(peer, msg)
318
319         case *msgs.BlocksMessage:
320                 m.handleBlocksMsg(peer, msg)
321
322         case *msgs.FilterLoadMessage:
323                 m.handleFilterLoadMsg(peer, msg)
324
325         case *msgs.FilterAddMessage:
326                 m.handleFilterAddMsg(peer, msg)
327
328         case *msgs.FilterClearMessage:
329                 m.handleFilterClearMsg(peer)
330
331         case *msgs.GetMerkleBlockMessage:
332                 m.handleGetMerkleBlockMsg(peer, msg)
333
334         default:
335                 log.WithFields(log.Fields{
336                         "module":       logModule,
337                         "peer":         basePeer.Addr(),
338                         "message_type": reflect.TypeOf(msg),
339                 }).Error("unhandled message type")
340         }
341 }
342
343 func (m *Manager) RemovePeer(peerID string) {
344         m.peers.RemovePeer(peerID)
345 }
346
347 func (m *Manager) SendStatus(peer peers.BasePeer) error {
348         p := m.peers.GetPeer(peer.ID())
349         if p == nil {
350                 return errors.New("invalid peer")
351         }
352
353         if err := p.SendStatus(m.chain.BestBlockHeader()); err != nil {
354                 m.peers.RemovePeer(p.ID())
355                 return err
356         }
357         return nil
358 }
359
360 func (m *Manager) Start() error {
361         var err error
362         m.txMsgSub, err = m.eventDispatcher.Subscribe(core.TxMsgEvent{})
363         if err != nil {
364                 return err
365         }
366
367         go m.broadcastTxsLoop()
368         go m.syncMempoolLoop()
369
370         return nil
371 }
372
373 //Stop stop sync manager
374 func (m *Manager) Stop() {
375         close(m.quit)
376 }