OSDN Git Service

Add mempool new tx subscription support (#1578)
[bytom/bytom.git] / wallet / wallet.go
1 package wallet
2
3 import (
4         "encoding/json"
5         "sync"
6
7         log "github.com/sirupsen/logrus"
8         "github.com/tendermint/tmlibs/db"
9
10         "github.com/bytom/account"
11         "github.com/bytom/asset"
12         "github.com/bytom/blockchain/pseudohsm"
13         "github.com/bytom/event"
14         "github.com/bytom/protocol"
15         "github.com/bytom/protocol/bc"
16         "github.com/bytom/protocol/bc/types"
17 )
18
19 const (
20         //SINGLE single sign
21         SINGLE    = 1
22         logModule = "wallet"
23 )
24
25 var walletKey = []byte("walletInfo")
26
27 //StatusInfo is base valid block info to handle orphan block rollback
28 type StatusInfo struct {
29         WorkHeight uint64
30         WorkHash   bc.Hash
31         BestHeight uint64
32         BestHash   bc.Hash
33 }
34
35 //Wallet is related to storing account unspent outputs
36 type Wallet struct {
37         DB              db.DB
38         rw              sync.RWMutex
39         status          StatusInfo
40         AccountMgr      *account.Manager
41         AssetReg        *asset.Registry
42         Hsm             *pseudohsm.HSM
43         chain           *protocol.Chain
44         RecoveryMgr     *recoveryManager
45         eventDispatcher *event.Dispatcher
46         txMsgSub        *event.Subscription
47
48         rescanCh chan struct{}
49 }
50
51 //NewWallet return a new wallet instance
52 func NewWallet(walletDB db.DB, account *account.Manager, asset *asset.Registry, hsm *pseudohsm.HSM, chain *protocol.Chain, dispatcher *event.Dispatcher) (*Wallet, error) {
53         w := &Wallet{
54                 DB:              walletDB,
55                 AccountMgr:      account,
56                 AssetReg:        asset,
57                 chain:           chain,
58                 Hsm:             hsm,
59                 RecoveryMgr:     newRecoveryManager(walletDB, account),
60                 eventDispatcher: dispatcher,
61                 rescanCh:        make(chan struct{}, 1),
62         }
63
64         if err := w.loadWalletInfo(); err != nil {
65                 return nil, err
66         }
67
68         if err := w.RecoveryMgr.LoadStatusInfo(); err != nil {
69                 return nil, err
70         }
71
72         var err error
73         w.txMsgSub, err = w.eventDispatcher.Subscribe(protocol.TxMsgEvent{})
74         if err != nil {
75                 return nil, err
76         }
77
78         go w.walletUpdater()
79         go w.delUnconfirmedTx()
80         go w.memPoolTxQueryLoop()
81         return w, nil
82 }
83
84 // memPoolTxQueryLoop constantly pass a transaction accepted by mempool to the wallet.
85 func (w *Wallet) memPoolTxQueryLoop() {
86         for {
87                 select {
88                 case obj, ok := <-w.txMsgSub.Chan():
89                         if !ok {
90                                 log.WithFields(log.Fields{"module": logModule}).Warning("tx pool tx msg subscription channel closed")
91                                 return
92                         }
93
94                         ev, ok := obj.Data.(protocol.TxMsgEvent)
95                         if !ok {
96                                 log.WithFields(log.Fields{"module": logModule}).Error("event type error")
97                                 continue
98                         }
99
100                         switch ev.TxMsg.MsgType {
101                         case protocol.MsgNewTx:
102                                 w.AddUnconfirmedTx(ev.TxMsg.TxDesc)
103                         case protocol.MsgRemoveTx:
104                                 w.RemoveUnconfirmedTx(ev.TxMsg.TxDesc)
105                         default:
106                                 log.WithFields(log.Fields{"module": logModule}).Warn("got unknow message type from the txPool channel")
107                         }
108                 }
109         }
110 }
111
112 //GetWalletInfo return stored wallet info and nil,if error,
113 //return initial wallet info and err
114 func (w *Wallet) loadWalletInfo() error {
115         if rawWallet := w.DB.Get(walletKey); rawWallet != nil {
116                 return json.Unmarshal(rawWallet, &w.status)
117         }
118
119         block, err := w.chain.GetBlockByHeight(0)
120         if err != nil {
121                 return err
122         }
123         return w.AttachBlock(block)
124 }
125
126 func (w *Wallet) commitWalletInfo(batch db.Batch) error {
127         rawWallet, err := json.Marshal(w.status)
128         if err != nil {
129                 log.WithField("err", err).Error("save wallet info")
130                 return err
131         }
132
133         batch.Set(walletKey, rawWallet)
134         batch.Write()
135         return nil
136 }
137
138 // AttachBlock attach a new block
139 func (w *Wallet) AttachBlock(block *types.Block) error {
140         w.rw.Lock()
141         defer w.rw.Unlock()
142
143         if block.PreviousBlockHash != w.status.WorkHash {
144                 log.Warn("wallet skip attachBlock due to status hash not equal to previous hash")
145                 return nil
146         }
147
148         blockHash := block.Hash()
149         txStatus, err := w.chain.GetTransactionStatus(&blockHash)
150         if err != nil {
151                 return err
152         }
153
154         if err := w.RecoveryMgr.FilterRecoveryTxs(block); err != nil {
155                 return err
156         }
157
158         storeBatch := w.DB.NewBatch()
159         if err := w.indexTransactions(storeBatch, block, txStatus); err != nil {
160                 return err
161         }
162
163         w.attachUtxos(storeBatch, block, txStatus)
164         w.status.WorkHeight = block.Height
165         w.status.WorkHash = block.Hash()
166         if w.status.WorkHeight >= w.status.BestHeight {
167                 w.status.BestHeight = w.status.WorkHeight
168                 w.status.BestHash = w.status.WorkHash
169         }
170         return w.commitWalletInfo(storeBatch)
171 }
172
173 // DetachBlock detach a block and rollback state
174 func (w *Wallet) DetachBlock(block *types.Block) error {
175         w.rw.Lock()
176         defer w.rw.Unlock()
177
178         blockHash := block.Hash()
179         txStatus, err := w.chain.GetTransactionStatus(&blockHash)
180         if err != nil {
181                 return err
182         }
183
184         storeBatch := w.DB.NewBatch()
185         w.detachUtxos(storeBatch, block, txStatus)
186         w.deleteTransactions(storeBatch, w.status.BestHeight)
187
188         w.status.BestHeight = block.Height - 1
189         w.status.BestHash = block.PreviousBlockHash
190
191         if w.status.WorkHeight > w.status.BestHeight {
192                 w.status.WorkHeight = w.status.BestHeight
193                 w.status.WorkHash = w.status.BestHash
194         }
195
196         return w.commitWalletInfo(storeBatch)
197 }
198
199 //WalletUpdate process every valid block and reverse every invalid block which need to rollback
200 func (w *Wallet) walletUpdater() {
201         for {
202                 w.getRescanNotification()
203                 for !w.chain.InMainChain(w.status.BestHash) {
204                         block, err := w.chain.GetBlockByHash(&w.status.BestHash)
205                         if err != nil {
206                                 log.WithField("err", err).Error("walletUpdater GetBlockByHash")
207                                 return
208                         }
209
210                         if err := w.DetachBlock(block); err != nil {
211                                 log.WithField("err", err).Error("walletUpdater detachBlock stop")
212                                 return
213                         }
214                 }
215
216                 block, _ := w.chain.GetBlockByHeight(w.status.WorkHeight + 1)
217                 if block == nil {
218                         w.walletBlockWaiter()
219                         continue
220                 }
221
222                 if err := w.AttachBlock(block); err != nil {
223                         log.WithField("err", err).Error("walletUpdater AttachBlock stop")
224                         return
225                 }
226         }
227 }
228
229 //RescanBlocks provide a trigger to rescan blocks
230 func (w *Wallet) RescanBlocks() {
231         select {
232         case w.rescanCh <- struct{}{}:
233         default:
234                 return
235         }
236 }
237
238 // deleteAccountTxs deletes all txs in wallet
239 func (w *Wallet) deleteAccountTxs() {
240         storeBatch := w.DB.NewBatch()
241
242         txIter := w.DB.IteratorPrefix([]byte(TxPrefix))
243         defer txIter.Release()
244
245         for txIter.Next() {
246                 storeBatch.Delete(txIter.Key())
247         }
248
249         txIndexIter := w.DB.IteratorPrefix([]byte(TxIndexPrefix))
250         defer txIndexIter.Release()
251
252         for txIndexIter.Next() {
253                 storeBatch.Delete(txIndexIter.Key())
254         }
255
256         storeBatch.Write()
257 }
258
259 // DeleteAccount deletes account matching accountID, then rescan wallet
260 func (w *Wallet) DeleteAccount(accountID string) (err error) {
261         w.rw.Lock()
262         defer w.rw.Unlock()
263
264         if err := w.AccountMgr.DeleteAccount(accountID); err != nil {
265                 return err
266         }
267
268         w.deleteAccountTxs()
269         w.RescanBlocks()
270         return nil
271 }
272
273 func (w *Wallet) UpdateAccountAlias(accountID string, newAlias string) (err error) {
274         w.rw.Lock()
275         defer w.rw.Unlock()
276
277         if err := w.AccountMgr.UpdateAccountAlias(accountID, newAlias); err != nil {
278                 return err
279         }
280
281         w.deleteAccountTxs()
282         w.RescanBlocks()
283         return nil
284 }
285
286 func (w *Wallet) getRescanNotification() {
287         select {
288         case <-w.rescanCh:
289                 w.setRescanStatus()
290         default:
291                 return
292         }
293 }
294
295 func (w *Wallet) setRescanStatus() {
296         block, _ := w.chain.GetBlockByHeight(0)
297         w.status.WorkHash = bc.Hash{}
298         w.AttachBlock(block)
299 }
300
301 func (w *Wallet) walletBlockWaiter() {
302         select {
303         case <-w.chain.BlockWaiter(w.status.WorkHeight + 1):
304         case <-w.rescanCh:
305                 w.setRescanStatus()
306         }
307 }
308
309 // GetWalletStatusInfo return current wallet StatusInfo
310 func (w *Wallet) GetWalletStatusInfo() StatusInfo {
311         w.rw.RLock()
312         defer w.rw.RUnlock()
313
314         return w.status
315 }