OSDN Git Service

new repo
[bytom/vapor.git] / protocol / validation / vmcontext.go
1 package validation
2
3 import (
4         "bytes"
5
6         "github.com/vapor/consensus/segwit"
7         "github.com/vapor/crypto/sha3pool"
8         "github.com/vapor/errors"
9         "github.com/vapor/protocol/bc"
10         "github.com/vapor/protocol/vm"
11 )
12
13 // NewTxVMContext generates the vm.Context for BVM
14 func NewTxVMContext(vs *validationState, entry bc.Entry, prog *bc.Program, args [][]byte) *vm.Context {
15         var (
16                 tx          = vs.tx
17                 blockHeight = vs.block.BlockHeader.GetHeight()
18                 numResults  = uint64(len(tx.ResultIds))
19                 entryID     = bc.EntryID(entry) // TODO(bobg): pass this in, don't recompute it
20
21                 assetID       *[]byte
22                 amount        *uint64
23                 destPos       *uint64
24                 spentOutputID *[]byte
25         )
26
27         switch e := entry.(type) {
28         case *bc.Issuance:
29                 a1 := e.Value.AssetId.Bytes()
30                 assetID = &a1
31                 amount = &e.Value.Amount
32                 destPos = &e.WitnessDestination.Position
33
34         case *bc.Spend:
35                 spentOutput := tx.Entries[*e.SpentOutputId].(*bc.Output)
36                 a1 := spentOutput.Source.Value.AssetId.Bytes()
37                 assetID = &a1
38                 amount = &spentOutput.Source.Value.Amount
39                 destPos = &e.WitnessDestination.Position
40                 s := e.SpentOutputId.Bytes()
41                 spentOutputID = &s
42         }
43
44         var txSigHash *[]byte
45         txSigHashFn := func() []byte {
46                 if txSigHash == nil {
47                         hasher := sha3pool.Get256()
48                         defer sha3pool.Put256(hasher)
49
50                         entryID.WriteTo(hasher)
51                         tx.ID.WriteTo(hasher)
52
53                         var hash bc.Hash
54                         hash.ReadFrom(hasher)
55                         hashBytes := hash.Bytes()
56                         txSigHash = &hashBytes
57                 }
58                 return *txSigHash
59         }
60
61         ec := &entryContext{
62                 entry:   entry,
63                 entries: tx.Entries,
64         }
65
66         result := &vm.Context{
67                 VMVersion: prog.VmVersion,
68                 Code:      witnessProgram(prog.Code),
69                 Arguments: args,
70
71                 EntryID: entryID.Bytes(),
72
73                 TxVersion:   &tx.Version,
74                 BlockHeight: &blockHeight,
75
76                 TxSigHash:     txSigHashFn,
77                 NumResults:    &numResults,
78                 AssetID:       assetID,
79                 Amount:        amount,
80                 DestPos:       destPos,
81                 SpentOutputID: spentOutputID,
82                 CheckOutput:   ec.checkOutput,
83         }
84
85         return result
86 }
87
88 func witnessProgram(prog []byte) []byte {
89         if segwit.IsP2WPKHScript(prog) {
90                 if witnessProg, err := segwit.ConvertP2PKHSigProgram([]byte(prog)); err == nil {
91                         return witnessProg
92                 }
93         } else if segwit.IsP2WSHScript(prog) {
94                 if witnessProg, err := segwit.ConvertP2SHProgram([]byte(prog)); err == nil {
95                         return witnessProg
96                 }
97         }
98         return prog
99 }
100
101 type entryContext struct {
102         entry   bc.Entry
103         entries map[bc.Hash]bc.Entry
104 }
105
106 func (ec *entryContext) checkOutput(index uint64, amount uint64, assetID []byte, vmVersion uint64, code []byte, expansion bool) (bool, error) {
107         checkEntry := func(e bc.Entry) (bool, error) {
108                 check := func(prog *bc.Program, value *bc.AssetAmount) bool {
109                         return (prog.VmVersion == vmVersion &&
110                                 bytes.Equal(prog.Code, code) &&
111                                 bytes.Equal(value.AssetId.Bytes(), assetID) &&
112                                 value.Amount == amount)
113                 }
114
115                 switch e := e.(type) {
116                 case *bc.Output:
117                         return check(e.ControlProgram, e.Source.Value), nil
118
119                 case *bc.Retirement:
120                         var prog bc.Program
121                         if expansion {
122                                 // The spec requires prog.Code to be the empty string only
123                                 // when !expansion. When expansion is true, we prepopulate
124                                 // prog.Code to give check() a freebie match.
125                                 //
126                                 // (The spec always requires prog.VmVersion to be zero.)
127                                 prog.Code = code
128                         }
129                         return check(&prog, e.Source.Value), nil
130                 }
131
132                 return false, vm.ErrContext
133         }
134
135         checkMux := func(m *bc.Mux) (bool, error) {
136                 if index >= uint64(len(m.WitnessDestinations)) {
137                         return false, errors.Wrapf(vm.ErrBadValue, "index %d >= %d", index, len(m.WitnessDestinations))
138                 }
139                 eID := m.WitnessDestinations[index].Ref
140                 e, ok := ec.entries[*eID]
141                 if !ok {
142                         return false, errors.Wrapf(bc.ErrMissingEntry, "entry for mux destination %d, id %x, not found", index, eID.Bytes())
143                 }
144                 return checkEntry(e)
145         }
146
147         switch e := ec.entry.(type) {
148         case *bc.Mux:
149                 return checkMux(e)
150
151         case *bc.Issuance:
152                 d, ok := ec.entries[*e.WitnessDestination.Ref]
153                 if !ok {
154                         return false, errors.Wrapf(bc.ErrMissingEntry, "entry for issuance destination %x not found", e.WitnessDestination.Ref.Bytes())
155                 }
156                 if m, ok := d.(*bc.Mux); ok {
157                         return checkMux(m)
158                 }
159                 if index != 0 {
160                         return false, errors.Wrapf(vm.ErrBadValue, "index %d >= 1", index)
161                 }
162                 return checkEntry(d)
163
164         case *bc.Spend:
165                 d, ok := ec.entries[*e.WitnessDestination.Ref]
166                 if !ok {
167                         return false, errors.Wrapf(bc.ErrMissingEntry, "entry for spend destination %x not found", e.WitnessDestination.Ref.Bytes())
168                 }
169                 if m, ok := d.(*bc.Mux); ok {
170                         return checkMux(m)
171                 }
172                 if index != 0 {
173                         return false, errors.Wrapf(vm.ErrBadValue, "index %d >= 1", index)
174                 }
175                 return checkEntry(d)
176         }
177
178         return false, vm.ErrContext
179 }