OSDN Git Service

update
[bytom/vapor.git] / wallet / wallet_test.go
1 package wallet
2
3 import (
4         "encoding/binary"
5         "encoding/json"
6         "io/ioutil"
7         "os"
8         "reflect"
9         "testing"
10         "time"
11
12         "github.com/vapor/account"
13         acc "github.com/vapor/account"
14         "github.com/vapor/asset"
15         "github.com/vapor/blockchain/pseudohsm"
16         "github.com/vapor/blockchain/query"
17         "github.com/vapor/blockchain/signers"
18         "github.com/vapor/blockchain/txbuilder"
19         "github.com/vapor/config"
20         "github.com/vapor/consensus"
21         "github.com/vapor/crypto/ed25519/chainkd"
22         "github.com/vapor/database"
23         dbm "github.com/vapor/database/leveldb"
24         "github.com/vapor/errors"
25         "github.com/vapor/event"
26         "github.com/vapor/protocol"
27         "github.com/vapor/protocol/bc"
28         "github.com/vapor/protocol/bc/types"
29 )
30
31 func TestEncodeDecodeGlobalTxIndex(t *testing.T) {
32         want := &struct {
33                 BlockHash bc.Hash
34                 Position  uint64
35         }{
36                 BlockHash: bc.NewHash([32]byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20}),
37                 Position:  1,
38         }
39
40         globalTxIdx := database.CalcGlobalTxIndex(&want.BlockHash, want.Position)
41         blockHashGot, positionGot := parseGlobalTxIdx(globalTxIdx)
42         if *blockHashGot != want.BlockHash {
43                 t.Errorf("blockHash mismatch. Get: %v. Expect: %v", *blockHashGot, want.BlockHash)
44         }
45
46         if positionGot != want.Position {
47                 t.Errorf("position mismatch. Get: %v. Expect: %v", positionGot, want.Position)
48         }
49 }
50
51 func TestWalletVersion(t *testing.T) {
52         // prepare wallet
53         dirPath, err := ioutil.TempDir(".", "")
54         if err != nil {
55                 t.Fatal(err)
56         }
57         defer os.RemoveAll(dirPath)
58
59         testDB := dbm.NewDB("testdb", "leveldb", "temp")
60         walletStore := newMockWalletStore(testDB)
61         defer func() {
62                 testDB.Close()
63                 os.RemoveAll("temp")
64         }()
65
66         dispatcher := event.NewDispatcher()
67         w := mockWallet(walletStore, nil, nil, nil, dispatcher, false)
68
69         // legacy status test case
70         type legacyStatusInfo struct {
71                 WorkHeight uint64
72                 WorkHash   bc.Hash
73                 BestHeight uint64
74                 BestHash   bc.Hash
75         }
76         rawWallet, err := json.Marshal(legacyStatusInfo{})
77         if err != nil {
78                 t.Fatal("Marshal legacyStatusInfo")
79         }
80
81         w.store.SetWalletInfo(rawWallet)
82         rawWallet = w.store.GetWalletInfo()
83         if rawWallet == nil {
84                 t.Fatal("fail to load wallet StatusInfo")
85         }
86
87         if err := json.Unmarshal(rawWallet, &w.status); err != nil {
88                 t.Fatal(err)
89         }
90
91         if err := w.checkWalletInfo(); err != errWalletVersionMismatch {
92                 t.Fatal("fail to detect legacy wallet version")
93         }
94
95         // lower wallet version test case
96         lowerVersion := StatusInfo{Version: currentVersion - 1}
97         rawWallet, err = json.Marshal(lowerVersion)
98         if err != nil {
99                 t.Fatal("save wallet info")
100         }
101
102         w.store.SetWalletInfo(rawWallet)
103         rawWallet = w.store.GetWalletInfo()
104         if rawWallet == nil {
105                 t.Fatal("fail to load wallet StatusInfo")
106         }
107
108         if err := json.Unmarshal(rawWallet, &w.status); err != nil {
109                 t.Fatal(err)
110         }
111
112         if err := w.checkWalletInfo(); err != errWalletVersionMismatch {
113                 t.Fatal("fail to detect expired wallet version")
114         }
115 }
116
117 func TestWalletUpdate(t *testing.T) {
118         dirPath, err := ioutil.TempDir(".", "")
119         if err != nil {
120                 t.Fatal(err)
121         }
122         defer os.RemoveAll(dirPath)
123
124         config.CommonConfig = config.DefaultConfig()
125         testDB := dbm.NewDB("testdb", "leveldb", "temp")
126         defer func() {
127                 testDB.Close()
128                 os.RemoveAll("temp")
129         }()
130
131         store := database.NewStore(testDB)
132         walletStore := database.NewWalletStore(testDB)
133         // walletStore := newMockWalletStore(testDB)
134         dispatcher := event.NewDispatcher()
135         txPool := protocol.NewTxPool(store, dispatcher)
136
137         chain, err := protocol.NewChain(store, txPool, dispatcher)
138         if err != nil {
139                 t.Fatal(err)
140         }
141
142         accountStore := database.NewAccountStore(testDB)
143         accountManager := account.NewManager(accountStore, chain)
144         hsm, err := pseudohsm.New(dirPath)
145         if err != nil {
146                 t.Fatal(err)
147         }
148
149         xpub1, _, err := hsm.XCreate("test_pub1", "password", "en")
150         if err != nil {
151                 t.Fatal(err)
152         }
153
154         testAccount, err := accountManager.Create([]chainkd.XPub{xpub1.XPub}, 1, "testAccount", signers.BIP0044)
155         if err != nil {
156                 t.Fatal(err)
157         }
158
159         controlProg, err := accountManager.CreateAddress(testAccount.ID, false)
160         if err != nil {
161                 t.Fatal(err)
162         }
163
164         controlProg.KeyIndex = 1
165
166         reg := asset.NewRegistry(testDB, chain)
167         asset := bc.AssetID{V0: 5}
168
169         utxos := []*account.UTXO{}
170         btmUtxo := mockUTXO(controlProg, consensus.BTMAssetID)
171         utxos = append(utxos, btmUtxo)
172         OtherUtxo := mockUTXO(controlProg, &asset)
173         utxos = append(utxos, OtherUtxo)
174
175         _, txData, err := mockTxData(utxos, testAccount)
176         if err != nil {
177                 t.Fatal(err)
178         }
179
180         tx := types.NewTx(*txData)
181         block := mockSingleBlock(tx)
182         txStatus := bc.NewTransactionStatus()
183         txStatus.SetStatus(0, false)
184         txStatus.SetStatus(1, false)
185         store.SaveBlock(block, txStatus)
186
187         w := mockWallet(walletStore, accountManager, reg, chain, dispatcher, true)
188         err = w.AttachBlock(block)
189         if err != nil {
190                 t.Fatal(err)
191         }
192
193         if _, err := w.GetTransactionByTxID(tx.ID.String()); err != nil {
194                 t.Fatal(err)
195         }
196
197         wants, err := w.GetTransactions(testAccount.ID, "", 1, false)
198         if len(wants) != 1 {
199                 t.Fatal(err)
200         }
201
202         if wants[0].ID != tx.ID {
203                 t.Fatal("account txID mismatch")
204         }
205
206         for position, tx := range block.Transactions {
207                 get := w.store.GetGlobalTransactionIndex(tx.ID.String())
208                 bh := block.BlockHeader.Hash()
209                 expect := CalcGlobalTxIndex(&bh, uint64(position))
210                 if !reflect.DeepEqual(get, expect) {
211                         t.Fatalf("position#%d: compare retrieved globalTxIdx err", position)
212                 }
213         }
214 }
215
216 func TestRescanWallet(t *testing.T) {
217         // prepare wallet & db
218         dirPath, err := ioutil.TempDir(".", "")
219         if err != nil {
220                 t.Fatal(err)
221         }
222         defer os.RemoveAll(dirPath)
223
224         config.CommonConfig = config.DefaultConfig()
225         testDB := dbm.NewDB("testdb", "leveldb", "temp")
226         walletStore := database.NewWalletStore(testDB)
227         defer func() {
228                 testDB.Close()
229                 os.RemoveAll("temp")
230         }()
231
232         store := database.NewStore(testDB)
233         dispatcher := event.NewDispatcher()
234         txPool := protocol.NewTxPool(store, dispatcher)
235         chain, err := protocol.NewChain(store, txPool, dispatcher)
236         if err != nil {
237                 t.Fatal(err)
238         }
239
240         statusInfo := StatusInfo{
241                 Version:  currentVersion,
242                 WorkHash: bc.Hash{V0: 0xff},
243         }
244         rawWallet, err := json.Marshal(statusInfo)
245         if err != nil {
246                 t.Fatal("save wallet info")
247         }
248
249         w := mockWallet(walletStore, nil, nil, chain, dispatcher, false)
250         w.store.SetWalletInfo(rawWallet)
251         rawWallet = w.store.GetWalletInfo()
252         if rawWallet == nil {
253                 t.Fatal("fail to load wallet StatusInfo")
254         }
255
256         if err := json.Unmarshal(rawWallet, &w.status); err != nil {
257                 t.Fatal(err)
258         }
259
260         // rescan wallet
261         if err := w.loadWalletInfo(); err != nil {
262                 t.Fatal(err)
263         }
264
265         block := config.GenesisBlock()
266         if w.status.WorkHash != block.Hash() {
267                 t.Fatal("reattach from genesis block")
268         }
269 }
270
271 func TestMemPoolTxQueryLoop(t *testing.T) {
272         dirPath, err := ioutil.TempDir(".", "")
273         if err != nil {
274                 t.Fatal(err)
275         }
276         config.CommonConfig = config.DefaultConfig()
277         testDB := dbm.NewDB("testdb", "leveldb", dirPath)
278         defer func() {
279                 testDB.Close()
280                 os.RemoveAll(dirPath)
281         }()
282
283         store := database.NewStore(testDB)
284         dispatcher := event.NewDispatcher()
285         txPool := protocol.NewTxPool(store, dispatcher)
286
287         chain, err := protocol.NewChain(store, txPool, dispatcher)
288         if err != nil {
289                 t.Fatal(err)
290         }
291
292         accountStore := database.NewAccountStore(testDB)
293         accountManager := account.NewManager(accountStore, chain)
294         hsm, err := pseudohsm.New(dirPath)
295         if err != nil {
296                 t.Fatal(err)
297         }
298
299         xpub1, _, err := hsm.XCreate("test_pub1", "password", "en")
300         if err != nil {
301                 t.Fatal(err)
302         }
303
304         testAccount, err := accountManager.Create([]chainkd.XPub{xpub1.XPub}, 1, "testAccount", signers.BIP0044)
305         if err != nil {
306                 t.Fatal(err)
307         }
308
309         controlProg, err := accountManager.CreateAddress(testAccount.ID, false)
310         if err != nil {
311                 t.Fatal(err)
312         }
313
314         controlProg.KeyIndex = 1
315
316         reg := asset.NewRegistry(testDB, chain)
317         asset := bc.AssetID{V0: 5}
318
319         utxos := []*account.UTXO{}
320         btmUtxo := mockUTXO(controlProg, consensus.BTMAssetID)
321         utxos = append(utxos, btmUtxo)
322         OtherUtxo := mockUTXO(controlProg, &asset)
323         utxos = append(utxos, OtherUtxo)
324
325         _, txData, err := mockTxData(utxos, testAccount)
326         if err != nil {
327                 t.Fatal(err)
328         }
329
330         tx := types.NewTx(*txData)
331         //block := mockSingleBlock(tx)
332         txStatus := bc.NewTransactionStatus()
333         txStatus.SetStatus(0, false)
334         walletStore := database.NewWalletStore(testDB)
335         w, err := NewWallet(walletStore, accountManager, reg, hsm, chain, dispatcher, false)
336         go w.memPoolTxQueryLoop()
337         w.eventDispatcher.Post(protocol.TxMsgEvent{TxMsg: &protocol.TxPoolMsg{TxDesc: &protocol.TxDesc{Tx: tx}, MsgType: protocol.MsgNewTx}})
338         time.Sleep(time.Millisecond * 10)
339         if _, err := w.GetUnconfirmedTxByTxID(tx.ID.String()); err != nil {
340                 t.Fatal("dispatch new tx msg error:", err)
341         }
342         w.eventDispatcher.Post(protocol.TxMsgEvent{TxMsg: &protocol.TxPoolMsg{TxDesc: &protocol.TxDesc{Tx: tx}, MsgType: protocol.MsgRemoveTx}})
343         time.Sleep(time.Millisecond * 10)
344         txs, err := w.GetUnconfirmedTxs(testAccount.ID)
345         if err != nil {
346                 t.Fatal("get unconfirmed tx error:", err)
347         }
348
349         if len(txs) != 0 {
350                 t.Fatal("dispatch remove tx msg error")
351         }
352
353         w.eventDispatcher.Post(protocol.TxMsgEvent{TxMsg: &protocol.TxPoolMsg{TxDesc: &protocol.TxDesc{Tx: tx}, MsgType: 2}})
354 }
355
356 func mockUTXO(controlProg *account.CtrlProgram, assetID *bc.AssetID) *account.UTXO {
357         utxo := &account.UTXO{}
358         utxo.OutputID = bc.Hash{V0: 1}
359         utxo.SourceID = bc.Hash{V0: 2}
360         utxo.AssetID = *assetID
361         utxo.Amount = 1000000000
362         utxo.SourcePos = 0
363         utxo.ControlProgram = controlProg.ControlProgram
364         utxo.AccountID = controlProg.AccountID
365         utxo.Address = controlProg.Address
366         utxo.ControlProgramIndex = controlProg.KeyIndex
367         return utxo
368 }
369
370 func mockTxData(utxos []*account.UTXO, testAccount *account.Account) (*txbuilder.Template, *types.TxData, error) {
371         tplBuilder := txbuilder.NewBuilder(time.Now())
372
373         for _, utxo := range utxos {
374                 txInput, sigInst, err := account.UtxoToInputs(testAccount.Signer, utxo)
375                 if err != nil {
376                         return nil, nil, err
377                 }
378                 tplBuilder.AddInput(txInput, sigInst)
379
380                 out := &types.TxOutput{}
381                 if utxo.AssetID == *consensus.BTMAssetID {
382                         out = types.NewIntraChainOutput(utxo.AssetID, 100, utxo.ControlProgram)
383                 } else {
384                         out = types.NewIntraChainOutput(utxo.AssetID, utxo.Amount, utxo.ControlProgram)
385                 }
386                 tplBuilder.AddOutput(out)
387         }
388
389         return tplBuilder.Build()
390 }
391
392 func mockWallet(store WalletStore, account *account.Manager, asset *asset.Registry, chain *protocol.Chain, dispatcher *event.Dispatcher, txIndexFlag bool) *Wallet {
393         wallet := &Wallet{
394                 store:           store,
395                 AccountMgr:      account,
396                 AssetReg:        asset,
397                 chain:           chain,
398                 RecoveryMgr:     newRecoveryManager(store, account),
399                 eventDispatcher: dispatcher,
400                 TxIndexFlag:     txIndexFlag,
401         }
402         wallet.txMsgSub, _ = wallet.eventDispatcher.Subscribe(protocol.TxMsgEvent{})
403         return wallet
404 }
405
406 func mockSingleBlock(tx *types.Tx) *types.Block {
407         return &types.Block{
408                 BlockHeader: types.BlockHeader{
409                         Version: 1,
410                         Height:  1,
411                 },
412                 Transactions: []*types.Tx{config.GenesisTx(), tx},
413         }
414 }
415
416 var (
417         WalletKey     = []byte{0x00, 0x3a}
418         TxIndexPrefix = []byte{0x01, 0x3a}
419         TxPrefix      = []byte{0x02, 0x3a}
420 )
421
422 func CalcGlobalTxIndex(blockHash *bc.Hash, position uint64) []byte {
423         txIdx := make([]byte, 40)
424         copy(txIdx[:32], blockHash.Bytes())
425         binary.BigEndian.PutUint64(txIdx[32:], position)
426         return txIdx
427 }
428
429 func calcTxIndexKey(txID string) []byte {
430         return append(TxIndexPrefix, []byte(txID)...)
431 }
432
433 func calcAnnotatedKey(formatKey string) []byte {
434         return append(TxPrefix, []byte(formatKey)...)
435 }
436
437 type mockAccountStore struct {
438         accountDB dbm.DB
439         batch     dbm.Batch
440 }
441
442 // NewAccountStore create new AccountStore.
443 func newMockAccountStore(db dbm.DB) *mockAccountStore {
444         return &mockAccountStore{
445                 accountDB: db,
446                 batch:     nil,
447         }
448 }
449
450 func (store *mockAccountStore) InitBatch() error                                   { return nil }
451 func (store *mockAccountStore) CommitBatch() error                                 { return nil }
452 func (store *mockAccountStore) DeleteAccount(*account.Account) error               { return nil }
453 func (store *mockAccountStore) DeleteStandardUTXO(outputID bc.Hash)                { return }
454 func (store *mockAccountStore) GetAccountByAlias(string) (*account.Account, error) { return nil, nil }
455 func (store *mockAccountStore) GetAccountByID(string) (*account.Account, error)    { return nil, nil }
456 func (store *mockAccountStore) GetAccountIndex([]chainkd.XPub) uint64              { return 0 }
457 func (store *mockAccountStore) GetBip44ContractIndex(string, bool) uint64          { return 0 }
458 func (store *mockAccountStore) GetCoinbaseArbitrary() []byte                       { return nil }
459 func (store *mockAccountStore) GetContractIndex(string) uint64                     { return 0 }
460 func (store *mockAccountStore) GetControlProgram(bc.Hash) (*account.CtrlProgram, error) {
461         return nil, nil
462 }
463 func (store *mockAccountStore) GetUTXO(outid bc.Hash) (*account.UTXO, error)               { return nil, nil }
464 func (store *mockAccountStore) GetMiningAddress() (*account.CtrlProgram, error)            { return nil, nil }
465 func (store *mockAccountStore) ListAccounts(string) ([]*account.Account, error)            { return nil, nil }
466 func (store *mockAccountStore) ListControlPrograms() ([]*account.CtrlProgram, error)       { return nil, nil }
467 func (store *mockAccountStore) ListUTXOs() ([]*account.UTXO, error)                        { return nil, nil }
468 func (store *mockAccountStore) SetAccount(*account.Account) error                          { return nil }
469 func (store *mockAccountStore) SetAccountIndex(*account.Account)                           { return }
470 func (store *mockAccountStore) SetBip44ContractIndex(string, bool, uint64)                 { return }
471 func (store *mockAccountStore) SetCoinbaseArbitrary([]byte)                                { return }
472 func (store *mockAccountStore) SetContractIndex(string, uint64)                            { return }
473 func (store *mockAccountStore) SetControlProgram(bc.Hash, *account.CtrlProgram) error      { return nil }
474 func (store *mockAccountStore) SetMiningAddress(*account.CtrlProgram) error                { return nil }
475 func (store *mockAccountStore) SetStandardUTXO(outputID bc.Hash, utxo *account.UTXO) error { return nil }
476
477 // WalletStore store wallet using leveldb
478 type mockWalletStore struct {
479         walletDB dbm.DB
480         batch    dbm.Batch
481 }
482
483 // NewWalletStore create new WalletStore struct
484 func newMockWalletStore(db dbm.DB) *mockWalletStore {
485         return &mockWalletStore{
486                 walletDB: db,
487                 batch:    nil,
488         }
489 }
490
491 func (store *mockWalletStore) InitBatch() error                                    { return nil }
492 func (store *mockWalletStore) CommitBatch() error                                  { return nil }
493 func (store *mockWalletStore) DeleteContractUTXO(bc.Hash)                          { return }
494 func (store *mockWalletStore) DeleteRecoveryStatus()                               { return }
495 func (store *mockWalletStore) DeleteTransactions(uint64)                           { return }
496 func (store *mockWalletStore) DeleteUnconfirmedTransaction(string)                 { return }
497 func (store *mockWalletStore) DeleteWalletTransactions()                           { return }
498 func (store *mockWalletStore) DeleteWalletUTXOs()                                  { return }
499 func (store *mockWalletStore) GetAsset(*bc.AssetID) (*asset.Asset, error)          { return nil, nil }
500 func (store *mockWalletStore) GetControlProgram(bc.Hash) (*acc.CtrlProgram, error) { return nil, nil }
501 func (store *mockWalletStore) GetGlobalTransactionIndex(string) []byte             { return nil }
502 func (store *mockWalletStore) GetStandardUTXO(bc.Hash) (*acc.UTXO, error)          { return nil, nil }
503
504 // func (store *mockWalletStore) GetTransaction(string) (*query.AnnotatedTx, error)   { return nil, nil }
505 func (store *mockWalletStore) GetUnconfirmedTransaction(string) (*query.AnnotatedTx, error) {
506         return nil, nil
507 }
508 func (store *mockWalletStore) GetRecoveryStatus([]byte) []byte              { return nil }
509 func (store *mockWalletStore) ListAccountUTXOs(string) ([]*acc.UTXO, error) { return nil, nil }
510 func (store *mockWalletStore) ListTransactions(string, string, uint, bool) ([]*query.AnnotatedTx, error) {
511         return nil, nil
512 }
513 func (store *mockWalletStore) ListUnconfirmedTransactions() ([]*query.AnnotatedTx, error) {
514         return nil, nil
515 }
516 func (store *mockWalletStore) SetAssetDefinition(*bc.AssetID, []byte)             { return }
517 func (store *mockWalletStore) SetContractUTXO(bc.Hash, *acc.UTXO) error           { return nil }
518 func (store *mockWalletStore) SetGlobalTransactionIndex(string, *bc.Hash, uint64) { return }
519 func (store *mockWalletStore) SetRecoveryStatus([]byte, []byte)                   { return }
520 func (store *mockWalletStore) SetTransaction(uint64, *query.AnnotatedTx) error    { return nil }
521 func (store *mockWalletStore) SetUnconfirmedTransaction(string, *query.AnnotatedTx) error {
522         return nil
523 }
524
525 // GetTransaction get tx by txid
526 func (store *mockWalletStore) GetTransaction(txID string) (*query.AnnotatedTx, error) {
527         formatKey := store.walletDB.Get(calcTxIndexKey(txID))
528         if formatKey == nil {
529                 return nil, errors.New("account TXID not found")
530         }
531         rawTx := store.walletDB.Get(calcAnnotatedKey(string(formatKey)))
532         tx := new(query.AnnotatedTx)
533         if err := json.Unmarshal(rawTx, tx); err != nil {
534                 return nil, err
535         }
536         return tx, nil
537 }
538
539 // GetWalletInfo get wallet information
540 func (store *mockWalletStore) GetWalletInfo() []byte {
541         return store.walletDB.Get([]byte(WalletKey))
542 }
543
544 // SetWalletInfo get wallet information
545 func (store *mockWalletStore) SetWalletInfo(rawWallet []byte) {
546         if store.batch == nil {
547                 store.walletDB.Set([]byte(WalletKey), rawWallet)
548         } else {
549                 store.batch.Set([]byte(WalletKey), rawWallet)
550         }
551 }