OSDN Git Service

update
[bytom/vapor.git] / wallet / wallet_test.go
1 package wallet
2
3 import (
4         "encoding/json"
5         "io/ioutil"
6         "os"
7         "reflect"
8         "testing"
9         "time"
10
11         "github.com/vapor/account"
12         "github.com/vapor/asset"
13         "github.com/vapor/blockchain/pseudohsm"
14         "github.com/vapor/blockchain/signers"
15         "github.com/vapor/blockchain/txbuilder"
16         "github.com/vapor/config"
17         "github.com/vapor/consensus"
18         "github.com/vapor/crypto/ed25519/chainkd"
19         "github.com/vapor/database"
20         "github.com/vapor/database/dbutils"
21         dbm "github.com/vapor/database/leveldb"
22         "github.com/vapor/event"
23         "github.com/vapor/protocol"
24         "github.com/vapor/protocol/bc"
25         "github.com/vapor/protocol/bc/types"
26 )
27
28 func TestEncodeDecodeGlobalTxIndex(t *testing.T) {
29         want := &struct {
30                 BlockHash bc.Hash
31                 Position  uint64
32         }{
33                 BlockHash: bc.NewHash([32]byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20}),
34                 Position:  1,
35         }
36
37         globalTxIdx := calcGlobalTxIndex(&want.BlockHash, want.Position)
38         blockHashGot, positionGot := parseGlobalTxIdx(globalTxIdx)
39         if *blockHashGot != want.BlockHash {
40                 t.Errorf("blockHash mismatch. Get: %v. Expect: %v", *blockHashGot, want.BlockHash)
41         }
42
43         if positionGot != want.Position {
44                 t.Errorf("position mismatch. Get: %v. Expect: %v", positionGot, want.Position)
45         }
46 }
47
48 func TestWalletVersion(t *testing.T) {
49         // prepare wallet
50         dirPath, err := ioutil.TempDir(".", "")
51         if err != nil {
52                 t.Fatal(err)
53         }
54         defer os.RemoveAll(dirPath)
55
56         testDB := database.NewDB("testdb", "leveldb", "temp")
57         defer func() {
58                 testDB.Close()
59                 os.RemoveAll("temp")
60         }()
61
62         dispatcher := event.NewDispatcher()
63         w := mockWallet(testDB, nil, nil, nil, dispatcher, false)
64
65         // legacy status test case
66         type legacyStatusInfo struct {
67                 WorkHeight uint64
68                 WorkHash   bc.Hash
69                 BestHeight uint64
70                 BestHash   bc.Hash
71         }
72         rawWallet, err := json.Marshal(legacyStatusInfo{})
73         if err != nil {
74                 t.Fatal("Marshal legacyStatusInfo")
75         }
76
77         w.DB.Set(walletKey, rawWallet)
78         rawWallet = w.DB.Get(walletKey)
79         if rawWallet == nil {
80                 t.Fatal("fail to load wallet StatusInfo")
81         }
82
83         if err := json.Unmarshal(rawWallet, &w.status); err != nil {
84                 t.Fatal(err)
85         }
86
87         if err := w.checkWalletInfo(); err != errWalletVersionMismatch {
88                 t.Fatal("fail to detect legacy wallet version")
89         }
90
91         // lower wallet version test case
92         lowerVersion := StatusInfo{Version: currentVersion - 1}
93         rawWallet, err = json.Marshal(lowerVersion)
94         if err != nil {
95                 t.Fatal("save wallet info")
96         }
97
98         w.DB.Set(walletKey, rawWallet)
99         rawWallet = w.DB.Get(walletKey)
100         if rawWallet == nil {
101                 t.Fatal("fail to load wallet StatusInfo")
102         }
103
104         if err := json.Unmarshal(rawWallet, &w.status); err != nil {
105                 t.Fatal(err)
106         }
107
108         if err := w.checkWalletInfo(); err != errWalletVersionMismatch {
109                 t.Fatal("fail to detect expired wallet version")
110         }
111 }
112
113 func TestWalletUpdate(t *testing.T) {
114         dirPath, err := ioutil.TempDir(".", "")
115         if err != nil {
116                 t.Fatal(err)
117         }
118         defer os.RemoveAll(dirPath)
119
120         config.CommonConfig = config.DefaultConfig()
121         testDB := dbm.NewDB("testdb", "leveldb", "temp")
122         defer func() {
123                 testDB.Close()
124                 os.RemoveAll("temp")
125         }()
126
127         store := database.NewStore(testDB)
128         dispatcher := event.NewDispatcher()
129         txPool := protocol.NewTxPool(store, dispatcher)
130
131         chain, err := protocol.NewChain(store, txPool, dispatcher)
132         if err != nil {
133                 t.Fatal(err)
134         }
135
136         accountManager := account.NewManager(testDB, chain)
137         hsm, err := pseudohsm.New(dirPath)
138         if err != nil {
139                 t.Fatal(err)
140         }
141
142         xpub1, _, err := hsm.XCreate("test_pub1", "password", "en")
143         if err != nil {
144                 t.Fatal(err)
145         }
146
147         testAccount, err := accountManager.Create([]chainkd.XPub{xpub1.XPub}, 1, "testAccount", signers.BIP0044)
148         if err != nil {
149                 t.Fatal(err)
150         }
151
152         controlProg, err := accountManager.CreateAddress(testAccount.ID, false)
153         if err != nil {
154                 t.Fatal(err)
155         }
156
157         controlProg.KeyIndex = 1
158
159         reg := asset.NewRegistry(testDB, chain)
160         asset := bc.AssetID{V0: 5}
161
162         utxos := []*account.UTXO{}
163         btmUtxo := mockUTXO(controlProg, consensus.BTMAssetID)
164         utxos = append(utxos, btmUtxo)
165         OtherUtxo := mockUTXO(controlProg, &asset)
166         utxos = append(utxos, OtherUtxo)
167
168         _, txData, err := mockTxData(utxos, testAccount)
169         if err != nil {
170                 t.Fatal(err)
171         }
172
173         tx := types.NewTx(*txData)
174         block := mockSingleBlock(tx)
175         txStatus := bc.NewTransactionStatus()
176         txStatus.SetStatus(0, false)
177         txStatus.SetStatus(1, false)
178         store.SaveBlock(block, txStatus)
179
180         w := mockWallet(testDB, accountManager, reg, chain, dispatcher, true)
181         err = w.AttachBlock(block)
182         if err != nil {
183                 t.Fatal(err)
184         }
185
186         if _, err := w.GetTransactionByTxID(tx.ID.String()); err != nil {
187                 t.Fatal(err)
188         }
189
190         wants, err := w.GetTransactions("")
191         if len(wants) != 1 {
192                 t.Fatal(err)
193         }
194
195         if wants[0].ID != tx.ID {
196                 t.Fatal("account txID mismatch")
197         }
198
199         for position, tx := range block.Transactions {
200                 get := w.DB.Get(calcGlobalTxIndexKey(tx.ID.String()))
201                 bh := block.BlockHeader.Hash()
202                 expect := calcGlobalTxIndex(&bh, uint64(position))
203                 if !reflect.DeepEqual(get, expect) {
204                         t.Fatalf("position#%d: compare retrieved globalTxIdx err", position)
205                 }
206         }
207 }
208
209 func TestRescanWallet(t *testing.T) {
210         // prepare wallet & db
211         dirPath, err := ioutil.TempDir(".", "")
212         if err != nil {
213                 t.Fatal(err)
214         }
215         defer os.RemoveAll(dirPath)
216
217         config.CommonConfig = config.DefaultConfig()
218         testDB := dbm.NewDB("testdb", "leveldb", "temp")
219         defer func() {
220                 testDB.Close()
221                 os.RemoveAll("temp")
222         }()
223
224         store := database.NewStore(testDB)
225         dispatcher := event.NewDispatcher()
226         txPool := protocol.NewTxPool(store, dispatcher)
227         chain, err := protocol.NewChain(store, txPool, dispatcher)
228         if err != nil {
229                 t.Fatal(err)
230         }
231
232         statusInfo := StatusInfo{
233                 Version:  currentVersion,
234                 WorkHash: bc.Hash{V0: 0xff},
235         }
236         rawWallet, err := json.Marshal(statusInfo)
237         if err != nil {
238                 t.Fatal("save wallet info")
239         }
240
241         w := mockWallet(testDB, nil, nil, chain, dispatcher, false)
242         w.DB.Set(walletKey, rawWallet)
243         rawWallet = w.DB.Get(walletKey)
244         if rawWallet == nil {
245                 t.Fatal("fail to load wallet StatusInfo")
246         }
247
248         if err := json.Unmarshal(rawWallet, &w.status); err != nil {
249                 t.Fatal(err)
250         }
251
252         // rescan wallet
253         if err := w.loadWalletInfo(); err != nil {
254                 t.Fatal(err)
255         }
256
257         block := config.GenesisBlock()
258         if w.status.WorkHash != block.Hash() {
259                 t.Fatal("reattach from genesis block")
260         }
261 }
262
263 func TestMemPoolTxQueryLoop(t *testing.T) {
264         dirPath, err := ioutil.TempDir(".", "")
265         if err != nil {
266                 t.Fatal(err)
267         }
268         config.CommonConfig = config.DefaultConfig()
269         testDB := dbm.NewDB("testdb", "leveldb", dirPath)
270         defer func() {
271                 testDB.Close()
272                 os.RemoveAll(dirPath)
273         }()
274
275         store := database.NewStore(testDB)
276         dispatcher := event.NewDispatcher()
277         txPool := protocol.NewTxPool(store, dispatcher)
278
279         chain, err := protocol.NewChain(store, txPool, dispatcher)
280         if err != nil {
281                 t.Fatal(err)
282         }
283
284         accountManager := account.NewManager(testDB, chain)
285         hsm, err := pseudohsm.New(dirPath)
286         if err != nil {
287                 t.Fatal(err)
288         }
289
290         xpub1, _, err := hsm.XCreate("test_pub1", "password", "en")
291         if err != nil {
292                 t.Fatal(err)
293         }
294
295         testAccount, err := accountManager.Create([]chainkd.XPub{xpub1.XPub}, 1, "testAccount", signers.BIP0044)
296         if err != nil {
297                 t.Fatal(err)
298         }
299
300         controlProg, err := accountManager.CreateAddress(testAccount.ID, false)
301         if err != nil {
302                 t.Fatal(err)
303         }
304
305         controlProg.KeyIndex = 1
306
307         reg := asset.NewRegistry(testDB, chain)
308         asset := bc.AssetID{V0: 5}
309
310         utxos := []*account.UTXO{}
311         btmUtxo := mockUTXO(controlProg, consensus.BTMAssetID)
312         utxos = append(utxos, btmUtxo)
313         OtherUtxo := mockUTXO(controlProg, &asset)
314         utxos = append(utxos, OtherUtxo)
315
316         _, txData, err := mockTxData(utxos, testAccount)
317         if err != nil {
318                 t.Fatal(err)
319         }
320
321         tx := types.NewTx(*txData)
322         //block := mockSingleBlock(tx)
323         txStatus := bc.NewTransactionStatus()
324         txStatus.SetStatus(0, false)
325         w, err := NewWallet(testDB, accountManager, reg, hsm, chain, dispatcher, false)
326         go w.memPoolTxQueryLoop()
327         w.eventDispatcher.Post(protocol.TxMsgEvent{TxMsg: &protocol.TxPoolMsg{TxDesc: &protocol.TxDesc{Tx: tx}, MsgType: protocol.MsgNewTx}})
328         time.Sleep(time.Millisecond * 10)
329         if _, err = w.GetUnconfirmedTxByTxID(tx.ID.String()); err != nil {
330                 t.Fatal("disaptch new tx msg error:", err)
331         }
332         w.eventDispatcher.Post(protocol.TxMsgEvent{TxMsg: &protocol.TxPoolMsg{TxDesc: &protocol.TxDesc{Tx: tx}, MsgType: protocol.MsgRemoveTx}})
333         time.Sleep(time.Millisecond * 10)
334         txs, err := w.GetUnconfirmedTxs(testAccount.ID)
335         if err != nil {
336                 t.Fatal("get unconfirmed tx error:", err)
337         }
338
339         if len(txs) != 0 {
340                 t.Fatal("disaptch remove tx msg error")
341         }
342
343         w.eventDispatcher.Post(protocol.TxMsgEvent{TxMsg: &protocol.TxPoolMsg{TxDesc: &protocol.TxDesc{Tx: tx}, MsgType: 2}})
344 }
345
346 func mockUTXO(controlProg *account.CtrlProgram, assetID *bc.AssetID) *account.UTXO {
347         utxo := &account.UTXO{}
348         utxo.OutputID = bc.Hash{V0: 1}
349         utxo.SourceID = bc.Hash{V0: 2}
350         utxo.AssetID = *assetID
351         utxo.Amount = 1000000000
352         utxo.SourcePos = 0
353         utxo.ControlProgram = controlProg.ControlProgram
354         utxo.AccountID = controlProg.AccountID
355         utxo.Address = controlProg.Address
356         utxo.ControlProgramIndex = controlProg.KeyIndex
357         return utxo
358 }
359
360 func mockTxData(utxos []*account.UTXO, testAccount *account.Account) (*txbuilder.Template, *types.TxData, error) {
361         tplBuilder := txbuilder.NewBuilder(time.Now())
362
363         for _, utxo := range utxos {
364                 txInput, sigInst, err := account.UtxoToInputs(testAccount.Signer, utxo)
365                 if err != nil {
366                         return nil, nil, err
367                 }
368                 tplBuilder.AddInput(txInput, sigInst)
369
370                 out := &types.TxOutput{}
371                 if utxo.AssetID == *consensus.BTMAssetID {
372                         out = types.NewIntraChainOutput(utxo.AssetID, 100, utxo.ControlProgram)
373                 } else {
374                         out = types.NewIntraChainOutput(utxo.AssetID, utxo.Amount, utxo.ControlProgram)
375                 }
376                 tplBuilder.AddOutput(out)
377         }
378
379         return tplBuilder.Build()
380 }
381
382 func mockWallet(walletDB dbutils.DB, account *account.Manager, asset *asset.Registry, chain *protocol.Chain, dispatcher *event.Dispatcher, txIndexFlag bool) *Wallet {
383         wallet := &Wallet{
384                 DB:              walletDB,
385                 AccountMgr:      account,
386                 AssetReg:        asset,
387                 chain:           chain,
388                 RecoveryMgr:     newRecoveryManager(walletDB, account),
389                 eventDispatcher: dispatcher,
390                 TxIndexFlag:     txIndexFlag,
391         }
392         wallet.txMsgSub, _ = wallet.eventDispatcher.Subscribe(protocol.TxMsgEvent{})
393         return wallet
394 }
395
396 func mockSingleBlock(tx *types.Tx) *types.Block {
397         return &types.Block{
398                 BlockHeader: types.BlockHeader{
399                         Version: 1,
400                         Height:  1,
401                 },
402                 Transactions: []*types.Tx{config.GenesisTx(), tx},
403         }
404 }