OSDN Git Service

Format netsync module code directory (#88)
authoryahtoo <yahtoo.ma@gmail.com>
Mon, 27 May 2019 07:30:22 +0000 (15:30 +0800)
committerPaladz <yzhu101@uottawa.ca>
Mon, 27 May 2019 07:30:22 +0000 (15:30 +0800)
15 files changed:
api/api.go
api/nodeinfo.go
netsync/block_fetcher.go [deleted file]
netsync/chainmgr/block_keeper.go [moved from netsync/block_keeper.go with 91% similarity]
netsync/chainmgr/block_keeper_test.go [moved from netsync/block_keeper_test.go with 92% similarity]
netsync/chainmgr/handle.go [new file with mode: 0644]
netsync/chainmgr/protocol_reactor.go [moved from netsync/protocol_reactor.go with 61% similarity]
netsync/chainmgr/tool_test.go [moved from netsync/tool_test.go with 85% similarity]
netsync/chainmgr/tx_keeper.go [moved from netsync/tx_keeper.go with 84% similarity]
netsync/handle.go [deleted file]
netsync/messages/chain_msg.go [moved from netsync/message.go with 94% similarity]
netsync/messages/chain_msg_test.go [moved from netsync/message_test.go with 99% similarity]
netsync/peers/peer.go [moved from netsync/peer.go with 51% similarity]
netsync/sync_manager.go [new file with mode: 0644]
p2p/switch.go

index 2396703..1574151 100644 (file)
@@ -24,7 +24,7 @@ import (
        "github.com/vapor/net/http/httpjson"
        "github.com/vapor/net/http/static"
        "github.com/vapor/net/websocket"
-       "github.com/vapor/netsync"
+       "github.com/vapor/netsync/peers"
        "github.com/vapor/p2p"
        "github.com/vapor/protocol"
        "github.com/vapor/wallet"
@@ -173,9 +173,9 @@ type NetSync interface {
        IsCaughtUp() bool
        PeerCount() int
        GetNetwork() string
-       BestPeer() *netsync.PeerInfo
+       BestPeer() *peers.PeerInfo
        DialPeerWithAddress(addr *p2p.NetAddress) error
-       GetPeerInfos() []*netsync.PeerInfo
+       GetPeerInfos() []*peers.PeerInfo
        StopPeer(peerID string) error
 }
 
index 6558aaa..7ddde83 100644 (file)
@@ -5,7 +5,7 @@ import (
        "net"
 
        "github.com/vapor/errors"
-       "github.com/vapor/netsync"
+       "github.com/vapor/netsync/peers"
        "github.com/vapor/p2p"
        "github.com/vapor/version"
 )
@@ -53,7 +53,7 @@ func (a *API) GetNodeInfo() *NetInfo {
 }
 
 // return the currently connected peers with net address
-func (a *API) getPeerInfoByAddr(addr string) *netsync.PeerInfo {
+func (a *API) getPeerInfoByAddr(addr string) *peers.PeerInfo {
        peerInfos := a.sync.GetPeerInfos()
        for _, peerInfo := range peerInfos {
                if peerInfo.RemoteAddr == addr {
@@ -69,7 +69,7 @@ func (a *API) disconnectPeerById(peerID string) error {
 }
 
 // connect peer b y net address
-func (a *API) connectPeerByIpAndPort(ip string, port uint16) (*netsync.PeerInfo, error) {
+func (a *API) connectPeerByIpAndPort(ip string, port uint16) (*peers.PeerInfo, error) {
        netIp := net.ParseIP(ip)
        if netIp == nil {
                return nil, errors.New("invalid ip address")
diff --git a/netsync/block_fetcher.go b/netsync/block_fetcher.go
deleted file mode 100644 (file)
index c48cd0b..0000000
+++ /dev/null
@@ -1,104 +0,0 @@
-package netsync
-
-import (
-       log "github.com/sirupsen/logrus"
-       "gopkg.in/karalabe/cookiejar.v2/collections/prque"
-
-       "github.com/vapor/protocol/bc"
-)
-
-const (
-       maxBlockDistance = 64
-       maxMsgSetSize    = 128
-       newBlockChSize   = 64
-)
-
-// blockFetcher is responsible for accumulating block announcements from various peers
-// and scheduling them for retrieval.
-type blockFetcher struct {
-       chain Chain
-       peers *peerSet
-
-       newBlockCh chan *blockMsg
-       queue      *prque.Prque
-       msgSet     map[bc.Hash]*blockMsg
-}
-
-//NewBlockFetcher creates a block fetcher to retrieve blocks of the new mined.
-func newBlockFetcher(chain Chain, peers *peerSet) *blockFetcher {
-       f := &blockFetcher{
-               chain:      chain,
-               peers:      peers,
-               newBlockCh: make(chan *blockMsg, newBlockChSize),
-               queue:      prque.New(),
-               msgSet:     make(map[bc.Hash]*blockMsg),
-       }
-       go f.blockProcessor()
-       return f
-}
-
-func (f *blockFetcher) blockProcessor() {
-       for {
-               height := f.chain.BestBlockHeight()
-               for !f.queue.Empty() {
-                       msg := f.queue.PopItem().(*blockMsg)
-                       if msg.block.Height > height+1 {
-                               f.queue.Push(msg, -float32(msg.block.Height))
-                               break
-                       }
-
-                       f.insert(msg)
-                       delete(f.msgSet, msg.block.Hash())
-               }
-               f.add(<-f.newBlockCh)
-       }
-}
-
-func (f *blockFetcher) add(msg *blockMsg) {
-       bestHeight := f.chain.BestBlockHeight()
-       if len(f.msgSet) > maxMsgSetSize || bestHeight > msg.block.Height || msg.block.Height-bestHeight > maxBlockDistance {
-               return
-       }
-
-       blockHash := msg.block.Hash()
-       if _, ok := f.msgSet[blockHash]; !ok {
-               f.msgSet[blockHash] = msg
-               f.queue.Push(msg, -float32(msg.block.Height))
-               log.WithFields(log.Fields{
-                       "module":       logModule,
-                       "block height": msg.block.Height,
-                       "block hash":   blockHash.String(),
-               }).Debug("blockFetcher receive mine block")
-       }
-}
-
-func (f *blockFetcher) insert(msg *blockMsg) {
-       isOrphan, err := f.chain.ProcessBlock(msg.block)
-       if err != nil {
-               peer := f.peers.getPeer(msg.peerID)
-               if peer == nil {
-                       return
-               }
-
-               f.peers.addBanScore(msg.peerID, 20, 0, err.Error())
-               return
-       }
-
-       if isOrphan {
-               return
-       }
-
-       if err := f.peers.broadcastMinedBlock(msg.block); err != nil {
-               log.WithFields(log.Fields{"module": logModule, "err": err}).Error("blockFetcher fail on broadcast new block")
-               return
-       }
-
-       if err := f.peers.broadcastNewStatus(msg.block); err != nil {
-               log.WithFields(log.Fields{"module": logModule, "err": err}).Error("blockFetcher fail on broadcast new status")
-               return
-       }
-}
-
-func (f *blockFetcher) processNewBlock(msg *blockMsg) {
-       f.newBlockCh <- msg
-}
similarity index 91%
rename from netsync/block_keeper.go
rename to netsync/chainmgr/block_keeper.go
index 9b61081..8ce20ea 100644 (file)
@@ -1,4 +1,4 @@
-package netsync
+package chainmgr
 
 import (
        "container/list"
@@ -8,6 +8,7 @@ import (
 
        "github.com/vapor/consensus"
        "github.com/vapor/errors"
+       "github.com/vapor/netsync/peers"
        "github.com/vapor/protocol/bc"
        "github.com/vapor/protocol/bc/types"
 )
@@ -27,7 +28,6 @@ var (
        errAppendHeaders  = errors.New("fail to append list due to order dismatch")
        errRequestTimeout = errors.New("request timeout")
        errPeerDropped    = errors.New("Peer dropped")
-       errPeerMisbehave  = errors.New("peer is misbehave")
 )
 
 type blockMsg struct {
@@ -47,9 +47,9 @@ type headersMsg struct {
 
 type blockKeeper struct {
        chain Chain
-       peers *peerSet
+       peers *peers.PeerSet
 
-       syncPeer         *peer
+       syncPeer         *peers.Peer
        blockProcessCh   chan *blockMsg
        blocksProcessCh  chan *blocksMsg
        headersProcessCh chan *headersMsg
@@ -57,7 +57,7 @@ type blockKeeper struct {
        headerList *list.List
 }
 
-func newBlockKeeper(chain Chain, peers *peerSet) *blockKeeper {
+func newBlockKeeper(chain Chain, peers *peers.PeerSet) *blockKeeper {
        bk := &blockKeeper{
                chain:            chain,
                peers:            peers,
@@ -117,7 +117,7 @@ func (bk *blockKeeper) fastBlockSync(checkPoint *consensus.Checkpoint) error {
        lastHeader := bk.headerList.Back().Value.(*types.BlockHeader)
        for ; lastHeader.Hash() != checkPoint.Hash; lastHeader = bk.headerList.Back().Value.(*types.BlockHeader) {
                if lastHeader.Height >= checkPoint.Height {
-                       return errors.Wrap(errPeerMisbehave, "peer is not in the checkpoint branch")
+                       return errors.Wrap(peers.ErrPeerMisbehave, "peer is not in the checkpoint branch")
                }
 
                lastHash := lastHeader.Hash()
@@ -127,7 +127,7 @@ func (bk *blockKeeper) fastBlockSync(checkPoint *consensus.Checkpoint) error {
                }
 
                if len(headers) == 0 {
-                       return errors.Wrap(errPeerMisbehave, "requireHeaders return empty list")
+                       return errors.Wrap(peers.ErrPeerMisbehave, "requireHeaders return empty list")
                }
 
                if err := bk.appendHeaderList(headers); err != nil {
@@ -144,7 +144,7 @@ func (bk *blockKeeper) fastBlockSync(checkPoint *consensus.Checkpoint) error {
                }
 
                if len(blocks) == 0 {
-                       return errors.Wrap(errPeerMisbehave, "requireBlocks return empty list")
+                       return errors.Wrap(peers.ErrPeerMisbehave, "requireBlocks return empty list")
                }
 
                for _, block := range blocks {
@@ -271,7 +271,7 @@ func (bk *blockKeeper) regularBlockSync(wantHeight uint64) error {
 }
 
 func (bk *blockKeeper) requireBlock(height uint64) (*types.Block, error) {
-       if ok := bk.syncPeer.getBlockByHeight(height); !ok {
+       if ok := bk.syncPeer.GetBlockByHeight(height); !ok {
                return nil, errPeerDropped
        }
 
@@ -295,7 +295,7 @@ func (bk *blockKeeper) requireBlock(height uint64) (*types.Block, error) {
 }
 
 func (bk *blockKeeper) requireBlocks(locator []*bc.Hash, stopHash *bc.Hash) ([]*types.Block, error) {
-       if ok := bk.syncPeer.getBlocks(locator, stopHash); !ok {
+       if ok := bk.syncPeer.GetBlocks(locator, stopHash); !ok {
                return nil, errPeerDropped
        }
 
@@ -316,7 +316,7 @@ func (bk *blockKeeper) requireBlocks(locator []*bc.Hash, stopHash *bc.Hash) ([]*
 }
 
 func (bk *blockKeeper) requireHeaders(locator []*bc.Hash, stopHash *bc.Hash) ([]*types.BlockHeader, error) {
-       if ok := bk.syncPeer.getHeaders(locator, stopHash); !ok {
+       if ok := bk.syncPeer.GetHeaders(locator, stopHash); !ok {
                return nil, errPeerDropped
        }
 
@@ -348,19 +348,19 @@ func (bk *blockKeeper) resetHeaderState() {
 
 func (bk *blockKeeper) startSync() bool {
        checkPoint := bk.nextCheckpoint()
-       peer := bk.peers.bestPeer(consensus.SFFastSync | consensus.SFFullNode)
+       peer := bk.peers.BestPeer(consensus.SFFastSync | consensus.SFFullNode)
        if peer != nil && checkPoint != nil && peer.Height() >= checkPoint.Height {
                bk.syncPeer = peer
                if err := bk.fastBlockSync(checkPoint); err != nil {
                        log.WithFields(log.Fields{"module": logModule, "err": err}).Warning("fail on fastBlockSync")
-                       bk.peers.errorHandler(peer.ID(), err)
+                       bk.peers.ErrorHandler(peer.ID(), err)
                        return false
                }
                return true
        }
 
        blockHeight := bk.chain.BestBlockHeight()
-       peer = bk.peers.bestPeer(consensus.SFFullNode)
+       peer = bk.peers.BestPeer(consensus.SFFullNode)
        if peer != nil && peer.Height() > blockHeight {
                bk.syncPeer = peer
                targetHeight := blockHeight + maxBlockPerMsg
@@ -370,7 +370,7 @@ func (bk *blockKeeper) startSync() bool {
 
                if err := bk.regularBlockSync(targetHeight); err != nil {
                        log.WithFields(log.Fields{"module": logModule, "err": err}).Warning("fail on regularBlockSync")
-                       bk.peers.errorHandler(peer.ID(), err)
+                       bk.peers.ErrorHandler(peer.ID(), err)
                        return false
                }
                return true
@@ -393,7 +393,7 @@ func (bk *blockKeeper) syncWorker() {
                        log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on syncWorker get best block")
                }
 
-               if err = bk.peers.broadcastNewStatus(block); err != nil {
+               if err = bk.peers.BroadcastNewStatus(block); err != nil {
                        log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on syncWorker broadcast new status")
                }
        }
similarity index 92%
rename from netsync/block_keeper_test.go
rename to netsync/chainmgr/block_keeper_test.go
index 68c4cdd..2be129b 100644 (file)
@@ -1,14 +1,14 @@
-package netsync
+package chainmgr
 
 import (
        "container/list"
-       "encoding/hex"
        "encoding/json"
        "testing"
        "time"
 
        "github.com/vapor/consensus"
        "github.com/vapor/errors"
+       msgs "github.com/vapor/netsync/messages"
        "github.com/vapor/protocol/bc"
        "github.com/vapor/protocol/bc/types"
        "github.com/vapor/test/mock"
@@ -190,7 +190,7 @@ func TestFastBlockSync(t *testing.T) {
                        go A2B.postMan()
                }
 
-               a.blockKeeper.syncPeer = a.peers.getPeer("test node B")
+               a.blockKeeper.syncPeer = a.peers.GetPeer("test node B")
                if err := a.blockKeeper.fastBlockSync(c.checkPoint); errors.Root(err) != c.err {
                        t.Errorf("case %d: got %v want %v", i, err, c.err)
                }
@@ -345,34 +345,34 @@ func TestNextCheckpoint(t *testing.T) {
                },
                {
                        checkPoints: []consensus.Checkpoint{
-                               {10000, bc.Hash{V0: 1}},
+                               {Height: 10000, Hash: bc.Hash{V0: 1}},
                        },
                        bestHeight: 5000,
-                       want:       &consensus.Checkpoint{10000, bc.Hash{V0: 1}},
+                       want:       &consensus.Checkpoint{Height: 10000, Hash: bc.Hash{V0: 1}},
                },
                {
                        checkPoints: []consensus.Checkpoint{
-                               {10000, bc.Hash{V0: 1}},
-                               {20000, bc.Hash{V0: 2}},
-                               {30000, bc.Hash{V0: 3}},
+                               {Height: 10000, Hash: bc.Hash{V0: 1}},
+                               {Height: 20000, Hash: bc.Hash{V0: 2}},
+                               {Height: 30000, Hash: bc.Hash{V0: 3}},
                        },
                        bestHeight: 15000,
-                       want:       &consensus.Checkpoint{20000, bc.Hash{V0: 2}},
+                       want:       &consensus.Checkpoint{Height: 20000, Hash: bc.Hash{V0: 2}},
                },
                {
                        checkPoints: []consensus.Checkpoint{
-                               {10000, bc.Hash{V0: 1}},
-                               {20000, bc.Hash{V0: 2}},
-                               {30000, bc.Hash{V0: 3}},
+                               {Height: 10000, Hash: bc.Hash{V0: 1}},
+                               {Height: 20000, Hash: bc.Hash{V0: 2}},
+                               {Height: 30000, Hash: bc.Hash{V0: 3}},
                        },
                        bestHeight: 10000,
-                       want:       &consensus.Checkpoint{20000, bc.Hash{V0: 2}},
+                       want:       &consensus.Checkpoint{Height: 20000, Hash: bc.Hash{V0: 2}},
                },
                {
                        checkPoints: []consensus.Checkpoint{
-                               {10000, bc.Hash{V0: 1}},
-                               {20000, bc.Hash{V0: 2}},
-                               {30000, bc.Hash{V0: 3}},
+                               {Height: 10000, Hash: bc.Hash{V0: 1}},
+                               {Height: 20000, Hash: bc.Hash{V0: 2}},
+                               {Height: 30000, Hash: bc.Hash{V0: 3}},
                        },
                        bestHeight: 35000,
                        want:       nil,
@@ -451,7 +451,7 @@ func TestRegularBlockSync(t *testing.T) {
                        go A2B.postMan()
                }
 
-               a.blockKeeper.syncPeer = a.peers.getPeer("test node B")
+               a.blockKeeper.syncPeer = a.peers.GetPeer("test node B")
                if err := a.blockKeeper.regularBlockSync(c.syncHeight); errors.Root(err) != c.err {
                        t.Errorf("case %d: got %v want %v", i, err, c.err)
                }
@@ -485,11 +485,11 @@ func TestRequireBlock(t *testing.T) {
                go A2B.postMan()
        }
 
-       a.blockKeeper.syncPeer = a.peers.getPeer("test node B")
-       b.blockKeeper.syncPeer = b.peers.getPeer("test node A")
+       a.blockKeeper.syncPeer = a.peers.GetPeer("test node B")
+       b.blockKeeper.syncPeer = b.peers.GetPeer("test node A")
        cases := []struct {
                syncTimeout   time.Duration
-               testNode      *SyncManager
+               testNode      *ChainManager
                requireHeight uint64
                want          *types.Block
                err           error
@@ -584,9 +584,9 @@ func TestSendMerkleBlock(t *testing.T) {
                completed := make(chan error)
                go func() {
                        msgBytes := <-F2S.msgCh
-                       _, msg, _ := DecodeMessage(msgBytes)
+                       _, msg, _ := decodeMessage(msgBytes)
                        switch m := msg.(type) {
-                       case *MerkleBlockMessage:
+                       case *msgs.MerkleBlockMessage:
                                var relatedTxIDs []*bc.Hash
                                for _, rawTx := range m.RawTxDatas {
                                        tx := &types.Tx{}
@@ -627,11 +627,11 @@ func TestSendMerkleBlock(t *testing.T) {
                        }
                }()
 
-               spvPeer := fullNode.peers.getPeer("spv_node")
+               spvPeer := fullNode.peers.GetPeer("spv_node")
                for i := 0; i < len(c.relatedTxIndex); i++ {
-                       spvPeer.filterAdds.Add(hex.EncodeToString(txs[c.relatedTxIndex[i]].Outputs[0].ControlProgram()))
+                       spvPeer.AddFilterAddress(txs[c.relatedTxIndex[i]].Outputs[0].ControlProgram())
                }
-               msg := &GetMerkleBlockMessage{RawHash: targetBlock.Hash().Byte32()}
+               msg := &msgs.GetMerkleBlockMessage{RawHash: targetBlock.Hash().Byte32()}
                fullNode.handleGetMerkleBlockMsg(spvPeer, msg)
                if err := <-completed; err != nil {
                        t.Fatal(err)
diff --git a/netsync/chainmgr/handle.go b/netsync/chainmgr/handle.go
new file mode 100644 (file)
index 0000000..8f7d679
--- /dev/null
@@ -0,0 +1,372 @@
+package chainmgr
+
+import (
+       "errors"
+       "reflect"
+
+       log "github.com/sirupsen/logrus"
+
+       cfg "github.com/vapor/config"
+       "github.com/vapor/consensus"
+       "github.com/vapor/event"
+       msgs "github.com/vapor/netsync/messages"
+       "github.com/vapor/netsync/peers"
+       "github.com/vapor/p2p"
+       core "github.com/vapor/protocol"
+       "github.com/vapor/protocol/bc"
+       "github.com/vapor/protocol/bc/types"
+)
+
+const (
+       logModule = "netsync"
+)
+
+// Chain is the interface for Bytom core
+type Chain interface {
+       BestBlockHeader() *types.BlockHeader
+       BestBlockHeight() uint64
+       GetBlockByHash(*bc.Hash) (*types.Block, error)
+       GetBlockByHeight(uint64) (*types.Block, error)
+       GetHeaderByHash(*bc.Hash) (*types.BlockHeader, error)
+       GetHeaderByHeight(uint64) (*types.BlockHeader, error)
+       GetTransactionStatus(*bc.Hash) (*bc.TransactionStatus, error)
+       InMainChain(bc.Hash) bool
+       ProcessBlock(*types.Block) (bool, error)
+       ValidateTx(*types.Tx) (bool, error)
+}
+
+type Switch interface {
+       AddReactor(name string, reactor p2p.Reactor) p2p.Reactor
+       AddBannedPeer(string) error
+       Start() (bool, error)
+       Stop() bool
+       IsListening() bool
+       DialPeerWithAddress(addr *p2p.NetAddress) error
+       Peers() *p2p.PeerSet
+}
+
+//ChainManager is responsible for the business layer information synchronization
+type ChainManager struct {
+       sw          Switch
+       chain       Chain
+       txPool      *core.TxPool
+       blockKeeper *blockKeeper
+       peers       *peers.PeerSet
+
+       txSyncCh chan *txSyncMsg
+       quitSync chan struct{}
+       config   *cfg.Config
+
+       eventDispatcher *event.Dispatcher
+       txMsgSub        *event.Subscription
+}
+
+//NewChainManager create a chain sync manager.
+func NewChainManager(config *cfg.Config, sw Switch, chain Chain, txPool *core.TxPool, dispatcher *event.Dispatcher, peers *peers.PeerSet) (*ChainManager, error) {
+       manager := &ChainManager{
+               sw:              sw,
+               txPool:          txPool,
+               chain:           chain,
+               blockKeeper:     newBlockKeeper(chain, peers),
+               peers:           peers,
+               txSyncCh:        make(chan *txSyncMsg),
+               quitSync:        make(chan struct{}),
+               config:          config,
+               eventDispatcher: dispatcher,
+       }
+
+       if !config.VaultMode {
+               protocolReactor := NewProtocolReactor(manager)
+               manager.sw.AddReactor("PROTOCOL", protocolReactor)
+       }
+       return manager, nil
+}
+
+func (cm *ChainManager) AddPeer(peer peers.BasePeer) {
+       cm.peers.AddPeer(peer)
+}
+
+//IsCaughtUp check wheather the peer finish the sync
+func (cm *ChainManager) IsCaughtUp() bool {
+       peer := cm.peers.BestPeer(consensus.SFFullNode)
+       return peer == nil || peer.Height() <= cm.chain.BestBlockHeight()
+}
+
+func (cm *ChainManager) handleBlockMsg(peer *peers.Peer, msg *msgs.BlockMessage) {
+       block, err := msg.GetBlock()
+       if err != nil {
+               return
+       }
+       cm.blockKeeper.processBlock(peer.ID(), block)
+}
+
+func (cm *ChainManager) handleBlocksMsg(peer *peers.Peer, msg *msgs.BlocksMessage) {
+       blocks, err := msg.GetBlocks()
+       if err != nil {
+               log.WithFields(log.Fields{"module": logModule, "err": err}).Debug("fail on handleBlocksMsg GetBlocks")
+               return
+       }
+
+       cm.blockKeeper.processBlocks(peer.ID(), blocks)
+}
+
+func (cm *ChainManager) handleFilterAddMsg(peer *peers.Peer, msg *msgs.FilterAddMessage) {
+       peer.AddFilterAddress(msg.Address)
+}
+
+func (cm *ChainManager) handleFilterClearMsg(peer *peers.Peer) {
+       peer.FilterClear()
+}
+
+func (cm *ChainManager) handleFilterLoadMsg(peer *peers.Peer, msg *msgs.FilterLoadMessage) {
+       peer.AddFilterAddresses(msg.Addresses)
+}
+
+func (cm *ChainManager) handleGetBlockMsg(peer *peers.Peer, msg *msgs.GetBlockMessage) {
+       var block *types.Block
+       var err error
+       if msg.Height != 0 {
+               block, err = cm.chain.GetBlockByHeight(msg.Height)
+       } else {
+               block, err = cm.chain.GetBlockByHash(msg.GetHash())
+       }
+       if err != nil {
+               log.WithFields(log.Fields{"module": logModule, "err": err}).Warning("fail on handleGetBlockMsg get block from chain")
+               return
+       }
+
+       ok, err := peer.SendBlock(block)
+       if !ok {
+               cm.peers.RemovePeer(peer.ID())
+       }
+       if err != nil {
+               log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on handleGetBlockMsg sentBlock")
+       }
+}
+
+func (cm *ChainManager) handleGetBlocksMsg(peer *peers.Peer, msg *msgs.GetBlocksMessage) {
+       blocks, err := cm.blockKeeper.locateBlocks(msg.GetBlockLocator(), msg.GetStopHash())
+       if err != nil || len(blocks) == 0 {
+               return
+       }
+
+       totalSize := 0
+       sendBlocks := []*types.Block{}
+       for _, block := range blocks {
+               rawData, err := block.MarshalText()
+               if err != nil {
+                       log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on handleGetBlocksMsg marshal block")
+                       continue
+               }
+
+               if totalSize+len(rawData) > msgs.MaxBlockchainResponseSize/2 {
+                       break
+               }
+               totalSize += len(rawData)
+               sendBlocks = append(sendBlocks, block)
+       }
+
+       ok, err := peer.SendBlocks(sendBlocks)
+       if !ok {
+               cm.peers.RemovePeer(peer.ID())
+       }
+       if err != nil {
+               log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on handleGetBlocksMsg sentBlock")
+       }
+}
+
+func (cm *ChainManager) handleGetHeadersMsg(peer *peers.Peer, msg *msgs.GetHeadersMessage) {
+       headers, err := cm.blockKeeper.locateHeaders(msg.GetBlockLocator(), msg.GetStopHash())
+       if err != nil || len(headers) == 0 {
+               log.WithFields(log.Fields{"module": logModule, "err": err}).Debug("fail on handleGetHeadersMsg locateHeaders")
+               return
+       }
+
+       ok, err := peer.SendHeaders(headers)
+       if !ok {
+               cm.peers.RemovePeer(peer.ID())
+       }
+       if err != nil {
+               log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on handleGetHeadersMsg sentBlock")
+       }
+}
+
+func (cm *ChainManager) handleGetMerkleBlockMsg(peer *peers.Peer, msg *msgs.GetMerkleBlockMessage) {
+       var err error
+       var block *types.Block
+       if msg.Height != 0 {
+               block, err = cm.chain.GetBlockByHeight(msg.Height)
+       } else {
+               block, err = cm.chain.GetBlockByHash(msg.GetHash())
+       }
+       if err != nil {
+               log.WithFields(log.Fields{"module": logModule, "err": err}).Warning("fail on handleGetMerkleBlockMsg get block from chain")
+               return
+       }
+
+       blockHash := block.Hash()
+       txStatus, err := cm.chain.GetTransactionStatus(&blockHash)
+       if err != nil {
+               log.WithFields(log.Fields{"module": logModule, "err": err}).Warning("fail on handleGetMerkleBlockMsg get transaction status")
+               return
+       }
+
+       ok, err := peer.SendMerkleBlock(block, txStatus)
+       if err != nil {
+               log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on handleGetMerkleBlockMsg sentMerkleBlock")
+               return
+       }
+
+       if !ok {
+               cm.peers.RemovePeer(peer.ID())
+       }
+}
+
+func (cm *ChainManager) handleHeadersMsg(peer *peers.Peer, msg *msgs.HeadersMessage) {
+       headers, err := msg.GetHeaders()
+       if err != nil {
+               log.WithFields(log.Fields{"module": logModule, "err": err}).Debug("fail on handleHeadersMsg GetHeaders")
+               return
+       }
+
+       cm.blockKeeper.processHeaders(peer.ID(), headers)
+}
+
+func (cm *ChainManager) handleStatusMsg(basePeer peers.BasePeer, msg *msgs.StatusMessage) {
+       if peer := cm.peers.GetPeer(basePeer.ID()); peer != nil {
+               peer.SetStatus(msg.Height, msg.GetHash())
+               return
+       }
+}
+
+func (cm *ChainManager) handleTransactionMsg(peer *peers.Peer, msg *msgs.TransactionMessage) {
+       tx, err := msg.GetTransaction()
+       if err != nil {
+               cm.peers.AddBanScore(peer.ID(), 0, 10, "fail on get tx from message")
+               return
+       }
+
+       if isOrphan, err := cm.chain.ValidateTx(tx); err != nil && err != core.ErrDustTx && !isOrphan {
+               cm.peers.AddBanScore(peer.ID(), 10, 0, "fail on validate tx transaction")
+       }
+       cm.peers.MarkTx(peer.ID(), tx.ID)
+}
+
+func (cm *ChainManager) handleTransactionsMsg(peer *peers.Peer, msg *msgs.TransactionsMessage) {
+       txs, err := msg.GetTransactions()
+       if err != nil {
+               cm.peers.AddBanScore(peer.ID(), 0, 20, "fail on get txs from message")
+               return
+       }
+
+       if len(txs) > msgs.TxsMsgMaxTxNum {
+               cm.peers.AddBanScore(peer.ID(), 20, 0, "exceeded the maximum tx number limit")
+               return
+       }
+
+       for _, tx := range txs {
+               if isOrphan, err := cm.chain.ValidateTx(tx); err != nil && !isOrphan {
+                       cm.peers.AddBanScore(peer.ID(), 10, 0, "fail on validate tx transaction")
+                       return
+               }
+               cm.peers.MarkTx(peer.ID(), tx.ID)
+       }
+}
+
+func (cm *ChainManager) processMsg(basePeer peers.BasePeer, msgType byte, msg msgs.BlockchainMessage) {
+       peer := cm.peers.GetPeer(basePeer.ID())
+       if peer == nil {
+               return
+       }
+
+       log.WithFields(log.Fields{
+               "module":  logModule,
+               "peer":    basePeer.Addr(),
+               "type":    reflect.TypeOf(msg),
+               "message": msg.String(),
+       }).Info("receive message from peer")
+
+       switch msg := msg.(type) {
+       case *msgs.GetBlockMessage:
+               cm.handleGetBlockMsg(peer, msg)
+
+       case *msgs.BlockMessage:
+               cm.handleBlockMsg(peer, msg)
+
+       case *msgs.StatusMessage:
+               cm.handleStatusMsg(basePeer, msg)
+
+       case *msgs.TransactionMessage:
+               cm.handleTransactionMsg(peer, msg)
+
+       case *msgs.TransactionsMessage:
+               cm.handleTransactionsMsg(peer, msg)
+
+       case *msgs.GetHeadersMessage:
+               cm.handleGetHeadersMsg(peer, msg)
+
+       case *msgs.HeadersMessage:
+               cm.handleHeadersMsg(peer, msg)
+
+       case *msgs.GetBlocksMessage:
+               cm.handleGetBlocksMsg(peer, msg)
+
+       case *msgs.BlocksMessage:
+               cm.handleBlocksMsg(peer, msg)
+
+       case *msgs.FilterLoadMessage:
+               cm.handleFilterLoadMsg(peer, msg)
+
+       case *msgs.FilterAddMessage:
+               cm.handleFilterAddMsg(peer, msg)
+
+       case *msgs.FilterClearMessage:
+               cm.handleFilterClearMsg(peer)
+
+       case *msgs.GetMerkleBlockMessage:
+               cm.handleGetMerkleBlockMsg(peer, msg)
+
+       default:
+               log.WithFields(log.Fields{
+                       "module":       logModule,
+                       "peer":         basePeer.Addr(),
+                       "message_type": reflect.TypeOf(msg),
+               }).Error("unhandled message type")
+       }
+}
+
+func (cm *ChainManager) RemovePeer(peerID string) {
+       cm.peers.RemovePeer(peerID)
+}
+
+func (cm *ChainManager) SendStatus(peer peers.BasePeer) error {
+       p := cm.peers.GetPeer(peer.ID())
+       if p == nil {
+               return errors.New("invalid peer")
+       }
+
+       if err := p.SendStatus(cm.chain.BestBlockHeader()); err != nil {
+               cm.peers.RemovePeer(p.ID())
+               return err
+       }
+       return nil
+}
+
+func (cm *ChainManager) Start() error {
+       var err error
+       cm.txMsgSub, err = cm.eventDispatcher.Subscribe(core.TxMsgEvent{})
+       if err != nil {
+               return err
+       }
+
+       // broadcast transactions
+       go cm.txBroadcastLoop()
+       go cm.txSyncLoop()
+
+       return nil
+}
+
+//Stop stop sync manager
+func (cm *ChainManager) Stop() {
+       close(cm.quitSync)
+}
similarity index 61%
rename from netsync/protocol_reactor.go
rename to netsync/chainmgr/protocol_reactor.go
index 8a6c610..86987fb 100644 (file)
@@ -1,38 +1,28 @@
-package netsync
+package chainmgr
 
 import (
-       "time"
+       "bytes"
 
        log "github.com/sirupsen/logrus"
+       "github.com/tendermint/go-wire"
 
        "github.com/vapor/errors"
+       msgs "github.com/vapor/netsync/messages"
        "github.com/vapor/p2p"
        "github.com/vapor/p2p/connection"
 )
 
-const (
-       handshakeTimeout    = 10 * time.Second
-       handshakeCheckPerid = 500 * time.Millisecond
-)
-
-var (
-       errProtocolHandshakeTimeout = errors.New("Protocol handshake timeout")
-       errStatusRequest            = errors.New("Status request error")
-)
-
 //ProtocolReactor handles new coming protocol message.
 type ProtocolReactor struct {
        p2p.BaseReactor
 
-       sm    *SyncManager
-       peers *peerSet
+       cm *ChainManager
 }
 
 // NewProtocolReactor returns the reactor of whole blockchain.
-func NewProtocolReactor(sm *SyncManager, peers *peerSet) *ProtocolReactor {
+func NewProtocolReactor(cm *ChainManager) *ProtocolReactor {
        pr := &ProtocolReactor{
-               sm:    sm,
-               peers: peers,
+               cm: cm,
        }
        pr.BaseReactor = *p2p.NewBaseReactor("ProtocolReactor", pr)
        return pr
@@ -42,7 +32,7 @@ func NewProtocolReactor(sm *SyncManager, peers *peerSet) *ProtocolReactor {
 func (pr *ProtocolReactor) GetChannels() []*connection.ChannelDescriptor {
        return []*connection.ChannelDescriptor{
                {
-                       ID:                BlockchainChannel,
+                       ID:                msgs.BlockchainChannel,
                        Priority:          5,
                        SendQueueCapacity: 100,
                },
@@ -62,26 +52,38 @@ func (pr *ProtocolReactor) OnStop() {
 
 // AddPeer implements Reactor by sending our state to peer.
 func (pr *ProtocolReactor) AddPeer(peer *p2p.Peer) error {
-       pr.sm.AddPeer(peer)
-       if err := pr.sm.SendStatus(peer); err != nil {
+       pr.cm.AddPeer(peer)
+       if err := pr.cm.SendStatus(peer); err != nil {
                return err
        }
-       pr.sm.syncTransactions(peer.Key)
+       pr.cm.syncTransactions(peer.Key)
        return nil
 }
 
 // RemovePeer implements Reactor by removing peer from the pool.
 func (pr *ProtocolReactor) RemovePeer(peer *p2p.Peer, reason interface{}) {
-       pr.peers.removePeer(peer.Key)
+       pr.cm.RemovePeer(peer.Key)
+}
+
+//decodeMessage decode msg
+func decodeMessage(bz []byte) (msgType byte, msg msgs.BlockchainMessage, err error) {
+       msgType = bz[0]
+       n := int(0)
+       r := bytes.NewReader(bz)
+       msg = wire.ReadBinary(struct{ msgs.BlockchainMessage }{}, r, msgs.MaxBlockchainResponseSize, &n, &err).(struct{ msgs.BlockchainMessage }).BlockchainMessage
+       if err != nil && n != len(bz) {
+               err = errors.New("DecodeMessage() had bytes left over")
+       }
+       return
 }
 
 // Receive implements Reactor by handling 4 types of messages (look below).
 func (pr *ProtocolReactor) Receive(chID byte, src *p2p.Peer, msgBytes []byte) {
-       msgType, msg, err := DecodeMessage(msgBytes)
+       msgType, msg, err := decodeMessage(msgBytes)
        if err != nil {
                log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on reactor decoding message")
                return
        }
 
-       pr.sm.processMsg(src, msgType, msg)
+       pr.cm.processMsg(src, msgType, msg)
 }
similarity index 85%
rename from netsync/tool_test.go
rename to netsync/chainmgr/tool_test.go
index e817930..34f0cc6 100644 (file)
@@ -1,4 +1,4 @@
-package netsync
+package chainmgr
 
 import (
        "errors"
@@ -9,6 +9,7 @@ import (
        "github.com/tendermint/tmlibs/flowrate"
 
        "github.com/vapor/consensus"
+       "github.com/vapor/netsync/peers"
        "github.com/vapor/protocol/bc"
        "github.com/vapor/protocol/bc/types"
        "github.com/vapor/test/mock"
@@ -20,7 +21,7 @@ type P2PPeer struct {
        flag consensus.ServiceFlag
 
        srcPeer    *P2PPeer
-       remoteNode *SyncManager
+       remoteNode *ChainManager
        msgCh      chan []byte
        async      bool
 }
@@ -51,7 +52,7 @@ func (p *P2PPeer) ServiceFlag() consensus.ServiceFlag {
        return p.flag
 }
 
-func (p *P2PPeer) SetConnection(srcPeer *P2PPeer, node *SyncManager) {
+func (p *P2PPeer) SetConnection(srcPeer *P2PPeer, node *ChainManager) {
        p.srcPeer = srcPeer
        p.remoteNode = node
 }
@@ -65,7 +66,7 @@ func (p *P2PPeer) TrySend(b byte, msg interface{}) bool {
        if p.async {
                p.msgCh <- msgBytes
        } else {
-               msgType, msg, _ := DecodeMessage(msgBytes)
+               msgType, msg, _ := decodeMessage(msgBytes)
                p.remoteNode.processMsg(p.srcPeer, msgType, msg)
        }
        return true
@@ -77,7 +78,7 @@ func (p *P2PPeer) setAsync(b bool) {
 
 func (p *P2PPeer) postMan() {
        for msgBytes := range p.msgCh {
-               msgType, msg, _ := DecodeMessage(msgBytes)
+               msgType, msg, _ := decodeMessage(msgBytes)
                p.remoteNode.processMsg(p.srcPeer, msgType, msg)
        }
 }
@@ -92,19 +93,19 @@ func (ps *PeerSet) AddBannedPeer(string) error { return nil }
 func (ps *PeerSet) StopPeerGracefully(string)  {}
 
 type NetWork struct {
-       nodes map[*SyncManager]P2PPeer
+       nodes map[*ChainManager]P2PPeer
 }
 
 func NewNetWork() *NetWork {
-       return &NetWork{map[*SyncManager]P2PPeer{}}
+       return &NetWork{map[*ChainManager]P2PPeer{}}
 }
 
-func (nw *NetWork) Register(node *SyncManager, addr, id string, flag consensus.ServiceFlag) {
+func (nw *NetWork) Register(node *ChainManager, addr, id string, flag consensus.ServiceFlag) {
        peer := NewP2PPeer(addr, id, flag)
        nw.nodes[node] = *peer
 }
 
-func (nw *NetWork) HandsShake(nodeA, nodeB *SyncManager) (*P2PPeer, *P2PPeer, error) {
+func (nw *NetWork) HandsShake(nodeA, nodeB *ChainManager) (*P2PPeer, *P2PPeer, error) {
        B2A, ok := nw.nodes[nodeA]
        if !ok {
                return nil, nil, errors.New("can't find nodeA's p2p peer on network")
@@ -149,17 +150,15 @@ func mockBlocks(startBlock *types.Block, height uint64) []*types.Block {
        return blocks
 }
 
-func mockSync(blocks []*types.Block) *SyncManager {
+func mockSync(blocks []*types.Block) *ChainManager {
        chain := mock.NewChain()
-       peers := newPeerSet(NewPeerSet())
+       peers := peers.NewPeerSet(NewPeerSet())
        chain.SetBestBlockHeader(&blocks[len(blocks)-1].BlockHeader)
        for _, block := range blocks {
                chain.SetBlockByHeight(block.Height, block)
        }
 
-       genesis, _ := chain.GetHeaderByHeight(0)
-       return &SyncManager{
-               genesisHash: genesis.Hash(),
+       return &ChainManager{
                chain:       chain,
                blockKeeper: newBlockKeeper(chain, peers),
                peers:       peers,
similarity index 84%
rename from netsync/tx_keeper.go
rename to netsync/chainmgr/tx_keeper.go
index 5b95ba9..9f1a7cc 100644 (file)
@@ -1,4 +1,4 @@
-package netsync
+package chainmgr
 
 import (
        "math/rand"
@@ -21,8 +21,8 @@ type txSyncMsg struct {
        txs    []*types.Tx
 }
 
-func (sm *SyncManager) syncTransactions(peerID string) {
-       pending := sm.txPool.GetTransactions()
+func (cm *ChainManager) syncTransactions(peerID string) {
+       pending := cm.txPool.GetTransactions()
        if len(pending) == 0 {
                return
        }
@@ -31,13 +31,13 @@ func (sm *SyncManager) syncTransactions(peerID string) {
        for i, batch := range pending {
                txs[i] = batch.Tx
        }
-       sm.txSyncCh <- &txSyncMsg{peerID, txs}
+       cm.txSyncCh <- &txSyncMsg{peerID, txs}
 }
 
-func (sm *SyncManager) txBroadcastLoop() {
+func (cm *ChainManager) txBroadcastLoop() {
        for {
                select {
-               case obj, ok := <-sm.txMsgSub.Chan():
+               case obj, ok := <-cm.txMsgSub.Chan():
                        if !ok {
                                log.WithFields(log.Fields{"module": logModule}).Warning("mempool tx msg subscription channel closed")
                                return
@@ -50,12 +50,12 @@ func (sm *SyncManager) txBroadcastLoop() {
                        }
 
                        if ev.TxMsg.MsgType == core.MsgNewTx {
-                               if err := sm.peers.broadcastTx(ev.TxMsg.Tx); err != nil {
+                               if err := cm.peers.BroadcastTx(ev.TxMsg.Tx); err != nil {
                                        log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on broadcast new tx.")
                                        continue
                                }
                        }
-               case <-sm.quitSync:
+               case <-cm.quitSync:
                        return
                }
        }
@@ -65,14 +65,14 @@ func (sm *SyncManager) txBroadcastLoop() {
 // connection. When a new peer appears, we relay all currently pending
 // transactions. In order to minimise egress bandwidth usage, we send
 // the transactions in small packs to one peer at a time.
-func (sm *SyncManager) txSyncLoop() {
+func (cm *ChainManager) txSyncLoop() {
        pending := make(map[string]*txSyncMsg)
        sending := false            // whether a send is active
        done := make(chan error, 1) // result of the send
 
        // send starts a sending a pack of transactions from the sync.
        send := func(msg *txSyncMsg) {
-               peer := sm.peers.getPeer(msg.peerID)
+               peer := cm.peers.GetPeer(msg.peerID)
                if peer == nil {
                        delete(pending, msg.peerID)
                        return
@@ -100,9 +100,9 @@ func (sm *SyncManager) txSyncLoop() {
                }).Debug("txSyncLoop sending transactions")
                sending = true
                go func() {
-                       err := peer.sendTransactions(sendTxs)
+                       err := peer.SendTransactions(sendTxs)
                        if err != nil {
-                               sm.peers.removePeer(msg.peerID)
+                               cm.peers.RemovePeer(msg.peerID)
                        }
                        done <- err
                }()
@@ -125,7 +125,7 @@ func (sm *SyncManager) txSyncLoop() {
 
        for {
                select {
-               case msg := <-sm.txSyncCh:
+               case msg := <-cm.txSyncCh:
                        pending[msg.peerID] = msg
                        if !sending {
                                send(msg)
diff --git a/netsync/handle.go b/netsync/handle.go
deleted file mode 100644 (file)
index 7dc35b9..0000000
+++ /dev/null
@@ -1,500 +0,0 @@
-package netsync
-
-import (
-       "errors"
-       "reflect"
-
-       log "github.com/sirupsen/logrus"
-
-       cfg "github.com/vapor/config"
-       "github.com/vapor/consensus"
-       "github.com/vapor/event"
-       "github.com/vapor/p2p"
-       core "github.com/vapor/protocol"
-       "github.com/vapor/protocol/bc"
-       "github.com/vapor/protocol/bc/types"
-)
-
-const (
-       logModule             = "netsync"
-       maxTxChanSize         = 10000
-       maxFilterAddressSize  = 50
-       maxFilterAddressCount = 1000
-)
-
-var (
-       errVaultModeDialPeer = errors.New("can't dial peer in vault mode")
-)
-
-// Chain is the interface for Bytom core
-type Chain interface {
-       BestBlockHeader() *types.BlockHeader
-       BestBlockHeight() uint64
-       GetBlockByHash(*bc.Hash) (*types.Block, error)
-       GetBlockByHeight(uint64) (*types.Block, error)
-       GetHeaderByHash(*bc.Hash) (*types.BlockHeader, error)
-       GetHeaderByHeight(uint64) (*types.BlockHeader, error)
-       GetTransactionStatus(*bc.Hash) (*bc.TransactionStatus, error)
-       InMainChain(bc.Hash) bool
-       ProcessBlock(*types.Block) (bool, error)
-       ValidateTx(*types.Tx) (bool, error)
-}
-
-type Switch interface {
-       AddReactor(name string, reactor p2p.Reactor) p2p.Reactor
-       AddBannedPeer(string) error
-       StopPeerGracefully(string)
-       Start() (bool, error)
-       Stop() bool
-       IsListening() bool
-       DialPeerWithAddress(addr *p2p.NetAddress) error
-       Peers() *p2p.PeerSet
-}
-
-//SyncManager Sync Manager is responsible for the business layer information synchronization
-type SyncManager struct {
-       sw           Switch
-       genesisHash  bc.Hash
-       chain        Chain
-       txPool       *core.TxPool
-       blockFetcher *blockFetcher
-       blockKeeper  *blockKeeper
-       peers        *peerSet
-
-       txSyncCh chan *txSyncMsg
-       quitSync chan struct{}
-       config   *cfg.Config
-
-       eventDispatcher *event.Dispatcher
-       minedBlockSub   *event.Subscription
-       txMsgSub        *event.Subscription
-}
-
-// CreateSyncManager create sync manager and set switch.
-func NewSyncManager(config *cfg.Config, chain Chain, txPool *core.TxPool, dispatcher *event.Dispatcher) (*SyncManager, error) {
-       sw, err := p2p.NewSwitch(config)
-       if err != nil {
-               return nil, err
-       }
-
-       return newSyncManager(config, sw, chain, txPool, dispatcher)
-}
-
-//NewSyncManager create a sync manager
-func newSyncManager(config *cfg.Config, sw Switch, chain Chain, txPool *core.TxPool, dispatcher *event.Dispatcher) (*SyncManager, error) {
-       genesisHeader, err := chain.GetHeaderByHeight(0)
-       if err != nil {
-               return nil, err
-       }
-       peers := newPeerSet(sw)
-       manager := &SyncManager{
-               sw:              sw,
-               genesisHash:     genesisHeader.Hash(),
-               txPool:          txPool,
-               chain:           chain,
-               blockFetcher:    newBlockFetcher(chain, peers),
-               blockKeeper:     newBlockKeeper(chain, peers),
-               peers:           peers,
-               txSyncCh:        make(chan *txSyncMsg),
-               quitSync:        make(chan struct{}),
-               config:          config,
-               eventDispatcher: dispatcher,
-       }
-
-       if !config.VaultMode {
-               protocolReactor := NewProtocolReactor(manager, peers)
-               manager.sw.AddReactor("PROTOCOL", protocolReactor)
-       }
-       return manager, nil
-}
-
-func (sm *SyncManager) AddPeer(peer BasePeer) {
-       sm.peers.addPeer(peer)
-}
-
-//BestPeer return the highest p2p peerInfo
-func (sm *SyncManager) BestPeer() *PeerInfo {
-       bestPeer := sm.peers.bestPeer(consensus.SFFullNode)
-       if bestPeer != nil {
-               return bestPeer.getPeerInfo()
-       }
-       return nil
-}
-
-func (sm *SyncManager) DialPeerWithAddress(addr *p2p.NetAddress) error {
-       if sm.config.VaultMode {
-               return errVaultModeDialPeer
-       }
-
-       return sm.sw.DialPeerWithAddress(addr)
-}
-
-func (sm *SyncManager) GetNetwork() string {
-       return sm.config.ChainID
-}
-
-//GetPeerInfos return peer info of all peers
-func (sm *SyncManager) GetPeerInfos() []*PeerInfo {
-       return sm.peers.getPeerInfos()
-}
-
-//IsCaughtUp check wheather the peer finish the sync
-func (sm *SyncManager) IsCaughtUp() bool {
-       peer := sm.peers.bestPeer(consensus.SFFullNode)
-       return peer == nil || peer.Height() <= sm.chain.BestBlockHeight()
-}
-
-//StopPeer try to stop peer by given ID
-func (sm *SyncManager) StopPeer(peerID string) error {
-       if peer := sm.peers.getPeer(peerID); peer == nil {
-               return errors.New("peerId not exist")
-       }
-       sm.peers.removePeer(peerID)
-       return nil
-}
-
-func (sm *SyncManager) handleBlockMsg(peer *peer, msg *BlockMessage) {
-       block, err := msg.GetBlock()
-       if err != nil {
-               return
-       }
-       sm.blockKeeper.processBlock(peer.ID(), block)
-}
-
-func (sm *SyncManager) handleBlocksMsg(peer *peer, msg *BlocksMessage) {
-       blocks, err := msg.GetBlocks()
-       if err != nil {
-               log.WithFields(log.Fields{"module": logModule, "err": err}).Debug("fail on handleBlocksMsg GetBlocks")
-               return
-       }
-
-       sm.blockKeeper.processBlocks(peer.ID(), blocks)
-}
-
-func (sm *SyncManager) handleFilterAddMsg(peer *peer, msg *FilterAddMessage) {
-       peer.addFilterAddress(msg.Address)
-}
-
-func (sm *SyncManager) handleFilterClearMsg(peer *peer) {
-       peer.filterAdds.Clear()
-}
-
-func (sm *SyncManager) handleFilterLoadMsg(peer *peer, msg *FilterLoadMessage) {
-       peer.addFilterAddresses(msg.Addresses)
-}
-
-func (sm *SyncManager) handleGetBlockMsg(peer *peer, msg *GetBlockMessage) {
-       var block *types.Block
-       var err error
-       if msg.Height != 0 {
-               block, err = sm.chain.GetBlockByHeight(msg.Height)
-       } else {
-               block, err = sm.chain.GetBlockByHash(msg.GetHash())
-       }
-       if err != nil {
-               log.WithFields(log.Fields{"module": logModule, "err": err}).Warning("fail on handleGetBlockMsg get block from chain")
-               return
-       }
-
-       ok, err := peer.sendBlock(block)
-       if !ok {
-               sm.peers.removePeer(peer.ID())
-       }
-       if err != nil {
-               log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on handleGetBlockMsg sentBlock")
-       }
-}
-
-func (sm *SyncManager) handleGetBlocksMsg(peer *peer, msg *GetBlocksMessage) {
-       blocks, err := sm.blockKeeper.locateBlocks(msg.GetBlockLocator(), msg.GetStopHash())
-       if err != nil || len(blocks) == 0 {
-               return
-       }
-
-       totalSize := 0
-       sendBlocks := []*types.Block{}
-       for _, block := range blocks {
-               rawData, err := block.MarshalText()
-               if err != nil {
-                       log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on handleGetBlocksMsg marshal block")
-                       continue
-               }
-
-               if totalSize+len(rawData) > maxBlockchainResponseSize/2 {
-                       break
-               }
-               totalSize += len(rawData)
-               sendBlocks = append(sendBlocks, block)
-       }
-
-       ok, err := peer.sendBlocks(sendBlocks)
-       if !ok {
-               sm.peers.removePeer(peer.ID())
-       }
-       if err != nil {
-               log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on handleGetBlocksMsg sentBlock")
-       }
-}
-
-func (sm *SyncManager) handleGetHeadersMsg(peer *peer, msg *GetHeadersMessage) {
-       headers, err := sm.blockKeeper.locateHeaders(msg.GetBlockLocator(), msg.GetStopHash())
-       if err != nil || len(headers) == 0 {
-               log.WithFields(log.Fields{"module": logModule, "err": err}).Debug("fail on handleGetHeadersMsg locateHeaders")
-               return
-       }
-
-       ok, err := peer.sendHeaders(headers)
-       if !ok {
-               sm.peers.removePeer(peer.ID())
-       }
-       if err != nil {
-               log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on handleGetHeadersMsg sentBlock")
-       }
-}
-
-func (sm *SyncManager) handleGetMerkleBlockMsg(peer *peer, msg *GetMerkleBlockMessage) {
-       var err error
-       var block *types.Block
-       if msg.Height != 0 {
-               block, err = sm.chain.GetBlockByHeight(msg.Height)
-       } else {
-               block, err = sm.chain.GetBlockByHash(msg.GetHash())
-       }
-       if err != nil {
-               log.WithFields(log.Fields{"module": logModule, "err": err}).Warning("fail on handleGetMerkleBlockMsg get block from chain")
-               return
-       }
-
-       blockHash := block.Hash()
-       txStatus, err := sm.chain.GetTransactionStatus(&blockHash)
-       if err != nil {
-               log.WithFields(log.Fields{"module": logModule, "err": err}).Warning("fail on handleGetMerkleBlockMsg get transaction status")
-               return
-       }
-
-       ok, err := peer.sendMerkleBlock(block, txStatus)
-       if err != nil {
-               log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on handleGetMerkleBlockMsg sentMerkleBlock")
-               return
-       }
-
-       if !ok {
-               sm.peers.removePeer(peer.ID())
-       }
-}
-
-func (sm *SyncManager) handleHeadersMsg(peer *peer, msg *HeadersMessage) {
-       headers, err := msg.GetHeaders()
-       if err != nil {
-               log.WithFields(log.Fields{"module": logModule, "err": err}).Debug("fail on handleHeadersMsg GetHeaders")
-               return
-       }
-
-       sm.blockKeeper.processHeaders(peer.ID(), headers)
-}
-
-func (sm *SyncManager) handleMineBlockMsg(peer *peer, msg *MineBlockMessage) {
-       block, err := msg.GetMineBlock()
-       if err != nil {
-               log.WithFields(log.Fields{"module": logModule, "err": err}).Warning("fail on handleMineBlockMsg GetMineBlock")
-               return
-       }
-
-       hash := block.Hash()
-       peer.markBlock(&hash)
-       sm.blockFetcher.processNewBlock(&blockMsg{peerID: peer.ID(), block: block})
-       peer.setStatus(block.Height, &hash)
-}
-
-func (sm *SyncManager) handleStatusMsg(basePeer BasePeer, msg *StatusMessage) {
-       if peer := sm.peers.getPeer(basePeer.ID()); peer != nil {
-               peer.setStatus(msg.Height, msg.GetHash())
-               return
-       }
-}
-
-func (sm *SyncManager) handleTransactionMsg(peer *peer, msg *TransactionMessage) {
-       tx, err := msg.GetTransaction()
-       if err != nil {
-               sm.peers.addBanScore(peer.ID(), 0, 10, "fail on get tx from message")
-               return
-       }
-
-       if isOrphan, err := sm.chain.ValidateTx(tx); err != nil && err != core.ErrDustTx && !isOrphan {
-               sm.peers.addBanScore(peer.ID(), 10, 0, "fail on validate tx transaction")
-       }
-       sm.peers.markTx(peer.ID(), tx.ID)
-}
-
-func (sm *SyncManager) handleTransactionsMsg(peer *peer, msg *TransactionsMessage) {
-       txs, err := msg.GetTransactions()
-       if err != nil {
-               sm.peers.addBanScore(peer.ID(), 0, 20, "fail on get txs from message")
-               return
-       }
-
-       if len(txs) > txsMsgMaxTxNum {
-               sm.peers.addBanScore(peer.ID(), 20, 0, "exceeded the maximum tx number limit")
-               return
-       }
-
-       for _, tx := range txs {
-               if isOrphan, err := sm.chain.ValidateTx(tx); err != nil && !isOrphan {
-                       sm.peers.addBanScore(peer.ID(), 10, 0, "fail on validate tx transaction")
-                       return
-               }
-               sm.peers.markTx(peer.ID(), tx.ID)
-       }
-}
-
-func (sm *SyncManager) IsListening() bool {
-       if sm.config.VaultMode {
-               return false
-       }
-       return sm.sw.IsListening()
-}
-
-func (sm *SyncManager) PeerCount() int {
-       if sm.config.VaultMode {
-               return 0
-       }
-       return len(sm.sw.Peers().List())
-}
-
-func (sm *SyncManager) processMsg(basePeer BasePeer, msgType byte, msg BlockchainMessage) {
-       peer := sm.peers.getPeer(basePeer.ID())
-       if peer == nil {
-               return
-       }
-
-       log.WithFields(log.Fields{
-               "module":  logModule,
-               "peer":    basePeer.Addr(),
-               "type":    reflect.TypeOf(msg),
-               "message": msg.String(),
-       }).Info("receive message from peer")
-
-       switch msg := msg.(type) {
-       case *GetBlockMessage:
-               sm.handleGetBlockMsg(peer, msg)
-
-       case *BlockMessage:
-               sm.handleBlockMsg(peer, msg)
-
-       case *StatusMessage:
-               sm.handleStatusMsg(basePeer, msg)
-
-       case *TransactionMessage:
-               sm.handleTransactionMsg(peer, msg)
-
-       case *TransactionsMessage:
-               sm.handleTransactionsMsg(peer, msg)
-
-       case *MineBlockMessage:
-               sm.handleMineBlockMsg(peer, msg)
-
-       case *GetHeadersMessage:
-               sm.handleGetHeadersMsg(peer, msg)
-
-       case *HeadersMessage:
-               sm.handleHeadersMsg(peer, msg)
-
-       case *GetBlocksMessage:
-               sm.handleGetBlocksMsg(peer, msg)
-
-       case *BlocksMessage:
-               sm.handleBlocksMsg(peer, msg)
-
-       case *FilterLoadMessage:
-               sm.handleFilterLoadMsg(peer, msg)
-
-       case *FilterAddMessage:
-               sm.handleFilterAddMsg(peer, msg)
-
-       case *FilterClearMessage:
-               sm.handleFilterClearMsg(peer)
-
-       case *GetMerkleBlockMessage:
-               sm.handleGetMerkleBlockMsg(peer, msg)
-
-       default:
-               log.WithFields(log.Fields{
-                       "module":       logModule,
-                       "peer":         basePeer.Addr(),
-                       "message_type": reflect.TypeOf(msg),
-               }).Error("unhandled message type")
-       }
-}
-
-func (sm *SyncManager) SendStatus(peer BasePeer) error {
-       p := sm.peers.getPeer(peer.ID())
-       if p == nil {
-               return errors.New("invalid peer")
-       }
-
-       if err := p.sendStatus(sm.chain.BestBlockHeader()); err != nil {
-               sm.peers.removePeer(p.ID())
-               return err
-       }
-       return nil
-}
-
-func (sm *SyncManager) Start() error {
-       var err error
-       if _, err = sm.sw.Start(); err != nil {
-               log.Error("switch start err")
-               return err
-       }
-
-       sm.minedBlockSub, err = sm.eventDispatcher.Subscribe(event.NewMinedBlockEvent{})
-       if err != nil {
-               return err
-       }
-
-       sm.txMsgSub, err = sm.eventDispatcher.Subscribe(core.TxMsgEvent{})
-       if err != nil {
-               return err
-       }
-
-       // broadcast transactions
-       go sm.txBroadcastLoop()
-       go sm.minedBroadcastLoop()
-       go sm.txSyncLoop()
-
-       return nil
-}
-
-//Stop stop sync manager
-func (sm *SyncManager) Stop() {
-       close(sm.quitSync)
-       sm.minedBlockSub.Unsubscribe()
-       if !sm.config.VaultMode {
-               sm.sw.Stop()
-       }
-}
-
-func (sm *SyncManager) minedBroadcastLoop() {
-       for {
-               select {
-               case obj, ok := <-sm.minedBlockSub.Chan():
-                       if !ok {
-                               log.WithFields(log.Fields{"module": logModule}).Warning("mined block subscription channel closed")
-                               return
-                       }
-
-                       ev, ok := obj.Data.(event.NewMinedBlockEvent)
-                       if !ok {
-                               log.WithFields(log.Fields{"module": logModule}).Error("event type error")
-                               continue
-                       }
-
-                       if err := sm.peers.broadcastMinedBlock(&ev.Block); err != nil {
-                               log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on broadcast mine block")
-                               continue
-                       }
-
-               case <-sm.quitSync:
-                       return
-               }
-       }
-}
similarity index 94%
rename from netsync/message.go
rename to netsync/messages/chain_msg.go
index b4f702a..84bc031 100644 (file)
@@ -1,10 +1,8 @@
-package netsync
+package messages
 
 import (
-       "bytes"
        "encoding/hex"
        "encoding/json"
-       "errors"
        "fmt"
 
        "github.com/tendermint/go-wire"
@@ -33,8 +31,8 @@ const (
        MerkleRequestByte   = byte(0x60)
        MerkleResponseByte  = byte(0x61)
 
-       maxBlockchainResponseSize = 22020096 + 2
-       txsMsgMaxTxNum            = 1024
+       MaxBlockchainResponseSize = 22020096 + 2
+       TxsMsgMaxTxNum            = 1024
 )
 
 //BlockchainMessage is a generic message for this reactor.
@@ -61,18 +59,6 @@ var _ = wire.RegisterInterface(
        wire.ConcreteType{&MerkleBlockMessage{}, MerkleResponseByte},
 )
 
-//DecodeMessage decode msg
-func DecodeMessage(bz []byte) (msgType byte, msg BlockchainMessage, err error) {
-       msgType = bz[0]
-       n := int(0)
-       r := bytes.NewReader(bz)
-       msg = wire.ReadBinary(struct{ BlockchainMessage }{}, r, maxBlockchainResponseSize, &n, &err).(struct{ BlockchainMessage }).BlockchainMessage
-       if err != nil && n != len(bz) {
-               err = errors.New("DecodeMessage() had bytes left over")
-       }
-       return
-}
-
 //GetBlockMessage request blocks from remote peers by height/hash
 type GetBlockMessage struct {
        Height  uint64
@@ -453,7 +439,7 @@ type MerkleBlockMessage struct {
        Flags          []byte
 }
 
-func (m *MerkleBlockMessage) setRawBlockHeader(bh types.BlockHeader) error {
+func (m *MerkleBlockMessage) SetRawBlockHeader(bh types.BlockHeader) error {
        rawHeader, err := bh.MarshalText()
        if err != nil {
                return err
@@ -463,7 +449,7 @@ func (m *MerkleBlockMessage) setRawBlockHeader(bh types.BlockHeader) error {
        return nil
 }
 
-func (m *MerkleBlockMessage) setTxInfo(txHashes []*bc.Hash, txFlags []uint8, relatedTxs []*types.Tx) error {
+func (m *MerkleBlockMessage) SetTxInfo(txHashes []*bc.Hash, txFlags []uint8, relatedTxs []*types.Tx) error {
        for _, txHash := range txHashes {
                m.TxHashes = append(m.TxHashes, txHash.Byte32())
        }
@@ -479,7 +465,7 @@ func (m *MerkleBlockMessage) setTxInfo(txHashes []*bc.Hash, txFlags []uint8, rel
        return nil
 }
 
-func (m *MerkleBlockMessage) setStatusInfo(statusHashes []*bc.Hash, relatedStatuses []*bc.TxVerifyResult) error {
+func (m *MerkleBlockMessage) SetStatusInfo(statusHashes []*bc.Hash, relatedStatuses []*bc.TxVerifyResult) error {
        for _, statusHash := range statusHashes {
                m.StatusHashes = append(m.StatusHashes, statusHash.Byte32())
        }
similarity index 99%
rename from netsync/message_test.go
rename to netsync/messages/chain_msg_test.go
index 14bb71f..c87d6bc 100644 (file)
@@ -1,4 +1,4 @@
-package netsync
+package messages
 
 import (
        "reflect"
similarity index 51%
rename from netsync/peer.go
rename to netsync/peers/peer.go
index 794501b..ee5068a 100644 (file)
@@ -1,4 +1,4 @@
-package netsync
+package peers
 
 import (
        "encoding/hex"
@@ -12,18 +12,27 @@ import (
 
        "github.com/vapor/consensus"
        "github.com/vapor/errors"
+       msgs "github.com/vapor/netsync/messages"
        "github.com/vapor/p2p/trust"
        "github.com/vapor/protocol/bc"
        "github.com/vapor/protocol/bc/types"
 )
 
 const (
-       maxKnownTxs         = 32768 // Maximum transactions hashes to keep in the known list (prevent DOS)
-       maxKnownBlocks      = 1024  // Maximum block hashes to keep in the known list (prevent DOS)
-       defaultBanThreshold = uint32(100)
+       maxKnownTxs           = 32768 // Maximum transactions hashes to keep in the known list (prevent DOS)
+       maxKnownSignatures    = 1024  // Maximum block signatures to keep in the known list (prevent DOS)
+       maxKnownBlocks        = 1024  // Maximum block hashes to keep in the known list (prevent DOS)
+       defaultBanThreshold   = uint32(100)
+       maxFilterAddressSize  = 50
+       maxFilterAddressCount = 1000
+
+       logModule = "peers"
 )
 
-var errSendStatusMsg = errors.New("send status msg fail")
+var (
+       errSendStatusMsg = errors.New("send status msg fail")
+       ErrPeerMisbehave = errors.New("peer is misbehave")
+)
 
 //BasePeer is the interface for connection level peer
 type BasePeer interface {
@@ -41,6 +50,14 @@ type BasePeerSet interface {
        StopPeerGracefully(string)
 }
 
+type BroadcastMsg interface {
+       FilterTargetPeers(ps *PeerSet) []string
+       MarkSendRecord(ps *PeerSet, peers []string)
+       GetChan() byte
+       GetMsg() interface{}
+       MsgString() string
+}
+
 // PeerInfo indicate peer status snap
 type PeerInfo struct {
        ID                  string `json:"peer_id"`
@@ -56,36 +73,38 @@ type PeerInfo struct {
        CurrentReceivedRate int64  `json:"current_received_rate"`
 }
 
-type peer struct {
+type Peer struct {
        BasePeer
-       mtx         sync.RWMutex
-       services    consensus.ServiceFlag
-       height      uint64
-       hash        *bc.Hash
-       banScore    trust.DynamicBanScore
-       knownTxs    *set.Set // Set of transaction hashes known to be known by this peer
-       knownBlocks *set.Set // Set of block hashes known to be known by this peer
-       knownStatus uint64   // Set of chain status known to be known by this peer
-       filterAdds  *set.Set // Set of addresses that the spv node cares about.
-}
-
-func newPeer(basePeer BasePeer) *peer {
-       return &peer{
-               BasePeer:    basePeer,
-               services:    basePeer.ServiceFlag(),
-               knownTxs:    set.New(),
-               knownBlocks: set.New(),
-               filterAdds:  set.New(),
-       }
-}
-
-func (p *peer) Height() uint64 {
+       mtx             sync.RWMutex
+       services        consensus.ServiceFlag
+       height          uint64
+       hash            *bc.Hash
+       banScore        trust.DynamicBanScore
+       knownTxs        *set.Set // Set of transaction hashes known to be known by this peer
+       knownBlocks     *set.Set // Set of block hashes known to be known by this peer
+       knownSignatures *set.Set // Set of block signatures known to be known by this peer
+       knownStatus     uint64   // Set of chain status known to be known by this peer
+       filterAdds      *set.Set // Set of addresses that the spv node cares about.
+}
+
+func newPeer(basePeer BasePeer) *Peer {
+       return &Peer{
+               BasePeer:        basePeer,
+               services:        basePeer.ServiceFlag(),
+               knownTxs:        set.New(),
+               knownBlocks:     set.New(),
+               knownSignatures: set.New(),
+               filterAdds:      set.New(),
+       }
+}
+
+func (p *Peer) Height() uint64 {
        p.mtx.RLock()
        defer p.mtx.RUnlock()
        return p.height
 }
 
-func (p *peer) addBanScore(persistent, transient uint32, reason string) bool {
+func (p *Peer) addBanScore(persistent, transient uint32, reason string) bool {
        score := p.banScore.Increase(persistent, transient)
        if score > defaultBanThreshold {
                log.WithFields(log.Fields{
@@ -109,7 +128,7 @@ func (p *peer) addBanScore(persistent, transient uint32, reason string) bool {
        return false
 }
 
-func (p *peer) addFilterAddress(address []byte) {
+func (p *Peer) AddFilterAddress(address []byte) {
        p.mtx.Lock()
        defer p.mtx.Unlock()
 
@@ -125,31 +144,35 @@ func (p *peer) addFilterAddress(address []byte) {
        p.filterAdds.Add(hex.EncodeToString(address))
 }
 
-func (p *peer) addFilterAddresses(addresses [][]byte) {
+func (p *Peer) AddFilterAddresses(addresses [][]byte) {
        if !p.filterAdds.IsEmpty() {
                p.filterAdds.Clear()
        }
        for _, address := range addresses {
-               p.addFilterAddress(address)
+               p.AddFilterAddress(address)
        }
 }
 
-func (p *peer) getBlockByHeight(height uint64) bool {
-       msg := struct{ BlockchainMessage }{&GetBlockMessage{Height: height}}
-       return p.TrySend(BlockchainChannel, msg)
+func (p *Peer) FilterClear() {
+       p.filterAdds.Clear()
 }
 
-func (p *peer) getBlocks(locator []*bc.Hash, stopHash *bc.Hash) bool {
-       msg := struct{ BlockchainMessage }{NewGetBlocksMessage(locator, stopHash)}
-       return p.TrySend(BlockchainChannel, msg)
+func (p *Peer) GetBlockByHeight(height uint64) bool {
+       msg := struct{ msgs.BlockchainMessage }{&msgs.GetBlockMessage{Height: height}}
+       return p.TrySend(msgs.BlockchainChannel, msg)
 }
 
-func (p *peer) getHeaders(locator []*bc.Hash, stopHash *bc.Hash) bool {
-       msg := struct{ BlockchainMessage }{NewGetHeadersMessage(locator, stopHash)}
-       return p.TrySend(BlockchainChannel, msg)
+func (p *Peer) GetBlocks(locator []*bc.Hash, stopHash *bc.Hash) bool {
+       msg := struct{ msgs.BlockchainMessage }{msgs.NewGetBlocksMessage(locator, stopHash)}
+       return p.TrySend(msgs.BlockchainChannel, msg)
 }
 
-func (p *peer) getPeerInfo() *PeerInfo {
+func (p *Peer) GetHeaders(locator []*bc.Hash, stopHash *bc.Hash) bool {
+       msg := struct{ msgs.BlockchainMessage }{msgs.NewGetHeadersMessage(locator, stopHash)}
+       return p.TrySend(msgs.BlockchainChannel, msg)
+}
+
+func (p *Peer) GetPeerInfo() *PeerInfo {
        p.mtx.RLock()
        defer p.mtx.RUnlock()
 
@@ -174,7 +197,7 @@ func (p *peer) getPeerInfo() *PeerInfo {
        }
 }
 
-func (p *peer) getRelatedTxAndStatus(txs []*types.Tx, txStatuses *bc.TransactionStatus) ([]*types.Tx, []*bc.TxVerifyResult) {
+func (p *Peer) getRelatedTxAndStatus(txs []*types.Tx, txStatuses *bc.TransactionStatus) ([]*types.Tx, []*bc.TxVerifyResult) {
        var relatedTxs []*types.Tx
        var relatedStatuses []*bc.TxVerifyResult
        for i, tx := range txs {
@@ -186,7 +209,7 @@ func (p *peer) getRelatedTxAndStatus(txs []*types.Tx, txStatuses *bc.Transaction
        return relatedTxs, relatedStatuses
 }
 
-func (p *peer) isRelatedTx(tx *types.Tx) bool {
+func (p *Peer) isRelatedTx(tx *types.Tx) bool {
        for _, input := range tx.Inputs {
                switch inp := input.TypedInput.(type) {
                case *types.SpendInput:
@@ -203,11 +226,11 @@ func (p *peer) isRelatedTx(tx *types.Tx) bool {
        return false
 }
 
-func (p *peer) isSPVNode() bool {
+func (p *Peer) isSPVNode() bool {
        return !p.services.IsEnable(consensus.SFFullNode)
 }
 
-func (p *peer) markBlock(hash *bc.Hash) {
+func (p *Peer) MarkBlock(hash *bc.Hash) {
        p.mtx.Lock()
        defer p.mtx.Unlock()
 
@@ -217,14 +240,24 @@ func (p *peer) markBlock(hash *bc.Hash) {
        p.knownBlocks.Add(hash.String())
 }
 
-func (p *peer) markNewStatus(height uint64) {
+func (p *Peer) markNewStatus(height uint64) {
        p.mtx.Lock()
        defer p.mtx.Unlock()
 
        p.knownStatus = height
 }
 
-func (p *peer) markTransaction(hash *bc.Hash) {
+func (p *Peer) markSign(signature []byte) {
+       p.mtx.Lock()
+       defer p.mtx.Unlock()
+
+       for p.knownSignatures.Size() >= maxKnownSignatures {
+               p.knownSignatures.Pop()
+       }
+       p.knownSignatures.Add(signature)
+}
+
+func (p *Peer) markTransaction(hash *bc.Hash) {
        p.mtx.Lock()
        defer p.mtx.Unlock()
 
@@ -234,13 +267,39 @@ func (p *peer) markTransaction(hash *bc.Hash) {
        p.knownTxs.Add(hash.String())
 }
 
-func (p *peer) sendBlock(block *types.Block) (bool, error) {
-       msg, err := NewBlockMessage(block)
+func (ps *PeerSet) PeersWithoutBlock(hash bc.Hash) []string {
+       ps.mtx.RLock()
+       defer ps.mtx.RUnlock()
+
+       var peers []string
+       for _, peer := range ps.peers {
+               if !peer.knownBlocks.Has(hash.String()) {
+                       peers = append(peers, peer.ID())
+               }
+       }
+       return peers
+}
+
+func (ps *PeerSet) PeersWithoutSign(signature []byte) []string {
+       ps.mtx.RLock()
+       defer ps.mtx.RUnlock()
+
+       var peers []string
+       for _, peer := range ps.peers {
+               if !peer.knownSignatures.Has(signature) {
+                       peers = append(peers, peer.ID())
+               }
+       }
+       return peers
+}
+
+func (p *Peer) SendBlock(block *types.Block) (bool, error) {
+       msg, err := msgs.NewBlockMessage(block)
        if err != nil {
                return false, errors.Wrap(err, "fail on NewBlockMessage")
        }
 
-       ok := p.TrySend(BlockchainChannel, struct{ BlockchainMessage }{msg})
+       ok := p.TrySend(msgs.BlockchainChannel, struct{ msgs.BlockchainMessage }{msg})
        if ok {
                blcokHash := block.Hash()
                p.knownBlocks.Add(blcokHash.String())
@@ -248,13 +307,13 @@ func (p *peer) sendBlock(block *types.Block) (bool, error) {
        return ok, nil
 }
 
-func (p *peer) sendBlocks(blocks []*types.Block) (bool, error) {
-       msg, err := NewBlocksMessage(blocks)
+func (p *Peer) SendBlocks(blocks []*types.Block) (bool, error) {
+       msg, err := msgs.NewBlocksMessage(blocks)
        if err != nil {
                return false, errors.Wrap(err, "fail on NewBlocksMessage")
        }
 
-       if ok := p.TrySend(BlockchainChannel, struct{ BlockchainMessage }{msg}); !ok {
+       if ok := p.TrySend(msgs.BlockchainChannel, struct{ msgs.BlockchainMessage }{msg}); !ok {
                return ok, nil
        }
 
@@ -265,39 +324,39 @@ func (p *peer) sendBlocks(blocks []*types.Block) (bool, error) {
        return true, nil
 }
 
-func (p *peer) sendHeaders(headers []*types.BlockHeader) (bool, error) {
-       msg, err := NewHeadersMessage(headers)
+func (p *Peer) SendHeaders(headers []*types.BlockHeader) (bool, error) {
+       msg, err := msgs.NewHeadersMessage(headers)
        if err != nil {
                return false, errors.New("fail on NewHeadersMessage")
        }
 
-       ok := p.TrySend(BlockchainChannel, struct{ BlockchainMessage }{msg})
+       ok := p.TrySend(msgs.BlockchainChannel, struct{ msgs.BlockchainMessage }{msg})
        return ok, nil
 }
 
-func (p *peer) sendMerkleBlock(block *types.Block, txStatuses *bc.TransactionStatus) (bool, error) {
-       msg := NewMerkleBlockMessage()
-       if err := msg.setRawBlockHeader(block.BlockHeader); err != nil {
+func (p *Peer) SendMerkleBlock(block *types.Block, txStatuses *bc.TransactionStatus) (bool, error) {
+       msg := msgs.NewMerkleBlockMessage()
+       if err := msg.SetRawBlockHeader(block.BlockHeader); err != nil {
                return false, err
        }
 
        relatedTxs, relatedStatuses := p.getRelatedTxAndStatus(block.Transactions, txStatuses)
 
        txHashes, txFlags := types.GetTxMerkleTreeProof(block.Transactions, relatedTxs)
-       if err := msg.setTxInfo(txHashes, txFlags, relatedTxs); err != nil {
+       if err := msg.SetTxInfo(txHashes, txFlags, relatedTxs); err != nil {
                return false, nil
        }
 
        statusHashes := types.GetStatusMerkleTreeProof(txStatuses.VerifyStatus, txFlags)
-       if err := msg.setStatusInfo(statusHashes, relatedStatuses); err != nil {
+       if err := msg.SetStatusInfo(statusHashes, relatedStatuses); err != nil {
                return false, nil
        }
 
-       ok := p.TrySend(BlockchainChannel, struct{ BlockchainMessage }{msg})
+       ok := p.TrySend(msgs.BlockchainChannel, struct{ msgs.BlockchainMessage }{msg})
        return ok, nil
 }
 
-func (p *peer) sendTransactions(txs []*types.Tx) error {
+func (p *Peer) SendTransactions(txs []*types.Tx) error {
        validTxs := make([]*types.Tx, 0, len(txs))
        for i, tx := range txs {
                if p.isSPVNode() && !p.isRelatedTx(tx) || p.knownTxs.Has(tx.ID.String()) {
@@ -305,16 +364,16 @@ func (p *peer) sendTransactions(txs []*types.Tx) error {
                }
 
                validTxs = append(validTxs, tx)
-               if len(validTxs) != txsMsgMaxTxNum && i != len(txs)-1 {
+               if len(validTxs) != msgs.TxsMsgMaxTxNum && i != len(txs)-1 {
                        continue
                }
 
-               msg, err := NewTransactionsMessage(validTxs)
+               msg, err := msgs.NewTransactionsMessage(validTxs)
                if err != nil {
                        return err
                }
 
-               if ok := p.TrySend(BlockchainChannel, struct{ BlockchainMessage }{msg}); !ok {
+               if ok := p.TrySend(msgs.BlockchainChannel, struct{ msgs.BlockchainMessage }{msg}); !ok {
                        return errors.New("failed to send txs msg")
                }
 
@@ -328,37 +387,37 @@ func (p *peer) sendTransactions(txs []*types.Tx) error {
        return nil
 }
 
-func (p *peer) sendStatus(header *types.BlockHeader) error {
-       msg := NewStatusMessage(header)
-       if ok := p.TrySend(BlockchainChannel, struct{ BlockchainMessage }{msg}); !ok {
+func (p *Peer) SendStatus(header *types.BlockHeader) error {
+       msg := msgs.NewStatusMessage(header)
+       if ok := p.TrySend(msgs.BlockchainChannel, struct{ msgs.BlockchainMessage }{msg}); !ok {
                return errSendStatusMsg
        }
        p.markNewStatus(header.Height)
        return nil
 }
 
-func (p *peer) setStatus(height uint64, hash *bc.Hash) {
+func (p *Peer) SetStatus(height uint64, hash *bc.Hash) {
        p.mtx.Lock()
        defer p.mtx.Unlock()
        p.height = height
        p.hash = hash
 }
 
-type peerSet struct {
+type PeerSet struct {
        BasePeerSet
        mtx   sync.RWMutex
-       peers map[string]*peer
+       peers map[string]*Peer
 }
 
 // newPeerSet creates a new peer set to track the active participants.
-func newPeerSet(basePeerSet BasePeerSet) *peerSet {
-       return &peerSet{
+func NewPeerSet(basePeerSet BasePeerSet) *PeerSet {
+       return &PeerSet{
                BasePeerSet: basePeerSet,
-               peers:       make(map[string]*peer),
+               peers:       make(map[string]*Peer),
        }
 }
 
-func (ps *peerSet) addBanScore(peerID string, persistent, transient uint32, reason string) {
+func (ps *PeerSet) AddBanScore(peerID string, persistent, transient uint32, reason string) {
        ps.mtx.Lock()
        peer := ps.peers[peerID]
        ps.mtx.Unlock()
@@ -372,10 +431,10 @@ func (ps *peerSet) addBanScore(peerID string, persistent, transient uint32, reas
        if err := ps.AddBannedPeer(peer.Addr().String()); err != nil {
                log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on add ban peer")
        }
-       ps.removePeer(peerID)
+       ps.RemovePeer(peerID)
 }
 
-func (ps *peerSet) addPeer(peer BasePeer) {
+func (ps *PeerSet) AddPeer(peer BasePeer) {
        ps.mtx.Lock()
        defer ps.mtx.Unlock()
 
@@ -386,11 +445,11 @@ func (ps *peerSet) addPeer(peer BasePeer) {
        log.WithField("module", logModule).Warning("add existing peer to blockKeeper")
 }
 
-func (ps *peerSet) bestPeer(flag consensus.ServiceFlag) *peer {
+func (ps *PeerSet) BestPeer(flag consensus.ServiceFlag) *Peer {
        ps.mtx.RLock()
        defer ps.mtx.RUnlock()
 
-       var bestPeer *peer
+       var bestPeer *Peer
        for _, p := range ps.peers {
                if !p.services.IsEnable(flag) {
                        continue
@@ -402,35 +461,47 @@ func (ps *peerSet) bestPeer(flag consensus.ServiceFlag) *peer {
        return bestPeer
 }
 
-func (ps *peerSet) broadcastMinedBlock(block *types.Block) error {
-       msg, err := NewMinedBlockMessage(block)
-       if err != nil {
-               return errors.Wrap(err, "fail on broadcast mined block")
+//SendMsg send message to the target peer.
+func (ps *PeerSet) SendMsg(peerID string, msgChannel byte, msg interface{}) bool {
+       peer := ps.GetPeer(peerID)
+       if peer == nil {
+               return false
        }
 
-       hash := block.Hash()
-       peers := ps.peersWithoutBlock(&hash)
+       ok := peer.TrySend(msgChannel, msg)
+       if !ok {
+               ps.RemovePeer(peerID)
+       }
+       return ok
+}
+
+//BroadcastMsg Broadcast message to the target peers
+// and mark the message send record
+func (ps *PeerSet) BroadcastMsg(bm BroadcastMsg) error {
+       //filter target peers
+       peers := bm.FilterTargetPeers(ps)
+
+       //broadcast to target peers
+       peersSuccess := make([]string, 0)
        for _, peer := range peers {
-               if peer.isSPVNode() {
+               if ok := ps.SendMsg(peer, bm.GetChan(), bm.GetMsg()); !ok {
+                       log.WithFields(log.Fields{"module": logModule, "peer": peer, "type": reflect.TypeOf(bm.GetMsg()), "message": bm.MsgString()}).Warning("send message to peer error")
                        continue
                }
-               if ok := peer.TrySend(BlockchainChannel, struct{ BlockchainMessage }{msg}); !ok {
-                       log.WithFields(log.Fields{"module": logModule, "peer": peer.Addr(), "type": reflect.TypeOf(msg), "message": msg.String()}).Warning("send message to peer error")
-                       ps.removePeer(peer.ID())
-                       continue
-               }
-               peer.markBlock(&hash)
-               peer.markNewStatus(block.Height)
+               peersSuccess = append(peersSuccess, peer)
        }
+
+       //mark the message send record
+       bm.MarkSendRecord(ps, peersSuccess)
        return nil
 }
 
-func (ps *peerSet) broadcastNewStatus(bestBlock *types.Block) error {
-       msg := NewStatusMessage(&bestBlock.BlockHeader)
+func (ps *PeerSet) BroadcastNewStatus(bestBlock *types.Block) error {
+       msg := msgs.NewStatusMessage(&bestBlock.BlockHeader)
        peers := ps.peersWithoutNewStatus(bestBlock.Height)
        for _, peer := range peers {
-               if ok := peer.TrySend(BlockchainChannel, struct{ BlockchainMessage }{msg}); !ok {
-                       ps.removePeer(peer.ID())
+               if ok := peer.TrySend(msgs.BlockchainChannel, struct{ msgs.BlockchainMessage }{msg}); !ok {
+                       ps.RemovePeer(peer.ID())
                        continue
                }
 
@@ -439,8 +510,8 @@ func (ps *peerSet) broadcastNewStatus(bestBlock *types.Block) error {
        return nil
 }
 
-func (ps *peerSet) broadcastTx(tx *types.Tx) error {
-       msg, err := NewTransactionMessage(tx)
+func (ps *PeerSet) BroadcastTx(tx *types.Tx) error {
+       msg, err := msgs.NewTransactionMessage(tx)
        if err != nil {
                return errors.Wrap(err, "fail on broadcast tx")
        }
@@ -450,14 +521,14 @@ func (ps *peerSet) broadcastTx(tx *types.Tx) error {
                if peer.isSPVNode() && !peer.isRelatedTx(tx) {
                        continue
                }
-               if ok := peer.TrySend(BlockchainChannel, struct{ BlockchainMessage }{msg}); !ok {
+               if ok := peer.TrySend(msgs.BlockchainChannel, struct{ msgs.BlockchainMessage }{msg}); !ok {
                        log.WithFields(log.Fields{
                                "module":  logModule,
                                "peer":    peer.Addr(),
                                "type":    reflect.TypeOf(msg),
                                "message": msg.String(),
                        }).Warning("send message to peer error")
-                       ps.removePeer(peer.ID())
+                       ps.RemovePeer(peer.ID())
                        continue
                }
                peer.markTransaction(&tx.ID)
@@ -465,33 +536,57 @@ func (ps *peerSet) broadcastTx(tx *types.Tx) error {
        return nil
 }
 
-func (ps *peerSet) errorHandler(peerID string, err error) {
-       if errors.Root(err) == errPeerMisbehave {
-               ps.addBanScore(peerID, 20, 0, err.Error())
+func (ps *PeerSet) ErrorHandler(peerID string, err error) {
+       if errors.Root(err) == ErrPeerMisbehave {
+               ps.AddBanScore(peerID, 20, 0, err.Error())
        } else {
-               ps.removePeer(peerID)
+               ps.RemovePeer(peerID)
        }
 }
 
 // Peer retrieves the registered peer with the given id.
-func (ps *peerSet) getPeer(id string) *peer {
+func (ps *PeerSet) GetPeer(id string) *Peer {
        ps.mtx.RLock()
        defer ps.mtx.RUnlock()
        return ps.peers[id]
 }
 
-func (ps *peerSet) getPeerInfos() []*PeerInfo {
+func (ps *PeerSet) GetPeerInfos() []*PeerInfo {
        ps.mtx.RLock()
        defer ps.mtx.RUnlock()
 
        result := []*PeerInfo{}
        for _, peer := range ps.peers {
-               result = append(result, peer.getPeerInfo())
+               result = append(result, peer.GetPeerInfo())
        }
        return result
 }
 
-func (ps *peerSet) markTx(peerID string, txHash bc.Hash) {
+func (ps *PeerSet) MarkBlock(peerID string, hash *bc.Hash) {
+       peer := ps.GetPeer(peerID)
+       if peer == nil {
+               return
+       }
+       peer.MarkBlock(hash)
+}
+
+func (ps *PeerSet) MarkBlockSignature(peerID string, signature []byte) {
+       peer := ps.GetPeer(peerID)
+       if peer == nil {
+               return
+       }
+       peer.markSign(signature)
+}
+
+func (ps *PeerSet) MarkStatus(peerID string, height uint64) {
+       peer := ps.GetPeer(peerID)
+       if peer == nil {
+               return
+       }
+       peer.markNewStatus(height)
+}
+
+func (ps *PeerSet) MarkTx(peerID string, txHash bc.Hash) {
        ps.mtx.Lock()
        peer := ps.peers[peerID]
        ps.mtx.Unlock()
@@ -502,11 +597,11 @@ func (ps *peerSet) markTx(peerID string, txHash bc.Hash) {
        peer.markTransaction(&txHash)
 }
 
-func (ps *peerSet) peersWithoutBlock(hash *bc.Hash) []*peer {
+func (ps *PeerSet) peersWithoutBlock(hash *bc.Hash) []*Peer {
        ps.mtx.RLock()
        defer ps.mtx.RUnlock()
 
-       peers := []*peer{}
+       peers := []*Peer{}
        for _, peer := range ps.peers {
                if !peer.knownBlocks.Has(hash.String()) {
                        peers = append(peers, peer)
@@ -515,11 +610,11 @@ func (ps *peerSet) peersWithoutBlock(hash *bc.Hash) []*peer {
        return peers
 }
 
-func (ps *peerSet) peersWithoutNewStatus(height uint64) []*peer {
+func (ps *PeerSet) peersWithoutNewStatus(height uint64) []*Peer {
        ps.mtx.RLock()
        defer ps.mtx.RUnlock()
 
-       var peers []*peer
+       var peers []*Peer
        for _, peer := range ps.peers {
                if peer.knownStatus < height {
                        peers = append(peers, peer)
@@ -528,11 +623,11 @@ func (ps *peerSet) peersWithoutNewStatus(height uint64) []*peer {
        return peers
 }
 
-func (ps *peerSet) peersWithoutTx(hash *bc.Hash) []*peer {
+func (ps *PeerSet) peersWithoutTx(hash *bc.Hash) []*Peer {
        ps.mtx.RLock()
        defer ps.mtx.RUnlock()
 
-       peers := []*peer{}
+       peers := []*Peer{}
        for _, peer := range ps.peers {
                if !peer.knownTxs.Has(hash.String()) {
                        peers = append(peers, peer)
@@ -541,9 +636,18 @@ func (ps *peerSet) peersWithoutTx(hash *bc.Hash) []*peer {
        return peers
 }
 
-func (ps *peerSet) removePeer(peerID string) {
+func (ps *PeerSet) RemovePeer(peerID string) {
        ps.mtx.Lock()
        delete(ps.peers, peerID)
        ps.mtx.Unlock()
        ps.StopPeerGracefully(peerID)
 }
+
+func (ps *PeerSet) SetStatus(peerID string, height uint64, hash *bc.Hash) {
+       peer := ps.GetPeer(peerID)
+       if peer == nil {
+               return
+       }
+
+       peer.SetStatus(height, hash)
+}
diff --git a/netsync/sync_manager.go b/netsync/sync_manager.go
new file mode 100644 (file)
index 0000000..18d5291
--- /dev/null
@@ -0,0 +1,137 @@
+package netsync
+
+import (
+       "errors"
+
+       log "github.com/sirupsen/logrus"
+
+       cfg "github.com/vapor/config"
+       "github.com/vapor/consensus"
+       "github.com/vapor/event"
+       "github.com/vapor/netsync/chainmgr"
+       "github.com/vapor/netsync/peers"
+       "github.com/vapor/p2p"
+       core "github.com/vapor/protocol"
+)
+
+const (
+       logModule = "netsync"
+)
+
+var (
+       errVaultModeDialPeer = errors.New("can't dial peer in vault mode")
+)
+
+type ChainMgr interface {
+       Start() error
+       IsCaughtUp() bool
+       Stop()
+}
+
+type Switch interface {
+       Start() (bool, error)
+       Stop() bool
+       IsListening() bool
+       DialPeerWithAddress(addr *p2p.NetAddress) error
+       Peers() *p2p.PeerSet
+}
+
+//SyncManager Sync Manager is responsible for the business layer information synchronization
+type SyncManager struct {
+       config   *cfg.Config
+       sw       Switch
+       chainMgr ChainMgr
+       peers    *peers.PeerSet
+}
+
+// NewSyncManager create sync manager and set switch.
+func NewSyncManager(config *cfg.Config, chain *core.Chain, txPool *core.TxPool, dispatcher *event.Dispatcher) (*SyncManager, error) {
+       sw, err := p2p.NewSwitch(config)
+       if err != nil {
+               return nil, err
+       }
+       peers := peers.NewPeerSet(sw)
+
+       chainManger, err := chainmgr.NewChainManager(config, sw, chain, txPool, dispatcher, peers)
+       if err != nil {
+               return nil, err
+       }
+
+       return &SyncManager{
+               config:   config,
+               sw:       sw,
+               chainMgr: chainManger,
+               peers:    peers,
+       }, nil
+}
+
+func (sm *SyncManager) Start() error {
+       if _, err := sm.sw.Start(); err != nil {
+               log.WithFields(log.Fields{"module": logModule, "err": err}).Error("failed start switch")
+               return err
+       }
+
+       return sm.chainMgr.Start()
+}
+
+func (sm *SyncManager) Stop() {
+       sm.chainMgr.Stop()
+       if !sm.config.VaultMode {
+               sm.sw.Stop()
+       }
+
+}
+
+func (sm *SyncManager) IsListening() bool {
+       if sm.config.VaultMode {
+               return false
+       }
+       return sm.sw.IsListening()
+
+}
+
+//IsCaughtUp check wheather the peer finish the sync
+func (sm *SyncManager) IsCaughtUp() bool {
+       return sm.chainMgr.IsCaughtUp()
+}
+
+func (sm *SyncManager) PeerCount() int {
+       if sm.config.VaultMode {
+               return 0
+       }
+       return len(sm.sw.Peers().List())
+}
+
+func (sm *SyncManager) GetNetwork() string {
+       return sm.config.ChainID
+}
+
+func (sm *SyncManager) BestPeer() *peers.PeerInfo {
+       bestPeer := sm.peers.BestPeer(consensus.SFFullNode)
+       if bestPeer != nil {
+               return bestPeer.GetPeerInfo()
+       }
+       return nil
+}
+
+func (sm *SyncManager) DialPeerWithAddress(addr *p2p.NetAddress) error {
+       if sm.config.VaultMode {
+               return errVaultModeDialPeer
+       }
+
+       return sm.sw.DialPeerWithAddress(addr)
+}
+
+//GetPeerInfos return peer info of all peers
+func (sm *SyncManager) GetPeerInfos() []*peers.PeerInfo {
+       return sm.peers.GetPeerInfos()
+}
+
+//StopPeer try to stop peer by given ID
+func (sm *SyncManager) StopPeer(peerID string) error {
+       if peer := sm.peers.GetPeer(peerID); peer == nil {
+               return errors.New("peerId not exist")
+       }
+       sm.peers.RemovePeer(peerID)
+       return nil
+}
index d70fa75..57b16cf 100644 (file)
@@ -293,6 +293,10 @@ func (sw *Switch) DialPeerWithAddress(addr *NetAddress) error {
        return nil
 }
 
+func (sw *Switch) ID() [32]byte {
+       return sw.nodeInfo.PubKey
+}
+
 //IsDialing prevent duplicate dialing
 func (sw *Switch) IsDialing(addr *NetAddress) bool {
        return sw.dialing.Has(addr.IP.String())