OSDN Git Service

the other statements support to operate contract value besides lock/unlock statement...
[bytom/equity.git] / compiler / checks.go
index bd2c524..de14e06 100644 (file)
@@ -5,15 +5,37 @@ import "fmt"
 func checkRecursive(contract *Contract) bool {
        for _, clause := range contract.Clauses {
                for _, stmt := range clause.statements {
-                       if l, ok := stmt.(*lockStatement); ok {
-                               if c, ok := l.program.(*callExpr); ok {
-                                       if references(c.fn, contract.Name) {
-                                               return true
-                                       }
-                               }
+                       if result := checkStatRecursive(stmt, contract.Name); result {
+                               return true
+                       }
+               }
+       }
+       return false
+}
+
+func checkStatRecursive(stmt statement, contractName string) bool {
+       switch s := stmt.(type) {
+       case *ifStatement:
+               for _, trueStmt := range s.body.trueBody {
+                       if result := checkStatRecursive(trueStmt, contractName); result {
+                               return true
+                       }
+               }
+
+               for _, falseStmt := range s.body.falseBody {
+                       if result := checkStatRecursive(falseStmt, contractName); result {
+                               return true
+                       }
+               }
+
+       case *lockStatement:
+               if c, ok := s.program.(*callExpr); ok {
+                       if references(c.fn, contractName) {
+                               return true
                        }
                }
        }
+
        return false
 }
 
@@ -36,6 +58,7 @@ func requireAllParamsUsedInClauses(params []*Param, clauses []*Clause) error {
                                break
                        }
                }
+
                if !used {
                        return fmt.Errorf("parameter \"%s\" is unused", p.Name)
                }
@@ -47,17 +70,7 @@ func requireAllParamsUsedInClause(params []*Param, clause *Clause) error {
        for _, p := range params {
                used := false
                for _, stmt := range clause.statements {
-                       switch s := stmt.(type) {
-                       case *defineStatement:
-                               used = references(s.expr, p.Name)
-                       case *verifyStatement:
-                               used = references(s.expr, p.Name)
-                       case *lockStatement:
-                               used = references(s.lockedAmount, p.Name) || references(s.lockedAsset, p.Name) || references(s.program, p.Name)
-                       case *unlockStatement:
-                               used = references(s.unlockedAmount, p.Name) || references(s.unlockedAsset, p.Name)
-                       }
-                       if used {
+                       if used = checkParamUsedInStatement(p, stmt); used {
                                break
                        }
                }
@@ -69,6 +82,42 @@ func requireAllParamsUsedInClause(params []*Param, clause *Clause) error {
        return nil
 }
 
+func checkParamUsedInStatement(param *Param, stmt statement) (used bool) {
+       switch s := stmt.(type) {
+       case *ifStatement:
+               if used = references(s.condition, param.Name); used {
+                       return used
+               }
+
+               for _, st := range s.body.trueBody {
+                       if used = checkParamUsedInStatement(param, st); used {
+                               break
+                       }
+               }
+
+               if !used {
+                       for _, st := range s.body.falseBody {
+                               if used = checkParamUsedInStatement(param, st); used {
+                                       break
+                               }
+                       }
+               }
+
+       case *defineStatement:
+               used = references(s.expr, param.Name)
+       case *assignStatement:
+               used = references(s.expr, param.Name)
+       case *verifyStatement:
+               used = references(s.expr, param.Name)
+       case *lockStatement:
+               used = references(s.lockedAmount, param.Name) || references(s.lockedAsset, param.Name) || references(s.program, param.Name)
+       case *unlockStatement:
+               used = references(s.unlockedAmount, param.Name) || references(s.unlockedAsset, param.Name)
+       }
+
+       return used
+}
+
 func references(expr expression, name string) bool {
        switch e := expr.(type) {
        case *binaryExpr:
@@ -98,94 +147,145 @@ func references(expr expression, name string) bool {
        return false
 }
 
-func requireAllValuesDisposedOnce(contract *Contract, clause *Clause) error {
-       err := valueDisposedOnce(contract.Value, clause)
-       if err != nil {
-               return err
+func referencedBuiltin(expr expression) *builtin {
+       if v, ok := expr.(varRef); ok {
+               for _, b := range builtins {
+                       if string(v) == b.name {
+                               return &b
+                       }
+               }
        }
        return nil
 }
 
-func valueDisposedOnce(value ValueInfo, clause *Clause) error {
-       var count int
-       for _, s := range clause.statements {
-               switch stmt := s.(type) {
-               case *unlockStatement:
-                       if references(stmt.unlockedAmount, value.Amount) && references(stmt.unlockedAsset, value.Asset) {
-                               count++
-                       }
-               case *lockStatement:
-                       if references(stmt.lockedAmount, value.Amount) && references(stmt.lockedAsset, value.Asset) {
-                               count++
-                       }
-               }
+func countsVarRef(stat statement, counts map[string]int) map[string]int {
+       if stmt, ok := stat.(*defineStatement); ok && stmt.expr == nil {
+               return counts
+       }
+
+       if _, ok := stat.(*unlockStatement); ok {
+               return counts
        }
-       switch count {
-       case 0:
-               return fmt.Errorf("valueAmount \"%s\" or valueAsset \"%s\" not disposed in clause \"%s\"", value.Amount, value.Asset, clause.Name)
-       case 1:
-               return nil
-       default:
-               return fmt.Errorf("valueAmount \"%s\" or valueAsset \"%s\" disposed multiple times in clause \"%s\"", value.Amount, value.Asset, clause.Name)
+
+       stat.countVarRefs(counts)
+       if stmt, ok := stat.(*ifStatement); ok {
+               for _, trueStmt := range stmt.body.trueBody {
+                       counts = countsVarRef(trueStmt, counts)
+               }
+
+               for _, falseStmt := range stmt.body.falseBody {
+                       counts = countsVarRef(falseStmt, counts)
+               }
        }
+
+       return counts
 }
 
-func referencedBuiltin(expr expression) *builtin {
-       if v, ok := expr.(varRef); ok {
-               for _, b := range builtins {
-                       if string(v) == b.name {
-                               return &b
-                       }
+func assignIndexes(clause *Clause) error {
+       var nextIndex int64
+       for i, stmt := range clause.statements {
+               if nextIndex = assignStatIndexes(stmt, nextIndex, i != len(clause.statements)-1); nextIndex < 0 {
+                       return fmt.Errorf("Not support that the number of lock/unlock statement is not equal between ifbody and elsebody when the if-else is not the last statement in clause \"%s\"", clause.Name)
                }
        }
+
        return nil
 }
 
-func assignIndexes(clause *Clause) {
-       var nextIndex int64
-       for _, s := range clause.statements {
-               switch stmt := s.(type) {
-               case *lockStatement:
-                       stmt.index = nextIndex
-                       nextIndex++
+func assignStatIndexes(stat statement, nextIndex int64, nonFinalFlag bool) int64 {
+       switch stmt := stat.(type) {
+       case *ifStatement:
+               trueIndex := nextIndex
+               falseIndex := nextIndex
+               for _, trueStmt := range stmt.body.trueBody {
+                       trueIndex = assignStatIndexes(trueStmt, trueIndex, nonFinalFlag)
+               }
+
+               for _, falseStmt := range stmt.body.falseBody {
+                       falseIndex = assignStatIndexes(falseStmt, falseIndex, nonFinalFlag)
+               }
 
-               case *unlockStatement:
-                       nextIndex++
+               if trueIndex != falseIndex && nonFinalFlag {
+                       return -1
+               } else if trueIndex == falseIndex {
+                       nextIndex = trueIndex
                }
+
+       case *lockStatement:
+               stmt.index = nextIndex
+               nextIndex++
+
+       case *unlockStatement:
+               nextIndex++
        }
+
+       return nextIndex
 }
 
 func typeCheckClause(contract *Contract, clause *Clause, env *environ) error {
        for _, s := range clause.statements {
-               switch stmt := s.(type) {
-               case *verifyStatement:
-                       if t := stmt.expr.typ(env); t != boolType {
-                               return fmt.Errorf("expression in verify statement in clause \"%s\" has type \"%s\", must be Boolean", clause.Name, t)
-                       }
+               if err := typeCheckStatement(s, contract.Value, clause.Name, env); err != nil {
+                       return err
+               }
+       }
+       return nil
+}
 
-               case *lockStatement:
-                       if t := stmt.lockedAmount.typ(env); !(t == intType || t == amountType) {
-                               return fmt.Errorf("lockedAmount expression \"%s\" in lock statement in clause \"%s\" has type \"%s\", must be Integer", stmt.lockedAmount, clause.Name, t)
-                       }
-                       if t := stmt.lockedAsset.typ(env); t != assetType {
-                               return fmt.Errorf("lockedAsset expression \"%s\" in lock statement in clause \"%s\" has type \"%s\", must be Asset", stmt.lockedAsset, clause.Name, t)
-                       }
-                       if t := stmt.program.typ(env); t != progType {
-                               return fmt.Errorf("program in lock statement in clause \"%s\" has type \"%s\", must be Program", clause.Name, t)
+func typeCheckStatement(stat statement, contractValue ValueInfo, clauseName string, env *environ) error {
+       switch stmt := stat.(type) {
+       case *ifStatement:
+               for _, trueStmt := range stmt.body.trueBody {
+                       if err := typeCheckStatement(trueStmt, contractValue, clauseName, env); err != nil {
+                               return err
                        }
+               }
 
-               case *unlockStatement:
-                       if t := stmt.unlockedAmount.typ(env); !(t == intType || t == amountType) {
-                               return fmt.Errorf("unlockedAmount expression \"%s\" in unlock statement of clause \"%s\" has type \"%s\", must be Integer", stmt.unlockedAmount, clause.Name, t)
-                       }
-                       if t := stmt.unlockedAsset.typ(env); t != assetType {
-                               return fmt.Errorf("unlockedAsset expression \"%s\" in unlock statement of clause \"%s\" has type \"%s\", must be Asset", stmt.unlockedAsset, clause.Name, t)
-                       }
-                       if stmt.unlockedAmount.String() != contract.Value.Amount || stmt.unlockedAsset.String() != contract.Value.Asset {
-                               return fmt.Errorf("amount \"%s\" of asset \"%s\" expression in unlock statement of clause \"%s\" must be the contract valueAmount \"%s\" of valueAsset \"%s\"",
-                                       stmt.unlockedAmount.String(), stmt.unlockedAsset.String(), clause.Name, contract.Value.Amount, contract.Value.Asset)
+               for _, falseStmt := range stmt.body.falseBody {
+                       if err := typeCheckStatement(falseStmt, contractValue, clauseName, env); err != nil {
+                               return err
                        }
                }
+
+       case *defineStatement:
+               if stmt.expr != nil && stmt.expr.typ(env) != stmt.variable.Type && !isHashSubtype(stmt.expr.typ(env)) {
+                       return fmt.Errorf("expression in define statement in clause \"%s\" has type \"%s\", must be \"%s\"",
+                               clauseName, stmt.expr.typ(env), stmt.variable.Type)
+               }
+
+       case *assignStatement:
+               if stmt.expr.typ(env) != stmt.variable.Type && !isHashSubtype(stmt.expr.typ(env)) {
+                       return fmt.Errorf("expression in assign statement in clause \"%s\" has type \"%s\", must be \"%s\"",
+                               clauseName, stmt.expr.typ(env), stmt.variable.Type)
+               }
+
+       case *verifyStatement:
+               if t := stmt.expr.typ(env); t != boolType {
+                       return fmt.Errorf("expression in verify statement in clause \"%s\" has type \"%s\", must be Boolean", clauseName, t)
+               }
+
+       case *lockStatement:
+               if t := stmt.lockedAmount.typ(env); !(t == intType || t == amountType) {
+                       return fmt.Errorf("lockedAmount expression \"%s\" in lock statement in clause \"%s\" has type \"%s\", must be Integer", stmt.lockedAmount, clauseName, t)
+               }
+               if t := stmt.lockedAsset.typ(env); t != assetType {
+                       return fmt.Errorf("lockedAsset expression \"%s\" in lock statement in clause \"%s\" has type \"%s\", must be Asset", stmt.lockedAsset, clauseName, t)
+               }
+               if t := stmt.program.typ(env); t != progType {
+                       return fmt.Errorf("program in lock statement in clause \"%s\" has type \"%s\", must be Program", clauseName, t)
+               }
+
+       case *unlockStatement:
+               if t := stmt.unlockedAmount.typ(env); !(t == intType || t == amountType) {
+                       return fmt.Errorf("unlockedAmount expression \"%s\" in unlock statement of clause \"%s\" has type \"%s\", must be Integer", stmt.unlockedAmount, clauseName, t)
+               }
+               if t := stmt.unlockedAsset.typ(env); t != assetType {
+                       return fmt.Errorf("unlockedAsset expression \"%s\" in unlock statement of clause \"%s\" has type \"%s\", must be Asset", stmt.unlockedAsset, clauseName, t)
+               }
+               if stmt.unlockedAmount.String() != contractValue.Amount || stmt.unlockedAsset.String() != contractValue.Asset {
+                       return fmt.Errorf("amount \"%s\" of asset \"%s\" expression in unlock statement of clause \"%s\" must be the contract valueAmount \"%s\" of valueAsset \"%s\"",
+                               stmt.unlockedAmount.String(), stmt.unlockedAsset.String(), clauseName, contractValue.Amount, contractValue.Asset)
+               }
        }
+
        return nil
 }