OSDN Git Service

same change while go over the codes check_logic
authorpaladz <453256728@qq.com>
Tue, 5 Nov 2019 16:59:00 +0000 (00:59 +0800)
committerpaladz <453256728@qq.com>
Tue, 5 Nov 2019 16:59:00 +0000 (00:59 +0800)
application/mov/common/type.go
application/mov/common/util.go
application/mov/contract/contract.go
application/mov/database/mov_iterator.go
application/mov/database/mov_iterator_test.go
application/mov/match/match.go
application/mov/match/match_test.go
application/mov/match/order_table.go
application/mov/match/order_table_test.go

index bc4757c..8320459 100644 (file)
@@ -10,6 +10,7 @@ import (
        "github.com/vapor/protocol/bc/types"
 )
 
+// MovUtxo store the utxo information for mov order
 type MovUtxo struct {
        SourceID       *bc.Hash
        SourcePos      uint64
@@ -17,6 +18,7 @@ type MovUtxo struct {
        ControlProgram []byte
 }
 
+// Order store all the order information
 type Order struct {
        FromAssetID *bc.AssetID
        ToAssetID   *bc.AssetID
@@ -24,14 +26,11 @@ type Order struct {
        Rate        float64
 }
 
+// OrderSlice is define for order's sort
 type OrderSlice []*Order
 
-func (o OrderSlice) Len() int {
-       return len(o)
-}
-func (o OrderSlice) Swap(i, j int) {
-       o[i], o[j] = o[j], o[i]
-}
+func (o OrderSlice) Len() int      { return len(o) }
+func (o OrderSlice) Swap(i, j int) { o[i], o[j] = o[j], o[i] }
 func (o OrderSlice) Less(i, j int) bool {
        if o[i].Rate == o[j].Rate {
                return hex.EncodeToString(o[i].UTXOHash().Bytes()) < hex.EncodeToString(o[j].UTXOHash().Bytes())
@@ -39,6 +38,7 @@ func (o OrderSlice) Less(i, j int) bool {
        return o[i].Rate < o[j].Rate
 }
 
+// NewOrderFromOutput convert txinput to order
 func NewOrderFromOutput(tx *types.Tx, outputIndex int) (*Order, error) {
        outputID := tx.OutputID(outputIndex)
        output, err := tx.IntraChainOutput(*outputID)
@@ -65,6 +65,7 @@ func NewOrderFromOutput(tx *types.Tx, outputIndex int) (*Order, error) {
        }, nil
 }
 
+// NewOrderFromInput convert txoutput to order
 func NewOrderFromInput(tx *types.Tx, inputIndex int) (*Order, error) {
        input, ok := tx.Inputs[inputIndex].TypedInput.(*types.SpendInput)
        if !ok {
@@ -83,12 +84,23 @@ func NewOrderFromInput(tx *types.Tx, inputIndex int) (*Order, error) {
                Utxo: &MovUtxo{
                        SourceID:       &input.SourceID,
                        Amount:         input.Amount,
-                       SourcePos:      input.SourcePosition,
+                       SourcePos:      input.SourcePosition,
                        ControlProgram: input.ControlProgram,
                },
        }, nil
 }
 
+// Key return the unique key for representing this order
+func (o *Order) Key() string {
+       return fmt.Sprintf("%s:%d", o.Utxo.SourceID, o.Utxo.SourcePos)
+}
+
+// TradePair return the trade pair info
+func (o *Order) TradePair() *TradePair {
+       return &TradePair{FromAssetID: o.FromAssetID, ToAssetID: o.ToAssetID}
+}
+
+// UTXOHash calculate the utxo hash of this order
 func (o *Order) UTXOHash() *bc.Hash {
        prog := &bc.Program{VmVersion: 1, Code: o.Utxo.ControlProgram}
        src := &bc.ValueSource{
@@ -100,20 +112,19 @@ func (o *Order) UTXOHash() *bc.Hash {
        return &hash
 }
 
-func (o *Order) TradePair() *TradePair {
-       return &TradePair{FromAssetID: o.FromAssetID, ToAssetID: o.ToAssetID}
-}
-
-func (o *Order) Key() string {
-       return fmt.Sprintf("%s:%d", o.Utxo.SourceID, o.Utxo.SourcePos)
-}
-
+// TradePair is the object for record trade pair info
 type TradePair struct {
        FromAssetID *bc.AssetID
        ToAssetID   *bc.AssetID
        Count       int
 }
 
+// Key return the unique key for representing this trade pair
+func (t *TradePair) Key() string {
+       return fmt.Sprintf("%s:%s", t.FromAssetID, t.ToAssetID)
+}
+
+// Reverse return the reverse trade pair object
 func (t *TradePair) Reverse() *TradePair {
        return &TradePair{
                FromAssetID: t.ToAssetID,
@@ -121,10 +132,7 @@ func (t *TradePair) Reverse() *TradePair {
        }
 }
 
-func (t *TradePair) Key() string {
-       return fmt.Sprintf("%s:%s", t.FromAssetID, t.ToAssetID)
-}
-
+// MovDatabaseState is object to record DB image status
 type MovDatabaseState struct {
        Height uint64
        Hash   *bc.Hash
index 4d8567d..860511c 100644 (file)
@@ -6,6 +6,7 @@ import (
        "github.com/vapor/protocol/bc/types"
 )
 
+// IsMatchedTx check if this transaction has trade mov order input
 func IsMatchedTx(tx *types.Tx) bool {
        if len(tx.Inputs) < 2 {
                return false
@@ -18,6 +19,7 @@ func IsMatchedTx(tx *types.Tx) bool {
        return false
 }
 
+// IsCancelOrderTx check if this transaction has cancel mov order input
 func IsCancelOrderTx(tx *types.Tx) bool {
        for _, input := range tx.Inputs {
                if input.InputType() == types.SpendInputType && contract.IsCancelClauseSelector(input) && segwit.IsP2WMCScript(input.ControlProgram()) {
@@ -26,4 +28,3 @@ func IsCancelOrderTx(tx *types.Tx) bool {
        }
        return false
 }
-
index 6bcb6db..fbf870c 100644 (file)
@@ -8,29 +8,34 @@ import (
 )
 
 const (
-       sizeOfCancelClauseArgs = 3
+       sizeOfCancelClauseArgs       = 3
        sizeOfPartialTradeClauseArgs = 3
-       sizeOfFullTradeClauseArgs = 2
+       sizeOfFullTradeClauseArgs    = 2
 )
 
+// smart contract clause select for differnet unlock method
 const (
        PartialTradeClauseSelector int64 = iota
        FullTradeClauseSelector
        CancelClauseSelector
 )
 
+// IsCancelClauseSelector check if input select cancel clause
 func IsCancelClauseSelector(input *types.TxInput) bool {
-       return len(input.Arguments()) == sizeOfCancelClauseArgs && hex.EncodeToString(input.Arguments()[len(input.Arguments()) - 1]) == hex.EncodeToString(vm.Int64Bytes(CancelClauseSelector))
+       return len(input.Arguments()) == sizeOfCancelClauseArgs && hex.EncodeToString(input.Arguments()[len(input.Arguments())-1]) == hex.EncodeToString(vm.Int64Bytes(CancelClauseSelector))
 }
 
+// IsTradeClauseSelector check if input select is partial trade clause or full trade clause
 func IsTradeClauseSelector(input *types.TxInput) bool {
        return IsPartialTradeClauseSelector(input) || IsFullTradeClauseSelector(input)
 }
 
+// IsPartialTradeClauseSelector check if input select partial trade clause
 func IsPartialTradeClauseSelector(input *types.TxInput) bool {
-       return len(input.Arguments()) == sizeOfPartialTradeClauseArgs && hex.EncodeToString(input.Arguments()[len(input.Arguments()) - 1]) == hex.EncodeToString(vm.Int64Bytes(PartialTradeClauseSelector))
+       return len(input.Arguments()) == sizeOfPartialTradeClauseArgs && hex.EncodeToString(input.Arguments()[len(input.Arguments())-1]) == hex.EncodeToString(vm.Int64Bytes(PartialTradeClauseSelector))
 }
 
+// IsFullTradeClauseSelector check if input select full trade clause
 func IsFullTradeClauseSelector(input *types.TxInput) bool {
-       return len(input.Arguments()) == sizeOfFullTradeClauseArgs && hex.EncodeToString(input.Arguments()[len(input.Arguments()) - 1]) == hex.EncodeToString(vm.Int64Bytes(FullTradeClauseSelector))
+       return len(input.Arguments()) == sizeOfFullTradeClauseArgs && hex.EncodeToString(input.Arguments()[len(input.Arguments())-1]) == hex.EncodeToString(vm.Int64Bytes(FullTradeClauseSelector))
 }
index 7ad6ce3..d391ace 100644 (file)
@@ -7,21 +7,25 @@ import (
        "github.com/vapor/protocol/bc"
 )
 
+// TradePairIterator wrap read trade pair from DB action
 type TradePairIterator struct {
        movStore       MovStore
        tradePairs     []*common.TradePair
        tradePairIndex int
 }
 
+// NewTradePairIterator create the new TradePairIterator object
 func NewTradePairIterator(movStore MovStore) *TradePairIterator {
        return &TradePairIterator{movStore: movStore}
 }
 
+// HasNext check if there are more trade pairs in memory or DB
 func (t *TradePairIterator) HasNext() bool {
        tradePairSize := len(t.tradePairs)
        if t.tradePairIndex < tradePairSize {
                return true
        }
+
        var fromAssetID, toAssetID *bc.AssetID
        if len(t.tradePairs) > 0 {
                lastTradePair := t.tradePairs[tradePairSize-1]
@@ -44,6 +48,7 @@ func (t *TradePairIterator) HasNext() bool {
        return true
 }
 
+// Next return the next available trade pair in memory or DB
 func (t *TradePairIterator) Next() *common.TradePair {
        if !t.HasNext() {
                return nil
@@ -54,12 +59,14 @@ func (t *TradePairIterator) Next() *common.TradePair {
        return tradePair
 }
 
+// OrderIterator wrap read order from DB action
 type OrderIterator struct {
        movStore  MovStore
        lastOrder *common.Order
        orders    []*common.Order
 }
 
+// NewOrderIterator create the new OrderIterator object
 func NewOrderIterator(movStore MovStore, tradePair *common.TradePair) *OrderIterator {
        return &OrderIterator{
                movStore:  movStore,
@@ -67,6 +74,7 @@ func NewOrderIterator(movStore MovStore, tradePair *common.TradePair) *OrderIter
        }
 }
 
+// HasNext check if there are more orders in memory or DB
 func (o *OrderIterator) HasNext() bool {
        if len(o.orders) == 0 {
                orders, err := o.movStore.ListOrders(o.lastOrder)
@@ -84,6 +92,7 @@ func (o *OrderIterator) HasNext() bool {
        return true
 }
 
+// NextBatch return the next batch of orders in memory or DB
 func (o *OrderIterator) NextBatch() []*common.Order {
        if !o.HasNext() {
                return nil
index 83e984c..be6c110 100644 (file)
@@ -126,26 +126,26 @@ func TestOrderIterator(t *testing.T) {
                wantOrders  []*common.Order
        }{
                {
-                       desc: "normal case",
-                       tradePair: &common.TradePair{FromAssetID: assetID1, ToAssetID: assetID2},
+                       desc:        "normal case",
+                       tradePair:   &common.TradePair{FromAssetID: assetID1, ToAssetID: assetID2},
                        storeOrders: []*common.Order{order1, order2, order3},
                        wantOrders:  []*common.Order{order1, order2, order3},
                },
                {
-                       desc: "num of orders more than one return",
-                       tradePair: &common.TradePair{FromAssetID: assetID1, ToAssetID: assetID2},
+                       desc:        "num of orders more than one return",
+                       tradePair:   &common.TradePair{FromAssetID: assetID1, ToAssetID: assetID2},
                        storeOrders: []*common.Order{order1, order2, order3, order4, order5},
                        wantOrders:  []*common.Order{order1, order2, order3, order4, order5},
                },
                {
-                       desc: "only one order",
-                       tradePair: &common.TradePair{FromAssetID: assetID1, ToAssetID: assetID2},
+                       desc:        "only one order",
+                       tradePair:   &common.TradePair{FromAssetID: assetID1, ToAssetID: assetID2},
                        storeOrders: []*common.Order{order1},
                        wantOrders:  []*common.Order{order1},
                },
                {
-                       desc: "store is empty",
-                       tradePair: &common.TradePair{FromAssetID: assetID1, ToAssetID: assetID2},
+                       desc:        "store is empty",
+                       tradePair:   &common.TradePair{FromAssetID: assetID1, ToAssetID: assetID2},
                        storeOrders: []*common.Order{},
                        wantOrders:  []*common.Order{},
                },
index 3dc4f6f..61111ba 100644 (file)
@@ -15,16 +15,19 @@ import (
        "github.com/vapor/protocol/vm/vmutil"
 )
 
+// Engine is used to generate math transactions
 type Engine struct {
        orderTable  *OrderTable
        maxFeeRate  float64
        nodeProgram []byte
 }
 
+// NewEngine return a new Engine
 func NewEngine(orderTable *OrderTable, maxFeeRate float64, nodeProgram []byte) *Engine {
        return &Engine{orderTable: orderTable, maxFeeRate: maxFeeRate, nodeProgram: nodeProgram}
 }
 
+// HasMatchedTx check does the input trade pair can generate a match deal
 func (e *Engine) HasMatchedTx(tradePairs ...*common.TradePair) bool {
        if err := validateTradePairs(tradePairs); err != nil {
                return false
@@ -55,141 +58,51 @@ func (e *Engine) NextMatchedTx(tradePairs ...*common.TradePair) (*types.Tx, erro
                e.orderTable.PopOrder(tradePair)
        }
 
-       if err := addPartialTradeOrder(tx, e.orderTable); err != nil {
+       if err := e.addPartialTradeOrder(tx); err != nil {
                return nil, err
        }
        return tx, nil
 }
 
-func (e *Engine) peekOrders(tradePairs []*common.TradePair) []*common.Order {
-       var orders []*common.Order
-       for _, tradePair := range tradePairs {
-               order := e.orderTable.PeekOrder(tradePair)
-               if order == nil {
-                       return nil
-               }
-
-               orders = append(orders, order)
-       }
-       return orders
-}
-
-func validateTradePairs(tradePairs []*common.TradePair) error {
-       if len(tradePairs) < 2 {
-               return errors.New("size of trade pairs at least 2")
-       }
-
-       for i, tradePair := range tradePairs {
-               oppositeTradePair := tradePairs[getOppositeIndex(len(tradePairs), i)]
-               if *tradePair.ToAssetID != *oppositeTradePair.FromAssetID {
-                       return errors.New("specified trade pairs is invalid")
-               }
-       }
-       return nil
-}
-
-func isMatched(orders []*common.Order) bool {
-       for i, order := range orders {
-               opposisteOrder := orders[getOppositeIndex(len(orders), i)]
-               if 1 / order.Rate < opposisteOrder.Rate {
-                       return false
-               }
-       }
-       return true
-}
-
-func (e *Engine) buildMatchTx(orders []*common.Order) (*types.Tx, error) {
-       txData := &types.TxData{Version: 1}
-       for i, order := range orders {
-               input := types.NewSpendInput(nil, *order.Utxo.SourceID, *order.FromAssetID, order.Utxo.Amount, order.Utxo.SourcePos, order.Utxo.ControlProgram)
-               txData.Inputs = append(txData.Inputs, input)
-
-               oppositeOrder := orders[getOppositeIndex(len(orders), i)]
-               if err := addMatchTxOutput(txData, input, order, oppositeOrder.Utxo.Amount); err != nil {
-                       return nil, err
-               }
-       }
-
-       if err := e.addMatchTxFeeOutput(txData); err != nil {
-               return nil, err
-       }
-
-       byteData, err := txData.MarshalText()
-       if err != nil {
-               return nil, err
-       }
-
-       txData.SerializedSize = uint64(len(byteData))
-       return types.NewTx(*txData), nil
-}
-
-func addMatchTxOutput(txData *types.TxData, txInput *types.TxInput, order *common.Order, oppositeAmount uint64) error {
-       contractArgs, err := segwit.DecodeP2WMCProgram(order.Utxo.ControlProgram)
-       if err != nil {
-               return err
-       }
-
-       requestAmount := calcRequestAmount(order.Utxo.Amount, contractArgs)
-       receiveAmount := vprMath.MinUint64(requestAmount, oppositeAmount)
-       shouldPayAmount := CalcShouldPayAmount(receiveAmount, contractArgs)
-       isPartialTrade := requestAmount > receiveAmount
-
-       setMatchTxArguments(txInput, isPartialTrade, len(txData.Outputs), receiveAmount)
-       txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(*order.ToAssetID, receiveAmount, contractArgs.SellerProgram))
-       if isPartialTrade {
-               txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(*order.FromAssetID, order.Utxo.Amount-shouldPayAmount, order.Utxo.ControlProgram))
-       }
-       return nil
-}
-
 func (e *Engine) addMatchTxFeeOutput(txData *types.TxData) error {
        txFee, err := CalcMatchedTxFee(txData, e.maxFeeRate)
        if err != nil {
                return err
        }
 
-       for feeAssetID, amount := range txFee {
-               var reminder int64 = 0
-               feeAmount := amount.FeeAmount
-               if amount.FeeAmount > amount.MaxFeeAmount {
-                       feeAmount = amount.MaxFeeAmount
-                       reminder = amount.FeeAmount - amount.MaxFeeAmount
+       for assetID, matchTxFee := range txFee {
+               feeAmount, reminder := matchTxFee.FeeAmount, int64(0)
+               if matchTxFee.FeeAmount > matchTxFee.MaxFeeAmount {
+                       feeAmount = matchTxFee.MaxFeeAmount
+                       reminder = matchTxFee.FeeAmount - matchTxFee.MaxFeeAmount
                }
-               txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(feeAssetID, uint64(feeAmount), e.nodeProgram))
+               txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(assetID, uint64(feeAmount), e.nodeProgram))
 
                // There is the remaining amount after paying the handling fee, assign it evenly to participants in the transaction
                averageAmount := reminder / int64(len(txData.Inputs))
                if averageAmount == 0 {
                        averageAmount = 1
                }
+
                for i := 0; i < len(txData.Inputs) && reminder > 0; i++ {
                        contractArgs, err := segwit.DecodeP2WMCProgram(txData.Inputs[i].ControlProgram())
                        if err != nil {
                                return err
                        }
 
-                       if i == len(txData.Inputs)-1 {
-                               txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(feeAssetID, uint64(reminder), contractArgs.SellerProgram))
-                       } else {
-                               txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(feeAssetID, uint64(averageAmount), contractArgs.SellerProgram))
+                       if reminder < 2*averageAmount {
+                               txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(assetID, uint64(reminder), contractArgs.SellerProgram))
+                               break
                        }
+
+                       txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(assetID, uint64(averageAmount), contractArgs.SellerProgram))
                        reminder -= averageAmount
                }
        }
        return nil
 }
 
-func setMatchTxArguments(txInput *types.TxInput, isPartialTrade bool, position int, receiveAmounts uint64) {
-       var arguments [][]byte
-       if isPartialTrade {
-               arguments = [][]byte{vm.Int64Bytes(int64(receiveAmounts)), vm.Int64Bytes(int64(position)), vm.Int64Bytes(contract.PartialTradeClauseSelector)}
-       } else {
-               arguments = [][]byte{vm.Int64Bytes(int64(position)), vm.Int64Bytes(contract.FullTradeClauseSelector)}
-       }
-       txInput.SetArguments(arguments)
-}
-
-func addPartialTradeOrder(tx *types.Tx, orderTable *OrderTable) error {
+func (e *Engine) addPartialTradeOrder(tx *types.Tx) error {
        for i, output := range tx.Outputs {
                if !segwit.IsP2WMCScript(output.ControlProgram()) {
                        continue
@@ -200,41 +113,70 @@ func addPartialTradeOrder(tx *types.Tx, orderTable *OrderTable) error {
                        return err
                }
 
-               if err := orderTable.AddOrder(order); err != nil {
+               if err := e.orderTable.AddOrder(order); err != nil {
                        return err
                }
        }
        return nil
 }
 
-func getOppositeIndex(size int, selfIdx int) int {
-       oppositeIdx := selfIdx + 1
-       if selfIdx >= size-1 {
-               oppositeIdx = 0
+func (e *Engine) buildMatchTx(orders []*common.Order) (*types.Tx, error) {
+       txData := &types.TxData{Version: 1}
+       for i, order := range orders {
+               input := types.NewSpendInput(nil, *order.Utxo.SourceID, *order.FromAssetID, order.Utxo.Amount, order.Utxo.SourcePos, order.Utxo.ControlProgram)
+               txData.Inputs = append(txData.Inputs, input)
+
+               oppositeOrder := orders[calcOppositeIndex(len(orders), i)]
+               if err := addMatchTxOutput(txData, input, order, oppositeOrder.Utxo.Amount); err != nil {
+                       return nil, err
+               }
+       }
+
+       if err := e.addMatchTxFeeOutput(txData); err != nil {
+               return nil, err
+       }
+
+       byteData, err := txData.MarshalText()
+       if err != nil {
+               return nil, err
+       }
+
+       txData.SerializedSize = uint64(len(byteData))
+       return types.NewTx(*txData), nil
+}
+
+func (e *Engine) peekOrders(tradePairs []*common.TradePair) []*common.Order {
+       var orders []*common.Order
+       for _, tradePair := range tradePairs {
+               order := e.orderTable.PeekOrder(tradePair)
+               if order == nil {
+                       return nil
+               }
+
+               orders = append(orders, order)
        }
-       return oppositeIdx
+       return orders
 }
 
+// MatchedTxFee is object to record the mov tx's fee information
 type MatchedTxFee struct {
        MaxFeeAmount int64
        FeeAmount    int64
 }
 
+// CalcMatchedTxFee is used to calculate tx's MatchedTxFees
 func CalcMatchedTxFee(txData *types.TxData, maxFeeRate float64) (map[bc.AssetID]*MatchedTxFee, error) {
        assetFeeMap := make(map[bc.AssetID]*MatchedTxFee)
-       sellerProgramMap := make(map[string]bool)
-       assetInputMap := make(map[bc.AssetID]uint64)
+       dealProgMaps := make(map[string]bool)
 
        for _, input := range txData.Inputs {
-               assetFeeMap[input.AssetID()] = &MatchedTxFee{}
-               assetFeeMap[input.AssetID()].FeeAmount += int64(input.AssetAmount().Amount)
+               assetFeeMap[input.AssetID()] = &MatchedTxFee{FeeAmount: int64(input.AssetAmount().Amount)}
                contractArgs, err := segwit.DecodeP2WMCProgram(input.ControlProgram())
                if err != nil {
                        return nil, err
                }
 
-               sellerProgramMap[hex.EncodeToString(contractArgs.SellerProgram)] = true
-               assetInputMap[input.AssetID()] = input.Amount()
+               dealProgMaps[hex.EncodeToString(contractArgs.SellerProgram)] = true
        }
 
        for _, input := range txData.Inputs {
@@ -243,40 +185,87 @@ func CalcMatchedTxFee(txData *types.TxData, maxFeeRate float64) (map[bc.AssetID]
                        return nil, err
                }
 
-               oppositeAmount := assetInputMap[contractArgs.RequestedAsset]
+               oppositeAmount := uint64(assetFeeMap[contractArgs.RequestedAsset].FeeAmount)
                receiveAmount := vprMath.MinUint64(calcRequestAmount(input.Amount(), contractArgs), oppositeAmount)
-               assetFeeMap[input.AssetID()].MaxFeeAmount = CalcMaxFeeAmount(CalcShouldPayAmount(receiveAmount, contractArgs), maxFeeRate)
+               assetFeeMap[input.AssetID()].MaxFeeAmount = calcMaxFeeAmount(calcShouldPayAmount(receiveAmount, contractArgs), maxFeeRate)
        }
 
        for _, output := range txData.Outputs {
-               // minus the amount of the re-order
-               if segwit.IsP2WMCScript(output.ControlProgram()) {
-                       assetFeeMap[*output.AssetAmount().AssetId].FeeAmount -= int64(output.AssetAmount().Amount)
-               }
-               // minus the amount of seller's receiving output
-               if _, ok := sellerProgramMap[hex.EncodeToString(output.ControlProgram())]; ok {
-                       assetID := *output.AssetAmount().AssetId
-                       fee, ok := assetFeeMap[assetID]
-                       if !ok {
-                               continue
-                       }
-                       fee.FeeAmount -= int64(output.AssetAmount().Amount)
-                       if fee.FeeAmount == 0 {
-                               delete(assetFeeMap, assetID)
+               assetAmount := output.AssetAmount()
+               if _, ok := dealProgMaps[hex.EncodeToString(output.ControlProgram())]; ok || segwit.IsP2WMCScript(output.ControlProgram()) {
+                       assetFeeMap[*assetAmount.AssetId].FeeAmount -= int64(assetAmount.Amount)
+                       if assetFeeMap[*assetAmount.AssetId].FeeAmount <= 0 {
+                               delete(assetFeeMap, *assetAmount.AssetId)
                        }
                }
        }
        return assetFeeMap, nil
 }
 
+func addMatchTxOutput(txData *types.TxData, txInput *types.TxInput, order *common.Order, oppositeAmount uint64) error {
+       contractArgs, err := segwit.DecodeP2WMCProgram(order.Utxo.ControlProgram)
+       if err != nil {
+               return err
+       }
+
+       requestAmount := calcRequestAmount(order.Utxo.Amount, contractArgs)
+       receiveAmount := vprMath.MinUint64(requestAmount, oppositeAmount)
+       shouldPayAmount := calcShouldPayAmount(receiveAmount, contractArgs)
+       isPartialTrade := requestAmount > receiveAmount
+
+       setMatchTxArguments(txInput, isPartialTrade, len(txData.Outputs), receiveAmount)
+       txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(*order.ToAssetID, receiveAmount, contractArgs.SellerProgram))
+       if isPartialTrade {
+               txData.Outputs = append(txData.Outputs, types.NewIntraChainOutput(*order.FromAssetID, order.Utxo.Amount-shouldPayAmount, order.Utxo.ControlProgram))
+       }
+       return nil
+}
+
 func calcRequestAmount(fromAmount uint64, contractArg *vmutil.MagneticContractArgs) uint64 {
        return uint64(int64(fromAmount) * contractArg.RatioNumerator / contractArg.RatioDenominator)
 }
 
-func CalcShouldPayAmount(receiveAmount uint64, contractArg *vmutil.MagneticContractArgs) uint64 {
+func calcShouldPayAmount(receiveAmount uint64, contractArg *vmutil.MagneticContractArgs) uint64 {
        return uint64(math.Floor(float64(receiveAmount) * float64(contractArg.RatioDenominator) / float64(contractArg.RatioNumerator)))
 }
 
-func CalcMaxFeeAmount(shouldPayAmount uint64, maxFeeRate float64) int64 {
+func calcMaxFeeAmount(shouldPayAmount uint64, maxFeeRate float64) int64 {
        return int64(math.Ceil(float64(shouldPayAmount) * maxFeeRate))
 }
+
+func calcOppositeIndex(size int, selfIdx int) int {
+       return (selfIdx + 1) % size
+}
+
+func isMatched(orders []*common.Order) bool {
+       for i, order := range orders {
+               if opposisteOrder := orders[calcOppositeIndex(len(orders), i)]; 1/order.Rate < opposisteOrder.Rate {
+                       return false
+               }
+       }
+       return true
+}
+
+func setMatchTxArguments(txInput *types.TxInput, isPartialTrade bool, position int, receiveAmounts uint64) {
+       var arguments [][]byte
+       if isPartialTrade {
+               arguments = [][]byte{vm.Int64Bytes(int64(receiveAmounts)), vm.Int64Bytes(int64(position)), vm.Int64Bytes(contract.PartialTradeClauseSelector)}
+       } else {
+               arguments = [][]byte{vm.Int64Bytes(int64(position)), vm.Int64Bytes(contract.FullTradeClauseSelector)}
+       }
+       txInput.SetArguments(arguments)
+}
+
+func validateTradePairs(tradePairs []*common.TradePair) error {
+       if len(tradePairs) < 2 {
+               return errors.New("size of trade pairs at least 2")
+       }
+
+       for i, tradePair := range tradePairs {
+               oppositeTradePair := tradePairs[calcOppositeIndex(len(tradePairs), i)]
+               if *tradePair.ToAssetID != *oppositeTradePair.FromAssetID {
+                       return errors.New("specified trade pairs is invalid")
+               }
+       }
+       return nil
+}
index 52fc53a..05e42fb 100644 (file)
@@ -11,6 +11,9 @@ import (
        "github.com/vapor/protocol/bc/types"
 )
 
+/*
+       Test: validateTradePairs vaild and invaild case for 2, 3 trade pairs
+*/
 func TestGenerateMatchedTxs(t *testing.T) {
        btc2eth := &common.TradePair{FromAssetID: &mock.BTC, ToAssetID: &mock.ETH}
        eth2btc := &common.TradePair{FromAssetID: &mock.ETH, ToAssetID: &mock.BTC}
@@ -108,22 +111,22 @@ func TestCalcMatchedTxFee(t *testing.T) {
                wantMatchedTxFee map[bc.AssetID]*MatchedTxFee
        }{
                {
-                       desc: "fee less than max fee",
-                       maxFeeRate: 0.05,
+                       desc:             "fee less than max fee",
+                       maxFeeRate:       0.05,
                        wantMatchedTxFee: map[bc.AssetID]*MatchedTxFee{mock.ETH: {FeeAmount: 10, MaxFeeAmount: 26}},
-                       tx: &mock.MatchedTxs[1].TxData,
+                       tx:               &mock.MatchedTxs[1].TxData,
                },
                {
-                       desc: "fee refund in tx",
-                       maxFeeRate: 0.05,
+                       desc:             "fee refund in tx",
+                       maxFeeRate:       0.05,
                        wantMatchedTxFee: map[bc.AssetID]*MatchedTxFee{mock.ETH: {FeeAmount: 27, MaxFeeAmount: 27}},
-                       tx: &mock.MatchedTxs[2].TxData,
+                       tx:               &mock.MatchedTxs[2].TxData,
                },
                {
-                       desc: "fee is zero",
-                       maxFeeRate: 0.05,
+                       desc:             "fee is zero",
+                       maxFeeRate:       0.05,
                        wantMatchedTxFee: map[bc.AssetID]*MatchedTxFee{},
-                       tx: &mock.MatchedTxs[0].TxData,
+                       tx:               &mock.MatchedTxs[0].TxData,
                },
        }
 
index b5fb4bb..7e6043f 100644 (file)
@@ -8,8 +8,9 @@ import (
        "github.com/vapor/errors"
 )
 
+// OrderTable is used to handle the mov orders in memory like stack
 type OrderTable struct {
-       movStore       database.MovStore
+       movStore database.MovStore
        // key of tradePair -> []order
        dbOrders map[string][]*common.Order
        // key of tradePair -> iterator
@@ -21,6 +22,7 @@ type OrderTable struct {
        arrivalDelOrders map[string]*common.Order
 }
 
+// NewOrderTable create a new OrderTable object
 func NewOrderTable(movStore database.MovStore, arrivalAddOrders, arrivalDelOrders []*common.Order) *OrderTable {
        return &OrderTable{
                movStore:       movStore,
@@ -32,21 +34,32 @@ func NewOrderTable(movStore database.MovStore, arrivalAddOrders, arrivalDelOrder
        }
 }
 
+// AddOrder add the in memory temp order to order table
+func (o *OrderTable) AddOrder(order *common.Order) error {
+       tradePairKey := order.TradePair().Key()
+       orders := o.arrivalAddOrders[tradePairKey]
+       if len(orders) > 0 && order.Rate > orders[len(orders)-1].Rate {
+               return errors.New("rate of order must less than the min order in order table")
+       }
+
+       o.arrivalAddOrders[tradePairKey] = append(orders, order)
+       return nil
+}
+
+// PeekOrder return the next lowest order of given trade pair
 func (o *OrderTable) PeekOrder(tradePair *common.TradePair) *common.Order {
        if len(o.dbOrders[tradePair.Key()]) == 0 {
                o.extendDBOrders(tradePair)
        }
 
        var nextOrder *common.Order
-
        orders := o.dbOrders[tradePair.Key()]
        if len(orders) != 0 {
-               nextOrder = orders[len(orders) - 1]
+               nextOrder = orders[len(orders)-1]
        }
 
        if nextOrder != nil && o.arrivalDelOrders[nextOrder.Key()] != nil {
                o.dbOrders[tradePair.Key()] = orders[0 : len(orders)-1]
-               delete(o.arrivalDelOrders, nextOrder.Key())
                return o.PeekOrder(tradePair)
        }
 
@@ -57,6 +70,7 @@ func (o *OrderTable) PeekOrder(tradePair *common.TradePair) *common.Order {
        return nextOrder
 }
 
+// PopOrder delete the next lowest order of given trade pair
 func (o *OrderTable) PopOrder(tradePair *common.TradePair) {
        order := o.PeekOrder(tradePair)
        if order == nil {
@@ -64,25 +78,34 @@ func (o *OrderTable) PopOrder(tradePair *common.TradePair) {
        }
 
        orders := o.dbOrders[tradePair.Key()]
-       if len(orders) != 0 && orders[len(orders) - 1].Key() == order.Key() {
+       if len(orders) != 0 && orders[len(orders)-1].Key() == order.Key() {
                o.dbOrders[tradePair.Key()] = orders[0 : len(orders)-1]
        }
 
        arrivalOrders := o.arrivalAddOrders[tradePair.Key()]
-       if len(arrivalOrders) != 0 && arrivalOrders[len(arrivalOrders) - 1].Key() == order.Key() {
+       if len(arrivalOrders) != 0 && arrivalOrders[len(arrivalOrders)-1].Key() == order.Key() {
                o.arrivalAddOrders[tradePair.Key()] = arrivalOrders[0 : len(arrivalOrders)-1]
        }
 }
 
-func (o *OrderTable) AddOrder(order *common.Order) error {
-       tradePair := order.TradePair()
-       orders := o.dbOrders[tradePair.Key()]
-       if len(orders) > 0 && order.Rate > orders[len(orders)-1].Rate {
-               return errors.New("rate of order must less than the min order in order table")
+func arrangeArrivalAddOrders(orders []*common.Order) map[string][]*common.Order {
+       arrivalAddOrderMap := make(map[string][]*common.Order)
+       for _, order := range orders {
+               arrivalAddOrderMap[order.TradePair().Key()] = append(arrivalAddOrderMap[order.TradePair().Key()], order)
        }
 
-       o.dbOrders[tradePair.Key()] = append(orders, order)
-       return nil
+       for _, orders := range arrivalAddOrderMap {
+               sort.Sort(sort.Reverse(common.OrderSlice(orders)))
+       }
+       return arrivalAddOrderMap
+}
+
+func arrangeArrivalDelOrders(orders []*common.Order) map[string]*common.Order {
+       arrivalDelOrderMap := make(map[string]*common.Order)
+       for _, order := range orders {
+               arrivalDelOrderMap[order.Key()] = order
+       }
+       return arrivalDelOrderMap
 }
 
 func (o *OrderTable) extendDBOrders(tradePair *common.TradePair) {
@@ -99,29 +122,8 @@ func (o *OrderTable) extendDBOrders(tradePair *common.TradePair) {
 }
 
 func (o *OrderTable) peekArrivalOrder(tradePair *common.TradePair) *common.Order {
-       arrivalAddOrders := o.arrivalAddOrders[tradePair.Key()]
-       if len(arrivalAddOrders) > 0 {
-               return arrivalAddOrders[len(arrivalAddOrders) -1]
+       if arrivalAddOrders := o.arrivalAddOrders[tradePair.Key()]; len(arrivalAddOrders) > 0 {
+               return arrivalAddOrders[len(arrivalAddOrders)-1]
        }
        return nil
 }
-
-func arrangeArrivalAddOrders(orders []*common.Order) map[string][]*common.Order {
-       arrivalAddOrderMap := make(map[string][]*common.Order)
-       for _, order := range orders {
-               arrivalAddOrderMap[order.TradePair().Key()] = append(arrivalAddOrderMap[order.TradePair().Key()], order)
-       }
-
-       for _, orders := range arrivalAddOrderMap {
-               sort.Sort(sort.Reverse(common.OrderSlice(orders)))
-       }
-       return arrivalAddOrderMap
-}
-
-func arrangeArrivalDelOrders(orders []*common.Order) map[string]*common.Order {
-       arrivalDelOrderMap := make(map[string]*common.Order)
-       for _, order := range orders {
-               arrivalDelOrderMap[order.Key()] = order
-       }
-       return arrivalDelOrderMap
-}
index 1235e5f..9f9ca17 100644 (file)
@@ -90,7 +90,7 @@ func TestOrderTable(t *testing.T) {
                                        mock.Btc2EthOrders[1], mock.Btc2EthOrders[2], mock.Btc2EthOrders[3],
                                }),
                        initArrivalAddOrders: []*common.Order{mock.Btc2EthOrders[0]},
-                       popOrders: []*common.TradePair{btc2eth},
+                       popOrders:            []*common.TradePair{btc2eth},
                        wantPeekedOrders: map[common.TradePair]*common.Order{
                                *btc2eth: mock.Btc2EthOrders[0],
                        },
@@ -129,7 +129,7 @@ func TestOrderTable(t *testing.T) {
                                }),
                        initArrivalAddOrders: []*common.Order{mock.Btc2EthOrders[0], mock.Btc2EthOrders[2]},
                        initArrivalDelOrders: []*common.Order{mock.Btc2EthOrders[3]},
-                       popOrders: []*common.TradePair{btc2eth},
+                       popOrders:            []*common.TradePair{btc2eth},
                        wantPeekedOrders: map[common.TradePair]*common.Order{
                                *btc2eth: mock.Btc2EthOrders[2],
                        },
@@ -154,6 +154,19 @@ func TestOrderTable(t *testing.T) {
                                *btc2eth: nil,
                        },
                },
+               {
+                       desc: "has arrival delete orders, no add order, no pop order, need recursive to peek one order",
+                       initMovStore: mock.NewMovStore(
+                               []*common.TradePair{btc2eth},
+                               []*common.Order{
+                                       mock.Btc2EthOrders[0], mock.Btc2EthOrders[1], mock.Btc2EthOrders[2], mock.Btc2EthOrders[3],
+                               }),
+                       initArrivalAddOrders: []*common.Order{},
+                       initArrivalDelOrders: []*common.Order{mock.Btc2EthOrders[3], mock.Btc2EthOrders[0], mock.Btc2EthOrders[2]},
+                       wantPeekedOrders: map[common.TradePair]*common.Order{
+                               *btc2eth: mock.Btc2EthOrders[1],
+                       },
+               },
        }
 
        for i, c := range cases {