OSDN Git Service

init finalizeTx
[bytom/vapor.git] / federation / warder.go
index 960df34..4e4c407 100644 (file)
@@ -2,6 +2,8 @@ package federation
 
 import (
        "database/sql"
+       "encoding/hex"
+       "encoding/json"
        "time"
 
        btmTypes "github.com/bytom/protocol/bc/types"
@@ -12,31 +14,41 @@ import (
        "github.com/vapor/errors"
        "github.com/vapor/federation/common"
        "github.com/vapor/federation/config"
+       "github.com/vapor/federation/database"
        "github.com/vapor/federation/database/orm"
        "github.com/vapor/federation/service"
+       "github.com/vapor/federation/util"
+       vaporBc "github.com/vapor/protocol/bc"
        vaporTypes "github.com/vapor/protocol/bc/types"
 )
 
 var collectInterval = 5 * time.Second
 
+var errUnknownTxType = errors.New("unknown tx type")
+
 type warder struct {
-       position       uint8
-       xpub           chainkd.XPub
-       colletInterval time.Duration
-       db             *gorm.DB
-       txCh           chan *orm.CrossTransaction
-       mainchainNode  *service.Node
-       sidechainNode  *service.Node
-       remotes        []*service.Warder
+       db            *gorm.DB
+       assetStore    *database.AssetStore
+       txCh          chan *orm.CrossTransaction
+       fedProg       []byte
+       position      uint8
+       xpub          chainkd.XPub
+       xprv          chainkd.XPrv
+       mainchainNode *service.Node
+       sidechainNode *service.Node
+       remotes       []*service.Warder
 }
 
-func NewWarder(cfg *config.Config, db *gorm.DB) *warder {
+func NewWarder(db *gorm.DB, assetStore *database.AssetStore, cfg *config.Config) *warder {
        local, remotes := parseWarders(cfg)
        return &warder{
-               position:      local.Position,
-               xpub:          local.XPub,
                db:            db,
+               assetStore:    assetStore,
                txCh:          make(chan *orm.CrossTransaction),
+               fedProg:       util.ParseFedProg(cfg.Warders, cfg.Quorum),
+               position:      local.Position,
+               xpub:          local.XPub,
+               xprv:          string2xprv(xprvStr),
                mainchainNode: service.NewNode(cfg.Mainchain.Upstream),
                sidechainNode: service.NewNode(cfg.Sidechain.Upstream),
                remotes:       remotes,
@@ -50,8 +62,8 @@ func parseWarders(cfg *config.Config) (*service.Warder, []*service.Warder) {
                if warderCfg.IsLocal {
                        local = service.NewWarder(&warderCfg)
                } else {
-                       remoteWarder := service.NewWarder(&warderCfg)
-                       remotes = append(remotes, remoteWarder)
+                       remote := service.NewWarder(&warderCfg)
+                       remotes = append(remotes, remote)
                }
        }
 
@@ -63,18 +75,13 @@ func parseWarders(cfg *config.Config) (*service.Warder, []*service.Warder) {
 }
 
 func (w *warder) Run() {
-       go w.collectPendingTx()
-       go w.processCrossTxRoutine()
-}
-
-func (w *warder) collectPendingTx() {
        ticker := time.NewTicker(collectInterval)
        for ; true; <-ticker.C {
                txs := []*orm.CrossTransaction{}
                if err := w.db.Preload("Chain").Preload("Reqs").
-                       // do not use "Where(&orm.CrossTransaction{Status: common.CrossTxPendingStatus})" directly,
+                       // do not use "Where(&orm.CrossTransaction{Status: common.CrossTxInitiatedStatus})" directly,
                        // otherwise the field "status" will be ignored
-                       Model(&orm.CrossTransaction{}).Where("status = ?", common.CrossTxPendingStatus).
+                       Model(&orm.CrossTransaction{}).Where("status = ?", common.CrossTxInitiatedStatus).
                        Find(&txs).Error; err == gorm.ErrRecordNotFound {
                        continue
                } else if err != nil {
@@ -82,58 +89,81 @@ func (w *warder) collectPendingTx() {
                }
 
                for _, tx := range txs {
-                       w.txCh <- tx
+                       go w.tryProcessCrossTx(tx)
                }
        }
 }
 
-func (w *warder) processCrossTxRoutine() {
-       for ormTx := range w.txCh {
-               if err := w.validateCrossTx(ormTx); err != nil {
-                       log.Warnln("invalid cross-chain tx", ormTx)
-                       continue
-               }
+func (w *warder) tryProcessCrossTx(ormTx *orm.CrossTransaction) error {
+       dbTx := w.db.Begin()
+       if err := w.processCrossTx(ormTx); err != nil {
+               dbTx.Rollback()
+               return err
+       }
 
-               destTx, destTxID, err := w.proposeDestTx(ormTx)
-               if err != nil {
-                       log.WithFields(log.Fields{"err": err, "cross-chain tx": ormTx}).Warnln("proposeDestTx")
-                       continue
-               }
+       return dbTx.Commit().Error
+}
 
-               if err := w.signDestTx(destTx, ormTx); err != nil {
-                       log.WithFields(log.Fields{"err": err, "cross-chain tx": ormTx}).Warnln("signDestTx")
-                       continue
-               }
+func (w *warder) processCrossTx(ormTx *orm.CrossTransaction) error {
+       if err := w.validateCrossTx(ormTx); err != nil {
+               log.Warnln("invalid cross-chain tx", ormTx)
+               return err
+       }
+
+       destTx, destTxID, err := w.proposeDestTx(ormTx)
+       if err != nil {
+               log.WithFields(log.Fields{"err": err, "cross-chain tx": ormTx}).Warnln("proposeDestTx")
+               return err
+       }
+
+       if err := w.initDestTxSigns(destTx, ormTx); err != nil {
+               log.WithFields(log.Fields{"err": err, "cross-chain tx": ormTx}).Warnln("initDestTxSigns")
+               return err
+       }
 
-               for _, remote := range w.remotes {
-                       signs, err := remote.RequestSign(destTx, ormTx)
-                       if err != nil {
-                               log.WithFields(log.Fields{"err": err, "remote": remote, "cross-chain tx": ormTx}).Warnln("RequestSign")
-                               continue
-                       }
+       var signersSigns [][][]byte
+
+       signerSigns, err := w.getSigns(destTx, ormTx)
+       if err != nil {
+               log.WithFields(log.Fields{"err": err, "cross-chain tx": ormTx}).Warnln("getSigns")
+               return err
+       }
 
-                       w.attachSignsForTx(destTx, ormTx, remote.Position, signs)
+       // TODO: pass ref?
+       signersSigns = w.attachSignsForTx( /*destTx,*/ ormTx, w.position, signerSigns)
+
+       for _, remote := range w.remotes {
+               signerSigns, err := remote.RequestSigns(destTx, ormTx)
+               if err != nil {
+                       log.WithFields(log.Fields{"err": err, "remote": remote, "cross-chain tx": ormTx}).Warnln("RequestSign")
+                       return err
                }
 
-               if w.isTxSignsReachQuorum(destTx) && w.isLeader() {
-                       submittedTxID, err := w.submitTx(destTx)
-                       if err != nil {
-                               log.WithFields(log.Fields{"err": err, "cross-chain tx": ormTx, "dest tx": destTx}).Warnln("submitTx")
-                               continue
-                       }
+               // TODO: pass ref?
+               signersSigns = w.attachSignsForTx( /*destTx,*/ ormTx, remote.Position, signerSigns)
+       }
 
-                       if submittedTxID != destTxID {
-                               log.WithFields(log.Fields{"err": err, "cross-chain tx": ormTx, "builtTx ID": destTxID, "submittedTx ID": submittedTxID}).Warnln("submitTx ID mismatch")
-                               continue
+       if w.isTxSignsReachQuorum(signersSigns) && w.isLeader() {
+               // TODO: check err
+               w.finalizeTx(destTx, signersSigns)
+               submittedTxID, err := w.submitTx(destTx)
+               if err != nil {
+                       log.WithFields(log.Fields{"err": err, "cross-chain tx": ormTx, "dest tx": destTx}).Warnln("submitTx")
+                       return err
+               }
 
-                       }
+               if submittedTxID != destTxID {
+                       log.WithFields(log.Fields{"err": err, "cross-chain tx": ormTx, "builtTx ID": destTxID, "submittedTx ID": submittedTxID}).Warnln("submitTx ID mismatch")
+                       return err
+               }
 
-                       if err := w.updateSubmission(ormTx); err != nil {
-                               log.WithFields(log.Fields{"err": err, "cross-chain tx": ormTx}).Warnln("updateSubmission")
-                               continue
-                       }
+               if err := w.updateSubmission(ormTx); err != nil {
+                       log.WithFields(log.Fields{"err": err, "cross-chain tx": ormTx}).Warnln("updateSubmission")
+                       return err
                }
        }
+
+       return nil
 }
 
 func (w *warder) validateCrossTx(tx *orm.CrossTransaction) error {
@@ -144,18 +174,9 @@ func (w *warder) validateCrossTx(tx *orm.CrossTransaction) error {
                return errors.New("cross-chain tx submitted")
        case common.CrossTxCompletedStatus:
                return errors.New("cross-chain tx completed")
+       default:
+               return nil
        }
-
-       crossTxReqs := []*orm.CrossTransactionReq{}
-       if err := w.db.Where(&orm.CrossTransactionReq{CrossTransactionID: tx.ID}).Find(&crossTxReqs).Error; err != nil {
-               return err
-       }
-
-       if len(crossTxReqs) != len(tx.Reqs) {
-               return errors.New("cross-chain requests count mismatch")
-       }
-
-       return nil
 }
 
 func (w *warder) proposeDestTx(tx *orm.CrossTransaction) (interface{}, string, error) {
@@ -169,43 +190,178 @@ func (w *warder) proposeDestTx(tx *orm.CrossTransaction) (interface{}, string, e
        }
 }
 
-// TODO:
-func (w *warder) buildSidechainTx(tx *orm.CrossTransaction) (*vaporTypes.Tx, string, error) {
-       sidechainTx := &vaporTypes.Tx{}
+func (w *warder) buildSidechainTx(ormTx *orm.CrossTransaction) (*vaporTypes.Tx, string, error) {
+       destTxData := &vaporTypes.TxData{Version: 1, TimeRange: 0}
+       muxID := &vaporBc.Hash{}
+       if err := muxID.UnmarshalText([]byte(ormTx.SourceMuxID)); err != nil {
+               return nil, "", errors.Wrap(err, "Unmarshal muxID")
+       }
+
+       for _, req := range ormTx.Reqs {
+               // getAsset from assetStore instead of preload asset, in order to save db query overload
+               asset, err := w.assetStore.GetByOrmID(req.AssetID)
+               if err != nil {
+                       return nil, "", errors.Wrap(err, "get asset by ormAsset ID")
+               }
+
+               assetID := &vaporBc.AssetID{}
+               if err := assetID.UnmarshalText([]byte(asset.AssetID)); err != nil {
+                       return nil, "", errors.Wrap(err, "Unmarshal muxID")
+               }
+
+               rawDefinitionByte, err := hex.DecodeString(asset.RawDefinitionByte)
+               if err != nil {
+                       return nil, "", errors.Wrap(err, "decode rawDefinitionByte")
+               }
 
-       if err := w.db.Where(tx).UpdateColumn(&orm.CrossTransaction{
-               DestTxHash: sql.NullString{sidechainTx.ID.String(), true},
-       }).Error; err != nil {
+               issuanceProgramByte, err := hex.DecodeString(asset.IssuanceProgram)
+               if err != nil {
+                       return nil, "", errors.Wrap(err, "decode issuanceProgramByte")
+               }
+
+               input := vaporTypes.NewCrossChainInput(nil, *muxID, *assetID, req.AssetAmount, req.SourcePos, 1, rawDefinitionByte, issuanceProgramByte)
+               destTxData.Inputs = append(destTxData.Inputs, input)
+
+               controlProgram, err := hex.DecodeString(req.Script)
+               if err != nil {
+                       return nil, "", errors.Wrap(err, "decode req.Script")
+               }
+
+               output := vaporTypes.NewIntraChainOutput(*assetID, req.AssetAmount, controlProgram)
+               destTxData.Outputs = append(destTxData.Outputs, output)
+       }
+
+       destTx := vaporTypes.NewTx(*destTxData)
+       w.addInputWitness(destTx)
+
+       if err := w.db.Model(&orm.CrossTransaction{}).
+               Where(&orm.CrossTransaction{ID: ormTx.ID}).
+               UpdateColumn(&orm.CrossTransaction{
+                       DestTxHash: sql.NullString{destTx.ID.String(), true},
+               }).Error; err != nil {
                return nil, "", err
        }
 
-       return sidechainTx, sidechainTx.ID.String(), nil
+       return destTx, destTx.ID.String(), nil
 }
 
-// TODO:
-func (w *warder) buildMainchainTx(tx *orm.CrossTransaction) (*btmTypes.Tx, string, error) {
-       mainchainTx := &btmTypes.Tx{}
+func (w *warder) buildMainchainTx(ormTx *orm.CrossTransaction) (*btmTypes.Tx, string, error) {
+       return nil, "", errors.New("buildMainchainTx not implemented yet")
+}
 
-       if err := w.db.Where(tx).UpdateColumn(&orm.CrossTransaction{
-               DestTxHash: sql.NullString{mainchainTx.ID.String(), true},
-       }).Error; err != nil {
-               return nil, "", err
+// tx is a pointer to types.Tx, so the InputArguments can be set and be valid afterward
+func (w *warder) addInputWitness(tx interface{}) {
+       switch tx := tx.(type) {
+       case *vaporTypes.Tx:
+               args := [][]byte{w.fedProg}
+               for i := range tx.Inputs {
+                       tx.SetInputArguments(uint32(i), args)
+               }
+
+       case *btmTypes.Tx:
+               args := [][]byte{util.SegWitWrap(w.fedProg)}
+               for i := range tx.Inputs {
+                       tx.SetInputArguments(uint32(i), args)
+               }
        }
+}
 
-       return mainchainTx, mainchainTx.ID.String(), nil
+func (w *warder) initDestTxSigns(destTx interface{}, ormTx *orm.CrossTransaction) error {
+       for i := 1; i <= len(w.remotes)+1; i++ {
+               if err := w.db.Create(&orm.CrossTransactionSign{
+                       CrossTransactionID: ormTx.ID,
+                       WarderID:           uint8(i),
+                       Status:             common.CrossTxSignPendingStatus,
+               }).Error; err != nil {
+                       return err
+               }
+       }
+
+       return w.db.Model(&orm.CrossTransaction{}).
+               Where(&orm.CrossTransaction{ID: ormTx.ID}).
+               UpdateColumn(&orm.CrossTransaction{
+                       Status: common.CrossTxPendingStatus,
+               }).Error
 }
 
-// TODO:
-func (w *warder) signDestTx(destTx interface{}, tx *orm.CrossTransaction) error {
-       if tx.Status != common.CrossTxPendingStatus || !tx.DestTxHash.Valid {
-               return errors.New("cross-chain tx status error")
+func (w *warder) getSignData(destTx interface{}) ([][]byte, error) {
+       var signData [][]byte
+
+       switch destTx := destTx.(type) {
+       case *vaporTypes.Tx:
+               signData = make([][]byte, len(destTx.Inputs))
+               for i := range destTx.Inputs {
+                       signHash := destTx.SigHash(uint32(i))
+                       signData[i] = signHash.Bytes()
+               }
+
+       case *btmTypes.Tx:
+               signData = make([][]byte, len(destTx.Inputs))
+               for i := range destTx.Inputs {
+                       signHash := destTx.SigHash(uint32(i))
+                       signData[i] = signHash.Bytes()
+               }
+
+       default:
+               return [][]byte{}, errUnknownTxType
        }
 
-       return nil
+       return signData, nil
+}
+
+func (w *warder) getSigns(destTx interface{}, ormTx *orm.CrossTransaction) ([][]byte, error) {
+       if ormTx.Status != common.CrossTxPendingStatus || !ormTx.DestTxHash.Valid {
+               return nil, errors.New("cross-chain tx status error")
+       }
+
+       signData, err := w.getSignData(destTx)
+       if err != nil {
+               return nil, errors.New("getSignData")
+       }
+
+       var signs [][]byte
+       for _, data := range signData {
+               var b [32]byte
+               copy(b[:], data)
+               // vaporBc.Hash & btmBc.Hash are marshaled in the same way
+               msg := vaporBc.NewHash(b)
+               sign := w.xprv.Sign([]byte(msg.String()))
+               signs = append(signs, sign)
+       }
+
+       return signs, nil
 }
 
 // TODO:
-func (w *warder) attachSignsForTx(destTx interface{}, ormTx *orm.CrossTransaction, position uint8, signs string) {
+func (w *warder) attachSignsForTx(destTx interface{}, ormTx *orm.CrossTransaction, position uint8, signs []string) error {
+       var inputsLen int
+       switch destTx := destTx.(type) {
+       case *vaporTypes.Tx:
+               inputsLen = len(destTx.Inputs)
+       case *btmTypes.Tx:
+               inputsLen = len(destTx.Inputs)
+       default:
+               return errUnknownTxType
+       }
+
+       // finalize tx?
+
+       signWitness := make([][]string, inputsLen)
+
+       b, err := json.Marshal(signs)
+       if err != nil {
+               return errors.Wrap(err, "marshal signs")
+       }
+
+       return w.db.Model(&orm.CrossTransactionSign{}).
+               Where(&orm.CrossTransactionSign{
+                       CrossTransactionID: ormTx.ID,
+                       WarderID:           w.position,
+               }).
+               UpdateColumn(&orm.CrossTransactionSign{
+                       Signatures: string(b),
+                       Status:     common.CrossTxSignCompletedStatus,
+               }).Error
 }
 
 // TODO:
@@ -213,9 +369,12 @@ func (w *warder) isTxSignsReachQuorum(destTx interface{}) bool {
        return false
 }
 
-// TODO:
 func (w *warder) isLeader() bool {
-       return false
+       return w.position == 1
+}
+
+func (w *warder) finalizeTx(destTx interface{}, signersSigns [][][]byte) error {
+       return nil
 }
 
 func (w *warder) submitTx(destTx interface{}) (string, error) {
@@ -225,19 +384,21 @@ func (w *warder) submitTx(destTx interface{}) (string, error) {
        case *vaporTypes.Tx:
                return w.sidechainNode.SubmitTx(tx)
        default:
-               return "", errors.New("unknown destTx type")
+               return "", errUnknownTxType
        }
 }
 
-func (w *warder) updateSubmission(tx *orm.CrossTransaction) error {
-       if err := w.db.Where(tx).UpdateColumn(&orm.CrossTransaction{
-               Status: common.CrossTxSubmittedStatus,
-       }).Error; err != nil {
+func (w *warder) updateSubmission(ormTx *orm.CrossTransaction) error {
+       if err := w.db.Model(&orm.CrossTransaction{}).
+               Where(&orm.CrossTransaction{ID: ormTx.ID}).
+               UpdateColumn(&orm.CrossTransaction{
+                       Status: common.CrossTxSubmittedStatus,
+               }).Error; err != nil {
                return err
        }
 
        for _, remote := range w.remotes {
-               remote.NotifySubmission(tx)
+               remote.NotifySubmission(ormTx)
        }
        return nil
 }