OSDN Git Service

clean
[bytom/vapor.git] / federation / warder.go
1 package federation
2
3 import (
4         "database/sql"
5         "encoding/hex"
6         "time"
7
8         btmBc "github.com/bytom/protocol/bc"
9         btmTypes "github.com/bytom/protocol/bc/types"
10         "github.com/jinzhu/gorm"
11         log "github.com/sirupsen/logrus"
12
13         "github.com/vapor/crypto/ed25519/chainkd"
14         "github.com/vapor/errors"
15         "github.com/vapor/federation/common"
16         "github.com/vapor/federation/config"
17         "github.com/vapor/federation/database"
18         "github.com/vapor/federation/database/orm"
19         "github.com/vapor/federation/service"
20         vaporBc "github.com/vapor/protocol/bc"
21         vaporTypes "github.com/vapor/protocol/bc/types"
22 )
23
24 var collectInterval = 5 * time.Second
25
26 type warder struct {
27         db            *gorm.DB
28         assetStore    *database.AssetStore
29         txCh          chan *orm.CrossTransaction
30         fedProg       []byte
31         position      uint8
32         xpub          chainkd.XPub
33         xprv          chainkd.XPrv
34         mainchainNode *service.Node
35         sidechainNode *service.Node
36         remotes       []*service.Warder
37 }
38
39 func NewWarder(db *gorm.DB, assetStore *database.AssetStore, cfg *config.Config) *warder {
40         local, remotes := parseWarders(cfg)
41         return &warder{
42                 db:            db,
43                 assetStore:    assetStore,
44                 txCh:          make(chan *orm.CrossTransaction),
45                 fedProg:       ParseFedProg(cfg.Warders, cfg.Quorum),
46                 position:      local.Position,
47                 xpub:          local.XPub,
48                 xprv:          string2xprv(xprvStr),
49                 mainchainNode: service.NewNode(cfg.Mainchain.Upstream),
50                 sidechainNode: service.NewNode(cfg.Sidechain.Upstream),
51                 remotes:       remotes,
52         }
53 }
54
55 func parseWarders(cfg *config.Config) (*service.Warder, []*service.Warder) {
56         var local *service.Warder
57         var remotes []*service.Warder
58         for _, warderCfg := range cfg.Warders {
59                 if warderCfg.IsLocal {
60                         local = service.NewWarder(&warderCfg)
61                 } else {
62                         remoteWarder := service.NewWarder(&warderCfg)
63                         remotes = append(remotes, remoteWarder)
64                 }
65         }
66
67         if local == nil {
68                 log.Fatal("none local warder set")
69         }
70
71         return local, remotes
72 }
73
74 func (w *warder) Run() {
75         go w.collectPendingTx()
76         go w.processCrossTxRoutine()
77 }
78
79 func (w *warder) collectPendingTx() {
80         ticker := time.NewTicker(collectInterval)
81         for ; true; <-ticker.C {
82                 txs := []*orm.CrossTransaction{}
83                 if err := w.db.Preload("Chain").Preload("Reqs").
84                         // do not use "Where(&orm.CrossTransaction{Status: common.CrossTxPendingStatus})" directly,
85                         // otherwise the field "status" will be ignored
86                         Model(&orm.CrossTransaction{}).Where("status = ?", common.CrossTxPendingStatus).
87                         Find(&txs).Error; err == gorm.ErrRecordNotFound {
88                         continue
89                 } else if err != nil {
90                         log.Warnln("collectPendingTx", err)
91                 }
92
93                 for _, tx := range txs {
94                         w.txCh <- tx
95                 }
96         }
97 }
98
99 func (w *warder) processCrossTxRoutine() {
100         for ormTx := range w.txCh {
101                 if err := w.validateCrossTx(ormTx); err != nil {
102                         log.Warnln("invalid cross-chain tx", ormTx)
103                         continue
104                 }
105
106                 destTx, destTxID, err := w.proposeDestTx(ormTx)
107                 if err != nil {
108                         log.WithFields(log.Fields{"err": err, "cross-chain tx": ormTx}).Warnln("proposeDestTx")
109                         continue
110                 }
111
112                 if err := w.initDestTxSigns(destTx, ormTx); err != nil {
113                         log.WithFields(log.Fields{"err": err, "cross-chain tx": ormTx}).Warnln("initDestTxSigns")
114                         continue
115                 }
116
117                 if err := w.signDestTx(destTx, ormTx); err != nil {
118                         log.WithFields(log.Fields{"err": err, "cross-chain tx": ormTx}).Warnln("signDestTx")
119                         continue
120                 }
121
122                 for _, remote := range w.remotes {
123                         signs, err := remote.RequestSign(destTx, ormTx)
124                         if err != nil {
125                                 log.WithFields(log.Fields{"err": err, "remote": remote, "cross-chain tx": ormTx}).Warnln("RequestSign")
126                                 continue
127                         }
128
129                         w.attachSignsForTx(destTx, ormTx, remote.Position, signs)
130                 }
131
132                 if w.isTxSignsReachQuorum(destTx) && w.isLeader() {
133                         submittedTxID, err := w.submitTx(destTx)
134                         if err != nil {
135                                 log.WithFields(log.Fields{"err": err, "cross-chain tx": ormTx, "dest tx": destTx}).Warnln("submitTx")
136                                 continue
137                         }
138
139                         if submittedTxID != destTxID {
140                                 log.WithFields(log.Fields{"err": err, "cross-chain tx": ormTx, "builtTx ID": destTxID, "submittedTx ID": submittedTxID}).Warnln("submitTx ID mismatch")
141                                 continue
142                         }
143
144                         if err := w.updateSubmission(ormTx); err != nil {
145                                 log.WithFields(log.Fields{"err": err, "cross-chain tx": ormTx}).Warnln("updateSubmission")
146                                 continue
147                         }
148                 }
149         }
150 }
151
152 func (w *warder) validateCrossTx(tx *orm.CrossTransaction) error {
153         switch tx.Status {
154         case common.CrossTxRejectedStatus:
155                 return errors.New("cross-chain tx rejected")
156         case common.CrossTxSubmittedStatus:
157                 return errors.New("cross-chain tx submitted")
158         case common.CrossTxCompletedStatus:
159                 return errors.New("cross-chain tx completed")
160         default:
161                 return nil
162         }
163 }
164
165 func (w *warder) proposeDestTx(tx *orm.CrossTransaction) (interface{}, string, error) {
166         switch tx.Chain.Name {
167         case "bytom":
168                 return w.buildSidechainTx(tx)
169         case "vapor":
170                 return w.buildMainchainTx(tx)
171         default:
172                 return nil, "", errors.New("unknown source chain")
173         }
174 }
175
176 func (w *warder) buildSidechainTx(ormTx *orm.CrossTransaction) (*vaporTypes.Tx, string, error) {
177         destTxData := &vaporTypes.TxData{Version: 1, TimeRange: 0}
178         muxID := &vaporBc.Hash{}
179         if err := muxID.UnmarshalText([]byte(ormTx.SourceMuxID)); err != nil {
180                 return nil, "", errors.Wrap(err, "Unmarshal muxID")
181         }
182
183         for _, req := range ormTx.Reqs {
184                 // getAsset from assetStore instead of preload asset, in order to save db query overload
185                 asset, err := w.assetStore.GetByOrmID(req.AssetID)
186                 if err != nil {
187                         return nil, "", errors.Wrap(err, "get asset by ormAsset ID")
188                 }
189
190                 assetID := &vaporBc.AssetID{}
191                 if err := assetID.UnmarshalText([]byte(asset.AssetID)); err != nil {
192                         return nil, "", errors.Wrap(err, "Unmarshal muxID")
193                 }
194
195                 rawDefinitionByte, err := hex.DecodeString(asset.RawDefinitionByte)
196                 if err != nil {
197                         return nil, "", errors.Wrap(err, "decode rawDefinitionByte")
198                 }
199
200                 input := vaporTypes.NewCrossChainInput(nil, *muxID, *assetID, req.AssetAmount, req.SourcePos, w.fedProg, rawDefinitionByte)
201                 destTxData.Inputs = append(destTxData.Inputs, input)
202
203                 controlProgram, err := hex.DecodeString(req.Script)
204                 if err != nil {
205                         return nil, "", errors.Wrap(err, "decode req.Script")
206                 }
207
208                 output := vaporTypes.NewIntraChainOutput(*assetID, req.AssetAmount, controlProgram)
209                 destTxData.Outputs = append(destTxData.Outputs, output)
210         }
211
212         destTx := vaporTypes.NewTx(*destTxData)
213         w.addInputWitness(destTx)
214
215         if err := w.db.Where(ormTx).UpdateColumn(&orm.CrossTransaction{
216                 DestTxHash: sql.NullString{destTx.ID.String(), true},
217         }).Error; err != nil {
218                 return nil, "", err
219         }
220
221         return destTx, destTx.ID.String(), nil
222 }
223
224 // TODO:
225 func (w *warder) buildMainchainTx(ormTx *orm.CrossTransaction) (*btmTypes.Tx, string, error) {
226         destTxData := &btmTypes.TxData{Version: 1, TimeRange: 0}
227         muxID := &btmBc.Hash{}
228         if err := muxID.UnmarshalText([]byte(ormTx.SourceMuxID)); err != nil {
229                 return nil, "", errors.Wrap(err, "Unmarshal muxID")
230         }
231
232         for _, req := range ormTx.Reqs {
233                 // getAsset from assetStore instead of preload asset, in order to save db query overload
234                 asset, err := w.assetStore.GetByOrmID(req.AssetID)
235                 if err != nil {
236                         return nil, "", errors.Wrap(err, "get asset by ormAsset ID")
237                 }
238
239                 assetID := &btmBc.AssetID{}
240                 if err := assetID.UnmarshalText([]byte(asset.AssetID)); err != nil {
241                         return nil, "", errors.Wrap(err, "Unmarshal muxID")
242                 }
243
244                 // rawDefinitionByte, err := hex.DecodeString(asset.RawDefinitionByte)
245                 // if err != nil {
246                 //      return nil, "", errors.Wrap(err, "decode rawDefinitionByte")
247                 // }
248
249                 // input := vaporTypes.NewCrossChainInput(nil, *muxID, *assetID, req.AssetAmount, req.SourcePos, w.fedProg, rawDefinitionByte)
250                 // destTxData.Inputs = append(destTxData.Inputs, input)
251
252                 // controlProgram, err := hex.DecodeString(req.Script)
253                 // if err != nil {
254                 //      return nil, "", errors.Wrap(err, "decode req.Script")
255                 // }
256
257                 // output := vaporTypes.NewIntraChainOutput(*assetID, req.AssetAmount, controlProgram)
258                 // destTxData.Outputs = append(destTxData.Outputs, output)
259         }
260
261         destTx := btmTypes.NewTx(*destTxData)
262         w.addInputWitness(destTx)
263
264         if err := w.db.Where(ormTx).UpdateColumn(&orm.CrossTransaction{
265                 DestTxHash: sql.NullString{destTx.ID.String(), true},
266         }).Error; err != nil {
267                 return nil, "", err
268         }
269
270         return destTx, destTx.ID.String(), nil
271
272 }
273
274 // tx is a pointer to types.Tx, so the InputArguments can be set and be valid afterward
275 func (w *warder) addInputWitness(tx interface{}) {
276         witness := [][]byte{w.fedProg}
277         switch tx := tx.(type) {
278         case *vaporTypes.Tx:
279                 for i := range tx.Inputs {
280                         tx.SetInputArguments(uint32(i), witness)
281                 }
282
283         case *btmTypes.Tx:
284                 for i := range tx.Inputs {
285                         tx.SetInputArguments(uint32(i), witness)
286                 }
287         }
288 }
289
290 func (w *warder) initDestTxSigns(destTx interface{}, ormTx *orm.CrossTransaction) error {
291         crossTxSigns := []*orm.CrossTransactionSign{}
292         for i := 1; i <= len(w.remotes)+1; i++ {
293                 crossTxSigns = append(crossTxSigns, &orm.CrossTransactionSign{
294                         CrossTransactionID: ormTx.ID,
295                         WarderID:           uint8(i),
296                         Status:             common.CrossTxSignPendingStatus,
297                 })
298         }
299         return w.db.Create(crossTxSigns).Error
300 }
301
302 // TODO:
303 func (w *warder) signDestTx(destTx interface{}, tx *orm.CrossTransaction) error {
304         if tx.Status != common.CrossTxPendingStatus || !tx.DestTxHash.Valid {
305                 return errors.New("cross-chain tx status error")
306         }
307
308         return nil
309 }
310
311 // TODO:
312 func (w *warder) attachSignsForTx(destTx interface{}, ormTx *orm.CrossTransaction, position uint8, signs string) {
313 }
314
315 // TODO:
316 func (w *warder) isTxSignsReachQuorum(destTx interface{}) bool {
317         return false
318 }
319
320 func (w *warder) isLeader() bool {
321         return w.position == 1
322 }
323
324 func (w *warder) submitTx(destTx interface{}) (string, error) {
325         switch tx := destTx.(type) {
326         case *btmTypes.Tx:
327                 return w.mainchainNode.SubmitTx(tx)
328         case *vaporTypes.Tx:
329                 return w.sidechainNode.SubmitTx(tx)
330         default:
331                 return "", errors.New("unknown destTx type")
332         }
333 }
334
335 func (w *warder) updateSubmission(tx *orm.CrossTransaction) error {
336         if err := w.db.Where(tx).UpdateColumn(&orm.CrossTransaction{
337                 Status: common.CrossTxSubmittedStatus,
338         }).Error; err != nil {
339                 return err
340         }
341
342         for _, remote := range w.remotes {
343                 remote.NotifySubmission(tx)
344         }
345         return nil
346 }