OSDN Git Service

fix
[bytom/vapor.git] / federation / warder.go
1 package federation
2
3 import (
4         "database/sql"
5         "encoding/hex"
6         "time"
7
8         btmTypes "github.com/bytom/protocol/bc/types"
9         "github.com/jinzhu/gorm"
10         log "github.com/sirupsen/logrus"
11
12         "github.com/vapor/crypto/ed25519/chainkd"
13         "github.com/vapor/errors"
14         "github.com/vapor/federation/common"
15         "github.com/vapor/federation/config"
16         "github.com/vapor/federation/database/orm"
17         "github.com/vapor/federation/service"
18         vaporBc "github.com/vapor/protocol/bc"
19         vaporTypes "github.com/vapor/protocol/bc/types"
20 )
21
22 var collectInterval = 5 * time.Second
23
24 var xprvStr = "d20e3d81ba2c5509619fbc276d7cd8b94f52a1dce1291ae9e6b28d4a48ee67d8ac5826ba65c9da0b035845b7cb379e816c529194c7e369492d8828dee5ede3e2"
25
26 func string2xprv(str string) (xprv chainkd.XPrv) {
27         if err := xprv.UnmarshalText([]byte(str)); err != nil {
28                 log.Panicf("fail to convert xprv string")
29         }
30         return xprv
31 }
32
33 type warder struct {
34         db            *gorm.DB
35         assetKeeper   *service.AssetKeeper
36         txCh          chan *orm.CrossTransaction
37         fedProg       []byte
38         position      uint8
39         xpub          chainkd.XPub
40         xprv          chainkd.XPrv
41         mainchainNode *service.Node
42         sidechainNode *service.Node
43         remotes       []*service.Warder
44 }
45
46 func NewWarder(db *gorm.DB, assetKeeper *service.AssetKeeper, cfg *config.Config) *warder {
47         local, remotes := parseWarders(cfg)
48         return &warder{
49                 db:          db,
50                 assetKeeper: assetKeeper,
51                 txCh:        make(chan *orm.CrossTransaction),
52                 fedProg:     ParseFedProg(cfg.Warders, cfg.Quorum),
53                 position:    local.Position,
54                 xpub:        local.XPub,
55                 // TODO:
56                 xprv:          string2xprv(xprvStr),
57                 mainchainNode: service.NewNode(cfg.Mainchain.Upstream),
58                 sidechainNode: service.NewNode(cfg.Sidechain.Upstream),
59                 remotes:       remotes,
60         }
61 }
62
63 func parseWarders(cfg *config.Config) (*service.Warder, []*service.Warder) {
64         var local *service.Warder
65         var remotes []*service.Warder
66         for _, warderCfg := range cfg.Warders {
67                 // TODO: use private key to check
68                 if warderCfg.IsLocal {
69                         local = service.NewWarder(&warderCfg)
70                 } else {
71                         remoteWarder := service.NewWarder(&warderCfg)
72                         remotes = append(remotes, remoteWarder)
73                 }
74         }
75
76         if local == nil {
77                 log.Fatal("none local warder set")
78         }
79
80         return local, remotes
81 }
82
83 func (w *warder) Run() {
84         go w.collectPendingTx()
85         go w.processCrossTxRoutine()
86 }
87
88 func (w *warder) collectPendingTx() {
89         ticker := time.NewTicker(collectInterval)
90         for ; true; <-ticker.C {
91                 txs := []*orm.CrossTransaction{}
92                 if err := w.db.Preload("Chain").Preload("Reqs").
93                         // do not use "Where(&orm.CrossTransaction{Status: common.CrossTxPendingStatus})" directly,
94                         // otherwise the field "status" will be ignored
95                         Model(&orm.CrossTransaction{}).Where("status = ?", common.CrossTxPendingStatus).
96                         Find(&txs).Error; err == gorm.ErrRecordNotFound {
97                         continue
98                 } else if err != nil {
99                         log.Warnln("collectPendingTx", err)
100                 }
101
102                 for _, tx := range txs {
103                         w.txCh <- tx
104                 }
105         }
106 }
107
108 func (w *warder) processCrossTxRoutine() {
109         for ormTx := range w.txCh {
110                 if err := w.validateCrossTx(ormTx); err != nil {
111                         log.Warnln("invalid cross-chain tx", ormTx)
112                         continue
113                 }
114
115                 destTx, destTxID, err := w.proposeDestTx(ormTx)
116                 if err != nil {
117                         log.WithFields(log.Fields{"err": err, "cross-chain tx": ormTx}).Warnln("proposeDestTx")
118                         continue
119                 }
120
121                 if err := w.signDestTx(destTx, ormTx); err != nil {
122                         log.WithFields(log.Fields{"err": err, "cross-chain tx": ormTx}).Warnln("signDestTx")
123                         continue
124                 }
125
126                 for _, remote := range w.remotes {
127                         signs, err := remote.RequestSign(destTx, ormTx)
128                         if err != nil {
129                                 log.WithFields(log.Fields{"err": err, "remote": remote, "cross-chain tx": ormTx}).Warnln("RequestSign")
130                                 continue
131                         }
132
133                         w.attachSignsForTx(destTx, ormTx, remote.Position, signs)
134                 }
135
136                 if w.isTxSignsReachQuorum(destTx) && w.isLeader() {
137                         submittedTxID, err := w.submitTx(destTx)
138                         if err != nil {
139                                 log.WithFields(log.Fields{"err": err, "cross-chain tx": ormTx, "dest tx": destTx}).Warnln("submitTx")
140                                 continue
141                         }
142
143                         if submittedTxID != destTxID {
144                                 log.WithFields(log.Fields{"err": err, "cross-chain tx": ormTx, "builtTx ID": destTxID, "submittedTx ID": submittedTxID}).Warnln("submitTx ID mismatch")
145                                 continue
146                         }
147
148                         if err := w.updateSubmission(ormTx); err != nil {
149                                 log.WithFields(log.Fields{"err": err, "cross-chain tx": ormTx}).Warnln("updateSubmission")
150                                 continue
151                         }
152                 }
153         }
154 }
155
156 func (w *warder) validateCrossTx(tx *orm.CrossTransaction) error {
157         switch tx.Status {
158         case common.CrossTxRejectedStatus:
159                 return errors.New("cross-chain tx rejected")
160         case common.CrossTxSubmittedStatus:
161                 return errors.New("cross-chain tx submitted")
162         case common.CrossTxCompletedStatus:
163                 return errors.New("cross-chain tx completed")
164         default:
165                 return nil
166         }
167 }
168
169 func (w *warder) proposeDestTx(tx *orm.CrossTransaction) (interface{}, string, error) {
170         switch tx.Chain.Name {
171         case "bytom":
172                 return w.buildSidechainTx(tx)
173         case "vapor":
174                 return w.buildMainchainTx(tx)
175         default:
176                 return nil, "", errors.New("unknown source chain")
177         }
178 }
179
180 // TODO:
181 // signInsts?
182 // addInputWitness(tx, signInsts)?
183 func (w *warder) buildSidechainTx(ormTx *orm.CrossTransaction) (*vaporTypes.Tx, string, error) {
184         destTxData := &vaporTypes.TxData{Version: 1, TimeRange: 0}
185         // signInsts := []*SigningInstruction{}
186         muxID := &vaporBc.Hash{}
187         if err := muxID.UnmarshalText([]byte(ormTx.SourceMuxID)); err != nil {
188                 return nil, "", errors.Wrap(err, "Unmarshal muxID")
189         }
190
191         for _, req := range ormTx.Reqs {
192                 // TODO:
193                 // getAsset from assetKeeper instead of preload asset, in order to save db query overload
194                 asset := &orm.Asset{}
195                 assetID := &vaporBc.AssetID{}
196                 if err := assetID.UnmarshalText([]byte(asset.AssetID)); err != nil {
197                         return nil, "", errors.Wrap(err, "Unmarshal muxID")
198                 }
199
200                 rawDefinitionByte, err := hex.DecodeString(asset.RawDefinitionByte)
201                 if err != nil {
202                         return nil, "", errors.Wrap(err, "decode rawDefinitionByte")
203                 }
204
205                 input := vaporTypes.NewCrossChainInput(nil, *muxID, *assetID, req.AssetAmount, req.SourcePos, w.fedProg, rawDefinitionByte)
206                 destTxData.Inputs = append(destTxData.Inputs, input)
207         }
208
209         // for?{
210
211         // txInput := btmTypes.NewSpendInput(nil, *utxoInfo.SourceID, *assetID, utxo.Amount, utxoInfo.SourcePos, cp)
212         // tx.Inputs = append(tx.Inputs, txInput)
213
214         // signInst := &SigningInstruction{}
215         // if utxo.Address == nil || utxo.Address.Wallet == nil {
216         //     return signInst, nil
217         // }
218
219         // path := pathForAddress(utxo.Address.Wallet.Idx, utxo.Address.Idx, utxo.Address.Change)
220         // for _, p := range path {
221         //     signInst.DerivationPath = append(signInst.DerivationPath, hex.EncodeToString(p))
222         // }
223
224         // xPubs, err := signersToXPubs(utxo.Address.Wallet.WalletSigners)
225         // if err != nil {
226         //     return nil, errors.Wrap(err, "signersToXPubs")
227         // }
228
229         // derivedXPubs := chainkd.DeriveXPubs(xPubs, path)
230         // derivedPKs := chainkd.XPubKeys(derivedXPubs)
231         // if len(derivedPKs) == 1 {
232         //     signInst.DataWitness = derivedPKs[0]
233         //     signInst.Pubkey = hex.EncodeToString(derivedPKs[0])
234         // } else if len(derivedPKs) > 1 {
235         //     if signInst.DataWitness, err = vmutil.P2SPMultiSigProgram(derivedPKs, int(utxo.Address.Wallet.M)); err != nil {
236         //         return nil, err
237         //     }
238         // }
239         // return signInst, nil
240
241         // signInsts = append(signInsts, signInst)
242
243         // }
244
245         // add the payment output && handle the fee
246         // if err := addOutput(txData, address, asset, amount, netParams); err != nil {
247         //     return nil, nil, errors.Wrap(err, "add payment output")
248         // }
249
250         destTx := vaporTypes.NewTx(*destTxData)
251         // addInputWitness(tx, signInsts)
252
253         if err := w.db.Where(ormTx).UpdateColumn(&orm.CrossTransaction{
254                 DestTxHash: sql.NullString{destTx.ID.String(), true},
255         }).Error; err != nil {
256                 return nil, "", err
257         }
258
259         return destTx, destTx.ID.String(), nil
260 }
261
262 // TODO:
263 func (w *warder) buildMainchainTx(tx *orm.CrossTransaction) (*btmTypes.Tx, string, error) {
264         mainchainTx := &btmTypes.Tx{}
265
266         if err := w.db.Where(tx).UpdateColumn(&orm.CrossTransaction{
267                 DestTxHash: sql.NullString{mainchainTx.ID.String(), true},
268         }).Error; err != nil {
269                 return nil, "", err
270         }
271
272         return mainchainTx, mainchainTx.ID.String(), nil
273 }
274
275 // TODO:
276 func (w *warder) signDestTx(destTx interface{}, tx *orm.CrossTransaction) error {
277         if tx.Status != common.CrossTxPendingStatus || !tx.DestTxHash.Valid {
278                 return errors.New("cross-chain tx status error")
279         }
280
281         return nil
282 }
283
284 // TODO:
285 func (w *warder) attachSignsForTx(destTx interface{}, ormTx *orm.CrossTransaction, position uint8, signs string) {
286 }
287
288 // TODO:
289 func (w *warder) isTxSignsReachQuorum(destTx interface{}) bool {
290         return false
291 }
292
293 func (w *warder) isLeader() bool {
294         return w.position == 1
295 }
296
297 func (w *warder) submitTx(destTx interface{}) (string, error) {
298         switch tx := destTx.(type) {
299         case *btmTypes.Tx:
300                 return w.mainchainNode.SubmitTx(tx)
301         case *vaporTypes.Tx:
302                 return w.sidechainNode.SubmitTx(tx)
303         default:
304                 return "", errors.New("unknown destTx type")
305         }
306 }
307
308 func (w *warder) updateSubmission(tx *orm.CrossTransaction) error {
309         if err := w.db.Where(tx).UpdateColumn(&orm.CrossTransaction{
310                 Status: common.CrossTxSubmittedStatus,
311         }).Error; err != nil {
312                 return err
313         }
314
315         for _, remote := range w.remotes {
316                 remote.NotifySubmission(tx)
317         }
318         return nil
319 }