OSDN Git Service

7ac617c92282f04f3d268c2a8a42796164dc56af
[bytom/vapor.git] / protocol / validation / vmcontext.go
1 package validation
2
3 import (
4         "bytes"
5
6         "github.com/vapor/consensus/segwit"
7         "github.com/vapor/crypto/sha3pool"
8         "github.com/vapor/errors"
9         "github.com/vapor/protocol/bc"
10         "github.com/vapor/protocol/vm"
11 )
12
13 // NewTxVMContext generates the vm.Context for BVM
14 func NewTxVMContext(vs *validationState, entry bc.Entry, prog *bc.Program, args [][]byte) *vm.Context {
15         var (
16                 tx          = vs.tx
17                 blockHeight = vs.block.BlockHeader.GetHeight()
18                 numResults  = uint64(len(tx.ResultIds))
19                 entryID     = bc.EntryID(entry) // TODO(bobg): pass this in, don't recompute it
20
21                 assetID       *[]byte
22                 amount        *uint64
23                 destPos       *uint64
24                 spentOutputID *[]byte
25         )
26
27         switch e := entry.(type) {
28         case *bc.CrossChainInput:
29                 mainchainOutput := tx.Entries[*e.MainchainOutputId].(*bc.IntraChainOutput)
30                 a1 := mainchainOutput.Source.Value.AssetId.Bytes()
31                 assetID = &a1
32                 amount = &mainchainOutput.Source.Value.Amount
33                 destPos = &e.WitnessDestination.Position
34                 s := e.MainchainOutputId.Bytes()
35                 spentOutputID = &s
36
37         case *bc.Spend:
38                 spentOutput := tx.Entries[*e.SpentOutputId].(*bc.IntraChainOutput)
39                 a1 := spentOutput.Source.Value.AssetId.Bytes()
40                 assetID = &a1
41                 amount = &spentOutput.Source.Value.Amount
42                 destPos = &e.WitnessDestination.Position
43                 s := e.SpentOutputId.Bytes()
44                 spentOutputID = &s
45
46         case *bc.VetoInput:
47                 voteOutput := tx.Entries[*e.SpentOutputId].(*bc.VoteOutput)
48                 a1 := voteOutput.Source.Value.AssetId.Bytes()
49                 assetID = &a1
50                 amount = &voteOutput.Source.Value.Amount
51                 destPos = &e.WitnessDestination.Position
52                 s := e.SpentOutputId.Bytes()
53                 spentOutputID = &s
54         }
55
56         var txSigHash *[]byte
57         txSigHashFn := func() []byte {
58                 if txSigHash == nil {
59                         hasher := sha3pool.Get256()
60                         defer sha3pool.Put256(hasher)
61
62                         entryID.WriteTo(hasher)
63                         tx.ID.WriteTo(hasher)
64
65                         var hash bc.Hash
66                         hash.ReadFrom(hasher)
67                         hashBytes := hash.Bytes()
68                         txSigHash = &hashBytes
69                 }
70                 return *txSigHash
71         }
72
73         ec := &entryContext{
74                 entry:   entry,
75                 entries: tx.Entries,
76         }
77
78         result := &vm.Context{
79                 VMVersion: prog.VmVersion,
80                 Code:      witnessProgram(prog.Code),
81                 Arguments: args,
82
83                 EntryID: entryID.Bytes(),
84
85                 TxVersion:   &tx.Version,
86                 BlockHeight: &blockHeight,
87
88                 TxSigHash:     txSigHashFn,
89                 NumResults:    &numResults,
90                 AssetID:       assetID,
91                 Amount:        amount,
92                 DestPos:       destPos,
93                 SpentOutputID: spentOutputID,
94                 CheckOutput:   ec.checkOutput,
95         }
96
97         return result
98 }
99
100 func witnessProgram(prog []byte) []byte {
101         if segwit.IsP2WPKHScript(prog) {
102                 if witnessProg, err := segwit.ConvertP2PKHSigProgram([]byte(prog)); err == nil {
103                         return witnessProg
104                 }
105         } else if segwit.IsP2WSHScript(prog) {
106                 if witnessProg, err := segwit.ConvertP2SHProgram([]byte(prog)); err == nil {
107                         return witnessProg
108                 }
109         }
110         return prog
111 }
112
113 type entryContext struct {
114         entry   bc.Entry
115         entries map[bc.Hash]bc.Entry
116 }
117
118 func (ec *entryContext) checkOutput(index uint64, amount uint64, assetID []byte, vmVersion uint64, code []byte, expansion bool) (bool, error) {
119         checkEntry := func(e bc.Entry) (bool, error) {
120                 check := func(prog *bc.Program, value *bc.AssetAmount) bool {
121                         return (prog.VmVersion == vmVersion &&
122                                 bytes.Equal(prog.Code, code) &&
123                                 bytes.Equal(value.AssetId.Bytes(), assetID) &&
124                                 value.Amount == amount)
125                 }
126
127                 switch e := e.(type) {
128                 case *bc.IntraChainOutput:
129                         return check(e.ControlProgram, e.Source.Value), nil
130
131                 case *bc.VoteOutput:
132                         return check(e.ControlProgram, e.Source.Value), nil
133
134                 case *bc.Retirement:
135                         var prog bc.Program
136                         if expansion {
137                                 // The spec requires prog.Code to be the empty string only
138                                 // when !expansion. When expansion is true, we prepopulate
139                                 // prog.Code to give check() a freebie match.
140                                 //
141                                 // (The spec always requires prog.VmVersion to be zero.)
142                                 prog.Code = code
143                         }
144                         return check(&prog, e.Source.Value), nil
145                 }
146
147                 return false, vm.ErrContext
148         }
149
150         checkMux := func(m *bc.Mux) (bool, error) {
151                 if index >= uint64(len(m.WitnessDestinations)) {
152                         return false, errors.Wrapf(vm.ErrBadValue, "index %d >= %d", index, len(m.WitnessDestinations))
153                 }
154                 eID := m.WitnessDestinations[index].Ref
155                 e, ok := ec.entries[*eID]
156                 if !ok {
157                         return false, errors.Wrapf(bc.ErrMissingEntry, "entry for mux destination %d, id %x, not found", index, eID.Bytes())
158                 }
159                 return checkEntry(e)
160         }
161
162         switch e := ec.entry.(type) {
163         case *bc.Mux:
164                 return checkMux(e)
165
166         case *bc.Spend:
167                 d, ok := ec.entries[*e.WitnessDestination.Ref]
168                 if !ok {
169                         return false, errors.Wrapf(bc.ErrMissingEntry, "entry for spend destination %x not found", e.WitnessDestination.Ref.Bytes())
170                 }
171                 if m, ok := d.(*bc.Mux); ok {
172                         return checkMux(m)
173                 }
174                 if index != 0 {
175                         return false, errors.Wrapf(vm.ErrBadValue, "index %d >= 1", index)
176                 }
177                 return checkEntry(d)
178
179         case *bc.VetoInput:
180                 d, ok := ec.entries[*e.WitnessDestination.Ref]
181                 if !ok {
182                         return false, errors.Wrapf(bc.ErrMissingEntry, "entry for vetoInput destination %x not found", e.WitnessDestination.Ref.Bytes())
183                 }
184                 if m, ok := d.(*bc.Mux); ok {
185                         return checkMux(m)
186                 }
187                 if index != 0 {
188                         return false, errors.Wrapf(vm.ErrBadValue, "index %d >= 1", index)
189                 }
190                 return checkEntry(d)
191         }
192
193         return false, vm.ErrContext
194 }