OSDN Git Service

Optimize status message process (#66)
authoryahtoo <yahtoo.ma@gmail.com>
Thu, 16 May 2019 14:15:48 +0000 (22:15 +0800)
committerPaladz <yzhu101@uottawa.ca>
Thu, 16 May 2019 14:15:48 +0000 (22:15 +0800)
* StatusResponseMessage del GenesisHash

* Del useless broadcastMinedBlock msg

* Add StatusMsg process

* Add new status broadcast

netsync/block_fetcher.go
netsync/block_keeper.go
netsync/handle.go
netsync/message.go
netsync/message_test.go [new file with mode: 0644]
netsync/peer.go
netsync/protocol_reactor.go
netsync/tool_test.go

index e7fe585..c48cd0b 100644 (file)
@@ -92,6 +92,11 @@ func (f *blockFetcher) insert(msg *blockMsg) {
                log.WithFields(log.Fields{"module": logModule, "err": err}).Error("blockFetcher fail on broadcast new block")
                return
        }
+
+       if err := f.peers.broadcastNewStatus(msg.block); err != nil {
+               log.WithFields(log.Fields{"module": logModule, "err": err}).Error("blockFetcher fail on broadcast new status")
+               return
+       }
 }
 
 func (f *blockFetcher) processNewBlock(msg *blockMsg) {
index 62a4025..9b61081 100644 (file)
@@ -379,11 +379,6 @@ func (bk *blockKeeper) startSync() bool {
 }
 
 func (bk *blockKeeper) syncWorker() {
-       genesisBlock, err := bk.chain.GetBlockByHeight(0)
-       if err != nil {
-               log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on handleStatusRequestMsg get genesis")
-               return
-       }
        syncTicker := time.NewTicker(syncCycle)
        defer syncTicker.Stop()
 
@@ -398,11 +393,7 @@ func (bk *blockKeeper) syncWorker() {
                        log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on syncWorker get best block")
                }
 
-               if err := bk.peers.broadcastMinedBlock(block); err != nil {
-                       log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on syncWorker broadcast new block")
-               }
-
-               if err = bk.peers.broadcastNewStatus(block, genesisBlock); err != nil {
+               if err = bk.peers.broadcastNewStatus(block); err != nil {
                        log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on syncWorker broadcast new status")
                }
        }
index db50d8a..70702c9 100644 (file)
@@ -109,6 +109,10 @@ func newSyncManager(config *cfg.Config, sw Switch, chain Chain, txPool *core.TxP
        return manager, nil
 }
 
+func (sm *SyncManager) AddPeer(peer BasePeer) {
+       sm.peers.addPeer(peer)
+}
+
 //BestPeer return the highest p2p peerInfo
 func (sm *SyncManager) BestPeer() *PeerInfo {
        bestPeer := sm.peers.bestPeer(consensus.SFFullNode)
@@ -303,32 +307,11 @@ func (sm *SyncManager) handleMineBlockMsg(peer *peer, msg *MineBlockMessage) {
        peer.setStatus(block.Height, &hash)
 }
 
-func (sm *SyncManager) handleStatusRequestMsg(peer BasePeer) {
-       bestHeader := sm.chain.BestBlockHeader()
-       genesisBlock, err := sm.chain.GetBlockByHeight(0)
-       if err != nil {
-               log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on handleStatusRequestMsg get genesis")
-       }
-
-       genesisHash := genesisBlock.Hash()
-       msg := NewStatusResponseMessage(bestHeader, &genesisHash)
-       if ok := peer.TrySend(BlockchainChannel, struct{ BlockchainMessage }{msg}); !ok {
-               sm.peers.removePeer(peer.ID())
-       }
-}
-
-func (sm *SyncManager) handleStatusResponseMsg(basePeer BasePeer, msg *StatusResponseMessage) {
+func (sm *SyncManager) handleStatusMsg(basePeer BasePeer, msg *StatusMessage) {
        if peer := sm.peers.getPeer(basePeer.ID()); peer != nil {
                peer.setStatus(msg.Height, msg.GetHash())
                return
        }
-
-       if genesisHash := msg.GetGenesisHash(); sm.genesisHash != *genesisHash {
-               log.WithFields(log.Fields{"module": logModule, "remote genesis": genesisHash.String(), "local genesis": sm.genesisHash.String()}).Warn("fail hand shake due to differnt genesis")
-               return
-       }
-
-       sm.peers.addPeer(basePeer, msg.Height, msg.GetHash())
 }
 
 func (sm *SyncManager) handleTransactionMsg(peer *peer, msg *TransactionMessage) {
@@ -363,7 +346,7 @@ func (sm *SyncManager) PeerCount() int {
 
 func (sm *SyncManager) processMsg(basePeer BasePeer, msgType byte, msg BlockchainMessage) {
        peer := sm.peers.getPeer(basePeer.ID())
-       if peer == nil && msgType != StatusResponseByte && msgType != StatusRequestByte {
+       if peer == nil {
                return
        }
 
@@ -381,11 +364,8 @@ func (sm *SyncManager) processMsg(basePeer BasePeer, msgType byte, msg Blockchai
        case *BlockMessage:
                sm.handleBlockMsg(peer, msg)
 
-       case *StatusRequestMessage:
-               sm.handleStatusRequestMsg(basePeer)
-
-       case *StatusResponseMessage:
-               sm.handleStatusResponseMsg(basePeer, msg)
+       case *StatusMessage:
+               sm.handleStatusMsg(basePeer, msg)
 
        case *TransactionMessage:
                sm.handleTransactionMsg(peer, msg)
@@ -426,6 +406,19 @@ func (sm *SyncManager) processMsg(basePeer BasePeer, msgType byte, msg Blockchai
        }
 }
 
+func (sm *SyncManager) SendStatus(peer BasePeer) error {
+       p := sm.peers.getPeer(peer.ID())
+       if p == nil {
+               return errors.New("invalid peer")
+       }
+
+       if err := p.sendStatus(sm.chain.BestBlockHeader()); err != nil {
+               sm.peers.removePeer(p.ID())
+               return err
+       }
+       return nil
+}
+
 func (sm *SyncManager) Start() error {
        var err error
        if _, err = sm.sw.Start(); err != nil {
index ed7b4a4..1c86d14 100644 (file)
@@ -23,8 +23,7 @@ const (
        HeadersResponseByte = byte(0x13)
        BlocksRequestByte   = byte(0x14)
        BlocksResponseByte  = byte(0x15)
-       StatusRequestByte   = byte(0x20)
-       StatusResponseByte  = byte(0x21)
+       StatusByte          = byte(0x21)
        NewTransactionByte  = byte(0x30)
        NewMineBlockByte    = byte(0x40)
        FilterLoadByte      = byte(0x50)
@@ -49,8 +48,7 @@ var _ = wire.RegisterInterface(
        wire.ConcreteType{&HeadersMessage{}, HeadersResponseByte},
        wire.ConcreteType{&GetBlocksMessage{}, BlocksRequestByte},
        wire.ConcreteType{&BlocksMessage{}, BlocksResponseByte},
-       wire.ConcreteType{&StatusRequestMessage{}, StatusRequestByte},
-       wire.ConcreteType{&StatusResponseMessage{}, StatusResponseByte},
+       wire.ConcreteType{&StatusMessage{}, StatusByte},
        wire.ConcreteType{&TransactionMessage{}, NewTransactionByte},
        wire.ConcreteType{&MineBlockMessage{}, NewMineBlockByte},
        wire.ConcreteType{&FilterLoadMessage{}, FilterLoadByte},
@@ -274,42 +272,27 @@ func (m *BlocksMessage) String() string {
        return fmt.Sprintf("{blocks_length: %d}", len(m.RawBlocks))
 }
 
-//StatusRequestMessage status request msg
-type StatusRequestMessage struct{}
-
-func (m *StatusRequestMessage) String() string {
-       return "{}"
-}
-
 //StatusResponseMessage get status response msg
-type StatusResponseMessage struct {
-       Height      uint64
-       RawHash     [32]byte
-       GenesisHash [32]byte
+type StatusMessage struct {
+       Height  uint64
+       RawHash [32]byte
 }
 
 //NewStatusResponseMessage construct get status response msg
-func NewStatusResponseMessage(blockHeader *types.BlockHeader, hash *bc.Hash) *StatusResponseMessage {
-       return &StatusResponseMessage{
-               Height:      blockHeader.Height,
-               RawHash:     blockHeader.Hash().Byte32(),
-               GenesisHash: hash.Byte32(),
+func NewStatusMessage(blockHeader *types.BlockHeader) *StatusMessage {
+       return &StatusMessage{
+               Height:  blockHeader.Height,
+               RawHash: blockHeader.Hash().Byte32(),
        }
 }
 
 //GetHash get hash from msg
-func (m *StatusResponseMessage) GetHash() *bc.Hash {
+func (m *StatusMessage) GetHash() *bc.Hash {
        hash := bc.NewHash(m.RawHash)
        return &hash
 }
 
-//GetGenesisHash get hash from msg
-func (m *StatusResponseMessage) GetGenesisHash() *bc.Hash {
-       hash := bc.NewHash(m.GenesisHash)
-       return &hash
-}
-
-func (m *StatusResponseMessage) String() string {
+func (m *StatusMessage) String() string {
        return fmt.Sprintf("{height: %d, hash: %s}", m.Height, hex.EncodeToString(m.RawHash[:]))
 }
 
diff --git a/netsync/message_test.go b/netsync/message_test.go
new file mode 100644 (file)
index 0000000..8743df7
--- /dev/null
@@ -0,0 +1,173 @@
+package netsync
+
+import (
+       "reflect"
+       "testing"
+
+       "github.com/davecgh/go-spew/spew"
+
+       "github.com/vapor/protocol/bc"
+       "github.com/vapor/protocol/bc/types"
+)
+
+var testBlock = &types.Block{
+       BlockHeader: types.BlockHeader{
+               Version:   1,
+               Height:    0,
+               Timestamp: 1528945000,
+               BlockCommitment: types.BlockCommitment{
+                       TransactionsMerkleRoot: bc.Hash{V0: uint64(0x11)},
+                       TransactionStatusHash:  bc.Hash{V0: uint64(0x55)},
+               },
+       },
+}
+
+func TestBlockMessage(t *testing.T) {
+       blockMsg, err := NewBlockMessage(testBlock)
+       if err != nil {
+               t.Fatalf("create new block msg err:%s", err)
+       }
+
+       gotBlock, err := blockMsg.GetBlock()
+       if err != nil {
+               t.Fatalf("got block err:%s", err)
+       }
+
+       if !reflect.DeepEqual(gotBlock.BlockHeader, testBlock.BlockHeader) {
+               t.Errorf("block msg test err: got %s\nwant %s", spew.Sdump(gotBlock.BlockHeader), spew.Sdump(testBlock.BlockHeader))
+       }
+
+       blockMsg.RawBlock[1] = blockMsg.RawBlock[1] + 0x1
+       _, err = blockMsg.GetBlock()
+       if err == nil {
+               t.Fatalf("get mine block err")
+       }
+}
+
+var testHeaders = []*types.BlockHeader{
+       {
+               Version:   1,
+               Height:    0,
+               Timestamp: 1528945000,
+               BlockCommitment: types.BlockCommitment{
+                       TransactionsMerkleRoot: bc.Hash{V0: uint64(0x11)},
+                       TransactionStatusHash:  bc.Hash{V0: uint64(0x55)},
+               },
+       },
+       {
+               Version:   1,
+               Height:    1,
+               Timestamp: 1528945000,
+               BlockCommitment: types.BlockCommitment{
+                       TransactionsMerkleRoot: bc.Hash{V0: uint64(0x11)},
+                       TransactionStatusHash:  bc.Hash{V0: uint64(0x55)},
+               },
+       },
+       {
+               Version:   1,
+               Height:    3,
+               Timestamp: 1528945000,
+               BlockCommitment: types.BlockCommitment{
+                       TransactionsMerkleRoot: bc.Hash{V0: uint64(0x11)},
+                       TransactionStatusHash:  bc.Hash{V0: uint64(0x55)},
+               },
+       },
+}
+
+func TestHeadersMessage(t *testing.T) {
+       headersMsg, err := NewHeadersMessage(testHeaders)
+       if err != nil {
+               t.Fatalf("create headers msg err:%s", err)
+       }
+
+       gotHeaders, err := headersMsg.GetHeaders()
+       if err != nil {
+               t.Fatalf("got headers err:%s", err)
+       }
+
+       if !reflect.DeepEqual(gotHeaders, testHeaders) {
+               t.Errorf("headers msg test err: got %s\nwant %s", spew.Sdump(gotHeaders), spew.Sdump(testHeaders))
+       }
+}
+
+func TestGetBlockMessage(t *testing.T) {
+       getBlockMsg := GetBlockMessage{RawHash: [32]byte{0x01}}
+       gotHash := getBlockMsg.GetHash()
+
+       if !reflect.DeepEqual(gotHash.Byte32(), getBlockMsg.RawHash) {
+               t.Errorf("get block msg test err: got %s\nwant %s", spew.Sdump(gotHash.Byte32()), spew.Sdump(getBlockMsg.RawHash))
+       }
+}
+
+type testGetHeadersMessage struct {
+       blockLocator []*bc.Hash
+       stopHash     *bc.Hash
+}
+
+func TestGetHeadersMessage(t *testing.T) {
+       testMsg := testGetHeadersMessage{
+               blockLocator: []*bc.Hash{{V0: 0x01}, {V0: 0x02}, {V0: 0x03}},
+               stopHash:     &bc.Hash{V0: 0xaa, V2: 0x55},
+       }
+       getHeadersMsg := NewGetHeadersMessage(testMsg.blockLocator, testMsg.stopHash)
+       gotBlockLocator := getHeadersMsg.GetBlockLocator()
+       gotStopHash := getHeadersMsg.GetStopHash()
+
+       if !reflect.DeepEqual(testMsg.blockLocator, gotBlockLocator) {
+               t.Errorf("get headers msg test err: got %s\nwant %s", spew.Sdump(gotBlockLocator), spew.Sdump(testMsg.blockLocator))
+       }
+
+       if !reflect.DeepEqual(testMsg.stopHash, gotStopHash) {
+               t.Errorf("get headers msg test err: got %s\nwant %s", spew.Sdump(gotStopHash), spew.Sdump(testMsg.stopHash))
+       }
+}
+
+var testBlocks = []*types.Block{
+       {
+               BlockHeader: types.BlockHeader{
+                       Version:   1,
+                       Height:    0,
+                       Timestamp: 1528945000,
+                       BlockCommitment: types.BlockCommitment{
+                               TransactionsMerkleRoot: bc.Hash{V0: uint64(0x11)},
+                               TransactionStatusHash:  bc.Hash{V0: uint64(0x55)},
+                       },
+               },
+       },
+       {
+               BlockHeader: types.BlockHeader{
+                       Version:   1,
+                       Height:    0,
+                       Timestamp: 1528945000,
+                       BlockCommitment: types.BlockCommitment{
+                               TransactionsMerkleRoot: bc.Hash{V0: uint64(0x11)},
+                               TransactionStatusHash:  bc.Hash{V0: uint64(0x55)},
+                       },
+               },
+       },
+}
+
+func TestBlocksMessage(t *testing.T) {
+       blocksMsg, err := NewBlocksMessage(testBlocks)
+       if err != nil {
+               t.Fatalf("create blocks msg err:%s", err)
+       }
+       gotBlocks, err := blocksMsg.GetBlocks()
+       if err != nil {
+               t.Fatalf("get blocks err:%s", err)
+       }
+
+       for _, gotBlock := range gotBlocks {
+               if !reflect.DeepEqual(gotBlock.BlockHeader, testBlock.BlockHeader) {
+                       t.Errorf("block msg test err: got %s\nwant %s", spew.Sdump(gotBlock.BlockHeader), spew.Sdump(testBlock.BlockHeader))
+               }
+       }
+}
+
+func TestStatusMessage(t *testing.T) {
+       statusResponseMsg := NewStatusMessage(&testBlock.BlockHeader)
+       gotHash := statusResponseMsg.GetHash()
+       if !reflect.DeepEqual(*gotHash, testBlock.Hash()) {
+               t.Errorf("status response msg test err: got %s\nwant %s", spew.Sdump(*gotHash), spew.Sdump(testBlock.Hash()))
+       }
+}
index b024989..246f1bc 100644 (file)
@@ -23,6 +23,8 @@ const (
        defaultBanThreshold = uint32(100)
 )
 
+var errSendStatusMsg = errors.New("send status msg fail")
+
 //BasePeer is the interface for connection level peer
 type BasePeer interface {
        Addr() net.Addr
@@ -63,15 +65,14 @@ type peer struct {
        banScore    trust.DynamicBanScore
        knownTxs    *set.Set // Set of transaction hashes known to be known by this peer
        knownBlocks *set.Set // Set of block hashes known to be known by this peer
+       knownStatus uint64   // Set of chain status known to be known by this peer
        filterAdds  *set.Set // Set of addresses that the spv node cares about.
 }
 
-func newPeer(height uint64, hash *bc.Hash, basePeer BasePeer) *peer {
+func newPeer(basePeer BasePeer) *peer {
        return &peer{
                BasePeer:    basePeer,
                services:    basePeer.ServiceFlag(),
-               height:      height,
-               hash:        hash,
                knownTxs:    set.New(),
                knownBlocks: set.New(),
                filterAdds:  set.New(),
@@ -216,6 +217,13 @@ func (p *peer) markBlock(hash *bc.Hash) {
        p.knownBlocks.Add(hash.String())
 }
 
+func (p *peer) markNewStatus(height uint64) {
+       p.mtx.Lock()
+       defer p.mtx.Unlock()
+
+       p.knownStatus = height
+}
+
 func (p *peer) markTransaction(hash *bc.Hash) {
        p.mtx.Lock()
        defer p.mtx.Unlock()
@@ -310,6 +318,15 @@ func (p *peer) sendTransactions(txs []*types.Tx) (bool, error) {
        return true, nil
 }
 
+func (p *peer) sendStatus(header *types.BlockHeader) error {
+       msg := NewStatusMessage(header)
+       if ok := p.TrySend(BlockchainChannel, struct{ BlockchainMessage }{msg}); !ok {
+               return errSendStatusMsg
+       }
+       p.markNewStatus(header.Height)
+       return nil
+}
+
 func (p *peer) setStatus(height uint64, hash *bc.Hash) {
        p.mtx.Lock()
        defer p.mtx.Unlock()
@@ -348,12 +365,12 @@ func (ps *peerSet) addBanScore(peerID string, persistent, transient uint32, reas
        ps.removePeer(peerID)
 }
 
-func (ps *peerSet) addPeer(peer BasePeer, height uint64, hash *bc.Hash) {
+func (ps *peerSet) addPeer(peer BasePeer) {
        ps.mtx.Lock()
        defer ps.mtx.Unlock()
 
        if _, ok := ps.peers[peer.ID()]; !ok {
-               ps.peers[peer.ID()] = newPeer(height, hash, peer)
+               ps.peers[peer.ID()] = newPeer(peer)
                return
        }
        log.WithField("module", logModule).Warning("add existing peer to blockKeeper")
@@ -393,22 +410,21 @@ func (ps *peerSet) broadcastMinedBlock(block *types.Block) error {
                        continue
                }
                peer.markBlock(&hash)
+               peer.markNewStatus(block.Height)
        }
        return nil
 }
 
-func (ps *peerSet) broadcastNewStatus(bestBlock, genesisBlock *types.Block) error {
-       bestBlockHash := bestBlock.Hash()
-       peers := ps.peersWithoutBlock(&bestBlockHash)
-
-       genesisHash := genesisBlock.Hash()
-       msg := NewStatusResponseMessage(&bestBlock.BlockHeader, &genesisHash)
+func (ps *peerSet) broadcastNewStatus(bestBlock *types.Block) error {
+       msg := NewStatusMessage(&bestBlock.BlockHeader)
+       peers := ps.peersWithoutNewStatus(bestBlock.Height)
        for _, peer := range peers {
                if ok := peer.TrySend(BlockchainChannel, struct{ BlockchainMessage }{msg}); !ok {
-                       log.WithFields(log.Fields{"module": logModule, "peer": peer.Addr(), "type": reflect.TypeOf(msg), "message": msg.String()}).Warning("send message to peer error")
                        ps.removePeer(peer.ID())
                        continue
                }
+
+               peer.markNewStatus(bestBlock.Height)
        }
        return nil
 }
@@ -478,6 +494,19 @@ func (ps *peerSet) peersWithoutBlock(hash *bc.Hash) []*peer {
        return peers
 }
 
+func (ps *peerSet) peersWithoutNewStatus(height uint64) []*peer {
+       ps.mtx.RLock()
+       defer ps.mtx.RUnlock()
+
+       var peers []*peer
+       for _, peer := range ps.peers {
+               if peer.knownStatus < height {
+                       peers = append(peers, peer)
+               }
+       }
+       return peers
+}
+
 func (ps *peerSet) peersWithoutTx(hash *bc.Hash) []*peer {
        ps.mtx.RLock()
        defer ps.mtx.RUnlock()
index cb9be4b..8a6c610 100644 (file)
@@ -62,26 +62,12 @@ func (pr *ProtocolReactor) OnStop() {
 
 // AddPeer implements Reactor by sending our state to peer.
 func (pr *ProtocolReactor) AddPeer(peer *p2p.Peer) error {
-       if ok := peer.TrySend(BlockchainChannel, struct{ BlockchainMessage }{&StatusRequestMessage{}}); !ok {
-               return errStatusRequest
-       }
-
-       checkTicker := time.NewTicker(handshakeCheckPerid)
-       defer checkTicker.Stop()
-       timeout := time.NewTimer(handshakeTimeout)
-       defer timeout.Stop()
-       for {
-               select {
-               case <-checkTicker.C:
-                       if exist := pr.peers.getPeer(peer.Key); exist != nil {
-                               pr.sm.syncTransactions(peer.Key)
-                               return nil
-                       }
-
-               case <-timeout.C:
-                       return errProtocolHandshakeTimeout
-               }
+       pr.sm.AddPeer(peer)
+       if err := pr.sm.SendStatus(peer); err != nil {
+               return err
        }
+       pr.sm.syncTransactions(peer.Key)
+       return nil
 }
 
 // RemovePeer implements Reactor by removing peer from the pool.
index 69de67e..e817930 100644 (file)
@@ -117,8 +117,8 @@ func (nw *NetWork) HandsShake(nodeA, nodeB *SyncManager) (*P2PPeer, *P2PPeer, er
        A2B.SetConnection(&B2A, nodeB)
        B2A.SetConnection(&A2B, nodeA)
 
-       nodeA.handleStatusRequestMsg(&A2B)
-       nodeB.handleStatusRequestMsg(&B2A)
+       nodeA.AddPeer(&A2B)
+       nodeB.AddPeer(&B2A)
 
        A2B.setAsync(true)
        B2A.setAsync(true)