OSDN Git Service

fix StrLiteral and BytesLiteral to support constant (#11)
[bytom/equity.git] / compiler / compile.go
index b543650..76913a2 100644 (file)
@@ -5,7 +5,6 @@ import (
        "fmt"
        "io"
        "io/ioutil"
-       "strings"
 
        chainjson "github.com/bytom/encoding/json"
        "github.com/bytom/errors"
@@ -220,9 +219,10 @@ func compileContract(contract *Contract, globalEnv *environ) error {
        }
 
        b := &builder{}
+       sequence := 0 // sequence is used to count the number of ifStatements
 
        if len(contract.Clauses) == 1 {
-               err = compileClause(b, stk, contract, env, contract.Clauses[0])
+               err = compileClause(b, stk, contract, env, contract.Clauses[0], &sequence)
                if err != nil {
                        return err
                }
@@ -266,7 +266,7 @@ func compileContract(contract *Contract, globalEnv *environ) error {
                                stk = b.addDrop(stk)
                        }
 
-                       err = compileClause(b, stk, contract, env, clause)
+                       err = compileClause(b, stk, contract, env, clause, &sequence)
                        if err != nil {
                                return errors.Wrapf(err, "compiling clause \"%s\"", clause.Name)
                        }
@@ -292,7 +292,7 @@ func compileContract(contract *Contract, globalEnv *environ) error {
        return nil
 }
 
-func compileClause(b *builder, contractStk stack, contract *Contract, env *environ, clause *Clause) error {
+func compileClause(b *builder, contractStk stack, contract *Contract, env *environ, clause *Clause, sequence *int) error {
        var err error
 
        // copy env to leave outerEnv unchanged
@@ -303,7 +303,10 @@ func compileClause(b *builder, contractStk stack, contract *Contract, env *envir
                        return err
                }
        }
-       assignIndexes(clause)
+
+       if err = assignIndexes(clause); err != nil {
+               return err
+       }
 
        var stk stack
        for _, p := range clause.Params {
@@ -317,98 +320,298 @@ func compileClause(b *builder, contractStk stack, contract *Contract, env *envir
        // a count of the number of times each variable is referenced
        counts := make(map[string]int)
        for _, s := range clause.statements {
+               if stmt, ok := s.(*defineStatement); ok && stmt.expr == nil {
+                       continue
+               }
+
                s.countVarRefs(counts)
+               if stmt, ok := s.(*ifStatement); ok {
+                       for _, trueStmt := range stmt.body.trueBody {
+                               trueStmt.countVarRefs(counts)
+                       }
+
+                       for _, falseStmt := range stmt.body.falseBody {
+                               falseStmt.countVarRefs(counts)
+                       }
+               }
        }
 
-       for _, s := range clause.statements {
-               switch stmt := s.(type) {
-               case *verifyStatement:
-                       stk, err = compileExpr(b, stk, contract, clause, env, counts, stmt.expr)
-                       if err != nil {
-                               return errors.Wrapf(err, "in verify statement in clause \"%s\"", clause.Name)
-                       }
-                       stk = b.addVerify(stk)
-
-                       // special-case reporting of certain function calls
-                       if c, ok := stmt.expr.(*callExpr); ok && len(c.args) == 1 {
-                               if b := referencedBuiltin(c.fn); b != nil {
-                                       switch b.name {
-                                       case "below":
-                                               clause.BlockHeight = append(clause.BlockHeight, c.args[0].String())
-                                       case "above":
-                                               clause.BlockHeight = append(clause.BlockHeight, c.args[0].String())
-                                       }
+       for _, stat := range clause.statements {
+               if stk, err = compileStatement(b, stk, contract, env, clause, counts, stat, sequence); err != nil {
+                       return err
+               }
+       }
+
+       err = requireAllValuesDisposedOnce(contract, clause)
+       if err != nil {
+               return err
+       }
+       err = typeCheckClause(contract, clause, env)
+       if err != nil {
+               return err
+       }
+       err = requireAllParamsUsedInClause(clause.Params, clause)
+       if err != nil {
+               return err
+       }
+
+       return nil
+}
+
+func compileStatement(b *builder, stk stack, contract *Contract, env *environ, clause *Clause, counts map[string]int, stat statement, sequence *int) (stack, error) {
+       var err error
+       switch stmt := stat.(type) {
+       case *ifStatement:
+               // sequence add 1 when the statement is ifStatement
+               *sequence++
+               strSequence := fmt.Sprintf("%d", *sequence)
+
+               // compile condition expression
+               stk, err = compileExpr(b, stk, contract, clause, env, counts, stmt.condition)
+               if err != nil {
+                       return stk, errors.Wrapf(err, "in check condition of ifStatement in clause \"%s\"", clause.Name)
+               }
+
+               // jump to falseBody when condition is false, while the JUMPIF instruction will be run success when
+               // the value of dataStack is true, therefore add this check
+               conditionExpr := stk.str
+               stk = b.addBoolean(stk, false)
+               stk = b.addEqual(stk, fmt.Sprintf("(%s == false)", conditionExpr)) // stack: [... <condition_result == false>]
+
+               // add label
+               var label string
+               if len(stmt.body.falseBody) != 0 {
+                       label = "else_" + strSequence
+               } else {
+                       label = "endif_" + strSequence
+               }
+               stk = b.addJumpIf(stk, label)
+               b.addJumpTarget(stk, "if_"+strSequence)
+
+               // temporary store stack and counts for falseBody
+               condStk := stk
+               elseCounts := make(map[string]int)
+               for k, v := range counts {
+                       elseCounts[k] = v
+               }
+
+               // compile trueBody statements
+               if len(stmt.body.trueBody) != 0 {
+                       for _, st := range stmt.body.trueBody {
+                               st.countVarRefs(counts)
+                       }
+
+                       // modify value amount because of using only once
+                       if counts[contract.Value.Amount] > 1 {
+                               counts[contract.Value.Amount] = 1
+                       }
+
+                       // modify value asset because of using only once
+                       if counts[contract.Value.Asset] > 1 {
+                               counts[contract.Value.Asset] = 1
+                       }
+
+                       for _, st := range stmt.body.trueBody {
+                               if stk, err = compileStatement(b, stk, contract, env, clause, counts, st, sequence); err != nil {
+                                       return stk, err
                                }
                        }
+               }
+
+               // compile falseBody statements
+               if len(stmt.body.falseBody) != 0 {
+                       counts := make(map[string]int)
+                       for k, v := range elseCounts {
+                               counts[k] = v
+                       }
 
-               case *lockStatement:
-                       // index
-                       stk = b.addInt64(stk, stmt.index)
+                       for _, st := range stmt.body.falseBody {
+                               st.countVarRefs(counts)
+                       }
 
-                       // TODO: permit more complex expressions for locked,
-                       // like "lock x+y with foo" (?)
+                       // modify value amount because of using only once
+                       if counts[contract.Value.Amount] > 1 {
+                               counts[contract.Value.Amount] = 1
+                       }
 
-                       if stmt.lockedAmount.String() == contract.Value.Amount && stmt.lockedAsset.String() == contract.Value.Asset {
-                               stk = b.addAmount(stk, contract.Value.Amount)
-                               stk = b.addAsset(stk, contract.Value.Asset)
-                       } else {
-                               if strings.Contains(stmt.lockedAmount.String(), contract.Value.Amount) {
-                                       stk = b.addAmount(stk, contract.Value.Amount)
+                       // modify value asset because of using only once
+                       if counts[contract.Value.Asset] > 1 {
+                               counts[contract.Value.Asset] = 1
+                       }
+
+                       stk = condStk
+                       b.addJump(stk, "endif_"+strSequence)
+                       b.addJumpTarget(stk, "else_"+strSequence)
+
+                       for _, st := range stmt.body.falseBody {
+                               if stk, err = compileStatement(b, stk, contract, env, clause, counts, st, sequence); err != nil {
+                                       return stk, err
                                }
+                       }
+               }
+               b.addJumpTarget(stk, "endif_"+strSequence)
+
+       case *defineStatement:
+               // add environ for define variable
+               if err = env.add(stmt.variable.Name, stmt.variable.Type, roleClauseVariable); err != nil {
+                       return stk, err
+               }
 
-                               if strings.Contains(stmt.lockedAsset.String(), contract.Value.Asset) {
-                                       stk = b.addAsset(stk, contract.Value.Asset)
+               // check whether the variable is used or not
+               if counts[stmt.variable.Name] == 0 {
+                       return stk, fmt.Errorf("the defined variable \"%s\" is unused in clause \"%s\"", stmt.variable.Name, clause.Name)
+               }
+
+               if stmt.expr != nil {
+                       // variable
+                       stk, err = compileExpr(b, stk, contract, clause, env, counts, stmt.expr)
+                       if err != nil {
+                               return stk, errors.Wrapf(err, "in define statement in clause \"%s\"", clause.Name)
+                       }
+
+                       // modify stack name
+                       stk.str = stmt.variable.Name
+               }
+
+       case *assignStatement:
+               // find variable from environ with roleClauseVariable
+               if entry := env.lookup(string(stmt.variable.Name)); entry != nil {
+                       if entry.r != roleClauseVariable {
+                               return stk, fmt.Errorf("the type of variable is not roleClauseVariable in assign statement in clause \"%s\"", clause.Name)
+                       }
+                       stmt.variable.Type = entry.t
+               } else {
+                       return stk, fmt.Errorf("the variable \"%s\" is not defined before the assign statement in clause \"%s\"", stmt.variable.Name, clause.Name)
+               }
+
+               // temporary store the counts of defined variable
+               varCount := counts[stmt.variable.Name]
+
+               // calculate the counts of variable for assign statement
+               tmpCounts := make(map[string]int)
+               stmt.countVarRefs(tmpCounts)
+
+               // modify the map counts of defined variable to 1 and minus the number of defined variable
+               // when the assign expression contains the defined variable
+               if tmpCounts[stmt.variable.Name] > 0 {
+                       counts[stmt.variable.Name] = 1
+                       varCount -= tmpCounts[stmt.variable.Name]
+               } else {
+                       depth := stk.find(stmt.variable.Name)
+                       switch depth {
+                       case 0:
+                               break
+                       case 1:
+                               stk = b.addSwap(stk)
+                       default:
+                               stk = b.addRoll(stk, depth)
+                       }
+                       stk = b.addDrop(stk)
+               }
+
+               // variable
+               stk, err = compileExpr(b, stk, contract, clause, env, counts, stmt.expr)
+               if err != nil {
+                       return stk, errors.Wrapf(err, "in define statement in clause \"%s\"", clause.Name)
+               }
+
+               // restore the defined variable counts
+               counts[stmt.variable.Name] = varCount
+
+               // modify stack name
+               stk.str = stmt.variable.Name
+
+       case *verifyStatement:
+               stk, err = compileExpr(b, stk, contract, clause, env, counts, stmt.expr)
+               if err != nil {
+                       return stk, errors.Wrapf(err, "in verify statement in clause \"%s\"", clause.Name)
+               }
+               stk = b.addVerify(stk)
+
+               // special-case reporting of certain function calls
+               if c, ok := stmt.expr.(*callExpr); ok && len(c.args) == 1 {
+                       if b := referencedBuiltin(c.fn); b != nil {
+                               switch b.name {
+                               case "below":
+                                       clause.BlockHeight = append(clause.BlockHeight, c.args[0].String())
+                               case "above":
+                                       clause.BlockHeight = append(clause.BlockHeight, c.args[0].String())
                                }
+                       }
+               }
+
+       case *lockStatement:
+               // index
+               stk = b.addInt64(stk, stmt.index)
+
+               // TODO: permit more complex expressions for locked,
+               // like "lock x+y with foo" (?)
 
-                               // amount
+               if stmt.lockedAmount.String() == contract.Value.Amount && stmt.lockedAsset.String() == contract.Value.Asset {
+                       stk = b.addAmount(stk, contract.Value.Amount)
+                       stk = b.addAsset(stk, contract.Value.Asset)
+               } else {
+                       // calculate the counts of variable for lockStatement
+                       lockCounts := make(map[string]int)
+                       stmt.countVarRefs(lockCounts)
+
+                       // amount
+                       switch {
+                       case stmt.lockedAmount.String() == contract.Value.Amount:
+                               stk = b.addAmount(stk, contract.Value.Amount)
+                       case stmt.lockedAmount.String() != contract.Value.Amount && lockCounts[contract.Value.Amount] > 0:
+                               stk = b.addAmount(stk, contract.Value.Amount)
+                               stk, err = compileExpr(b, stk, contract, clause, env, counts, stmt.lockedAmount)
+                               if err != nil {
+                                       return stk, errors.Wrapf(err, "in lock statement in clause \"%s\"", clause.Name)
+                               }
+                       default:
                                stk, err = compileExpr(b, stk, contract, clause, env, counts, stmt.lockedAmount)
                                if err != nil {
-                                       return errors.Wrapf(err, "in lock statement in clause \"%s\"", clause.Name)
+                                       return stk, errors.Wrapf(err, "in lock statement in clause \"%s\"", clause.Name)
                                }
+                       }
 
-                               // asset
+                       // asset
+                       switch {
+                       case stmt.lockedAsset.String() == contract.Value.Asset:
+                               stk = b.addAsset(stk, contract.Value.Asset)
+                       case stmt.lockedAsset.String() != contract.Value.Asset && lockCounts[contract.Value.Asset] > 0:
+                               stk = b.addAsset(stk, contract.Value.Asset)
+                               stk, err = compileExpr(b, stk, contract, clause, env, counts, stmt.lockedAsset)
+                               if err != nil {
+                                       return stk, errors.Wrapf(err, "in lock statement in clause \"%s\"", clause.Name)
+                               }
+                       default:
                                stk, err = compileExpr(b, stk, contract, clause, env, counts, stmt.lockedAsset)
                                if err != nil {
-                                       return errors.Wrapf(err, "in lock statement in clause \"%s\"", clause.Name)
+                                       return stk, errors.Wrapf(err, "in lock statement in clause \"%s\"", clause.Name)
                                }
                        }
+               }
 
-                       // version
-                       stk = b.addInt64(stk, 1)
+               // version
+               stk = b.addInt64(stk, 1)
 
-                       // prog
-                       stk, err = compileExpr(b, stk, contract, clause, env, counts, stmt.program)
-                       if err != nil {
-                               return errors.Wrapf(err, "in lock statement in clause \"%s\"", clause.Name)
-                       }
+               // prog
+               stk, err = compileExpr(b, stk, contract, clause, env, counts, stmt.program)
+               if err != nil {
+                       return stk, errors.Wrapf(err, "in lock statement in clause \"%s\"", clause.Name)
+               }
 
-                       stk = b.addCheckOutput(stk, fmt.Sprintf("checkOutput(%s, %s, %s)",
-                               stmt.lockedAmount.String(), stmt.lockedAsset.String(), stmt.program))
-                       stk = b.addVerify(stk)
+               stk = b.addCheckOutput(stk, fmt.Sprintf("checkOutput(%s, %s, %s)",
+                       stmt.lockedAmount.String(), stmt.lockedAsset.String(), stmt.program))
+               stk = b.addVerify(stk)
 
-               case *unlockStatement:
-                       if len(clause.statements) == 1 {
-                               // This is the only statement in the clause, make sure TRUE is
-                               // on the stack.
-                               stk = b.addBoolean(stk, true)
-                       }
+       case *unlockStatement:
+               if len(clause.statements) == 1 {
+                       // This is the only statement in the clause, make sure TRUE is
+                       // on the stack.
+                       stk = b.addBoolean(stk, true)
                }
        }
 
-       err = requireAllValuesDisposedOnce(contract, clause)
-       if err != nil {
-               return err
-       }
-       err = typeCheckClause(contract, clause, env)
-       if err != nil {
-               return err
-       }
-       err = requireAllParamsUsedInClause(clause.Params, clause)
-       if err != nil {
-               return err
-       }
-
-       return nil
+       return stk, nil
 }
 
 func compileExpr(b *builder, stk stack, contract *Contract, clause *Clause, env *environ, counts map[string]int, expr expression) (stack, error) {