OSDN Git Service

feat: add cross-chain output (#56)
[bytom/vapor.git] / protocol / validation / tx.go
index 8d36ea5..c12c04f 100644 (file)
@@ -210,7 +210,14 @@ func checkValid(vs *validationState, e bc.Entry) (err error) {
                        }
                }
 
-       case *bc.Output:
+       case *bc.IntraChainOutput:
+               vs2 := *vs
+               vs2.sourcePos = 0
+               if err = checkValidSrc(&vs2, e.Source); err != nil {
+                       return errors.Wrap(err, "checking output source")
+               }
+
+       case *bc.CrossChainOutput:
                vs2 := *vs
                vs2.sourcePos = 0
                if err = checkValidSrc(&vs2, e.Source); err != nil {
@@ -248,7 +255,7 @@ func checkValid(vs *validationState, e bc.Entry) (err error) {
                if e.SpentOutputId == nil {
                        return errors.Wrap(ErrMissingField, "spend without spent output ID")
                }
-               spentOutput, err := vs.tx.Output(*e.SpentOutputId)
+               spentOutput, err := vs.tx.IntraChainOutput(*e.SpentOutputId)
                if err != nil {
                        return errors.Wrap(err, "getting spend prevout")
                }
@@ -398,7 +405,13 @@ func checkValidDest(vs *validationState, vd *bc.ValueDestination) error {
 
        var src *bc.ValueSource
        switch ref := e.(type) {
-       case *bc.Output:
+       case *bc.IntraChainOutput:
+               if vd.Position != 0 {
+                       return errors.Wrapf(ErrPosition, "invalid position %d for output destination", vd.Position)
+               }
+               src = ref.Source
+
+       case *bc.CrossChainOutput:
                if vd.Position != 0 {
                        return errors.Wrapf(ErrPosition, "invalid position %d for output destination", vd.Position)
                }
@@ -417,7 +430,7 @@ func checkValidDest(vs *validationState, vd *bc.ValueDestination) error {
                src = ref.Sources[vd.Position]
 
        default:
-               return errors.Wrapf(bc.ErrEntryType, "value destination is %T, should be output, retirement, or mux", e)
+               return errors.Wrapf(bc.ErrEntryType, "value destination is %T, should be intra-chain/cross-chain output, retirement, or mux", e)
        }
 
        if src.Ref == nil || *src.Ref != vs.entryID {
@@ -451,12 +464,13 @@ func checkStandardTx(tx *bc.Tx, blockHeight uint64) error {
                if err != nil {
                        continue
                }
-               spentOutput, err := tx.Output(*spend.SpentOutputId)
+
+               intraChainSpentOutput, err := tx.IntraChainOutput(*spend.SpentOutputId)
                if err != nil {
                        return err
                }
 
-               if !segwit.IsP2WScript(spentOutput.ControlProgram.Code) {
+               if !segwit.IsP2WScript(intraChainSpentOutput.ControlProgram.Code) {
                        return ErrNotStandardTx
                }
        }
@@ -467,15 +481,29 @@ func checkStandardTx(tx *bc.Tx, blockHeight uint64) error {
                        return errors.Wrapf(bc.ErrMissingEntry, "id %x", id.Bytes())
                }
 
-               output, ok := e.(*bc.Output)
-               if !ok || *output.Source.Value.AssetId != *consensus.BTMAssetID {
+               var prog []byte
+               switch e := e.(type) {
+               case *bc.IntraChainOutput:
+                       if *e.Source.Value.AssetId != *consensus.BTMAssetID {
+                               continue
+                       }
+                       prog = e.ControlProgram.Code
+
+               case *bc.CrossChainOutput:
+                       if *e.Source.Value.AssetId != *consensus.BTMAssetID {
+                               continue
+                       }
+                       prog = e.ControlProgram.Code
+
+               default:
                        continue
                }
 
-               if !segwit.IsP2WScript(output.ControlProgram.Code) {
+               if !segwit.IsP2WScript(prog) {
                        return ErrNotStandardTx
                }
        }
+
        return nil
 }
 
@@ -487,6 +515,7 @@ func checkTimeRange(tx *bc.Tx, block *bc.Block) error {
        if tx.TimeRange < block.Height {
                return ErrBadTimeRange
        }
+
        return nil
 }