OSDN Git Service

feat(consensus): update mainnet and testnet checkpoint (#1824)
[bytom/bytom.git] / wallet / wallet.go
index 32e2ae6..ed52121 100644 (file)
@@ -2,94 +2,153 @@ package wallet
 
 import (
        "encoding/json"
-       "fmt"
+       "sync"
 
        log "github.com/sirupsen/logrus"
-       "github.com/tendermint/go-wire/data/base58"
-       "github.com/tendermint/tmlibs/db"
-
-       "github.com/bytom/account"
-       "github.com/bytom/asset"
-       "github.com/bytom/blockchain/pseudohsm"
-       "github.com/bytom/crypto/ed25519/chainkd"
-       "github.com/bytom/crypto/sha3pool"
-       "github.com/bytom/protocol"
-       "github.com/bytom/protocol/bc"
-       "github.com/bytom/protocol/bc/types"
+
+       "github.com/bytom/bytom/account"
+       "github.com/bytom/bytom/asset"
+       "github.com/bytom/bytom/blockchain/pseudohsm"
+       dbm "github.com/bytom/bytom/database/leveldb"
+       "github.com/bytom/bytom/errors"
+       "github.com/bytom/bytom/event"
+       "github.com/bytom/bytom/protocol"
+       "github.com/bytom/bytom/protocol/bc"
+       "github.com/bytom/bytom/protocol/bc/types"
 )
 
-//SINGLE single sign
-const SINGLE = 1
+const (
+       //SINGLE single sign
+       SINGLE    = 1
+       logModule = "wallet"
+)
 
-//RecoveryIndex walletdb recovery cp number
-const RecoveryIndex = 5000
+var (
+       currentVersion = uint(1)
+       walletKey      = []byte("walletInfo")
 
-var walletKey = []byte("walletInfo")
-var privKeyKey = []byte("keysInfo")
+       errBestBlockNotFoundInCore = errors.New("best block not found in core")
+       errWalletVersionMismatch   = errors.New("wallet version mismatch")
+)
 
 //StatusInfo is base valid block info to handle orphan block rollback
 type StatusInfo struct {
+       Version    uint
        WorkHeight uint64
        WorkHash   bc.Hash
        BestHeight uint64
        BestHash   bc.Hash
 }
 
-//KeyInfo is key import status
-type KeyInfo struct {
-       Alias    string       `json:"alias"`
-       XPub     chainkd.XPub `json:"xpub"`
-       Percent  uint8        `json:"percent"`
-       Complete bool         `json:"complete"`
-}
-
 //Wallet is related to storing account unspent outputs
 type Wallet struct {
-       DB             db.DB
-       status         StatusInfo
-       AccountMgr     *account.Manager
-       AssetReg       *asset.Registry
-       Hsm            *pseudohsm.HSM
-       chain          *protocol.Chain
-       rescanProgress chan struct{}
-       ImportPrivKey  bool
-       keysInfo       []KeyInfo
+       DB              dbm.DB
+       rw              sync.RWMutex
+       status          StatusInfo
+       TxIndexFlag     bool
+       AccountMgr      *account.Manager
+       AssetReg        *asset.Registry
+       Hsm             *pseudohsm.HSM
+       chain           *protocol.Chain
+       RecoveryMgr     *recoveryManager
+       eventDispatcher *event.Dispatcher
+       txMsgSub        *event.Subscription
+
+       rescanCh chan struct{}
 }
 
 //NewWallet return a new wallet instance
-func NewWallet(walletDB db.DB, account *account.Manager, asset *asset.Registry, hsm *pseudohsm.HSM, chain *protocol.Chain) (*Wallet, error) {
+func NewWallet(walletDB dbm.DB, account *account.Manager, asset *asset.Registry, hsm *pseudohsm.HSM, chain *protocol.Chain, dispatcher *event.Dispatcher, txIndexFlag bool) (*Wallet, error) {
        w := &Wallet{
-               DB:             walletDB,
-               AccountMgr:     account,
-               AssetReg:       asset,
-               chain:          chain,
-               Hsm:            hsm,
-               rescanProgress: make(chan struct{}, 1),
-               keysInfo:       make([]KeyInfo, 0),
+               DB:              walletDB,
+               AccountMgr:      account,
+               AssetReg:        asset,
+               chain:           chain,
+               Hsm:             hsm,
+               RecoveryMgr:     newRecoveryManager(walletDB, account),
+               eventDispatcher: dispatcher,
+               rescanCh:        make(chan struct{}, 1),
+               TxIndexFlag:     txIndexFlag,
        }
 
        if err := w.loadWalletInfo(); err != nil {
                return nil, err
        }
 
-       if err := w.loadKeysInfo(); err != nil {
+       if err := w.RecoveryMgr.LoadStatusInfo(); err != nil {
                return nil, err
        }
 
-       w.ImportPrivKey = w.getImportKeyFlag()
+       var err error
+       w.txMsgSub, err = w.eventDispatcher.Subscribe(protocol.TxMsgEvent{})
+       if err != nil {
+               return nil, err
+       }
 
        go w.walletUpdater()
-
+       go w.delUnconfirmedTx()
+       go w.memPoolTxQueryLoop()
        return w, nil
 }
 
-//GetWalletInfo return stored wallet info and nil,if error,
-//return initial wallet info and err
+// memPoolTxQueryLoop constantly pass a transaction accepted by mempool to the wallet.
+func (w *Wallet) memPoolTxQueryLoop() {
+       for {
+               select {
+               case obj, ok := <-w.txMsgSub.Chan():
+                       if !ok {
+                               log.WithFields(log.Fields{"module": logModule}).Warning("tx pool tx msg subscription channel closed")
+                               return
+                       }
+
+                       ev, ok := obj.Data.(protocol.TxMsgEvent)
+                       if !ok {
+                               log.WithFields(log.Fields{"module": logModule}).Error("event type error")
+                               continue
+                       }
+
+                       switch ev.TxMsg.MsgType {
+                       case protocol.MsgNewTx:
+                               w.AddUnconfirmedTx(ev.TxMsg.TxDesc)
+                       case protocol.MsgRemoveTx:
+                               w.RemoveUnconfirmedTx(ev.TxMsg.TxDesc)
+                       default:
+                               log.WithFields(log.Fields{"module": logModule}).Warn("got unknow message type from the txPool channel")
+                       }
+               }
+       }
+}
+
+func (w *Wallet) checkWalletInfo() error {
+       if w.status.Version != currentVersion {
+               return errWalletVersionMismatch
+       } else if !w.chain.BlockExist(&w.status.BestHash) {
+               return errBestBlockNotFoundInCore
+       }
+
+       return nil
+}
+
+//loadWalletInfo return stored wallet info and nil,
+//if error, return initial wallet info and err
 func (w *Wallet) loadWalletInfo() error {
        if rawWallet := w.DB.Get(walletKey); rawWallet != nil {
-               return json.Unmarshal(rawWallet, &w.status)
+               if err := json.Unmarshal(rawWallet, &w.status); err != nil {
+                       return err
+               }
+
+               err := w.checkWalletInfo()
+               if err == nil {
+                       return nil
+               }
+
+               log.WithFields(log.Fields{"module": logModule}).Warn(err.Error())
+               w.deleteAccountTxs()
+               w.deleteUtxos()
        }
 
+       w.status.Version = currentVersion
+       w.status.WorkHash = bc.Hash{}
        block, err := w.chain.GetBlockByHeight(0)
        if err != nil {
                return err
@@ -97,10 +156,10 @@ func (w *Wallet) loadWalletInfo() error {
        return w.AttachBlock(block)
 }
 
-func (w *Wallet) commitWalletInfo(batch db.Batch) error {
+func (w *Wallet) commitWalletInfo(batch dbm.Batch) error {
        rawWallet, err := json.Marshal(w.status)
        if err != nil {
-               log.WithField("err", err).Error("save wallet info")
+               log.WithFields(log.Fields{"module": logModule, "err": err}).Error("save wallet info")
                return err
        }
 
@@ -109,37 +168,11 @@ func (w *Wallet) commitWalletInfo(batch db.Batch) error {
        return nil
 }
 
-//GetWalletInfo return stored wallet info and nil,if error,
-//return initial wallet info and err
-func (w *Wallet) loadKeysInfo() error {
-       if rawKeyInfo := w.DB.Get(privKeyKey); rawKeyInfo != nil {
-               json.Unmarshal(rawKeyInfo, &w.keysInfo)
-               return nil
-       }
-       return nil
-}
-
-func (w *Wallet) commitkeysInfo() error {
-       rawKeysInfo, err := json.Marshal(w.keysInfo)
-       if err != nil {
-               log.WithField("err", err).Error("save wallet info")
-               return err
-       }
-       w.DB.Set(privKeyKey, rawKeysInfo)
-       return nil
-}
-
-func (w *Wallet) getImportKeyFlag() bool {
-       for _, v := range w.keysInfo {
-               if v.Complete == false {
-                       return true
-               }
-       }
-       return false
-}
-
 // AttachBlock attach a new block
 func (w *Wallet) AttachBlock(block *types.Block) error {
+       w.rw.Lock()
+       defer w.rw.Unlock()
+
        if block.PreviousBlockHash != w.status.WorkHash {
                log.Warn("wallet skip attachBlock due to status hash not equal to previous hash")
                return nil
@@ -151,10 +184,17 @@ func (w *Wallet) AttachBlock(block *types.Block) error {
                return err
        }
 
+       if err := w.RecoveryMgr.FilterRecoveryTxs(block); err != nil {
+               log.WithField("err", err).Error("filter recovery txs")
+               w.RecoveryMgr.finished()
+       }
+
        storeBatch := w.DB.NewBatch()
-       w.indexTransactions(storeBatch, block, txStatus)
-       w.buildAccountUTXOs(storeBatch, block, txStatus)
+       if err := w.indexTransactions(storeBatch, block, txStatus); err != nil {
+               return err
+       }
 
+       w.attachUtxos(storeBatch, block, txStatus)
        w.status.WorkHeight = block.Height
        w.status.WorkHash = block.Hash()
        if w.status.WorkHeight >= w.status.BestHeight {
@@ -166,6 +206,9 @@ func (w *Wallet) AttachBlock(block *types.Block) error {
 
 // DetachBlock detach a block and rollback state
 func (w *Wallet) DetachBlock(block *types.Block) error {
+       w.rw.Lock()
+       defer w.rw.Unlock()
+
        blockHash := block.Hash()
        txStatus, err := w.chain.GetTransactionStatus(&blockHash)
        if err != nil {
@@ -173,7 +216,7 @@ func (w *Wallet) DetachBlock(block *types.Block) error {
        }
 
        storeBatch := w.DB.NewBatch()
-       w.reverseAccountUTXOs(storeBatch, block, txStatus)
+       w.detachUtxos(storeBatch, block, txStatus)
        w.deleteTransactions(storeBatch, w.status.BestHeight)
 
        w.status.BestHeight = block.Height - 1
@@ -190,178 +233,133 @@ func (w *Wallet) DetachBlock(block *types.Block) error {
 //WalletUpdate process every valid block and reverse every invalid block which need to rollback
 func (w *Wallet) walletUpdater() {
        for {
-               getRescanNotification(w)
-               checkRescanStatus(w)
+               w.getRescanNotification()
                for !w.chain.InMainChain(w.status.BestHash) {
                        block, err := w.chain.GetBlockByHash(&w.status.BestHash)
                        if err != nil {
-                               log.WithField("err", err).Error("walletUpdater GetBlockByHash")
+                               log.WithFields(log.Fields{"module": logModule, "err": err}).Error("walletUpdater GetBlockByHash")
                                return
                        }
 
                        if err := w.DetachBlock(block); err != nil {
-                               log.WithField("err", err).Error("walletUpdater detachBlock")
+                               log.WithFields(log.Fields{"module": logModule, "err": err}).Error("walletUpdater detachBlock stop")
                                return
                        }
                }
 
                block, _ := w.chain.GetBlockByHeight(w.status.WorkHeight + 1)
                if block == nil {
-                       <-w.chain.BlockWaiter(w.status.WorkHeight + 1)
+                       w.walletBlockWaiter()
                        continue
                }
 
                if err := w.AttachBlock(block); err != nil {
-                       log.WithField("err", err).Error("walletUpdater stop")
+                       log.WithFields(log.Fields{"module": logModule, "err": err}).Error("walletUpdater AttachBlock stop")
                        return
                }
        }
 }
 
-func getRescanNotification(w *Wallet) {
+//RescanBlocks provide a trigger to rescan blocks
+func (w *Wallet) RescanBlocks() {
        select {
-       case <-w.rescanProgress:
-               w.status.WorkHeight = 0
-               block, _ := w.chain.GetBlockByHeight(w.status.WorkHeight)
-               w.status.WorkHash = block.Hash()
+       case w.rescanCh <- struct{}{}:
        default:
                return
        }
 }
 
-// ExportAccountPrivKey exports the account private key as a WIF for encoding as a string
-// in the Wallet Import Formt.
-func (w *Wallet) ExportAccountPrivKey(xpub chainkd.XPub, auth string) (*string, error) {
-       xprv, err := w.Hsm.LoadChainKDKey(xpub, auth)
-       if err != nil {
-               return nil, err
-       }
-       var hashed [32]byte
-       sha3pool.Sum256(hashed[:], xprv[:])
+// deleteAccountTxs deletes all txs in wallet
+func (w *Wallet) deleteAccountTxs() {
+       storeBatch := w.DB.NewBatch()
 
-       tmp := append(xprv[:], hashed[:4]...)
-       res := base58.Encode(tmp)
-       return &res, nil
-}
+       txIter := w.DB.IteratorPrefix([]byte(TxPrefix))
+       defer txIter.Release()
 
-// ImportAccountPrivKey imports the account key in the Wallet Import Formt.
-func (w *Wallet) ImportAccountPrivKey(xprv chainkd.XPrv, keyAlias, auth string, index uint64, accountAlias string) (*pseudohsm.XPub, error) {
-       if w.Hsm.HasAlias(keyAlias) {
-               return nil, pseudohsm.ErrDuplicateKeyAlias
-       }
-       if w.Hsm.HasKey(xprv) {
-               return nil, pseudohsm.ErrDuplicateKey
+       for txIter.Next() {
+               storeBatch.Delete(txIter.Key())
        }
 
-       if acc, _ := w.AccountMgr.FindByAlias(nil, accountAlias); acc != nil {
-               return nil, account.ErrDuplicateAlias
-       }
+       txIndexIter := w.DB.IteratorPrefix([]byte(TxIndexPrefix))
+       defer txIndexIter.Release()
 
-       xpub, _, err := w.Hsm.ImportXPrvKey(auth, keyAlias, xprv)
-       if err != nil {
-               return nil, err
+       for txIndexIter.Next() {
+               storeBatch.Delete(txIndexIter.Key())
        }
 
-       newAccount, err := w.AccountMgr.Create(nil, []chainkd.XPub{xpub.XPub}, SINGLE, accountAlias)
-       if err != nil {
-               return nil, err
-       }
-       if err := w.recoveryAccountWalletDB(newAccount, xpub, index, keyAlias); err != nil {
-               return nil, err
-       }
-       return xpub, nil
+       storeBatch.Write()
 }
 
-// ImportAccountXpubKey imports the account key in the Wallet Import Formt.
-func (w *Wallet) ImportAccountXpubKey(xpubIndex int, xpub pseudohsm.XPub, cpIndex uint64) error {
-       accountAlias := fmt.Sprintf("recovery_%d", xpubIndex)
-
-       if acc, _ := w.AccountMgr.FindByAlias(nil, accountAlias); acc != nil {
-               return account.ErrDuplicateAlias
+func (w *Wallet) deleteUtxos() {
+       storeBatch := w.DB.NewBatch()
+       ruIter := w.DB.IteratorPrefix([]byte(account.UTXOPreFix))
+       defer ruIter.Release()
+       for ruIter.Next() {
+               storeBatch.Delete(ruIter.Key())
        }
 
-       newAccount, err := w.AccountMgr.Create(nil, []chainkd.XPub{xpub.XPub}, SINGLE, accountAlias)
-       if err != nil {
-               return err
+       suIter := w.DB.IteratorPrefix([]byte(account.SUTXOPrefix))
+       defer suIter.Release()
+       for suIter.Next() {
+               storeBatch.Delete(suIter.Key())
        }
-
-       return w.recoveryAccountWalletDB(newAccount, &xpub, cpIndex, xpub.Alias)
+       storeBatch.Write()
 }
 
-func (w *Wallet) recoveryAccountWalletDB(account *account.Account, XPub *pseudohsm.XPub, index uint64, keyAlias string) error {
-       if err := w.createProgram(account, XPub, index); err != nil {
+// DeleteAccount deletes account matching accountID, then rescan wallet
+func (w *Wallet) DeleteAccount(accountID string) (err error) {
+       w.rw.Lock()
+       defer w.rw.Unlock()
+
+       if err := w.AccountMgr.DeleteAccount(accountID); err != nil {
                return err
        }
-       w.ImportPrivKey = true
-       tmp := KeyInfo{
-               Alias:    keyAlias,
-               XPub:     XPub.XPub,
-               Complete: false,
-       }
-       w.keysInfo = append(w.keysInfo, tmp)
-       w.commitkeysInfo()
-       w.rescanBlocks()
 
+       w.deleteAccountTxs()
+       w.RescanBlocks()
        return nil
 }
 
-func (w *Wallet) createProgram(account *account.Account, XPub *pseudohsm.XPub, index uint64) error {
-       for i := uint64(0); i < index; i++ {
-               if _, err := w.AccountMgr.CreateAddress(nil, account.ID, false); err != nil {
-                       return err
-               }
+func (w *Wallet) UpdateAccountAlias(accountID string, newAlias string) (err error) {
+       w.rw.Lock()
+       defer w.rw.Unlock()
+
+       if err := w.AccountMgr.UpdateAccountAlias(accountID, newAlias); err != nil {
+               return err
        }
+
+       w.deleteAccountTxs()
+       w.RescanBlocks()
        return nil
 }
 
-func (w *Wallet) rescanBlocks() {
+func (w *Wallet) getRescanNotification() {
        select {
-       case w.rescanProgress <- struct{}{}:
+       case <-w.rescanCh:
+               w.setRescanStatus()
        default:
                return
        }
 }
 
-//GetRescanStatus return key import rescan status
-func (w *Wallet) GetRescanStatus() ([]KeyInfo, error) {
-       keysInfo := make([]KeyInfo, len(w.keysInfo))
-
-       if rawKeyInfo := w.DB.Get(privKeyKey); rawKeyInfo != nil {
-               if err := json.Unmarshal(rawKeyInfo, &keysInfo); err != nil {
-                       return nil, err
-               }
-       }
-
-       var status StatusInfo
-       if rawWallet := w.DB.Get(walletKey); rawWallet != nil {
-               if err := json.Unmarshal(rawWallet, &status); err != nil {
-                       return nil, err
-               }
-       }
-
-       for i := range keysInfo {
-               if keysInfo[i].Complete == true || status.BestHeight == 0 {
-                       keysInfo[i].Percent = 100
-                       continue
-               }
+func (w *Wallet) setRescanStatus() {
+       block, _ := w.chain.GetBlockByHeight(0)
+       w.status.WorkHash = bc.Hash{}
+       w.AttachBlock(block)
+}
 
-               keysInfo[i].Percent = uint8(status.WorkHeight * 100 / status.BestHeight)
-               if keysInfo[i].Percent == 100 {
-                       keysInfo[i].Complete = true
-               }
+func (w *Wallet) walletBlockWaiter() {
+       select {
+       case <-w.chain.BlockWaiter(w.status.WorkHeight + 1):
+       case <-w.rescanCh:
+               w.setRescanStatus()
        }
-       return keysInfo, nil
 }
 
-//checkRescanStatus mark private key import process `Complete` if rescan finished
-func checkRescanStatus(w *Wallet) {
-       if !w.ImportPrivKey || w.status.WorkHeight < w.status.BestHeight {
-               return
-       }
+// GetWalletStatusInfo return current wallet StatusInfo
+func (w *Wallet) GetWalletStatusInfo() StatusInfo {
+       w.rw.RLock()
+       defer w.rw.RUnlock()
 
-       w.ImportPrivKey = false
-       for _, keyInfo := range w.keysInfo {
-               keyInfo.Complete = true
-       }
-       w.commitkeysInfo()
+       return w.status
 }