OSDN Git Service

Remove transaction reference data (#416)
[bytom/bytom.git] / protocol / validation / vmcontext.go
1 package validation
2
3 import (
4         "bytes"
5
6         "github.com/bytom/consensus/segwit"
7         "github.com/bytom/crypto/sha3pool"
8         "github.com/bytom/errors"
9         "github.com/bytom/protocol/bc"
10         "github.com/bytom/protocol/vm"
11 )
12
13 // NewTxVMContext generates the vm.Context for BVM
14 func NewTxVMContext(vs *validationState, entry bc.Entry, prog *bc.Program, args [][]byte) *vm.Context {
15         var (
16                 tx          = vs.tx
17                 blockHeight = vs.block.BlockHeader.GetHeight()
18                 numResults  = uint64(len(tx.ResultIds))
19                 entryID     = bc.EntryID(entry) // TODO(bobg): pass this in, don't recompute it
20
21                 assetID       *[]byte
22                 amount        *uint64
23                 entryData     *[]byte
24                 destPos       *uint64
25                 anchorID      *[]byte
26                 spentOutputID *[]byte
27         )
28
29         switch e := entry.(type) {
30         case *bc.Nonce:
31                 anchored := tx.Entries[*e.WitnessAnchoredId]
32                 if iss, ok := anchored.(*bc.Issuance); ok {
33                         a1 := iss.Value.AssetId.Bytes()
34                         assetID = &a1
35                         amount = &iss.Value.Amount
36                 }
37
38         case *bc.Issuance:
39                 a1 := e.Value.AssetId.Bytes()
40                 assetID = &a1
41                 amount = &e.Value.Amount
42                 destPos = &e.WitnessDestination.Position
43                 d := e.Data.Bytes()
44                 entryData = &d
45                 a2 := e.AnchorId.Bytes()
46                 anchorID = &a2
47
48         case *bc.Spend:
49                 spentOutput := tx.Entries[*e.SpentOutputId].(*bc.Output)
50                 a1 := spentOutput.Source.Value.AssetId.Bytes()
51                 assetID = &a1
52                 amount = &spentOutput.Source.Value.Amount
53                 destPos = &e.WitnessDestination.Position
54                 d := e.Data.Bytes()
55                 entryData = &d
56                 s := e.SpentOutputId.Bytes()
57                 spentOutputID = &s
58
59         case *bc.Output:
60                 d := e.Data.Bytes()
61                 entryData = &d
62
63         case *bc.Retirement:
64                 d := e.Data.Bytes()
65                 entryData = &d
66         }
67
68         var txSigHash *[]byte
69         txSigHashFn := func() []byte {
70                 if txSigHash == nil {
71                         hasher := sha3pool.Get256()
72                         defer sha3pool.Put256(hasher)
73
74                         entryID.WriteTo(hasher)
75                         tx.ID.WriteTo(hasher)
76
77                         var hash bc.Hash
78                         hash.ReadFrom(hasher)
79                         hashBytes := hash.Bytes()
80                         txSigHash = &hashBytes
81                 }
82                 return *txSigHash
83         }
84
85         ec := &entryContext{
86                 entry:   entry,
87                 entries: tx.Entries,
88         }
89
90         result := &vm.Context{
91                 VMVersion: prog.VmVersion,
92                 Code:      witnessProgram(prog.Code),
93                 Arguments: args,
94
95                 EntryID: entryID.Bytes(),
96
97                 TxVersion:   &tx.Version,
98                 BlockHeight: &blockHeight,
99
100                 TxSigHash:     txSigHashFn,
101                 NumResults:    &numResults,
102                 AssetID:       assetID,
103                 Amount:        amount,
104                 EntryData:     entryData,
105                 DestPos:       destPos,
106                 AnchorID:      anchorID,
107                 SpentOutputID: spentOutputID,
108                 CheckOutput:   ec.checkOutput,
109         }
110
111         return result
112 }
113
114 func witnessProgram(prog []byte) []byte {
115         if segwit.IsP2WPKHScript(prog) {
116                 if witnessProg, err := segwit.ConvertP2PKHSigProgram([]byte(prog)); err == nil {
117                         return witnessProg
118                 }
119         } else if segwit.IsP2WSHScript(prog) {
120                 if witnessProg, err := segwit.ConvertP2SHProgram([]byte(prog)); err == nil {
121                         return witnessProg
122                 }
123         }
124         return prog
125 }
126
127 type entryContext struct {
128         entry   bc.Entry
129         entries map[bc.Hash]bc.Entry
130 }
131
132 func (ec *entryContext) checkOutput(index uint64, data []byte, amount uint64, assetID []byte, vmVersion uint64, code []byte, expansion bool) (bool, error) {
133         checkEntry := func(e bc.Entry) (bool, error) {
134                 check := func(prog *bc.Program, value *bc.AssetAmount, dataHash *bc.Hash) bool {
135                         return (prog.VmVersion == vmVersion &&
136                                 bytes.Equal(prog.Code, code) &&
137                                 bytes.Equal(value.AssetId.Bytes(), assetID) &&
138                                 value.Amount == amount &&
139                                 (len(data) == 0 || bytes.Equal(dataHash.Bytes(), data)))
140                 }
141
142                 switch e := e.(type) {
143                 case *bc.Output:
144                         return check(e.ControlProgram, e.Source.Value, e.Data), nil
145
146                 case *bc.Retirement:
147                         var prog bc.Program
148                         if expansion {
149                                 // The spec requires prog.Code to be the empty string only
150                                 // when !expansion. When expansion is true, we prepopulate
151                                 // prog.Code to give check() a freebie match.
152                                 //
153                                 // (The spec always requires prog.VmVersion to be zero.)
154                                 prog.Code = code
155                         }
156                         return check(&prog, e.Source.Value, e.Data), nil
157                 }
158
159                 return false, vm.ErrContext
160         }
161
162         checkMux := func(m *bc.Mux) (bool, error) {
163                 if index >= uint64(len(m.WitnessDestinations)) {
164                         return false, errors.Wrapf(vm.ErrBadValue, "index %d >= %d", index, len(m.WitnessDestinations))
165                 }
166                 eID := m.WitnessDestinations[index].Ref
167                 e, ok := ec.entries[*eID]
168                 if !ok {
169                         return false, errors.Wrapf(bc.ErrMissingEntry, "entry for mux destination %d, id %x, not found", index, eID.Bytes())
170                 }
171                 return checkEntry(e)
172         }
173
174         switch e := ec.entry.(type) {
175         case *bc.Mux:
176                 return checkMux(e)
177
178         case *bc.Issuance:
179                 d, ok := ec.entries[*e.WitnessDestination.Ref]
180                 if !ok {
181                         return false, errors.Wrapf(bc.ErrMissingEntry, "entry for issuance destination %x not found", e.WitnessDestination.Ref.Bytes())
182                 }
183                 if m, ok := d.(*bc.Mux); ok {
184                         return checkMux(m)
185                 }
186                 if index != 0 {
187                         return false, errors.Wrapf(vm.ErrBadValue, "index %d >= 1", index)
188                 }
189                 return checkEntry(d)
190
191         case *bc.Spend:
192                 d, ok := ec.entries[*e.WitnessDestination.Ref]
193                 if !ok {
194                         return false, errors.Wrapf(bc.ErrMissingEntry, "entry for spend destination %x not found", e.WitnessDestination.Ref.Bytes())
195                 }
196                 if m, ok := d.(*bc.Mux); ok {
197                         return checkMux(m)
198                 }
199                 if index != 0 {
200                         return false, errors.Wrapf(vm.ErrBadValue, "index %d >= 1", index)
201                 }
202                 return checkEntry(d)
203         }
204
205         return false, vm.ErrContext
206 }