OSDN Git Service

add trader type to input artumetns
[bytom/vapor.git] / application / mov / match / engine.go
1 package match
2
3 import (
4         "math/big"
5
6         "github.com/bytom/vapor/application/mov/common"
7         "github.com/bytom/vapor/application/mov/contract"
8         "github.com/bytom/vapor/errors"
9         vprMath "github.com/bytom/vapor/math"
10         "github.com/bytom/vapor/protocol/bc"
11         "github.com/bytom/vapor/protocol/bc/types"
12         "github.com/bytom/vapor/protocol/vm"
13 )
14
15 // Engine is used to generate math transactions
16 type Engine struct {
17         orderBook     *OrderBook
18         feeStrategy   FeeStrategy
19         rewardProgram []byte
20         blockHeight   uint64
21 }
22
23 // NewEngine return a new Engine
24 func NewEngine(orderBook *OrderBook, rewardProgram []byte) *Engine {
25         return &Engine{orderBook: orderBook, feeStrategy: NewDefaultFeeStrategy(), rewardProgram: rewardProgram}
26 }
27
28 // HasMatchedTx check does the input trade pair can generate a match deal
29 func (e *Engine) HasMatchedTx(tradePairs ...*common.TradePair) bool {
30         if err := validateTradePairs(tradePairs); err != nil {
31                 return false
32         }
33
34         orders := e.orderBook.PeekOrders(tradePairs)
35         if len(orders) == 0 {
36                 return false
37         }
38
39         return IsMatched(orders)
40 }
41
42 // NextMatchedTx return the next matchable transaction by the specified trade pairs
43 // the size of trade pairs at least 2, and the sequence of trade pairs can form a loop
44 // for example, [assetA -> assetB, assetB -> assetC, assetC -> assetA]
45 func (e *Engine) NextMatchedTx(tradePairs ...*common.TradePair) (*types.Tx, error) {
46         if !e.HasMatchedTx(tradePairs...) {
47                 return nil, errors.New("the specified trade pairs can not be matched")
48         }
49
50         tx, partialOrders, err := e.buildMatchTx(sortOrders(e.orderBook.PeekOrders(tradePairs)))
51         if err != nil {
52                 return nil, err
53         }
54
55         for _, tradePair := range tradePairs {
56                 e.orderBook.PopOrder(tradePair)
57         }
58
59         for _, order := range partialOrders {
60                 e.orderBook.AddOrder(order)
61         }
62         return tx, nil
63 }
64
65 func addMatchTxFeeOutput(txData *types.TxData, fees []*bc.AssetAmount, rewardProgram []byte) {
66         for _, feeAmount := range fees {
67                 if feeAmount.Amount != 0 {
68                         txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(*feeAmount.AssetId, feeAmount.Amount, rewardProgram))
69                 }
70         }
71 }
72
73 func addRefundOutput(txData *types.TxData, orders []*common.Order) {
74         refundAmount := map[bc.AssetID]uint64{}
75         var assetIDs []bc.AssetID
76         var refundScript [][]byte
77         for i, input := range txData.Inputs {
78                 refundAmount[input.AssetID()] += input.Amount()
79                 assetIDs = append(assetIDs, input.AssetID())
80                 refundScript = append(refundScript, orders[i].ContractArgs.SellerProgram)
81         }
82
83         for _, output := range txData.Outputs {
84                 assetAmount := output.AssetAmount()
85                 refundAmount[*assetAmount.AssetId] -= assetAmount.Amount
86         }
87
88         refundCount := len(refundScript)
89         for _, assetID := range assetIDs {
90                 amount := refundAmount[assetID]
91                 averageAmount := amount / uint64(refundCount)
92                 if averageAmount == 0 {
93                         averageAmount = 1
94                 }
95
96                 for i := 0; i < refundCount && amount > 0; i++ {
97                         if i == refundCount-1 {
98                                 averageAmount = amount
99                         }
100                         txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(assetID, averageAmount, refundScript[i]))
101                         amount -= averageAmount
102                 }
103         }
104 }
105
106 func addTakerOutput(txData *types.TxData, orders []*common.Order, priceDiffs []*bc.AssetAmount, isMakers []bool) {
107         for i, order := range orders {
108                 if isMakers[i] {
109                         continue
110                 }
111                 for _, priceDiff := range priceDiffs {
112                         if priceDiff.AssetId.String() == orders[i].FromAssetID.String() {
113                                 txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(*priceDiff.AssetId, priceDiff.Amount, order.Utxo.ControlProgram))
114                         } else {
115                                 txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(*priceDiff.AssetId, priceDiff.Amount, orders[i].ContractArgs.SellerProgram))
116                         }
117                 }
118                 break
119         }
120 }
121
122 func (e *Engine) buildMatchTx(orders []*common.Order) (*types.Tx, []*common.Order, error) {
123         txData := &types.TxData{Version: 1}
124         for _, order := range orders {
125                 input := types.NewSpendInput(nil, *order.Utxo.SourceID, *order.FromAssetID, order.Utxo.Amount, order.Utxo.SourcePos, order.Utxo.ControlProgram)
126                 txData.Inputs = append(txData.Inputs, input)
127         }
128
129         isMakers := MakerFlags(orders)
130         receivedAmounts, priceDiffs := CalcReceivedAmount(orders)
131         allocatedAssets := e.feeStrategy.Allocate(receivedAmounts, isMakers)
132
133         partialOrders, err := addMatchTxOutput(txData, orders, receivedAmounts, allocatedAssets, isMakers)
134         if err != nil {
135                 return nil, nil, err
136         }
137
138         addMatchTxFeeOutput(txData, allocatedAssets.Fees, e.rewardProgram)
139         addTakerOutput(txData, orders, priceDiffs, isMakers)
140         addRefundOutput(txData, orders)
141
142         byteData, err := txData.MarshalText()
143         if err != nil {
144                 return nil, nil, err
145         }
146
147         txData.SerializedSize = uint64(len(byteData))
148         return types.NewTx(*txData), partialOrders, nil
149 }
150
151 func addMatchTxOutput(txData *types.TxData, orders []*common.Order, receivedAmounts []*bc.AssetAmount, allocatedAssets *AllocatedAssets, isMakers []bool) ([]*common.Order, error) {
152         var partialOrders []*common.Order
153         for i, order := range orders {
154                 contractArgs := order.ContractArgs
155                 receivedAmount := receivedAmounts[i].Amount
156                 shouldPayAmount := calcShouldPayAmount(receivedAmount, contractArgs.RatioNumerator, contractArgs.RatioDenominator)
157
158                 requestAmount := CalcRequestAmount(order.Utxo.Amount, order.RatioNumerator, order.RatioDenominator)
159                 exchangeAmount := order.Utxo.Amount - shouldPayAmount
160                 isPartialTrade := requestAmount > receivedAmount && CalcRequestAmount(exchangeAmount, contractArgs.RatioNumerator, contractArgs.RatioDenominator) >= 1
161
162                 setMatchTxArguments(txData.Inputs[i], isPartialTrade, len(txData.Outputs), receivedAmount, isMakers[i])
163
164                 txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(*order.ToAssetID, allocatedAssets.Receives[i].Amount, contractArgs.SellerProgram))
165                 if isPartialTrade {
166                         txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(*order.FromAssetID, exchangeAmount, order.Utxo.ControlProgram))
167                         partialOrder, err := common.NewOrderFromOutput(&types.Tx{TxData: *txData}, len(txData.Outputs)-1, order.Sequence, order.BlockHeight)
168                         if err != nil {
169                                 return nil, err
170                         }
171
172                         partialOrders = append(partialOrders, partialOrder)
173                 }
174         }
175         return partialOrders, nil
176 }
177
178 func calcOppositeIndex(size int, selfIdx int) int {
179         return (selfIdx + 1) % size
180 }
181
182 // CalcRequestAmount is from amount * numerator / ratioDenominator
183 func CalcRequestAmount(fromAmount uint64, ratioNumerator, ratioDenominator int64) uint64 {
184         res := big.NewInt(0).SetUint64(fromAmount)
185         res.Mul(res, big.NewInt(ratioNumerator)).Quo(res, big.NewInt(ratioDenominator))
186         if !res.IsUint64() {
187                 return 0
188         }
189         return res.Uint64()
190 }
191
192 func calcShouldPayAmount(receiveAmount uint64, ratioNumerator, ratioDenominator int64) uint64 {
193         res := big.NewInt(0).SetUint64(receiveAmount)
194         res.Mul(res, big.NewInt(ratioDenominator)).Quo(res, big.NewInt(ratioNumerator))
195         if !res.IsUint64() {
196                 return 0
197         }
198         return res.Uint64()
199 }
200
201 // CalcReceivedAmount return amount of assets received by each participant in the matching transaction and the price difference
202 func CalcReceivedAmount(orders []*common.Order) ([]*bc.AssetAmount, []*bc.AssetAmount) {
203         var receivedAmounts, priceDiffs, shouldPayAmounts []*bc.AssetAmount
204         for i, order := range orders {
205                 requestAmount := CalcRequestAmount(order.Utxo.Amount, order.RatioNumerator, order.RatioDenominator)
206                 oppositeOrder := orders[calcOppositeIndex(len(orders), i)]
207                 receiveAmount := vprMath.MinUint64(oppositeOrder.Utxo.Amount, requestAmount)
208                 shouldPayAmount := calcShouldPayAmount(receiveAmount, order.RatioNumerator, order.RatioDenominator)
209                 receivedAmounts = append(receivedAmounts, &bc.AssetAmount{AssetId: order.ToAssetID, Amount: receiveAmount})
210                 shouldPayAmounts = append(shouldPayAmounts, &bc.AssetAmount{AssetId: order.FromAssetID, Amount: shouldPayAmount})
211         }
212
213         for i, receivedAmount := range receivedAmounts {
214                 oppositeShouldPayAmount := shouldPayAmounts[calcOppositeIndex(len(orders), i)]
215                 priceDiffs = append(priceDiffs, &bc.AssetAmount{AssetId: oppositeShouldPayAmount.AssetId, Amount: 0})
216                 if oppositeShouldPayAmount.Amount > receivedAmount.Amount {
217                         priceDiffs[i].Amount = oppositeShouldPayAmount.Amount - receivedAmount.Amount
218                 }
219         }
220         return receivedAmounts, priceDiffs
221 }
222
223 // IsMatched check does the orders can be exchange
224 func IsMatched(orders []*common.Order) bool {
225         sortedOrders := sortOrders(orders)
226         if len(sortedOrders) == 0 {
227                 return false
228         }
229
230         product := big.NewRat(1, 1)
231         for _, order := range orders {
232                 product.Mul(product, big.NewRat(order.RatioNumerator, order.RatioDenominator))
233         }
234         one := big.NewRat(1, 1)
235         return product.Cmp(one) <= 0
236 }
237
238 // MakerFlags return a slice of array indicate whether orders[i] is maker
239 func MakerFlags(orders []*common.Order) []bool {
240         isMakers := make([]bool, len(orders))
241         for i, order := range orders {
242                 isMakers[i] = isMaker(order, orders[calcOppositeIndex(i, len(orders))])
243         }
244         return isMakers
245 }
246
247 func isMaker(order, oppositeOrder *common.Order) bool {
248         // old version of order's block height and tx sequence is 0
249         if order.BlockHeight == 0 && oppositeOrder.BlockHeight != 0 {
250                 return true
251         }
252         if order.BlockHeight != 0 && oppositeOrder.BlockHeight == 0 {
253                 return false
254         }
255         if order.BlockHeight == 0 && oppositeOrder.BlockHeight == 0 {
256                 return order.UTXOHash().String() < oppositeOrder.UTXOHash().String()
257         }
258         if order.BlockHeight == oppositeOrder.BlockHeight {
259                 return order.Sequence < oppositeOrder.Sequence
260         }
261         return order.BlockHeight < oppositeOrder.BlockHeight
262 }
263
264 func setMatchTxArguments(txInput *types.TxInput, isPartialTrade bool, position int, receiveAmounts uint64, isMaker bool) {
265         traderType := contract.Taker
266         if isMaker {
267                 traderType = contract.Maker
268         }
269
270         var arguments [][]byte
271         if isPartialTrade {
272                 arguments = [][]byte{vm.Int64Bytes(int64(receiveAmounts)), vm.Int64Bytes(int64(position)), vm.Int64Bytes(contract.PartialTradeClauseSelector), vm.Int64Bytes(traderType)}
273         } else {
274                 arguments = [][]byte{vm.Int64Bytes(int64(position)), vm.Int64Bytes(contract.FullTradeClauseSelector), vm.Int64Bytes(traderType)}
275         }
276         txInput.SetArguments(arguments)
277 }
278
279 func sortOrders(orders []*common.Order) []*common.Order {
280         if len(orders) == 0 {
281                 return nil
282         }
283
284         orderMap := make(map[bc.AssetID]*common.Order)
285         firstOrder := orders[0]
286         for i := 1; i < len(orders); i++ {
287                 orderMap[*orders[i].FromAssetID] = orders[i]
288         }
289
290         sortedOrders := []*common.Order{firstOrder}
291         for order := firstOrder; *order.ToAssetID != *firstOrder.FromAssetID; {
292                 nextOrder, ok := orderMap[*order.ToAssetID]
293                 if !ok {
294                         return nil
295                 }
296
297                 sortedOrders = append(sortedOrders, nextOrder)
298                 order = nextOrder
299         }
300         return sortedOrders
301 }
302
303 func validateTradePairs(tradePairs []*common.TradePair) error {
304         if len(tradePairs) < 2 {
305                 return errors.New("size of trade pairs at least 2")
306         }
307
308         assetMap := make(map[string]bool)
309         for _, tradePair := range tradePairs {
310                 assetMap[tradePair.FromAssetID.String()] = true
311                 if *tradePair.FromAssetID == *tradePair.ToAssetID {
312                         return errors.New("from asset id can't equal to asset id")
313                 }
314         }
315
316         for _, tradePair := range tradePairs {
317                 key := tradePair.ToAssetID.String()
318                 if _, ok := assetMap[key]; !ok {
319                         return errors.New("invalid trade pairs")
320                 }
321                 delete(assetMap, key)
322         }
323         return nil
324 }