OSDN Git Service

fix match collector (#481)
[bytom/vapor.git] / application / mov / match_collector.go
1 package mov
2
3 import (
4         "runtime"
5         "sync"
6
7         "github.com/bytom/vapor/application/mov/common"
8         "github.com/bytom/vapor/application/mov/database"
9         "github.com/bytom/vapor/application/mov/match"
10         "github.com/bytom/vapor/protocol/bc/types"
11 )
12
13 type matchCollector struct {
14         engine            *match.Engine
15         tradePairIterator *database.TradePairIterator
16         gasLeft           int64
17         isTimeout         func() bool
18
19         workerNum   int
20         workerNumCh chan int
21         processCh   chan *matchTxResult
22         tradePairCh chan *common.TradePair
23         closeCh     chan struct{}
24 }
25
26 type matchTxResult struct {
27         matchedTx *types.Tx
28         err       error
29 }
30
31 func newMatchTxCollector(engine *match.Engine, iterator *database.TradePairIterator, gasLeft int64, isTimeout func() bool) *matchCollector {
32         workerNum := runtime.NumCPU()
33         return &matchCollector{
34                 engine:            engine,
35                 tradePairIterator: iterator,
36                 workerNum:         workerNum,
37                 workerNumCh:       make(chan int, workerNum),
38                 processCh:         make(chan *matchTxResult, 32),
39                 tradePairCh:       make(chan *common.TradePair, workerNum),
40                 closeCh:           make(chan struct{}),
41                 gasLeft:           gasLeft,
42                 isTimeout:         isTimeout,
43         }
44 }
45
46 func (m *matchCollector) result() ([]*types.Tx, error) {
47         var wg sync.WaitGroup
48         for i := 0; i < int(m.workerNum); i++ {
49                 wg.Add(1)
50                 go m.matchTxWorker(&wg)
51         }
52
53         wg.Add(1)
54         go m.tradePairProducer(&wg)
55
56         matchedTxs, err := m.collect()
57         // wait for all goroutine release
58         wg.Wait()
59         return matchedTxs, err
60 }
61
62 func (m *matchCollector) collect() ([]*types.Tx, error) {
63         defer close(m.closeCh)
64
65         var matchedTxs []*types.Tx
66         appendMatchedTxs := func(data *matchTxResult) bool {
67                 gasUsed := calcMatchedTxGasUsed(data.matchedTx)
68                 if m.gasLeft -= gasUsed; m.gasLeft >= 0 {
69                         matchedTxs = append(matchedTxs, data.matchedTx)
70                         return false
71                 }
72                 return true
73         }
74
75         completed := 0
76         for !m.isTimeout() {
77                 select {
78                 case data := <-m.processCh:
79                         if data.err != nil {
80                                 return nil, data.err
81                         }
82
83                         if done := appendMatchedTxs(data); done {
84                                 return matchedTxs, nil
85                         }
86                 case <-m.workerNumCh:
87                         if completed++; completed == m.workerNum {
88                                 // read the remaining process results
89                                 close(m.processCh)
90                                 for data := range m.processCh {
91                                         if data.err != nil {
92                                                 return nil, data.err
93                                         }
94
95                                         appendMatchedTxs(data)
96                                 }
97                                 return matchedTxs, nil
98                         }
99                 }
100         }
101         return matchedTxs, nil
102 }
103
104 func (m *matchCollector) tradePairProducer(wg *sync.WaitGroup) {
105         defer func() {
106                 close(m.tradePairCh)
107                 wg.Done()
108         }()
109
110         tradePairMap := make(map[string]bool)
111
112         for m.tradePairIterator.HasNext() {
113                 tradePair := m.tradePairIterator.Next()
114                 if tradePairMap[tradePair.Key()] {
115                         continue
116                 }
117
118                 tradePairMap[tradePair.Key()] = true
119                 tradePairMap[tradePair.Reverse().Key()] = true
120
121                 select {
122                 case <-m.closeCh:
123                         return
124                 case m.tradePairCh <- tradePair:
125                 }
126         }
127 }
128
129 func (m *matchCollector) matchTxWorker(wg *sync.WaitGroup) {
130         dispatchData := func(data *matchTxResult) bool {
131                 select {
132                 case <-m.closeCh:
133                         return true
134                 case m.processCh <- data:
135                         if data.err != nil {
136                                 return true
137                         }
138                         return false
139                 }
140         }
141
142         defer func() {
143                 m.workerNumCh <- 1
144                 wg.Done()
145         }()
146         for {
147                 select {
148                 case <-m.closeCh:
149                         return
150                 case tradePair := <-m.tradePairCh:
151                         if tradePair == nil {
152                                 return
153                         }
154                         for m.engine.HasMatchedTx(tradePair, tradePair.Reverse()) {
155                                 matchedTx, err := m.engine.NextMatchedTx(tradePair, tradePair.Reverse())
156                                 if done := dispatchData(&matchTxResult{matchedTx: matchedTx, err: err}); done {
157                                         return
158                                 }
159                         }
160                 }
161
162         }
163 }