OSDN Git Service

fix fee (#523)
[bytom/vapor.git] / application / mov / match / engine.go
index 4e5a029..d332d75 100644 (file)
@@ -62,20 +62,44 @@ func (e *Engine) NextMatchedTx(tradePairs ...*common.TradePair) (*types.Tx, erro
        return tx, nil
 }
 
-func (e *Engine) addMatchTxFeeOutput(txData *types.TxData, refunds RefundAssets, fees []*bc.AssetAmount) error {
+func (e *Engine) addMatchTxFeeOutput(txData *types.TxData, fees []*bc.AssetAmount) error {
        for _, feeAmount := range fees {
                txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(*feeAmount.AssetId, feeAmount.Amount, e.rewardProgram))
        }
 
-       for i, refund := range refunds {
-               // each trading participant may be refunded multiple assets
-               for _, assetAmount := range refund {
-                       contractArgs, err := segwit.DecodeP2WMCProgram(txData.Inputs[i].ControlProgram())
-                       if err != nil {
-                               return err
-                       }
+       refoundAmount := map[bc.AssetID]uint64{}
+       assetIDs := []bc.AssetID{}
+       refoundScript := [][]byte{}
+       for _, input := range txData.Inputs {
+               refoundAmount[input.AssetID()] += input.Amount()
+               contractArgs, err := segwit.DecodeP2WMCProgram(input.ControlProgram())
+               if err != nil {
+                       return err
+               }
+
+               assetIDs = append(assetIDs, input.AssetID())
+               refoundScript = append(refoundScript, contractArgs.SellerProgram)
+       }
 
-                       txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(*assetAmount.AssetId, assetAmount.Amount, contractArgs.SellerProgram))
+       for _, output := range txData.Outputs {
+               assetAmount := output.AssetAmount()
+               refoundAmount[*assetAmount.AssetId] -= assetAmount.Amount
+       }
+
+       refoundCount := len(refoundScript)
+       for _, assetID := range assetIDs {
+               amount := refoundAmount[assetID]
+               averageAmount := amount / uint64(refoundCount)
+               if averageAmount == 0 {
+                       averageAmount = 1
+               }
+
+               for i := 0; i < refoundCount && amount > 0; i++ {
+                       if i == refoundCount-1 {
+                               averageAmount = amount
+                       }
+                       txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(assetID, averageAmount, refoundScript[i]))
+                       amount -= averageAmount
                }
        }
        return nil
@@ -110,7 +134,7 @@ func (e *Engine) buildMatchTx(orders []*common.Order) (*types.Tx, error) {
                return nil, err
        }
 
-       if err := e.addMatchTxFeeOutput(txData, allocatedAssets.Refunds, allocatedAssets.Fees); err != nil {
+       if err := e.addMatchTxFeeOutput(txData, allocatedAssets.Fees); err != nil {
                return nil, err
        }
 
@@ -140,8 +164,6 @@ func addMatchTxOutput(txData *types.TxData, orders []*common.Order, receivedAmou
                txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(*order.ToAssetID, allocatedAssets.Receives[i].Amount, contractArgs.SellerProgram))
                if isPartialTrade {
                        txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(*order.FromAssetID, exchangeAmount, order.Utxo.ControlProgram))
-               } else if exchangeAmount > 0 {
-                       allocatedAssets.Refunds.Add(i, *order.FromAssetID, exchangeAmount)
                }
        }
        return nil
@@ -184,10 +206,9 @@ func CalcReceivedAmount(orders []*common.Order) ([]*bc.AssetAmount, []*bc.AssetA
 
        for i, receivedAmount := range receivedAmounts {
                oppositeShouldPayAmount := shouldPayAmounts[calcOppositeIndex(len(orders), i)]
+               priceDiffs = append(priceDiffs, &bc.AssetAmount{AssetId: oppositeShouldPayAmount.AssetId, Amount: 0})
                if oppositeShouldPayAmount.Amount > receivedAmount.Amount {
-                       assetID := oppositeShouldPayAmount.AssetId
-                       amount := oppositeShouldPayAmount.Amount - receivedAmount.Amount
-                       priceDiffs = append(priceDiffs, &bc.AssetAmount{AssetId: assetID, Amount: amount})
+                       priceDiffs[i].Amount = oppositeShouldPayAmount.Amount - receivedAmount.Amount
                }
        }
        return receivedAmounts, priceDiffs