OSDN Git Service

fix getSigns
[bytom/vapor.git] / federation / warder.go
1 package federation
2
3 import (
4         "database/sql"
5         "encoding/hex"
6         "encoding/json"
7         "time"
8
9         btmTypes "github.com/bytom/protocol/bc/types"
10         "github.com/jinzhu/gorm"
11         log "github.com/sirupsen/logrus"
12
13         "github.com/vapor/crypto/ed25519/chainkd"
14         "github.com/vapor/errors"
15         "github.com/vapor/federation/common"
16         "github.com/vapor/federation/config"
17         "github.com/vapor/federation/database"
18         "github.com/vapor/federation/database/orm"
19         "github.com/vapor/federation/service"
20         "github.com/vapor/federation/util"
21         vaporBc "github.com/vapor/protocol/bc"
22         vaporTypes "github.com/vapor/protocol/bc/types"
23 )
24
25 var collectInterval = 5 * time.Second
26
27 var errUnknownTxType = errors.New("unknown tx type")
28
29 type warder struct {
30         db            *gorm.DB
31         assetStore    *database.AssetStore
32         txCh          chan *orm.CrossTransaction
33         fedProg       []byte
34         position      uint8
35         xpub          chainkd.XPub
36         xprv          chainkd.XPrv
37         mainchainNode *service.Node
38         sidechainNode *service.Node
39         remotes       []*service.Warder
40 }
41
42 func NewWarder(db *gorm.DB, assetStore *database.AssetStore, cfg *config.Config) *warder {
43         local, remotes := parseWarders(cfg)
44         return &warder{
45                 db:            db,
46                 assetStore:    assetStore,
47                 txCh:          make(chan *orm.CrossTransaction),
48                 fedProg:       util.ParseFedProg(cfg.Warders, cfg.Quorum),
49                 position:      local.Position,
50                 xpub:          local.XPub,
51                 xprv:          string2xprv(xprvStr),
52                 mainchainNode: service.NewNode(cfg.Mainchain.Upstream),
53                 sidechainNode: service.NewNode(cfg.Sidechain.Upstream),
54                 remotes:       remotes,
55         }
56 }
57
58 func parseWarders(cfg *config.Config) (*service.Warder, []*service.Warder) {
59         var local *service.Warder
60         var remotes []*service.Warder
61         for _, warderCfg := range cfg.Warders {
62                 if warderCfg.IsLocal {
63                         local = service.NewWarder(&warderCfg)
64                 } else {
65                         remote := service.NewWarder(&warderCfg)
66                         remotes = append(remotes, remote)
67                 }
68         }
69
70         if local == nil {
71                 log.Fatal("none local warder set")
72         }
73
74         return local, remotes
75 }
76
77 func (w *warder) Run() {
78         ticker := time.NewTicker(collectInterval)
79         for ; true; <-ticker.C {
80                 txs := []*orm.CrossTransaction{}
81                 if err := w.db.Preload("Chain").Preload("Reqs").
82                         // do not use "Where(&orm.CrossTransaction{Status: common.CrossTxInitiatedStatus})" directly,
83                         // otherwise the field "status" will be ignored
84                         Model(&orm.CrossTransaction{}).Where("status = ?", common.CrossTxInitiatedStatus).
85                         Find(&txs).Error; err == gorm.ErrRecordNotFound {
86                         continue
87                 } else if err != nil {
88                         log.Warnln("collectPendingTx", err)
89                 }
90
91                 for _, tx := range txs {
92                         go w.tryProcessCrossTx(tx)
93                 }
94         }
95 }
96
97 func (w *warder) tryProcessCrossTx(ormTx *orm.CrossTransaction) error {
98         dbTx := w.db.Begin()
99         if err := w.processCrossTx(ormTx); err != nil {
100                 dbTx.Rollback()
101                 return err
102         }
103
104         return dbTx.Commit().Error
105 }
106
107 func (w *warder) processCrossTx(ormTx *orm.CrossTransaction) error {
108         if err := w.validateCrossTx(ormTx); err != nil {
109                 log.Warnln("invalid cross-chain tx", ormTx)
110                 return err
111         }
112
113         destTx, destTxID, err := w.proposeDestTx(ormTx)
114         if err != nil {
115                 log.WithFields(log.Fields{"err": err, "cross-chain tx": ormTx}).Warnln("proposeDestTx")
116                 return err
117         }
118
119         if err := w.initDestTxSigns(destTx, ormTx); err != nil {
120                 log.WithFields(log.Fields{"err": err, "cross-chain tx": ormTx}).Warnln("initDestTxSigns")
121                 return err
122         }
123
124         var signersSigns [][][]byte
125
126         signerSigns, err := w.getSigns(destTx, ormTx)
127         if err != nil {
128                 log.WithFields(log.Fields{"err": err, "cross-chain tx": ormTx}).Warnln("getSigns")
129                 return err
130         }
131
132         // TODO: pass ref?
133         signersSigns = w.attachSignsForTx( /*destTx,*/ ormTx, w.position, signerSigns)
134
135         for _, remote := range w.remotes {
136                 signerSigns, err := remote.RequestSigns(destTx, ormTx)
137                 if err != nil {
138                         log.WithFields(log.Fields{"err": err, "remote": remote, "cross-chain tx": ormTx}).Warnln("RequestSign")
139                         return err
140                 }
141
142                 // TODO: pass ref?
143                 signersSigns = w.attachSignsForTx( /*destTx,*/ ormTx, remote.Position, signerSigns)
144         }
145
146         if w.isTxSignsReachQuorum(signersSigns) && w.isLeader() {
147                 // TODO: check err
148                 w.finalizeTx(destTx, signersSigns)
149                 submittedTxID, err := w.submitTx(destTx)
150                 if err != nil {
151                         log.WithFields(log.Fields{"err": err, "cross-chain tx": ormTx, "dest tx": destTx}).Warnln("submitTx")
152                         return err
153                 }
154
155                 if submittedTxID != destTxID {
156                         log.WithFields(log.Fields{"err": err, "cross-chain tx": ormTx, "builtTx ID": destTxID, "submittedTx ID": submittedTxID}).Warnln("submitTx ID mismatch")
157                         return err
158                 }
159
160                 if err := w.updateSubmission(ormTx); err != nil {
161                         log.WithFields(log.Fields{"err": err, "cross-chain tx": ormTx}).Warnln("updateSubmission")
162                         return err
163                 }
164         }
165
166         return nil
167 }
168
169 func (w *warder) validateCrossTx(tx *orm.CrossTransaction) error {
170         switch tx.Status {
171         case common.CrossTxRejectedStatus:
172                 return errors.New("cross-chain tx rejected")
173         case common.CrossTxSubmittedStatus:
174                 return errors.New("cross-chain tx submitted")
175         case common.CrossTxCompletedStatus:
176                 return errors.New("cross-chain tx completed")
177         default:
178                 return nil
179         }
180 }
181
182 func (w *warder) proposeDestTx(tx *orm.CrossTransaction) (interface{}, string, error) {
183         switch tx.Chain.Name {
184         case "bytom":
185                 return w.buildSidechainTx(tx)
186         case "vapor":
187                 return w.buildMainchainTx(tx)
188         default:
189                 return nil, "", errors.New("unknown source chain")
190         }
191 }
192
193 func (w *warder) buildSidechainTx(ormTx *orm.CrossTransaction) (*vaporTypes.Tx, string, error) {
194         destTxData := &vaporTypes.TxData{Version: 1, TimeRange: 0}
195         muxID := &vaporBc.Hash{}
196         if err := muxID.UnmarshalText([]byte(ormTx.SourceMuxID)); err != nil {
197                 return nil, "", errors.Wrap(err, "Unmarshal muxID")
198         }
199
200         for _, req := range ormTx.Reqs {
201                 // getAsset from assetStore instead of preload asset, in order to save db query overload
202                 asset, err := w.assetStore.GetByOrmID(req.AssetID)
203                 if err != nil {
204                         return nil, "", errors.Wrap(err, "get asset by ormAsset ID")
205                 }
206
207                 assetID := &vaporBc.AssetID{}
208                 if err := assetID.UnmarshalText([]byte(asset.AssetID)); err != nil {
209                         return nil, "", errors.Wrap(err, "Unmarshal muxID")
210                 }
211
212                 rawDefinitionByte, err := hex.DecodeString(asset.RawDefinitionByte)
213                 if err != nil {
214                         return nil, "", errors.Wrap(err, "decode rawDefinitionByte")
215                 }
216
217                 issuanceProgramByte, err := hex.DecodeString(asset.IssuanceProgram)
218                 if err != nil {
219                         return nil, "", errors.Wrap(err, "decode issuanceProgramByte")
220                 }
221
222                 input := vaporTypes.NewCrossChainInput(nil, *muxID, *assetID, req.AssetAmount, req.SourcePos, 1, rawDefinitionByte, issuanceProgramByte)
223                 destTxData.Inputs = append(destTxData.Inputs, input)
224
225                 controlProgram, err := hex.DecodeString(req.Script)
226                 if err != nil {
227                         return nil, "", errors.Wrap(err, "decode req.Script")
228                 }
229
230                 output := vaporTypes.NewIntraChainOutput(*assetID, req.AssetAmount, controlProgram)
231                 destTxData.Outputs = append(destTxData.Outputs, output)
232         }
233
234         destTx := vaporTypes.NewTx(*destTxData)
235         w.addInputWitness(destTx)
236
237         if err := w.db.Model(&orm.CrossTransaction{}).
238                 Where(&orm.CrossTransaction{ID: ormTx.ID}).
239                 UpdateColumn(&orm.CrossTransaction{
240                         DestTxHash: sql.NullString{destTx.ID.String(), true},
241                 }).Error; err != nil {
242                 return nil, "", err
243         }
244
245         return destTx, destTx.ID.String(), nil
246 }
247
248 func (w *warder) buildMainchainTx(ormTx *orm.CrossTransaction) (*btmTypes.Tx, string, error) {
249         return nil, "", errors.New("buildMainchainTx not implemented yet")
250 }
251
252 // tx is a pointer to types.Tx, so the InputArguments can be set and be valid afterward
253 func (w *warder) addInputWitness(tx interface{}) {
254         switch tx := tx.(type) {
255         case *vaporTypes.Tx:
256                 args := [][]byte{w.fedProg}
257                 for i := range tx.Inputs {
258                         tx.SetInputArguments(uint32(i), args)
259                 }
260
261         case *btmTypes.Tx:
262                 args := [][]byte{util.SegWitWrap(w.fedProg)}
263                 for i := range tx.Inputs {
264                         tx.SetInputArguments(uint32(i), args)
265                 }
266         }
267 }
268
269 func (w *warder) initDestTxSigns(destTx interface{}, ormTx *orm.CrossTransaction) error {
270         for i := 1; i <= len(w.remotes)+1; i++ {
271                 if err := w.db.Create(&orm.CrossTransactionSign{
272                         CrossTransactionID: ormTx.ID,
273                         WarderID:           uint8(i),
274                         Status:             common.CrossTxSignPendingStatus,
275                 }).Error; err != nil {
276                         return err
277                 }
278         }
279
280         return w.db.Model(&orm.CrossTransaction{}).
281                 Where(&orm.CrossTransaction{ID: ormTx.ID}).
282                 UpdateColumn(&orm.CrossTransaction{
283                         Status: common.CrossTxPendingStatus,
284                 }).Error
285 }
286
287 func (w *warder) getSignData(destTx interface{}) ([][]byte, error) {
288         var signData [][]byte
289
290         switch destTx := destTx.(type) {
291         case *vaporTypes.Tx:
292                 signData = make([][]byte, len(destTx.Inputs))
293                 for i := range destTx.Inputs {
294                         signHash := destTx.SigHash(uint32(i))
295                         signData[i] = signHash.Bytes()
296                 }
297
298         case *btmTypes.Tx:
299                 signData = make([][]byte, len(destTx.Inputs))
300                 for i := range destTx.Inputs {
301                         signHash := destTx.SigHash(uint32(i))
302                         signData[i] = signHash.Bytes()
303                 }
304
305         default:
306                 return [][]byte{}, errUnknownTxType
307         }
308
309         return signData, nil
310 }
311
312 func (w *warder) getSigns(destTx interface{}, ormTx *orm.CrossTransaction) ([][]byte, error) {
313         if ormTx.Status != common.CrossTxPendingStatus || !ormTx.DestTxHash.Valid {
314                 return nil, errors.New("cross-chain tx status error")
315         }
316
317         signData, err := w.getSignData(destTx)
318         if err != nil {
319                 return nil, errors.New("getSignData")
320         }
321
322         var signs [][]byte
323         for _, data := range signData {
324                 var b [32]byte
325                 copy(b[:], data)
326                 // vaporBc.Hash & btmBc.Hash are marshaled in the same way
327                 msg := vaporBc.NewHash(b)
328                 sign := w.xprv.Sign([]byte(msg.String()))
329                 signs = append(signs, sign)
330         }
331
332         return signs, nil
333 }
334
335 // TODO:
336 func (w *warder) attachSignsForTx(destTx interface{}, ormTx *orm.CrossTransaction, position uint8, signs []string) error {
337         var inputsLen int
338         switch destTx := destTx.(type) {
339         case *vaporTypes.Tx:
340                 inputsLen = len(destTx.Inputs)
341         case *btmTypes.Tx:
342                 inputsLen = len(destTx.Inputs)
343         default:
344                 return errUnknownTxType
345         }
346
347         // finalize tx?
348
349         signWitness := make([][]string, inputsLen)
350
351         b, err := json.Marshal(signs)
352         if err != nil {
353                 return errors.Wrap(err, "marshal signs")
354         }
355
356         return w.db.Model(&orm.CrossTransactionSign{}).
357                 Where(&orm.CrossTransactionSign{
358                         CrossTransactionID: ormTx.ID,
359                         WarderID:           w.position,
360                 }).
361                 UpdateColumn(&orm.CrossTransactionSign{
362                         Signatures: string(b),
363                         Status:     common.CrossTxSignCompletedStatus,
364                 }).Error
365 }
366
367 // TODO:
368 func (w *warder) isTxSignsReachQuorum(destTx interface{}) bool {
369         return false
370 }
371
372 func (w *warder) isLeader() bool {
373         return w.position == 1
374 }
375
376 func (w *warder) submitTx(destTx interface{}) (string, error) {
377         switch tx := destTx.(type) {
378         case *btmTypes.Tx:
379                 return w.mainchainNode.SubmitTx(tx)
380         case *vaporTypes.Tx:
381                 return w.sidechainNode.SubmitTx(tx)
382         default:
383                 return "", errUnknownTxType
384         }
385 }
386
387 func (w *warder) updateSubmission(ormTx *orm.CrossTransaction) error {
388         if err := w.db.Model(&orm.CrossTransaction{}).
389                 Where(&orm.CrossTransaction{ID: ormTx.ID}).
390                 UpdateColumn(&orm.CrossTransaction{
391                         Status: common.CrossTxSubmittedStatus,
392                 }).Error; err != nil {
393                 return err
394         }
395
396         for _, remote := range w.remotes {
397                 remote.NotifySubmission(ormTx)
398         }
399         return nil
400 }