OSDN Git Service

add if-else statement (#7)
authoroysheng <33340252+oysheng@users.noreply.github.com>
Mon, 10 Sep 2018 08:09:52 +0000 (16:09 +0800)
committerPaladz <yzhu101@uottawa.ca>
Mon, 10 Sep 2018 08:09:52 +0000 (16:09 +0800)
* add define statement for equity

* add test

* add if-else statement

* handle with stack

* optimise parameter reference check

* optimise else label

* optimise ifbody stack

* optimse compile if else statement

* check math for count

* add unit test

* optimise parse ifstatement

* after endif support add statements

* add sequese for ifstatement

* add test

* rm redundant modify

compiler/ast.go
compiler/builder.go
compiler/checks.go
compiler/compile.go
compiler/compile_test.go
compiler/equitytest/equitytest.go
compiler/parse.go

index 4169e24..b896fcc 100644 (file)
@@ -90,18 +90,13 @@ type HashCall struct {
        ArgType string `json:"arg_type"`
 }
 
-// ClauseReq describes a payment requirement of a clause (one of the
-// things after the "requires" keyword).
-type ClauseReq struct {
-       Name string `json:"name"`
-
-       assetExpr, amountExpr expression
+// IfBody describes a if ... else ... struct
+type IfStatmentBody struct {
+       // if statements body
+       trueBody []statement
 
-       // Asset is the expression describing the required asset.
-       Asset string `json:"asset"`
-
-       // Amount is the expression describing the required amount.
-       Amount string `json:"amount"`
+       // else statements body
+       falseBody []statement
 }
 
 type statement interface {
@@ -117,6 +112,15 @@ func (s defineStatement) countVarRefs(counts map[string]int) {
        s.expr.countVarRefs(counts)
 }
 
+type ifStatement struct {
+       condition expression
+       body      *IfStatmentBody
+}
+
+func (s ifStatement) countVarRefs(counts map[string]int) {
+       s.condition.countVarRefs(counts)
+}
+
 type verifyStatement struct {
        expr expression
 }
index 13b1faf..e8dca4e 100644 (file)
@@ -44,6 +44,10 @@ func (b *builder) addInt64(stk stack, n int64) stack {
        return b.add(s, stk.add(s))
 }
 
+func (b *builder) addEqual(stk stack, desc string) stack {
+       return b.add("EQUAL", stk.dropN(2).add(desc))
+}
+
 func (b *builder) addNumEqual(stk stack, desc string) stack {
        return b.add("NUMEQUAL", stk.dropN(2).add(desc))
 }
index bd2c524..a0c7992 100644 (file)
@@ -36,6 +36,7 @@ func requireAllParamsUsedInClauses(params []*Param, clauses []*Clause) error {
                                break
                        }
                }
+
                if !used {
                        return fmt.Errorf("parameter \"%s\" is unused", p.Name)
                }
@@ -47,17 +48,7 @@ func requireAllParamsUsedInClause(params []*Param, clause *Clause) error {
        for _, p := range params {
                used := false
                for _, stmt := range clause.statements {
-                       switch s := stmt.(type) {
-                       case *defineStatement:
-                               used = references(s.expr, p.Name)
-                       case *verifyStatement:
-                               used = references(s.expr, p.Name)
-                       case *lockStatement:
-                               used = references(s.lockedAmount, p.Name) || references(s.lockedAsset, p.Name) || references(s.program, p.Name)
-                       case *unlockStatement:
-                               used = references(s.unlockedAmount, p.Name) || references(s.unlockedAsset, p.Name)
-                       }
-                       if used {
+                       if used = checkParamUsedInStatement(p, stmt); used {
                                break
                        }
                }
@@ -69,6 +60,40 @@ func requireAllParamsUsedInClause(params []*Param, clause *Clause) error {
        return nil
 }
 
+func checkParamUsedInStatement(param *Param, stmt statement) (used bool) {
+       switch s := stmt.(type) {
+       case *ifStatement:
+               if used = references(s.condition, param.Name); used {
+                       return used
+               }
+
+               for _, st := range s.body.trueBody {
+                       if used = checkParamUsedInStatement(param, st); used {
+                               break
+                       }
+               }
+
+               if !used {
+                       for _, st := range s.body.falseBody {
+                               if used = checkParamUsedInStatement(param, st); used {
+                                       break
+                               }
+                       }
+               }
+
+       case *defineStatement:
+               used = references(s.expr, param.Name)
+       case *verifyStatement:
+               used = references(s.expr, param.Name)
+       case *lockStatement:
+               used = references(s.lockedAmount, param.Name) || references(s.lockedAsset, param.Name) || references(s.program, param.Name)
+       case *unlockStatement:
+               used = references(s.unlockedAmount, param.Name) || references(s.unlockedAsset, param.Name)
+       }
+
+       return used
+}
+
 func references(expr expression, name string) bool {
        switch e := expr.(type) {
        case *binaryExpr:
index 3de0d31..e86399d 100644 (file)
@@ -220,9 +220,10 @@ func compileContract(contract *Contract, globalEnv *environ) error {
        }
 
        b := &builder{}
+       sequence := 0 // sequence is used to count the number of ifStatements
 
        if len(contract.Clauses) == 1 {
-               err = compileClause(b, stk, contract, env, contract.Clauses[0])
+               err = compileClause(b, stk, contract, env, contract.Clauses[0], &sequence)
                if err != nil {
                        return err
                }
@@ -266,7 +267,7 @@ func compileContract(contract *Contract, globalEnv *environ) error {
                                stk = b.addDrop(stk)
                        }
 
-                       err = compileClause(b, stk, contract, env, clause)
+                       err = compileClause(b, stk, contract, env, clause, &sequence)
                        if err != nil {
                                return errors.Wrapf(err, "compiling clause \"%s\"", clause.Name)
                        }
@@ -292,7 +293,7 @@ func compileContract(contract *Contract, globalEnv *environ) error {
        return nil
 }
 
-func compileClause(b *builder, contractStk stack, contract *Contract, env *environ, clause *Clause) error {
+func compileClause(b *builder, contractStk stack, contract *Contract, env *environ, clause *Clause, sequence *int) error {
        var err error
 
        // copy env to leave outerEnv unchanged
@@ -318,118 +319,206 @@ func compileClause(b *builder, contractStk stack, contract *Contract, env *envir
        counts := make(map[string]int)
        for _, s := range clause.statements {
                s.countVarRefs(counts)
+               if stmt, ok := s.(*ifStatement); ok {
+                       for _, trueStmt := range stmt.body.trueBody {
+                               trueStmt.countVarRefs(counts)
+                       }
+
+                       for _, falseStmt := range stmt.body.falseBody {
+                               falseStmt.countVarRefs(counts)
+                       }
+               }
        }
 
-       for _, s := range clause.statements {
-               switch stmt := s.(type) {
-               case *defineStatement:
-                       // variable
-                       stk, err = compileExpr(b, stk, contract, clause, env, counts, stmt.expr)
-                       if err != nil {
-                               return errors.Wrapf(err, "in define statement in clause \"%s\"", clause.Name)
+       for _, stat := range clause.statements {
+               if stk, err = compileStatement(b, stk, contract, env, clause, counts, stat, sequence); err != nil {
+                       return err
+               }
+       }
+
+       err = requireAllValuesDisposedOnce(contract, clause)
+       if err != nil {
+               return err
+       }
+       err = typeCheckClause(contract, clause, env)
+       if err != nil {
+               return err
+       }
+       err = requireAllParamsUsedInClause(clause.Params, clause)
+       if err != nil {
+               return err
+       }
+
+       return nil
+}
+
+func compileStatement(b *builder, stk stack, contract *Contract, env *environ, clause *Clause, counts map[string]int, stat statement, sequence *int) (stack, error) {
+       var err error
+       switch stmt := stat.(type) {
+       case *ifStatement:
+               // sequence add 1 when the statement is ifStatement
+               *sequence++
+               strSequence := fmt.Sprintf("%d", *sequence)
+
+               // compile condition expression
+               stk, err = compileExpr(b, stk, contract, clause, env, counts, stmt.condition)
+               if err != nil {
+                       return stk, errors.Wrapf(err, "in check condition of ifStatement in clause \"%s\"", clause.Name)
+               }
+
+               // jump to falseBody when condition is false, while the JUMPIF instruction will be run success when
+               // the value of dataStack is true, therefore add this check
+               conditionExpr := stk.str
+               stk = b.addBoolean(stk, false)
+               stk = b.addEqual(stk, fmt.Sprintf("(%s == false)", conditionExpr)) // stack: [... <condition_result == false>]
+
+               // add label
+               var label string
+               if len(stmt.body.falseBody) != 0 {
+                       label = "else_" + strSequence
+               } else {
+                       label = "endif_" + strSequence
+               }
+               stk = b.addJumpIf(stk, label)
+               b.addJumpTarget(stk, "if_"+strSequence)
+
+               // temporary store stack and counts for falseBody
+               condStk := stk
+               elseCounts := make(map[string]int)
+               for k, v := range counts {
+                       elseCounts[k] = v
+               }
+
+               // compile trueBody statements
+               if len(stmt.body.trueBody) != 0 {
+                       for _, st := range stmt.body.trueBody {
+                               st.countVarRefs(counts)
                        }
 
-                       // check variable type
-                       if stmt.expr.typ(env) != stmt.varName.Type {
-                               return fmt.Errorf("expression in define statement in clause \"%s\" has type \"%s\", must be \"%s\"",
-                                       clause.Name, stmt.expr.typ(env), stmt.varName.Type)
+                       for _, st := range stmt.body.trueBody {
+                               if stk, err = compileStatement(b, stk, contract, env, clause, counts, st, sequence); err != nil {
+                                       return stk, err
+                               }
                        }
+               }
 
-                       // modify stack name
-                       stk.str = stmt.varName.Name
+               // compile falseBody statements
+               if len(stmt.body.falseBody) != 0 {
+                       counts := make(map[string]int)
+                       for k, v := range elseCounts {
+                               counts[k] = v
+                       }
 
-                       // add environ for define variable
-                       if err = env.add(stmt.varName.Name, stmt.varName.Type, roleClauseVariable); err != nil {
-                               return err
+                       for _, st := range stmt.body.falseBody {
+                               st.countVarRefs(counts)
                        }
 
-               case *verifyStatement:
-                       stk, err = compileExpr(b, stk, contract, clause, env, counts, stmt.expr)
-                       if err != nil {
-                               return errors.Wrapf(err, "in verify statement in clause \"%s\"", clause.Name)
-                       }
-                       stk = b.addVerify(stk)
-
-                       // special-case reporting of certain function calls
-                       if c, ok := stmt.expr.(*callExpr); ok && len(c.args) == 1 {
-                               if b := referencedBuiltin(c.fn); b != nil {
-                                       switch b.name {
-                                       case "below":
-                                               clause.BlockHeight = append(clause.BlockHeight, c.args[0].String())
-                                       case "above":
-                                               clause.BlockHeight = append(clause.BlockHeight, c.args[0].String())
-                                       }
+                       stk = condStk
+                       b.addJump(stk, "endif_"+strSequence)
+                       b.addJumpTarget(stk, "else_"+strSequence)
+
+                       for _, st := range stmt.body.falseBody {
+                               if stk, err = compileStatement(b, stk, contract, env, clause, counts, st, sequence); err != nil {
+                                       return stk, err
                                }
                        }
+               }
+               b.addJumpTarget(stk, "endif_"+strSequence)
 
-               case *lockStatement:
-                       // index
-                       stk = b.addInt64(stk, stmt.index)
+       case *defineStatement:
+               // variable
+               stk, err = compileExpr(b, stk, contract, clause, env, counts, stmt.expr)
+               if err != nil {
+                       return stk, errors.Wrapf(err, "in define statement in clause \"%s\"", clause.Name)
+               }
 
-                       // TODO: permit more complex expressions for locked,
-                       // like "lock x+y with foo" (?)
+               // check variable type
+               if stmt.expr.typ(env) != stmt.varName.Type {
+                       return stk, fmt.Errorf("expression in define statement in clause \"%s\" has type \"%s\", must be \"%s\"",
+                               clause.Name, stmt.expr.typ(env), stmt.varName.Type)
+               }
 
-                       if stmt.lockedAmount.String() == contract.Value.Amount && stmt.lockedAsset.String() == contract.Value.Asset {
-                               stk = b.addAmount(stk, contract.Value.Amount)
-                               stk = b.addAsset(stk, contract.Value.Asset)
-                       } else {
-                               if strings.Contains(stmt.lockedAmount.String(), contract.Value.Amount) {
-                                       stk = b.addAmount(stk, contract.Value.Amount)
-                               }
+               // modify stack name
+               stk.str = stmt.varName.Name
 
-                               if strings.Contains(stmt.lockedAsset.String(), contract.Value.Asset) {
-                                       stk = b.addAsset(stk, contract.Value.Asset)
-                               }
+               // add environ for define variable
+               if err = env.add(stmt.varName.Name, stmt.varName.Type, roleClauseVariable); err != nil {
+                       return stk, err
+               }
 
-                               // amount
-                               stk, err = compileExpr(b, stk, contract, clause, env, counts, stmt.lockedAmount)
-                               if err != nil {
-                                       return errors.Wrapf(err, "in lock statement in clause \"%s\"", clause.Name)
+       case *verifyStatement:
+               stk, err = compileExpr(b, stk, contract, clause, env, counts, stmt.expr)
+               if err != nil {
+                       return stk, errors.Wrapf(err, "in verify statement in clause \"%s\"", clause.Name)
+               }
+               stk = b.addVerify(stk)
+
+               // special-case reporting of certain function calls
+               if c, ok := stmt.expr.(*callExpr); ok && len(c.args) == 1 {
+                       if b := referencedBuiltin(c.fn); b != nil {
+                               switch b.name {
+                               case "below":
+                                       clause.BlockHeight = append(clause.BlockHeight, c.args[0].String())
+                               case "above":
+                                       clause.BlockHeight = append(clause.BlockHeight, c.args[0].String())
                                }
+                       }
+               }
 
-                               // asset
-                               stk, err = compileExpr(b, stk, contract, clause, env, counts, stmt.lockedAsset)
-                               if err != nil {
-                                       return errors.Wrapf(err, "in lock statement in clause \"%s\"", clause.Name)
-                               }
+       case *lockStatement:
+               // index
+               stk = b.addInt64(stk, stmt.index)
+
+               // TODO: permit more complex expressions for locked,
+               // like "lock x+y with foo" (?)
+
+               if stmt.lockedAmount.String() == contract.Value.Amount && stmt.lockedAsset.String() == contract.Value.Asset {
+                       stk = b.addAmount(stk, contract.Value.Amount)
+                       stk = b.addAsset(stk, contract.Value.Asset)
+               } else {
+                       if strings.Contains(stmt.lockedAmount.String(), contract.Value.Amount) {
+                               stk = b.addAmount(stk, contract.Value.Amount)
                        }
 
-                       // version
-                       stk = b.addInt64(stk, 1)
+                       if strings.Contains(stmt.lockedAsset.String(), contract.Value.Asset) {
+                               stk = b.addAsset(stk, contract.Value.Asset)
+                       }
 
-                       // prog
-                       stk, err = compileExpr(b, stk, contract, clause, env, counts, stmt.program)
+                       // amount
+                       stk, err = compileExpr(b, stk, contract, clause, env, counts, stmt.lockedAmount)
                        if err != nil {
-                               return errors.Wrapf(err, "in lock statement in clause \"%s\"", clause.Name)
+                               return stk, errors.Wrapf(err, "in lock statement in clause \"%s\"", clause.Name)
                        }
 
-                       stk = b.addCheckOutput(stk, fmt.Sprintf("checkOutput(%s, %s, %s)",
-                               stmt.lockedAmount.String(), stmt.lockedAsset.String(), stmt.program))
-                       stk = b.addVerify(stk)
-
-               case *unlockStatement:
-                       if len(clause.statements) == 1 {
-                               // This is the only statement in the clause, make sure TRUE is
-                               // on the stack.
-                               stk = b.addBoolean(stk, true)
+                       // asset
+                       stk, err = compileExpr(b, stk, contract, clause, env, counts, stmt.lockedAsset)
+                       if err != nil {
+                               return stk, errors.Wrapf(err, "in lock statement in clause \"%s\"", clause.Name)
                        }
                }
-       }
 
-       err = requireAllValuesDisposedOnce(contract, clause)
-       if err != nil {
-               return err
-       }
-       err = typeCheckClause(contract, clause, env)
-       if err != nil {
-               return err
-       }
-       err = requireAllParamsUsedInClause(clause.Params, clause)
-       if err != nil {
-               return err
+               // version
+               stk = b.addInt64(stk, 1)
+
+               // prog
+               stk, err = compileExpr(b, stk, contract, clause, env, counts, stmt.program)
+               if err != nil {
+                       return stk, errors.Wrapf(err, "in lock statement in clause \"%s\"", clause.Name)
+               }
+
+               stk = b.addCheckOutput(stk, fmt.Sprintf("checkOutput(%s, %s, %s)",
+                       stmt.lockedAmount.String(), stmt.lockedAsset.String(), stmt.program))
+               stk = b.addVerify(stk)
+
+       case *unlockStatement:
+               if len(clause.statements) == 1 {
+                       // This is the only statement in the clause, make sure TRUE is
+                       // on the stack.
+                       stk = b.addBoolean(stk, true)
+               }
        }
 
-       return nil
+       return stk, nil
 }
 
 func compileExpr(b *builder, stk stack, contract *Contract, clause *Clause, env *environ, counts map[string]int, expr expression) (stack, error) {
index fffbc7a..aea9644 100644 (file)
@@ -61,6 +61,21 @@ func TestCompile(t *testing.T) {
                        equitytest.TestDefineVar,
                        "52797b937b7887916987",
                },
+               {
+                       "TestSigIf",
+                       equitytest.TestSigIf,
+                       "53797b879169765379a00087641c00000052795279a0696321000000765279a069",
+               },
+               {
+                       "TestIfAndMultiClause",
+                       equitytest.TestIfAndMultiClause,
+                       "7b641f0000007087916976547aa00087641a000000765379a06963240000007b7bae7cac",
+               },
+               {
+                       "TestIfRecursive",
+                       equitytest.TestIfRecursive,
+                       "7b644400000054795279879169765579a00087643500000052795479a000876429000000765379a06952795579879169633a000000765479a06953797b8791635c0000007654798791695279a000876459000000527978a0697d8791",
+               },
        }
        for _, c := range cases {
                t.Run(c.name, func(t *testing.T) {
index 98c5f94..b08374a 100644 (file)
@@ -89,3 +89,58 @@ contract TestDefineVar(result: Integer) locks valueAmount of valueAsset {
   }
 }
 `
+
+const TestSigIf = `
+contract TestSigIf(a: Integer, count:Integer) locks valueAmount of valueAsset {
+  clause check(b: Integer, c: Integer) {
+    verify b != count
+    if a > b {
+        verify b > c
+    } else {
+        verify a > c
+    }
+    unlock valueAmount of valueAsset
+  }
+}
+`
+const TestIfAndMultiClause = `
+contract TestIfAndMultiClause(a: Integer, cancelKey: PublicKey) locks valueAmount of valueAsset {
+  clause check(b: Integer, c: Integer) {
+    verify b != c
+    if a > b {
+        verify a > c
+    }
+    unlock valueAmount of valueAsset
+  }
+  clause cancel(sellerSig: Signature) {
+    verify checkTxSig(cancelKey, sellerSig)
+    unlock valueAmount of valueAsset
+  }
+}
+`
+
+const TestIfRecursive = `
+contract TestIfRecursive(a: Integer, count:Integer) locks valueAmount of valueAsset {
+  clause check(b: Integer, c: Integer, d: Integer) {
+    verify b != count
+    if a > b {
+        if d > c {
+           verify a > d
+        }
+        verify d != b
+    } else {
+        verify a > c
+    }
+    verify c != count
+    unlock valueAmount of valueAsset
+  }
+  clause cancel(e: Integer, f: Integer) {
+    verify a != e
+    if a > f {
+      verify e > count
+    }
+    verify f != count
+    unlock valueAmount of valueAsset
+  }
+}
+`
index 3c1a5c0..b773164 100644 (file)
@@ -141,6 +141,8 @@ func parseStatements(p *parser) []statement {
 
 func parseStatement(p *parser) statement {
        switch peekKeyword(p) {
+       case "if":
+               return parseIfStmt(p)
        case "define":
                return parseDefineStmt(p)
        case "verify":
@@ -153,6 +155,22 @@ func parseStatement(p *parser) statement {
        panic(parseErr(p.buf, p.pos, "unknown keyword \"%s\"", peekKeyword(p)))
 }
 
+func parseIfStmt(p *parser) *ifStatement {
+       consumeKeyword(p, "if")
+       condition := parseExpr(p)
+       body := &IfStatmentBody{}
+       consumeTok(p, "{")
+       body.trueBody = parseStatements(p)
+       consumeTok(p, "}")
+       if peekKeyword(p) == "else" {
+               consumeKeyword(p, "else")
+               consumeTok(p, "{")
+               body.falseBody = parseStatements(p)
+               consumeTok(p, "}")
+       }
+       return &ifStatement{condition: condition, body: body}
+}
+
 func parseDefineStmt(p *parser) *defineStatement {
        consumeKeyword(p, "define")
        variableName := consumeIdentifier(p)