From c41548e5cc1bc365017e2fca0ac7807c867d3650 Mon Sep 17 00:00:00 2001 From: Yongfeng LI Date: Mon, 16 Apr 2018 21:37:00 +0800 Subject: [PATCH] delete unused addresses for imported key and account --- account/accounts.go | 40 ++++++++++++++++++++++++++++++++------- account/accounts_test.go | 8 ++++---- account/builder.go | 4 ++-- api/accounts.go | 26 +++++++++++++------------ crypto/ed25519/chainkd/chainkd.go | 12 +++++++++++- wallet/indexer.go | 12 +++++++++++- wallet/set.go | 28 +++++++++++++++++++++++++++ wallet/wallet.go | 32 +++++++++++++++++++++---------- wallet/wallet_test.go | 11 ++++++----- 9 files changed, 131 insertions(+), 42 deletions(-) create mode 100644 wallet/set.go diff --git a/account/accounts.go b/account/accounts.go index 84925d08..b85cfd07 100644 --- a/account/accounts.go +++ b/account/accounts.go @@ -104,7 +104,7 @@ type Manager struct { delayedACPsMu sync.Mutex delayedACPs map[*txbuilder.TemplateBuilder][]*CtrlProgram - accIndexMu sync.Mutex + accIndexMu sync.Mutex } // ExpireReservations removes reservations that have expired periodically. @@ -180,7 +180,7 @@ func (m *Manager) FindByAlias(ctx context.Context, alias string) (*Account, erro cachedID, ok := m.aliasCache.Get(alias) m.cacheMu.Unlock() if ok { - return m.findByID(ctx, cachedID.(string)) + return m.FindByID(ctx, cachedID.(string)) } rawID := m.db.Get(aliasKey(alias)) @@ -192,11 +192,11 @@ func (m *Manager) FindByAlias(ctx context.Context, alias string) (*Account, erro m.cacheMu.Lock() m.aliasCache.Add(alias, accountID) m.cacheMu.Unlock() - return m.findByID(ctx, accountID) + return m.FindByID(ctx, accountID) } -// findByID returns an account's Signer record by its ID. -func (m *Manager) findByID(ctx context.Context, id string) (*Account, error) { +// FindByID returns an account's Signer record by its ID. +func (m *Manager) FindByID(ctx context.Context, id string) (*Account, error) { m.cacheMu.Lock() cachedAccount, ok := m.cache.Get(id) m.cacheMu.Unlock() @@ -244,7 +244,7 @@ func (m *Manager) CreateCtrlProgramForChange(ctx context.Context, accountID stri // CreateAddress generate an address for the select account func (m *Manager) CreateAddress(ctx context.Context, accountID string, change bool) (cp *CtrlProgram, err error) { - account, err := m.findByID(ctx, accountID) + account, err := m.FindByID(ctx, accountID) if err != nil { return nil, err } @@ -268,6 +268,23 @@ func (m *Manager) createAddress(ctx context.Context, account *Account, change bo return cp, nil } +// listAddressesById +func (m *Manager) ListCtrlProgramsByXpubs(ctx context.Context, xpubs []chainkd.XPub) ([]*CtrlProgram, error) { + cps, err := m.ListControlProgram() + if err != nil { + return nil, err + } + + var result []*CtrlProgram + for _, cp := range cps { + if cp.Address == "" || chainkd.CompareTwoXPubs(cp.XPubs, xpubs) != 0 { + continue + } + result = append(result, cp) + } + return result, nil +} + func (m *Manager) createP2PKH(ctx context.Context, account *Account, change bool) (*CtrlProgram, error) { idx := m.getNextXpubsIndex(account.Signer.XPubs) path := signers.Path(account.Signer, signers.AccountKeySpace, idx) @@ -288,6 +305,7 @@ func (m *Manager) createP2PKH(ctx context.Context, account *Account, change bool return &CtrlProgram{ AccountID: account.ID, + XPubs: account.Signer.XPubs, Address: address.EncodeAddress(), KeyIndex: idx, ControlProgram: control, @@ -319,6 +337,7 @@ func (m *Manager) createP2SH(ctx context.Context, account *Account, change bool) return &CtrlProgram{ AccountID: account.ID, + XPubs: account.Signer.XPubs, Address: address.EncodeAddress(), KeyIndex: idx, ControlProgram: control, @@ -329,6 +348,7 @@ func (m *Manager) createP2SH(ctx context.Context, account *Account, change bool) //CtrlProgram is structure of account control program type CtrlProgram struct { AccountID string + XPubs []chainkd.XPub Address string KeyIndex uint64 ControlProgram []byte @@ -349,6 +369,12 @@ func (m *Manager) insertAccountControlProgram(ctx context.Context, progs ...*Ctr return nil } +func (m *Manager) DeleteAccountControlProgram(prog []byte) { + var hash common.Hash + sha3pool.Sum256(hash[:], prog) + m.db.Delete(CPKey(hash)) +} + // IsLocalControlProgram check is the input control program belong to local func (m *Manager) IsLocalControlProgram(prog []byte) bool { var hash common.Hash @@ -399,7 +425,7 @@ type Info struct { func (m *Manager) DeleteAccount(aliasOrId string) (err error) { account := &Account{} if account, err = m.FindByAlias(nil, aliasOrId); err != nil { - if account, err = m.findByID(nil, aliasOrId); err != nil { + if account, err = m.FindByID(nil, aliasOrId); err != nil { return err } } diff --git a/account/accounts_test.go b/account/accounts_test.go index b4891ec6..9be886ba 100644 --- a/account/accounts_test.go +++ b/account/accounts_test.go @@ -53,7 +53,7 @@ func TestCreateAccount(t *testing.T) { testutil.FatalErr(t, err) } - found, err := m.findByID(ctx, account.ID) + found, err := m.FindByID(ctx, account.ID) if err != nil { t.Errorf("unexpected error %v", err) } @@ -91,7 +91,7 @@ func TestDeleteAccount(t *testing.T) { testutil.FatalErr(t, err) } - found, err := m.findByID(ctx, account1.ID) + found, err := m.FindByID(ctx, account1.ID) if err != nil { t.Errorf("expected account %v should be deleted", found) } @@ -100,7 +100,7 @@ func TestDeleteAccount(t *testing.T) { testutil.FatalErr(t, err) } - found, err = m.findByID(ctx, account2.ID) + found, err = m.FindByID(ctx, account2.ID) if err != nil { t.Errorf("expected account %v should be deleted", found) } @@ -111,7 +111,7 @@ func TestFindByID(t *testing.T) { ctx := context.Background() account := m.createTestAccount(ctx, t, "", nil) - found, err := m.findByID(ctx, account.ID) + found, err := m.FindByID(ctx, account.ID) if err != nil { testutil.FatalErr(t, err) } diff --git a/account/builder.go b/account/builder.go index 945c0855..dda4ec3e 100644 --- a/account/builder.go +++ b/account/builder.go @@ -43,7 +43,7 @@ func (a *spendAction) Build(ctx context.Context, b *txbuilder.TemplateBuilder) e return txbuilder.MissingFieldsError(missing...) } - acct, err := a.accounts.findByID(ctx, a.AccountID) + acct, err := a.accounts.FindByID(ctx, a.AccountID) if err != nil { return errors.Wrap(err, "get account info") } @@ -115,7 +115,7 @@ func (a *spendUTXOAction) Build(ctx context.Context, b *txbuilder.TemplateBuilde var accountSigner *signers.Signer if len(res.Source.AccountID) != 0 { - account, err := a.accounts.findByID(ctx, res.Source.AccountID) + account, err := a.accounts.FindByID(ctx, res.Source.AccountID) if err != nil { return err } diff --git a/api/accounts.go b/api/accounts.go index f08497ba..e0cf81c9 100644 --- a/api/accounts.go +++ b/api/accounts.go @@ -2,6 +2,7 @@ package api import ( "context" + log "github.com/sirupsen/logrus" "github.com/bytom/account" @@ -13,9 +14,9 @@ import ( // POST /create-account func (a *API) createAccount(ctx context.Context, ins struct { - RootXPubs []chainkd.XPub `json:"root_xpubs"` - Quorum int `json:"quorum"` - Alias string `json:"alias"` + RootXPubs []chainkd.XPub `json:"root_xpubs"` + Quorum int `json:"quorum"` + Alias string `json:"alias"` }) Response { acc, err := a.wallet.AccountMgr.Create(ctx, ins.RootXPubs, ins.Quorum, ins.Alias) if err != nil { @@ -90,29 +91,30 @@ func (a *API) listAddresses(ctx context.Context, ins struct { AccountAlias string `json:"account_alias"` }) Response { accountID := ins.AccountID + var target *account.Account if ins.AccountAlias != "" { acc, err := a.wallet.AccountMgr.FindByAlias(ctx, ins.AccountAlias) if err != nil { return NewErrorResponse(err) } - - accountID = acc.ID + target = acc + } else { + acc, err := a.wallet.AccountMgr.FindByID(ctx, accountID) + if err != nil { + return NewErrorResponse(err) + } + target = acc } - cps, err := a.wallet.AccountMgr.ListControlProgram() + cps, err := a.wallet.AccountMgr.ListCtrlProgramsByXpubs(ctx, target.XPubs) if err != nil { return NewErrorResponse(err) } var addresses []*addressResp for _, cp := range cps { - if cp.Address == "" || (accountID != "" && accountID != cp.AccountID) { - continue - } - - accountAlias := a.wallet.AccountMgr.GetAliasByID(cp.AccountID) addresses = append(addresses, &addressResp{ - AccountAlias: accountAlias, + AccountAlias: target.Alias, AccountID: cp.AccountID, Address: cp.Address, Change: cp.Change, diff --git a/crypto/ed25519/chainkd/chainkd.go b/crypto/ed25519/chainkd/chainkd.go index 5d8ca4ea..23f6fc24 100644 --- a/crypto/ed25519/chainkd/chainkd.go +++ b/crypto/ed25519/chainkd/chainkd.go @@ -8,6 +8,7 @@ import ( "github.com/bytom/crypto/ed25519" "github.com/bytom/crypto/ed25519/ecmath" + "bytes" ) type ( @@ -17,7 +18,16 @@ type ( XPub [64]byte ) -var one = [32]byte{1} +// CompareTwoXPubs +func CompareTwoXPubs(a, b []XPub) int { + for i, xpub := range a { + result := bytes.Compare(xpub[:], b[i][:]) + if result != 0 { + return result + } + } + return 0 +} // NewXPrv takes a source of random bytes and produces a new XPrv. // If r is nil, crypto/rand.Reader is used. diff --git a/wallet/indexer.go b/wallet/indexer.go index 4a55b13a..c41d00b5 100644 --- a/wallet/indexer.go +++ b/wallet/indexer.go @@ -383,15 +383,25 @@ func (w *Wallet) filterAccountTxs(b *types.Block, txStatus *bc.TransactionStatus transactionLoop: for pos, tx := range b.Transactions { statusFail, _ := txStatus.GetStatus(pos) + isLocal := false for _, v := range tx.Outputs { var hash [32]byte sha3pool.Sum256(hash[:], v.ControlProgram) if bytes := w.DB.Get(account.CPKey(hash)); bytes != nil { + cp := &account.CtrlProgram{} + if err := json.Unmarshal(bytes, cp); err == nil { + w.status.selfProgramsOnChain.Add(cp.Address) + } + annotatedTxs = append(annotatedTxs, w.buildAnnotatedTransaction(tx, b, statusFail, pos)) - continue transactionLoop + isLocal = true } } + if isLocal { + continue + } + for _, v := range tx.Inputs { outid, err := v.SpentOutputID() if err != nil { diff --git a/wallet/set.go b/wallet/set.go new file mode 100644 index 00000000..627ebd28 --- /dev/null +++ b/wallet/set.go @@ -0,0 +1,28 @@ +package wallet + +type Set map[interface{}]bool + +func NewSet() Set { + return make(Set) +} + +// Add Add the specified element to this set if it is not already present (optional operation) +func (s *Set) Add(i interface{}) bool { + _, found := (*s)[i] + if found { + return false //False if it existed already + } + + (*s)[i] = true + return true +} + +// Contains Returns true if this set contains the specified elements +func (s *Set) Contains(i ...interface{}) bool { + for _, val := range i { + if _, ok := (*s)[val]; !ok { + return false + } + } + return true +} diff --git a/wallet/wallet.go b/wallet/wallet.go index 117476e3..5a061ae5 100644 --- a/wallet/wallet.go +++ b/wallet/wallet.go @@ -33,10 +33,12 @@ type StatusInfo struct { WorkHash bc.Hash BestHeight uint64 BestHash bc.Hash + selfProgramsOnChain Set } //KeyInfo is key import status type KeyInfo struct { + account account.Account Alias string `json:"alias"` XPub chainkd.XPub `json:"xpub"` Percent uint8 `json:"percent"` @@ -59,13 +61,13 @@ type Wallet struct { //NewWallet return a new wallet instance func NewWallet(walletDB db.DB, account *account.Manager, asset *asset.Registry, hsm *pseudohsm.HSM, chain *protocol.Chain) (*Wallet, error) { w := &Wallet{ - DB: walletDB, - AccountMgr: account, - AssetReg: asset, - chain: chain, - Hsm: hsm, - rescanProgress: make(chan struct{}, 1), - importingKeysInfo: make([]KeyInfo, 0), + DB: walletDB, + AccountMgr: account, + AssetReg: asset, + chain: chain, + Hsm: hsm, + rescanProgress: make(chan struct{}, 1), + importingKeysInfo: make([]KeyInfo, 0), } if err := w.loadWalletInfo(); err != nil { @@ -90,6 +92,7 @@ func (w *Wallet) loadWalletInfo() error { return json.Unmarshal(rawWallet, &w.status) } + w.status.selfProgramsOnChain = NewSet() block, err := w.chain.GetBlockByHeight(0) if err != nil { return err @@ -191,7 +194,7 @@ func (w *Wallet) DetachBlock(block *types.Block) error { func (w *Wallet) walletUpdater() { for { getRescanNotification(w) - updateRescanStatus(w) + w.updateRescanStatus() for !w.chain.InMainChain(w.status.BestHash) { block, err := w.chain.GetBlockByHash(&w.status.BestHash) if err != nil { @@ -297,6 +300,7 @@ func (w *Wallet) recoveryAccountWalletDB(account *account.Account, XPub *pseudoh } w.ImportingPrivateKey = true tmp := KeyInfo{ + account: *account, Alias: keyAlias, XPub: XPub.XPub, Complete: false, @@ -339,7 +343,7 @@ func (w *Wallet) GetRescanStatus() ([]KeyInfo, error) { } //updateRescanStatus mark private key import process `Complete` if rescan finished -func updateRescanStatus(w *Wallet) { +func (w *Wallet) updateRescanStatus() { if !w.ImportingPrivateKey { return } @@ -349,6 +353,7 @@ func updateRescanStatus(w *Wallet) { for _, keyInfo := range w.importingKeysInfo { keyInfo.Percent = percent } + w.commitkeysInfo() return } @@ -356,7 +361,14 @@ func updateRescanStatus(w *Wallet) { for _, keyInfo := range w.importingKeysInfo { keyInfo.Percent = 100 keyInfo.Complete = true + + if cps, err := w.AccountMgr.ListCtrlProgramsByXpubs(nil, keyInfo.account.XPubs); err == nil { + for _, cp := range cps { + if !w.status.selfProgramsOnChain.Contains(cp.Address) { + w.AccountMgr.DeleteAccountControlProgram(cp.ControlProgram) + } + } + } } w.commitkeysInfo() - // TODO: delete the generated but not used addresses } diff --git a/wallet/wallet_test.go b/wallet/wallet_test.go index 43d8fe31..76c80099 100644 --- a/wallet/wallet_test.go +++ b/wallet/wallet_test.go @@ -236,11 +236,12 @@ func mockTxData(utxos []*account.UTXO, testAccount *account.Account) (*txbuilder func mockWallet(walletDB dbm.DB, account *account.Manager, asset *asset.Registry, chain *protocol.Chain) *Wallet { return &Wallet{ - DB: walletDB, - AccountMgr: account, - AssetReg: asset, - chain: chain, - rescanProgress: make(chan struct{}, 1), + DB: walletDB, + AccountMgr: account, + AssetReg: asset, + chain: chain, + rescanProgress: make(chan struct{}, 1), + selfProgramsOnChain: NewSet(), } } -- 2.11.0