OSDN Git Service

opt code
[bytom/vapor.git] / application / mov / match / match.go
index c09ec88..3dc4f6f 100644 (file)
@@ -6,7 +6,6 @@ import (
 
        "github.com/vapor/application/mov/common"
        "github.com/vapor/application/mov/contract"
-       "github.com/vapor/application/mov/database"
        "github.com/vapor/consensus/segwit"
        "github.com/vapor/errors"
        vprMath "github.com/vapor/math"
@@ -16,15 +15,14 @@ import (
        "github.com/vapor/protocol/vm/vmutil"
 )
 
-const maxFeeRate = 0.05
-
 type Engine struct {
        orderTable  *OrderTable
+       maxFeeRate  float64
        nodeProgram []byte
 }
 
-func NewEngine(movStore database.MovStore, nodeProgram []byte) *Engine {
-       return &Engine{orderTable: NewOrderTable(movStore), nodeProgram: nodeProgram}
+func NewEngine(orderTable *OrderTable, maxFeeRate float64, nodeProgram []byte) *Engine {
+       return &Engine{orderTable: orderTable, maxFeeRate: maxFeeRate, nodeProgram: nodeProgram}
 }
 
 func (e *Engine) HasMatchedTx(tradePairs ...*common.TradePair) bool {
@@ -37,28 +35,18 @@ func (e *Engine) HasMatchedTx(tradePairs ...*common.TradePair) bool {
                return false
        }
 
-       for i, order := range orders {
-               if canNotBeMatched(order, orders[getOppositeIndex(len(orders), i)]) {
-                       return false
-               }
-       }
-       return true
+       return isMatched(orders)
 }
 
 // NextMatchedTx return the next matchable transaction by the specified trade pairs
 // the size of trade pairs at least 2, and the sequence of trade pairs can form a loop
 // for example, [assetA -> assetB, assetB -> assetC, assetC -> assetA]
 func (e *Engine) NextMatchedTx(tradePairs ...*common.TradePair) (*types.Tx, error) {
-       if err := validateTradePairs(tradePairs); err != nil {
-               return nil, err
-       }
-
-       orders := e.peekOrders(tradePairs)
-       if len(orders) == 0 {
-               return nil, errors.New("no order for the specified trade pair in the order table")
+       if !e.HasMatchedTx(tradePairs...) {
+               return nil, errors.New("the specified trade pairs can not be matched")
        }
 
-       tx, err := e.buildMatchTx(orders)
+       tx, err := e.buildMatchTx(e.peekOrders(tradePairs))
        if err != nil {
                return nil, err
        }
@@ -100,9 +88,14 @@ func validateTradePairs(tradePairs []*common.TradePair) error {
        return nil
 }
 
-func canNotBeMatched(order, oppositeOrder *common.Order) bool {
-       rate := 1 / order.Rate
-       return rate < oppositeOrder.Rate
+func isMatched(orders []*common.Order) bool {
+       for i, order := range orders {
+               opposisteOrder := orders[getOppositeIndex(len(orders), i)]
+               if 1 / order.Rate < opposisteOrder.Rate {
+                       return false
+               }
+       }
+       return true
 }
 
 func (e *Engine) buildMatchTx(orders []*common.Order) (*types.Tx, error) {
@@ -139,7 +132,7 @@ func addMatchTxOutput(txData *types.TxData, txInput *types.TxInput, order *commo
        requestAmount := calcRequestAmount(order.Utxo.Amount, contractArgs)
        receiveAmount := vprMath.MinUint64(requestAmount, oppositeAmount)
        shouldPayAmount := CalcShouldPayAmount(receiveAmount, contractArgs)
-       isPartialTrade := order.Utxo.Amount > shouldPayAmount
+       isPartialTrade := requestAmount > receiveAmount
 
        setMatchTxArguments(txInput, isPartialTrade, len(txData.Outputs), receiveAmount)
        txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(*order.ToAssetID, receiveAmount, contractArgs.SellerProgram))
@@ -150,17 +143,17 @@ func addMatchTxOutput(txData *types.TxData, txInput *types.TxInput, order *commo
 }
 
 func (e *Engine) addMatchTxFeeOutput(txData *types.TxData) error {
-       feeAssetAmountMap, err := CalcFeeFromMatchedTx(txData)
+       txFee, err := CalcMatchedTxFee(txData, e.maxFeeRate)
        if err != nil {
                return err
        }
 
-       for feeAssetID, amount := range feeAssetAmountMap {
+       for feeAssetID, amount := range txFee {
                var reminder int64 = 0
-               feeAmount := amount.payableFeeAmount
-               if amount.payableFeeAmount > amount.maxFeeAmount {
-                       feeAmount = amount.maxFeeAmount
-                       reminder = amount.payableFeeAmount - amount.maxFeeAmount
+               feeAmount := amount.FeeAmount
+               if amount.FeeAmount > amount.MaxFeeAmount {
+                       feeAmount = amount.MaxFeeAmount
+                       reminder = amount.FeeAmount - amount.MaxFeeAmount
                }
                txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(feeAssetID, uint64(feeAmount), e.nodeProgram))
 
@@ -222,25 +215,26 @@ func getOppositeIndex(size int, selfIdx int) int {
        return oppositeIdx
 }
 
-type feeAmount struct {
-       maxFeeAmount     int64
-       payableFeeAmount int64
+type MatchedTxFee struct {
+       MaxFeeAmount int64
+       FeeAmount    int64
 }
 
-func CalcFeeFromMatchedTx(txData *types.TxData) (map[bc.AssetID]*feeAmount, error) {
-       assetAmountMap := make(map[bc.AssetID]*feeAmount)
-       for _, input := range txData.Inputs {
-               assetAmountMap[input.AssetID()] = &feeAmount{}
-       }
+func CalcMatchedTxFee(txData *types.TxData, maxFeeRate float64) (map[bc.AssetID]*MatchedTxFee, error) {
+       assetFeeMap := make(map[bc.AssetID]*MatchedTxFee)
+       sellerProgramMap := make(map[string]bool)
+       assetInputMap := make(map[bc.AssetID]uint64)
 
-       receiveOutputMap := make(map[string]*types.TxOutput)
-       for _, output := range txData.Outputs {
-               // minus the amount of the re-order
-               if segwit.IsP2WMCScript(output.ControlProgram()) {
-                       assetAmountMap[*output.AssetAmount().AssetId].payableFeeAmount -= int64(output.AssetAmount().Amount)
-               } else {
-                       receiveOutputMap[hex.EncodeToString(output.ControlProgram())] = output
+       for _, input := range txData.Inputs {
+               assetFeeMap[input.AssetID()] = &MatchedTxFee{}
+               assetFeeMap[input.AssetID()].FeeAmount += int64(input.AssetAmount().Amount)
+               contractArgs, err := segwit.DecodeP2WMCProgram(input.ControlProgram())
+               if err != nil {
+                       return nil, err
                }
+
+               sellerProgramMap[hex.EncodeToString(contractArgs.SellerProgram)] = true
+               assetInputMap[input.AssetID()] = input.Amount()
        }
 
        for _, input := range txData.Inputs {
@@ -249,22 +243,30 @@ func CalcFeeFromMatchedTx(txData *types.TxData) (map[bc.AssetID]*feeAmount, erro
                        return nil, err
                }
 
-               assetAmountMap[input.AssetID()].payableFeeAmount += int64(input.AssetAmount().Amount)
-               receiveOutput, ok := receiveOutputMap[hex.EncodeToString(contractArgs.SellerProgram)]
-               if !ok {
-                       return nil, errors.New("the input of matched tx has no receive output")
-               }
-
-               assetAmountMap[*receiveOutput.AssetAmount().AssetId].payableFeeAmount -= int64(receiveOutput.AssetAmount().Amount)
-               assetAmountMap[input.AssetID()].maxFeeAmount = CalcMaxFeeAmount(CalcShouldPayAmount(receiveOutput.AssetAmount().Amount, contractArgs))
+               oppositeAmount := assetInputMap[contractArgs.RequestedAsset]
+               receiveAmount := vprMath.MinUint64(calcRequestAmount(input.Amount(), contractArgs), oppositeAmount)
+               assetFeeMap[input.AssetID()].MaxFeeAmount = CalcMaxFeeAmount(CalcShouldPayAmount(receiveAmount, contractArgs), maxFeeRate)
        }
 
-       for assetID, amount := range assetAmountMap {
-               if amount.payableFeeAmount == 0 {
-                       delete(assetAmountMap, assetID)
+       for _, output := range txData.Outputs {
+               // minus the amount of the re-order
+               if segwit.IsP2WMCScript(output.ControlProgram()) {
+                       assetFeeMap[*output.AssetAmount().AssetId].FeeAmount -= int64(output.AssetAmount().Amount)
+               }
+               // minus the amount of seller's receiving output
+               if _, ok := sellerProgramMap[hex.EncodeToString(output.ControlProgram())]; ok {
+                       assetID := *output.AssetAmount().AssetId
+                       fee, ok := assetFeeMap[assetID]
+                       if !ok {
+                               continue
+                       }
+                       fee.FeeAmount -= int64(output.AssetAmount().Amount)
+                       if fee.FeeAmount == 0 {
+                               delete(assetFeeMap, assetID)
+                       }
                }
        }
-       return assetAmountMap, nil
+       return assetFeeMap, nil
 }
 
 func calcRequestAmount(fromAmount uint64, contractArg *vmutil.MagneticContractArgs) uint64 {
@@ -272,9 +274,9 @@ func calcRequestAmount(fromAmount uint64, contractArg *vmutil.MagneticContractAr
 }
 
 func CalcShouldPayAmount(receiveAmount uint64, contractArg *vmutil.MagneticContractArgs) uint64 {
-       return uint64(math.Ceil(float64(receiveAmount) * float64(contractArg.RatioDenominator) / float64(contractArg.RatioNumerator)))
+       return uint64(math.Floor(float64(receiveAmount) * float64(contractArg.RatioDenominator) / float64(contractArg.RatioNumerator)))
 }
 
-func CalcMaxFeeAmount(shouldPayAmount uint64) int64 {
+func CalcMaxFeeAmount(shouldPayAmount uint64, maxFeeRate float64) int64 {
        return int64(math.Ceil(float64(shouldPayAmount) * maxFeeRate))
 }