OSDN Git Service

the other statements support to operate contract value besides lock/unlock statement...
authoroysheng <33340252+oysheng@users.noreply.github.com>
Wed, 21 Nov 2018 08:48:35 +0000 (16:48 +0800)
committerPaladz <yzhu101@uottawa.ca>
Wed, 21 Nov 2018 08:48:35 +0000 (16:48 +0800)
* the other statements support contract value besides lock/unlock statement

* optimise

* fix counts for reference variables

compiler/checks.go
compiler/compile.go

index e78d8c4..de14e06 100644 (file)
@@ -158,6 +158,29 @@ func referencedBuiltin(expr expression) *builtin {
        return nil
 }
 
+func countsVarRef(stat statement, counts map[string]int) map[string]int {
+       if stmt, ok := stat.(*defineStatement); ok && stmt.expr == nil {
+               return counts
+       }
+
+       if _, ok := stat.(*unlockStatement); ok {
+               return counts
+       }
+
+       stat.countVarRefs(counts)
+       if stmt, ok := stat.(*ifStatement); ok {
+               for _, trueStmt := range stmt.body.trueBody {
+                       counts = countsVarRef(trueStmt, counts)
+               }
+
+               for _, falseStmt := range stmt.body.falseBody {
+                       counts = countsVarRef(falseStmt, counts)
+               }
+       }
+
+       return counts
+}
+
 func assignIndexes(clause *Clause) error {
        var nextIndex int64
        for i, stmt := range clause.statements {
index ddaf496..4b181b5 100644 (file)
@@ -319,21 +319,8 @@ func compileClause(b *builder, contractStk stack, contract *Contract, env *envir
 
        // a count of the number of times each variable is referenced
        counts := make(map[string]int)
-       for _, s := range clause.statements {
-               if stmt, ok := s.(*defineStatement); ok && stmt.expr == nil {
-                       continue
-               }
-
-               s.countVarRefs(counts)
-               if stmt, ok := s.(*ifStatement); ok {
-                       for _, trueStmt := range stmt.body.trueBody {
-                               trueStmt.countVarRefs(counts)
-                       }
-
-                       for _, falseStmt := range stmt.body.falseBody {
-                               falseStmt.countVarRefs(counts)
-                       }
-               }
+       for _, stat := range clause.statements {
+               counts = countsVarRef(stat, counts)
        }
 
        for _, stat := range clause.statements {
@@ -362,6 +349,9 @@ func compileStatement(b *builder, stk stack, contract *Contract, env *environ, c
                *sequence++
                strSequence := fmt.Sprintf("%d", *sequence)
 
+               // compile the contract valueAmount and valueAsset for expression
+               stk, counts = compileContractValue(b, stmt.condition, contract.Value, stk, counts)
+
                // compile condition expression
                stk, err = compileExpr(b, stk, contract, clause, env, counts, stmt.condition)
                if err != nil {
@@ -441,6 +431,9 @@ func compileStatement(b *builder, stk stack, contract *Contract, env *environ, c
                }
 
                if stmt.expr != nil {
+                       // compile the contract valueAmount and valueAsset for expression
+                       stk, counts = compileContractValue(b, stmt.expr, contract.Value, stk, counts)
+
                        // variable
                        stk, err = compileExpr(b, stk, contract, clause, env, counts, stmt.expr)
                        if err != nil {
@@ -487,6 +480,9 @@ func compileStatement(b *builder, stk stack, contract *Contract, env *environ, c
                        stk = b.addDrop(stk)
                }
 
+               // compile the contract valueAmount and valueAsset for expression
+               stk, counts = compileContractValue(b, stmt.expr, contract.Value, stk, counts)
+
                // variable
                stk, err = compileExpr(b, stk, contract, clause, env, counts, stmt.expr)
                if err != nil {
@@ -502,6 +498,9 @@ func compileStatement(b *builder, stk stack, contract *Contract, env *environ, c
                stk.str = stmt.variable.Name
 
        case *verifyStatement:
+               // compile the contract valueAmount and valueAsset for expression
+               stk, counts = compileContractValue(b, stmt.expr, contract.Value, stk, counts)
+
                stk, err = compileExpr(b, stk, contract, clause, env, counts, stmt.expr)
                if err != nil {
                        return stk, errors.Wrapf(err, "in verify statement in clause \"%s\"", clause.Name)
@@ -852,6 +851,22 @@ func compileArg(b *builder, stk stack, contract *Contract, clause *Clause, env *
        return stk, 1, err
 }
 
+func compileContractValue(b *builder, expr expression, contractValue ValueInfo, stk stack, counts map[string]int) (stack, map[string]int) {
+       valueCounts := make(map[string]int)
+       expr.countVarRefs(valueCounts)
+       if valueCounts[contractValue.Amount] > 0 {
+               counts[contractValue.Amount] = valueCounts[contractValue.Amount]
+               stk = b.addAmount(stk, contractValue.Amount)
+       }
+
+       if valueCounts[contractValue.Asset] > 0 {
+               counts[contractValue.Asset] = valueCounts[contractValue.Asset]
+               stk = b.addAsset(stk, contractValue.Asset)
+       }
+
+       return stk, counts
+}
+
 func compileRef(b *builder, stk stack, counts map[string]int, ref varRef) (stack, error) {
        depth := stk.find(string(ref))
        if depth < 0 {