OSDN Git Service

126810d1d816dd26ededad8b7fcaf6ddae950396
[bytom/vapor.git] / application / mov / match / engine.go
1 package match
2
3 import (
4         "math/big"
5
6         "github.com/bytom/vapor/application/mov/common"
7         "github.com/bytom/vapor/application/mov/contract"
8         "github.com/bytom/vapor/consensus/segwit"
9         "github.com/bytom/vapor/errors"
10         vprMath "github.com/bytom/vapor/math"
11         "github.com/bytom/vapor/protocol/bc"
12         "github.com/bytom/vapor/protocol/bc/types"
13         "github.com/bytom/vapor/protocol/vm"
14 )
15
16 // Engine is used to generate math transactions
17 type Engine struct {
18         orderBook     *OrderBook
19         feeStrategy   FeeStrategy
20         rewardProgram []byte
21 }
22
23 // NewEngine return a new Engine
24 func NewEngine(orderBook *OrderBook, feeStrategy FeeStrategy, rewardProgram []byte) *Engine {
25         return &Engine{orderBook: orderBook, feeStrategy: feeStrategy, rewardProgram: rewardProgram}
26 }
27
28 // HasMatchedTx check does the input trade pair can generate a match deal
29 func (e *Engine) HasMatchedTx(tradePairs ...*common.TradePair) bool {
30         if err := validateTradePairs(tradePairs); err != nil {
31                 return false
32         }
33
34         orders := e.orderBook.PeekOrders(tradePairs)
35         if len(orders) == 0 {
36                 return false
37         }
38
39         return IsMatched(orders)
40 }
41
42 // NextMatchedTx return the next matchable transaction by the specified trade pairs
43 // the size of trade pairs at least 2, and the sequence of trade pairs can form a loop
44 // for example, [assetA -> assetB, assetB -> assetC, assetC -> assetA]
45 func (e *Engine) NextMatchedTx(tradePairs ...*common.TradePair) (*types.Tx, error) {
46         if !e.HasMatchedTx(tradePairs...) {
47                 return nil, errors.New("the specified trade pairs can not be matched")
48         }
49
50         tx, err := e.buildMatchTx(sortOrders(e.orderBook.PeekOrders(tradePairs)))
51         if err != nil {
52                 return nil, err
53         }
54
55         for _, tradePair := range tradePairs {
56                 e.orderBook.PopOrder(tradePair)
57         }
58
59         if err := e.addPartialTradeOrder(tx); err != nil {
60                 return nil, err
61         }
62         return tx, nil
63 }
64
65 func (e *Engine) addMatchTxFeeOutput(txData *types.TxData, fees []*bc.AssetAmount) error {
66         for _, feeAmount := range fees {
67                 txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(*feeAmount.AssetId, feeAmount.Amount, e.rewardProgram))
68         }
69
70         refoundAmount := map[bc.AssetID]uint64{}
71         assetIDs := []bc.AssetID{}
72         refoundScript := [][]byte{}
73         for _, input := range txData.Inputs {
74                 refoundAmount[input.AssetID()] += input.Amount()
75                 contractArgs, err := segwit.DecodeP2WMCProgram(input.ControlProgram())
76                 if err != nil {
77                         return err
78                 }
79
80                 assetIDs = append(assetIDs, input.AssetID())
81                 refoundScript = append(refoundScript, contractArgs.SellerProgram)
82         }
83
84         for _, output := range txData.Outputs {
85                 assetAmount := output.AssetAmount()
86                 refoundAmount[*assetAmount.AssetId] -= assetAmount.Amount
87         }
88
89         refoundCount := len(refoundScript)
90         for _, assetID := range assetIDs {
91                 amount := refoundAmount[assetID]
92                 averageAmount := amount / uint64(refoundCount)
93                 if averageAmount == 0 {
94                         averageAmount = 1
95                 }
96
97                 for i := 0; i < refoundCount && amount > 0; i++ {
98                         if i == refoundCount-1 {
99                                 averageAmount = amount
100                         }
101                         txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(assetID, averageAmount, refoundScript[i]))
102                         amount -= averageAmount
103                 }
104         }
105         return nil
106 }
107
108 func (e *Engine) addPartialTradeOrder(tx *types.Tx) error {
109         for i, output := range tx.Outputs {
110                 if !segwit.IsP2WMCScript(output.ControlProgram()) || output.AssetAmount().Amount == 0 {
111                         continue
112                 }
113
114                 order, err := common.NewOrderFromOutput(tx, i)
115                 if err != nil {
116                         return err
117                 }
118
119                 e.orderBook.AddOrder(order)
120         }
121         return nil
122 }
123
124 func (e *Engine) buildMatchTx(orders []*common.Order) (*types.Tx, error) {
125         txData := &types.TxData{Version: 1}
126         for _, order := range orders {
127                 input := types.NewSpendInput(nil, *order.Utxo.SourceID, *order.FromAssetID, order.Utxo.Amount, order.Utxo.SourcePos, order.Utxo.ControlProgram)
128                 txData.Inputs = append(txData.Inputs, input)
129         }
130
131         receivedAmounts, priceDiffs := CalcReceivedAmount(orders)
132         allocatedAssets := e.feeStrategy.Allocate(receivedAmounts, priceDiffs)
133         if err := addMatchTxOutput(txData, orders, receivedAmounts, allocatedAssets); err != nil {
134                 return nil, err
135         }
136
137         if err := e.addMatchTxFeeOutput(txData, allocatedAssets.Fees); err != nil {
138                 return nil, err
139         }
140
141         byteData, err := txData.MarshalText()
142         if err != nil {
143                 return nil, err
144         }
145
146         txData.SerializedSize = uint64(len(byteData))
147         return types.NewTx(*txData), nil
148 }
149
150 func addMatchTxOutput(txData *types.TxData, orders []*common.Order, receivedAmounts []*bc.AssetAmount, allocatedAssets *AllocatedAssets) error {
151         for i, order := range orders {
152                 contractArgs, err := segwit.DecodeP2WMCProgram(order.Utxo.ControlProgram)
153                 if err != nil {
154                         return err
155                 }
156
157                 receivedAmount := receivedAmounts[i].Amount
158                 shouldPayAmount := calcShouldPayAmount(receivedAmount, contractArgs.RatioNumerator, contractArgs.RatioDenominator)
159
160                 requestAmount := CalcRequestAmount(order.Utxo.Amount, order.RatioNumerator, order.RatioDenominator)
161                 exchangeAmount := order.Utxo.Amount - shouldPayAmount
162                 isPartialTrade := requestAmount > receivedAmount && CalcRequestAmount(exchangeAmount, contractArgs.RatioNumerator, contractArgs.RatioDenominator) >= 1
163
164                 setMatchTxArguments(txData.Inputs[i], isPartialTrade, len(txData.Outputs), receivedAmount)
165                 txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(*order.ToAssetID, allocatedAssets.Receives[i].Amount, contractArgs.SellerProgram))
166                 if isPartialTrade {
167                         txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(*order.FromAssetID, exchangeAmount, order.Utxo.ControlProgram))
168                 }
169         }
170         return nil
171 }
172
173 func calcOppositeIndex(size int, selfIdx int) int {
174         return (selfIdx + 1) % size
175 }
176
177 // CalcRequestAmount is from amount * numerator / ratioDenominator
178 func CalcRequestAmount(fromAmount uint64, ratioNumerator, ratioDenominator int64) uint64 {
179         res := big.NewInt(0).SetUint64(fromAmount)
180         res.Mul(res, big.NewInt(ratioNumerator)).Quo(res, big.NewInt(ratioDenominator))
181         if !res.IsUint64() {
182                 return 0
183         }
184         return res.Uint64()
185 }
186
187 func calcShouldPayAmount(receiveAmount uint64, ratioNumerator, ratioDenominator int64) uint64 {
188         res := big.NewInt(0).SetUint64(receiveAmount)
189         res.Mul(res, big.NewInt(ratioDenominator)).Quo(res, big.NewInt(ratioNumerator))
190         if !res.IsUint64() {
191                 return 0
192         }
193         return res.Uint64()
194 }
195
196 // CalcReceivedAmount return amount of assets received by each participant in the matching transaction and the price difference
197 func CalcReceivedAmount(orders []*common.Order) ([]*bc.AssetAmount, []*bc.AssetAmount) {
198         var receivedAmounts, priceDiffs, shouldPayAmounts []*bc.AssetAmount
199         for i, order := range orders {
200                 requestAmount := CalcRequestAmount(order.Utxo.Amount, order.RatioNumerator, order.RatioDenominator)
201                 oppositeOrder := orders[calcOppositeIndex(len(orders), i)]
202                 receiveAmount := vprMath.MinUint64(oppositeOrder.Utxo.Amount, requestAmount)
203                 shouldPayAmount := calcShouldPayAmount(receiveAmount, order.RatioNumerator, order.RatioDenominator)
204                 receivedAmounts = append(receivedAmounts, &bc.AssetAmount{AssetId: order.ToAssetID, Amount: receiveAmount})
205                 shouldPayAmounts = append(shouldPayAmounts, &bc.AssetAmount{AssetId: order.FromAssetID, Amount: shouldPayAmount})
206         }
207
208         for i, receivedAmount := range receivedAmounts {
209                 oppositeShouldPayAmount := shouldPayAmounts[calcOppositeIndex(len(orders), i)]
210                 priceDiffs = append(priceDiffs, &bc.AssetAmount{AssetId: oppositeShouldPayAmount.AssetId, Amount: 0})
211                 if oppositeShouldPayAmount.Amount > receivedAmount.Amount {
212                         priceDiffs[i].Amount = oppositeShouldPayAmount.Amount - receivedAmount.Amount
213                 }
214         }
215         return receivedAmounts, priceDiffs
216 }
217
218 // IsMatched check does the orders can be exchange
219 func IsMatched(orders []*common.Order) bool {
220         sortedOrders := sortOrders(orders)
221         if len(sortedOrders) == 0 {
222                 return false
223         }
224
225         product := big.NewRat(1, 1)
226         for _, order := range orders {
227                 product.Mul(product, big.NewRat(order.RatioNumerator, order.RatioDenominator))
228         }
229         one := big.NewRat(1, 1)
230         return product.Cmp(one) <= 0
231 }
232
233 func setMatchTxArguments(txInput *types.TxInput, isPartialTrade bool, position int, receiveAmounts uint64) {
234         var arguments [][]byte
235         if isPartialTrade {
236                 arguments = [][]byte{vm.Int64Bytes(int64(receiveAmounts)), vm.Int64Bytes(int64(position)), vm.Int64Bytes(contract.PartialTradeClauseSelector)}
237         } else {
238                 arguments = [][]byte{vm.Int64Bytes(int64(position)), vm.Int64Bytes(contract.FullTradeClauseSelector)}
239         }
240         txInput.SetArguments(arguments)
241 }
242
243 func sortOrders(orders []*common.Order) []*common.Order {
244         if len(orders) == 0 {
245                 return nil
246         }
247
248         orderMap := make(map[bc.AssetID]*common.Order)
249         firstOrder := orders[0]
250         for i := 1; i < len(orders); i++ {
251                 orderMap[*orders[i].FromAssetID] = orders[i]
252         }
253
254         sortedOrders := []*common.Order{firstOrder}
255         for order := firstOrder; *order.ToAssetID != *firstOrder.FromAssetID; {
256                 nextOrder, ok := orderMap[*order.ToAssetID]
257                 if !ok {
258                         return nil
259                 }
260
261                 sortedOrders = append(sortedOrders, nextOrder)
262                 order = nextOrder
263         }
264         return sortedOrders
265 }
266
267 func validateTradePairs(tradePairs []*common.TradePair) error {
268         if len(tradePairs) < 2 {
269                 return errors.New("size of trade pairs at least 2")
270         }
271
272         assetMap := make(map[string]bool)
273         for _, tradePair := range tradePairs {
274                 assetMap[tradePair.FromAssetID.String()] = true
275                 if *tradePair.FromAssetID == *tradePair.ToAssetID {
276                         return errors.New("from asset id can't equal to asset id")
277                 }
278         }
279
280         for _, tradePair := range tradePairs {
281                 key := tradePair.ToAssetID.String()
282                 if _, ok := assetMap[key]; !ok {
283                         return errors.New("invalid trade pairs")
284                 }
285                 delete(assetMap, key)
286         }
287         return nil
288 }