OSDN Git Service

delete unused addresses for imported key and account
authorYongfeng LI <wliyongfeng@gmail.com>
Mon, 16 Apr 2018 13:37:00 +0000 (21:37 +0800)
committerYongfeng LI <wliyongfeng@gmail.com>
Tue, 17 Apr 2018 05:33:44 +0000 (13:33 +0800)
account/accounts.go
account/accounts_test.go
account/builder.go
api/accounts.go
crypto/ed25519/chainkd/chainkd.go
wallet/indexer.go
wallet/set.go [new file with mode: 0644]
wallet/wallet.go
wallet/wallet_test.go

index 84925d0..b85cfd0 100644 (file)
@@ -104,7 +104,7 @@ type Manager struct {
        delayedACPsMu sync.Mutex
        delayedACPs   map[*txbuilder.TemplateBuilder][]*CtrlProgram
 
-       accIndexMu  sync.Mutex
+       accIndexMu sync.Mutex
 }
 
 // ExpireReservations removes reservations that have expired periodically.
@@ -180,7 +180,7 @@ func (m *Manager) FindByAlias(ctx context.Context, alias string) (*Account, erro
        cachedID, ok := m.aliasCache.Get(alias)
        m.cacheMu.Unlock()
        if ok {
-               return m.findByID(ctx, cachedID.(string))
+               return m.FindByID(ctx, cachedID.(string))
        }
 
        rawID := m.db.Get(aliasKey(alias))
@@ -192,11 +192,11 @@ func (m *Manager) FindByAlias(ctx context.Context, alias string) (*Account, erro
        m.cacheMu.Lock()
        m.aliasCache.Add(alias, accountID)
        m.cacheMu.Unlock()
-       return m.findByID(ctx, accountID)
+       return m.FindByID(ctx, accountID)
 }
 
-// findByID returns an account's Signer record by its ID.
-func (m *Manager) findByID(ctx context.Context, id string) (*Account, error) {
+// FindByID returns an account's Signer record by its ID.
+func (m *Manager) FindByID(ctx context.Context, id string) (*Account, error) {
        m.cacheMu.Lock()
        cachedAccount, ok := m.cache.Get(id)
        m.cacheMu.Unlock()
@@ -244,7 +244,7 @@ func (m *Manager) CreateCtrlProgramForChange(ctx context.Context, accountID stri
 
 // CreateAddress generate an address for the select account
 func (m *Manager) CreateAddress(ctx context.Context, accountID string, change bool) (cp *CtrlProgram, err error) {
-       account, err := m.findByID(ctx, accountID)
+       account, err := m.FindByID(ctx, accountID)
        if err != nil {
                return nil, err
        }
@@ -268,6 +268,23 @@ func (m *Manager) createAddress(ctx context.Context, account *Account, change bo
        return cp, nil
 }
 
+// listAddressesById
+func (m *Manager) ListCtrlProgramsByXpubs(ctx context.Context, xpubs []chainkd.XPub) ([]*CtrlProgram, error) {
+       cps, err := m.ListControlProgram()
+       if err != nil {
+               return nil, err
+       }
+
+       var result []*CtrlProgram
+       for _, cp := range cps {
+               if cp.Address == "" || chainkd.CompareTwoXPubs(cp.XPubs, xpubs) != 0 {
+                       continue
+               }
+               result = append(result, cp)
+       }
+       return result, nil
+}
+
 func (m *Manager) createP2PKH(ctx context.Context, account *Account, change bool) (*CtrlProgram, error) {
        idx := m.getNextXpubsIndex(account.Signer.XPubs)
        path := signers.Path(account.Signer, signers.AccountKeySpace, idx)
@@ -288,6 +305,7 @@ func (m *Manager) createP2PKH(ctx context.Context, account *Account, change bool
 
        return &CtrlProgram{
                AccountID:      account.ID,
+               XPubs:          account.Signer.XPubs,
                Address:        address.EncodeAddress(),
                KeyIndex:       idx,
                ControlProgram: control,
@@ -319,6 +337,7 @@ func (m *Manager) createP2SH(ctx context.Context, account *Account, change bool)
 
        return &CtrlProgram{
                AccountID:      account.ID,
+               XPubs:          account.Signer.XPubs,
                Address:        address.EncodeAddress(),
                KeyIndex:       idx,
                ControlProgram: control,
@@ -329,6 +348,7 @@ func (m *Manager) createP2SH(ctx context.Context, account *Account, change bool)
 //CtrlProgram is structure of account control program
 type CtrlProgram struct {
        AccountID      string
+       XPubs          []chainkd.XPub
        Address        string
        KeyIndex       uint64
        ControlProgram []byte
@@ -349,6 +369,12 @@ func (m *Manager) insertAccountControlProgram(ctx context.Context, progs ...*Ctr
        return nil
 }
 
+func (m *Manager) DeleteAccountControlProgram(prog []byte) {
+       var hash common.Hash
+       sha3pool.Sum256(hash[:], prog)
+       m.db.Delete(CPKey(hash))
+}
+
 // IsLocalControlProgram check is the input control program belong to local
 func (m *Manager) IsLocalControlProgram(prog []byte) bool {
        var hash common.Hash
@@ -399,7 +425,7 @@ type Info struct {
 func (m *Manager) DeleteAccount(aliasOrId string) (err error) {
        account := &Account{}
        if account, err = m.FindByAlias(nil, aliasOrId); err != nil {
-               if account, err = m.findByID(nil, aliasOrId); err != nil {
+               if account, err = m.FindByID(nil, aliasOrId); err != nil {
                        return err
                }
        }
index b4891ec..9be886b 100644 (file)
@@ -53,7 +53,7 @@ func TestCreateAccount(t *testing.T) {
                testutil.FatalErr(t, err)
        }
 
-       found, err := m.findByID(ctx, account.ID)
+       found, err := m.FindByID(ctx, account.ID)
        if err != nil {
                t.Errorf("unexpected error %v", err)
        }
@@ -91,7 +91,7 @@ func TestDeleteAccount(t *testing.T) {
                testutil.FatalErr(t, err)
        }
 
-       found, err := m.findByID(ctx, account1.ID)
+       found, err := m.FindByID(ctx, account1.ID)
        if err != nil {
                t.Errorf("expected account %v should be deleted", found)
        }
@@ -100,7 +100,7 @@ func TestDeleteAccount(t *testing.T) {
                testutil.FatalErr(t, err)
        }
 
-       found, err = m.findByID(ctx, account2.ID)
+       found, err = m.FindByID(ctx, account2.ID)
        if err != nil {
                t.Errorf("expected account %v should be deleted", found)
        }
@@ -111,7 +111,7 @@ func TestFindByID(t *testing.T) {
        ctx := context.Background()
        account := m.createTestAccount(ctx, t, "", nil)
 
-       found, err := m.findByID(ctx, account.ID)
+       found, err := m.FindByID(ctx, account.ID)
        if err != nil {
                testutil.FatalErr(t, err)
        }
index 945c085..dda4ec3 100644 (file)
@@ -43,7 +43,7 @@ func (a *spendAction) Build(ctx context.Context, b *txbuilder.TemplateBuilder) e
                return txbuilder.MissingFieldsError(missing...)
        }
 
-       acct, err := a.accounts.findByID(ctx, a.AccountID)
+       acct, err := a.accounts.FindByID(ctx, a.AccountID)
        if err != nil {
                return errors.Wrap(err, "get account info")
        }
@@ -115,7 +115,7 @@ func (a *spendUTXOAction) Build(ctx context.Context, b *txbuilder.TemplateBuilde
 
        var accountSigner *signers.Signer
        if len(res.Source.AccountID) != 0 {
-               account, err := a.accounts.findByID(ctx, res.Source.AccountID)
+               account, err := a.accounts.FindByID(ctx, res.Source.AccountID)
                if err != nil {
                        return err
                }
index f08497b..e0cf81c 100644 (file)
@@ -2,6 +2,7 @@ package api
 
 import (
        "context"
+
        log "github.com/sirupsen/logrus"
 
        "github.com/bytom/account"
@@ -13,9 +14,9 @@ import (
 
 // POST /create-account
 func (a *API) createAccount(ctx context.Context, ins struct {
-       RootXPubs []chainkd.XPub         `json:"root_xpubs"`
-       Quorum    int                    `json:"quorum"`
-       Alias     string                 `json:"alias"`
+       RootXPubs []chainkd.XPub `json:"root_xpubs"`
+       Quorum    int            `json:"quorum"`
+       Alias     string         `json:"alias"`
 }) Response {
        acc, err := a.wallet.AccountMgr.Create(ctx, ins.RootXPubs, ins.Quorum, ins.Alias)
        if err != nil {
@@ -90,29 +91,30 @@ func (a *API) listAddresses(ctx context.Context, ins struct {
        AccountAlias string `json:"account_alias"`
 }) Response {
        accountID := ins.AccountID
+       var target *account.Account
        if ins.AccountAlias != "" {
                acc, err := a.wallet.AccountMgr.FindByAlias(ctx, ins.AccountAlias)
                if err != nil {
                        return NewErrorResponse(err)
                }
-
-               accountID = acc.ID
+               target = acc
+       } else {
+               acc, err := a.wallet.AccountMgr.FindByID(ctx, accountID)
+               if err != nil {
+                       return NewErrorResponse(err)
+               }
+               target = acc
        }
 
-       cps, err := a.wallet.AccountMgr.ListControlProgram()
+       cps, err := a.wallet.AccountMgr.ListCtrlProgramsByXpubs(ctx, target.XPubs)
        if err != nil {
                return NewErrorResponse(err)
        }
 
        var addresses []*addressResp
        for _, cp := range cps {
-               if cp.Address == "" || (accountID != "" && accountID != cp.AccountID) {
-                       continue
-               }
-
-               accountAlias := a.wallet.AccountMgr.GetAliasByID(cp.AccountID)
                addresses = append(addresses, &addressResp{
-                       AccountAlias: accountAlias,
+                       AccountAlias: target.Alias,
                        AccountID:    cp.AccountID,
                        Address:      cp.Address,
                        Change:       cp.Change,
index 5d8ca4e..23f6fc2 100644 (file)
@@ -8,6 +8,7 @@ import (
 
        "github.com/bytom/crypto/ed25519"
        "github.com/bytom/crypto/ed25519/ecmath"
+       "bytes"
 )
 
 type (
@@ -17,7 +18,16 @@ type (
        XPub [64]byte
 )
 
-var one = [32]byte{1}
+// CompareTwoXPubs
+func CompareTwoXPubs(a, b []XPub) int {
+       for i, xpub := range a {
+               result := bytes.Compare(xpub[:], b[i][:])
+               if result != 0 {
+                       return result
+               }
+       }
+       return 0
+}
 
 // NewXPrv takes a source of random bytes and produces a new XPrv.
 // If r is nil, crypto/rand.Reader is used.
index 4a55b13..c41d00b 100644 (file)
@@ -383,15 +383,25 @@ func (w *Wallet) filterAccountTxs(b *types.Block, txStatus *bc.TransactionStatus
 transactionLoop:
        for pos, tx := range b.Transactions {
                statusFail, _ := txStatus.GetStatus(pos)
+               isLocal := false
                for _, v := range tx.Outputs {
                        var hash [32]byte
                        sha3pool.Sum256(hash[:], v.ControlProgram)
                        if bytes := w.DB.Get(account.CPKey(hash)); bytes != nil {
+                               cp := &account.CtrlProgram{}
+                               if err := json.Unmarshal(bytes, cp); err == nil {
+                                       w.status.selfProgramsOnChain.Add(cp.Address)
+                               }
+
                                annotatedTxs = append(annotatedTxs, w.buildAnnotatedTransaction(tx, b, statusFail, pos))
-                               continue transactionLoop
+                               isLocal = true
                        }
                }
 
+               if isLocal {
+                       continue
+               }
+
                for _, v := range tx.Inputs {
                        outid, err := v.SpentOutputID()
                        if err != nil {
diff --git a/wallet/set.go b/wallet/set.go
new file mode 100644 (file)
index 0000000..627ebd2
--- /dev/null
@@ -0,0 +1,28 @@
+package wallet
+
+type Set map[interface{}]bool
+
+func NewSet() Set {
+       return make(Set)
+}
+
+// Add Add the specified element to this set if it is not already present (optional operation)
+func (s *Set) Add(i interface{}) bool {
+       _, found := (*s)[i]
+       if found {
+               return false //False if it existed already
+       }
+
+       (*s)[i] = true
+       return true
+}
+
+// Contains Returns true if this set contains the specified elements
+func (s *Set) Contains(i ...interface{}) bool {
+       for _, val := range i {
+               if _, ok := (*s)[val]; !ok {
+                       return false
+               }
+       }
+       return true
+}
index 117476e..5a061ae 100644 (file)
@@ -33,10 +33,12 @@ type StatusInfo struct {
        WorkHash   bc.Hash
        BestHeight uint64
        BestHash   bc.Hash
+       selfProgramsOnChain Set
 }
 
 //KeyInfo is key import status
 type KeyInfo struct {
+       account  account.Account
        Alias    string       `json:"alias"`
        XPub     chainkd.XPub `json:"xpub"`
        Percent  uint8        `json:"percent"`
@@ -59,13 +61,13 @@ type Wallet struct {
 //NewWallet return a new wallet instance
 func NewWallet(walletDB db.DB, account *account.Manager, asset *asset.Registry, hsm *pseudohsm.HSM, chain *protocol.Chain) (*Wallet, error) {
        w := &Wallet{
-               DB:                walletDB,
-               AccountMgr:        account,
-               AssetReg:          asset,
-               chain:             chain,
-               Hsm:               hsm,
-               rescanProgress:    make(chan struct{}, 1),
-               importingKeysInfo: make([]KeyInfo, 0),
+               DB:                  walletDB,
+               AccountMgr:          account,
+               AssetReg:            asset,
+               chain:               chain,
+               Hsm:                 hsm,
+               rescanProgress:      make(chan struct{}, 1),
+               importingKeysInfo:   make([]KeyInfo, 0),
        }
 
        if err := w.loadWalletInfo(); err != nil {
@@ -90,6 +92,7 @@ func (w *Wallet) loadWalletInfo() error {
                return json.Unmarshal(rawWallet, &w.status)
        }
 
+       w.status.selfProgramsOnChain = NewSet()
        block, err := w.chain.GetBlockByHeight(0)
        if err != nil {
                return err
@@ -191,7 +194,7 @@ func (w *Wallet) DetachBlock(block *types.Block) error {
 func (w *Wallet) walletUpdater() {
        for {
                getRescanNotification(w)
-               updateRescanStatus(w)
+               w.updateRescanStatus()
                for !w.chain.InMainChain(w.status.BestHash) {
                        block, err := w.chain.GetBlockByHash(&w.status.BestHash)
                        if err != nil {
@@ -297,6 +300,7 @@ func (w *Wallet) recoveryAccountWalletDB(account *account.Account, XPub *pseudoh
        }
        w.ImportingPrivateKey = true
        tmp := KeyInfo{
+               account:  *account,
                Alias:    keyAlias,
                XPub:     XPub.XPub,
                Complete: false,
@@ -339,7 +343,7 @@ func (w *Wallet) GetRescanStatus() ([]KeyInfo, error) {
 }
 
 //updateRescanStatus mark private key import process `Complete` if rescan finished
-func updateRescanStatus(w *Wallet) {
+func (w *Wallet) updateRescanStatus() {
        if !w.ImportingPrivateKey {
                return
        }
@@ -349,6 +353,7 @@ func updateRescanStatus(w *Wallet) {
                for _, keyInfo := range w.importingKeysInfo {
                        keyInfo.Percent = percent
                }
+               w.commitkeysInfo()
                return
        }
 
@@ -356,7 +361,14 @@ func updateRescanStatus(w *Wallet) {
        for _, keyInfo := range w.importingKeysInfo {
                keyInfo.Percent = 100
                keyInfo.Complete = true
+
+               if cps, err := w.AccountMgr.ListCtrlProgramsByXpubs(nil, keyInfo.account.XPubs); err == nil {
+                       for _, cp := range cps {
+                               if !w.status.selfProgramsOnChain.Contains(cp.Address) {
+                                       w.AccountMgr.DeleteAccountControlProgram(cp.ControlProgram)
+                               }
+                       }
+               }
        }
        w.commitkeysInfo()
-       // TODO: delete the generated but not used addresses
 }
index 43d8fe3..76c8009 100644 (file)
@@ -236,11 +236,12 @@ func mockTxData(utxos []*account.UTXO, testAccount *account.Account) (*txbuilder
 
 func mockWallet(walletDB dbm.DB, account *account.Manager, asset *asset.Registry, chain *protocol.Chain) *Wallet {
        return &Wallet{
-               DB:             walletDB,
-               AccountMgr:     account,
-               AssetReg:       asset,
-               chain:          chain,
-               rescanProgress: make(chan struct{}, 1),
+               DB:                  walletDB,
+               AccountMgr:          account,
+               AssetReg:            asset,
+               chain:               chain,
+               rescanProgress:      make(chan struct{}, 1),
+               selfProgramsOnChain: NewSet(),
        }
 }