OSDN Git Service

Merge pull request #201 from Bytom/v0.1
[bytom/vapor.git] / wallet / wallet.go
index 3f76418..92a24b9 100644 (file)
@@ -5,12 +5,13 @@ import (
        "sync"
 
        log "github.com/sirupsen/logrus"
-       "github.com/tendermint/tmlibs/db"
 
        "github.com/vapor/account"
        "github.com/vapor/asset"
        "github.com/vapor/blockchain/pseudohsm"
-       "github.com/vapor/common"
+       dbm "github.com/vapor/database/leveldb"
+       "github.com/vapor/errors"
+       "github.com/vapor/event"
        "github.com/vapor/protocol"
        "github.com/vapor/protocol/bc"
        "github.com/vapor/protocol/bc/types"
@@ -18,13 +19,21 @@ import (
 
 const (
        //SINGLE single sign
-       SINGLE = 1
+       SINGLE    = 1
+       logModule = "wallet"
 )
 
-var walletKey = []byte("walletInfo")
+var (
+       currentVersion = uint(1)
+       walletKey      = []byte("walletInfo")
+
+       errBestBlockNotFoundInCore = errors.New("best block not found in core")
+       errWalletVersionMismatch   = errors.New("wallet version mismatch")
+)
 
 //StatusInfo is base valid block info to handle orphan block rollback
 type StatusInfo struct {
+       Version    uint
        WorkHeight uint64
        WorkHash   bc.Hash
        BestHeight uint64
@@ -33,29 +42,33 @@ type StatusInfo struct {
 
 //Wallet is related to storing account unspent outputs
 type Wallet struct {
-       DB          db.DB
-       rw          sync.RWMutex
-       status      StatusInfo
-       AccountMgr  *account.Manager
-       AssetReg    *asset.Registry
-       Hsm         *pseudohsm.HSM
-       chain       *protocol.Chain
-       RecoveryMgr *recoveryManager
-       rescanCh    chan struct{}
-       dposAddress common.Address
+       DB              dbm.DB
+       rw              sync.RWMutex
+       status          StatusInfo
+       TxIndexFlag     bool
+       AccountMgr      *account.Manager
+       AssetReg        *asset.Registry
+       Hsm             *pseudohsm.HSM
+       chain           *protocol.Chain
+       RecoveryMgr     *recoveryManager
+       eventDispatcher *event.Dispatcher
+       txMsgSub        *event.Subscription
+
+       rescanCh chan struct{}
 }
 
 //NewWallet return a new wallet instance
-func NewWallet(walletDB db.DB, account *account.Manager, asset *asset.Registry, hsm *pseudohsm.HSM, chain *protocol.Chain, dposAddress common.Address) (*Wallet, error) {
+func NewWallet(walletDB dbm.DB, account *account.Manager, asset *asset.Registry, hsm *pseudohsm.HSM, chain *protocol.Chain, dispatcher *event.Dispatcher, txIndexFlag bool) (*Wallet, error) {
        w := &Wallet{
-               DB:          walletDB,
-               AccountMgr:  account,
-               AssetReg:    asset,
-               chain:       chain,
-               Hsm:         hsm,
-               RecoveryMgr: newRecoveryManager(walletDB, account),
-               rescanCh:    make(chan struct{}, 1),
-               dposAddress: dposAddress,
+               DB:              walletDB,
+               AccountMgr:      account,
+               AssetReg:        asset,
+               chain:           chain,
+               Hsm:             hsm,
+               RecoveryMgr:     newRecoveryManager(walletDB, account),
+               eventDispatcher: dispatcher,
+               rescanCh:        make(chan struct{}, 1),
+               TxIndexFlag:     txIndexFlag,
        }
 
        if err := w.loadWalletInfo(); err != nil {
@@ -66,18 +79,76 @@ func NewWallet(walletDB db.DB, account *account.Manager, asset *asset.Registry,
                return nil, err
        }
 
+       var err error
+       w.txMsgSub, err = w.eventDispatcher.Subscribe(protocol.TxMsgEvent{})
+       if err != nil {
+               return nil, err
+       }
+
        go w.walletUpdater()
        go w.delUnconfirmedTx()
+       go w.memPoolTxQueryLoop()
        return w, nil
 }
 
-//GetWalletInfo return stored wallet info and nil,if error,
-//return initial wallet info and err
+// memPoolTxQueryLoop constantly pass a transaction accepted by mempool to the wallet.
+func (w *Wallet) memPoolTxQueryLoop() {
+       for {
+               select {
+               case obj, ok := <-w.txMsgSub.Chan():
+                       if !ok {
+                               log.WithFields(log.Fields{"module": logModule}).Warning("tx pool tx msg subscription channel closed")
+                               return
+                       }
+
+                       ev, ok := obj.Data.(protocol.TxMsgEvent)
+                       if !ok {
+                               log.WithFields(log.Fields{"module": logModule}).Error("event type error")
+                               continue
+                       }
+
+                       switch ev.TxMsg.MsgType {
+                       case protocol.MsgNewTx:
+                               w.AddUnconfirmedTx(ev.TxMsg.TxDesc)
+                       case protocol.MsgRemoveTx:
+                               w.RemoveUnconfirmedTx(ev.TxMsg.TxDesc)
+                       default:
+                               log.WithFields(log.Fields{"module": logModule}).Warn("got unknow message type from the txPool channel")
+                       }
+               }
+       }
+}
+
+func (w *Wallet) checkWalletInfo() error {
+       if w.status.Version != currentVersion {
+               return errWalletVersionMismatch
+       } else if !w.chain.BlockExist(&w.status.BestHash) {
+               return errBestBlockNotFoundInCore
+       }
+
+       return nil
+}
+
+//loadWalletInfo return stored wallet info and nil,
+//if error, return initial wallet info and err
 func (w *Wallet) loadWalletInfo() error {
        if rawWallet := w.DB.Get(walletKey); rawWallet != nil {
-               return json.Unmarshal(rawWallet, &w.status)
+               if err := json.Unmarshal(rawWallet, &w.status); err != nil {
+                       return err
+               }
+
+               err := w.checkWalletInfo()
+               if err == nil {
+                       return nil
+               }
+
+               log.WithFields(log.Fields{"module": logModule}).Warn(err.Error())
+               w.deleteAccountTxs()
+               w.deleteUtxos()
        }
 
+       w.status.Version = currentVersion
+       w.status.WorkHash = bc.Hash{}
        block, err := w.chain.GetBlockByHeight(0)
        if err != nil {
                return err
@@ -85,10 +156,10 @@ func (w *Wallet) loadWalletInfo() error {
        return w.AttachBlock(block)
 }
 
-func (w *Wallet) commitWalletInfo(batch db.Batch) error {
+func (w *Wallet) commitWalletInfo(batch dbm.Batch) error {
        rawWallet, err := json.Marshal(w.status)
        if err != nil {
-               log.WithField("err", err).Error("save wallet info")
+               log.WithFields(log.Fields{"module": logModule, "err": err}).Error("save wallet info")
                return err
        }
 
@@ -114,7 +185,8 @@ func (w *Wallet) AttachBlock(block *types.Block) error {
        }
 
        if err := w.RecoveryMgr.FilterRecoveryTxs(block); err != nil {
-               return err
+               log.WithField("err", err).Error("filter recovery txs")
+               w.RecoveryMgr.finished()
        }
 
        storeBatch := w.DB.NewBatch()
@@ -165,12 +237,12 @@ func (w *Wallet) walletUpdater() {
                for !w.chain.InMainChain(w.status.BestHash) {
                        block, err := w.chain.GetBlockByHash(&w.status.BestHash)
                        if err != nil {
-                               log.WithField("err", err).Error("walletUpdater GetBlockByHash")
+                               log.WithFields(log.Fields{"module": logModule, "err": err}).Error("walletUpdater GetBlockByHash")
                                return
                        }
 
                        if err := w.DetachBlock(block); err != nil {
-                               log.WithField("err", err).Error("walletUpdater detachBlock stop")
+                               log.WithFields(log.Fields{"module": logModule, "err": err}).Error("walletUpdater detachBlock stop")
                                return
                        }
                }
@@ -182,7 +254,7 @@ func (w *Wallet) walletUpdater() {
                }
 
                if err := w.AttachBlock(block); err != nil {
-                       log.WithField("err", err).Error("walletUpdater AttachBlock stop")
+                       log.WithFields(log.Fields{"module": logModule, "err": err}).Error("walletUpdater AttachBlock stop")
                        return
                }
        }
@@ -218,6 +290,22 @@ func (w *Wallet) deleteAccountTxs() {
        storeBatch.Write()
 }
 
+func (w *Wallet) deleteUtxos() {
+       storeBatch := w.DB.NewBatch()
+       ruIter := w.DB.IteratorPrefix([]byte(account.UTXOPreFix))
+       defer ruIter.Release()
+       for ruIter.Next() {
+               storeBatch.Delete(ruIter.Key())
+       }
+
+       suIter := w.DB.IteratorPrefix([]byte(account.SUTXOPrefix))
+       defer suIter.Release()
+       for suIter.Next() {
+               storeBatch.Delete(suIter.Key())
+       }
+       storeBatch.Write()
+}
+
 // DeleteAccount deletes account matching accountID, then rescan wallet
 func (w *Wallet) DeleteAccount(accountID string) (err error) {
        w.rw.Lock()