OSDN Git Service

05112c6660f4241ce21243dcfecf575116b555a2
[bytom/vapor.git] / federation / warder.go
1 package federation
2
3 import (
4         "database/sql"
5         "encoding/hex"
6         "time"
7
8         btmTypes "github.com/bytom/protocol/bc/types"
9         "github.com/jinzhu/gorm"
10         log "github.com/sirupsen/logrus"
11
12         "github.com/vapor/crypto/ed25519/chainkd"
13         "github.com/vapor/errors"
14         "github.com/vapor/federation/common"
15         "github.com/vapor/federation/config"
16         "github.com/vapor/federation/database"
17         "github.com/vapor/federation/database/orm"
18         "github.com/vapor/federation/service"
19         "github.com/vapor/federation/util"
20         vaporBc "github.com/vapor/protocol/bc"
21         vaporTypes "github.com/vapor/protocol/bc/types"
22 )
23
24 var collectInterval = 5 * time.Second
25
26 type warder struct {
27         db            *gorm.DB
28         assetStore    *database.AssetStore
29         txCh          chan *orm.CrossTransaction
30         fedProg       []byte
31         position      uint8
32         xpub          chainkd.XPub
33         xprv          chainkd.XPrv
34         mainchainNode *service.Node
35         sidechainNode *service.Node
36         remotes       []*service.Warder
37 }
38
39 func NewWarder(db *gorm.DB, assetStore *database.AssetStore, cfg *config.Config) *warder {
40         local, remotes := parseWarders(cfg)
41         return &warder{
42                 db:            db,
43                 assetStore:    assetStore,
44                 txCh:          make(chan *orm.CrossTransaction),
45                 fedProg:       util.ParseFedProg(cfg.Warders, cfg.Quorum),
46                 position:      local.Position,
47                 xpub:          local.XPub,
48                 xprv:          string2xprv(xprvStr),
49                 mainchainNode: service.NewNode(cfg.Mainchain.Upstream),
50                 sidechainNode: service.NewNode(cfg.Sidechain.Upstream),
51                 remotes:       remotes,
52         }
53 }
54
55 func parseWarders(cfg *config.Config) (*service.Warder, []*service.Warder) {
56         var local *service.Warder
57         var remotes []*service.Warder
58         for _, warderCfg := range cfg.Warders {
59                 if warderCfg.IsLocal {
60                         local = service.NewWarder(&warderCfg)
61                 } else {
62                         remote := service.NewWarder(&warderCfg)
63                         remotes = append(remotes, remote)
64                 }
65         }
66
67         if local == nil {
68                 log.Fatal("none local warder set")
69         }
70
71         return local, remotes
72 }
73
74 func (w *warder) Run() {
75         ticker := time.NewTicker(collectInterval)
76         for ; true; <-ticker.C {
77                 txs := []*orm.CrossTransaction{}
78                 if err := w.db.Preload("Chain").Preload("Reqs").
79                         // do not use "Where(&orm.CrossTransaction{Status: common.CrossTxInitiatedStatus})" directly,
80                         // otherwise the field "status" will be ignored
81                         Model(&orm.CrossTransaction{}).Where("status = ?", common.CrossTxInitiatedStatus).
82                         Find(&txs).Error; err == gorm.ErrRecordNotFound {
83                         continue
84                 } else if err != nil {
85                         log.Warnln("collectPendingTx", err)
86                 }
87
88                 for _, tx := range txs {
89                         go w.tryProcessCrossTx(tx)
90                 }
91         }
92 }
93
94 func (w *warder) tryProcessCrossTx(ormTx *orm.CrossTransaction) error {
95         dbTx := w.db.Begin()
96         if err := w.processCrossTx(ormTx); err != nil {
97                 dbTx.Rollback()
98                 return err
99         }
100
101         return dbTx.Commit().Error
102 }
103
104 func (w *warder) processCrossTx(ormTx *orm.CrossTransaction) error {
105         if err := w.validateCrossTx(ormTx); err != nil {
106                 log.Warnln("invalid cross-chain tx", ormTx)
107                 return err
108         }
109
110         destTx, destTxID, err := w.proposeDestTx(ormTx)
111         if err != nil {
112                 log.WithFields(log.Fields{"err": err, "cross-chain tx": ormTx}).Warnln("proposeDestTx")
113                 return err
114         }
115
116         if err := w.initDestTxSigns(destTx, ormTx); err != nil {
117                 log.WithFields(log.Fields{"err": err, "cross-chain tx": ormTx}).Warnln("initDestTxSigns")
118                 return err
119         }
120
121         if err := w.signDestTx(destTx, ormTx); err != nil {
122                 log.WithFields(log.Fields{"err": err, "cross-chain tx": ormTx}).Warnln("signDestTx")
123                 return err
124         }
125
126         for _, remote := range w.remotes {
127                 signs, err := remote.RequestSign(destTx, ormTx)
128                 if err != nil {
129                         log.WithFields(log.Fields{"err": err, "remote": remote, "cross-chain tx": ormTx}).Warnln("RequestSign")
130                         return err
131                 }
132
133                 w.attachSignsForTx(destTx, ormTx, remote.Position, signs)
134         }
135
136         if w.isTxSignsReachQuorum(destTx) && w.isLeader() {
137                 submittedTxID, err := w.submitTx(destTx)
138                 if err != nil {
139                         log.WithFields(log.Fields{"err": err, "cross-chain tx": ormTx, "dest tx": destTx}).Warnln("submitTx")
140                         return err
141                 }
142
143                 if submittedTxID != destTxID {
144                         log.WithFields(log.Fields{"err": err, "cross-chain tx": ormTx, "builtTx ID": destTxID, "submittedTx ID": submittedTxID}).Warnln("submitTx ID mismatch")
145                         return err
146                 }
147
148                 if err := w.updateSubmission(ormTx); err != nil {
149                         log.WithFields(log.Fields{"err": err, "cross-chain tx": ormTx}).Warnln("updateSubmission")
150                         return err
151                 }
152         }
153
154         return nil
155 }
156
157 func (w *warder) validateCrossTx(tx *orm.CrossTransaction) error {
158         switch tx.Status {
159         case common.CrossTxRejectedStatus:
160                 return errors.New("cross-chain tx rejected")
161         case common.CrossTxSubmittedStatus:
162                 return errors.New("cross-chain tx submitted")
163         case common.CrossTxCompletedStatus:
164                 return errors.New("cross-chain tx completed")
165         default:
166                 return nil
167         }
168 }
169
170 func (w *warder) proposeDestTx(tx *orm.CrossTransaction) (interface{}, string, error) {
171         switch tx.Chain.Name {
172         case "bytom":
173                 return w.buildSidechainTx(tx)
174         case "vapor":
175                 return w.buildMainchainTx(tx)
176         default:
177                 return nil, "", errors.New("unknown source chain")
178         }
179 }
180
181 func (w *warder) buildSidechainTx(ormTx *orm.CrossTransaction) (*vaporTypes.Tx, string, error) {
182         destTxData := &vaporTypes.TxData{Version: 1, TimeRange: 0}
183         muxID := &vaporBc.Hash{}
184         if err := muxID.UnmarshalText([]byte(ormTx.SourceMuxID)); err != nil {
185                 return nil, "", errors.Wrap(err, "Unmarshal muxID")
186         }
187
188         for _, req := range ormTx.Reqs {
189                 // getAsset from assetStore instead of preload asset, in order to save db query overload
190                 asset, err := w.assetStore.GetByOrmID(req.AssetID)
191                 if err != nil {
192                         return nil, "", errors.Wrap(err, "get asset by ormAsset ID")
193                 }
194
195                 assetID := &vaporBc.AssetID{}
196                 if err := assetID.UnmarshalText([]byte(asset.AssetID)); err != nil {
197                         return nil, "", errors.Wrap(err, "Unmarshal muxID")
198                 }
199
200                 rawDefinitionByte, err := hex.DecodeString(asset.RawDefinitionByte)
201                 if err != nil {
202                         return nil, "", errors.Wrap(err, "decode rawDefinitionByte")
203                 }
204
205                 issuanceProgramByte, err := hex.DecodeString(asset.IssuanceProgram)
206                 if err != nil {
207                         return nil, "", errors.Wrap(err, "decode issuanceProgramByte")
208                 }
209
210                 input := vaporTypes.NewCrossChainInput(nil, *muxID, *assetID, req.AssetAmount, req.SourcePos, 1, rawDefinitionByte, issuanceProgramByte)
211                 destTxData.Inputs = append(destTxData.Inputs, input)
212
213                 controlProgram, err := hex.DecodeString(req.Script)
214                 if err != nil {
215                         return nil, "", errors.Wrap(err, "decode req.Script")
216                 }
217
218                 output := vaporTypes.NewIntraChainOutput(*assetID, req.AssetAmount, controlProgram)
219                 destTxData.Outputs = append(destTxData.Outputs, output)
220         }
221
222         destTx := vaporTypes.NewTx(*destTxData)
223         w.addInputWitness(destTx)
224
225         if err := w.db.Model(&orm.CrossTransaction{}).
226                 Where(&orm.CrossTransaction{ID: ormTx.ID}).
227                 UpdateColumn(&orm.CrossTransaction{
228                         DestTxHash: sql.NullString{destTx.ID.String(), true},
229                 }).Error; err != nil {
230                 return nil, "", err
231         }
232
233         return destTx, destTx.ID.String(), nil
234 }
235
236 func (w *warder) buildMainchainTx(ormTx *orm.CrossTransaction) (*btmTypes.Tx, string, error) {
237         return nil, "", errors.New("buildMainchainTx not implemented yet")
238 }
239
240 // tx is a pointer to types.Tx, so the InputArguments can be set and be valid afterward
241 func (w *warder) addInputWitness(tx interface{}) {
242         switch tx := tx.(type) {
243         case *vaporTypes.Tx:
244                 args := [][]byte{w.fedProg}
245                 for i := range tx.Inputs {
246                         tx.SetInputArguments(uint32(i), args)
247                 }
248
249         case *btmTypes.Tx:
250                 args := [][]byte{util.SegWitWrap(w.fedProg)}
251                 for i := range tx.Inputs {
252                         tx.SetInputArguments(uint32(i), args)
253                 }
254         }
255 }
256
257 func (w *warder) initDestTxSigns(destTx interface{}, ormTx *orm.CrossTransaction) error {
258         for i := 1; i <= len(w.remotes)+1; i++ {
259                 if err := w.db.Create(&orm.CrossTransactionSign{
260                         CrossTransactionID: ormTx.ID,
261                         WarderID:           uint8(i),
262                         Status:             common.CrossTxSignPendingStatus,
263                 }).Error; err != nil {
264                         return err
265                 }
266         }
267
268         return w.db.Model(&orm.CrossTransaction{}).
269                 Where(&orm.CrossTransaction{ID: ormTx.ID}).
270                 UpdateColumn(&orm.CrossTransaction{
271                         Status: common.CrossTxPendingStatus,
272                 }).Error
273 }
274
275 // TODO:
276 func (w *warder) signDestTx(destTx interface{}, ormTx *orm.CrossTransaction) error {
277         if ormTx.Status != common.CrossTxPendingStatus || !ormTx.DestTxHash.Valid {
278                 return errors.New("cross-chain tx status error")
279         }
280
281         return nil
282 }
283
284 func (w *warder) getSignData(destTx interface{}) ([]string, error) {
285         var signData []string
286
287         switch destTx := destTx.(type) {
288         case *vaporTypes.Tx:
289                 signData = make([]string, len(destTx.Inputs))
290                 for i := range destTx.Inputs {
291                         signHash := destTx.SigHash(uint32(i))
292                         signData[i] = signHash.String()
293                 }
294
295         case *btmTypes.Tx:
296                 signData = make([]string, len(destTx.Inputs))
297                 for i := range destTx.Inputs {
298                         signHash := destTx.SigHash(uint32(i))
299                         signData[i] = signHash.String()
300                 }
301
302         default:
303                 return []string{}, errors.New("unknown tx type")
304         }
305
306         return signData, nil
307 }
308
309 // TODO:
310 func (w *warder) attachSignsForTx(destTx interface{}, ormTx *orm.CrossTransaction, position uint8, signs string) {
311 }
312
313 // TODO:
314 func (w *warder) isTxSignsReachQuorum(destTx interface{}) bool {
315         return false
316 }
317
318 func (w *warder) isLeader() bool {
319         return w.position == 1
320 }
321
322 func (w *warder) submitTx(destTx interface{}) (string, error) {
323         switch tx := destTx.(type) {
324         case *btmTypes.Tx:
325                 return w.mainchainNode.SubmitTx(tx)
326         case *vaporTypes.Tx:
327                 return w.sidechainNode.SubmitTx(tx)
328         default:
329                 return "", errors.New("unknown destTx type")
330         }
331 }
332
333 func (w *warder) updateSubmission(ormTx *orm.CrossTransaction) error {
334         if err := w.db.Model(&orm.CrossTransaction{}).
335                 Where(&orm.CrossTransaction{ID: ormTx.ID}).
336                 UpdateColumn(&orm.CrossTransaction{
337                         Status: common.CrossTxSubmittedStatus,
338                 }).Error; err != nil {
339                 return err
340         }
341
342         for _, remote := range w.remotes {
343                 remote.NotifySubmission(ormTx)
344         }
345         return nil
346 }