OSDN Git Service

validateBlock unit test (#346)
[bytom/vapor.git] / wallet / wallet.go
index 92a24b9..0a955d2 100644 (file)
@@ -1,7 +1,6 @@
 package wallet
 
 import (
-       "encoding/json"
        "sync"
 
        log "github.com/sirupsen/logrus"
@@ -9,7 +8,6 @@ import (
        "github.com/vapor/account"
        "github.com/vapor/asset"
        "github.com/vapor/blockchain/pseudohsm"
-       dbm "github.com/vapor/database/leveldb"
        "github.com/vapor/errors"
        "github.com/vapor/event"
        "github.com/vapor/protocol"
@@ -25,10 +23,13 @@ const (
 
 var (
        currentVersion = uint(1)
-       walletKey      = []byte("walletInfo")
 
        errBestBlockNotFoundInCore = errors.New("best block not found in core")
        errWalletVersionMismatch   = errors.New("wallet version mismatch")
+       ErrGetWalletStatusInfo     = errors.New("failed get wallet info")
+       ErrGetAsset                = errors.New("Failed to find asset definition")
+       ErrAccntTxIDNotFound       = errors.New("account TXID not found")
+       ErrGetStandardUTXO         = errors.New("failed get standard UTXO")
 )
 
 //StatusInfo is base valid block info to handle orphan block rollback
@@ -42,36 +43,36 @@ type StatusInfo struct {
 
 //Wallet is related to storing account unspent outputs
 type Wallet struct {
-       DB              dbm.DB
+       Store           WalletStore
        rw              sync.RWMutex
-       status          StatusInfo
+       Status          StatusInfo
        TxIndexFlag     bool
        AccountMgr      *account.Manager
        AssetReg        *asset.Registry
        Hsm             *pseudohsm.HSM
-       chain           *protocol.Chain
+       Chain           *protocol.Chain
        RecoveryMgr     *recoveryManager
-       eventDispatcher *event.Dispatcher
-       txMsgSub        *event.Subscription
+       EventDispatcher *event.Dispatcher
+       TxMsgSub        *event.Subscription
 
        rescanCh chan struct{}
 }
 
 //NewWallet return a new wallet instance
-func NewWallet(walletDB dbm.DB, account *account.Manager, asset *asset.Registry, hsm *pseudohsm.HSM, chain *protocol.Chain, dispatcher *event.Dispatcher, txIndexFlag bool) (*Wallet, error) {
+func NewWallet(store WalletStore, account *account.Manager, asset *asset.Registry, hsm *pseudohsm.HSM, chain *protocol.Chain, dispatcher *event.Dispatcher, txIndexFlag bool) (*Wallet, error) {
        w := &Wallet{
-               DB:              walletDB,
+               Store:           store,
                AccountMgr:      account,
                AssetReg:        asset,
-               chain:           chain,
+               Chain:           chain,
                Hsm:             hsm,
-               RecoveryMgr:     newRecoveryManager(walletDB, account),
-               eventDispatcher: dispatcher,
+               RecoveryMgr:     NewRecoveryManager(store, account),
+               EventDispatcher: dispatcher,
                rescanCh:        make(chan struct{}, 1),
                TxIndexFlag:     txIndexFlag,
        }
 
-       if err := w.loadWalletInfo(); err != nil {
+       if err := w.LoadWalletInfo(); err != nil {
                return nil, err
        }
 
@@ -80,22 +81,22 @@ func NewWallet(walletDB dbm.DB, account *account.Manager, asset *asset.Registry,
        }
 
        var err error
-       w.txMsgSub, err = w.eventDispatcher.Subscribe(protocol.TxMsgEvent{})
+       w.TxMsgSub, err = w.EventDispatcher.Subscribe(protocol.TxMsgEvent{})
        if err != nil {
                return nil, err
        }
 
        go w.walletUpdater()
        go w.delUnconfirmedTx()
-       go w.memPoolTxQueryLoop()
+       go w.MemPoolTxQueryLoop()
        return w, nil
 }
 
-// memPoolTxQueryLoop constantly pass a transaction accepted by mempool to the wallet.
-func (w *Wallet) memPoolTxQueryLoop() {
+// MemPoolTxQueryLoop constantly pass a transaction accepted by mempool to the wallet.
+func (w *Wallet) MemPoolTxQueryLoop() {
        for {
                select {
-               case obj, ok := <-w.txMsgSub.Chan():
+               case obj, ok := <-w.TxMsgSub.Chan():
                        if !ok {
                                log.WithFields(log.Fields{"module": logModule}).Warning("tx pool tx msg subscription channel closed")
                                return
@@ -120,51 +121,50 @@ func (w *Wallet) memPoolTxQueryLoop() {
 }
 
 func (w *Wallet) checkWalletInfo() error {
-       if w.status.Version != currentVersion {
+       if w.Status.Version != currentVersion {
                return errWalletVersionMismatch
-       } else if !w.chain.BlockExist(&w.status.BestHash) {
+       } else if !w.Chain.BlockExist(&w.Status.BestHash) {
                return errBestBlockNotFoundInCore
        }
 
        return nil
 }
 
-//loadWalletInfo return stored wallet info and nil,
+//LoadWalletInfo return stored wallet info and nil,
 //if error, return initial wallet info and err
-func (w *Wallet) loadWalletInfo() error {
-       if rawWallet := w.DB.Get(walletKey); rawWallet != nil {
-               if err := json.Unmarshal(rawWallet, &w.status); err != nil {
-                       return err
-               }
+func (w *Wallet) LoadWalletInfo() error {
+       walletStatus, err := w.Store.GetWalletInfo()
+       if walletStatus == nil && err != ErrGetWalletStatusInfo {
+               return err
+       }
 
-               err := w.checkWalletInfo()
+       if walletStatus != nil {
+               w.Status = *walletStatus
+               err = w.checkWalletInfo()
                if err == nil {
                        return nil
                }
 
                log.WithFields(log.Fields{"module": logModule}).Warn(err.Error())
-               w.deleteAccountTxs()
-               w.deleteUtxos()
+               w.Store.DeleteWalletTransactions()
+               w.Store.DeleteWalletUTXOs()
        }
 
-       w.status.Version = currentVersion
-       w.status.WorkHash = bc.Hash{}
-       block, err := w.chain.GetBlockByHeight(0)
+       w.Status.Version = currentVersion
+       w.Status.WorkHash = bc.Hash{}
+       block, err := w.Chain.GetBlockByHeight(0)
        if err != nil {
                return err
        }
+
        return w.AttachBlock(block)
 }
 
-func (w *Wallet) commitWalletInfo(batch dbm.Batch) error {
-       rawWallet, err := json.Marshal(w.status)
-       if err != nil {
+func (w *Wallet) commitWalletInfo(store WalletStore) error {
+       if err := store.SetWalletInfo(&w.Status); err != nil {
                log.WithFields(log.Fields{"module": logModule, "err": err}).Error("save wallet info")
                return err
        }
-
-       batch.Set(walletKey, rawWallet)
-       batch.Write()
        return nil
 }
 
@@ -173,13 +173,13 @@ func (w *Wallet) AttachBlock(block *types.Block) error {
        w.rw.Lock()
        defer w.rw.Unlock()
 
-       if block.PreviousBlockHash != w.status.WorkHash {
+       if block.PreviousBlockHash != w.Status.WorkHash {
                log.Warn("wallet skip attachBlock due to status hash not equal to previous hash")
                return nil
        }
 
        blockHash := block.Hash()
-       txStatus, err := w.chain.GetTransactionStatus(&blockHash)
+       txStatus, err := w.Chain.GetTransactionStatus(&blockHash)
        if err != nil {
                return err
        }
@@ -189,19 +189,35 @@ func (w *Wallet) AttachBlock(block *types.Block) error {
                w.RecoveryMgr.finished()
        }
 
-       storeBatch := w.DB.NewBatch()
-       if err := w.indexTransactions(storeBatch, block, txStatus); err != nil {
+       annotatedTxs := w.filterAccountTxs(block, txStatus)
+       if err := saveExternalAssetDefinition(block, w.Store); err != nil {
+               return err
+       }
+
+       w.annotateTxsAccount(annotatedTxs)
+
+       newStore := w.Store.InitBatch()
+       if err := w.indexTransactions(block, txStatus, annotatedTxs, newStore); err != nil {
+               return err
+       }
+
+       w.attachUtxos(block, txStatus, newStore)
+       w.Status.WorkHeight = block.Height
+       w.Status.WorkHash = block.Hash()
+       if w.Status.WorkHeight >= w.Status.BestHeight {
+               w.Status.BestHeight = w.Status.WorkHeight
+               w.Status.BestHash = w.Status.WorkHash
+       }
+
+       if err := w.commitWalletInfo(newStore); err != nil {
                return err
        }
 
-       w.attachUtxos(storeBatch, block, txStatus)
-       w.status.WorkHeight = block.Height
-       w.status.WorkHash = block.Hash()
-       if w.status.WorkHeight >= w.status.BestHeight {
-               w.status.BestHeight = w.status.WorkHeight
-               w.status.BestHash = w.status.WorkHash
+       if err := newStore.CommitBatch(); err != nil {
+               return err
        }
-       return w.commitWalletInfo(storeBatch)
+
+       return nil
 }
 
 // DetachBlock detach a block and rollback state
@@ -210,32 +226,40 @@ func (w *Wallet) DetachBlock(block *types.Block) error {
        defer w.rw.Unlock()
 
        blockHash := block.Hash()
-       txStatus, err := w.chain.GetTransactionStatus(&blockHash)
+       txStatus, err := w.Chain.GetTransactionStatus(&blockHash)
        if err != nil {
                return err
        }
 
-       storeBatch := w.DB.NewBatch()
-       w.detachUtxos(storeBatch, block, txStatus)
-       w.deleteTransactions(storeBatch, w.status.BestHeight)
+       newStore := w.Store.InitBatch()
 
-       w.status.BestHeight = block.Height - 1
-       w.status.BestHash = block.PreviousBlockHash
+       w.detachUtxos(block, txStatus, newStore)
+       newStore.DeleteTransactions(w.Status.BestHeight)
 
-       if w.status.WorkHeight > w.status.BestHeight {
-               w.status.WorkHeight = w.status.BestHeight
-               w.status.WorkHash = w.status.BestHash
+       w.Status.BestHeight = block.Height - 1
+       w.Status.BestHash = block.PreviousBlockHash
+
+       if w.Status.WorkHeight > w.Status.BestHeight {
+               w.Status.WorkHeight = w.Status.BestHeight
+               w.Status.WorkHash = w.Status.BestHash
+       }
+       if err := w.commitWalletInfo(newStore); err != nil {
+               return err
        }
 
-       return w.commitWalletInfo(storeBatch)
+       if err := newStore.CommitBatch(); err != nil {
+               return err
+       }
+
+       return nil
 }
 
 //WalletUpdate process every valid block and reverse every invalid block which need to rollback
 func (w *Wallet) walletUpdater() {
        for {
                w.getRescanNotification()
-               for !w.chain.InMainChain(w.status.BestHash) {
-                       block, err := w.chain.GetBlockByHash(&w.status.BestHash)
+               for !w.Chain.InMainChain(w.Status.BestHash) {
+                       block, err := w.Chain.GetBlockByHash(&w.Status.BestHash)
                        if err != nil {
                                log.WithFields(log.Fields{"module": logModule, "err": err}).Error("walletUpdater GetBlockByHash")
                                return
@@ -247,7 +271,7 @@ func (w *Wallet) walletUpdater() {
                        }
                }
 
-               block, _ := w.chain.GetBlockByHeight(w.status.WorkHeight + 1)
+               block, _ := w.Chain.GetBlockByHeight(w.Status.WorkHeight + 1)
                if block == nil {
                        w.walletBlockWaiter()
                        continue
@@ -269,43 +293,6 @@ func (w *Wallet) RescanBlocks() {
        }
 }
 
-// deleteAccountTxs deletes all txs in wallet
-func (w *Wallet) deleteAccountTxs() {
-       storeBatch := w.DB.NewBatch()
-
-       txIter := w.DB.IteratorPrefix([]byte(TxPrefix))
-       defer txIter.Release()
-
-       for txIter.Next() {
-               storeBatch.Delete(txIter.Key())
-       }
-
-       txIndexIter := w.DB.IteratorPrefix([]byte(TxIndexPrefix))
-       defer txIndexIter.Release()
-
-       for txIndexIter.Next() {
-               storeBatch.Delete(txIndexIter.Key())
-       }
-
-       storeBatch.Write()
-}
-
-func (w *Wallet) deleteUtxos() {
-       storeBatch := w.DB.NewBatch()
-       ruIter := w.DB.IteratorPrefix([]byte(account.UTXOPreFix))
-       defer ruIter.Release()
-       for ruIter.Next() {
-               storeBatch.Delete(ruIter.Key())
-       }
-
-       suIter := w.DB.IteratorPrefix([]byte(account.SUTXOPrefix))
-       defer suIter.Release()
-       for suIter.Next() {
-               storeBatch.Delete(suIter.Key())
-       }
-       storeBatch.Write()
-}
-
 // DeleteAccount deletes account matching accountID, then rescan wallet
 func (w *Wallet) DeleteAccount(accountID string) (err error) {
        w.rw.Lock()
@@ -315,7 +302,7 @@ func (w *Wallet) DeleteAccount(accountID string) (err error) {
                return err
        }
 
-       w.deleteAccountTxs()
+       w.Store.DeleteWalletTransactions()
        w.RescanBlocks()
        return nil
 }
@@ -328,7 +315,7 @@ func (w *Wallet) UpdateAccountAlias(accountID string, newAlias string) (err erro
                return err
        }
 
-       w.deleteAccountTxs()
+       w.Store.DeleteWalletTransactions()
        w.RescanBlocks()
        return nil
 }
@@ -343,14 +330,14 @@ func (w *Wallet) getRescanNotification() {
 }
 
 func (w *Wallet) setRescanStatus() {
-       block, _ := w.chain.GetBlockByHeight(0)
-       w.status.WorkHash = bc.Hash{}
+       block, _ := w.Chain.GetBlockByHeight(0)
+       w.Status.WorkHash = bc.Hash{}
        w.AttachBlock(block)
 }
 
 func (w *Wallet) walletBlockWaiter() {
        select {
-       case <-w.chain.BlockWaiter(w.status.WorkHeight + 1):
+       case <-w.Chain.BlockWaiter(w.Status.WorkHeight + 1):
        case <-w.rescanCh:
                w.setRescanStatus()
        }
@@ -361,5 +348,5 @@ func (w *Wallet) GetWalletStatusInfo() StatusInfo {
        w.rw.RLock()
        defer w.rw.RUnlock()
 
-       return w.status
+       return w.Status
 }