OSDN Git Service

fix votetx for validation (#75)
[bytom/vapor.git] / protocol / validation / vmcontext.go
1 package validation
2
3 import (
4         "bytes"
5
6         "github.com/vapor/consensus/segwit"
7         "github.com/vapor/crypto/sha3pool"
8         "github.com/vapor/errors"
9         "github.com/vapor/protocol/bc"
10         "github.com/vapor/protocol/vm"
11 )
12
13 // NewTxVMContext generates the vm.Context for BVM
14 func NewTxVMContext(vs *validationState, entry bc.Entry, prog *bc.Program, args [][]byte) *vm.Context {
15         var (
16                 tx          = vs.tx
17                 blockHeight = vs.block.BlockHeader.GetHeight()
18                 numResults  = uint64(len(tx.ResultIds))
19                 entryID     = bc.EntryID(entry) // TODO(bobg): pass this in, don't recompute it
20
21                 assetID       *[]byte
22                 amount        *uint64
23                 destPos       *uint64
24                 spentOutputID *[]byte
25         )
26
27         switch e := entry.(type) {
28         case *bc.Spend:
29                 switch spentOutput := tx.Entries[*e.SpentOutputId].(type) {
30                 case *bc.IntraChainOutput:
31                         a1 := spentOutput.Source.Value.AssetId.Bytes()
32                         assetID = &a1
33                         amount = &spentOutput.Source.Value.Amount
34                         destPos = &e.WitnessDestination.Position
35                         s := e.SpentOutputId.Bytes()
36                         spentOutputID = &s
37                 case *bc.VoteOutput:
38                         a1 := spentOutput.Source.Value.AssetId.Bytes()
39                         assetID = &a1
40                         amount = &spentOutput.Source.Value.Amount
41                         destPos = &e.WitnessDestination.Position
42                         s := e.SpentOutputId.Bytes()
43                         spentOutputID = &s
44                 }
45         }
46
47         var txSigHash *[]byte
48         txSigHashFn := func() []byte {
49                 if txSigHash == nil {
50                         hasher := sha3pool.Get256()
51                         defer sha3pool.Put256(hasher)
52
53                         entryID.WriteTo(hasher)
54                         tx.ID.WriteTo(hasher)
55
56                         var hash bc.Hash
57                         hash.ReadFrom(hasher)
58                         hashBytes := hash.Bytes()
59                         txSigHash = &hashBytes
60                 }
61                 return *txSigHash
62         }
63
64         ec := &entryContext{
65                 entry:   entry,
66                 entries: tx.Entries,
67         }
68
69         result := &vm.Context{
70                 VMVersion: prog.VmVersion,
71                 Code:      witnessProgram(prog.Code),
72                 Arguments: args,
73
74                 EntryID: entryID.Bytes(),
75
76                 TxVersion:   &tx.Version,
77                 BlockHeight: &blockHeight,
78
79                 TxSigHash:     txSigHashFn,
80                 NumResults:    &numResults,
81                 AssetID:       assetID,
82                 Amount:        amount,
83                 DestPos:       destPos,
84                 SpentOutputID: spentOutputID,
85                 CheckOutput:   ec.checkOutput,
86         }
87
88         return result
89 }
90
91 func witnessProgram(prog []byte) []byte {
92         if segwit.IsP2WPKHScript(prog) {
93                 if witnessProg, err := segwit.ConvertP2PKHSigProgram([]byte(prog)); err == nil {
94                         return witnessProg
95                 }
96         } else if segwit.IsP2WSHScript(prog) {
97                 if witnessProg, err := segwit.ConvertP2SHProgram([]byte(prog)); err == nil {
98                         return witnessProg
99                 }
100         }
101         return prog
102 }
103
104 type entryContext struct {
105         entry   bc.Entry
106         entries map[bc.Hash]bc.Entry
107 }
108
109 func (ec *entryContext) checkOutput(index uint64, amount uint64, assetID []byte, vmVersion uint64, code []byte, expansion bool) (bool, error) {
110         checkEntry := func(e bc.Entry) (bool, error) {
111                 check := func(prog *bc.Program, value *bc.AssetAmount) bool {
112                         return (prog.VmVersion == vmVersion &&
113                                 bytes.Equal(prog.Code, code) &&
114                                 bytes.Equal(value.AssetId.Bytes(), assetID) &&
115                                 value.Amount == amount)
116                 }
117
118                 switch e := e.(type) {
119                 case *bc.IntraChainOutput:
120                         return check(e.ControlProgram, e.Source.Value), nil
121
122                 case *bc.VoteOutput:
123                         return check(e.ControlProgram, e.Source.Value), nil
124
125                 case *bc.Retirement:
126                         var prog bc.Program
127                         if expansion {
128                                 // The spec requires prog.Code to be the empty string only
129                                 // when !expansion. When expansion is true, we prepopulate
130                                 // prog.Code to give check() a freebie match.
131                                 //
132                                 // (The spec always requires prog.VmVersion to be zero.)
133                                 prog.Code = code
134                         }
135                         return check(&prog, e.Source.Value), nil
136                 }
137
138                 return false, vm.ErrContext
139         }
140
141         checkMux := func(m *bc.Mux) (bool, error) {
142                 if index >= uint64(len(m.WitnessDestinations)) {
143                         return false, errors.Wrapf(vm.ErrBadValue, "index %d >= %d", index, len(m.WitnessDestinations))
144                 }
145                 eID := m.WitnessDestinations[index].Ref
146                 e, ok := ec.entries[*eID]
147                 if !ok {
148                         return false, errors.Wrapf(bc.ErrMissingEntry, "entry for mux destination %d, id %x, not found", index, eID.Bytes())
149                 }
150                 return checkEntry(e)
151         }
152
153         switch e := ec.entry.(type) {
154         case *bc.Mux:
155                 return checkMux(e)
156
157         case *bc.Spend:
158                 d, ok := ec.entries[*e.WitnessDestination.Ref]
159                 if !ok {
160                         return false, errors.Wrapf(bc.ErrMissingEntry, "entry for spend destination %x not found", e.WitnessDestination.Ref.Bytes())
161                 }
162                 if m, ok := d.(*bc.Mux); ok {
163                         return checkMux(m)
164                 }
165                 if index != 0 {
166                         return false, errors.Wrapf(vm.ErrBadValue, "index %d >= 1", index)
167                 }
168                 return checkEntry(d)
169         }
170
171         return false, vm.ErrContext
172 }