From 93dd7ceb9c426fa99362a81bf7e137d410d4ef92 Mon Sep 17 00:00:00 2001 From: yahtoo Date: Mon, 3 Jun 2019 10:45:54 +0800 Subject: [PATCH] Add mempool sync test (#114) --- netsync/chainmgr/block_keeper_test.go | 24 ++++---- netsync/chainmgr/handle.go | 22 +++++--- netsync/chainmgr/protocol_reactor.go | 2 +- netsync/chainmgr/tool_test.go | 6 +- netsync/chainmgr/tx_keeper.go | 15 ++--- netsync/chainmgr/tx_keeper_test.go | 101 ++++++++++++++++++++++++++++++++++ test/mock/chain.go | 14 +++-- test/mock/mempool.go | 24 ++++++++ 8 files changed, 173 insertions(+), 35 deletions(-) create mode 100644 netsync/chainmgr/tx_keeper_test.go create mode 100644 test/mock/mempool.go diff --git a/netsync/chainmgr/block_keeper_test.go b/netsync/chainmgr/block_keeper_test.go index 2b1cc699..43e6ec7a 100644 --- a/netsync/chainmgr/block_keeper_test.go +++ b/netsync/chainmgr/block_keeper_test.go @@ -109,7 +109,7 @@ func TestBlockLocator(t *testing.T) { } for i, c := range cases { - mockChain := mock.NewChain() + mockChain := mock.NewChain(nil) bk := &blockKeeper{chain: mockChain} mockChain.SetBestBlockHeader(&blocks[c.bestHeight].BlockHeader) for i := uint64(0); i <= c.bestHeight; i++ { @@ -178,8 +178,8 @@ func TestFastBlockSync(t *testing.T) { for i, c := range cases { syncTimeout = c.syncTimeout - a := mockSync(c.aBlocks) - b := mockSync(c.bBlocks) + a := mockSync(c.aBlocks, nil) + b := mockSync(c.bBlocks, nil) netWork := NewNetWork() netWork.Register(a, "192.168.0.1", "test node A", consensus.SFFullNode) netWork.Register(b, "192.168.0.2", "test node B", consensus.SFFullNode) @@ -225,7 +225,7 @@ func TestLocateBlocks(t *testing.T) { }, } - mockChain := mock.NewChain() + mockChain := mock.NewChain(nil) bk := &blockKeeper{chain: mockChain} for _, block := range blocks { mockChain.SetBlockByHeight(block.Height, block) @@ -305,7 +305,7 @@ func TestLocateHeaders(t *testing.T) { } for i, c := range cases { - mockChain := mock.NewChain() + mockChain := mock.NewChain(nil) bk := &blockKeeper{chain: mockChain} for i := uint64(0); i <= c.chainHeight; i++ { mockChain.SetBlockByHeight(i, blocks[i]) @@ -379,7 +379,7 @@ func TestNextCheckpoint(t *testing.T) { }, } - mockChain := mock.NewChain() + mockChain := mock.NewChain(nil) for i, c := range cases { consensus.ActiveNetParams.Checkpoints = c.checkPoints mockChain.SetBestBlockHeader(&types.BlockHeader{Height: c.bestHeight}) @@ -439,8 +439,8 @@ func TestRegularBlockSync(t *testing.T) { for i, c := range cases { syncTimeout = c.syncTimeout - a := mockSync(c.aBlocks) - b := mockSync(c.bBlocks) + a := mockSync(c.aBlocks, nil) + b := mockSync(c.bBlocks, nil) netWork := NewNetWork() netWork.Register(a, "192.168.0.1", "test node A", consensus.SFFullNode) netWork.Register(b, "192.168.0.2", "test node B", consensus.SFFullNode) @@ -473,8 +473,8 @@ func TestRegularBlockSync(t *testing.T) { func TestRequireBlock(t *testing.T) { blocks := mockBlocks(nil, 5) - a := mockSync(blocks[:1]) - b := mockSync(blocks[:5]) + a := mockSync(blocks[:1], nil) + b := mockSync(blocks[:5], nil) netWork := NewNetWork() netWork.Register(a, "192.168.0.1", "test node A", consensus.SFFullNode) netWork.Register(b, "192.168.0.2", "test node B", consensus.SFFullNode) @@ -560,7 +560,7 @@ func TestSendMerkleBlock(t *testing.T) { t.Fatal(err) } - spvNode := mockSync(blocks) + spvNode := mockSync(blocks, nil) blockHash := targetBlock.Hash() var statusResult *bc.TransactionStatus if statusResult, err = spvNode.chain.GetTransactionStatus(&blockHash); err != nil { @@ -571,7 +571,7 @@ func TestSendMerkleBlock(t *testing.T) { t.Fatal(err) } - fullNode := mockSync(blocks) + fullNode := mockSync(blocks, nil) netWork := NewNetWork() netWork.Register(spvNode, "192.168.0.1", "spv_node", consensus.SFFastSync) netWork.Register(fullNode, "192.168.0.2", "full_node", consensus.DefaultServices) diff --git a/netsync/chainmgr/handle.go b/netsync/chainmgr/handle.go index 9dea086f..51dee4c0 100644 --- a/netsync/chainmgr/handle.go +++ b/netsync/chainmgr/handle.go @@ -45,16 +45,21 @@ type Switch interface { Peers() *p2p.PeerSet } +// Mempool is the interface for Bytom mempool +type Mempool interface { + GetTransactions() []*core.TxDesc +} + //Manager is responsible for the business layer information synchronization type Manager struct { sw Switch chain Chain - txPool *core.TxPool + mempool Mempool blockKeeper *blockKeeper peers *peers.PeerSet txSyncCh chan *txSyncMsg - quitSync chan struct{} + quit chan struct{} config *cfg.Config eventDispatcher *event.Dispatcher @@ -62,15 +67,15 @@ type Manager struct { } //NewChainManager create a chain sync manager. -func NewManager(config *cfg.Config, sw Switch, chain Chain, txPool *core.TxPool, dispatcher *event.Dispatcher, peers *peers.PeerSet) (*Manager, error) { +func NewManager(config *cfg.Config, sw Switch, chain Chain, mempool Mempool, dispatcher *event.Dispatcher, peers *peers.PeerSet) (*Manager, error) { manager := &Manager{ sw: sw, - txPool: txPool, + mempool: mempool, chain: chain, blockKeeper: newBlockKeeper(chain, peers), peers: peers, txSyncCh: make(chan *txSyncMsg), - quitSync: make(chan struct{}), + quit: make(chan struct{}), config: config, eventDispatcher: dispatcher, } @@ -359,14 +364,13 @@ func (m *Manager) Start() error { return err } - // broadcast transactions - go m.txBroadcastLoop() - go m.txSyncLoop() + go m.broadcastTxsLoop() + go m.syncMempoolLoop() return nil } //Stop stop sync manager func (m *Manager) Stop() { - close(m.quitSync) + close(m.quit) } diff --git a/netsync/chainmgr/protocol_reactor.go b/netsync/chainmgr/protocol_reactor.go index 85a5c259..7cf0909a 100644 --- a/netsync/chainmgr/protocol_reactor.go +++ b/netsync/chainmgr/protocol_reactor.go @@ -56,7 +56,7 @@ func (pr *ProtocolReactor) AddPeer(peer *p2p.Peer) error { if err := pr.manager.SendStatus(peer); err != nil { return err } - pr.manager.syncTransactions(peer.Key) + pr.manager.syncMempool(peer.Key) return nil } diff --git a/netsync/chainmgr/tool_test.go b/netsync/chainmgr/tool_test.go index dba3c89c..e3549841 100644 --- a/netsync/chainmgr/tool_test.go +++ b/netsync/chainmgr/tool_test.go @@ -150,8 +150,8 @@ func mockBlocks(startBlock *types.Block, height uint64) []*types.Block { return blocks } -func mockSync(blocks []*types.Block) *Manager { - chain := mock.NewChain() +func mockSync(blocks []*types.Block, mempool *mock.Mempool) *Manager { + chain := mock.NewChain(mempool) peers := peers.NewPeerSet(NewPeerSet()) chain.SetBestBlockHeader(&blocks[len(blocks)-1].BlockHeader) for _, block := range blocks { @@ -162,6 +162,8 @@ func mockSync(blocks []*types.Block) *Manager { chain: chain, blockKeeper: newBlockKeeper(chain, peers), peers: peers, + mempool: mempool, + txSyncCh: make(chan *txSyncMsg), } } diff --git a/netsync/chainmgr/tx_keeper.go b/netsync/chainmgr/tx_keeper.go index 50714031..6c6f5f9e 100644 --- a/netsync/chainmgr/tx_keeper.go +++ b/netsync/chainmgr/tx_keeper.go @@ -21,8 +21,8 @@ type txSyncMsg struct { txs []*types.Tx } -func (m *Manager) syncTransactions(peerID string) { - pending := m.txPool.GetTransactions() +func (m *Manager) syncMempool(peerID string) { + pending := m.mempool.GetTransactions() if len(pending) == 0 { return } @@ -34,7 +34,7 @@ func (m *Manager) syncTransactions(peerID string) { m.txSyncCh <- &txSyncMsg{peerID, txs} } -func (m *Manager) txBroadcastLoop() { +func (m *Manager) broadcastTxsLoop() { for { select { case obj, ok := <-m.txMsgSub.Chan(): @@ -55,17 +55,17 @@ func (m *Manager) txBroadcastLoop() { continue } } - case <-m.quitSync: + case <-m.quit: return } } } -// txSyncLoop takes care of the initial transaction sync for each new +// syncMempoolLoop takes care of the initial transaction sync for each new // connection. When a new peer appears, we relay all currently pending // transactions. In order to minimise egress bandwidth usage, we send // the transactions in small packs to one peer at a time. -func (m *Manager) txSyncLoop() { +func (m *Manager) syncMempoolLoop() { pending := make(map[string]*txSyncMsg) sending := false // whether a send is active done := make(chan error, 1) // result of the send @@ -130,7 +130,6 @@ func (m *Manager) txSyncLoop() { if !sending { send(msg) } - case err := <-done: sending = false if err != nil { @@ -140,6 +139,8 @@ func (m *Manager) txSyncLoop() { if s := pick(); s != nil { send(s) } + case <-m.quit: + return } } } diff --git a/netsync/chainmgr/tx_keeper_test.go b/netsync/chainmgr/tx_keeper_test.go new file mode 100644 index 00000000..7401af21 --- /dev/null +++ b/netsync/chainmgr/tx_keeper_test.go @@ -0,0 +1,101 @@ +package chainmgr + +import ( + "reflect" + "testing" + "time" + + "github.com/davecgh/go-spew/spew" + + "github.com/vapor/consensus" + "github.com/vapor/protocol" + "github.com/vapor/protocol/bc" + "github.com/vapor/protocol/bc/types" + "github.com/vapor/test/mock" +) + +const txsNumber = 2000 + +func getTransactions() []*types.Tx { + txs := []*types.Tx{} + for i := 0; i < txsNumber; i++ { + txInput := types.NewSpendInput(nil, bc.NewHash([32]byte{0x01}), *consensus.BTMAssetID, uint64(i), 1, []byte{0x51}) + txInput.CommitmentSuffix = []byte{0, 1, 2} + txInput.WitnessSuffix = []byte{0, 1, 2} + + tx := &types.Tx{ + + TxData: types.TxData{ + //SerializedSize: uint64(i * 10), + Inputs: []*types.TxInput{ + txInput, + }, + Outputs: []*types.TxOutput{ + types.NewIntraChainOutput(*consensus.BTMAssetID, uint64(i), []byte{0x6a}), + }, + }, + Tx: &bc.Tx{ + ID: bc.Hash{V0: uint64(i), V1: uint64(i), V2: uint64(i), V3: uint64(i)}, + }, + } + txs = append(txs, tx) + } + return txs +} + +func TestSyncMempool(t *testing.T) { + blocks := mockBlocks(nil, 5) + a := mockSync(blocks, &mock.Mempool{}) + b := mockSync(blocks, &mock.Mempool{}) + + netWork := NewNetWork() + netWork.Register(a, "192.168.0.1", "test node A", consensus.SFFullNode) + netWork.Register(b, "192.168.0.2", "test node B", consensus.SFFullNode) + if B2A, A2B, err := netWork.HandsShake(a, b); err != nil { + t.Errorf("fail on peer hands shake %v", err) + } else { + go B2A.postMan() + go A2B.postMan() + } + + go a.syncMempoolLoop() + a.syncMempool("test node B") + wantTxs := getTransactions() + a.txSyncCh <- &txSyncMsg{"test node B", wantTxs} + + timeout := time.NewTimer(2 * time.Second) + defer timeout.Stop() + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + + gotTxs := []*protocol.TxDesc{} + for { + select { + case <-ticker.C: + gotTxs = b.mempool.GetTransactions() + if len(gotTxs) >= txsNumber { + goto out + } + case <-timeout.C: + t.Fatalf("mempool sync timeout") + } + } + +out: + if len(gotTxs) != txsNumber { + t.Fatalf("mempool sync txs num err. got:%d want:%d", len(gotTxs), txsNumber) + } + + for i, gotTx := range gotTxs { + index := gotTx.Tx.Inputs[0].Amount() + if !reflect.DeepEqual(gotTx.Tx.Inputs[0].Amount(), wantTxs[index].Inputs[0].Amount()) { + t.Fatalf("mempool tx err. index:%d\n,gotTx:%s\n,wantTx:%s", i, spew.Sdump(gotTx.Tx.Inputs), spew.Sdump(wantTxs[0].Inputs)) + } + + if !reflect.DeepEqual(gotTx.Tx.Outputs[0].AssetAmount(), wantTxs[index].Outputs[0].AssetAmount()) { + t.Fatalf("mempool tx err. index:%d\n,gotTx:%s\n,wantTx:%s", i, spew.Sdump(gotTx.Tx.Outputs), spew.Sdump(wantTxs[0].Outputs)) + } + + } + +} diff --git a/test/mock/chain.go b/test/mock/chain.go index b1601b19..0b93ce9e 100644 --- a/test/mock/chain.go +++ b/test/mock/chain.go @@ -8,19 +8,24 @@ import ( "github.com/vapor/protocol/bc/types" ) +type mempool interface { + AddTx(tx *types.Tx) +} + type Chain struct { bestBlockHeader *types.BlockHeader heightMap map[uint64]*types.Block blockMap map[bc.Hash]*types.Block - - prevOrphans map[bc.Hash]*types.Block + prevOrphans map[bc.Hash]*types.Block + mempool mempool } -func NewChain() *Chain { +func NewChain(mempool *Mempool) *Chain { return &Chain{ heightMap: map[uint64]*types.Block{}, blockMap: map[bc.Hash]*types.Block{}, prevOrphans: make(map[bc.Hash]*types.Block), + mempool: mempool, } } @@ -137,6 +142,7 @@ func (c *Chain) SetBlockByHeight(height uint64, block *types.Block) { c.blockMap[block.Hash()] = block } -func (c *Chain) ValidateTx(*types.Tx) (bool, error) { +func (c *Chain) ValidateTx(tx *types.Tx) (bool, error) { + c.mempool.AddTx(tx) return false, nil } diff --git a/test/mock/mempool.go b/test/mock/mempool.go new file mode 100644 index 00000000..767fb7b2 --- /dev/null +++ b/test/mock/mempool.go @@ -0,0 +1,24 @@ +package mock + +import ( + "github.com/vapor/protocol" + "github.com/vapor/protocol/bc/types" +) + +type Mempool struct { + txs []*protocol.TxDesc +} + +func newMempool() *Mempool { + return &Mempool{ + txs: []*protocol.TxDesc{}, + } +} + +func (m *Mempool) AddTx(tx *types.Tx) { + m.txs = append(m.txs, &protocol.TxDesc{Tx: tx}) +} + +func (m *Mempool) GetTransactions() []*protocol.TxDesc { + return m.txs +} -- 2.11.0