OSDN Git Service

match_fee_strategy (#506)
authorPaladz <yzhu101@uottawa.ca>
Thu, 5 Mar 2020 15:12:28 +0000 (23:12 +0800)
committerGitHub <noreply@github.com>
Thu, 5 Mar 2020 15:12:28 +0000 (23:12 +0800)
* match_fee_strategy

* rename variable

* opt code

* rename

* adjust order

* add test case

application/mov/match/match.go
application/mov/match/match_fee.go [new file with mode: 0644]
application/mov/match/match_test.go
application/mov/mock/mock.go
application/mov/mov_core.go
application/mov/mov_core_test.go

index bfc3943..7b81e1a 100644 (file)
@@ -1,8 +1,6 @@
 package match
 
 import (
-       "encoding/hex"
-       "math"
        "math/big"
 
        "github.com/bytom/vapor/application/mov/common"
@@ -13,19 +11,18 @@ import (
        "github.com/bytom/vapor/protocol/bc"
        "github.com/bytom/vapor/protocol/bc/types"
        "github.com/bytom/vapor/protocol/vm"
-       "github.com/bytom/vapor/protocol/vm/vmutil"
 )
 
 // Engine is used to generate math transactions
 type Engine struct {
        orderBook     *OrderBook
-       maxFeeRate    float64
+       feeStrategy   FeeStrategy
        rewardProgram []byte
 }
 
 // NewEngine return a new Engine
-func NewEngine(orderBook *OrderBook, maxFeeRate float64, rewardProgram []byte) *Engine {
-       return &Engine{orderBook: orderBook, maxFeeRate: maxFeeRate, rewardProgram: rewardProgram}
+func NewEngine(orderBook *OrderBook, feeStrategy FeeStrategy, rewardProgram []byte) *Engine {
+       return &Engine{orderBook: orderBook, feeStrategy: feeStrategy, rewardProgram: rewardProgram}
 }
 
 // HasMatchedTx check does the input trade pair can generate a match deal
@@ -65,39 +62,18 @@ func (e *Engine) NextMatchedTx(tradePairs ...*common.TradePair) (*types.Tx, erro
        return tx, nil
 }
 
-func (e *Engine) addMatchTxFeeOutput(txData *types.TxData) error {
-       txFee, err := CalcMatchedTxFee(txData, e.maxFeeRate)
-       if err != nil {
-               return err
+func (e *Engine) addMatchTxFeeOutput(txData *types.TxData, refundAmounts, feeAmounts []*bc.AssetAmount) error {
+       for _, feeAmount := range feeAmounts {
+               txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(*feeAmount.AssetId, feeAmount.Amount, e.rewardProgram))
        }
 
-       for assetID, matchTxFee := range txFee {
-               feeAmount, reminder := matchTxFee.FeeAmount, int64(0)
-               if matchTxFee.FeeAmount > matchTxFee.MaxFeeAmount {
-                       feeAmount = matchTxFee.MaxFeeAmount
-                       reminder = matchTxFee.FeeAmount - matchTxFee.MaxFeeAmount
-               }
-               txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(assetID, uint64(feeAmount), e.rewardProgram))
-
-               // There is the remaining amount after paying the handling fee, assign it evenly to participants in the transaction
-               averageAmount := reminder / int64(len(txData.Inputs))
-               if averageAmount == 0 {
-                       averageAmount = 1
+       for i, refundAmount := range refundAmounts {
+               contractArgs, err := segwit.DecodeP2WMCProgram(txData.Inputs[i].ControlProgram())
+               if err != nil {
+                       return err
                }
 
-               for i := 0; i < len(txData.Inputs) && reminder > 0; i++ {
-                       contractArgs, err := segwit.DecodeP2WMCProgram(txData.Inputs[i].ControlProgram())
-                       if err != nil {
-                               return err
-                       }
-
-                       if i == len(txData.Inputs)-1 {
-                               txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(assetID, uint64(reminder), contractArgs.SellerProgram))
-                       } else {
-                               txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(assetID, uint64(averageAmount), contractArgs.SellerProgram))
-                       }
-                       reminder -= averageAmount
-               }
+               txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(*refundAmount.AssetId, refundAmount.Amount, contractArgs.SellerProgram))
        }
        return nil
 }
@@ -120,17 +96,18 @@ func (e *Engine) addPartialTradeOrder(tx *types.Tx) error {
 
 func (e *Engine) buildMatchTx(orders []*common.Order) (*types.Tx, error) {
        txData := &types.TxData{Version: 1}
-       for i, order := range orders {
+       for _, order := range orders {
                input := types.NewSpendInput(nil, *order.Utxo.SourceID, *order.FromAssetID, order.Utxo.Amount, order.Utxo.SourcePos, order.Utxo.ControlProgram)
                txData.Inputs = append(txData.Inputs, input)
+       }
 
-               oppositeOrder := orders[calcOppositeIndex(len(orders), i)]
-               if err := addMatchTxOutput(txData, input, order, oppositeOrder.Utxo.Amount); err != nil {
-                       return nil, err
-               }
+       receivedAmounts, priceDiff := CalcReceivedAmount(orders)
+       receivedAfterDeductFee, refundAmounts, feeAmounts := e.feeStrategy.Allocate(receivedAmounts, priceDiff)
+       if err := addMatchTxOutput(txData, orders, receivedAmounts, receivedAfterDeductFee); err != nil {
+               return nil, err
        }
 
-       if err := e.addMatchTxFeeOutput(txData); err != nil {
+       if err := e.addMatchTxFeeOutput(txData, refundAmounts, feeAmounts); err != nil {
                return nil, err
        }
 
@@ -143,94 +120,77 @@ func (e *Engine) buildMatchTx(orders []*common.Order) (*types.Tx, error) {
        return types.NewTx(*txData), nil
 }
 
-// MatchedTxFee is object to record the mov tx's fee information
-type MatchedTxFee struct {
-       MaxFeeAmount int64
-       FeeAmount    int64
-}
-
-// CalcMatchedTxFee is used to calculate tx's MatchedTxFees
-func CalcMatchedTxFee(txData *types.TxData, maxFeeRate float64) (map[bc.AssetID]*MatchedTxFee, error) {
-       assetFeeMap := make(map[bc.AssetID]*MatchedTxFee)
-       dealProgMaps := make(map[string]bool)
-
-       for _, input := range txData.Inputs {
-               assetFeeMap[input.AssetID()] = &MatchedTxFee{FeeAmount: int64(input.AssetAmount().Amount)}
-               contractArgs, err := segwit.DecodeP2WMCProgram(input.ControlProgram())
-               if err != nil {
-                       return nil, err
-               }
-
-               dealProgMaps[hex.EncodeToString(contractArgs.SellerProgram)] = true
-       }
-
-       for _, input := range txData.Inputs {
-               contractArgs, err := segwit.DecodeP2WMCProgram(input.ControlProgram())
+func addMatchTxOutput(txData *types.TxData, orders []*common.Order, receivedAmounts, receivedAfterDeductFee []*bc.AssetAmount) error {
+       for i, order := range orders {
+               contractArgs, err := segwit.DecodeP2WMCProgram(order.Utxo.ControlProgram)
                if err != nil {
-                       return nil, err
+                       return err
                }
 
-               oppositeAmount := uint64(assetFeeMap[contractArgs.RequestedAsset].FeeAmount)
-               receiveAmount := vprMath.MinUint64(CalcRequestAmount(input.Amount(), contractArgs), oppositeAmount)
-               assetFeeMap[input.AssetID()].MaxFeeAmount = calcMaxFeeAmount(calcShouldPayAmount(receiveAmount, contractArgs), maxFeeRate)
-       }
+               requestAmount := CalcRequestAmount(order.Utxo.Amount, contractArgs.RatioNumerator, contractArgs.RatioDenominator)
+               receivedAmount := receivedAmounts[i].Amount
+               shouldPayAmount := calcShouldPayAmount(receivedAmount, contractArgs.RatioNumerator, contractArgs.RatioDenominator)
+               isPartialTrade := requestAmount > receivedAmount
 
-       for _, output := range txData.Outputs {
-               assetAmount := output.AssetAmount()
-               if _, ok := dealProgMaps[hex.EncodeToString(output.ControlProgram())]; ok || segwit.IsP2WMCScript(output.ControlProgram()) {
-                       assetFeeMap[*assetAmount.AssetId].FeeAmount -= int64(assetAmount.Amount)
-                       if assetFeeMap[*assetAmount.AssetId].FeeAmount <= 0 {
-                               delete(assetFeeMap, *assetAmount.AssetId)
-                       }
+               setMatchTxArguments(txData.Inputs[i], isPartialTrade, len(txData.Outputs), receivedAfterDeductFee[i].Amount)
+               txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(*order.ToAssetID, receivedAfterDeductFee[i].Amount, contractArgs.SellerProgram))
+               if isPartialTrade {
+                       txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(*order.FromAssetID, order.Utxo.Amount-shouldPayAmount, order.Utxo.ControlProgram))
                }
        }
-       return assetFeeMap, nil
+       return nil
 }
 
-func addMatchTxOutput(txData *types.TxData, txInput *types.TxInput, order *common.Order, oppositeAmount uint64) error {
-       contractArgs, err := segwit.DecodeP2WMCProgram(order.Utxo.ControlProgram)
-       if err != nil {
-               return err
-       }
-
-       requestAmount := CalcRequestAmount(order.Utxo.Amount, contractArgs)
-       receiveAmount := vprMath.MinUint64(requestAmount, oppositeAmount)
-       shouldPayAmount := calcShouldPayAmount(receiveAmount, contractArgs)
-       isPartialTrade := requestAmount > receiveAmount
-
-       setMatchTxArguments(txInput, isPartialTrade, len(txData.Outputs), receiveAmount)
-       txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(*order.ToAssetID, receiveAmount, contractArgs.SellerProgram))
-       if isPartialTrade {
-               txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(*order.FromAssetID, order.Utxo.Amount-shouldPayAmount, order.Utxo.ControlProgram))
-       }
-       return nil
+func calcOppositeIndex(size int, selfIdx int) int {
+       return (selfIdx + 1) % size
 }
 
 // CalcRequestAmount is from amount * numerator / ratioDenominator
-func CalcRequestAmount(fromAmount uint64, contractArg *vmutil.MagneticContractArgs) uint64 {
+func CalcRequestAmount(fromAmount uint64, ratioNumerator, ratioDenominator int64) uint64 {
        res := big.NewInt(0).SetUint64(fromAmount)
-       res.Mul(res, big.NewInt(contractArg.RatioNumerator)).Quo(res, big.NewInt(contractArg.RatioDenominator))
+       res.Mul(res, big.NewInt(ratioNumerator)).Quo(res, big.NewInt(ratioDenominator))
        if !res.IsUint64() {
                return 0
        }
        return res.Uint64()
 }
 
-func calcShouldPayAmount(receiveAmount uint64, contractArg *vmutil.MagneticContractArgs) uint64 {
+func calcShouldPayAmount(receiveAmount uint64, ratioNumerator, ratioDenominator int64) uint64 {
        res := big.NewInt(0).SetUint64(receiveAmount)
-       res.Mul(res, big.NewInt(contractArg.RatioDenominator)).Quo(res, big.NewInt(contractArg.RatioNumerator))
+       res.Mul(res, big.NewInt(ratioDenominator)).Quo(res, big.NewInt(ratioNumerator))
        if !res.IsUint64() {
                return 0
        }
        return res.Uint64()
 }
 
-func calcMaxFeeAmount(shouldPayAmount uint64, maxFeeRate float64) int64 {
-       return int64(math.Ceil(float64(shouldPayAmount) * maxFeeRate))
-}
+// CalcReceivedAmount return amount of assets received by each participant in the matching transaction and the price difference
+func CalcReceivedAmount(orders []*common.Order) ([]*bc.AssetAmount, *bc.AssetAmount) {
+       priceDiff := &bc.AssetAmount{}
+       if len(orders) == 0 {
+               return nil, priceDiff
+       }
 
-func calcOppositeIndex(size int, selfIdx int) int {
-       return (selfIdx + 1) % size
+       var receivedAmounts, shouldPayAmounts []*bc.AssetAmount
+       for i, order := range orders {
+               requestAmount := CalcRequestAmount(order.Utxo.Amount, order.RatioNumerator, order.RatioDenominator)
+               oppositeOrder := orders[calcOppositeIndex(len(orders), i)]
+               receiveAmount := vprMath.MinUint64(oppositeOrder.Utxo.Amount, requestAmount)
+               shouldPayAmount := calcShouldPayAmount(receiveAmount, order.RatioNumerator, order.RatioDenominator)
+               receivedAmounts = append(receivedAmounts, &bc.AssetAmount{AssetId: order.ToAssetID, Amount: receiveAmount})
+               shouldPayAmounts = append(shouldPayAmounts, &bc.AssetAmount{AssetId: order.FromAssetID, Amount: shouldPayAmount})
+       }
+
+       for i, receivedAmount := range receivedAmounts {
+               oppositeShouldPayAmount := shouldPayAmounts[calcOppositeIndex(len(orders), i)]
+               if oppositeShouldPayAmount.Amount > receivedAmount.Amount {
+                       priceDiff.AssetId = oppositeShouldPayAmount.AssetId
+                       priceDiff.Amount = oppositeShouldPayAmount.Amount - receivedAmount.Amount
+                       // price differential can only produce once
+                       break
+               }
+       }
+       return receivedAmounts, priceDiff
 }
 
 // IsMatched check does the orders can be exchange
diff --git a/application/mov/match/match_fee.go b/application/mov/match/match_fee.go
new file mode 100644 (file)
index 0000000..6569ce1
--- /dev/null
@@ -0,0 +1,102 @@
+package match
+
+import (
+       "math"
+
+       "github.com/bytom/vapor/errors"
+       "github.com/bytom/vapor/protocol/bc"
+)
+
+var (
+       // ErrAmountOfFeeExceedMaximum represent The fee charged is exceeded the maximum
+       ErrAmountOfFeeExceedMaximum = errors.New("amount of fee greater than max fee amount")
+       // ErrFeeMoreThanOneAsset represent the fee charged can only have one asset
+       ErrFeeMoreThanOneAsset      = errors.New("fee can only be an asset")
+)
+
+// FeeStrategy used to indicate how to charge a matching fee
+type FeeStrategy interface {
+       // Allocate will allocate the price differential in matching transaction to the participants and the fee
+       // @param receiveAmounts the amount of assets that the participants in the matching transaction can received when no fee is considered
+       // @param priceDiff price differential of matching transaction
+       // @return the amount of assets that the participants in the matching transaction can received when fee is considered
+       // @return the amount of assets returned to the transaction participant when the fee exceeds a certain ratio
+       // @return the amount of fees
+       Allocate(receiveAmounts []*bc.AssetAmount, priceDiff *bc.AssetAmount) ([]*bc.AssetAmount, []*bc.AssetAmount, []*bc.AssetAmount)
+
+       // Validate verify that the fee charged for a matching transaction is correct
+       Validate(receiveAmounts []*bc.AssetAmount, priceDiff *bc.AssetAmount, feeAmounts map[bc.AssetID]int64) error
+}
+
+// DefaultFeeStrategy represent the default fee charge strategy
+type DefaultFeeStrategy struct {
+       maxFeeRate float64
+}
+
+// NewDefaultFeeStrategy return a new instance of DefaultFeeStrategy
+func NewDefaultFeeStrategy(maxFeeRate float64) *DefaultFeeStrategy {
+       return &DefaultFeeStrategy{maxFeeRate: maxFeeRate}
+}
+
+// Allocate will allocate the price differential in matching transaction to the participants and the fee
+func (d *DefaultFeeStrategy) Allocate(receiveAmounts []*bc.AssetAmount, priceDiff *bc.AssetAmount) ([]*bc.AssetAmount, []*bc.AssetAmount, []*bc.AssetAmount) {
+       receivedAfterDeductFee := make([]*bc.AssetAmount, len(receiveAmounts))
+       copy(receivedAfterDeductFee, receiveAmounts)
+
+       if priceDiff.Amount == 0 {
+               return receivedAfterDeductFee, nil, nil
+       }
+
+       var maxFeeAmount int64
+       for _, receiveAmount := range receiveAmounts {
+               if *receiveAmount.AssetId == *priceDiff.AssetId {
+                       maxFeeAmount = calcMaxFeeAmount(receiveAmount.Amount, d.maxFeeRate)
+               }
+       }
+
+       priceDiffAmount := int64(priceDiff.Amount)
+       feeAmount, reminder := priceDiffAmount, int64(0)
+       if priceDiffAmount > maxFeeAmount {
+               feeAmount = maxFeeAmount
+               reminder = priceDiffAmount - maxFeeAmount
+       }
+
+       // There is the remaining amount after paying the handling fee, assign it evenly to participants in the transaction
+       averageAmount := reminder / int64(len(receiveAmounts))
+       if averageAmount == 0 {
+               averageAmount = 1
+       }
+
+       var refundAmounts []*bc.AssetAmount
+       for i := 0; i < len(receiveAmounts) && reminder > 0; i++ {
+               amount := averageAmount
+               if i == len(receiveAmounts)-1 {
+                       amount = reminder
+               }
+               refundAmounts = append(refundAmounts, &bc.AssetAmount{AssetId: priceDiff.AssetId, Amount: uint64(amount)})
+               reminder -= averageAmount
+       }
+
+       feeAmounts := []*bc.AssetAmount{{AssetId: priceDiff.AssetId, Amount: uint64(feeAmount)}}
+       return receivedAfterDeductFee, refundAmounts, feeAmounts
+}
+
+// Validate verify that the fee charged for a matching transaction is correct
+func (d *DefaultFeeStrategy) Validate(receiveAmounts []*bc.AssetAmount, priceDiff *bc.AssetAmount, feeAmounts map[bc.AssetID]int64) error {
+       if len(feeAmounts) > 1 {
+               return ErrFeeMoreThanOneAsset
+       }
+
+       for _, receiveAmount := range receiveAmounts {
+               if feeAmount, ok := feeAmounts[*receiveAmount.AssetId]; ok {
+                       if feeAmount > calcMaxFeeAmount(receiveAmount.Amount, d.maxFeeRate) {
+                               return ErrAmountOfFeeExceedMaximum
+                       }
+               }
+       }
+       return nil
+}
+
+func calcMaxFeeAmount(amount uint64, maxFeeRate float64) int64 {
+       return int64(math.Ceil(float64(amount) * maxFeeRate))
+}
index 70fda39..c14067b 100644 (file)
@@ -8,7 +8,6 @@ import (
        "github.com/bytom/vapor/protocol/bc"
        "github.com/bytom/vapor/protocol/bc/types"
        "github.com/bytom/vapor/protocol/validation"
-       "github.com/bytom/vapor/testutil"
 )
 
 func TestGenerateMatchedTxs(t *testing.T) {
@@ -80,7 +79,7 @@ func TestGenerateMatchedTxs(t *testing.T) {
 
        for i, c := range cases {
                movStore := mock.NewMovStore([]*common.TradePair{btc2eth, eth2btc}, c.initStoreOrders)
-               matchEngine := NewEngine(NewOrderBook(movStore, nil, nil), 0.05, mock.RewardProgram)
+               matchEngine := NewEngine(NewOrderBook(movStore, nil, nil), NewDefaultFeeStrategy(0.05), mock.RewardProgram)
                var gotMatchedTxs []*types.Tx
                for matchEngine.HasMatchedTx(c.tradePairs...) {
                        matchedTx, err := matchEngine.NextMatchedTx(c.tradePairs...)
@@ -96,19 +95,19 @@ func TestGenerateMatchedTxs(t *testing.T) {
                        continue
                }
 
-               for i, gotMatchedTx := range gotMatchedTxs {
+               for j, gotMatchedTx := range gotMatchedTxs {
                        if _, err := validation.ValidateTx(gotMatchedTx.Tx, &bc.Block{BlockHeader: &bc.BlockHeader{Version: 1}}); err != nil {
                                t.Fatal(err)
                        }
 
-                       c.wantMatchedTxs[i].Version = 1
-                       byteData, err := c.wantMatchedTxs[i].MarshalText()
+                       c.wantMatchedTxs[j].Version = 1
+                       byteData, err := c.wantMatchedTxs[j].MarshalText()
                        if err != nil {
                                t.Fatal(err)
                        }
 
-                       c.wantMatchedTxs[i].SerializedSize = uint64(len(byteData))
-                       wantMatchedTx := types.NewTx(c.wantMatchedTxs[i].TxData)
+                       c.wantMatchedTxs[j].SerializedSize = uint64(len(byteData))
+                       wantMatchedTx := types.NewTx(c.wantMatchedTxs[j].TxData)
                        if gotMatchedTx.ID != wantMatchedTx.ID {
                                t.Errorf("#%d(%s) the tx hash of got matched tx: %s is not equals want matched tx: %s", i, c.desc, gotMatchedTx.ID.String(), wantMatchedTx.ID.String())
                        }
@@ -116,45 +115,6 @@ func TestGenerateMatchedTxs(t *testing.T) {
        }
 }
 
-func TestCalcMatchedTxFee(t *testing.T) {
-       cases := []struct {
-               desc             string
-               tx               *types.TxData
-               maxFeeRate       float64
-               wantMatchedTxFee map[bc.AssetID]*MatchedTxFee
-       }{
-               {
-                       desc:             "fee less than max fee",
-                       maxFeeRate:       0.05,
-                       wantMatchedTxFee: map[bc.AssetID]*MatchedTxFee{mock.ETH: {FeeAmount: 10, MaxFeeAmount: 26}},
-                       tx:               &mock.MatchedTxs[1].TxData,
-               },
-               {
-                       desc:             "fee refund in tx",
-                       maxFeeRate:       0.05,
-                       wantMatchedTxFee: map[bc.AssetID]*MatchedTxFee{mock.ETH: {FeeAmount: 27, MaxFeeAmount: 27}},
-                       tx:               &mock.MatchedTxs[2].TxData,
-               },
-               {
-                       desc:             "fee is zero",
-                       maxFeeRate:       0.05,
-                       wantMatchedTxFee: map[bc.AssetID]*MatchedTxFee{},
-                       tx:               &mock.MatchedTxs[0].TxData,
-               },
-       }
-
-       for i, c := range cases {
-               gotMatchedTxFee, err := CalcMatchedTxFee(c.tx, c.maxFeeRate)
-               if err != nil {
-                       t.Fatal(err)
-               }
-
-               if !testutil.DeepEqual(gotMatchedTxFee, c.wantMatchedTxFee) {
-                       t.Errorf("#%d(%s):fail to caculate matched tx fee, got (%v), want (%v)", i, c.desc, gotMatchedTxFee, c.wantMatchedTxFee)
-               }
-       }
-}
-
 func TestValidateTradePairs(t *testing.T) {
        cases := []struct {
                desc       string
index f351e64..4a2be60 100644 (file)
@@ -285,10 +285,10 @@ var (
                                // re-order
                                types.NewIntraChainOutput(*Eth2BtcOrders[2].FromAssetID, 270, Eth2BtcOrders[2].Utxo.ControlProgram),
                                // fee
-                               types.NewIntraChainOutput(*Eth2BtcOrders[2].FromAssetID, 27, RewardProgram),
+                               types.NewIntraChainOutput(*Eth2BtcOrders[2].FromAssetID, 25, RewardProgram),
                                // refund
-                               types.NewIntraChainOutput(*Eth2BtcOrders[2].FromAssetID, 6, testutil.MustDecodeHexString("0014f928b723999312df4ed51cb275a2644336c19251")),
-                               types.NewIntraChainOutput(*Eth2BtcOrders[2].FromAssetID, 7, testutil.MustDecodeHexString("0014f928b723999312df4ed51cb275a2644336c19255")),
+                               types.NewIntraChainOutput(*Eth2BtcOrders[2].FromAssetID, 7, testutil.MustDecodeHexString("0014f928b723999312df4ed51cb275a2644336c19251")),
+                               types.NewIntraChainOutput(*Eth2BtcOrders[2].FromAssetID, 8, testutil.MustDecodeHexString("0014f928b723999312df4ed51cb275a2644336c19255")),
                        },
                }),
                types.NewTx(types.TxData{
index dcf17ad..6a8855a 100644 (file)
@@ -23,7 +23,6 @@ var (
        errInputProgramMustP2WMCScript   = errors.New("input program of trade tx must p2wmc script")
        errExistCancelOrderInMatchedTx   = errors.New("can't exist cancel order in the matched transaction")
        errExistTradeInCancelOrderTx     = errors.New("can't exist trade in the cancel order transaction")
-       errAmountOfFeeGreaterThanMaximum = errors.New("amount of fee greater than max fee amount")
        errAssetIDMustUniqueInMatchedTx  = errors.New("asset id must unique in matched transaction")
        errRatioOfTradeLessThanZero      = errors.New("ratio arguments must greater than zero")
        errSpendOutputIDIsIncorrect      = errors.New("spend output id of matched tx is not equals to actual matched tx")
@@ -89,7 +88,7 @@ func (m *MovCore) BeforeProposalBlock(txs []*types.Tx, blockHeight uint64, gasLe
                return nil, err
        }
 
-       matchEngine := match.NewEngine(orderBook, maxFeeRate, rewardProgram)
+       matchEngine := match.NewEngine(orderBook, match.NewDefaultFeeStrategy(maxFeeRate), rewardProgram)
        tradePairIterator := database.NewTradePairIterator(m.movStore)
        matchCollector := newMatchTxCollector(matchEngine, tradePairIterator, gasLeft, isTimeout)
        return matchCollector.result()
@@ -183,6 +182,33 @@ func (m *MovCore) ValidateTx(tx *types.Tx, verifyResult *bc.TxVerifyResult) erro
        return nil
 }
 
+// calcFeeAmount return the amount of fee in the matching transaction
+func calcFeeAmount(matchedTx *types.Tx) (map[bc.AssetID]int64, error) {
+       assetFeeMap := make(map[bc.AssetID]int64)
+       dealProgMaps := make(map[string]bool)
+
+       for _, input := range matchedTx.Inputs {
+               assetFeeMap[input.AssetID()] = int64(input.AssetAmount().Amount)
+               contractArgs, err := segwit.DecodeP2WMCProgram(input.ControlProgram())
+               if err != nil {
+                       return nil, err
+               }
+
+               dealProgMaps[hex.EncodeToString(contractArgs.SellerProgram)] = true
+       }
+
+       for _, output := range matchedTx.Outputs {
+               assetAmount := output.AssetAmount()
+               if _, ok := dealProgMaps[hex.EncodeToString(output.ControlProgram())]; ok || segwit.IsP2WMCScript(output.ControlProgram()) {
+                       assetFeeMap[*assetAmount.AssetId] -= int64(assetAmount.Amount)
+                       if assetFeeMap[*assetAmount.AssetId] <= 0 {
+                               delete(assetFeeMap, *assetAmount.AssetId)
+                       }
+               }
+       }
+       return assetFeeMap, nil
+}
+
 func validateCancelOrderTx(tx *types.Tx, verifyResult *bc.TxVerifyResult) error {
        if verifyResult.StatusFail {
                return errStatusFailMustFalse
@@ -214,7 +240,7 @@ func validateMagneticContractArgs(fromAssetAmount bc.AssetAmount, program []byte
                return errRatioOfTradeLessThanZero
        }
 
-       if match.CalcRequestAmount(fromAssetAmount.Amount, contractArgs) < 1 {
+       if match.CalcRequestAmount(fromAssetAmount.Amount, contractArgs.RatioNumerator, contractArgs.RatioDenominator) < 1 {
                return errRequestAmountMath
        }
        return nil
@@ -253,17 +279,19 @@ func validateMatchedTx(tx *types.Tx, verifyResult *bc.TxVerifyResult) error {
 }
 
 func validateMatchedTxFeeAmount(tx *types.Tx) error {
-       txFee, err := match.CalcMatchedTxFee(&tx.TxData, maxFeeRate)
+       orders, err := getDeleteOrdersFromTx(tx)
        if err != nil {
                return err
        }
 
-       for _, amount := range txFee {
-               if amount.FeeAmount > amount.MaxFeeAmount {
-                       return errAmountOfFeeGreaterThanMaximum
-               }
+       receivedAmount, priceDiff := match.CalcReceivedAmount(orders)
+       feeAmounts, err := calcFeeAmount(tx)
+       if err != nil {
+               return err
        }
-       return nil
+
+       feeStrategy := match.NewDefaultFeeStrategy(maxFeeRate)
+       return feeStrategy.Validate(receivedAmount, priceDiff, feeAmounts)
 }
 
 func (m *MovCore) validateMatchedTxSequence(txs []*types.Tx) error {
index 28a0bd9..9358f89 100644 (file)
@@ -8,6 +8,7 @@ import (
 
        "github.com/bytom/vapor/application/mov/common"
        "github.com/bytom/vapor/application/mov/database"
+       "github.com/bytom/vapor/application/mov/match"
        "github.com/bytom/vapor/application/mov/mock"
        "github.com/bytom/vapor/consensus"
        dbm "github.com/bytom/vapor/database/leveldb"
@@ -446,8 +447,8 @@ func TestValidateBlock(t *testing.T) {
                                                        types.NewSpendInput([][]byte{vm.Int64Bytes(10), vm.Int64Bytes(1), vm.Int64Bytes(0)}, *mock.Eth2BtcOrders[2].Utxo.SourceID, *mock.Eth2BtcOrders[2].FromAssetID, mock.Eth2BtcOrders[2].Utxo.Amount, mock.Eth2BtcOrders[2].Utxo.SourcePos, mock.Eth2BtcOrders[2].Utxo.ControlProgram),
                                                },
                                                Outputs: []*types.TxOutput{
-                                                       types.NewIntraChainOutput(*mock.Btc2EthOrders[0].ToAssetID, 500, testutil.MustDecodeHexString("51")),
-                                                       types.NewIntraChainOutput(*mock.Eth2BtcOrders[2].ToAssetID, 10, testutil.MustDecodeHexString("55")),
+                                                       types.NewIntraChainOutput(*mock.Btc2EthOrders[0].ToAssetID, 500, testutil.MustDecodeHexString("0014f928b723999312df4ed51cb275a2644336c19251")),
+                                                       types.NewIntraChainOutput(*mock.Eth2BtcOrders[2].ToAssetID, 10, testutil.MustDecodeHexString("0014f928b723999312df4ed51cb275a2644336c19255")),
                                                        // re-order
                                                        types.NewIntraChainOutput(*mock.Eth2BtcOrders[2].FromAssetID, 270, mock.Eth2BtcOrders[2].Utxo.ControlProgram),
                                                        // fee
@@ -457,7 +458,7 @@ func TestValidateBlock(t *testing.T) {
                                },
                        },
                        verifyResults: []*bc.TxVerifyResult{{StatusFail: false}},
-                       wantError:     errAmountOfFeeGreaterThanMaximum,
+                       wantError:     match.ErrAmountOfFeeExceedMaximum,
                },
                {
                        desc: "ratio numerator is zero",
@@ -508,6 +509,45 @@ func TestValidateBlock(t *testing.T) {
        }
 }
 
+func TestCalcMatchedTxFee(t *testing.T) {
+       cases := []struct {
+               desc             string
+               tx               types.TxData
+               maxFeeRate       float64
+               wantMatchedTxFee map[bc.AssetID]int64
+       }{
+               {
+                       desc:             "fee less than max fee",
+                       maxFeeRate:       0.05,
+                       wantMatchedTxFee: map[bc.AssetID]int64{mock.ETH: 10},
+                       tx:               mock.MatchedTxs[1].TxData,
+               },
+               {
+                       desc:             "fee refund in tx",
+                       maxFeeRate:       0.05,
+                       wantMatchedTxFee: map[bc.AssetID]int64{mock.ETH: 25},
+                       tx:               mock.MatchedTxs[2].TxData,
+               },
+               {
+                       desc:             "fee is zero",
+                       maxFeeRate:       0.05,
+                       wantMatchedTxFee: map[bc.AssetID]int64{},
+                       tx:               mock.MatchedTxs[0].TxData,
+               },
+       }
+
+       for i, c := range cases {
+               gotMatchedTxFee, err := calcFeeAmount(types.NewTx(c.tx))
+               if err != nil {
+                       t.Fatal(err)
+               }
+
+               if !testutil.DeepEqual(gotMatchedTxFee, c.wantMatchedTxFee) {
+                       t.Errorf("#%d(%s):fail to caculate matched tx fee, got (%v), want (%v)", i, c.desc, gotMatchedTxFee, c.wantMatchedTxFee)
+               }
+       }
+}
+
 func TestBeforeProposalBlock(t *testing.T) {
        consensus.ActiveNetParams.MovRewardProgram = hex.EncodeToString(mock.RewardProgram)