OSDN Git Service

opt code
authorshenao78 <shenao.78@163.com>
Tue, 22 Oct 2019 10:20:02 +0000 (18:20 +0800)
committershenao78 <shenao.78@163.com>
Tue, 22 Oct 2019 10:20:02 +0000 (18:20 +0800)
application/mov/match/match.go
application/mov/match/match_test.go
application/mov/util/util.go [new file with mode: 0644]

index 32061b6..d2ca96c 100644 (file)
@@ -6,7 +6,9 @@ import (
 
        "github.com/vapor/application/mov/common"
        "github.com/vapor/application/mov/database"
+       "github.com/vapor/application/mov/util"
        "github.com/vapor/consensus/segwit"
+       "github.com/vapor/errors"
        vprMath "github.com/vapor/math"
        "github.com/vapor/protocol/bc"
        "github.com/vapor/protocol/bc/types"
@@ -14,7 +16,7 @@ import (
        "github.com/vapor/protocol/vm/vmutil"
 )
 
-var maxFeeRate = 0.1
+var maxFeeRate = 0.05
 
 type Engine struct {
        orderTable  *OrderTable
@@ -27,67 +29,84 @@ func NewEngine(movStore database.MovStore, nodeProgram []byte) *Engine {
 
 // NextMatchedTx match two opposite pending orders.
 // for example, the buy orders want change A with B, then the sell orders must change B with A.
-func (e *Engine) NextMatchedTx(buyTradePair, sellTradePair *common.TradePair) (*types.Tx, error) {
-       buyOrder := e.orderTable.PeekOrder(buyTradePair)
-       sellOrder := e.orderTable.PeekOrder(sellTradePair)
-       if buyOrder == nil || sellOrder == nil {
-               return nil, nil
+func (e *Engine) NextMatchedTx(tradePairs  ...*common.TradePair) (*types.Tx, error) {
+       if err := validateTradePairs(tradePairs); err != nil {
+               return nil, err
        }
 
-       buyContractArgs, err := segwit.DecodeP2WMCProgram(buyOrder.Utxo.ControlProgram)
-       if err != nil {
-               return nil, err
+       var orders []*common.Order
+       for _, tradePair := range tradePairs {
+               order := e.orderTable.PeekOrder(tradePair)
+               if order == nil {
+                       return nil, nil
+               }
+
+               orders = append(orders, order)
        }
 
-       sellContractArgs, err := segwit.DecodeP2WMCProgram(sellOrder.Utxo.ControlProgram)
+       tx, err := e.buildMatchTx(orders)
        if err != nil {
                return nil, err
        }
 
-       if canNotBeMatched(buyOrder, sellOrder, buyContractArgs, sellContractArgs) {
+       if tx == nil {
                return nil, nil
        }
 
-       tx, err := e.buildMatchTx(buyOrder, sellOrder, buyContractArgs, sellContractArgs)
-       if err != nil {
-               return nil, err
+       for _, tradePair := range tradePairs {
+               e.orderTable.PopOrder(tradePair)
        }
-
-       e.orderTable.PopOrder(buyTradePair)
-       e.orderTable.PopOrder(sellTradePair)
        if err := addPartialTradeOrder(tx, e.orderTable); err != nil {
                return nil, err
        }
        return tx, nil
 }
 
-func canNotBeMatched(buyOrder, sellOrder *common.Order, buyContractArgs, sellContractArgs *vmutil.MagneticContractArgs) bool {
-       if buyOrder.ToAssetID != sellOrder.FromAssetID || sellOrder.ToAssetID != buyOrder.FromAssetID {
-               return false
+func validateTradePairs(tradePairs []*common.TradePair) error {
+       if len(tradePairs) < 2 {
+               return errors.New("size of trade pairs at least 2")
        }
 
-       if buyContractArgs.RatioNumerator == 0 || sellContractArgs.RatioDenominator == 0 {
-               return false
+       for i, tradePair:= range tradePairs {
+               oppositeTradePair := tradePairs[getOppositeIndex(len(tradePairs), i)]
+               if *tradePair.FromAssetID != *oppositeTradePair.ToAssetID || *tradePair.ToAssetID != *oppositeTradePair.FromAssetID {
+                       return errors.New("specified trade pairs is invalid")
+               }
        }
-
-       buyRate := big.NewFloat(0).Quo(big.NewFloat(0).SetInt64(buyContractArgs.RatioDenominator), big.NewFloat(0).SetInt64(buyContractArgs.RatioNumerator))
-       sellRate := big.NewFloat(0).Quo(big.NewFloat(0).SetInt64(sellContractArgs.RatioNumerator), big.NewFloat(0).SetInt64(sellContractArgs.RatioDenominator))
-       return buyRate.Cmp(sellRate) < 0
+       return nil
 }
 
-func (e *Engine) buildMatchTx(buyOrder, sellOrder *common.Order, buyContractArgs, sellContractArgs *vmutil.MagneticContractArgs) (*types.Tx, error) {
+func (e *Engine) buildMatchTx(orders []*common.Order) (*types.Tx, error) {
        txData := &types.TxData{Version: 1}
-       txData.Inputs = append(txData.Inputs, types.NewSpendInput(nil, *buyOrder.Utxo.SourceID, *buyOrder.FromAssetID, buyOrder.Utxo.Amount, buyOrder.Utxo.SourcePos, buyOrder.Utxo.ControlProgram))
-       txData.Inputs = append(txData.Inputs, types.NewSpendInput(nil, *sellOrder.Utxo.SourceID, *sellOrder.FromAssetID, sellOrder.Utxo.Amount, sellOrder.Utxo.SourcePos, sellOrder.Utxo.ControlProgram))
+       var partialTradeStatus []bool
+       var receiveAmounts []uint64
 
-       isBuyPartialTrade, buyReceiveAmount, buyShouldPayAmount := addMatchTxOutput(txData, buyOrder, buyContractArgs, sellOrder.Utxo.Amount)
-       isSellPartialTrade, sellReceiveAmount, sellShouldPayAmount := addMatchTxOutput(txData, sellOrder, sellContractArgs, buyOrder.Utxo.Amount)
+       for i, order := range orders {
+               contractArgs, err := segwit.DecodeP2WMCProgram(order.Utxo.ControlProgram)
+               if err != nil {
+                       return nil, err
+               }
+
+               oppositeOrder := orders[getOppositeIndex(len(orders), i)]
+               oppositeContractArgs, err := segwit.DecodeP2WMCProgram(oppositeOrder.Utxo.ControlProgram)
+               if err != nil {
+                       return nil, err
+               }
 
-       participantPrograms := [][]byte{buyContractArgs.SellerProgram, sellContractArgs.SellerProgram}
-       e.addMatchTxFeeOutput(txData, buyShouldPayAmount, sellReceiveAmount, *buyOrder.FromAssetID, participantPrograms)
-       e.addMatchTxFeeOutput(txData, sellShouldPayAmount, buyReceiveAmount, *sellOrder.FromAssetID, participantPrograms)
+               if canNotBeMatched(contractArgs, oppositeContractArgs) {
+                       return nil, nil
+               }
 
-       setMatchTxArguments(txData, []bool{isBuyPartialTrade, isSellPartialTrade}, []uint64{buyReceiveAmount, sellReceiveAmount})
+               txData.Inputs = append(txData.Inputs, types.NewSpendInput(nil, *order.Utxo.SourceID, *order.FromAssetID, order.Utxo.Amount, order.Utxo.SourcePos, order.Utxo.ControlProgram))
+               isPartialTrade, receiveAmount := addMatchTxOutput(txData, order, contractArgs, oppositeOrder.Utxo.Amount)
+               partialTradeStatus = append(partialTradeStatus, isPartialTrade)
+               receiveAmounts = append(receiveAmounts, receiveAmount)
+       }
+
+       setMatchTxArguments(txData, partialTradeStatus, receiveAmounts)
+       if err := e.addMatchTxFeeOutput(txData); err != nil {
+               return nil, err
+       }
 
        byteData, err := txData.MarshalText()
        if err != nil {
@@ -99,8 +118,18 @@ func (e *Engine) buildMatchTx(buyOrder, sellOrder *common.Order, buyContractArgs
        return tx, nil
 }
 
+func canNotBeMatched(contractArgs, oppositeContractArgs *vmutil.MagneticContractArgs) bool {
+       if contractArgs.RatioNumerator == 0 || oppositeContractArgs.RatioDenominator == 0 {
+               return false
+       }
+
+       buyRate := big.NewFloat(0).Quo(big.NewFloat(0).SetInt64(contractArgs.RatioDenominator), big.NewFloat(0).SetInt64(contractArgs.RatioNumerator))
+       sellRate := big.NewFloat(0).Quo(big.NewFloat(0).SetInt64(oppositeContractArgs.RatioNumerator), big.NewFloat(0).SetInt64(oppositeContractArgs.RatioDenominator))
+       return buyRate.Cmp(sellRate) < 0
+}
+
 // addMatchTxOutput return whether partial matched
-func addMatchTxOutput(txData *types.TxData, order *common.Order, contractArgs *vmutil.MagneticContractArgs, oppositeAmount uint64) (bool, uint64, uint64) {
+func addMatchTxOutput(txData *types.TxData, order *common.Order, contractArgs *vmutil.MagneticContractArgs, oppositeAmount uint64) (bool, uint64) {
        requestAmount := calcRequestAmount(order.Utxo.Amount, contractArgs)
        receiveAmount := vprMath.MinUint64(requestAmount, oppositeAmount)
        shouldPayAmount := CalcShouldPayAmount(receiveAmount, contractArgs)
@@ -108,37 +137,46 @@ func addMatchTxOutput(txData *types.TxData, order *common.Order, contractArgs *v
        txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(*order.ToAssetID, receiveAmount, contractArgs.SellerProgram))
        if order.Utxo.Amount > shouldPayAmount {
                txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(*order.FromAssetID, order.Utxo.Amount-shouldPayAmount, order.Utxo.ControlProgram))
-               return true, receiveAmount, shouldPayAmount
+               return true, receiveAmount
        }
-       return false, receiveAmount, shouldPayAmount
+       return false, receiveAmount
 }
 
-func (e *Engine) addMatchTxFeeOutput(txData *types.TxData, shouldPayAmount, oppositeReceiveAmount uint64, fromAssetID bc.AssetID, participantPrograms [][]byte) {
-       if shouldPayAmount <= oppositeReceiveAmount {
-               return
-       }
-       feeAmount := shouldPayAmount - oppositeReceiveAmount
-       var reminder uint64 = 0
-       maxFeeAmount := CalcMaxFeeAmount(shouldPayAmount)
-       if feeAmount > maxFeeAmount {
-               feeAmount = maxFeeAmount
-               reminder = feeAmount - maxFeeAmount
+func (e *Engine) addMatchTxFeeOutput(txData *types.TxData) error {
+       feeAssetAmountMap, err := CalcFeeFromMatchedTx(txData)
+       if err != nil {
+               return err
        }
-       txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(fromAssetID, feeAmount, e.nodeProgram))
 
-       // There is the remaining amount after paying the handling fee, assign it evenly to participants in the transaction
-       averageAmount := reminder / uint64(len(participantPrograms))
-       if averageAmount == 0 {
-               averageAmount = 1
-       }
-       for i := 0; i < len(participantPrograms) && reminder > 0; i++ {
-               if i == len(participantPrograms)-1 {
-                       txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(fromAssetID, reminder, participantPrograms[i]))
-               } else {
-                       txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(fromAssetID, averageAmount, participantPrograms[i]))
+       for feeAssetID, amount := range feeAssetAmountMap {
+               var reminder uint64 = 0
+               feeAmount := amount.payableFeeAmount
+               if amount.payableFeeAmount > amount.maxFeeAmount {
+                       feeAmount = amount.maxFeeAmount
+                       reminder = amount.payableFeeAmount - amount.maxFeeAmount
+               }
+               txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(feeAssetID, feeAmount, e.nodeProgram))
+
+               // There is the remaining amount after paying the handling fee, assign it evenly to participants in the transaction
+               averageAmount := reminder / uint64(len(txData.Inputs))
+               if averageAmount == 0 {
+                       averageAmount = 1
+               }
+               for i := 0; i < len(txData.Inputs) && reminder > 0; i++ {
+                       contractArgs, err := segwit.DecodeP2WMCProgram(txData.Inputs[i].ControlProgram())
+                       if err != nil {
+                               return err
+                       }
+
+                       if i == len(txData.Inputs)-1 {
+                               txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(feeAssetID, reminder, contractArgs.SellerProgram))
+                       } else {
+                               txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(feeAssetID, averageAmount, contractArgs.SellerProgram))
+                       }
+                       reminder -= averageAmount
                }
-               reminder -= averageAmount
        }
+       return nil
 }
 
 func setMatchTxArguments(txData *types.TxData, partialTradeStatus []bool, receiveAmounts []uint64) {
@@ -174,6 +212,14 @@ func addPartialTradeOrder(tx *types.Tx, orderTable *OrderTable) error {
        return nil
 }
 
+func getOppositeIndex(size int, selfIdx int) int {
+       oppositeIdx := selfIdx + 1
+       if selfIdx >= size - 1 {
+               oppositeIdx = 0
+       }
+       return oppositeIdx
+}
+
 func calcRequestAmount(fromAmount uint64, contractArg *vmutil.MagneticContractArgs) uint64 {
        return uint64(int64(fromAmount) * contractArg.RatioNumerator / contractArg.RatioDenominator)
 }
@@ -185,3 +231,46 @@ func CalcShouldPayAmount(receiveAmount uint64, contractArg *vmutil.MagneticContr
 func CalcMaxFeeAmount(shouldPayAmount uint64) uint64 {
        return uint64(math.Ceil(float64(shouldPayAmount) * maxFeeRate))
 }
+
+type feeAmount struct {
+       maxFeeAmount     uint64
+       payableFeeAmount uint64
+}
+
+func CalcFeeFromMatchedTx(txData *types.TxData) (map[bc.AssetID]*feeAmount, error) {
+       assetAmountMap := make(map[bc.AssetID]*feeAmount)
+       for _, input := range txData.Inputs {
+               assetAmountMap[input.AssetID()] = &feeAmount{}
+       }
+
+       for _, input := range txData.Inputs {
+               assetAmountMap[input.AssetID()].payableFeeAmount += input.AssetAmount().Amount
+               outputPos, err := util.GetTradeReceivePosition(input)
+               if err != nil {
+                       return nil, err
+               }
+
+               receiveOutput := txData.Outputs[outputPos]
+               assetAmountMap[*receiveOutput.AssetAmount().AssetId].payableFeeAmount -= receiveOutput.AssetAmount().Amount
+               contractArgs, err := segwit.DecodeP2WMCProgram(input.ControlProgram())
+               if err != nil {
+                       return nil, err
+               }
+
+               assetAmountMap[input.AssetID()].maxFeeAmount = CalcMaxFeeAmount(CalcShouldPayAmount(receiveOutput.AssetAmount().Amount, contractArgs))
+       }
+
+       for _, output := range txData.Outputs {
+               // minus the amount of the re-order
+               if segwit.IsP2WMCScript(output.ControlProgram()) {
+                       assetAmountMap[*output.AssetAmount().AssetId].payableFeeAmount -= output.AssetAmount().Amount
+               }
+       }
+
+       for assetID, amount := range assetAmountMap {
+               if amount.payableFeeAmount == 0 {
+                       delete(assetAmountMap, assetID)
+               }
+       }
+       return assetAmountMap, nil
+}
index 709fc14..bdde9ab 100644 (file)
@@ -153,7 +153,10 @@ func TestGenerateMatchedTxs(t *testing.T) {
                                                // re-order
                                                types.NewIntraChainOutput(*orders[4].FromAssetID, 270, orders[4].Utxo.ControlProgram),
                                                // fee
-                                               types.NewIntraChainOutput(*orders[4].FromAssetID, 40, []byte{0x51}),
+                                               types.NewIntraChainOutput(*orders[4].FromAssetID, 27, []byte{0x51}),
+                                               // refund
+                                               types.NewIntraChainOutput(*orders[4].FromAssetID, 6, testutil.MustDecodeHexString("51")),
+                                               types.NewIntraChainOutput(*orders[4].FromAssetID, 7, testutil.MustDecodeHexString("55")),
                                        },
                                },
                                {
diff --git a/application/mov/util/util.go b/application/mov/util/util.go
new file mode 100644 (file)
index 0000000..6d6f87a
--- /dev/null
@@ -0,0 +1,36 @@
+package util
+
+import (
+       "encoding/hex"
+
+       "github.com/vapor/errors"
+       "github.com/vapor/protocol/bc/types"
+       "github.com/vapor/protocol/vm"
+)
+
+func IsCancelClauseSelector(input *types.TxInput) bool {
+       return len(input.Arguments()) == 3 && hex.EncodeToString(input.Arguments()[2]) == hex.EncodeToString(vm.Int64Bytes(2))
+}
+
+func IsTradeClauseSelector(input *types.TxInput) bool {
+       return IsPartialTradeClauseSelector(input) || IsFullTradeClauseSelector(input)
+}
+
+func IsPartialTradeClauseSelector(input *types.TxInput) bool {
+       return len(input.Arguments()) == 3 && hex.EncodeToString(input.Arguments()[2]) == hex.EncodeToString(vm.Int64Bytes(0))
+}
+
+func IsFullTradeClauseSelector(input *types.TxInput) bool {
+       return len(input.Arguments()) == 2 && hex.EncodeToString(input.Arguments()[1]) == hex.EncodeToString(vm.Int64Bytes(1))
+}
+
+func GetTradeReceivePosition(input *types.TxInput) (int64, error) {
+       if IsPartialTradeClauseSelector(input) {
+               return vm.AsInt64(input.Arguments()[1])
+       }
+
+       if IsFullTradeClauseSelector(input) {
+               return vm.AsInt64(input.Arguments()[0])
+       }
+       return 0, errors.New("non trade transaction input")
+}