OSDN Git Service

edht tx txoutput (#1966)
[bytom/bytom.git] / protocol / validation / tx.go
index face9b5..8759345 100644 (file)
@@ -7,21 +7,18 @@ import (
        "sync"
 
        "github.com/bytom/bytom/consensus"
-       "github.com/bytom/bytom/consensus/segwit"
        "github.com/bytom/bytom/errors"
        "github.com/bytom/bytom/math/checked"
        "github.com/bytom/bytom/protocol/bc"
        "github.com/bytom/bytom/protocol/vm"
 )
 
-const ruleAA = 142500
-
 // validate transaction error
 var (
        ErrTxVersion                 = errors.New("invalid transaction version")
        ErrWrongTransactionSize      = errors.New("invalid transaction size")
        ErrBadTimeRange              = errors.New("invalid transaction time range")
-       ErrEmptyInputIDs             = errors.New("got the empty InputIDs")
+       ErrInputDoubleSend           = errors.New("got the double spend input")
        ErrNotStandardTx             = errors.New("not standard transaction")
        ErrWrongCoinbaseTransaction  = errors.New("wrong coinbase transaction")
        ErrWrongCoinbaseAsset        = errors.New("wrong coinbase assetID")
@@ -38,6 +35,9 @@ var (
        ErrUnbalanced                = errors.New("unbalanced asset amount between input and output")
        ErrOverGasCredit             = errors.New("all gas credit has been spend")
        ErrGasCalculate              = errors.New("gas usage calculate got a math error")
+       ErrVotePubKey                = errors.New("invalid public key of vote")
+       ErrVoteOutputAmount          = errors.New("invalid vote amount")
+       ErrVoteOutputAseet           = errors.New("incorrect asset_id while checking vote asset")
 )
 
 // GasState record the gas usage status
@@ -45,7 +45,6 @@ type GasState struct {
        BTMValue   uint64
        GasLeft    int64
        GasUsed    int64
-       GasValid   bool
        StorageGas int64
 }
 
@@ -81,7 +80,6 @@ func (g *GasState) setGasValid() error {
                return errors.Wrap(ErrGasCalculate, "setGasValid calc gasUsed")
        }
 
-       g.GasValid = true
        return nil
 }
 
@@ -97,22 +95,26 @@ func (g *GasState) updateUsage(gasLeft int64) error {
                return errors.Wrap(ErrGasCalculate, "updateUsage calc gas diff")
        }
 
-       if !g.GasValid && (g.GasUsed > consensus.DefaultGasCredit || g.StorageGas > g.GasLeft) {
+       if g.GasUsed > consensus.DefaultGasCredit || g.StorageGas > g.GasLeft {
                return ErrOverGasCredit
        }
        return nil
 }
 
+// ProgramConverterFunc represent a func convert control program
+type ProgramConverterFunc func(prog []byte) ([]byte, error)
+
 // validationState contains the context that must propagate through
 // the transaction graph when validating entries.
 type validationState struct {
        block     *bc.Block
        tx        *bc.Tx
        gasStatus *GasState
-       entryID   bc.Hash           // The ID of the nearest enclosing entry
-       sourcePos uint64            // The source position, for validate ValueSources
-       destPos   uint64            // The destination position, for validate ValueDestinations
-       cache     map[bc.Hash]error // Memoized per-entry validation results
+       entryID   bc.Hash              // The ID of the nearest enclosing entry
+       sourcePos uint64               // The source position, for validate ValueSources
+       destPos   uint64               // The destination position, for validate ValueDestinations
+       cache     map[bc.Hash]error    // Memoized per-entry validation results
+       converter ProgramConverterFunc // Program converter function
 }
 
 func checkValid(vs *validationState, e bc.Entry) (err error) {
@@ -179,19 +181,6 @@ func checkValid(vs *validationState, e bc.Entry) (err error) {
                        }
                }
 
-               for _, BTMInputID := range vs.tx.GasInputIDs {
-                       e, ok := vs.tx.Entries[BTMInputID]
-                       if !ok {
-                               return errors.Wrapf(bc.ErrMissingEntry, "entry for bytom input %x not found", BTMInputID)
-                       }
-
-                       vs2 := *vs
-                       vs2.entryID = BTMInputID
-                       if err := checkValid(&vs2, e); err != nil {
-                               return errors.Wrap(err, "checking gas input")
-                       }
-               }
-
                for i, dest := range e.WitnessDestinations {
                        vs2 := *vs
                        vs2.destPos = uint64(i)
@@ -200,10 +189,6 @@ func checkValid(vs *validationState, e bc.Entry) (err error) {
                        }
                }
 
-               if err := vs.gasStatus.setGasValid(); err != nil {
-                       return err
-               }
-
                for i, src := range e.Sources {
                        vs2 := *vs
                        vs2.sourcePos = uint64(i)
@@ -212,7 +197,11 @@ func checkValid(vs *validationState, e bc.Entry) (err error) {
                        }
                }
 
-       case *bc.Output:
+               if err := vs.gasStatus.setGasValid(); err != nil {
+                       return err
+               }
+
+       case *bc.OriginalOutput:
                vs2 := *vs
                vs2.sourcePos = 0
                if err = checkValidSrc(&vs2, e.Source); err != nil {
@@ -225,14 +214,31 @@ func checkValid(vs *validationState, e bc.Entry) (err error) {
                if err = checkValidSrc(&vs2, e.Source); err != nil {
                        return errors.Wrap(err, "checking retirement source")
                }
+       case *bc.VoteOutput:
+               if len(e.Vote) != 64 {
+                       return ErrVotePubKey
+               }
+
+               vs2 := *vs
+               vs2.sourcePos = 0
+               if err = checkValidSrc(&vs2, e.Source); err != nil {
+                       return errors.Wrap(err, "checking vote output source")
+               }
+
+               if e.Source.Value.Amount < consensus.MinVoteOutputAmount {
+                       return ErrVoteOutputAmount
+               }
 
+               if *e.Source.Value.AssetId != *consensus.BTMAssetID {
+                       return ErrVoteOutputAseet
+               }
        case *bc.Issuance:
                computedAssetID := e.WitnessAssetDefinition.ComputeAssetID()
                if computedAssetID != *e.Value.AssetId {
                        return errors.WithDetailf(ErrMismatchedAssetID, "asset ID is %x, issuance wants %x", computedAssetID.Bytes(), e.Value.AssetId.Bytes())
                }
 
-               gasLeft, err := vm.Verify(NewTxVMContext(vs, e, e.WitnessAssetDefinition.IssuanceProgram, e.WitnessArguments), vs.gasStatus.GasLeft)
+               gasLeft, err := vm.Verify(NewTxVMContext(vs, e, e.WitnessAssetDefinition.IssuanceProgram, [][]byte{}, e.WitnessArguments), vs.gasStatus.GasLeft)
                if err != nil {
                        return errors.Wrap(err, "checking issuance program")
                }
@@ -250,12 +256,12 @@ func checkValid(vs *validationState, e bc.Entry) (err error) {
                if e.SpentOutputId == nil {
                        return errors.Wrap(ErrMissingField, "spend without spent output ID")
                }
-               spentOutput, err := vs.tx.Output(*e.SpentOutputId)
+               spentOutput, err := vs.tx.OriginalOutput(*e.SpentOutputId)
                if err != nil {
                        return errors.Wrap(err, "getting spend prevout")
                }
 
-               gasLeft, err := vm.Verify(NewTxVMContext(vs, e, spentOutput.ControlProgram, e.WitnessArguments), vs.gasStatus.GasLeft)
+               gasLeft, err := vm.Verify(NewTxVMContext(vs, e, spentOutput.ControlProgram, spentOutput.StateData, e.WitnessArguments), vs.gasStatus.GasLeft)
                if err != nil {
                        return errors.Wrap(err, "checking control program")
                }
@@ -283,6 +289,47 @@ func checkValid(vs *validationState, e bc.Entry) (err error) {
                if err = checkValidDest(&vs2, e.WitnessDestination); err != nil {
                        return errors.Wrap(err, "checking spend destination")
                }
+       case *bc.VetoInput:
+               if e.SpentOutputId == nil {
+                       return errors.Wrap(ErrMissingField, "vetoInput without vetoInput output ID")
+               }
+
+               voteOutput, err := vs.tx.VoteOutput(*e.SpentOutputId)
+               if err != nil {
+                       return errors.Wrap(err, "getting vetoInput prevout")
+               }
+
+               if len(voteOutput.Vote) != 64 {
+                       return ErrVotePubKey
+               }
+
+               gasLeft, err := vm.Verify(NewTxVMContext(vs, e, voteOutput.ControlProgram, voteOutput.StateData, e.WitnessArguments), vs.gasStatus.GasLeft)
+               if err != nil {
+                       return errors.Wrap(err, "checking control program")
+               }
+               if err = vs.gasStatus.updateUsage(gasLeft); err != nil {
+                       return err
+               }
+
+               eq, err := voteOutput.Source.Value.Equal(e.WitnessDestination.Value)
+               if err != nil {
+                       return err
+               }
+               if !eq {
+                       return errors.WithDetailf(
+                               ErrMismatchedValue,
+                               "previous output is for %d unit(s) of %x, vetoInput wants %d unit(s) of %x",
+                               voteOutput.Source.Value.Amount,
+                               voteOutput.Source.Value.AssetId.Bytes(),
+                               e.WitnessDestination.Value.Amount,
+                               e.WitnessDestination.Value.AssetId.Bytes(),
+                       )
+               }
+               vs2 := *vs
+               vs2.destPos = 0
+               if err = checkValidDest(&vs2, e.WitnessDestination); err != nil {
+                       return errors.Wrap(err, "checking vetoInput destination")
+               }
 
        case *bc.Coinbase:
                if vs.block == nil || len(vs.block.Transactions) == 0 || vs.block.Transactions[0] != vs.tx {
@@ -353,6 +400,12 @@ func checkValidSrc(vstate *validationState, vs *bc.ValueSource) error {
                }
                dest = ref.WitnessDestination
 
+       case *bc.VetoInput:
+               if vs.Position != 0 {
+                       return errors.Wrapf(ErrPosition, "invalid position %d for veto-input source", vs.Position)
+               }
+               dest = ref.WitnessDestination
+
        case *bc.Mux:
                if vs.Position >= uint64(len(ref.WitnessDestinations)) {
                        return errors.Wrapf(ErrPosition, "invalid position %d for %d-destination mux source", vs.Position, len(ref.WitnessDestinations))
@@ -400,7 +453,7 @@ func checkValidDest(vs *validationState, vd *bc.ValueDestination) error {
 
        var src *bc.ValueSource
        switch ref := e.(type) {
-       case *bc.Output:
+       case *bc.OriginalOutput:
                if vd.Position != 0 {
                        return errors.Wrapf(ErrPosition, "invalid position %d for output destination", vd.Position)
                }
@@ -412,6 +465,12 @@ func checkValidDest(vs *validationState, vd *bc.ValueDestination) error {
                }
                src = ref.Source
 
+       case *bc.VoteOutput:
+               if vd.Position != 0 {
+                       return errors.Wrapf(ErrPosition, "invalid position %d for output destination", vd.Position)
+               }
+               src = ref.Source
+
        case *bc.Mux:
                if vd.Position >= uint64(len(ref.Sources)) {
                        return errors.Wrapf(ErrPosition, "invalid position %d for %d-source mux destination", vd.Position, len(ref.Sources))
@@ -441,43 +500,16 @@ func checkValidDest(vs *validationState, vd *bc.ValueDestination) error {
        return nil
 }
 
-func checkStandardTx(tx *bc.Tx, blockHeight uint64) error {
+func checkDoubleSpend(tx *bc.Tx) error {
+       usedInputMap := make(map[bc.Hash]bool)
        for _, id := range tx.InputIDs {
-               if blockHeight >= ruleAA && id.IsZero() {
-                       return ErrEmptyInputIDs
+               if _, ok := usedInputMap[id]; ok {
+                       return ErrInputDoubleSend
                }
-       }
 
-       for _, id := range tx.GasInputIDs {
-               spend, err := tx.Spend(id)
-               if err != nil {
-                       continue
-               }
-               spentOutput, err := tx.Output(*spend.SpentOutputId)
-               if err != nil {
-                       return err
-               }
-
-               if !segwit.IsP2WScript(spentOutput.ControlProgram.Code) {
-                       return ErrNotStandardTx
-               }
+               usedInputMap[id] = true
        }
 
-       for _, id := range tx.ResultIds {
-               e, ok := tx.Entries[*id]
-               if !ok {
-                       return errors.Wrapf(bc.ErrMissingEntry, "id %x", id.Bytes())
-               }
-
-               output, ok := e.(*bc.Output)
-               if !ok || *output.Source.Value.AssetId != *consensus.BTMAssetID {
-                       continue
-               }
-
-               if !segwit.IsP2WScript(output.ControlProgram.Code) {
-                       return ErrNotStandardTx
-               }
-       }
        return nil
 }
 
@@ -493,29 +525,37 @@ func checkTimeRange(tx *bc.Tx, block *bc.Block) error {
 }
 
 // ValidateTx validates a transaction.
-func ValidateTx(tx *bc.Tx, block *bc.Block) (*GasState, error) {
-       gasStatus := &GasState{GasValid: false}
+func ValidateTx(tx *bc.Tx, block *bc.Block, converter ProgramConverterFunc) (*GasState, error) {
        if block.Version == 1 && tx.Version != 1 {
-               return gasStatus, errors.WithDetailf(ErrTxVersion, "block version %d, transaction version %d", block.Version, tx.Version)
+               return nil, errors.WithDetailf(ErrTxVersion, "block version %d, transaction version %d", block.Version, tx.Version)
        }
+
        if tx.SerializedSize == 0 {
-               return gasStatus, ErrWrongTransactionSize
+               return nil, ErrWrongTransactionSize
        }
+
        if err := checkTimeRange(tx, block); err != nil {
-               return gasStatus, err
+               return nil, err
        }
-       if err := checkStandardTx(tx, block.Height); err != nil {
-               return gasStatus, err
+
+       if err := checkDoubleSpend(tx); err != nil {
+               return nil, err
        }
 
        vs := &validationState{
                block:     block,
                tx:        tx,
                entryID:   tx.ID,
-               gasStatus: gasStatus,
+               gasStatus: &GasState{},
                cache:     make(map[bc.Hash]error),
+               converter: converter,
        }
-       return vs.gasStatus, checkValid(vs, tx.TxHeader)
+
+       if err := checkValid(vs, tx.TxHeader); err != nil {
+               return nil, err
+       }
+
+       return vs.gasStatus, nil
 }
 
 type validateTxWork struct {
@@ -541,16 +581,16 @@ func (r *ValidateTxResult) GetError() error {
        return r.err
 }
 
-func validateTxWorker(workCh chan *validateTxWork, resultCh chan *ValidateTxResult, wg *sync.WaitGroup) {
+func validateTxWorker(workCh chan *validateTxWork, resultCh chan *ValidateTxResult, wg *sync.WaitGroup, converter ProgramConverterFunc) {
        for work := range workCh {
-               gasStatus, err := ValidateTx(work.tx, work.block)
+               gasStatus, err := ValidateTx(work.tx, work.block, converter)
                resultCh <- &ValidateTxResult{i: work.i, gasStatus: gasStatus, err: err}
        }
        wg.Done()
 }
 
 // ValidateTxs validates txs in async mode
-func ValidateTxs(txs []*bc.Tx, block *bc.Block) []*ValidateTxResult {
+func ValidateTxs(txs []*bc.Tx, block *bc.Block, converter ProgramConverterFunc) []*ValidateTxResult {
        txSize := len(txs)
        validateWorkerNum := runtime.NumCPU()
        //init the goroutine validate worker
@@ -559,7 +599,7 @@ func ValidateTxs(txs []*bc.Tx, block *bc.Block) []*ValidateTxResult {
        resultCh := make(chan *ValidateTxResult, txSize)
        for i := 0; i <= validateWorkerNum && i < txSize; i++ {
                wg.Add(1)
-               go validateTxWorker(workCh, resultCh, &wg)
+               go validateTxWorker(workCh, resultCh, &wg, converter)
        }
 
        //sent the works