import (
"encoding/json"
- "fmt"
+ "sync"
log "github.com/sirupsen/logrus"
- "github.com/tendermint/go-wire/data/base58"
- "github.com/tendermint/tmlibs/db"
-
- "github.com/bytom/account"
- "github.com/bytom/asset"
- "github.com/bytom/blockchain/pseudohsm"
- "github.com/bytom/crypto/ed25519/chainkd"
- "github.com/bytom/crypto/sha3pool"
- "github.com/bytom/protocol"
- "github.com/bytom/protocol/bc"
- "github.com/bytom/protocol/bc/types"
+
+ "github.com/bytom/bytom/account"
+ "github.com/bytom/bytom/asset"
+ "github.com/bytom/bytom/blockchain/pseudohsm"
+ dbm "github.com/bytom/bytom/database/leveldb"
+ "github.com/bytom/bytom/errors"
+ "github.com/bytom/bytom/event"
+ "github.com/bytom/bytom/protocol"
+ "github.com/bytom/bytom/protocol/bc"
+ "github.com/bytom/bytom/protocol/bc/types"
)
-//SINGLE single sign
-const SINGLE = 1
+const (
+ //SINGLE single sign
+ SINGLE = 1
+ logModule = "wallet"
+)
-//RecoveryIndex walletdb recovery cp number
-const RecoveryIndex = 5000
+var (
+ currentVersion = uint(1)
+ walletKey = []byte("walletInfo")
-var walletKey = []byte("walletInfo")
-var privKeyKey = []byte("keysInfo")
+ errBestBlockNotFoundInCore = errors.New("best block not found in core")
+ errWalletVersionMismatch = errors.New("wallet version mismatch")
+)
//StatusInfo is base valid block info to handle orphan block rollback
type StatusInfo struct {
+ Version uint
WorkHeight uint64
WorkHash bc.Hash
BestHeight uint64
BestHash bc.Hash
}
-//KeyInfo is key import status
-type KeyInfo struct {
- Alias string `json:"alias"`
- XPub chainkd.XPub `json:"xpub"`
- Percent uint8 `json:"percent"`
- Complete bool `json:"complete"`
-}
-
//Wallet is related to storing account unspent outputs
type Wallet struct {
- DB db.DB
- status StatusInfo
- AccountMgr *account.Manager
- AssetReg *asset.Registry
- Hsm *pseudohsm.HSM
- chain *protocol.Chain
- rescanProgress chan struct{}
- ImportPrivKey bool
- keysInfo []KeyInfo
+ DB dbm.DB
+ rw sync.RWMutex
+ status StatusInfo
+ TxIndexFlag bool
+ AccountMgr *account.Manager
+ AssetReg *asset.Registry
+ Hsm *pseudohsm.HSM
+ chain *protocol.Chain
+ RecoveryMgr *recoveryManager
+ eventDispatcher *event.Dispatcher
+ txMsgSub *event.Subscription
+
+ rescanCh chan struct{}
}
//NewWallet return a new wallet instance
-func NewWallet(walletDB db.DB, account *account.Manager, asset *asset.Registry, hsm *pseudohsm.HSM, chain *protocol.Chain) (*Wallet, error) {
+func NewWallet(walletDB dbm.DB, account *account.Manager, asset *asset.Registry, hsm *pseudohsm.HSM, chain *protocol.Chain, dispatcher *event.Dispatcher, txIndexFlag bool) (*Wallet, error) {
w := &Wallet{
- DB: walletDB,
- AccountMgr: account,
- AssetReg: asset,
- chain: chain,
- Hsm: hsm,
- rescanProgress: make(chan struct{}, 1),
- keysInfo: make([]KeyInfo, 0),
+ DB: walletDB,
+ AccountMgr: account,
+ AssetReg: asset,
+ chain: chain,
+ Hsm: hsm,
+ RecoveryMgr: newRecoveryManager(walletDB, account),
+ eventDispatcher: dispatcher,
+ rescanCh: make(chan struct{}, 1),
+ TxIndexFlag: txIndexFlag,
}
if err := w.loadWalletInfo(); err != nil {
return nil, err
}
- if err := w.loadKeysInfo(); err != nil {
+ if err := w.RecoveryMgr.LoadStatusInfo(); err != nil {
return nil, err
}
- w.ImportPrivKey = w.getImportKeyFlag()
+ var err error
+ w.txMsgSub, err = w.eventDispatcher.Subscribe(protocol.TxMsgEvent{})
+ if err != nil {
+ return nil, err
+ }
go w.walletUpdater()
-
+ go w.delUnconfirmedTx()
+ go w.memPoolTxQueryLoop()
return w, nil
}
-//GetWalletInfo return stored wallet info and nil,if error,
-//return initial wallet info and err
+// memPoolTxQueryLoop constantly pass a transaction accepted by mempool to the wallet.
+func (w *Wallet) memPoolTxQueryLoop() {
+ for {
+ select {
+ case obj, ok := <-w.txMsgSub.Chan():
+ if !ok {
+ log.WithFields(log.Fields{"module": logModule}).Warning("tx pool tx msg subscription channel closed")
+ return
+ }
+
+ ev, ok := obj.Data.(protocol.TxMsgEvent)
+ if !ok {
+ log.WithFields(log.Fields{"module": logModule}).Error("event type error")
+ continue
+ }
+
+ switch ev.TxMsg.MsgType {
+ case protocol.MsgNewTx:
+ w.AddUnconfirmedTx(ev.TxMsg.TxDesc)
+ case protocol.MsgRemoveTx:
+ w.RemoveUnconfirmedTx(ev.TxMsg.TxDesc)
+ default:
+ log.WithFields(log.Fields{"module": logModule}).Warn("got unknow message type from the txPool channel")
+ }
+ }
+ }
+}
+
+func (w *Wallet) checkWalletInfo() error {
+ if w.status.Version != currentVersion {
+ return errWalletVersionMismatch
+ } else if !w.chain.BlockExist(&w.status.BestHash) {
+ return errBestBlockNotFoundInCore
+ }
+
+ return nil
+}
+
+//loadWalletInfo return stored wallet info and nil,
+//if error, return initial wallet info and err
func (w *Wallet) loadWalletInfo() error {
if rawWallet := w.DB.Get(walletKey); rawWallet != nil {
- return json.Unmarshal(rawWallet, &w.status)
+ if err := json.Unmarshal(rawWallet, &w.status); err != nil {
+ return err
+ }
+
+ err := w.checkWalletInfo()
+ if err == nil {
+ return nil
+ }
+
+ log.WithFields(log.Fields{"module": logModule}).Warn(err.Error())
+ w.deleteAccountTxs()
+ w.deleteUtxos()
}
+ w.status.Version = currentVersion
+ w.status.WorkHash = bc.Hash{}
block, err := w.chain.GetBlockByHeight(0)
if err != nil {
return err
return w.AttachBlock(block)
}
-func (w *Wallet) commitWalletInfo(batch db.Batch) error {
+func (w *Wallet) commitWalletInfo(batch dbm.Batch) error {
rawWallet, err := json.Marshal(w.status)
if err != nil {
- log.WithField("err", err).Error("save wallet info")
+ log.WithFields(log.Fields{"module": logModule, "err": err}).Error("save wallet info")
return err
}
return nil
}
-//GetWalletInfo return stored wallet info and nil,if error,
-//return initial wallet info and err
-func (w *Wallet) loadKeysInfo() error {
- if rawKeyInfo := w.DB.Get(privKeyKey); rawKeyInfo != nil {
- json.Unmarshal(rawKeyInfo, &w.keysInfo)
- return nil
- }
- return nil
-}
-
-func (w *Wallet) commitkeysInfo() error {
- rawKeysInfo, err := json.Marshal(w.keysInfo)
- if err != nil {
- log.WithField("err", err).Error("save wallet info")
- return err
- }
- w.DB.Set(privKeyKey, rawKeysInfo)
- return nil
-}
-
-func (w *Wallet) getImportKeyFlag() bool {
- for _, v := range w.keysInfo {
- if v.Complete == false {
- return true
- }
- }
- return false
-}
-
// AttachBlock attach a new block
func (w *Wallet) AttachBlock(block *types.Block) error {
+ w.rw.Lock()
+ defer w.rw.Unlock()
+
if block.PreviousBlockHash != w.status.WorkHash {
log.Warn("wallet skip attachBlock due to status hash not equal to previous hash")
return nil
return err
}
+ if err := w.RecoveryMgr.FilterRecoveryTxs(block); err != nil {
+ log.WithField("err", err).Error("filter recovery txs")
+ w.RecoveryMgr.finished()
+ }
+
storeBatch := w.DB.NewBatch()
- w.indexTransactions(storeBatch, block, txStatus)
- w.buildAccountUTXOs(storeBatch, block, txStatus)
+ if err := w.indexTransactions(storeBatch, block, txStatus); err != nil {
+ return err
+ }
+ w.attachUtxos(storeBatch, block, txStatus)
w.status.WorkHeight = block.Height
w.status.WorkHash = block.Hash()
if w.status.WorkHeight >= w.status.BestHeight {
// DetachBlock detach a block and rollback state
func (w *Wallet) DetachBlock(block *types.Block) error {
+ w.rw.Lock()
+ defer w.rw.Unlock()
+
blockHash := block.Hash()
txStatus, err := w.chain.GetTransactionStatus(&blockHash)
if err != nil {
}
storeBatch := w.DB.NewBatch()
- w.reverseAccountUTXOs(storeBatch, block, txStatus)
+ w.detachUtxos(storeBatch, block, txStatus)
w.deleteTransactions(storeBatch, w.status.BestHeight)
w.status.BestHeight = block.Height - 1
//WalletUpdate process every valid block and reverse every invalid block which need to rollback
func (w *Wallet) walletUpdater() {
for {
- getRescanNotification(w)
- checkRescanStatus(w)
+ w.getRescanNotification()
for !w.chain.InMainChain(w.status.BestHash) {
block, err := w.chain.GetBlockByHash(&w.status.BestHash)
if err != nil {
- log.WithField("err", err).Error("walletUpdater GetBlockByHash")
+ log.WithFields(log.Fields{"module": logModule, "err": err}).Error("walletUpdater GetBlockByHash")
return
}
if err := w.DetachBlock(block); err != nil {
- log.WithField("err", err).Error("walletUpdater detachBlock")
+ log.WithFields(log.Fields{"module": logModule, "err": err}).Error("walletUpdater detachBlock stop")
return
}
}
block, _ := w.chain.GetBlockByHeight(w.status.WorkHeight + 1)
if block == nil {
- <-w.chain.BlockWaiter(w.status.WorkHeight + 1)
+ w.walletBlockWaiter()
continue
}
if err := w.AttachBlock(block); err != nil {
- log.WithField("err", err).Error("walletUpdater stop")
+ log.WithFields(log.Fields{"module": logModule, "err": err}).Error("walletUpdater AttachBlock stop")
return
}
}
}
-func getRescanNotification(w *Wallet) {
+//RescanBlocks provide a trigger to rescan blocks
+func (w *Wallet) RescanBlocks() {
select {
- case <-w.rescanProgress:
- w.status.WorkHeight = 0
- block, _ := w.chain.GetBlockByHeight(w.status.WorkHeight)
- w.status.WorkHash = block.Hash()
+ case w.rescanCh <- struct{}{}:
default:
return
}
}
-// ExportAccountPrivKey exports the account private key as a WIF for encoding as a string
-// in the Wallet Import Formt.
-func (w *Wallet) ExportAccountPrivKey(xpub chainkd.XPub, auth string) (*string, error) {
- xprv, err := w.Hsm.LoadChainKDKey(xpub, auth)
- if err != nil {
- return nil, err
- }
- var hashed [32]byte
- sha3pool.Sum256(hashed[:], xprv[:])
+// deleteAccountTxs deletes all txs in wallet
+func (w *Wallet) deleteAccountTxs() {
+ storeBatch := w.DB.NewBatch()
- tmp := append(xprv[:], hashed[:4]...)
- res := base58.Encode(tmp)
- return &res, nil
-}
+ txIter := w.DB.IteratorPrefix([]byte(TxPrefix))
+ defer txIter.Release()
-// ImportAccountPrivKey imports the account key in the Wallet Import Formt.
-func (w *Wallet) ImportAccountPrivKey(xprv chainkd.XPrv, keyAlias, auth string, index uint64, accountAlias string) (*pseudohsm.XPub, error) {
- if w.Hsm.HasAlias(keyAlias) {
- return nil, pseudohsm.ErrDuplicateKeyAlias
- }
- if w.Hsm.HasKey(xprv) {
- return nil, pseudohsm.ErrDuplicateKey
+ for txIter.Next() {
+ storeBatch.Delete(txIter.Key())
}
- if acc, _ := w.AccountMgr.FindByAlias(nil, accountAlias); acc != nil {
- return nil, account.ErrDuplicateAlias
- }
+ txIndexIter := w.DB.IteratorPrefix([]byte(TxIndexPrefix))
+ defer txIndexIter.Release()
- xpub, _, err := w.Hsm.ImportXPrvKey(auth, keyAlias, xprv)
- if err != nil {
- return nil, err
+ for txIndexIter.Next() {
+ storeBatch.Delete(txIndexIter.Key())
}
- newAccount, err := w.AccountMgr.Create(nil, []chainkd.XPub{xpub.XPub}, SINGLE, accountAlias)
- if err != nil {
- return nil, err
- }
- if err := w.recoveryAccountWalletDB(newAccount, xpub, index, keyAlias); err != nil {
- return nil, err
- }
- return xpub, nil
+ storeBatch.Write()
}
-// ImportAccountXpubKey imports the account key in the Wallet Import Formt.
-func (w *Wallet) ImportAccountXpubKey(xpubIndex int, xpub pseudohsm.XPub, cpIndex uint64) error {
- accountAlias := fmt.Sprintf("recovery_%d", xpubIndex)
-
- if acc, _ := w.AccountMgr.FindByAlias(nil, accountAlias); acc != nil {
- return account.ErrDuplicateAlias
+func (w *Wallet) deleteUtxos() {
+ storeBatch := w.DB.NewBatch()
+ ruIter := w.DB.IteratorPrefix([]byte(account.UTXOPreFix))
+ defer ruIter.Release()
+ for ruIter.Next() {
+ storeBatch.Delete(ruIter.Key())
}
- newAccount, err := w.AccountMgr.Create(nil, []chainkd.XPub{xpub.XPub}, SINGLE, accountAlias)
- if err != nil {
- return err
+ suIter := w.DB.IteratorPrefix([]byte(account.SUTXOPrefix))
+ defer suIter.Release()
+ for suIter.Next() {
+ storeBatch.Delete(suIter.Key())
}
-
- return w.recoveryAccountWalletDB(newAccount, &xpub, cpIndex, xpub.Alias)
+ storeBatch.Write()
}
-func (w *Wallet) recoveryAccountWalletDB(account *account.Account, XPub *pseudohsm.XPub, index uint64, keyAlias string) error {
- if err := w.createProgram(account, XPub, index); err != nil {
+// DeleteAccount deletes account matching accountID, then rescan wallet
+func (w *Wallet) DeleteAccount(accountID string) (err error) {
+ w.rw.Lock()
+ defer w.rw.Unlock()
+
+ if err := w.AccountMgr.DeleteAccount(accountID); err != nil {
return err
}
- w.ImportPrivKey = true
- tmp := KeyInfo{
- Alias: keyAlias,
- XPub: XPub.XPub,
- Complete: false,
- }
- w.keysInfo = append(w.keysInfo, tmp)
- w.commitkeysInfo()
- w.rescanBlocks()
+ w.deleteAccountTxs()
+ w.RescanBlocks()
return nil
}
-func (w *Wallet) createProgram(account *account.Account, XPub *pseudohsm.XPub, index uint64) error {
- for i := uint64(0); i < index; i++ {
- if _, err := w.AccountMgr.CreateAddress(nil, account.ID, false); err != nil {
- return err
- }
+func (w *Wallet) UpdateAccountAlias(accountID string, newAlias string) (err error) {
+ w.rw.Lock()
+ defer w.rw.Unlock()
+
+ if err := w.AccountMgr.UpdateAccountAlias(accountID, newAlias); err != nil {
+ return err
}
+
+ w.deleteAccountTxs()
+ w.RescanBlocks()
return nil
}
-func (w *Wallet) rescanBlocks() {
+func (w *Wallet) getRescanNotification() {
select {
- case w.rescanProgress <- struct{}{}:
+ case <-w.rescanCh:
+ w.setRescanStatus()
default:
return
}
}
-//GetRescanStatus return key import rescan status
-func (w *Wallet) GetRescanStatus() ([]KeyInfo, error) {
- keysInfo := make([]KeyInfo, len(w.keysInfo))
-
- if rawKeyInfo := w.DB.Get(privKeyKey); rawKeyInfo != nil {
- if err := json.Unmarshal(rawKeyInfo, &keysInfo); err != nil {
- return nil, err
- }
- }
-
- var status StatusInfo
- if rawWallet := w.DB.Get(walletKey); rawWallet != nil {
- if err := json.Unmarshal(rawWallet, &status); err != nil {
- return nil, err
- }
- }
-
- for i := range keysInfo {
- if keysInfo[i].Complete == true || status.BestHeight == 0 {
- keysInfo[i].Percent = 100
- continue
- }
+func (w *Wallet) setRescanStatus() {
+ block, _ := w.chain.GetBlockByHeight(0)
+ w.status.WorkHash = bc.Hash{}
+ w.AttachBlock(block)
+}
- keysInfo[i].Percent = uint8(status.WorkHeight * 100 / status.BestHeight)
- if keysInfo[i].Percent == 100 {
- keysInfo[i].Complete = true
- }
+func (w *Wallet) walletBlockWaiter() {
+ select {
+ case <-w.chain.BlockWaiter(w.status.WorkHeight + 1):
+ case <-w.rescanCh:
+ w.setRescanStatus()
}
- return keysInfo, nil
}
-//checkRescanStatus mark private key import process `Complete` if rescan finished
-func checkRescanStatus(w *Wallet) {
- if !w.ImportPrivKey || w.status.WorkHeight < w.status.BestHeight {
- return
- }
+// GetWalletStatusInfo return current wallet StatusInfo
+func (w *Wallet) GetWalletStatusInfo() StatusInfo {
+ w.rw.RLock()
+ defer w.rw.RUnlock()
- w.ImportPrivKey = false
- for _, keyInfo := range w.keysInfo {
- keyInfo.Complete = true
- }
- w.commitkeysInfo()
+ return w.status
}