OSDN Git Service

fix StrLiteral and BytesLiteral to support constant (#11)
[bytom/equity.git] / compiler / compile.go
index 99fe605..76913a2 100644 (file)
@@ -89,21 +89,17 @@ func Compile(r io.Reader) ([]*Contract, error) {
                                switch s := stmt.(type) {
                                case *lockStatement:
                                        valueInfo := ValueInfo{
-                                               Name:    s.locked.String(),
+                                               Amount:  s.lockedAmount.String(),
+                                               Asset:   s.lockedAsset.String(),
                                                Program: s.program.String(),
                                        }
-                                       if s.locked.String() != contract.Value {
-                                               for _, r := range clause.Reqs {
-                                                       if s.locked.String() == r.Name {
-                                                               valueInfo.Asset = r.assetExpr.String()
-                                                               valueInfo.Amount = r.amountExpr.String()
-                                                               break
-                                                       }
-                                               }
-                                       }
+
                                        clause.Values = append(clause.Values, valueInfo)
                                case *unlockStatement:
-                                       valueInfo := ValueInfo{Name: contract.Value}
+                                       valueInfo := ValueInfo{
+                                               Amount: contract.Value.Amount,
+                                               Asset:  contract.Value.Asset,
+                                       }
                                        clause.Values = append(clause.Values, valueInfo)
                                }
                        }
@@ -182,10 +178,15 @@ func compileContract(contract *Contract, globalEnv *environ) error {
                        return err
                }
        }
-       err = env.add(contract.Value, valueType, roleContractValue)
-       if err != nil {
+
+       // value is spilt with valueAmount and valueAsset
+       if err = env.add(contract.Value.Amount, amountType, roleContractValue); err != nil {
+               return err
+       }
+       if err = env.add(contract.Value.Asset, assetType, roleContractValue); err != nil {
                return err
        }
+
        for _, c := range contract.Clauses {
                err = env.add(c.Name, nilType, roleClause)
                if err != nil {
@@ -193,10 +194,6 @@ func compileContract(contract *Contract, globalEnv *environ) error {
                }
        }
 
-       err = prohibitValueParams(contract)
-       if err != nil {
-               return err
-       }
        err = prohibitSigParams(contract)
        if err != nil {
                return err
@@ -222,9 +219,10 @@ func compileContract(contract *Contract, globalEnv *environ) error {
        }
 
        b := &builder{}
+       sequence := 0 // sequence is used to count the number of ifStatements
 
        if len(contract.Clauses) == 1 {
-               err = compileClause(b, stk, contract, env, contract.Clauses[0])
+               err = compileClause(b, stk, contract, env, contract.Clauses[0], &sequence)
                if err != nil {
                        return err
                }
@@ -268,7 +266,7 @@ func compileContract(contract *Contract, globalEnv *environ) error {
                                stk = b.addDrop(stk)
                        }
 
-                       err = compileClause(b, stk, contract, env, clause)
+                       err = compileClause(b, stk, contract, env, clause, &sequence)
                        if err != nil {
                                return errors.Wrapf(err, "compiling clause \"%s\"", clause.Name)
                        }
@@ -294,7 +292,7 @@ func compileContract(contract *Contract, globalEnv *environ) error {
        return nil
 }
 
-func compileClause(b *builder, contractStk stack, contract *Contract, env *environ, clause *Clause) error {
+func compileClause(b *builder, contractStk stack, contract *Contract, env *environ, clause *Clause, sequence *int) error {
        var err error
 
        // copy env to leave outerEnv unchanged
@@ -305,16 +303,10 @@ func compileClause(b *builder, contractStk stack, contract *Contract, env *envir
                        return err
                }
        }
-       for _, req := range clause.Reqs {
-               err = env.add(req.Name, valueType, roleClauseValue)
-               if err != nil {
-                       return err
-               }
-               req.Asset = req.assetExpr.String()
-               req.Amount = req.amountExpr.String()
-       }
 
-       assignIndexes(clause)
+       if err = assignIndexes(clause); err != nil {
+               return err
+       }
 
        var stk stack
        for _, p := range clause.Params {
@@ -327,105 +319,299 @@ func compileClause(b *builder, contractStk stack, contract *Contract, env *envir
 
        // a count of the number of times each variable is referenced
        counts := make(map[string]int)
-       for _, req := range clause.Reqs {
-               req.assetExpr.countVarRefs(counts)
-               req.amountExpr.countVarRefs(counts)
-       }
        for _, s := range clause.statements {
+               if stmt, ok := s.(*defineStatement); ok && stmt.expr == nil {
+                       continue
+               }
+
                s.countVarRefs(counts)
+               if stmt, ok := s.(*ifStatement); ok {
+                       for _, trueStmt := range stmt.body.trueBody {
+                               trueStmt.countVarRefs(counts)
+                       }
+
+                       for _, falseStmt := range stmt.body.falseBody {
+                               falseStmt.countVarRefs(counts)
+                       }
+               }
        }
 
-       for _, s := range clause.statements {
-               switch stmt := s.(type) {
-               case *verifyStatement:
+       for _, stat := range clause.statements {
+               if stk, err = compileStatement(b, stk, contract, env, clause, counts, stat, sequence); err != nil {
+                       return err
+               }
+       }
+
+       err = requireAllValuesDisposedOnce(contract, clause)
+       if err != nil {
+               return err
+       }
+       err = typeCheckClause(contract, clause, env)
+       if err != nil {
+               return err
+       }
+       err = requireAllParamsUsedInClause(clause.Params, clause)
+       if err != nil {
+               return err
+       }
+
+       return nil
+}
+
+func compileStatement(b *builder, stk stack, contract *Contract, env *environ, clause *Clause, counts map[string]int, stat statement, sequence *int) (stack, error) {
+       var err error
+       switch stmt := stat.(type) {
+       case *ifStatement:
+               // sequence add 1 when the statement is ifStatement
+               *sequence++
+               strSequence := fmt.Sprintf("%d", *sequence)
+
+               // compile condition expression
+               stk, err = compileExpr(b, stk, contract, clause, env, counts, stmt.condition)
+               if err != nil {
+                       return stk, errors.Wrapf(err, "in check condition of ifStatement in clause \"%s\"", clause.Name)
+               }
+
+               // jump to falseBody when condition is false, while the JUMPIF instruction will be run success when
+               // the value of dataStack is true, therefore add this check
+               conditionExpr := stk.str
+               stk = b.addBoolean(stk, false)
+               stk = b.addEqual(stk, fmt.Sprintf("(%s == false)", conditionExpr)) // stack: [... <condition_result == false>]
+
+               // add label
+               var label string
+               if len(stmt.body.falseBody) != 0 {
+                       label = "else_" + strSequence
+               } else {
+                       label = "endif_" + strSequence
+               }
+               stk = b.addJumpIf(stk, label)
+               b.addJumpTarget(stk, "if_"+strSequence)
+
+               // temporary store stack and counts for falseBody
+               condStk := stk
+               elseCounts := make(map[string]int)
+               for k, v := range counts {
+                       elseCounts[k] = v
+               }
+
+               // compile trueBody statements
+               if len(stmt.body.trueBody) != 0 {
+                       for _, st := range stmt.body.trueBody {
+                               st.countVarRefs(counts)
+                       }
+
+                       // modify value amount because of using only once
+                       if counts[contract.Value.Amount] > 1 {
+                               counts[contract.Value.Amount] = 1
+                       }
+
+                       // modify value asset because of using only once
+                       if counts[contract.Value.Asset] > 1 {
+                               counts[contract.Value.Asset] = 1
+                       }
+
+                       for _, st := range stmt.body.trueBody {
+                               if stk, err = compileStatement(b, stk, contract, env, clause, counts, st, sequence); err != nil {
+                                       return stk, err
+                               }
+                       }
+               }
+
+               // compile falseBody statements
+               if len(stmt.body.falseBody) != 0 {
+                       counts := make(map[string]int)
+                       for k, v := range elseCounts {
+                               counts[k] = v
+                       }
+
+                       for _, st := range stmt.body.falseBody {
+                               st.countVarRefs(counts)
+                       }
+
+                       // modify value amount because of using only once
+                       if counts[contract.Value.Amount] > 1 {
+                               counts[contract.Value.Amount] = 1
+                       }
+
+                       // modify value asset because of using only once
+                       if counts[contract.Value.Asset] > 1 {
+                               counts[contract.Value.Asset] = 1
+                       }
+
+                       stk = condStk
+                       b.addJump(stk, "endif_"+strSequence)
+                       b.addJumpTarget(stk, "else_"+strSequence)
+
+                       for _, st := range stmt.body.falseBody {
+                               if stk, err = compileStatement(b, stk, contract, env, clause, counts, st, sequence); err != nil {
+                                       return stk, err
+                               }
+                       }
+               }
+               b.addJumpTarget(stk, "endif_"+strSequence)
+
+       case *defineStatement:
+               // add environ for define variable
+               if err = env.add(stmt.variable.Name, stmt.variable.Type, roleClauseVariable); err != nil {
+                       return stk, err
+               }
+
+               // check whether the variable is used or not
+               if counts[stmt.variable.Name] == 0 {
+                       return stk, fmt.Errorf("the defined variable \"%s\" is unused in clause \"%s\"", stmt.variable.Name, clause.Name)
+               }
+
+               if stmt.expr != nil {
+                       // variable
                        stk, err = compileExpr(b, stk, contract, clause, env, counts, stmt.expr)
                        if err != nil {
-                               return errors.Wrapf(err, "in verify statement in clause \"%s\"", clause.Name)
-                       }
-                       stk = b.addVerify(stk)
-
-                       // special-case reporting of certain function calls
-                       if c, ok := stmt.expr.(*callExpr); ok && len(c.args) == 1 {
-                               if b := referencedBuiltin(c.fn); b != nil {
-                                       switch b.name {
-                                       case "below":
-                                               clause.BlockHeight = append(clause.BlockHeight, c.args[0].String())
-                                       case "above":
-                                               clause.BlockHeight = append(clause.BlockHeight, c.args[0].String())
-                                       }
+                               return stk, errors.Wrapf(err, "in define statement in clause \"%s\"", clause.Name)
+                       }
+
+                       // modify stack name
+                       stk.str = stmt.variable.Name
+               }
+
+       case *assignStatement:
+               // find variable from environ with roleClauseVariable
+               if entry := env.lookup(string(stmt.variable.Name)); entry != nil {
+                       if entry.r != roleClauseVariable {
+                               return stk, fmt.Errorf("the type of variable is not roleClauseVariable in assign statement in clause \"%s\"", clause.Name)
+                       }
+                       stmt.variable.Type = entry.t
+               } else {
+                       return stk, fmt.Errorf("the variable \"%s\" is not defined before the assign statement in clause \"%s\"", stmt.variable.Name, clause.Name)
+               }
+
+               // temporary store the counts of defined variable
+               varCount := counts[stmt.variable.Name]
+
+               // calculate the counts of variable for assign statement
+               tmpCounts := make(map[string]int)
+               stmt.countVarRefs(tmpCounts)
+
+               // modify the map counts of defined variable to 1 and minus the number of defined variable
+               // when the assign expression contains the defined variable
+               if tmpCounts[stmt.variable.Name] > 0 {
+                       counts[stmt.variable.Name] = 1
+                       varCount -= tmpCounts[stmt.variable.Name]
+               } else {
+                       depth := stk.find(stmt.variable.Name)
+                       switch depth {
+                       case 0:
+                               break
+                       case 1:
+                               stk = b.addSwap(stk)
+                       default:
+                               stk = b.addRoll(stk, depth)
+                       }
+                       stk = b.addDrop(stk)
+               }
+
+               // variable
+               stk, err = compileExpr(b, stk, contract, clause, env, counts, stmt.expr)
+               if err != nil {
+                       return stk, errors.Wrapf(err, "in define statement in clause \"%s\"", clause.Name)
+               }
+
+               // restore the defined variable counts
+               counts[stmt.variable.Name] = varCount
+
+               // modify stack name
+               stk.str = stmt.variable.Name
+
+       case *verifyStatement:
+               stk, err = compileExpr(b, stk, contract, clause, env, counts, stmt.expr)
+               if err != nil {
+                       return stk, errors.Wrapf(err, "in verify statement in clause \"%s\"", clause.Name)
+               }
+               stk = b.addVerify(stk)
+
+               // special-case reporting of certain function calls
+               if c, ok := stmt.expr.(*callExpr); ok && len(c.args) == 1 {
+                       if b := referencedBuiltin(c.fn); b != nil {
+                               switch b.name {
+                               case "below":
+                                       clause.BlockHeight = append(clause.BlockHeight, c.args[0].String())
+                               case "above":
+                                       clause.BlockHeight = append(clause.BlockHeight, c.args[0].String())
                                }
                        }
+               }
 
-               case *lockStatement:
-                       // index
-                       stk = b.addInt64(stk, stmt.index)
+       case *lockStatement:
+               // index
+               stk = b.addInt64(stk, stmt.index)
 
-                       // TODO: permit more complex expressions for locked,
-                       // like "lock x+y with foo" (?)
+               // TODO: permit more complex expressions for locked,
+               // like "lock x+y with foo" (?)
 
-                       if stmt.locked.String() == contract.Value {
-                               stk = b.addAmount(stk)
-                               stk = b.addAsset(stk)
-                       } else {
-                               var req *ClauseReq
-                               for _, r := range clause.Reqs {
-                                       if stmt.locked.String() == r.Name {
-                                               req = r
-                                               break
-                                       }
+               if stmt.lockedAmount.String() == contract.Value.Amount && stmt.lockedAsset.String() == contract.Value.Asset {
+                       stk = b.addAmount(stk, contract.Value.Amount)
+                       stk = b.addAsset(stk, contract.Value.Asset)
+               } else {
+                       // calculate the counts of variable for lockStatement
+                       lockCounts := make(map[string]int)
+                       stmt.countVarRefs(lockCounts)
+
+                       // amount
+                       switch {
+                       case stmt.lockedAmount.String() == contract.Value.Amount:
+                               stk = b.addAmount(stk, contract.Value.Amount)
+                       case stmt.lockedAmount.String() != contract.Value.Amount && lockCounts[contract.Value.Amount] > 0:
+                               stk = b.addAmount(stk, contract.Value.Amount)
+                               stk, err = compileExpr(b, stk, contract, clause, env, counts, stmt.lockedAmount)
+                               if err != nil {
+                                       return stk, errors.Wrapf(err, "in lock statement in clause \"%s\"", clause.Name)
                                }
-                               if req == nil {
-                                       return fmt.Errorf("unknown value \"%s\" in lock statement in clause \"%s\"", stmt.locked, clause.Name)
+                       default:
+                               stk, err = compileExpr(b, stk, contract, clause, env, counts, stmt.lockedAmount)
+                               if err != nil {
+                                       return stk, errors.Wrapf(err, "in lock statement in clause \"%s\"", clause.Name)
                                }
+                       }
 
-                               // amount
-                               stk, err = compileExpr(b, stk, contract, clause, env, counts, req.amountExpr)
+                       // asset
+                       switch {
+                       case stmt.lockedAsset.String() == contract.Value.Asset:
+                               stk = b.addAsset(stk, contract.Value.Asset)
+                       case stmt.lockedAsset.String() != contract.Value.Asset && lockCounts[contract.Value.Asset] > 0:
+                               stk = b.addAsset(stk, contract.Value.Asset)
+                               stk, err = compileExpr(b, stk, contract, clause, env, counts, stmt.lockedAsset)
                                if err != nil {
-                                       return errors.Wrapf(err, "in lock statement in clause \"%s\"", clause.Name)
+                                       return stk, errors.Wrapf(err, "in lock statement in clause \"%s\"", clause.Name)
                                }
-
-                               // asset
-                               stk, err = compileExpr(b, stk, contract, clause, env, counts, req.assetExpr)
+                       default:
+                               stk, err = compileExpr(b, stk, contract, clause, env, counts, stmt.lockedAsset)
                                if err != nil {
-                                       return errors.Wrapf(err, "in lock statement in clause \"%s\"", clause.Name)
+                                       return stk, errors.Wrapf(err, "in lock statement in clause \"%s\"", clause.Name)
                                }
                        }
+               }
 
-                       // version
-                       stk = b.addInt64(stk, 1)
+               // version
+               stk = b.addInt64(stk, 1)
 
-                       // prog
-                       stk, err = compileExpr(b, stk, contract, clause, env, counts, stmt.program)
-                       if err != nil {
-                               return errors.Wrapf(err, "in lock statement in clause \"%s\"", clause.Name)
-                       }
+               // prog
+               stk, err = compileExpr(b, stk, contract, clause, env, counts, stmt.program)
+               if err != nil {
+                       return stk, errors.Wrapf(err, "in lock statement in clause \"%s\"", clause.Name)
+               }
 
-                       stk = b.addCheckOutput(stk, fmt.Sprintf("checkOutput(%s, %s)", stmt.locked, stmt.program))
-                       stk = b.addVerify(stk)
+               stk = b.addCheckOutput(stk, fmt.Sprintf("checkOutput(%s, %s, %s)",
+                       stmt.lockedAmount.String(), stmt.lockedAsset.String(), stmt.program))
+               stk = b.addVerify(stk)
 
-               case *unlockStatement:
-                       if len(clause.statements) == 1 {
-                               // This is the only statement in the clause, make sure TRUE is
-                               // on the stack.
-                               stk = b.addBoolean(stk, true)
-                       }
+       case *unlockStatement:
+               if len(clause.statements) == 1 {
+                       // This is the only statement in the clause, make sure TRUE is
+                       // on the stack.
+                       stk = b.addBoolean(stk, true)
                }
        }
 
-       err = requireAllValuesDisposedOnce(contract, clause)
-       if err != nil {
-               return err
-       }
-       err = typeCheckClause(contract, clause, env)
-       if err != nil {
-               return err
-       }
-       err = requireAllParamsUsedInClause(clause.Params, clause)
-       if err != nil {
-               return err
-       }
-
-       return nil
+       return stk, nil
 }
 
 func compileExpr(b *builder, stk stack, contract *Contract, clause *Clause, env *environ, counts map[string]int, expr expression) (stack, error) {
@@ -447,12 +633,12 @@ func compileExpr(b *builder, stk stack, contract *Contract, clause *Clause, env
                }
 
                lType := e.left.typ(env)
-               if e.op.left != "" && lType != e.op.left {
+               if e.op.left != "" && !(lType == e.op.left || lType == amountType) {
                        return stk, fmt.Errorf("in \"%s\", left operand has type \"%s\", must be \"%s\"", e, lType, e.op.left)
                }
 
                rType := e.right.typ(env)
-               if e.op.right != "" && rType != e.op.right {
+               if e.op.right != "" && !(rType == e.op.right || rType == amountType) {
                        return stk, fmt.Errorf("in \"%s\", right operand has type \"%s\", must be \"%s\"", e, rType, e.op.right)
                }