OSDN Git Service

cdaedfa86b3666a98aef1642b08c2dbe1f7831ca
[bytom/bytom.git] / protocol / validation / vmcontext.go
1 package validation
2
3 import (
4         "bytes"
5
6         "github.com/bytom/bytom/consensus/bcrp"
7         "github.com/bytom/bytom/consensus/segwit"
8         "github.com/bytom/bytom/crypto/sha3pool"
9         "github.com/bytom/bytom/errors"
10         "github.com/bytom/bytom/protocol/bc"
11         "github.com/bytom/bytom/protocol/vm"
12 )
13
14 // NewTxVMContext generates the vm.Context for BVM
15 func NewTxVMContext(vs *validationState, entry bc.Entry, prog *bc.Program, stateData [][]byte, args [][]byte) *vm.Context {
16         var (
17                 tx          = vs.tx
18                 blockHeight = vs.block.BlockHeader.GetHeight()
19                 numResults  = uint64(len(tx.ResultIds))
20                 entryID     = bc.EntryID(entry) // TODO(bobg): pass this in, don't recompute it
21
22                 assetID       *[]byte
23                 amount        *uint64
24                 destPos       *uint64
25                 spentOutputID *[]byte
26         )
27
28         switch e := entry.(type) {
29         case *bc.Issuance:
30                 a1 := e.Value.AssetId.Bytes()
31                 assetID = &a1
32                 amount = &e.Value.Amount
33                 destPos = &e.WitnessDestination.Position
34
35         case *bc.Spend:
36                 spentOutput := tx.Entries[*e.SpentOutputId].(*bc.OriginalOutput)
37                 a1 := spentOutput.Source.Value.AssetId.Bytes()
38                 assetID = &a1
39                 amount = &spentOutput.Source.Value.Amount
40                 destPos = &e.WitnessDestination.Position
41                 s := e.SpentOutputId.Bytes()
42                 spentOutputID = &s
43         }
44
45         var txSigHash *[]byte
46         txSigHashFn := func() []byte {
47                 if txSigHash == nil {
48                         hasher := sha3pool.Get256()
49                         defer sha3pool.Put256(hasher)
50
51                         entryID.WriteTo(hasher)
52                         tx.ID.WriteTo(hasher)
53
54                         var hash bc.Hash
55                         hash.ReadFrom(hasher)
56                         hashBytes := hash.Bytes()
57                         txSigHash = &hashBytes
58                 }
59                 return *txSigHash
60         }
61
62         ec := &entryContext{
63                 entry:   entry,
64                 entries: tx.Entries,
65         }
66
67         result := &vm.Context{
68                 VMVersion: prog.VmVersion,
69                 Code:      convertProgram(prog.Code, vs.converter),
70                 StateData: stateData,
71                 Arguments: args,
72
73                 EntryID: entryID.Bytes(),
74
75                 TxVersion:   &tx.Version,
76                 BlockHeight: &blockHeight,
77
78                 TxSigHash:     txSigHashFn,
79                 NumResults:    &numResults,
80                 AssetID:       assetID,
81                 Amount:        amount,
82                 DestPos:       destPos,
83                 SpentOutputID: spentOutputID,
84                 CheckOutput:   ec.checkOutput,
85         }
86
87         return result
88 }
89
90 func convertProgram(prog []byte, converter ProgramConverterFunc) []byte {
91         if segwit.IsP2WPKHScript(prog) {
92                 if witnessProg, err := segwit.ConvertP2PKHSigProgram([]byte(prog)); err == nil {
93                         return witnessProg
94                 }
95         } else if segwit.IsP2WSHScript(prog) {
96                 if witnessProg, err := segwit.ConvertP2SHProgram([]byte(prog)); err == nil {
97                         return witnessProg
98                 }
99         } else if bcrp.IsCallContractScript(prog) {
100                 if contractProg, err := converter(prog); err == nil {
101                         return contractProg
102                 }
103         }
104         return prog
105 }
106
107 type entryContext struct {
108         entry   bc.Entry
109         entries map[bc.Hash]bc.Entry
110 }
111
112 func (ec *entryContext) checkOutput(index uint64, amount uint64, assetID []byte, vmVersion uint64, code []byte, state [][]byte, expansion bool) (bool, error) {
113         checkEntry := func(e bc.Entry) (bool, error) {
114                 check := func(prog *bc.Program, value *bc.AssetAmount, stateData [][]byte) bool {
115                         return (prog.VmVersion == vmVersion &&
116                                 bytes.Equal(prog.Code, code) &&
117                                 bytes.Equal(value.AssetId.Bytes(), assetID) &&
118                                 value.Amount == amount &&
119                                 bytesEqual(stateData, state))
120                 }
121
122                 switch e := e.(type) {
123                 case *bc.OriginalOutput:
124                         return check(e.ControlProgram, e.Source.Value, e.StateData), nil
125
126                 case *bc.Retirement:
127                         var prog bc.Program
128                         if expansion {
129                                 // The spec requires prog.Code to be the empty string only
130                                 // when !expansion. When expansion is true, we prepopulate
131                                 // prog.Code to give check() a freebie match.
132                                 //
133                                 // (The spec always requires prog.VmVersion to be zero.)
134                                 prog.Code = code
135                         }
136                         return check(&prog, e.Source.Value, [][]byte{}), nil
137                 }
138
139                 return false, vm.ErrContext
140         }
141
142         checkMux := func(m *bc.Mux) (bool, error) {
143                 if index >= uint64(len(m.WitnessDestinations)) {
144                         return false, errors.Wrapf(vm.ErrBadValue, "index %d >= %d", index, len(m.WitnessDestinations))
145                 }
146                 eID := m.WitnessDestinations[index].Ref
147                 e, ok := ec.entries[*eID]
148                 if !ok {
149                         return false, errors.Wrapf(bc.ErrMissingEntry, "entry for mux destination %d, id %x, not found", index, eID.Bytes())
150                 }
151                 return checkEntry(e)
152         }
153
154         switch e := ec.entry.(type) {
155         case *bc.Mux:
156                 return checkMux(e)
157
158         case *bc.Issuance:
159                 d, ok := ec.entries[*e.WitnessDestination.Ref]
160                 if !ok {
161                         return false, errors.Wrapf(bc.ErrMissingEntry, "entry for issuance destination %x not found", e.WitnessDestination.Ref.Bytes())
162                 }
163                 if m, ok := d.(*bc.Mux); ok {
164                         return checkMux(m)
165                 }
166                 if index != 0 {
167                         return false, errors.Wrapf(vm.ErrBadValue, "index %d >= 1", index)
168                 }
169                 return checkEntry(d)
170
171         case *bc.Spend:
172                 d, ok := ec.entries[*e.WitnessDestination.Ref]
173                 if !ok {
174                         return false, errors.Wrapf(bc.ErrMissingEntry, "entry for spend destination %x not found", e.WitnessDestination.Ref.Bytes())
175                 }
176                 if m, ok := d.(*bc.Mux); ok {
177                         return checkMux(m)
178                 }
179                 if index != 0 {
180                         return false, errors.Wrapf(vm.ErrBadValue, "index %d >= 1", index)
181                 }
182                 return checkEntry(d)
183         }
184
185         return false, vm.ErrContext
186 }
187
188 func bytesEqual(a, b [][]byte) bool {
189         if (a == nil) != (b == nil) {
190                 return false
191         }
192
193         if len(a) != len(b) {
194                 return false
195         }
196
197         for i, v := range a {
198                 if !bytes.Equal(v, b[i]) {
199                         return false
200                 }
201         }
202
203         return true
204 }