OSDN Git Service

merge order from pool (#432)
authorPoseidon <shenao.78@163.com>
Fri, 1 Nov 2019 02:52:52 +0000 (10:52 +0800)
committerPaladz <yzhu101@uottawa.ca>
Fri, 1 Nov 2019 02:52:52 +0000 (10:52 +0800)
* merge order from pool

* bug fix order table

* bug fix order table

* bug fix

* rename

application/mov/common/type.go
application/mov/database/mock_mov_store.go
application/mov/match/match.go
application/mov/match/match_test.go
application/mov/match/order_table.go
application/mov/mov_core.go
proposal/proposal.go
protocol/protocol.go

index 9f55f2f..570f8a7 100644 (file)
@@ -23,6 +23,18 @@ type Order struct {
        Rate        float64
 }
 
+type OrderSlice []*Order
+
+func (o OrderSlice) Len() int {
+       return len(o)
+}
+func (o OrderSlice) Swap(i, j int) {
+       o[i], o[j] = o[j], o[i]
+}
+func (o OrderSlice) Less(i, j int) bool {
+       return o[i].Rate < o[j].Rate
+}
+
 func NewOrderFromOutput(tx *types.Tx, outputIndex int) (*Order, error) {
        outputID := tx.OutputID(outputIndex)
        output, err := tx.IntraChainOutput(*outputID)
index 52d1513..96bbcdd 100644 (file)
@@ -63,18 +63,6 @@ func (m *MockMovStore) ListTradePairsWithStart(fromAssetIDAfter, toAssetIDAfter
        return result, nil
 }
 
-type OrderSlice []*common.Order
-
-func (o OrderSlice) Len() int {
-       return len(o)
-}
-func (o OrderSlice) Swap(i, j int) {
-       o[i], o[j] = o[j], o[i]
-}
-func (o OrderSlice) Less(i, j int) bool {
-       return o[i].Rate < o[j].Rate
-}
-
 func (m *MockMovStore) ProcessOrders(addOrders []*common.Order, delOrders []*common.Order, blockHeader *types.BlockHeader) error {
        for _, order := range addOrders {
                tradePair := &common.TradePair{FromAssetID: order.FromAssetID, ToAssetID: order.ToAssetID}
@@ -90,7 +78,7 @@ func (m *MockMovStore) ProcessOrders(addOrders []*common.Order, delOrders []*com
                }
        }
        for _, orders := range m.OrderMap {
-               sort.Sort(OrderSlice(orders))
+               sort.Sort(common.OrderSlice(orders))
        }
 
        if blockHeader.Height == m.DBState.Height {
index 5a8a95c..336269a 100644 (file)
@@ -6,7 +6,6 @@ import (
 
        "github.com/vapor/application/mov/common"
        "github.com/vapor/application/mov/contract"
-       "github.com/vapor/application/mov/database"
        "github.com/vapor/consensus/segwit"
        "github.com/vapor/errors"
        vprMath "github.com/vapor/math"
@@ -22,8 +21,8 @@ type Engine struct {
        nodeProgram []byte
 }
 
-func NewEngine(movStore database.MovStore, maxFeeRate float64, nodeProgram []byte) *Engine {
-       return &Engine{orderTable: NewOrderTable(movStore), maxFeeRate: maxFeeRate, nodeProgram: nodeProgram}
+func NewEngine(orderTable *OrderTable, maxFeeRate float64, nodeProgram []byte) *Engine {
+       return &Engine{orderTable: orderTable, maxFeeRate: maxFeeRate, nodeProgram: nodeProgram}
 }
 
 func (e *Engine) HasMatchedTx(tradePairs ...*common.TradePair) bool {
index 5dcb44a..ee38a0a 100644 (file)
@@ -175,7 +175,7 @@ func TestGenerateMatchedTxs(t *testing.T) {
 
        for i, c := range cases {
                movStore := &database.MockMovStore{OrderMap: c.storeOrderMap}
-               matchEngine := NewEngine(movStore, 0.05, []byte{0x51})
+               matchEngine := NewEngine(NewOrderTable(movStore, nil, nil), 0.05, []byte{0x51})
                var gotMatchedTxs []*types.Tx
                for matchEngine.HasMatchedTx(c.tradePair, c.tradePair.Reverse()) {
                        matchedTx, err := matchEngine.NextMatchedTx(c.tradePair, c.tradePair.Reverse())
index b314ed0..bc656bb 100644 (file)
 package match
 
 import (
+       "sort"
+
        "github.com/vapor/application/mov/common"
        "github.com/vapor/application/mov/database"
        "github.com/vapor/errors"
 )
 
 type OrderTable struct {
-       movStore    database.MovStore
-       orderMap    map[string][]*common.Order
-       iteratorMap map[string]*database.OrderIterator
+       movStore       database.MovStore
+       // key of tradePair -> []order
+       dbOrders map[string][]*common.Order
+       // key of tradePair -> iterator
+       orderIterators map[string]*database.OrderIterator
+
+       // key of tradePair -> []order
+       arrivalAddOrders map[string][]*common.Order
+       // key of order -> order
+       arrivalDelOrders map[string]*common.Order
 }
 
-func NewOrderTable(movStore database.MovStore) *OrderTable {
+func NewOrderTable(movStore database.MovStore, arrivalAddOrders, arrivalDelOrders []*common.Order) *OrderTable {
        return &OrderTable{
-               movStore:    movStore,
-               orderMap:    make(map[string][]*common.Order),
-               iteratorMap: make(map[string]*database.OrderIterator),
+               movStore:       movStore,
+               dbOrders:       make(map[string][]*common.Order),
+               orderIterators: make(map[string]*database.OrderIterator),
+
+               arrivalAddOrders: arrangeArrivalAddOrders(arrivalAddOrders),
+               arrivalDelOrders: arrangeArrivalDelOrders(arrivalDelOrders),
        }
 }
 
 func (o *OrderTable) PeekOrder(tradePair *common.TradePair) *common.Order {
-       orders := o.orderMap[tradePair.Key()]
-       if len(orders) != 0 {
-               return orders[len(orders)-1]
+       if len(o.dbOrders[tradePair.Key()]) == 0 {
+               o.extendDBOrders(tradePair)
        }
 
-       iterator, ok := o.iteratorMap[tradePair.Key()]
-       if !ok {
-               iterator = database.NewOrderIterator(o.movStore, tradePair)
-               o.iteratorMap[tradePair.Key()] = iterator
+       var nextOrder *common.Order
+
+       orders := o.dbOrders[tradePair.Key()]
+       if len(orders) != 0 {
+               nextOrder = orders[len(orders) - 1]
        }
 
-       nextOrders := iterator.NextBatch()
-       if len(nextOrders) == 0 {
-               return nil
+       if nextOrder != nil && o.arrivalDelOrders[nextOrder.Key()] != nil {
+               o.dbOrders[tradePair.Key()] = orders[0 : len(orders)-1]
+               delete(o.arrivalDelOrders, nextOrder.Key())
+               return o.PeekOrder(tradePair)
        }
 
-       for i := len(nextOrders) - 1; i >= 0; i-- {
-               o.orderMap[tradePair.Key()] = append(o.orderMap[tradePair.Key()], nextOrders[i])
+       arrivalOrder := o.peekArrivalOrder(tradePair)
+       if nextOrder == nil || (arrivalOrder != nil && arrivalOrder.Rate < nextOrder.Rate) {
+               nextOrder = arrivalOrder
        }
-       return nextOrders[0]
+       return nextOrder
 }
 
 func (o *OrderTable) PopOrder(tradePair *common.TradePair) {
-       if orders := o.orderMap[tradePair.Key()]; len(orders) > 0 {
-               o.orderMap[tradePair.Key()] = orders[0 : len(orders)-1]
+       order := o.PeekOrder(tradePair)
+       if order == nil {
+               return
+       }
+
+       orders := o.dbOrders[tradePair.Key()]
+       if len(orders) != 0 && orders[len(orders) - 1].Key() == order.Key() {
+               o.dbOrders[tradePair.Key()] = orders[0 : len(orders)-1]
+       }
+
+       arrivalOrders := o.arrivalAddOrders[tradePair.Key()]
+       if len(arrivalOrders) != 0 && orders[len(arrivalOrders) - 1].Key() == order.Key() {
+               o.arrivalAddOrders[tradePair.Key()] = arrivalOrders[0 : len(arrivalOrders)-1]
        }
 }
 
 func (o *OrderTable) AddOrder(order *common.Order) error {
        tradePair := order.GetTradePair()
-       orders := o.orderMap[tradePair.Key()]
+       orders := o.dbOrders[tradePair.Key()]
        if len(orders) > 0 && order.Rate > orders[len(orders)-1].Rate {
                return errors.New("rate of order must less than the min order in order table")
        }
 
-       o.orderMap[tradePair.Key()] = append(orders, order)
+       o.dbOrders[tradePair.Key()] = append(orders, order)
        return nil
 }
+
+func (o *OrderTable) extendDBOrders(tradePair *common.TradePair) {
+       iterator, ok := o.orderIterators[tradePair.Key()]
+       if !ok {
+               iterator = database.NewOrderIterator(o.movStore, tradePair)
+               o.orderIterators[tradePair.Key()] = iterator
+       }
+
+       nextOrders := iterator.NextBatch()
+       for i := len(nextOrders) - 1; i >= 0; i-- {
+               o.dbOrders[tradePair.Key()] = append(o.dbOrders[tradePair.Key()], nextOrders[i])
+       }
+}
+
+func (o *OrderTable) peekArrivalOrder(tradePair *common.TradePair) *common.Order {
+       arrivalAddOrders := o.arrivalAddOrders[tradePair.Key()]
+       if len(arrivalAddOrders) > 0 {
+               return arrivalAddOrders[len(arrivalAddOrders) -1]
+       }
+       return nil
+}
+
+func arrangeArrivalAddOrders(orders []*common.Order) map[string][]*common.Order {
+       arrivalAddOrderMap := make(map[string][]*common.Order)
+       for _, order := range orders {
+               arrivalAddOrderMap[order.Key()] = append(arrivalAddOrderMap[order.Key()], order)
+       }
+
+       for _, orders := range arrivalAddOrderMap {
+               sort.Sort(common.OrderSlice(orders))
+       }
+       return arrivalAddOrderMap
+}
+
+func arrangeArrivalDelOrders(orders []*common.Order) map[string]*common.Order {
+       arrivalDelOrderMap := make(map[string]*common.Order)
+       for _, order := range orders {
+               arrivalDelOrderMap[order.Key()] = order
+       }
+       return arrivalDelOrderMap
+}
index ca29612..5663f30 100644 (file)
@@ -209,7 +209,12 @@ func (m *MovCore) ApplyBlock(block *types.Block) error {
 }
 
 func (m *MovCore) validateMatchedTxSequence(txs []*types.Tx) error {
-       matchEngine := match.NewEngine(m.movStore, maxFeeRate, nil)
+       orderTable, err := buildOrderTable(m.movStore, txs)
+       if err != nil {
+               return err
+       }
+
+       matchEngine := match.NewEngine(orderTable, maxFeeRate, nil)
        for _, matchedTx := range txs {
                if !common.IsMatchedTx(matchedTx) {
                        continue
@@ -301,12 +306,17 @@ func (m *MovCore) DetachBlock(block *types.Block) error {
 }
 
 // BeforeProposalBlock return all transactions than can be matched, and the number of transactions cannot exceed the given capacity.
-func (m *MovCore) BeforeProposalBlock(nodeProgram []byte, blockHeight uint64, gasLeft int64) ([]*types.Tx, int64, error) {
+func (m *MovCore) BeforeProposalBlock(txs []*types.Tx, nodeProgram []byte, blockHeight uint64, gasLeft int64) ([]*types.Tx, int64, error) {
        if blockHeight <= m.startBlockHeight {
                return nil, 0, nil
        }
 
-       matchEngine := match.NewEngine(m.movStore, maxFeeRate, nodeProgram)
+       orderTable, err := buildOrderTable(m.movStore, txs)
+       if err != nil {
+               return nil, 0, err
+       }
+
+       matchEngine := match.NewEngine(orderTable, maxFeeRate, nodeProgram)
        tradePairMap := make(map[string]bool)
        tradePairIterator := database.NewTradePairIterator(m.movStore)
 
@@ -339,6 +349,33 @@ func calcMatchedTxGasUsed(tx *types.Tx) int64 {
        return int64(len(tx.Inputs)) * 150 + int64(tx.SerializedSize)
 }
 
+func buildOrderTable(store database.MovStore, txs []*types.Tx) (*match.OrderTable, error) {
+       var nonMatchedTxs []*types.Tx
+       for _, tx := range txs {
+               if !common.IsMatchedTx(tx) {
+                       nonMatchedTxs = append(nonMatchedTxs, tx)
+               }
+       }
+
+       var arrivalAddOrders, arrivalDelOrders []*common.Order
+       for _, tx := range nonMatchedTxs {
+               addOrders, err := getAddOrdersFromTx(tx)
+               if err != nil {
+                       return nil, err
+               }
+
+               delOrders, err := getDeleteOrdersFromTx(tx)
+               if err != nil {
+                       return nil, err
+               }
+
+               arrivalAddOrders = append(arrivalAddOrders, addOrders...)
+               arrivalDelOrders = append(arrivalDelOrders, delOrders...)
+       }
+
+       return match.NewOrderTable(store, arrivalAddOrders, arrivalDelOrders), nil
+}
+
 // IsDust block the transaction that are not generated by the match engine 
 func (m *MovCore) IsDust(tx *types.Tx) bool {
        for _, input := range tx.Inputs {
index a31f5bf..b620465 100644 (file)
@@ -180,7 +180,7 @@ func applyTransactionFromPool(chain *protocol.Chain, view *state.UtxoViewpoint,
 }
 
 func applyTransactionFromSubProtocol(chain *protocol.Chain, view *state.UtxoViewpoint, block *types.Block, txStatus *bc.TransactionStatus, accountManager *account.Manager, gasLeft int64) error {
-       txs, err := getTxsFromSubProtocols(chain, accountManager, gasLeft)
+       txs, err := getTxsFromSubProtocols(chain, accountManager, block.Transactions, gasLeft)
        if err != nil {
                return err
        }
@@ -273,7 +273,7 @@ func getAllTxsFromPool(txPool *protocol.TxPool) []*types.Tx {
        return poolTxs
 }
 
-func getTxsFromSubProtocols(chain *protocol.Chain, accountManager *account.Manager, gasLeft int64) ([]*types.Tx, error) {
+func getTxsFromSubProtocols(chain *protocol.Chain, accountManager *account.Manager, poolTxs []*types.Tx, gasLeft int64) ([]*types.Tx, error) {
        cp, err := accountManager.GetCoinbaseControlProgram()
        if err != nil {
                return nil, err
@@ -286,7 +286,7 @@ func getTxsFromSubProtocols(chain *protocol.Chain, accountManager *account.Manag
                        break
                }
 
-               subTxs, gasLeft, err = p.BeforeProposalBlock(cp, chain.BestBlockHeight() + 1, gasLeft)
+               subTxs, gasLeft, err = p.BeforeProposalBlock(poolTxs, cp, chain.BestBlockHeight() + 1, gasLeft)
                if err != nil {
                        log.WithFields(log.Fields{"module": logModule, "index": i, "error": err}).Error("failed on sub protocol txs package")
                        continue
index fd5a146..3f71715 100644 (file)
@@ -21,7 +21,7 @@ const (
 
 type Protocoler interface {
        Name() string
-       BeforeProposalBlock(nodeProgram []byte, blockHeight uint64, gasLeft int64) ([]*types.Tx, int64, error)
+       BeforeProposalBlock(txs []*types.Tx, nodeProgram []byte, blockHeight uint64, gasLeft int64) ([]*types.Tx, int64, error)
        ChainStatus() (uint64, *bc.Hash, error)
        ValidateBlock(block *types.Block, verifyResults []*bc.TxVerifyResult) error
        ValidateTxs(txs []*types.Tx, verifyResults []*bc.TxVerifyResult) error