OSDN Git Service

optimise equity commandline (#36)
[bytom/equity.git] / compiler / checks.go
1 package compiler
2
3 import (
4         "fmt"
5         "strings"
6 )
7
8 func checkRecursive(contract *Contract) bool {
9         for _, clause := range contract.Clauses {
10                 for _, stmt := range clause.statements {
11                         if result := checkStatRecursive(stmt, contract.Name); result {
12                                 return true
13                         }
14                 }
15         }
16         return false
17 }
18
19 func checkStatRecursive(stmt statement, contractName string) bool {
20         switch s := stmt.(type) {
21         case *ifStatement:
22                 for _, trueStmt := range s.body.trueBody {
23                         if result := checkStatRecursive(trueStmt, contractName); result {
24                                 return true
25                         }
26                 }
27
28                 for _, falseStmt := range s.body.falseBody {
29                         if result := checkStatRecursive(falseStmt, contractName); result {
30                                 return true
31                         }
32                 }
33
34         case *lockStatement:
35                 if c, ok := s.program.(*callExpr); ok {
36                         if references(c.fn, contractName) {
37                                 return true
38                         }
39                 }
40         }
41
42         return false
43 }
44
45 func calClauseValues(contract *Contract, env *environ, stmt statement, condValues *[]CondValueInfo, tempVariables map[string]ExpressionInfo) (valueInfo *ValueInfo) {
46         switch s := stmt.(type) {
47         case *ifStatement:
48                 conditionCounts := make(map[string]int)
49                 s.condition.countVarRefs(conditionCounts)
50                 condExpr := s.condition.String()
51                 params := getParams(env, conditionCounts, &condExpr, tempVariables)
52                 condition := ExpressionInfo{Source: condExpr, Params: params}
53
54                 var trueValues []ValueInfo
55                 for _, trueStmt := range s.body.trueBody {
56                         var trueValue *ValueInfo
57                         trueValue = calClauseValues(contract, env, trueStmt, condValues, tempVariables)
58                         if trueValue != nil {
59                                 trueValues = append(trueValues, *trueValue)
60                         }
61                 }
62
63                 var falseValues []ValueInfo
64                 if len(s.body.falseBody) != 0 {
65                         for _, falseStmt := range s.body.falseBody {
66                                 var falseValue *ValueInfo
67                                 falseValue = calClauseValues(contract, env, falseStmt, condValues, tempVariables)
68                                 if falseValue != nil {
69                                         falseValues = append(falseValues, *falseValue)
70                                 }
71                         }
72                 }
73                 condValue := CondValueInfo{Condition: condition, TrueBodyValues: trueValues, FalseBodyValues: falseValues}
74                 *condValues = append([]CondValueInfo{condValue}, *condValues...)
75
76         case *defineStatement:
77                 if s.expr != nil {
78                         defineCounts := make(map[string]int)
79                         s.expr.countVarRefs(defineCounts)
80                         defineExpr := s.expr.String()
81                         params := getParams(env, defineCounts, &defineExpr, tempVariables)
82                         tempVariables[s.variable.Name] = ExpressionInfo{Source: defineExpr, Params: params}
83                 }
84
85         case *assignStatement:
86                 assignCounts := make(map[string]int)
87                 s.expr.countVarRefs(assignCounts)
88                 assignExpr := s.expr.String()
89                 params := getParams(env, assignCounts, &assignExpr, tempVariables)
90                 tempVariables[s.variable.Name] = ExpressionInfo{Source: assignExpr, Params: params}
91
92         case *lockStatement:
93                 valueInfo = &ValueInfo{Asset: s.lockedAsset.String()}
94                 lockCounts := make(map[string]int)
95                 s.lockedAmount.countVarRefs(lockCounts)
96                 lockedAmountExpr := s.lockedAmount.String()
97                 if _, ok := lockCounts[lockedAmountExpr]; !ok {
98                         valueInfo.AmountParams = getParams(env, lockCounts, &lockedAmountExpr, tempVariables)
99                 } else if _, ok := tempVariables[lockedAmountExpr]; ok {
100                         valueInfo.AmountParams = tempVariables[lockedAmountExpr].Params
101                         lockedAmountExpr = tempVariables[lockedAmountExpr].Source
102                 }
103                 valueInfo.Amount = lockedAmountExpr
104
105                 programExpr := s.program.String()
106                 if res, ok := s.program.(*callExpr); ok {
107                         if bi := referencedBuiltin(res.fn); bi == nil {
108                                 if v, ok := res.fn.(varRef); ok {
109                                         if entry := env.lookup(string(v)); entry != nil && entry.t == contractType {
110                                                 programExpr = fmt.Sprintf("%s(", string(v))
111                                                 for i := 0; i < len(res.args); i++ {
112                                                         argExpr := res.args[i].String()
113                                                         argCounts := make(map[string]int)
114                                                         res.args[i].countVarRefs(argCounts)
115                                                         if _, ok := argCounts[argExpr]; !ok {
116                                                                 params := getParams(env, argCounts, &argExpr, tempVariables)
117                                                                 valueInfo.ContractCalls = append(valueInfo.ContractCalls, CallArgs{Source: argExpr, Position: i, Params: params})
118                                                         } else if _, ok := tempVariables[argExpr]; ok {
119                                                                 valueInfo.ContractCalls = append(valueInfo.ContractCalls, CallArgs{Source: tempVariables[argExpr].Source, Position: i, Params: tempVariables[argExpr].Params})
120                                                                 argExpr = tempVariables[argExpr].Source
121                                                         }
122
123                                                         if i == len(res.args)-1 {
124                                                                 programExpr = fmt.Sprintf("%s%s)", programExpr, argExpr)
125                                                         } else {
126                                                                 programExpr = fmt.Sprintf("%s%s, ", programExpr, argExpr)
127                                                         }
128                                                 }
129                                         }
130                                 }
131                         }
132                 }
133                 valueInfo.Program = programExpr
134
135         case *unlockStatement:
136                 valueInfo = &ValueInfo{
137                         Amount: contract.Value.Amount,
138                         Asset:  contract.Value.Asset,
139                 }
140         }
141
142         return valueInfo
143 }
144
145 func getParams(env *environ, counts map[string]int, expr *string, tempVariables map[string]ExpressionInfo) (params []*Param) {
146         for v := range counts {
147                 if entry := env.lookup(v); entry != nil && (entry.r == roleContractParam || entry.r == roleContractValue || entry.r == roleClauseParam) {
148                         params = append(params, &Param{Name: v, Type: entry.t})
149                 } else if entry.r == roleClauseVariable {
150                         if expr != nil {
151                                 *expr = strings.Replace(*expr, v, tempVariables[v].Source, -1)
152                         }
153
154                         if _, ok := tempVariables[v]; ok {
155                                 for _, param := range tempVariables[v].Params {
156                                         if ok := checkParams(param, params); !ok {
157                                                 params = append(params, &Param{Name: param.Name, Type: param.Type})
158                                         }
159                                 }
160                         }
161                 }
162         }
163         return params
164 }
165
166 func checkParams(param *Param, params []*Param) bool {
167         for _, p := range params {
168                 if p.Name == param.Name {
169                         return true
170                 }
171         }
172         return false
173 }
174
175 func prohibitSigParams(contract *Contract) error {
176         for _, p := range contract.Params {
177                 if p.Type == sigType {
178                         return fmt.Errorf("contract parameter \"%s\" has type Signature, but contract parameters cannot have type Signature", p.Name)
179                 }
180         }
181         return nil
182 }
183
184 func requireAllParamsUsedInClauses(params []*Param, clauses []*Clause) error {
185         for _, p := range params {
186                 used := false
187                 for _, c := range clauses {
188                         err := requireAllParamsUsedInClause([]*Param{p}, c)
189                         if err == nil {
190                                 used = true
191                                 break
192                         }
193                 }
194
195                 if !used {
196                         return fmt.Errorf("parameter \"%s\" is unused", p.Name)
197                 }
198         }
199         return nil
200 }
201
202 func requireAllParamsUsedInClause(params []*Param, clause *Clause) error {
203         for _, p := range params {
204                 used := false
205                 for _, stmt := range clause.statements {
206                         if used = checkParamUsedInStatement(p, stmt); used {
207                                 break
208                         }
209                 }
210
211                 if !used {
212                         return fmt.Errorf("parameter \"%s\" is unused in clause \"%s\"", p.Name, clause.Name)
213                 }
214         }
215         return nil
216 }
217
218 func checkParamUsedInStatement(param *Param, stmt statement) (used bool) {
219         switch s := stmt.(type) {
220         case *ifStatement:
221                 if used = references(s.condition, param.Name); used {
222                         return used
223                 }
224
225                 for _, st := range s.body.trueBody {
226                         if used = checkParamUsedInStatement(param, st); used {
227                                 break
228                         }
229                 }
230
231                 if !used {
232                         for _, st := range s.body.falseBody {
233                                 if used = checkParamUsedInStatement(param, st); used {
234                                         break
235                                 }
236                         }
237                 }
238
239         case *defineStatement:
240                 used = references(s.expr, param.Name)
241         case *assignStatement:
242                 used = references(s.expr, param.Name)
243         case *verifyStatement:
244                 used = references(s.expr, param.Name)
245         case *lockStatement:
246                 used = references(s.lockedAmount, param.Name) || references(s.lockedAsset, param.Name) || references(s.program, param.Name)
247         case *unlockStatement:
248                 used = references(s.unlockedAmount, param.Name) || references(s.unlockedAsset, param.Name)
249         }
250
251         return used
252 }
253
254 func references(expr expression, name string) bool {
255         switch e := expr.(type) {
256         case *binaryExpr:
257                 return references(e.left, name) || references(e.right, name)
258         case *unaryExpr:
259                 return references(e.expr, name)
260         case *callExpr:
261                 if references(e.fn, name) {
262                         return true
263                 }
264                 for _, a := range e.args {
265                         if references(a, name) {
266                                 return true
267                         }
268                 }
269                 return false
270         case varRef:
271                 return string(e) == name
272         case listExpr:
273                 for _, elt := range []expression(e) {
274                         if references(elt, name) {
275                                 return true
276                         }
277                 }
278                 return false
279         }
280         return false
281 }
282
283 func referencedBuiltin(expr expression) *builtin {
284         if v, ok := expr.(varRef); ok {
285                 for _, b := range builtins {
286                         if string(v) == b.name {
287                                 return &b
288                         }
289                 }
290         }
291         return nil
292 }
293
294 func countsVarRef(stat statement, counts map[string]int) map[string]int {
295         if stmt, ok := stat.(*defineStatement); ok && stmt.expr == nil {
296                 return counts
297         }
298
299         if _, ok := stat.(*unlockStatement); ok {
300                 return counts
301         }
302
303         stat.countVarRefs(counts)
304         if stmt, ok := stat.(*ifStatement); ok {
305                 for _, trueStmt := range stmt.body.trueBody {
306                         counts = countsVarRef(trueStmt, counts)
307                 }
308
309                 for _, falseStmt := range stmt.body.falseBody {
310                         counts = countsVarRef(falseStmt, counts)
311                 }
312         }
313
314         return counts
315 }
316
317 func assignIndexes(clause *Clause) error {
318         var nextIndex int64
319         for i, stmt := range clause.statements {
320                 if nextIndex = assignStatIndexes(stmt, nextIndex, i != len(clause.statements)-1); nextIndex < 0 {
321                         return fmt.Errorf("Not support that the number of lock/unlock statement is not equal between ifbody and elsebody when the if-else is not the last statement in clause \"%s\"", clause.Name)
322                 }
323         }
324
325         return nil
326 }
327
328 func assignStatIndexes(stat statement, nextIndex int64, nonFinalFlag bool) int64 {
329         switch stmt := stat.(type) {
330         case *ifStatement:
331                 trueIndex := nextIndex
332                 falseIndex := nextIndex
333                 for _, trueStmt := range stmt.body.trueBody {
334                         trueIndex = assignStatIndexes(trueStmt, trueIndex, nonFinalFlag)
335                 }
336
337                 for _, falseStmt := range stmt.body.falseBody {
338                         falseIndex = assignStatIndexes(falseStmt, falseIndex, nonFinalFlag)
339                 }
340
341                 if trueIndex != falseIndex && nonFinalFlag {
342                         return -1
343                 } else if trueIndex == falseIndex {
344                         nextIndex = trueIndex
345                 }
346
347         case *lockStatement:
348                 stmt.index = nextIndex
349                 nextIndex++
350
351         case *unlockStatement:
352                 nextIndex++
353         }
354
355         return nextIndex
356 }
357
358 func typeCheckClause(contract *Contract, clause *Clause, env *environ) error {
359         for _, s := range clause.statements {
360                 if err := typeCheckStatement(s, contract.Value, clause.Name, env); err != nil {
361                         return err
362                 }
363         }
364         return nil
365 }
366
367 func typeCheckStatement(stat statement, contractValue ValueInfo, clauseName string, env *environ) error {
368         switch stmt := stat.(type) {
369         case *ifStatement:
370                 for _, trueStmt := range stmt.body.trueBody {
371                         if err := typeCheckStatement(trueStmt, contractValue, clauseName, env); err != nil {
372                                 return err
373                         }
374                 }
375
376                 for _, falseStmt := range stmt.body.falseBody {
377                         if err := typeCheckStatement(falseStmt, contractValue, clauseName, env); err != nil {
378                                 return err
379                         }
380                 }
381
382         case *defineStatement:
383                 if stmt.expr != nil && stmt.expr.typ(env) != stmt.variable.Type && !(stmt.variable.Type == hashType && isHashSubtype(stmt.expr.typ(env))) {
384                         return fmt.Errorf("expression in define statement in clause \"%s\" has type \"%s\", must be \"%s\"",
385                                 clauseName, stmt.expr.typ(env), stmt.variable.Type)
386                 }
387
388         case *assignStatement:
389                 if stmt.expr.typ(env) != stmt.variable.Type && !(stmt.variable.Type == hashType && isHashSubtype(stmt.expr.typ(env))) {
390                         return fmt.Errorf("expression in assign statement in clause \"%s\" has type \"%s\", must be \"%s\"",
391                                 clauseName, stmt.expr.typ(env), stmt.variable.Type)
392                 }
393
394         case *verifyStatement:
395                 if t := stmt.expr.typ(env); t != boolType {
396                         return fmt.Errorf("expression in verify statement in clause \"%s\" has type \"%s\", must be Boolean", clauseName, t)
397                 }
398
399         case *lockStatement:
400                 if t := stmt.lockedAmount.typ(env); !(t == intType || t == amountType) {
401                         return fmt.Errorf("lockedAmount expression \"%s\" in lock statement in clause \"%s\" has type \"%s\", must be Integer", stmt.lockedAmount, clauseName, t)
402                 }
403                 if t := stmt.lockedAsset.typ(env); t != assetType {
404                         return fmt.Errorf("lockedAsset expression \"%s\" in lock statement in clause \"%s\" has type \"%s\", must be Asset", stmt.lockedAsset, clauseName, t)
405                 }
406                 if t := stmt.program.typ(env); t != progType {
407                         return fmt.Errorf("program in lock statement in clause \"%s\" has type \"%s\", must be Program", clauseName, t)
408                 }
409
410         case *unlockStatement:
411                 if t := stmt.unlockedAmount.typ(env); !(t == intType || t == amountType) {
412                         return fmt.Errorf("unlockedAmount expression \"%s\" in unlock statement of clause \"%s\" has type \"%s\", must be Integer", stmt.unlockedAmount, clauseName, t)
413                 }
414                 if t := stmt.unlockedAsset.typ(env); t != assetType {
415                         return fmt.Errorf("unlockedAsset expression \"%s\" in unlock statement of clause \"%s\" has type \"%s\", must be Asset", stmt.unlockedAsset, clauseName, t)
416                 }
417                 if stmt.unlockedAmount.String() != contractValue.Amount || stmt.unlockedAsset.String() != contractValue.Asset {
418                         return fmt.Errorf("amount \"%s\" of asset \"%s\" expression in unlock statement of clause \"%s\" must be the contract valueAmount \"%s\" of valueAsset \"%s\"",
419                                 stmt.unlockedAmount.String(), stmt.unlockedAsset.String(), clauseName, contractValue.Amount, contractValue.Asset)
420                 }
421         }
422
423         return nil
424 }