OSDN Git Service

add wallet version check & globalTxIdx (#1657)
authorHAOYUatHZ <37070449+HAOYUatHZ@users.noreply.github.com>
Wed, 27 Mar 2019 07:08:31 +0000 (15:08 +0800)
committerPaladz <yzhu101@uottawa.ca>
Wed, 27 Mar 2019 07:08:31 +0000 (15:08 +0800)
* feat: add current wallet version check

* feat: save globalTxIdx

* refactor: clean up

* test: add wallet txID check

* test: add genesisTx in wallet test

* refactor: adjust calcGlobalTxIndex format

* refactor: change date type for wallet version

* test: fix attach block in wallet_test

* refactor: fix https://github.com/Bytom/bytom/pull/1657/files#r269394040

* fix: move json.unmarshal form check_walletinfo to load_wallet

* fix: fix wallet version check test

* refactor: clean

* feat: add w.GetGlobalTxIdxes()

* test: add globalTxIdx comparasion

* refactor: clean up

* refactor: clean

* refactor: change bh type to pointer

* refactor: use nil txPool for TestWalletVersion

* refactor: use nil for unnecessary wallet params

config/genesis.go
wallet/indexer.go
wallet/wallet.go
wallet/wallet_test.go

index e031f75..90b481c 100644 (file)
@@ -10,7 +10,7 @@ import (
        "github.com/bytom/protocol/bc/types"
 )
 
-func genesisTx() *types.Tx {
+func GenesisTx() *types.Tx {
        contract, err := hex.DecodeString("00148c9d063ff74ee6d9ffa88d83aeb038068366c4c4")
        if err != nil {
                log.Panicf("fail on decode genesis tx output control program")
@@ -29,7 +29,7 @@ func genesisTx() *types.Tx {
 }
 
 func mainNetGenesisBlock() *types.Block {
-       tx := genesisTx()
+       tx := GenesisTx()
        txStatus := bc.NewTransactionStatus()
        if err := txStatus.SetStatus(0, false); err != nil {
                log.Panicf(err.Error())
@@ -62,7 +62,7 @@ func mainNetGenesisBlock() *types.Block {
 }
 
 func testNetGenesisBlock() *types.Block {
-       tx := genesisTx()
+       tx := GenesisTx()
        txStatus := bc.NewTransactionStatus()
        if err := txStatus.SetStatus(0, false); err != nil {
                log.Panicf(err.Error())
@@ -95,7 +95,7 @@ func testNetGenesisBlock() *types.Block {
 }
 
 func soloNetGenesisBlock() *types.Block {
-       tx := genesisTx()
+       tx := GenesisTx()
        txStatus := bc.NewTransactionStatus()
        if err := txStatus.SetStatus(0, false); err != nil {
                log.Panicf(err.Error())
index dad3b64..459efad 100644 (file)
@@ -22,6 +22,8 @@ const (
        TxPrefix = "TXS:"
        //TxIndexPrefix is wallet database tx index prefix
        TxIndexPrefix = "TID:"
+       //TxIndexPrefix is wallet database global tx index prefix
+       GlobalTxIndexPrefix = "GTID:"
 )
 
 func formatKey(blockHeight uint64, position uint32) string {
@@ -40,6 +42,14 @@ func calcTxIndexKey(txID string) []byte {
        return []byte(TxIndexPrefix + txID)
 }
 
+func calcGlobalTxIndexKey(txID string) []byte {
+       return []byte(GlobalTxIndexPrefix + txID)
+}
+
+func calcGlobalTxIndex(blockHash *bc.Hash, position int) []byte {
+       return []byte(fmt.Sprintf("%064x%08x", blockHash.String(), position))
+}
+
 // deleteTransaction delete transactions when orphan block rollback
 func (w *Wallet) deleteTransactions(batch db.Batch, height uint64) {
        tmpTx := query.AnnotatedTx{}
@@ -113,6 +123,12 @@ func (w *Wallet) indexTransactions(batch db.Batch, b *types.Block, txStatus *bc.
                // delete unconfirmed transaction
                batch.Delete(calcUnconfirmedTxKey(tx.ID.String()))
        }
+
+       for position, globalTx := range b.Transactions {
+               blockHash := b.BlockHeader.Hash()
+               batch.Set(calcGlobalTxIndexKey(globalTx.ID.String()), calcGlobalTxIndex(&blockHash, position))
+       }
+
        return nil
 }
 
index 365f3c0..58d7a67 100644 (file)
@@ -10,6 +10,7 @@ import (
        "github.com/bytom/account"
        "github.com/bytom/asset"
        "github.com/bytom/blockchain/pseudohsm"
+       "github.com/bytom/errors"
        "github.com/bytom/event"
        "github.com/bytom/protocol"
        "github.com/bytom/protocol/bc"
@@ -22,10 +23,17 @@ const (
        logModule = "wallet"
 )
 
-var walletKey = []byte("walletInfo")
+var (
+       currentVersion = uint(1)
+       walletKey      = []byte("walletInfo")
+
+       errBestBlockNotFoundInCore = errors.New("best block not found in core")
+       errWalletVersionMismatch   = errors.New("wallet version mismatch")
+)
 
 //StatusInfo is base valid block info to handle orphan block rollback
 type StatusInfo struct {
+       Version    uint
        WorkHeight uint64
        WorkHash   bc.Hash
        BestHeight uint64
@@ -109,25 +117,35 @@ func (w *Wallet) memPoolTxQueryLoop() {
        }
 }
 
-//GetWalletInfo return stored wallet info and nil,if error,
-//return initial wallet info and err
+func (w *Wallet) checkWalletInfo() error {
+       if w.status.Version != currentVersion {
+               return errWalletVersionMismatch
+       } else if !w.chain.BlockExist(&w.status.BestHash) {
+               return errBestBlockNotFoundInCore
+       }
+
+       return nil
+}
+
+//loadWalletInfo return stored wallet info and nil,
+//if error, return initial wallet info and err
 func (w *Wallet) loadWalletInfo() error {
        if rawWallet := w.DB.Get(walletKey); rawWallet != nil {
                if err := json.Unmarshal(rawWallet, &w.status); err != nil {
                        return err
                }
 
-               //handle the case than use replace the coreDB during status in fork chain
-               if w.chain.BlockExist(&w.status.BestHash) {
+               err := w.checkWalletInfo()
+               if err == nil {
                        return nil
                }
 
-               log.WithFields(log.Fields{"module": logModule}).Warn("reset the wallet status due to core doesn't have wallet best block")
+               log.WithFields(log.Fields{"module": logModule}).Warn(err.Error())
                w.deleteAccountTxs()
                w.deleteUtxos()
-               w.status = StatusInfo{}
        }
 
+       w.status.Version = currentVersion
        block, err := w.chain.GetBlockByHeight(0)
        if err != nil {
                return err
index 5481bfd..f3e7153 100644 (file)
@@ -1,8 +1,10 @@
 package wallet
 
 import (
+       "encoding/json"
        "io/ioutil"
        "os"
+       "reflect"
        "testing"
        "time"
 
@@ -13,6 +15,7 @@ import (
        "github.com/bytom/blockchain/pseudohsm"
        "github.com/bytom/blockchain/signers"
        "github.com/bytom/blockchain/txbuilder"
+       "github.com/bytom/config"
        "github.com/bytom/consensus"
        "github.com/bytom/crypto/ed25519/chainkd"
        "github.com/bytom/database/leveldb"
@@ -22,6 +25,68 @@ import (
        "github.com/bytom/protocol/bc/types"
 )
 
+func TestWalletVersion(t *testing.T) {
+       // prepare wallet
+       dirPath, err := ioutil.TempDir(".", "")
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer os.RemoveAll(dirPath)
+
+       testDB := dbm.NewDB("testdb", "leveldb", "temp")
+       defer os.RemoveAll("temp")
+
+       dispatcher := event.NewDispatcher()
+       w := mockWallet(testDB, nil, nil, nil, dispatcher)
+
+       // legacy status test case
+       type legacyStatusInfo struct {
+               WorkHeight uint64
+               WorkHash   bc.Hash
+               BestHeight uint64
+               BestHash   bc.Hash
+       }
+       rawWallet, err := json.Marshal(legacyStatusInfo{})
+       if err != nil {
+               t.Fatal("Marshal legacyStatusInfo")
+       }
+
+       w.DB.Set(walletKey, rawWallet)
+       rawWallet = w.DB.Get(walletKey)
+       if rawWallet == nil {
+               t.Fatal("fail to load wallet StatusInfo")
+       }
+
+       if err := json.Unmarshal(rawWallet, &w.status); err != nil {
+               t.Fatal(err)
+       }
+
+       if err := w.checkWalletInfo(); err != errWalletVersionMismatch {
+               t.Fatal("fail to detect legacy wallet version")
+       }
+
+       // lower wallet version test case
+       lowerVersion := StatusInfo{Version: currentVersion - 1}
+       rawWallet, err = json.Marshal(lowerVersion)
+       if err != nil {
+               t.Fatal("save wallet info")
+       }
+
+       w.DB.Set(walletKey, rawWallet)
+       rawWallet = w.DB.Get(walletKey)
+       if rawWallet == nil {
+               t.Fatal("fail to load wallet StatusInfo")
+       }
+
+       if err := json.Unmarshal(rawWallet, &w.status); err != nil {
+               t.Fatal(err)
+       }
+
+       if err := w.checkWalletInfo(); err != errWalletVersionMismatch {
+               t.Fatal("fail to detect expired wallet version")
+       }
+}
+
 func TestWalletUpdate(t *testing.T) {
        dirPath, err := ioutil.TempDir(".", "")
        if err != nil {
@@ -85,6 +150,7 @@ func TestWalletUpdate(t *testing.T) {
        block := mockSingleBlock(tx)
        txStatus := bc.NewTransactionStatus()
        txStatus.SetStatus(0, false)
+       txStatus.SetStatus(1, false)
        store.SaveBlock(block, txStatus)
 
        w := mockWallet(testDB, accountManager, reg, chain, dispatcher)
@@ -101,6 +167,19 @@ func TestWalletUpdate(t *testing.T) {
        if len(wants) != 1 {
                t.Fatal(err)
        }
+
+       if wants[0].ID != tx.ID {
+               t.Fatal("account txID mismatch")
+       }
+
+       for position, tx := range block.Transactions {
+               get := w.DB.Get(calcGlobalTxIndexKey(tx.ID.String()))
+               bh := block.BlockHeader.Hash()
+               expect := calcGlobalTxIndex(&bh, position)
+               if !reflect.DeepEqual(get, expect) {
+                       t.Fatalf("position#%d: compare retrieved globalTxIdx err", position)
+               }
+       }
 }
 
 func TestMemPoolTxQueryLoop(t *testing.T) {
@@ -242,6 +321,6 @@ func mockSingleBlock(tx *types.Tx) *types.Block {
                        Height:  1,
                        Bits:    2305843009230471167,
                },
-               Transactions: []*types.Tx{tx},
+               Transactions: []*types.Tx{config.GenesisTx(), tx},
        }
 }