import (
"sort"
+ "sync"
"github.com/bytom/vapor/application/mov/common"
"github.com/bytom/vapor/application/mov/database"
type OrderBook struct {
movStore database.MovStore
// key of tradePair -> []order
- dbOrders map[string][]*common.Order
+ dbOrders *sync.Map
// key of tradePair -> iterator
- orderIterators map[string]*database.OrderIterator
+ orderIterators *sync.Map
// key of tradePair -> []order
- arrivalAddOrders map[string][]*common.Order
+ arrivalAddOrders *sync.Map
// key of order -> order
- arrivalDelOrders map[string]*common.Order
+ arrivalDelOrders *sync.Map
}
// NewOrderBook create a new OrderBook object
func NewOrderBook(movStore database.MovStore, arrivalAddOrders, arrivalDelOrders []*common.Order) *OrderBook {
return &OrderBook{
movStore: movStore,
- dbOrders: make(map[string][]*common.Order),
- orderIterators: make(map[string]*database.OrderIterator),
+ dbOrders: &sync.Map{},
+ orderIterators: &sync.Map{},
arrivalAddOrders: arrangeArrivalAddOrders(arrivalAddOrders),
arrivalDelOrders: arrangeArrivalDelOrders(arrivalDelOrders),
// AddOrder add the in memory temp order to order table
func (o *OrderBook) AddOrder(order *common.Order) error {
tradePairKey := order.TradePair().Key()
- orders := o.arrivalAddOrders[tradePairKey]
- if len(orders) > 0 && order.Rate() > orders[len(orders)-1].Rate() {
+ orders := o.getArrivalAddOrders(tradePairKey)
+ if len(orders) > 0 && order.Cmp(orders[len(orders)-1]) > 0 {
return errors.New("rate of order must less than the min order in order table")
}
- o.arrivalAddOrders[tradePairKey] = append(orders, order)
+ orders = append(orders, order)
+ o.arrivalAddOrders.Store(tradePairKey, orders)
return nil
}
// PeekOrder return the next lowest order of given trade pair
func (o *OrderBook) PeekOrder(tradePair *common.TradePair) *common.Order {
- if len(o.dbOrders[tradePair.Key()]) == 0 {
+ if len(o.getDBOrders(tradePair.Key())) == 0 {
o.extendDBOrders(tradePair)
}
var nextOrder *common.Order
- orders := o.dbOrders[tradePair.Key()]
+ orders := o.getDBOrders(tradePair.Key())
if len(orders) != 0 {
nextOrder = orders[len(orders)-1]
}
- if nextOrder != nil && o.arrivalDelOrders[nextOrder.Key()] != nil {
- o.dbOrders[tradePair.Key()] = orders[0 : len(orders)-1]
+ if nextOrder != nil && o.getArrivalDelOrders(nextOrder.Key()) != nil {
+ o.dbOrders.Store(tradePair.Key(), orders[0:len(orders)-1])
return o.PeekOrder(tradePair)
}
arrivalOrder := o.peekArrivalOrder(tradePair)
- if nextOrder == nil || (arrivalOrder != nil && arrivalOrder.Rate() < nextOrder.Rate()) {
+ if nextOrder == nil || (arrivalOrder != nil && arrivalOrder.Cmp(nextOrder) < 0) {
nextOrder = arrivalOrder
}
return nextOrder
return
}
- orders := o.dbOrders[tradePair.Key()]
+ orders := o.getDBOrders(tradePair.Key())
if len(orders) != 0 && orders[len(orders)-1].Key() == order.Key() {
- o.dbOrders[tradePair.Key()] = orders[0 : len(orders)-1]
+ o.dbOrders.Store(tradePair.Key(), orders[0 : len(orders)-1])
}
- arrivalOrders := o.arrivalAddOrders[tradePair.Key()]
+ arrivalOrders := o.getArrivalAddOrders(tradePair.Key())
if len(arrivalOrders) != 0 && arrivalOrders[len(arrivalOrders)-1].Key() == order.Key() {
- o.arrivalAddOrders[tradePair.Key()] = arrivalOrders[0 : len(arrivalOrders)-1]
+ o.arrivalAddOrders.Store(tradePair.Key(), arrivalOrders[0 : len(arrivalOrders)-1])
}
}
return orders
}
-func arrangeArrivalAddOrders(orders []*common.Order) map[string][]*common.Order {
- arrivalAddOrderMap := make(map[string][]*common.Order)
+func (o *OrderBook) getDBOrders(tradePairKey string) []*common.Order {
+ if orders, ok := o.dbOrders.Load(tradePairKey); ok {
+ return orders.([]*common.Order)
+ }
+ return []*common.Order{}
+}
+
+func (o *OrderBook) getArrivalAddOrders(tradePairKey string) []*common.Order {
+ if orders, ok := o.arrivalAddOrders.Load(tradePairKey); ok {
+ return orders.([]*common.Order)
+ }
+ return []*common.Order{}
+}
+
+func (o *OrderBook) getArrivalDelOrders(orderKey string) *common.Order {
+ if order, ok := o.arrivalDelOrders.Load(orderKey); ok {
+ return order.(*common.Order)
+ }
+ return nil
+}
+
+func arrangeArrivalAddOrders(orders []*common.Order) *sync.Map {
+ orderMap := make(map[string][]*common.Order)
for _, order := range orders {
- arrivalAddOrderMap[order.TradePair().Key()] = append(arrivalAddOrderMap[order.TradePair().Key()], order)
+ orderMap[order.TradePair().Key()] = append(orderMap[order.TradePair().Key()], order)
}
- for _, orders := range arrivalAddOrderMap {
+ arrivalOrderMap := &sync.Map{}
+ for key, orders := range orderMap {
sort.Sort(sort.Reverse(common.OrderSlice(orders)))
+ arrivalOrderMap.Store(key, orders)
+
}
- return arrivalAddOrderMap
+ return arrivalOrderMap
}
-func arrangeArrivalDelOrders(orders []*common.Order) map[string]*common.Order {
- arrivalDelOrderMap := make(map[string]*common.Order)
+func arrangeArrivalDelOrders(orders []*common.Order) *sync.Map {
+ arrivalDelOrderMap := &sync.Map{}
for _, order := range orders {
- arrivalDelOrderMap[order.Key()] = order
+ arrivalDelOrderMap.Store(order.Key(), order)
}
return arrivalDelOrderMap
}
func (o *OrderBook) extendDBOrders(tradePair *common.TradePair) {
- iterator, ok := o.orderIterators[tradePair.Key()]
+ iterator, ok := o.orderIterators.Load(tradePair.Key())
if !ok {
iterator = database.NewOrderIterator(o.movStore, tradePair)
- o.orderIterators[tradePair.Key()] = iterator
+ o.orderIterators.Store(tradePair.Key(), iterator)
}
- nextOrders := iterator.NextBatch()
+ nextOrders := iterator.(*database.OrderIterator).NextBatch()
+ orders := o.getDBOrders(tradePair.Key())
for i := len(nextOrders) - 1; i >= 0; i-- {
- o.dbOrders[tradePair.Key()] = append(o.dbOrders[tradePair.Key()], nextOrders[i])
+ orders = append(orders, nextOrders[i])
}
+ o.dbOrders.Store(tradePair.Key(), orders)
}
func (o *OrderBook) peekArrivalOrder(tradePair *common.TradePair) *common.Order {
- if arrivalAddOrders := o.arrivalAddOrders[tradePair.Key()]; len(arrivalAddOrders) > 0 {
+ if arrivalAddOrders := o.getArrivalAddOrders(tradePair.Key()); len(arrivalAddOrders) > 0 {
return arrivalAddOrders[len(arrivalAddOrders)-1]
}
return nil
--- /dev/null
+package mov
+
+import (
+ "runtime"
+ "sync"
+
+ "github.com/bytom/vapor/application/mov/common"
+ "github.com/bytom/vapor/application/mov/database"
+ "github.com/bytom/vapor/application/mov/match"
+ "github.com/bytom/vapor/protocol/bc/types"
+)
+
+type matchCollector struct {
+ engine *match.Engine
+ tradePairIterator *database.TradePairIterator
+ gasLeft int64
+ isTimeout func() bool
+
+ workerNum int
+ workerNumCh chan int
+ processCh chan *matchTxResult
+ tradePairCh chan *common.TradePair
+ closeCh chan struct{}
+}
+
+type matchTxResult struct {
+ matchedTx *types.Tx
+ err error
+}
+
+func newMatchTxCollector(engine *match.Engine, iterator *database.TradePairIterator, gasLeft int64, isTimeout func() bool) *matchCollector {
+ workerNum := runtime.NumCPU()
+ return &matchCollector{
+ engine: engine,
+ tradePairIterator: iterator,
+ workerNum: workerNum,
+ workerNumCh: make(chan int, workerNum),
+ processCh: make(chan *matchTxResult, 32),
+ tradePairCh: make(chan *common.TradePair, workerNum),
+ closeCh: make(chan struct{}),
+ gasLeft: gasLeft,
+ isTimeout: isTimeout,
+ }
+}
+
+func (m *matchCollector) result() ([]*types.Tx, error) {
+ var wg sync.WaitGroup
+ for i := 0; i < int(m.workerNum); i++ {
+ wg.Add(1)
+ go m.matchTxWorker(&wg)
+ }
+
+ wg.Add(1)
+ go m.tradePairProducer(&wg)
+
+ matchedTxs, err := m.collect()
+ // wait for all goroutine release
+ wg.Wait()
+ return matchedTxs, err
+}
+
+func (m *matchCollector) collect() ([]*types.Tx, error) {
+ defer close(m.closeCh)
+
+ var matchedTxs []*types.Tx
+ completed := 0
+ for !m.isTimeout() {
+ select {
+ case data := <-m.processCh:
+ if data.err != nil {
+ return nil, data.err
+ }
+
+ gasUsed := calcMatchedTxGasUsed(data.matchedTx)
+ if m.gasLeft -= gasUsed; m.gasLeft >= 0 {
+ matchedTxs = append(matchedTxs, data.matchedTx)
+ } else {
+ return matchedTxs, nil
+ }
+ case <-m.workerNumCh:
+ if completed++; completed == m.workerNum {
+ return matchedTxs, nil
+ }
+ }
+ }
+ return matchedTxs, nil
+}
+
+func (m *matchCollector) tradePairProducer(wg *sync.WaitGroup) {
+ defer func() {
+ close(m.tradePairCh)
+ wg.Done()
+ }()
+
+ tradePairMap := make(map[string]bool)
+
+ for m.tradePairIterator.HasNext() {
+ tradePair := m.tradePairIterator.Next()
+ if tradePairMap[tradePair.Key()] {
+ continue
+ }
+
+ tradePairMap[tradePair.Key()] = true
+ tradePairMap[tradePair.Reverse().Key()] = true
+
+ select {
+ case <-m.closeCh:
+ return
+ case m.tradePairCh <- tradePair:
+ }
+ }
+}
+
+func (m *matchCollector) matchTxWorker(wg *sync.WaitGroup) {
+ dispatchData := func(data *matchTxResult) bool {
+ select {
+ case <-m.closeCh:
+ return true
+ case m.processCh <- data:
+ if data.err != nil {
+ return true
+ }
+ return false
+ }
+ }
+
+ defer func() {
+ m.workerNumCh <- 1
+ wg.Done()
+ }()
+ for {
+ select {
+ case <-m.closeCh:
+ return
+ case tradePair := <-m.tradePairCh:
+ if tradePair == nil {
+ return
+ }
+ for m.engine.HasMatchedTx(tradePair, tradePair.Reverse()) {
+ matchedTx, err := m.engine.NextMatchedTx(tradePair, tradePair.Reverse())
+ if done := dispatchData(&matchTxResult{matchedTx: matchedTx, err: err}); done {
+ return
+ }
+ }
+ }
+
+ }
+}
}
matchEngine := match.NewEngine(orderBook, maxFeeRate, nodeProgram)
- tradePairMap := make(map[string]bool)
tradePairIterator := database.NewTradePairIterator(m.movStore)
-
- var packagedTxs []*types.Tx
- for gasLeft > 0 && !isTimeout() && tradePairIterator.HasNext() {
- tradePair := tradePairIterator.Next()
- if tradePairMap[tradePair.Key()] {
- continue
- }
- tradePairMap[tradePair.Key()] = true
- tradePairMap[tradePair.Reverse().Key()] = true
-
- for gasLeft > 0 && !isTimeout() && matchEngine.HasMatchedTx(tradePair, tradePair.Reverse()) {
- matchedTx, err := matchEngine.NextMatchedTx(tradePair, tradePair.Reverse())
- if err != nil {
- return nil, err
- }
-
- gasUsed := calcMatchedTxGasUsed(matchedTx)
- if gasLeft-gasUsed >= 0 {
- packagedTxs = append(packagedTxs, matchedTx)
- }
- gasLeft -= gasUsed
- }
- }
- return packagedTxs, nil
+ matchCollector := newMatchTxCollector(matchEngine, tradePairIterator, gasLeft, isTimeout)
+ return matchCollector.result()
}
// ChainStatus return the current block height and block hash in dex core
}
func (m *MovCore) validateMatchedTxSequence(txs []*types.Tx) error {
+ var matchedTxs []*types.Tx
+ for _, tx := range txs {
+ if common.IsMatchedTx(tx) {
+ matchedTxs = append(matchedTxs, tx)
+ }
+ }
+
+ if len(matchedTxs) == 0 {
+ return nil
+ }
+
orderBook, err := buildOrderBook(m.movStore, txs)
if err != nil {
return err
}
- for _, matchedTx := range txs {
- if !common.IsMatchedTx(matchedTx) {
- continue
- }
-
+ for _, matchedTx := range matchedTxs {
tradePairs, err := getTradePairsFromMatchedTx(matchedTx)
if err != nil {
return err
return nil
}
-
func validateSpendOrders(matchedTx *types.Tx, orders []*common.Order) error {
spendOutputIDs := make(map[string]bool)
for _, input := range matchedTx.Inputs {