OSDN Git Service

Feat(BVM): change op cmp (#1873)
[bytom/bytom.git] / protocol / vm / numeric.go
index 09486d2..392a1b2 100644 (file)
@@ -3,7 +3,9 @@ package vm
 import (
        "math"
 
-       "github.com/bytom/math/checked"
+       "github.com/holiman/uint256"
+
+       "github.com/bytom/bytom/math/checked"
 )
 
 func op1Add(vm *virtualMachine) error {
@@ -11,15 +13,22 @@ func op1Add(vm *virtualMachine) error {
        if err != nil {
                return err
        }
-       n, err := vm.popInt64(true)
+
+       n, err := vm.popBigInt(true)
        if err != nil {
                return err
        }
-       res, ok := checked.AddInt64(n, 1)
+
+       num, ok := checked.NewUInt256("1")
        if !ok {
+               return ErrBadValue
+       }
+
+       if num.Add(n, num); num.Sign() < 0 {
                return ErrRange
        }
-       return vm.pushInt64(res, true)
+
+       return vm.pushBigInt(num, true)
 }
 
 func op1Sub(vm *virtualMachine) error {
@@ -27,15 +36,22 @@ func op1Sub(vm *virtualMachine) error {
        if err != nil {
                return err
        }
-       n, err := vm.popInt64(true)
+
+       n, err := vm.popBigInt(true)
        if err != nil {
                return err
        }
-       res, ok := checked.SubInt64(n, 1)
+
+       num, ok := checked.NewUInt256("1")
        if !ok {
+               return ErrBadValue
+       }
+
+       if num.Sub(n, num); num.Sign() < 0 {
                return ErrRange
        }
-       return vm.pushInt64(res, true)
+
+       return vm.pushBigInt(num, true)
 }
 
 func op2Mul(vm *virtualMachine) error {
@@ -43,15 +59,22 @@ func op2Mul(vm *virtualMachine) error {
        if err != nil {
                return err
        }
-       n, err := vm.popInt64(true)
+
+       n, err := vm.popBigInt(true)
        if err != nil {
                return err
        }
-       res, ok := checked.MulInt64(n, 2)
+
+       num, ok := checked.NewUInt256("2")
        if !ok {
+               return ErrBadValue
+       }
+
+       if num.Mul(n, num); num.Sign() < 0 {
                return ErrRange
        }
-       return vm.pushInt64(res, true)
+
+       return vm.pushBigInt(num, true)
 }
 
 func op2Div(vm *virtualMachine) error {
@@ -59,11 +82,13 @@ func op2Div(vm *virtualMachine) error {
        if err != nil {
                return err
        }
-       n, err := vm.popInt64(true)
+
+       n, err := vm.popBigInt(true)
        if err != nil {
                return err
        }
-       return vm.pushInt64(n>>1, true)
+
+       return vm.pushBigInt(n.Rsh(n, 1), true)
 }
 
 func opNegate(vm *virtualMachine) error {
@@ -105,11 +130,13 @@ func opNot(vm *virtualMachine) error {
        if err != nil {
                return err
        }
-       n, err := vm.popInt64(true)
+
+       n, err := vm.popBigInt(true)
        if err != nil {
                return err
        }
-       return vm.pushBool(n == 0, true)
+
+       return vm.pushBool(n.Cmp(uint256.NewInt()) == 0, true)
 }
 
 func op0NotEqual(vm *virtualMachine) error {
@@ -117,11 +144,13 @@ func op0NotEqual(vm *virtualMachine) error {
        if err != nil {
                return err
        }
-       n, err := vm.popInt64(true)
+
+       n, err := vm.popBigInt(true)
        if err != nil {
                return err
        }
-       return vm.pushBool(n != 0, true)
+
+       return vm.pushBool(n.Cmp(uint256.NewInt()) != 0, true)
 }
 
 func opAdd(vm *virtualMachine) error {
@@ -129,19 +158,22 @@ func opAdd(vm *virtualMachine) error {
        if err != nil {
                return err
        }
-       y, err := vm.popInt64(true)
+
+       y, err := vm.popBigInt(true)
        if err != nil {
                return err
        }
-       x, err := vm.popInt64(true)
+
+       x, err := vm.popBigInt(true)
        if err != nil {
                return err
        }
-       res, ok := checked.AddInt64(x, y)
-       if !ok {
+
+       if x.Add(x, y); x.Sign() < 0 {
                return ErrRange
        }
-       return vm.pushInt64(res, true)
+
+       return vm.pushBigInt(x, true)
 }
 
 func opSub(vm *virtualMachine) error {
@@ -149,19 +181,22 @@ func opSub(vm *virtualMachine) error {
        if err != nil {
                return err
        }
-       y, err := vm.popInt64(true)
+
+       y, err := vm.popBigInt(true)
        if err != nil {
                return err
        }
-       x, err := vm.popInt64(true)
+
+       x, err := vm.popBigInt(true)
        if err != nil {
                return err
        }
-       res, ok := checked.SubInt64(x, y)
-       if !ok {
+
+       if x.Sub(x, y); x.Sign() < 0 {
                return ErrRange
        }
-       return vm.pushInt64(res, true)
+
+       return vm.pushBigInt(x, true)
 }
 
 func opMul(vm *virtualMachine) error {
@@ -169,19 +204,22 @@ func opMul(vm *virtualMachine) error {
        if err != nil {
                return err
        }
-       y, err := vm.popInt64(true)
+
+       y, err := vm.popBigInt(true)
        if err != nil {
                return err
        }
-       x, err := vm.popInt64(true)
+
+       x, err := vm.popBigInt(true)
        if err != nil {
                return err
        }
-       res, ok := checked.MulInt64(x, y)
-       if !ok {
+
+       if overflow := x.MulOverflow(x, y); overflow || x.Sign() < 0 {
                return ErrRange
        }
-       return vm.pushInt64(res, true)
+
+       return vm.pushBigInt(x, true)
 }
 
 func opDiv(vm *virtualMachine) error {
@@ -189,22 +227,22 @@ func opDiv(vm *virtualMachine) error {
        if err != nil {
                return err
        }
-       y, err := vm.popInt64(true)
+
+       y, err := vm.popBigInt(true)
        if err != nil {
                return err
        }
-       x, err := vm.popInt64(true)
+
+       x, err := vm.popBigInt(true)
        if err != nil {
                return err
        }
-       if y == 0 {
+
+       if y.IsZero() {
                return ErrDivZero
        }
-       res, ok := checked.DivInt64(x, y)
-       if !ok {
-               return ErrRange
-       }
-       return vm.pushInt64(res, true)
+
+       return vm.pushBigInt(x.Div(x, y), true)
 }
 
 func opMod(vm *virtualMachine) error {
@@ -212,30 +250,22 @@ func opMod(vm *virtualMachine) error {
        if err != nil {
                return err
        }
-       y, err := vm.popInt64(true)
+
+       y, err := vm.popBigInt(true)
        if err != nil {
                return err
        }
-       x, err := vm.popInt64(true)
+
+       x, err := vm.popBigInt(true)
        if err != nil {
                return err
        }
-       if y == 0 {
-               return ErrDivZero
-       }
-
-       res, ok := checked.ModInt64(x, y)
-       if !ok {
-               return ErrRange
-       }
 
-       // Go's modulus operator produces the wrong result for mixed-sign
-       // operands
-       if res != 0 && (x >= 0) != (y >= 0) {
-               res += y
+       if y.IsZero() {
+               return ErrDivZero
        }
 
-       return vm.pushInt64(res, true)
+       return vm.pushBigInt(x.Mod(x, y), true)
 }
 
 func opLshift(vm *virtualMachine) error {
@@ -374,28 +404,28 @@ func doNumCompare(vm *virtualMachine, op int) error {
        if err != nil {
                return err
        }
-       y, err := vm.popInt64(true)
+       y, err := vm.popBigInt(true)
        if err != nil {
                return err
        }
-       x, err := vm.popInt64(true)
+       x, err := vm.popBigInt(true)
        if err != nil {
                return err
        }
        var res bool
        switch op {
        case cmpLess:
-               res = x < y
+               res = x.Cmp(y) < 0
        case cmpLessEqual:
-               res = x <= y
+               res = x.Cmp(y) <= 0
        case cmpGreater:
-               res = x > y
+               res = x.Cmp(y) > 0
        case cmpGreaterEqual:
-               res = x >= y
+               res = x.Cmp(y) >= 0
        case cmpEqual:
-               res = x == y
+               res = x.Cmp(y) == 0
        case cmpNotEqual:
-               res = x != y
+               res = x.Cmp(y) != 0
        }
        return vm.pushBool(res, true)
 }
@@ -405,18 +435,21 @@ func opMin(vm *virtualMachine) error {
        if err != nil {
                return err
        }
-       y, err := vm.popInt64(true)
+
+       y, err := vm.popBigInt(true)
        if err != nil {
                return err
        }
-       x, err := vm.popInt64(true)
+
+       x, err := vm.popBigInt(true)
        if err != nil {
                return err
        }
-       if x > y {
-               x = y
+
+       if x.Cmp(y) > 0 {
+               return vm.pushBigInt(y, true)
        }
-       return vm.pushInt64(x, true)
+       return vm.pushBigInt(x, true)
 }
 
 func opMax(vm *virtualMachine) error {
@@ -424,18 +457,21 @@ func opMax(vm *virtualMachine) error {
        if err != nil {
                return err
        }
-       y, err := vm.popInt64(true)
+
+       y, err := vm.popBigInt(true)
        if err != nil {
                return err
        }
-       x, err := vm.popInt64(true)
+
+       x, err := vm.popBigInt(true)
        if err != nil {
                return err
        }
-       if x < y {
-               x = y
+
+       if x.Cmp(y) < 0 {
+               return vm.pushBigInt(y, true)
        }
-       return vm.pushInt64(x, true)
+       return vm.pushBigInt(x, true)
 }
 
 func opWithin(vm *virtualMachine) error {
@@ -443,17 +479,20 @@ func opWithin(vm *virtualMachine) error {
        if err != nil {
                return err
        }
-       max, err := vm.popInt64(true)
+       max, err := vm.popBigInt(true)
        if err != nil {
                return err
        }
-       min, err := vm.popInt64(true)
+
+       min, err := vm.popBigInt(true)
        if err != nil {
                return err
        }
-       x, err := vm.popInt64(true)
+
+       x, err := vm.popBigInt(true)
        if err != nil {
                return err
        }
-       return vm.pushBool(x >= min && x < max, true)
+
+       return vm.pushBool(x.Cmp(min) >= 0 && x.Cmp(max) < 0, true)
 }