OSDN Git Service

edit code for elegant (#1983)
authorPaladz <yzhu101@uottawa.ca>
Wed, 23 Jun 2021 08:08:15 +0000 (16:08 +0800)
committerGitHub <noreply@github.com>
Wed, 23 Jun 2021 08:08:15 +0000 (16:08 +0800)
Co-authored-by: paladz <colt@ColtdeMacBook-Pro.local>
protocol/vm/bitwise.go
protocol/vm/control.go
protocol/vm/crypto.go
protocol/vm/introspection.go
protocol/vm/numeric.go
protocol/vm/ops.go
protocol/vm/pushdata_test.go
protocol/vm/vmutil/script.go

index 54d05ad..5fdffce 100644 (file)
@@ -3,16 +3,16 @@ package vm
 import "bytes"
 
 func opInvert(vm *virtualMachine) error {
-       err := vm.applyCost(1)
-       if err != nil {
+       if err := vm.applyCost(1); err != nil {
                return err
        }
+
        top, err := vm.top()
        if err != nil {
                return err
        }
-       err = vm.applyCost(int64(len(top)))
-       if err != nil {
+
+       if err = vm.applyCost(int64(len(top))); err != nil {
                return err
        }
        // Could rewrite top in place but maybe it's a shared data
@@ -26,26 +26,29 @@ func opInvert(vm *virtualMachine) error {
 }
 
 func opAnd(vm *virtualMachine) error {
-       err := vm.applyCost(1)
-       if err != nil {
+       if err := vm.applyCost(1); err != nil {
                return err
        }
+
        b, err := vm.pop(true)
        if err != nil {
                return err
        }
+
        a, err := vm.pop(true)
        if err != nil {
                return err
        }
+
        min, max := len(a), len(b)
        if min > max {
                min, max = max, min
        }
-       err = vm.applyCost(int64(min))
-       if err != nil {
+
+       if err = vm.applyCost(int64(min)); err != nil {
                return err
        }
+
        res := make([]byte, 0, min)
        for i := 0; i < min; i++ {
                res = append(res, a[i]&b[i])
@@ -62,26 +65,29 @@ func opXor(vm *virtualMachine) error {
 }
 
 func doOr(vm *virtualMachine, xor bool) error {
-       err := vm.applyCost(1)
-       if err != nil {
+       if err := vm.applyCost(1); err != nil {
                return err
        }
+
        b, err := vm.pop(true)
        if err != nil {
                return err
        }
+
        a, err := vm.pop(true)
        if err != nil {
                return err
        }
+
        min, max := len(a), len(b)
        if min > max {
                min, max = max, min
        }
-       err = vm.applyCost(int64(max))
-       if err != nil {
+
+       if err = vm.applyCost(int64(max)); err != nil {
                return err
        }
+
        res := make([]byte, 0, max)
        for i := 0; i < max; i++ {
                var aByte, bByte, resByte byte
@@ -111,6 +117,7 @@ func opEqual(vm *virtualMachine) error {
        if err != nil {
                return err
        }
+
        return vm.pushBool(res, true)
 }
 
@@ -119,6 +126,7 @@ func opEqualVerify(vm *virtualMachine) error {
        if err != nil {
                return err
        }
+
        if res {
                return nil
        }
@@ -126,24 +134,26 @@ func opEqualVerify(vm *virtualMachine) error {
 }
 
 func doEqual(vm *virtualMachine) (bool, error) {
-       err := vm.applyCost(1)
-       if err != nil {
+       if err := vm.applyCost(1); err != nil {
                return false, err
        }
+
        b, err := vm.pop(true)
        if err != nil {
                return false, err
        }
+
        a, err := vm.pop(true)
        if err != nil {
                return false, err
        }
+
        min, max := len(a), len(b)
        if min > max {
                min, max = max, min
        }
-       err = vm.applyCost(int64(min))
-       if err != nil {
+
+       if err = vm.applyCost(int64(min)); err != nil {
                return false, err
        }
        return bytes.Equal(a, b), nil
index c1989d4..dbf4283 100644 (file)
@@ -5,14 +5,15 @@ import (
 )
 
 func opVerify(vm *virtualMachine) error {
-       err := vm.applyCost(1)
-       if err != nil {
+       if err := vm.applyCost(1); err != nil {
                return err
        }
+
        p, err := vm.pop(true)
        if err != nil {
                return err
        }
+
        if AsBool(p) {
                return nil
        }
@@ -20,10 +21,10 @@ func opVerify(vm *virtualMachine) error {
 }
 
 func opFail(vm *virtualMachine) error {
-       err := vm.applyCost(1)
-       if err != nil {
+       if err := vm.applyCost(1); err != nil {
                return err
        }
+
        return ErrReturn
 }
 
@@ -92,24 +93,25 @@ func opCheckPredicate(vm *virtualMachine) error {
 }
 
 func opJump(vm *virtualMachine) error {
-       err := vm.applyCost(1)
-       if err != nil {
+       if err := vm.applyCost(1); err != nil {
                return err
        }
+
        address := binary.LittleEndian.Uint32(vm.data)
        vm.nextPC = address
        return nil
 }
 
 func opJumpIf(vm *virtualMachine) error {
-       err := vm.applyCost(1)
-       if err != nil {
+       if err := vm.applyCost(1); err != nil {
                return err
        }
+
        p, err := vm.pop(true)
        if err != nil {
                return err
        }
+
        if AsBool(p) {
                address := binary.LittleEndian.Uint32(vm.data)
                vm.nextPC = address
index 2f6c6af..6a7db93 100644 (file)
@@ -24,42 +24,47 @@ func doHash(vm *virtualMachine, hashFactory func() hash.Hash) error {
        if err != nil {
                return err
        }
+
        cost := int64(len(x))
        if cost < 64 {
                cost = 64
        }
-       err = vm.applyCost(cost)
-       if err != nil {
+
+       if err = vm.applyCost(cost); err != nil {
                return err
        }
+
        h := hashFactory()
-       _, err = h.Write(x)
-       if err != nil {
+       if _, err = h.Write(x); err != nil {
                return err
        }
        return vm.pushDataStack(h.Sum(nil), false)
 }
 
 func opCheckSig(vm *virtualMachine) error {
-       err := vm.applyCost(1024)
-       if err != nil {
+       if err := vm.applyCost(1024); err != nil {
                return err
        }
+
        pubkeyBytes, err := vm.pop(true)
        if err != nil {
                return err
        }
+
        msg, err := vm.pop(true)
        if err != nil {
                return err
        }
+
        sig, err := vm.pop(true)
        if err != nil {
                return err
        }
+
        if len(msg) != 32 {
                return ErrBadValue
        }
+
        if len(pubkeyBytes) != ed25519.PublicKeySize {
                return vm.pushBool(false, true)
        }
@@ -99,6 +104,7 @@ func opCheckMultiSig(vm *virtualMachine) error {
        if numSigs < 0 || numSigs > numPubkeys || (numPubkeys > 0 && numSigs == 0) {
                return ErrBadValue
        }
+
        pubkeyByteses := make([][]byte, 0, numPubkeys)
        for i := int64(0); i < numPubkeys; i++ {
                pubkeyBytes, err := vm.pop(true)
@@ -107,13 +113,16 @@ func opCheckMultiSig(vm *virtualMachine) error {
                }
                pubkeyByteses = append(pubkeyByteses, pubkeyBytes)
        }
+
        msg, err := vm.pop(true)
        if err != nil {
                return err
        }
+
        if len(msg) != 32 {
                return ErrBadValue
        }
+
        sigs := make([][]byte, 0, numSigs)
        for i := int64(0); i < numSigs; i++ {
                sig, err := vm.pop(true)
@@ -137,17 +146,19 @@ func opCheckMultiSig(vm *virtualMachine) error {
                }
                pubkeys = pubkeys[1:]
        }
+
        return vm.pushBool(len(sigs) == 0, true)
 }
 
 func opTxSigHash(vm *virtualMachine) error {
-       err := vm.applyCost(256)
-       if err != nil {
+       if err := vm.applyCost(256); err != nil {
                return err
        }
+
        if vm.context.TxSigHash == nil {
                return ErrContext
        }
+
        return vm.pushDataStack(vm.context.TxSigHash(), false)
 }
 
index acab1e6..099655f 100644 (file)
@@ -23,6 +23,7 @@ func opCheckOutput(vm *virtualMachine) error {
        if err != nil {
                return err
        }
+
        amountInt, err := vm.popBigInt(true)
        if err != nil {
                return err
@@ -50,8 +51,7 @@ func opCheckOutput(vm *virtualMachine) error {
 }
 
 func opAsset(vm *virtualMachine) error {
-       err := vm.applyCost(1)
-       if err != nil {
+       if err := vm.applyCost(1); err != nil {
                return err
        }
 
@@ -62,8 +62,7 @@ func opAsset(vm *virtualMachine) error {
 }
 
 func opAmount(vm *virtualMachine) error {
-       err := vm.applyCost(1)
-       if err != nil {
+       if err := vm.applyCost(1); err != nil {
                return err
        }
 
@@ -75,8 +74,7 @@ func opAmount(vm *virtualMachine) error {
 }
 
 func opProgram(vm *virtualMachine) error {
-       err := vm.applyCost(1)
-       if err != nil {
+       if err := vm.applyCost(1); err != nil {
                return err
        }
 
@@ -96,16 +94,14 @@ func opIndex(vm *virtualMachine) error {
 }
 
 func opEntryID(vm *virtualMachine) error {
-       err := vm.applyCost(1)
-       if err != nil {
+       if err := vm.applyCost(1); err != nil {
                return err
        }
        return vm.pushDataStack(vm.context.EntryID, true)
 }
 
 func opOutputID(vm *virtualMachine) error {
-       err := vm.applyCost(1)
-       if err != nil {
+       if err := vm.applyCost(1); err != nil {
                return err
        }
 
@@ -116,8 +112,7 @@ func opOutputID(vm *virtualMachine) error {
 }
 
 func opBlockHeight(vm *virtualMachine) error {
-       err := vm.applyCost(1)
-       if err != nil {
+       if err := vm.applyCost(1); err != nil {
                return err
        }
 
index 5289327..51afdb5 100644 (file)
@@ -2,13 +2,10 @@ package vm
 
 import (
        "github.com/holiman/uint256"
-
-       "github.com/bytom/bytom/math/checked"
 )
 
 func op1Add(vm *virtualMachine) error {
-       err := vm.applyCost(2)
-       if err != nil {
+       if err := vm.applyCost(2); err != nil {
                return err
        }
 
@@ -17,11 +14,7 @@ func op1Add(vm *virtualMachine) error {
                return err
        }
 
-       num, ok := checked.NewUInt256("1")
-       if !ok {
-               return ErrBadValue
-       }
-
+       num := uint256.NewInt().SetUint64(1)
        if num.Add(n, num); num.Sign() < 0 {
                return ErrRange
        }
@@ -30,8 +23,7 @@ func op1Add(vm *virtualMachine) error {
 }
 
 func op1Sub(vm *virtualMachine) error {
-       err := vm.applyCost(2)
-       if err != nil {
+       if err := vm.applyCost(2); err != nil {
                return err
        }
 
@@ -40,11 +32,7 @@ func op1Sub(vm *virtualMachine) error {
                return err
        }
 
-       num, ok := checked.NewUInt256("1")
-       if !ok {
-               return ErrBadValue
-       }
-
+       num := uint256.NewInt().SetUint64(1)
        if num.Sub(n, num); num.Sign() < 0 {
                return ErrRange
        }
@@ -53,8 +41,7 @@ func op1Sub(vm *virtualMachine) error {
 }
 
 func op2Mul(vm *virtualMachine) error {
-       err := vm.applyCost(2)
-       if err != nil {
+       if err := vm.applyCost(2); err != nil {
                return err
        }
 
@@ -63,11 +50,7 @@ func op2Mul(vm *virtualMachine) error {
                return err
        }
 
-       num, ok := checked.NewUInt256("2")
-       if !ok {
-               return ErrBadValue
-       }
-
+       num := uint256.NewInt().SetUint64(2)
        if num.Mul(n, num); num.Sign() < 0 {
                return ErrRange
        }
@@ -90,8 +73,7 @@ func op2Div(vm *virtualMachine) error {
 }
 
 func opNot(vm *virtualMachine) error {
-       err := vm.applyCost(2)
-       if err != nil {
+       if err := vm.applyCost(2); err != nil {
                return err
        }
 
@@ -104,8 +86,7 @@ func opNot(vm *virtualMachine) error {
 }
 
 func op0NotEqual(vm *virtualMachine) error {
-       err := vm.applyCost(2)
-       if err != nil {
+       if err := vm.applyCost(2); err != nil {
                return err
        }
 
@@ -118,8 +99,7 @@ func op0NotEqual(vm *virtualMachine) error {
 }
 
 func opAdd(vm *virtualMachine) error {
-       err := vm.applyCost(2)
-       if err != nil {
+       if err := vm.applyCost(2); err != nil {
                return err
        }
 
@@ -141,8 +121,7 @@ func opAdd(vm *virtualMachine) error {
 }
 
 func opSub(vm *virtualMachine) error {
-       err := vm.applyCost(2)
-       if err != nil {
+       if err := vm.applyCost(2); err != nil {
                return err
        }
 
@@ -164,8 +143,7 @@ func opSub(vm *virtualMachine) error {
 }
 
 func opMul(vm *virtualMachine) error {
-       err := vm.applyCost(8)
-       if err != nil {
+       if err := vm.applyCost(8); err != nil {
                return err
        }
 
@@ -187,8 +165,7 @@ func opMul(vm *virtualMachine) error {
 }
 
 func opDiv(vm *virtualMachine) error {
-       err := vm.applyCost(8)
-       if err != nil {
+       if err := vm.applyCost(8); err != nil {
                return err
        }
 
@@ -210,8 +187,7 @@ func opDiv(vm *virtualMachine) error {
 }
 
 func opMod(vm *virtualMachine) error {
-       err := vm.applyCost(8)
-       if err != nil {
+       if err := vm.applyCost(8); err != nil {
                return err
        }
 
@@ -233,8 +209,7 @@ func opMod(vm *virtualMachine) error {
 }
 
 func opLshift(vm *virtualMachine) error {
-       err := vm.applyCost(8)
-       if err != nil {
+       if err := vm.applyCost(8); err != nil {
                return err
        }
 
@@ -261,8 +236,7 @@ func opLshift(vm *virtualMachine) error {
 }
 
 func opRshift(vm *virtualMachine) error {
-       err := vm.applyCost(8)
-       if err != nil {
+       if err := vm.applyCost(8); err != nil {
                return err
        }
 
@@ -286,14 +260,15 @@ func opRshift(vm *virtualMachine) error {
 }
 
 func opBoolAnd(vm *virtualMachine) error {
-       err := vm.applyCost(2)
-       if err != nil {
+       if err := vm.applyCost(2); err != nil {
                return err
        }
+
        b, err := vm.pop(true)
        if err != nil {
                return err
        }
+
        a, err := vm.pop(true)
        if err != nil {
                return err
@@ -302,14 +277,15 @@ func opBoolAnd(vm *virtualMachine) error {
 }
 
 func opBoolOr(vm *virtualMachine) error {
-       err := vm.applyCost(2)
-       if err != nil {
+       if err := vm.applyCost(2); err != nil {
                return err
        }
+
        b, err := vm.pop(true)
        if err != nil {
                return err
        }
+
        a, err := vm.pop(true)
        if err != nil {
                return err
@@ -331,8 +307,7 @@ func opNumEqual(vm *virtualMachine) error {
 }
 
 func opNumEqualVerify(vm *virtualMachine) error {
-       err := vm.applyCost(2)
-       if err != nil {
+       if err := vm.applyCost(2); err != nil {
                return err
        }
 
@@ -373,18 +348,20 @@ func opGreaterThanOrEqual(vm *virtualMachine) error {
 }
 
 func doNumCompare(vm *virtualMachine, op int) error {
-       err := vm.applyCost(2)
-       if err != nil {
+       if err := vm.applyCost(2); err != nil {
                return err
        }
+
        y, err := vm.popBigInt(true)
        if err != nil {
                return err
        }
+
        x, err := vm.popBigInt(true)
        if err != nil {
                return err
        }
+
        var res bool
        switch op {
        case cmpLess:
@@ -404,8 +381,7 @@ func doNumCompare(vm *virtualMachine, op int) error {
 }
 
 func opMin(vm *virtualMachine) error {
-       err := vm.applyCost(2)
-       if err != nil {
+       if err := vm.applyCost(2); err != nil {
                return err
        }
 
@@ -426,8 +402,7 @@ func opMin(vm *virtualMachine) error {
 }
 
 func opMax(vm *virtualMachine) error {
-       err := vm.applyCost(2)
-       if err != nil {
+       if err := vm.applyCost(2); err != nil {
                return err
        }
 
@@ -448,10 +423,10 @@ func opMax(vm *virtualMachine) error {
 }
 
 func opWithin(vm *virtualMachine) error {
-       err := vm.applyCost(4)
-       if err != nil {
+       if err := vm.applyCost(4); err != nil {
                return err
        }
+
        max, err := vm.popBigInt(true)
        if err != nil {
                return err
index fd168f1..0b4ba7f 100644 (file)
@@ -124,7 +124,6 @@ const (
        OP_PUSHDATA1 Op = 0x4c
        OP_PUSHDATA2 Op = 0x4d
        OP_PUSHDATA4 Op = 0x4e
-       OP_1NEGATE   Op = 0x4f
        OP_NOP       Op = 0x61
 
        OP_JUMP           Op = 0x63
@@ -318,14 +317,15 @@ var (
 // ParseOp parses the op at position pc in prog, returning the parsed
 // instruction (opcode plus any associated data).
 func ParseOp(prog []byte, pc uint32) (inst Instruction, err error) {
-       if len(prog) > math.MaxInt32 {
-               err = ErrLongProgram
-       }
        l := uint32(len(prog))
+       if l > math.MaxInt32 {
+               return inst, ErrLongProgram
+       }
+
        if pc >= l {
-               err = ErrShortProgram
-               return
+               return inst, ErrShortProgram
        }
+
        opcode := Op(prog[pc])
        inst.Op = opcode
        inst.Len = 1
@@ -333,94 +333,99 @@ func ParseOp(prog []byte, pc uint32) (inst Instruction, err error) {
                inst.Data = []byte{uint8(opcode-OP_1) + 1}
                return
        }
+
        if opcode >= OP_DATA_1 && opcode <= OP_DATA_75 {
                inst.Len += uint32(opcode - OP_DATA_1 + 1)
                end, ok := checked.AddUint32(pc, inst.Len)
                if !ok {
-                       err = errors.WithDetail(checked.ErrOverflow, "data length exceeds max program size")
-                       return
+                       return inst, errors.WithDetail(checked.ErrOverflow, "data length exceeds max program size")
                }
+
                if end > l {
-                       err = ErrShortProgram
-                       return
+                       return inst, ErrShortProgram
                }
+
                inst.Data = prog[pc+1 : end]
                return
        }
+
        if opcode == OP_PUSHDATA1 {
                if pc == l-1 {
-                       err = ErrShortProgram
-                       return
+                       return inst, ErrShortProgram
                }
+
                n := prog[pc+1]
                inst.Len += uint32(n) + 1
                end, ok := checked.AddUint32(pc, inst.Len)
                if !ok {
-                       err = errors.WithDetail(checked.ErrOverflow, "data length exceeds max program size")
+                       return inst, errors.WithDetail(checked.ErrOverflow, "data length exceeds max program size")
                }
+
                if end > l {
-                       err = ErrShortProgram
-                       return
+                       return inst, ErrShortProgram
                }
+
                inst.Data = prog[pc+2 : end]
                return
        }
+
        if opcode == OP_PUSHDATA2 {
                if len(prog) < 3 || pc > l-3 {
-                       err = ErrShortProgram
-                       return
+                       return inst, ErrShortProgram
                }
+
                n := binary.LittleEndian.Uint16(prog[pc+1 : pc+3])
                inst.Len += uint32(n) + 2
                end, ok := checked.AddUint32(pc, inst.Len)
                if !ok {
-                       err = errors.WithDetail(checked.ErrOverflow, "data length exceeds max program size")
-                       return
+                       return inst, errors.WithDetail(checked.ErrOverflow, "data length exceeds max program size")
                }
+
                if end > l {
-                       err = ErrShortProgram
-                       return
+                       return inst, ErrShortProgram
                }
+
                inst.Data = prog[pc+3 : end]
                return
        }
+
        if opcode == OP_PUSHDATA4 {
                if len(prog) < 5 || pc > l-5 {
-                       err = ErrShortProgram
-                       return
+                       return inst, ErrShortProgram
                }
-               inst.Len += 4
 
+               inst.Len += 4
                n := binary.LittleEndian.Uint32(prog[pc+1 : pc+5])
                var ok bool
                inst.Len, ok = checked.AddUint32(inst.Len, n)
                if !ok {
-                       err = errors.WithDetail(checked.ErrOverflow, "data length exceeds max program size")
-                       return
+                       return inst, errors.WithDetail(checked.ErrOverflow, "data length exceeds max program size")
                }
+
                end, ok := checked.AddUint32(pc, inst.Len)
                if !ok {
-                       err = errors.WithDetail(checked.ErrOverflow, "data length exceeds max program size")
-                       return
+                       return inst, errors.WithDetail(checked.ErrOverflow, "data length exceeds max program size")
                }
+
                if end > l {
-                       err = ErrShortProgram
-                       return
+                       return inst, ErrShortProgram
                }
+
                inst.Data = prog[pc+5 : end]
                return
        }
+
        if opcode == OP_JUMP || opcode == OP_JUMPIF {
                inst.Len += 4
                end, ok := checked.AddUint32(pc, inst.Len)
                if !ok {
-                       err = errors.WithDetail(checked.ErrOverflow, "jump target exceeds max program size")
-                       return
+                       return inst, errors.WithDetail(checked.ErrOverflow, "jump target exceeds max program size")
                }
+
                if end > l {
-                       err = ErrShortProgram
-                       return
+                       return inst, ErrShortProgram
                }
+
                inst.Data = prog[pc+1 : end]
                return
        }
@@ -434,6 +439,7 @@ func ParseProgram(prog []byte) ([]Instruction, error) {
                if err != nil {
                        return nil, err
                }
+
                result = append(result, inst)
                var ok bool
                pc, ok = checked.AddUint32(pc, inst.Len)
@@ -475,10 +481,5 @@ func init() {
 
 // IsPushdata judge instruction whether is a pushdata operation(include opFalse operation)
 func (inst *Instruction) IsPushdata() bool {
-       if reflect.ValueOf(ops[inst.Op].fn) == reflect.ValueOf(ops[OP_1].fn) ||
-               reflect.ValueOf(ops[inst.Op].fn) == reflect.ValueOf(ops[OP_0].fn) {
-               return true
-       }
-
-       return false
+       return reflect.ValueOf(ops[inst.Op].fn) == reflect.ValueOf(ops[OP_1].fn) || reflect.ValueOf(ops[inst.Op].fn) == reflect.ValueOf(ops[OP_0].fn)
 }
index 3de2d73..c415f96 100644 (file)
@@ -61,7 +61,7 @@ func TestPushdataOps(t *testing.T) {
                })
        }
 
-       pushops := append(pushdataops, OP_FALSE, OP_1NEGATE, OP_NOP)
+       pushops := append(pushdataops, OP_FALSE, OP_NOP)
        for _, op := range pushops {
                cases = append(cases, testStruct{
                        op: op,
index 6cd77f6..7b90dbc 100644 (file)
@@ -135,9 +135,8 @@ func P2SPMultiSigProgramWithHeight(pubkeys []ed25519.PublicKey, nrequired int, b
                builder.AddOp(vm.OP_BLOCKHEIGHT)
                builder.AddOp(vm.OP_GREATERTHAN)
                builder.AddOp(vm.OP_VERIFY)
-       } else if blockHeight < 0 {
-               return nil, errors.WithDetail(ErrBadValue, "negative blockHeight")
        }
+
        if err := builder.addP2SPMultiSig(pubkeys, nrequired); err != nil {
                return nil, err
        }