OSDN Git Service

a7689fda085b6e0ac63f1fcf9197997e9e2a5167
[bytom/vapor.git] / federation / warder.go
1 package federation
2
3 import (
4         "database/sql"
5         "encoding/hex"
6         "encoding/json"
7         "time"
8
9         btmTypes "github.com/bytom/protocol/bc/types"
10         "github.com/jinzhu/gorm"
11         log "github.com/sirupsen/logrus"
12
13         "github.com/vapor/crypto/ed25519/chainkd"
14         "github.com/vapor/errors"
15         "github.com/vapor/federation/common"
16         "github.com/vapor/federation/config"
17         "github.com/vapor/federation/database"
18         "github.com/vapor/federation/database/orm"
19         "github.com/vapor/federation/service"
20         "github.com/vapor/federation/util"
21         vaporBc "github.com/vapor/protocol/bc"
22         vaporTypes "github.com/vapor/protocol/bc/types"
23 )
24
25 var collectInterval = 5 * time.Second
26
27 var errUnknownTxType = errors.New("unknown tx type")
28
29 type warder struct {
30         db            *gorm.DB
31         assetStore    *database.AssetStore
32         txCh          chan *orm.CrossTransaction
33         fedProg       []byte
34         position      uint8
35         xpub          chainkd.XPub
36         xprv          chainkd.XPrv
37         mainchainNode *service.Node
38         sidechainNode *service.Node
39         remotes       []*service.Warder
40 }
41
42 func NewWarder(db *gorm.DB, assetStore *database.AssetStore, cfg *config.Config) *warder {
43         local, remotes := parseWarders(cfg)
44         return &warder{
45                 db:            db,
46                 assetStore:    assetStore,
47                 txCh:          make(chan *orm.CrossTransaction),
48                 fedProg:       util.ParseFedProg(cfg.Warders, cfg.Quorum),
49                 position:      local.Position,
50                 xpub:          local.XPub,
51                 xprv:          string2xprv(xprvStr),
52                 mainchainNode: service.NewNode(cfg.Mainchain.Upstream),
53                 sidechainNode: service.NewNode(cfg.Sidechain.Upstream),
54                 remotes:       remotes,
55         }
56 }
57
58 func parseWarders(cfg *config.Config) (*service.Warder, []*service.Warder) {
59         var local *service.Warder
60         var remotes []*service.Warder
61         for _, warderCfg := range cfg.Warders {
62                 if warderCfg.IsLocal {
63                         local = service.NewWarder(&warderCfg)
64                 } else {
65                         remote := service.NewWarder(&warderCfg)
66                         remotes = append(remotes, remote)
67                 }
68         }
69
70         if local == nil {
71                 log.Fatal("none local warder set")
72         }
73
74         return local, remotes
75 }
76
77 func (w *warder) Run() {
78         ticker := time.NewTicker(collectInterval)
79         for ; true; <-ticker.C {
80                 txs := []*orm.CrossTransaction{}
81                 if err := w.db.Preload("Chain").Preload("Reqs").
82                         // do not use "Where(&orm.CrossTransaction{Status: common.CrossTxInitiatedStatus})" directly,
83                         // otherwise the field "status" will be ignored
84                         Model(&orm.CrossTransaction{}).Where("status = ?", common.CrossTxInitiatedStatus).
85                         Find(&txs).Error; err == gorm.ErrRecordNotFound {
86                         continue
87                 } else if err != nil {
88                         log.Warnln("collectPendingTx", err)
89                 }
90
91                 for _, tx := range txs {
92                         go w.tryProcessCrossTx(tx)
93                 }
94         }
95 }
96
97 func (w *warder) tryProcessCrossTx(ormTx *orm.CrossTransaction) error {
98         dbTx := w.db.Begin()
99         if err := w.processCrossTx(ormTx); err != nil {
100                 dbTx.Rollback()
101                 return err
102         }
103
104         return dbTx.Commit().Error
105 }
106
107 func (w *warder) processCrossTx(ormTx *orm.CrossTransaction) error {
108         if err := w.validateCrossTx(ormTx); err != nil {
109                 log.Warnln("invalid cross-chain tx", ormTx)
110                 return err
111         }
112
113         destTx, destTxID, err := w.proposeDestTx(ormTx)
114         if err != nil {
115                 log.WithFields(log.Fields{"err": err, "cross-chain tx": ormTx}).Warnln("proposeDestTx")
116                 return err
117         }
118
119         if err := w.initDestTxSigns(destTx, ormTx); err != nil {
120                 log.WithFields(log.Fields{"err": err, "cross-chain tx": ormTx}).Warnln("initDestTxSigns")
121                 return err
122         }
123
124         signs, err := w.getSigns(destTx, ormTx)
125         if err != nil {
126                 log.WithFields(log.Fields{"err": err, "cross-chain tx": ormTx}).Warnln("getSigns")
127                 return err
128         }
129
130         w.attachSignsForTx(destTx, ormTx, w.position, signs)
131
132         for _, remote := range w.remotes {
133                 signs, err := remote.RequestSigns(destTx, ormTx)
134                 if err != nil {
135                         log.WithFields(log.Fields{"err": err, "remote": remote, "cross-chain tx": ormTx}).Warnln("RequestSign")
136                         return err
137                 }
138
139                 w.attachSignsForTx(destTx, ormTx, remote.Position, signs)
140         }
141
142         if w.isTxSignsReachQuorum(destTx) && w.isLeader() {
143                 submittedTxID, err := w.submitTx(destTx)
144                 if err != nil {
145                         log.WithFields(log.Fields{"err": err, "cross-chain tx": ormTx, "dest tx": destTx}).Warnln("submitTx")
146                         return err
147                 }
148
149                 if submittedTxID != destTxID {
150                         log.WithFields(log.Fields{"err": err, "cross-chain tx": ormTx, "builtTx ID": destTxID, "submittedTx ID": submittedTxID}).Warnln("submitTx ID mismatch")
151                         return err
152                 }
153
154                 if err := w.updateSubmission(ormTx); err != nil {
155                         log.WithFields(log.Fields{"err": err, "cross-chain tx": ormTx}).Warnln("updateSubmission")
156                         return err
157                 }
158         }
159
160         return nil
161 }
162
163 func (w *warder) validateCrossTx(tx *orm.CrossTransaction) error {
164         switch tx.Status {
165         case common.CrossTxRejectedStatus:
166                 return errors.New("cross-chain tx rejected")
167         case common.CrossTxSubmittedStatus:
168                 return errors.New("cross-chain tx submitted")
169         case common.CrossTxCompletedStatus:
170                 return errors.New("cross-chain tx completed")
171         default:
172                 return nil
173         }
174 }
175
176 func (w *warder) proposeDestTx(tx *orm.CrossTransaction) (interface{}, string, error) {
177         switch tx.Chain.Name {
178         case "bytom":
179                 return w.buildSidechainTx(tx)
180         case "vapor":
181                 return w.buildMainchainTx(tx)
182         default:
183                 return nil, "", errors.New("unknown source chain")
184         }
185 }
186
187 func (w *warder) buildSidechainTx(ormTx *orm.CrossTransaction) (*vaporTypes.Tx, string, error) {
188         destTxData := &vaporTypes.TxData{Version: 1, TimeRange: 0}
189         muxID := &vaporBc.Hash{}
190         if err := muxID.UnmarshalText([]byte(ormTx.SourceMuxID)); err != nil {
191                 return nil, "", errors.Wrap(err, "Unmarshal muxID")
192         }
193
194         for _, req := range ormTx.Reqs {
195                 // getAsset from assetStore instead of preload asset, in order to save db query overload
196                 asset, err := w.assetStore.GetByOrmID(req.AssetID)
197                 if err != nil {
198                         return nil, "", errors.Wrap(err, "get asset by ormAsset ID")
199                 }
200
201                 assetID := &vaporBc.AssetID{}
202                 if err := assetID.UnmarshalText([]byte(asset.AssetID)); err != nil {
203                         return nil, "", errors.Wrap(err, "Unmarshal muxID")
204                 }
205
206                 rawDefinitionByte, err := hex.DecodeString(asset.RawDefinitionByte)
207                 if err != nil {
208                         return nil, "", errors.Wrap(err, "decode rawDefinitionByte")
209                 }
210
211                 issuanceProgramByte, err := hex.DecodeString(asset.IssuanceProgram)
212                 if err != nil {
213                         return nil, "", errors.Wrap(err, "decode issuanceProgramByte")
214                 }
215
216                 input := vaporTypes.NewCrossChainInput(nil, *muxID, *assetID, req.AssetAmount, req.SourcePos, 1, rawDefinitionByte, issuanceProgramByte)
217                 destTxData.Inputs = append(destTxData.Inputs, input)
218
219                 controlProgram, err := hex.DecodeString(req.Script)
220                 if err != nil {
221                         return nil, "", errors.Wrap(err, "decode req.Script")
222                 }
223
224                 output := vaporTypes.NewIntraChainOutput(*assetID, req.AssetAmount, controlProgram)
225                 destTxData.Outputs = append(destTxData.Outputs, output)
226         }
227
228         destTx := vaporTypes.NewTx(*destTxData)
229         w.addInputWitness(destTx)
230
231         if err := w.db.Model(&orm.CrossTransaction{}).
232                 Where(&orm.CrossTransaction{ID: ormTx.ID}).
233                 UpdateColumn(&orm.CrossTransaction{
234                         DestTxHash: sql.NullString{destTx.ID.String(), true},
235                 }).Error; err != nil {
236                 return nil, "", err
237         }
238
239         return destTx, destTx.ID.String(), nil
240 }
241
242 func (w *warder) buildMainchainTx(ormTx *orm.CrossTransaction) (*btmTypes.Tx, string, error) {
243         return nil, "", errors.New("buildMainchainTx not implemented yet")
244 }
245
246 // tx is a pointer to types.Tx, so the InputArguments can be set and be valid afterward
247 func (w *warder) addInputWitness(tx interface{}) {
248         switch tx := tx.(type) {
249         case *vaporTypes.Tx:
250                 args := [][]byte{w.fedProg}
251                 for i := range tx.Inputs {
252                         tx.SetInputArguments(uint32(i), args)
253                 }
254
255         case *btmTypes.Tx:
256                 args := [][]byte{util.SegWitWrap(w.fedProg)}
257                 for i := range tx.Inputs {
258                         tx.SetInputArguments(uint32(i), args)
259                 }
260         }
261 }
262
263 func (w *warder) initDestTxSigns(destTx interface{}, ormTx *orm.CrossTransaction) error {
264         for i := 1; i <= len(w.remotes)+1; i++ {
265                 if err := w.db.Create(&orm.CrossTransactionSign{
266                         CrossTransactionID: ormTx.ID,
267                         WarderID:           uint8(i),
268                         Status:             common.CrossTxSignPendingStatus,
269                 }).Error; err != nil {
270                         return err
271                 }
272         }
273
274         return w.db.Model(&orm.CrossTransaction{}).
275                 Where(&orm.CrossTransaction{ID: ormTx.ID}).
276                 UpdateColumn(&orm.CrossTransaction{
277                         Status: common.CrossTxPendingStatus,
278                 }).Error
279 }
280
281 func (w *warder) getSignData(destTx interface{}) ([]string, error) {
282         var signData []string
283
284         switch destTx := destTx.(type) {
285         case *vaporTypes.Tx:
286                 signData = make([]string, len(destTx.Inputs))
287                 for i := range destTx.Inputs {
288                         signHash := destTx.SigHash(uint32(i))
289                         signData[i] = signHash.String()
290                 }
291
292         case *btmTypes.Tx:
293                 signData = make([]string, len(destTx.Inputs))
294                 for i := range destTx.Inputs {
295                         signHash := destTx.SigHash(uint32(i))
296                         signData[i] = signHash.String()
297                 }
298
299         default:
300                 return []string{}, errUnknownTxType
301         }
302
303         return signData, nil
304 }
305
306 func (w *warder) getSigns(destTx interface{}, ormTx *orm.CrossTransaction) ([]string, error) {
307         if ormTx.Status != common.CrossTxPendingStatus || !ormTx.DestTxHash.Valid {
308                 return nil, errors.New("cross-chain tx status error")
309         }
310
311         signData, err := w.getSignData(destTx)
312         if err != nil {
313                 return nil, errors.New("getSignData")
314         }
315
316         var signs []string
317         for _, data := range signData {
318                 var sign []byte
319                 // vaporBc.Hash & btmBc.Hash are marshaled in the same way
320                 msg := &vaporBc.Hash{}
321                 if err := msg.UnmarshalText([]byte(data)); err != nil {
322                         return nil, errors.Wrap(err, "Unmarshal signData")
323                 }
324
325                 sign = w.xprv.Sign([]byte(msg.String()))
326                 signs = append(signs, hex.EncodeToString(sign))
327         }
328
329         return signs, nil
330 }
331
332 // TODO:
333 func (w *warder) attachSignsForTx(destTx interface{}, ormTx *orm.CrossTransaction, position uint8, signs []string) error {
334         var inputsLen int
335         switch destTx := destTx.(type) {
336         case *vaporTypes.Tx:
337                 inputsLen = len(destTx.Inputs)
338         case *btmTypes.Tx:
339                 inputsLen = len(destTx.Inputs)
340         default:
341                 return errUnknownTxType
342         }
343
344         signWitness := make([][]string, inputsLen)
345
346         b, err := json.Marshal(signs)
347         if err != nil {
348                 return errors.Wrap(err, "marshal signs")
349         }
350
351         return w.db.Model(&orm.CrossTransactionSign{}).
352                 Where(&orm.CrossTransactionSign{
353                         CrossTransactionID: ormTx.ID,
354                         WarderID:           w.position,
355                 }).
356                 UpdateColumn(&orm.CrossTransactionSign{
357                         Signatures: string(b),
358                         Status:     common.CrossTxSignCompletedStatus,
359                 }).Error
360 }
361
362 // TODO:
363 func (w *warder) isTxSignsReachQuorum(destTx interface{}) bool {
364         return false
365 }
366
367 func (w *warder) isLeader() bool {
368         return w.position == 1
369 }
370
371 func (w *warder) submitTx(destTx interface{}) (string, error) {
372         switch tx := destTx.(type) {
373         case *btmTypes.Tx:
374                 return w.mainchainNode.SubmitTx(tx)
375         case *vaporTypes.Tx:
376                 return w.sidechainNode.SubmitTx(tx)
377         default:
378                 return "", errUnknownTxType
379         }
380 }
381
382 func (w *warder) updateSubmission(ormTx *orm.CrossTransaction) error {
383         if err := w.db.Model(&orm.CrossTransaction{}).
384                 Where(&orm.CrossTransaction{ID: ormTx.ID}).
385                 UpdateColumn(&orm.CrossTransaction{
386                         Status: common.CrossTxSubmittedStatus,
387                 }).Error; err != nil {
388                 return err
389         }
390
391         for _, remote := range w.remotes {
392                 remote.NotifySubmission(ormTx)
393         }
394         return nil
395 }