OSDN Git Service

versoin1.1.9 (#594)
[bytom/vapor.git] / protocol / validation / vmcontext.go
1 package validation
2
3 import (
4         "bytes"
5
6         "github.com/bytom/vapor/consensus/segwit"
7         "github.com/bytom/vapor/crypto/sha3pool"
8         "github.com/bytom/vapor/errors"
9         "github.com/bytom/vapor/protocol/bc"
10         "github.com/bytom/vapor/protocol/vm"
11 )
12
13 // NewTxVMContext generates the vm.Context for BVM
14 func NewTxVMContext(vs *validationState, entry bc.Entry, prog *bc.Program, args [][]byte) *vm.Context {
15         var (
16                 tx          = vs.tx
17                 blockHeight = vs.block.BlockHeader.GetHeight()
18                 numResults  = uint64(len(tx.ResultIds))
19                 entryID     = bc.EntryID(entry) // TODO(bobg): pass this in, don't recompute it
20
21                 assetID       *[]byte
22                 amount        *uint64
23                 destPos       *uint64
24                 spentOutputID *[]byte
25         )
26
27         switch e := entry.(type) {
28         case *bc.CrossChainInput:
29                 mainchainOutput := tx.Entries[*e.MainchainOutputId].(*bc.IntraChainOutput)
30                 a1 := mainchainOutput.Source.Value.AssetId.Bytes()
31                 assetID = &a1
32                 amount = &mainchainOutput.Source.Value.Amount
33                 destPos = &e.WitnessDestination.Position
34                 s := e.MainchainOutputId.Bytes()
35                 spentOutputID = &s
36
37         case *bc.Spend:
38                 spentOutput := tx.Entries[*e.SpentOutputId].(*bc.IntraChainOutput)
39                 a1 := spentOutput.Source.Value.AssetId.Bytes()
40                 assetID = &a1
41                 amount = &spentOutput.Source.Value.Amount
42                 destPos = &e.WitnessDestination.Position
43                 s := e.SpentOutputId.Bytes()
44                 spentOutputID = &s
45
46         case *bc.VetoInput:
47                 voteOutput := tx.Entries[*e.SpentOutputId].(*bc.VoteOutput)
48                 a1 := voteOutput.Source.Value.AssetId.Bytes()
49                 assetID = &a1
50                 amount = &voteOutput.Source.Value.Amount
51                 destPos = &e.WitnessDestination.Position
52                 s := e.SpentOutputId.Bytes()
53                 spentOutputID = &s
54         }
55
56         var txSigHash *[]byte
57         txSigHashFn := func() []byte {
58                 if txSigHash == nil {
59                         hasher := sha3pool.Get256()
60                         defer sha3pool.Put256(hasher)
61
62                         entryID.WriteTo(hasher)
63                         tx.ID.WriteTo(hasher)
64
65                         var hash bc.Hash
66                         hash.ReadFrom(hasher)
67                         hashBytes := hash.Bytes()
68                         txSigHash = &hashBytes
69                 }
70                 return *txSigHash
71         }
72
73         ec := &entryContext{
74                 entry:   entry,
75                 entries: tx.Entries,
76         }
77
78         result := &vm.Context{
79                 VMVersion: prog.VmVersion,
80                 Code:      witnessProgram(prog.Code),
81                 Arguments: args,
82
83                 EntryID: entryID.Bytes(),
84
85                 TxVersion:   &tx.Version,
86                 BlockHeight: &blockHeight,
87
88                 TxSigHash:     txSigHashFn,
89                 NumResults:    &numResults,
90                 AssetID:       assetID,
91                 Amount:        amount,
92                 DestPos:       destPos,
93                 SpentOutputID: spentOutputID,
94                 CheckOutput:   ec.checkOutput,
95         }
96
97         return result
98 }
99
100 func witnessProgram(prog []byte) []byte {
101         switch {
102         case segwit.IsP2WPKHScript(prog):
103                 if witnessProg, err := segwit.ConvertP2PKHSigProgram(prog); err == nil {
104                         return witnessProg
105                 }
106         case segwit.IsP2WSHScript(prog):
107                 if witnessProg, err := segwit.ConvertP2SHProgram(prog); err == nil {
108                         return witnessProg
109                 }
110         case segwit.IsP2WMCScript(prog):
111                 if witnessProg, err := segwit.ConvertP2MCProgram(prog); err == nil {
112                         return witnessProg
113                 }
114         }
115         return prog
116 }
117
118 type entryContext struct {
119         entry   bc.Entry
120         entries map[bc.Hash]bc.Entry
121 }
122
123 func (ec *entryContext) checkOutput(index uint64, amount uint64, assetID []byte, vmVersion uint64, code []byte, expansion bool) (bool, error) {
124         checkEntry := func(e bc.Entry) (bool, error) {
125                 check := func(prog *bc.Program, value *bc.AssetAmount) bool {
126                         return (prog.VmVersion == vmVersion &&
127                                 bytes.Equal(prog.Code, code) &&
128                                 bytes.Equal(value.AssetId.Bytes(), assetID) &&
129                                 value.Amount == amount)
130                 }
131
132                 switch e := e.(type) {
133                 case *bc.IntraChainOutput:
134                         return check(e.ControlProgram, e.Source.Value), nil
135
136                 case *bc.VoteOutput:
137                         return check(e.ControlProgram, e.Source.Value), nil
138
139                 case *bc.Retirement:
140                         var prog bc.Program
141                         if expansion {
142                                 // The spec requires prog.Code to be the empty string only
143                                 // when !expansion. When expansion is true, we prepopulate
144                                 // prog.Code to give check() a freebie match.
145                                 //
146                                 // (The spec always requires prog.VmVersion to be zero.)
147                                 prog.Code = code
148                         }
149                         return check(&prog, e.Source.Value), nil
150                 }
151
152                 return false, vm.ErrContext
153         }
154
155         checkMux := func(m *bc.Mux) (bool, error) {
156                 if index >= uint64(len(m.WitnessDestinations)) {
157                         return false, errors.Wrapf(vm.ErrBadValue, "index %d >= %d", index, len(m.WitnessDestinations))
158                 }
159                 eID := m.WitnessDestinations[index].Ref
160                 e, ok := ec.entries[*eID]
161                 if !ok {
162                         return false, errors.Wrapf(bc.ErrMissingEntry, "entry for mux destination %d, id %x, not found", index, eID.Bytes())
163                 }
164                 return checkEntry(e)
165         }
166
167         switch e := ec.entry.(type) {
168         case *bc.Mux:
169                 return checkMux(e)
170
171         case *bc.Spend:
172                 d, ok := ec.entries[*e.WitnessDestination.Ref]
173                 if !ok {
174                         return false, errors.Wrapf(bc.ErrMissingEntry, "entry for spend destination %x not found", e.WitnessDestination.Ref.Bytes())
175                 }
176                 if m, ok := d.(*bc.Mux); ok {
177                         return checkMux(m)
178                 }
179                 if index != 0 {
180                         return false, errors.Wrapf(vm.ErrBadValue, "index %d >= 1", index)
181                 }
182                 return checkEntry(d)
183
184         case *bc.VetoInput:
185                 d, ok := ec.entries[*e.WitnessDestination.Ref]
186                 if !ok {
187                         return false, errors.Wrapf(bc.ErrMissingEntry, "entry for vetoInput destination %x not found", e.WitnessDestination.Ref.Bytes())
188                 }
189                 if m, ok := d.(*bc.Mux); ok {
190                         return checkMux(m)
191                 }
192                 if index != 0 {
193                         return false, errors.Wrapf(vm.ErrBadValue, "index %d >= 1", index)
194                 }
195                 return checkEntry(d)
196         }
197
198         return false, vm.ErrContext
199 }