OSDN Git Service

fix: add CrossChainInput in NewTxVMContext (#124)
[bytom/vapor.git] / protocol / validation / vmcontext.go
1 package validation
2
3 import (
4         "bytes"
5
6         "github.com/vapor/consensus/segwit"
7         "github.com/vapor/crypto/sha3pool"
8         "github.com/vapor/errors"
9         "github.com/vapor/protocol/bc"
10         "github.com/vapor/protocol/vm"
11 )
12
13 // NewTxVMContext generates the vm.Context for BVM
14 func NewTxVMContext(vs *validationState, entry bc.Entry, prog *bc.Program, args [][]byte) *vm.Context {
15         var (
16                 tx          = vs.tx
17                 blockHeight = vs.block.BlockHeader.GetHeight()
18                 numResults  = uint64(len(tx.ResultIds))
19                 entryID     = bc.EntryID(entry) // TODO(bobg): pass this in, don't recompute it
20
21                 assetID       *[]byte
22                 amount        *uint64
23                 destPos       *uint64
24                 spentOutputID *[]byte
25         )
26
27         switch e := entry.(type) {
28         case *bc.CrossChainInput:
29                 a1 := e.Value.AssetId.Bytes()
30                 assetID = &a1
31                 amount = &e.Value.Amount
32                 destPos = &e.WitnessDestination.Position
33
34         case *bc.Spend:
35                 switch spentOutput := tx.Entries[*e.SpentOutputId].(type) {
36                 case *bc.IntraChainOutput:
37                         a1 := spentOutput.Source.Value.AssetId.Bytes()
38                         assetID = &a1
39                         amount = &spentOutput.Source.Value.Amount
40                         destPos = &e.WitnessDestination.Position
41                         s := e.SpentOutputId.Bytes()
42                         spentOutputID = &s
43
44                 case *bc.VoteOutput:
45                         a1 := spentOutput.Source.Value.AssetId.Bytes()
46                         assetID = &a1
47                         amount = &spentOutput.Source.Value.Amount
48                         destPos = &e.WitnessDestination.Position
49                         s := e.SpentOutputId.Bytes()
50                         spentOutputID = &s
51                 }
52         }
53
54         var txSigHash *[]byte
55         txSigHashFn := func() []byte {
56                 if txSigHash == nil {
57                         hasher := sha3pool.Get256()
58                         defer sha3pool.Put256(hasher)
59
60                         entryID.WriteTo(hasher)
61                         tx.ID.WriteTo(hasher)
62
63                         var hash bc.Hash
64                         hash.ReadFrom(hasher)
65                         hashBytes := hash.Bytes()
66                         txSigHash = &hashBytes
67                 }
68                 return *txSigHash
69         }
70
71         ec := &entryContext{
72                 entry:   entry,
73                 entries: tx.Entries,
74         }
75
76         result := &vm.Context{
77                 VMVersion: prog.VmVersion,
78                 Code:      witnessProgram(prog.Code),
79                 Arguments: args,
80
81                 EntryID: entryID.Bytes(),
82
83                 TxVersion:   &tx.Version,
84                 BlockHeight: &blockHeight,
85
86                 TxSigHash:     txSigHashFn,
87                 NumResults:    &numResults,
88                 AssetID:       assetID,
89                 Amount:        amount,
90                 DestPos:       destPos,
91                 SpentOutputID: spentOutputID,
92                 CheckOutput:   ec.checkOutput,
93         }
94
95         return result
96 }
97
98 func witnessProgram(prog []byte) []byte {
99         if segwit.IsP2WPKHScript(prog) {
100                 if witnessProg, err := segwit.ConvertP2PKHSigProgram([]byte(prog)); err == nil {
101                         return witnessProg
102                 }
103         } else if segwit.IsP2WSHScript(prog) {
104                 if witnessProg, err := segwit.ConvertP2SHProgram([]byte(prog)); err == nil {
105                         return witnessProg
106                 }
107         }
108         return prog
109 }
110
111 type entryContext struct {
112         entry   bc.Entry
113         entries map[bc.Hash]bc.Entry
114 }
115
116 func (ec *entryContext) checkOutput(index uint64, amount uint64, assetID []byte, vmVersion uint64, code []byte, expansion bool) (bool, error) {
117         checkEntry := func(e bc.Entry) (bool, error) {
118                 check := func(prog *bc.Program, value *bc.AssetAmount) bool {
119                         return (prog.VmVersion == vmVersion &&
120                                 bytes.Equal(prog.Code, code) &&
121                                 bytes.Equal(value.AssetId.Bytes(), assetID) &&
122                                 value.Amount == amount)
123                 }
124
125                 switch e := e.(type) {
126                 case *bc.IntraChainOutput:
127                         return check(e.ControlProgram, e.Source.Value), nil
128
129                 case *bc.VoteOutput:
130                         return check(e.ControlProgram, e.Source.Value), nil
131
132                 case *bc.Retirement:
133                         var prog bc.Program
134                         if expansion {
135                                 // The spec requires prog.Code to be the empty string only
136                                 // when !expansion. When expansion is true, we prepopulate
137                                 // prog.Code to give check() a freebie match.
138                                 //
139                                 // (The spec always requires prog.VmVersion to be zero.)
140                                 prog.Code = code
141                         }
142                         return check(&prog, e.Source.Value), nil
143                 }
144
145                 return false, vm.ErrContext
146         }
147
148         checkMux := func(m *bc.Mux) (bool, error) {
149                 if index >= uint64(len(m.WitnessDestinations)) {
150                         return false, errors.Wrapf(vm.ErrBadValue, "index %d >= %d", index, len(m.WitnessDestinations))
151                 }
152                 eID := m.WitnessDestinations[index].Ref
153                 e, ok := ec.entries[*eID]
154                 if !ok {
155                         return false, errors.Wrapf(bc.ErrMissingEntry, "entry for mux destination %d, id %x, not found", index, eID.Bytes())
156                 }
157                 return checkEntry(e)
158         }
159
160         switch e := ec.entry.(type) {
161         case *bc.Mux:
162                 return checkMux(e)
163
164         case *bc.Spend:
165                 d, ok := ec.entries[*e.WitnessDestination.Ref]
166                 if !ok {
167                         return false, errors.Wrapf(bc.ErrMissingEntry, "entry for spend destination %x not found", e.WitnessDestination.Ref.Bytes())
168                 }
169                 if m, ok := d.(*bc.Mux); ok {
170                         return checkMux(m)
171                 }
172                 if index != 0 {
173                         return false, errors.Wrapf(vm.ErrBadValue, "index %d >= 1", index)
174                 }
175                 return checkEntry(d)
176         }
177
178         return false, vm.ErrContext
179 }