OSDN Git Service

add tx unit test
[bytom/bytom.git] / test / wallet_test_util.go
1 package test
2
3 import (
4         "fmt"
5         "io/ioutil"
6         "os"
7         "path"
8         "reflect"
9
10         dbm "github.com/tendermint/tmlibs/db"
11
12         "github.com/bytom/account"
13         "github.com/bytom/asset"
14         "github.com/bytom/blockchain/pseudohsm"
15         "github.com/bytom/blockchain/signers"
16         "github.com/bytom/crypto/ed25519/chainkd"
17         "github.com/bytom/event"
18         "github.com/bytom/protocol"
19         "github.com/bytom/protocol/bc/types"
20         w "github.com/bytom/wallet"
21 )
22
23 type walletTestConfig struct {
24         Keys       []*keyInfo     `json:"keys"`
25         Accounts   []*accountInfo `json:"accounts"`
26         Blocks     []*wtBlock     `json:"blocks"`
27         RollbackTo uint64         `json:"rollback_to"`
28 }
29
30 type keyInfo struct {
31         Name     string `json:"name"`
32         Password string `json:"password"`
33 }
34
35 type accountInfo struct {
36         Name   string   `json:"name"`
37         Keys   []string `json:"keys"`
38         Quorum int      `json:"quorum"`
39 }
40
41 type wtBlock struct {
42         CoinbaseAccount string            `json:"coinbase_account"`
43         Transactions    []*wtTransaction  `json:"transactions"`
44         PostStates      []*accountBalance `json:"post_states"`
45         Append          uint64            `json:"append"`
46 }
47
48 func (b *wtBlock) create(ctx *walletTestContext) (*types.Block, error) {
49         transactions := make([]*types.Tx, 0, len(b.Transactions))
50         for _, t := range b.Transactions {
51                 tx, err := t.create(ctx)
52                 if err != nil {
53                         continue
54                 }
55                 transactions = append(transactions, tx)
56         }
57         return ctx.newBlock(transactions, b.CoinbaseAccount)
58 }
59
60 func (b *wtBlock) verifyPostStates(ctx *walletTestContext) error {
61         for _, state := range b.PostStates {
62                 balance, err := ctx.getBalance(state.AccountAlias, state.AssetAlias)
63                 if err != nil {
64                         return err
65                 }
66
67                 if balance != state.Amount {
68                         return fmt.Errorf("AccountAlias: %s, AssetAlias: %s, expected: %d, have: %d", state.AccountAlias, state.AssetAlias, state.Amount, balance)
69                 }
70         }
71         return nil
72 }
73
74 type wtTransaction struct {
75         Passwords []string  `json:"passwords"`
76         Inputs    []*action `json:"inputs"`
77         Outputs   []*action `json:"outputs"`
78 }
79
80 // create signed transaction
81 func (t *wtTransaction) create(ctx *walletTestContext) (*types.Tx, error) {
82         generator := NewTxGenerator(ctx.Wallet.AccountMgr, ctx.Wallet.AssetReg, ctx.Wallet.Hsm)
83         for _, input := range t.Inputs {
84                 switch input.Type {
85                 case "spend_account":
86                         if err := generator.AddSpendInput(input.AccountAlias, input.AssetAlias, input.Amount); err != nil {
87                                 return nil, err
88                         }
89                 case "issue":
90                         _, err := ctx.createAsset(input.AccountAlias, input.AssetAlias)
91                         if err != nil {
92                                 return nil, err
93                         }
94                         if err := generator.AddIssuanceInput(input.AssetAlias, input.Amount); err != nil {
95                                 return nil, err
96                         }
97                 }
98         }
99
100         for _, output := range t.Outputs {
101                 switch output.Type {
102                 case "output":
103                         if err := generator.AddTxOutput(output.AccountAlias, output.AssetAlias, output.Amount); err != nil {
104                                 return nil, err
105                         }
106                 case "retire":
107                         if err := generator.AddRetirement(output.AssetAlias, output.Amount); err != nil {
108                                 return nil, err
109                         }
110                 }
111         }
112         return generator.Sign(t.Passwords)
113 }
114
115 type action struct {
116         Type         string `json:"type"`
117         AccountAlias string `json:"name"`
118         AssetAlias   string `json:"asset"`
119         Amount       uint64 `json:"amount"`
120 }
121
122 type accountBalance struct {
123         AssetAlias   string `json:"asset"`
124         AccountAlias string `json:"name"`
125         Amount       uint64 `json:"amount"`
126 }
127
128 type walletTestContext struct {
129         Wallet *w.Wallet
130         Chain  *protocol.Chain
131 }
132
133 func (ctx *walletTestContext) createControlProgram(accountName string, change bool) (*account.CtrlProgram, error) {
134         acc, err := ctx.Wallet.AccountMgr.FindByAlias(accountName)
135         if err != nil {
136                 return nil, err
137         }
138
139         return ctx.Wallet.AccountMgr.CreateAddress(acc.ID, change)
140 }
141
142 func (ctx *walletTestContext) getPubkey(keyAlias string) *chainkd.XPub {
143         pubKeys := ctx.Wallet.Hsm.ListKeys()
144         for i, key := range pubKeys {
145                 if key.Alias == keyAlias {
146                         return &pubKeys[i].XPub
147                 }
148         }
149         return nil
150 }
151
152 func (ctx *walletTestContext) createAsset(accountAlias string, assetAlias string) (*asset.Asset, error) {
153         acc, err := ctx.Wallet.AccountMgr.FindByAlias(accountAlias)
154         if err != nil {
155                 return nil, err
156         }
157         return ctx.Wallet.AssetReg.Define(acc.XPubs, len(acc.XPubs), nil, assetAlias, nil)
158 }
159
160 func (ctx *walletTestContext) newBlock(txs []*types.Tx, coinbaseAccount string) (*types.Block, error) {
161         controlProgram, err := ctx.createControlProgram(coinbaseAccount, true)
162         if err != nil {
163                 return nil, err
164         }
165         return NewBlock(ctx.Chain, txs, controlProgram.ControlProgram)
166 }
167
168 func (ctx *walletTestContext) createKey(name string, password string) error {
169         _, _, err := ctx.Wallet.Hsm.XCreate(name, password, "en")
170         return err
171 }
172
173 func (ctx *walletTestContext) createAccount(name string, keys []string, quorum int) error {
174         xpubs := []chainkd.XPub{}
175         for _, alias := range keys {
176                 xpub := ctx.getPubkey(alias)
177                 if xpub == nil {
178                         return fmt.Errorf("can't find pubkey for %s", alias)
179                 }
180                 xpubs = append(xpubs, *xpub)
181         }
182         _, err := ctx.Wallet.AccountMgr.Create(xpubs, quorum, name, signers.BIP0044)
183         return err
184 }
185
186 func (ctx *walletTestContext) update(block *types.Block) error {
187         if err := SolveAndUpdate(ctx.Chain, block); err != nil {
188                 return err
189         }
190         if err := ctx.Wallet.AttachBlock(block); err != nil {
191                 return err
192         }
193         return nil
194 }
195
196 func (ctx *walletTestContext) getBalance(accountAlias string, assetAlias string) (uint64, error) {
197         balances, _ := ctx.Wallet.GetAccountBalances("", "")
198         for _, balance := range balances {
199                 if balance.Alias == accountAlias && balance.AssetAlias == assetAlias {
200                         return balance.Amount, nil
201                 }
202         }
203         return 0, nil
204 }
205
206 func (ctx *walletTestContext) getAccBalances() map[string]map[string]uint64 {
207         accBalances := make(map[string]map[string]uint64)
208         balances, _ := ctx.Wallet.GetAccountBalances("", "")
209         for _, balance := range balances {
210                 if accBalance, ok := accBalances[balance.Alias]; ok {
211                         if _, ok := accBalance[balance.AssetAlias]; ok {
212                                 accBalance[balance.AssetAlias] += balance.Amount
213                                 continue
214                         }
215                         accBalance[balance.AssetAlias] = balance.Amount
216                         continue
217                 }
218                 accBalances[balance.Alias] = map[string]uint64{balance.AssetAlias: balance.Amount}
219         }
220         return accBalances
221 }
222
223 func (ctx *walletTestContext) getDetachedBlocks(height uint64) ([]*types.Block, error) {
224         currentHeight := ctx.Chain.BestBlockHeight()
225         detachedBlocks := make([]*types.Block, 0, currentHeight-height)
226         for i := currentHeight; i > height; i-- {
227                 block, err := ctx.Chain.GetBlockByHeight(i)
228                 if err != nil {
229                         return detachedBlocks, err
230                 }
231                 detachedBlocks = append(detachedBlocks, block)
232         }
233         return detachedBlocks, nil
234 }
235
236 func (ctx *walletTestContext) validateRollback(oldAccBalances map[string]map[string]uint64) error {
237         accBalances := ctx.getAccBalances()
238         if reflect.DeepEqual(oldAccBalances, accBalances) {
239                 return nil
240         }
241         return fmt.Errorf("different account balances after rollback")
242 }
243
244 func (cfg *walletTestConfig) Run() error {
245         dirPath, err := ioutil.TempDir(".", "pseudo_hsm")
246         if err != nil {
247                 return err
248         }
249         defer os.RemoveAll(dirPath)
250         hsm, err := pseudohsm.New(dirPath)
251         if err != nil {
252                 return err
253         }
254
255         db := dbm.NewDB("wallet_test_db", "leveldb", path.Join(dirPath, "wallet_test_db"))
256         chain, _, _, err := MockChain(db)
257         if err != nil {
258                 return err
259         }
260         walletDB := dbm.NewDB("wallet", "leveldb", path.Join(dirPath, "wallet_db"))
261         accountManager := account.NewManager(walletDB, chain)
262         assets := asset.NewRegistry(walletDB, chain)
263         dispatcher := event.NewDispatcher()
264         wallet, err := w.NewWallet(walletDB, accountManager, assets, hsm, chain, dispatcher)
265         if err != nil {
266                 return err
267         }
268         ctx := &walletTestContext{
269                 Wallet: wallet,
270                 Chain:  chain,
271         }
272
273         for _, key := range cfg.Keys {
274                 if err := ctx.createKey(key.Name, key.Password); err != nil {
275                         return err
276                 }
277         }
278
279         for _, acc := range cfg.Accounts {
280                 if err := ctx.createAccount(acc.Name, acc.Keys, acc.Quorum); err != nil {
281                         return err
282                 }
283         }
284
285         var accBalances map[string]map[string]uint64
286         var rollbackBlock *types.Block
287         for _, blk := range cfg.Blocks {
288                 block, err := blk.create(ctx)
289                 if err != nil {
290                         return err
291                 }
292                 if err := ctx.update(block); err != nil {
293                         return err
294                 }
295                 if err := blk.verifyPostStates(ctx); err != nil {
296                         return err
297                 }
298                 if block.Height <= cfg.RollbackTo && cfg.RollbackTo <= block.Height+blk.Append {
299                         accBalances = ctx.getAccBalances()
300                         rollbackBlock = block
301                 }
302                 if err := AppendBlocks(ctx.Chain, blk.Append); err != nil {
303                         return err
304                 }
305         }
306
307         if rollbackBlock == nil {
308                 return nil
309         }
310
311         // rollback and validate
312         detachedBlocks, err := ctx.getDetachedBlocks(rollbackBlock.Height)
313         if err != nil {
314                 return err
315         }
316
317         forkPath, err := ioutil.TempDir(".", "forked_chain")
318         if err != nil {
319                 return err
320         }
321
322         forkedChain, err := declChain(forkPath, ctx.Chain, rollbackBlock.Height, ctx.Chain.BestBlockHeight()+1)
323         defer os.RemoveAll(forkPath)
324         if err != nil {
325                 return err
326         }
327
328         if err := merge(forkedChain, ctx.Chain); err != nil {
329                 return err
330         }
331
332         for _, block := range detachedBlocks {
333                 if err := ctx.Wallet.DetachBlock(block); err != nil {
334                         return err
335                 }
336         }
337         return ctx.validateRollback(accBalances)
338 }