OSDN Git Service

Frozen vm (#577)
authorPaladz <yzhu101@uottawa.ca>
Wed, 11 Apr 2018 13:12:48 +0000 (21:12 +0800)
committerGitHub <noreply@github.com>
Wed, 11 Apr 2018 13:12:48 +0000 (21:12 +0800)
* remove mux vm verify

* tmp save

* change the code style

* edit code format

* rename the file

* mv validate block to right folder

* delete unused from validation package

* add unit test for checkBlockTime && checkCoinbaseAmount

* delete unused test_file

* update the blockindex import

* format protocol level

* fix for golint

* edit for fix bug

23 files changed:
account/reserve.go
api/block_retrieve.go
api/nodeinfo.go
consensus/general.go
database/leveldb/store.go
mining/cpuminer/cpuminer.go
netsync/block_keeper.go
netsync/fetcher.go
netsync/sync.go
protocol/block.go
protocol/orphan_manage.go
protocol/protocol.go
protocol/state/blockindex.go [moved from protocol/blockindex.go with 65% similarity]
protocol/store.go
protocol/tx.go
protocol/tx_test.go [deleted file]
protocol/validation.go [deleted file]
protocol/validation/block.go [new file with mode: 0644]
protocol/validation/block_test.go
protocol/validation/tx.go [moved from protocol/validation/validation.go with 65% similarity]
protocol/validation/tx_test.go [moved from protocol/validation/validation_test.go with 95% similarity]
protocol/validation/vmcontext_test.go
test/performance/mining_test.go

index 40adfb8..2c04cb4 100644 (file)
@@ -174,7 +174,7 @@ func (re *reserver) reserveUTXO(ctx context.Context, out bc.Hash, exp time.Time,
        }
 
        //u.ValidHeight > 0 means coinbase utxo
-       if u.ValidHeight > 0 && u.ValidHeight > re.c.Height() {
+       if u.ValidHeight > 0 && u.ValidHeight > re.c.BestBlockHeight() {
                return nil, errors.WithDetail(ErrMatchUTXO, "this coinbase utxo is immature")
        }
 
@@ -257,7 +257,7 @@ func (re *reserver) source(src source) *sourceReserver {
                db:            re.db,
                src:           src,
                reserved:      make(map[bc.Hash]uint64),
-               currentHeight: re.c.Height,
+               currentHeight: re.c.BestBlockHeight,
        }
        re.sources[src] = sr
        return sr
index e1ea528..60b9896 100644 (file)
@@ -162,6 +162,6 @@ func (a *API) getBlockTransactionsCountByHeight(height uint64) Response {
 
 // return current block count
 func (a *API) getBlockCount() Response {
-       blockHeight := map[string]uint64{"block_count": a.chain.Height()}
+       blockHeight := map[string]uint64{"block_count": a.chain.BestBlockHeight()}
        return NewSuccessResponse(blockHeight)
 }
index 0a6c125..9ea1a57 100644 (file)
@@ -15,7 +15,7 @@ func (a *API) GetNodeInfo() *NetInfo {
                Syncing:      a.sync.BlockKeeper().IsCaughtUp(),
                Mining:       a.cpuMiner.IsMining(),
                PeerCount:    len(a.sync.Switch().Peers().List()),
-               CurrentBlock: a.chain.Height(),
+               CurrentBlock: a.chain.BestBlockHeight(),
        }
        _, info.HighestBlock = a.sync.Peers().BestPeer()
        if info.CurrentBlock > info.HighestBlock {
index f18a0b3..137d3aa 100644 (file)
@@ -8,8 +8,7 @@ import (
 
 //consensus variables
 const (
-       // define the Max transaction size and Max block size
-       MaxTxSize   = uint64(1048576)
+       // Max gas that one block contains
        MaxBlockGas = uint64(100000000)
 
        //config parameter for coinbase reward
@@ -34,7 +33,7 @@ const (
        CoinbaseArbitrarySizeLimit = 128
 
        VMGasRate        = int64(1000)
-       StorageGasRate   = int64(0)
+       StorageGasRate   = int64(5)
        MaxGasAmount     = int64(100000)
        DefaultGasCredit = int64(80000)
 
index 7834a5c..4e9a164 100644 (file)
@@ -5,6 +5,7 @@ import (
        "encoding/json"
 
        "github.com/golang/protobuf/proto"
+       log "github.com/sirupsen/logrus"
        "github.com/tendermint/tmlibs/common"
        dbm "github.com/tendermint/tmlibs/db"
 
@@ -23,15 +24,13 @@ var (
        txStatusPrefix    = []byte("BTS:")
 )
 
-func loadBlockStoreStateJSON(db dbm.DB) protocol.BlockStoreStateJSON {
+func loadBlockStoreStateJSON(db dbm.DB) *protocol.BlockStoreState {
        bytes := db.Get(blockStoreKey)
        if bytes == nil {
-               return protocol.BlockStoreStateJSON{
-                       Height: 0,
-               }
+               return nil
        }
-       bsj := protocol.BlockStoreStateJSON{}
-       if err := json.Unmarshal(bytes, &bsj); err != nil {
+       bsj := &protocol.BlockStoreState{}
+       if err := json.Unmarshal(bytes, bsj); err != nil {
                common.PanicCrisis(common.Fmt("Could not unmarshal bytes: %X", bytes))
        }
        return bsj
@@ -119,30 +118,30 @@ func (s *Store) GetTransactionStatus(hash *bc.Hash) (*bc.TransactionStatus, erro
 }
 
 // GetStoreStatus return the BlockStoreStateJSON
-func (s *Store) GetStoreStatus() protocol.BlockStoreStateJSON {
+func (s *Store) GetStoreStatus() *protocol.BlockStoreState {
        return loadBlockStoreStateJSON(s.db)
 }
 
-func (s *Store) LoadBlockIndex() (*protocol.BlockIndex, error) {
-       blockIndex := protocol.NewBlockIndex()
+func (s *Store) LoadBlockIndex() (*state.BlockIndex, error) {
+       blockIndex := state.NewBlockIndex()
        bhIter := s.db.IteratorPrefix(blockHeaderPrefix)
        defer bhIter.Release()
 
-       var lastNode *protocol.BlockNode
+       var lastNode *state.BlockNode
        for bhIter.Next() {
                bh := &types.BlockHeader{}
                if err := bh.UnmarshalText(bhIter.Value()); err != nil {
                        return nil, err
                }
 
-               var parent *protocol.BlockNode
+               var parent *state.BlockNode
                if lastNode == nil || lastNode.Hash == bh.PreviousBlockHash {
                        parent = lastNode
                } else {
                        parent = blockIndex.GetNode(&bh.PreviousBlockHash)
                }
 
-               node, err := protocol.NewBlockNode(bh, parent)
+               node, err := state.NewBlockNode(bh, parent)
                if err != nil {
                        return nil, err
                }
@@ -177,19 +176,19 @@ func (s *Store) SaveBlock(block *types.Block, ts *bc.TransactionStatus) error {
        batch.Set(calcBlockHeaderKey(block.Height, &blockHash), binaryBlockHeader)
        batch.Set(calcTxStatusKey(&blockHash), binaryTxStatus)
        batch.Write()
+
+       log.WithFields(log.Fields{"height": block.Height, "hash": blockHash.String()}).Info("block saved on disk")
        return nil
 }
 
 // SaveChainStatus save the core's newest status && delete old status
-func (s *Store) SaveChainStatus(block *types.Block, view *state.UtxoViewpoint) error {
-       hash := block.Hash()
+func (s *Store) SaveChainStatus(node *state.BlockNode, view *state.UtxoViewpoint) error {
        batch := s.db.NewBatch()
-
        if err := saveUtxoView(batch, view); err != nil {
                return err
        }
 
-       bytes, err := json.Marshal(protocol.BlockStoreStateJSON{Height: block.Height, Hash: &hash})
+       bytes, err := json.Marshal(protocol.BlockStoreState{Height: node.Height, Hash: &node.Hash})
        if err != nil {
                return err
        }
index 5cbf0c0..06dd44f 100644 (file)
@@ -59,7 +59,7 @@ func (m *CPUMiner) solveBlock(block *types.Block, ticker *time.Ticker, quit chan
                case <-quit:
                        return false
                case <-ticker.C:
-                       if m.chain.Height() >= header.Height {
+                       if m.chain.BestBlockHeight() >= header.Height {
                                return false
                        }
                default:
index b170d39..6758654 100644 (file)
@@ -66,11 +66,11 @@ func (bk *blockKeeper) AddTx(tx *types.Tx, peerID string) {
 
 func (bk *blockKeeper) IsCaughtUp() bool {
        _, height := bk.peers.BestPeer()
-       return bk.chain.Height() < height
+       return bk.chain.BestBlockHeight() < height
 }
 
 func (bk *blockKeeper) BlockRequestWorker(peerID string, maxPeerHeight uint64) error {
-       num := bk.chain.Height() + 1
+       num := bk.chain.BestBlockHeight() + 1
        currentHash := bk.chain.BestBlockHash()
        orphanNum := uint64(0)
        reqNum := uint64(0)
@@ -106,7 +106,7 @@ func (bk *blockKeeper) BlockRequestWorker(peerID string, maxPeerHeight uint64) e
                num++
        }
        bestHash := bk.chain.BestBlockHash()
-       log.Info("Block sync complete. height:", bk.chain.Height(), " hash:", bestHash)
+       log.Info("Block sync complete. height:", bk.chain.BestBlockHeight(), " hash:", bestHash)
        if strings.Compare(currentHash.String(), bestHash.String()) != 0 {
                log.Info("Broadcast new chain status.")
 
index 754e2ce..ecb7522 100644 (file)
@@ -6,11 +6,12 @@ import (
        log "github.com/sirupsen/logrus"
        "gopkg.in/karalabe/cookiejar.v2/collections/prque"
 
+       "strings"
+
        "github.com/bytom/p2p"
        core "github.com/bytom/protocol"
        "github.com/bytom/protocol/bc"
        "github.com/bytom/protocol/bc/types"
-       "strings"
 )
 
 const (
@@ -83,7 +84,7 @@ func (f *Fetcher) Enqueue(peer string, block *types.Block) error {
 func (f *Fetcher) loop() {
        for {
                // Import any queued blocks that could potentially fit
-               height := f.chain.Height()
+               height := f.chain.BestBlockHeight()
                for !f.queue.Empty() {
                        op := f.queue.PopItem().(*blockPending)
                        // If too high up the chain or phase, continue later
@@ -125,7 +126,7 @@ func (f *Fetcher) enqueue(peer string, block *types.Block) {
 
        //TODO: Ensure the peer isn't DOSing us
        // Discard any past or too distant blocks
-       if dist := int64(block.Height) - int64(f.chain.Height()); dist < 0 || dist > maxQueueDist {
+       if dist := int64(block.Height) - int64(f.chain.BestBlockHeight()); dist < 0 || dist > maxQueueDist {
                log.Info("Discarded propagated block, too far away", " peer: ", peer, "number: ", block.Height, "distance: ", dist)
                return
        }
index c3c1238..7d4531a 100644 (file)
@@ -87,7 +87,7 @@ func (sm *SyncManager) synchronise() {
        if peer == nil {
                return
        }
-       if bestHeight > sm.chain.Height() {
+       if bestHeight > sm.chain.BestBlockHeight() {
                sm.blockKeeper.BlockRequestWorker(peer.Key, bestHeight)
        }
 }
index 0ffcfb5..1a79413 100644 (file)
@@ -7,6 +7,7 @@ import (
        "github.com/bytom/protocol/bc"
        "github.com/bytom/protocol/bc/types"
        "github.com/bytom/protocol/state"
+       "github.com/bytom/protocol/validation"
 )
 
 var (
@@ -20,7 +21,7 @@ var (
 
 // BlockExist check is a block in chain or orphan
 func (c *Chain) BlockExist(hash *bc.Hash) bool {
-       return c.orphanManage.BlockExist(hash) || c.index.BlockExist(hash)
+       return c.index.BlockExist(hash) || c.orphanManage.BlockExist(hash)
 }
 
 // GetBlockByHash return a block by given hash
@@ -37,21 +38,13 @@ func (c *Chain) GetBlockByHeight(height uint64) (*types.Block, error) {
        return c.store.GetBlock(&node.Hash)
 }
 
-// ConnectBlock append block to end of chain
-func (c *Chain) ConnectBlock(block *types.Block) error {
-       c.state.cond.L.Lock()
-       defer c.state.cond.L.Unlock()
-       return c.connectBlock(block)
-}
-
 func (c *Chain) connectBlock(block *types.Block) (err error) {
        bcBlock := types.MapBlock(block)
-       utxoView := state.NewUtxoViewpoint()
-       bcBlock.TransactionStatus, err = c.store.GetTransactionStatus(&bcBlock.ID)
-       if err != nil {
+       if bcBlock.TransactionStatus, err = c.store.GetTransactionStatus(&bcBlock.ID); err != nil {
                return err
        }
 
+       utxoView := state.NewUtxoViewpoint()
        if err := c.store.GetTransactionsUtxo(utxoView, bcBlock.Transactions); err != nil {
                return err
        }
@@ -59,7 +52,8 @@ func (c *Chain) connectBlock(block *types.Block) (err error) {
                return err
        }
 
-       if err := c.setState(block, utxoView); err != nil {
+       node := c.index.GetNode(&bcBlock.ID)
+       if err := c.setState(node, utxoView); err != nil {
                return err
        }
 
@@ -69,29 +63,35 @@ func (c *Chain) connectBlock(block *types.Block) (err error) {
        return nil
 }
 
-func (c *Chain) getReorganizeBlocks(block *types.Block) ([]*types.Block, []*types.Block) {
-       attachBlocks := []*types.Block{}
-       detachBlocks := []*types.Block{}
-       ancestor := block
+func (c *Chain) calcReorganizeNodes(node *state.BlockNode) ([]*state.BlockNode, []*state.BlockNode) {
+       var attachNodes []*state.BlockNode
+       var detachNodes []*state.BlockNode
 
-       for !c.index.InMainchain(ancestor.Hash()) {
-               attachBlocks = append([]*types.Block{ancestor}, attachBlocks...)
-               ancestor, _ = c.GetBlockByHash(&ancestor.PreviousBlockHash)
+       attachIter := node
+       for c.index.NodeByHeight(attachIter.Height) != attachIter {
+               attachNodes = append([]*state.BlockNode{attachIter}, attachNodes...)
+               attachIter = attachIter.Parent
        }
 
-       for d, _ := c.store.GetBlock(c.state.hash); d.Hash() != ancestor.Hash(); d, _ = c.store.GetBlock(&d.PreviousBlockHash) {
-               detachBlocks = append(detachBlocks, d)
+       detachIter := c.bestNode
+       for detachIter != attachIter {
+               detachNodes = append(detachNodes, detachIter)
+               detachIter = detachIter.Parent
        }
-
-       return attachBlocks, detachBlocks
+       return attachNodes, detachNodes
 }
 
-func (c *Chain) reorganizeChain(block *types.Block) error {
-       attachBlocks, detachBlocks := c.getReorganizeBlocks(block)
+func (c *Chain) reorganizeChain(node *state.BlockNode) error {
+       attachNodes, detachNodes := c.calcReorganizeNodes(node)
        utxoView := state.NewUtxoViewpoint()
 
-       for _, d := range detachBlocks {
-               detachBlock := types.MapBlock(d)
+       for _, detachNode := range detachNodes {
+               b, err := c.store.GetBlock(&detachNode.Hash)
+               if err != nil {
+                       return err
+               }
+
+               detachBlock := types.MapBlock(b)
                if err := c.store.GetTransactionsUtxo(utxoView, detachBlock.Transactions); err != nil {
                        return err
                }
@@ -102,11 +102,17 @@ func (c *Chain) reorganizeChain(block *types.Block) error {
                if err := utxoView.DetachBlock(detachBlock, txStatus); err != nil {
                        return err
                }
-               log.WithFields(log.Fields{"height": detachBlock.Height, "hash": detachBlock.ID.String()}).Debug("Detach from mainchain")
+
+               log.WithFields(log.Fields{"height": node.Height, "hash": node.Hash.String()}).Debug("detach from mainchain")
        }
 
-       for _, a := range attachBlocks {
-               attachBlock := types.MapBlock(a)
+       for _, attachNode := range attachNodes {
+               b, err := c.store.GetBlock(&attachNode.Hash)
+               if err != nil {
+                       return err
+               }
+
+               attachBlock := types.MapBlock(b)
                if err := c.store.GetTransactionsUtxo(utxoView, attachBlock.Transactions); err != nil {
                        return err
                }
@@ -114,31 +120,30 @@ func (c *Chain) reorganizeChain(block *types.Block) error {
                if err != nil {
                        return err
                }
-
                if err := utxoView.ApplyBlock(attachBlock, txStatus); err != nil {
                        return err
                }
-               log.WithFields(log.Fields{"height": attachBlock.Height, "hash": attachBlock.ID.String()}).Debug("Attach from mainchain")
+
+               log.WithFields(log.Fields{"height": node.Height, "hash": node.Hash.String()}).Debug("attach from mainchain")
        }
 
-       return c.setState(block, utxoView)
+       return c.setState(node, utxoView)
 }
 
 // SaveBlock will validate and save block into storage
-func (c *Chain) SaveBlock(block *types.Block) error {
-       blockEnts := types.MapBlock(block)
-       if err := c.validateBlock(blockEnts); err != nil {
+func (c *Chain) saveBlock(block *types.Block) error {
+       bcBlock := types.MapBlock(block)
+       parent := c.index.GetNode(&block.PreviousBlockHash)
+
+       if err := validation.ValidateBlock(bcBlock, parent); err != nil {
                return errors.Sub(ErrBadBlock, err)
        }
-
-       if err := c.store.SaveBlock(block, blockEnts.TransactionStatus); err != nil {
+       if err := c.store.SaveBlock(block, bcBlock.TransactionStatus); err != nil {
                return err
        }
-       log.WithFields(log.Fields{"height": block.Height, "hash": blockEnts.ID.String()}).Info("Block saved on disk")
 
-       c.orphanManage.Delete(&blockEnts.ID)
-       parent := c.index.GetNode(&block.PreviousBlockHash)
-       node, err := NewBlockNode(&block.BlockHeader, parent)
+       c.orphanManage.Delete(&bcBlock.ID)
+       node, err := state.NewBlockNode(&block.BlockHeader, parent)
        if err != nil {
                return err
        }
@@ -147,34 +152,30 @@ func (c *Chain) SaveBlock(block *types.Block) error {
        return nil
 }
 
-func (c *Chain) findBestChainTail(block *types.Block) (bestBlock *types.Block) {
-       bestBlock = block
+func (c *Chain) saveSubBlock(block *types.Block) *types.Block {
        blockHash := block.Hash()
-       preorphans, ok := c.orphanManage.preOrphans[blockHash]
+       prevOrphans, ok := c.orphanManage.GetPrevOrphans(&blockHash)
        if !ok {
-               return
+               return block
        }
 
-       for _, preorphan := range preorphans {
-               orphanBlock, ok := c.orphanManage.Get(preorphan)
+       bestBlock := block
+       for _, prevOrphan := range prevOrphans {
+               orphanBlock, ok := c.orphanManage.Get(prevOrphan)
                if !ok {
+                       log.WithFields(log.Fields{"hash": prevOrphan.String()}).Warning("saveSubBlock fail to get block from orphanManage")
                        continue
                }
-
-               if err := c.SaveBlock(orphanBlock); err != nil {
-                       log.WithFields(log.Fields{
-                               "height": block.Height,
-                               "hash":   blockHash.String(),
-                       }).Errorf("findBestChainTail fail on save block %v", err)
+               if err := c.saveBlock(orphanBlock); err != nil {
+                       log.WithFields(log.Fields{"hash": prevOrphan.String(), "height": orphanBlock.Height}).Warning("saveSubBlock fail to save block")
                        continue
                }
 
-               if subResult := c.findBestChainTail(orphanBlock); subResult.Height > bestBlock.Height {
-                       bestBlock = subResult
+               if subBestBlock := c.saveSubBlock(orphanBlock); subBestBlock.Height > bestBlock.Height {
+                       bestBlock = subBestBlock
                }
        }
-
-       return
+       return bestBlock
 }
 
 type processBlockResponse struct {
@@ -187,6 +188,7 @@ type processBlockMsg struct {
        reply chan processBlockResponse
 }
 
+// ProcessBlock is the entry for chain update
 func (c *Chain) ProcessBlock(block *types.Block) (bool, error) {
        reply := make(chan processBlockResponse, 1)
        c.processBlockCh <- &processBlockMsg{block: block, reply: reply}
@@ -205,32 +207,31 @@ func (c *Chain) blockProcesser() {
 func (c *Chain) processBlock(block *types.Block) (bool, error) {
        blockHash := block.Hash()
        if c.BlockExist(&blockHash) {
-               log.WithField("hash", blockHash.String()).Debug("Skip process due to block already been handled")
+               log.WithFields(log.Fields{"hash": blockHash.String(), "height": block.Height}).Info("block has been processed")
                return c.orphanManage.BlockExist(&blockHash), nil
        }
-       if !c.store.BlockExist(&block.PreviousBlockHash) {
-               log.WithField("hash", blockHash.String()).Debug("Add block to orphan manage")
+
+       if parent := c.index.GetNode(&block.PreviousBlockHash); parent == nil {
                c.orphanManage.Add(block)
                return true, nil
        }
-       if err := c.SaveBlock(block); err != nil {
+
+       if err := c.saveBlock(block); err != nil {
                return false, err
        }
 
-       bestBlock := c.findBestChainTail(block)
-       bestMainChain := c.index.BestNode()
+       bestBlock := c.saveSubBlock(block)
        bestBlockHash := bestBlock.Hash()
        bestNode := c.index.GetNode(&bestBlockHash)
 
-       if bestNode.parent == bestMainChain {
-               log.WithField("hash", blockHash.String()).Debug("Start to append block to the tail of mainchain")
+       if bestNode.Parent == c.bestNode {
+               log.Debug("append block to the end of mainchain")
                return false, c.connectBlock(bestBlock)
        }
 
-       if bestNode.height > bestMainChain.height && bestNode.workSum.Cmp(bestMainChain.workSum) >= 0 {
-               log.WithField("hash", blockHash.String()).Debug("Start to reorganize mainchain")
-               return false, c.reorganizeChain(bestBlock)
+       if bestNode.Height > c.bestNode.Height && bestNode.WorkSum.Cmp(c.bestNode.WorkSum) >= 0 {
+               log.Debug("start to reorganize chain")
+               return false, c.reorganizeChain(bestNode)
        }
-
        return false, nil
 }
index 04b0ab8..20067e1 100644 (file)
@@ -3,6 +3,8 @@ package protocol
 import (
        "sync"
 
+       log "github.com/sirupsen/logrus"
+
        "github.com/bytom/protocol/bc"
        "github.com/bytom/protocol/bc/types"
 )
@@ -10,16 +12,16 @@ import (
 // OrphanManage is use to handle all the orphan block
 type OrphanManage struct {
        //TODO: add orphan cached block limit
-       orphan     map[bc.Hash]*types.Block
-       preOrphans map[bc.Hash][]*bc.Hash
-       mtx        sync.RWMutex
+       orphan      map[bc.Hash]*types.Block
+       prevOrphans map[bc.Hash][]*bc.Hash
+       mtx         sync.RWMutex
 }
 
 // NewOrphanManage return a new orphan block
 func NewOrphanManage() *OrphanManage {
        return &OrphanManage{
-               orphan:     make(map[bc.Hash]*types.Block),
-               preOrphans: make(map[bc.Hash][]*bc.Hash),
+               orphan:      make(map[bc.Hash]*types.Block),
+               prevOrphans: make(map[bc.Hash][]*bc.Hash),
        }
 }
 
@@ -42,7 +44,9 @@ func (o *OrphanManage) Add(block *types.Block) {
        }
 
        o.orphan[blockHash] = block
-       o.preOrphans[block.PreviousBlockHash] = append(o.preOrphans[block.PreviousBlockHash], &blockHash)
+       o.prevOrphans[block.PreviousBlockHash] = append(o.prevOrphans[block.PreviousBlockHash], &blockHash)
+
+       log.WithFields(log.Fields{"hash": blockHash.String(), "height": block.Height}).Info("add block to orphan")
 }
 
 // Delete will delelte the block from OrphanManage
@@ -55,15 +59,15 @@ func (o *OrphanManage) Delete(hash *bc.Hash) {
        }
        delete(o.orphan, *hash)
 
-       preOrphans, ok := o.preOrphans[block.PreviousBlockHash]
-       if !ok || len(preOrphans) == 1 {
-               delete(o.preOrphans, block.PreviousBlockHash)
+       prevOrphans, ok := o.prevOrphans[block.PreviousBlockHash]
+       if !ok || len(prevOrphans) == 1 {
+               delete(o.prevOrphans, block.PreviousBlockHash)
                return
        }
 
-       for i, preOrphan := range preOrphans {
+       for i, preOrphan := range prevOrphans {
                if preOrphan == hash {
-                       o.preOrphans[block.PreviousBlockHash] = append(preOrphans[:i], preOrphans[i+1:]...)
+                       o.prevOrphans[block.PreviousBlockHash] = append(prevOrphans[:i], prevOrphans[i+1:]...)
                        return
                }
        }
@@ -76,3 +80,11 @@ func (o *OrphanManage) Get(hash *bc.Hash) (*types.Block, bool) {
        o.mtx.RUnlock()
        return block, ok
 }
+
+// GetPrevOrphans return the list of child orphans
+func (o *OrphanManage) GetPrevOrphans(hash *bc.Hash) ([]*bc.Hash, bool) {
+       o.mtx.RLock()
+       prevOrphans, ok := o.prevOrphans[*hash]
+       o.mtx.RUnlock()
+       return prevOrphans, ok
+}
index 3e80d8d..e4480db 100644 (file)
@@ -1,77 +1,60 @@
 package protocol
 
 import (
-       "context"
-       "math/big"
        "sync"
 
        log "github.com/sirupsen/logrus"
 
        "github.com/bytom/config"
-       "github.com/bytom/database/storage"
        "github.com/bytom/errors"
        "github.com/bytom/protocol/bc"
        "github.com/bytom/protocol/bc/types"
        "github.com/bytom/protocol/state"
 )
 
-const (
-       maxProcessBlockChSize = 1024
-)
+const maxProcessBlockChSize = 1024
 
-var (
-       // ErrTheDistantFuture is returned when waiting for a blockheight
-       // too far in excess of the tip of the blockchain.
-       ErrTheDistantFuture = errors.New("block height too far in future")
-)
+// ErrTheDistantFuture is returned when waiting for a blockheight too far in
+// excess of the tip of the blockchain.
+var ErrTheDistantFuture = errors.New("block height too far in future")
 
-// Chain provides a complete, minimal blockchain database. It
-// delegates the underlying storage to other objects, and uses
-// validation logic from package validation to decide what
-// objects can be safely stored.
+// Chain provides functions for working with the Bytom block chain.
 type Chain struct {
-       index          *BlockIndex
+       index          *state.BlockIndex
        orphanManage   *OrphanManage
        txPool         *TxPool
+       store          Store
        processBlockCh chan *processBlockMsg
 
-       state struct {
-               cond    sync.Cond
-               hash    *bc.Hash
-               height  uint64
-               workSum *big.Int
-       }
-
-       store Store
+       cond     sync.Cond
+       bestNode *state.BlockNode
 }
 
 // NewChain returns a new Chain using store as the underlying storage.
 func NewChain(store Store, txPool *TxPool) (*Chain, error) {
        c := &Chain{
                orphanManage:   NewOrphanManage(),
-               store:          store,
                txPool:         txPool,
+               store:          store,
                processBlockCh: make(chan *processBlockMsg, maxProcessBlockChSize),
        }
-       c.state.cond.L = new(sync.Mutex)
+       c.cond.L = new(sync.Mutex)
 
-       var err error
-       if storeStatus := store.GetStoreStatus(); storeStatus.Hash != nil {
-               c.state.hash = storeStatus.Hash
-       } else {
-               if err = c.initChainStatus(); err != nil {
+       storeStatus := store.GetStoreStatus()
+       if storeStatus == nil {
+               if err := c.initChainStatus(); err != nil {
                        return nil, err
                }
+               storeStatus = store.GetStoreStatus()
        }
 
+       var err error
        if c.index, err = store.LoadBlockIndex(); err != nil {
                return nil, err
        }
 
-       bestNode := c.index.GetNode(c.state.hash)
-       c.index.SetMainChain(bestNode)
-       c.state.height = bestNode.height
-       c.state.workSum = bestNode.workSum
+       c.bestNode = c.index.GetNode(storeStatus.Hash)
+       c.index.SetMainChain(c.bestNode)
        go c.blockProcesser()
        return c, nil
 }
@@ -79,7 +62,7 @@ func NewChain(store Store, txPool *TxPool) (*Chain, error) {
 func (c *Chain) initChainStatus() error {
        genesisBlock := config.GenerateGenesisBlock()
        txStatus := bc.NewTransactionStatus()
-       for i, _ := range genesisBlock.Transactions {
+       for i := range genesisBlock.Transactions {
                txStatus.SetStatus(i, false)
        }
 
@@ -93,43 +76,36 @@ func (c *Chain) initChainStatus() error {
                return err
        }
 
-       if err := c.store.SaveChainStatus(genesisBlock, utxoView); err != nil {
+       node, err := state.NewBlockNode(&genesisBlock.BlockHeader, nil)
+       if err != nil {
                return err
        }
-
-       hash := genesisBlock.Hash()
-       c.state.hash = &hash
-       return nil
+       return c.store.SaveChainStatus(node, utxoView)
 }
 
-// Height returns the current height of the blockchain.
-func (c *Chain) Height() uint64 {
-       c.state.cond.L.Lock()
-       defer c.state.cond.L.Unlock()
-       return c.state.height
+// BestBlockHeight returns the current height of the blockchain.
+func (c *Chain) BestBlockHeight() uint64 {
+       c.cond.L.Lock()
+       defer c.cond.L.Unlock()
+       return c.bestNode.Height
 }
 
 // BestBlockHash return the hash of the chain tail block
 func (c *Chain) BestBlockHash() *bc.Hash {
-       c.state.cond.L.Lock()
-       defer c.state.cond.L.Unlock()
-       return c.state.hash
-}
-
-// InMainChain checks wheather a block is in the main chain
-func (c *Chain) InMainChain(hash bc.Hash) bool {
-       return c.index.InMainchain(hash)
+       c.cond.L.Lock()
+       defer c.cond.L.Unlock()
+       return &c.bestNode.Hash
 }
 
-// BestBlock returns the chain tail block
+// BestBlockHeader returns the chain tail block
 func (c *Chain) BestBlockHeader() *types.BlockHeader {
        node := c.index.BestNode()
-       return node.blockHeader()
+       return node.BlockHeader()
 }
 
-// GetUtxo try to find the utxo status in db
-func (c *Chain) GetUtxo(hash *bc.Hash) (*storage.UtxoEntry, error) {
-       return c.store.GetUtxo(hash)
+// InMainChain checks wheather a block is in the main chain
+func (c *Chain) InMainChain(hash bc.Hash) bool {
+       return c.index.InMainchain(hash)
 }
 
 // CalcNextSeed return the seed for the given block
@@ -150,76 +126,31 @@ func (c *Chain) CalcNextBits(preBlock *bc.Hash) (uint64, error) {
        return node.CalcNextBits(), nil
 }
 
-// GetTransactionStatus return the transaction status of give block
-func (c *Chain) GetTransactionStatus(hash *bc.Hash) (*bc.TransactionStatus, error) {
-       return c.store.GetTransactionStatus(hash)
-}
-
-// GetTransactionsUtxo return all the utxos that related to the txs' inputs
-func (c *Chain) GetTransactionsUtxo(view *state.UtxoViewpoint, txs []*bc.Tx) error {
-       return c.store.GetTransactionsUtxo(view, txs)
-}
-
 // This function must be called with mu lock in above level
-func (c *Chain) setState(block *types.Block, view *state.UtxoViewpoint) error {
-       if err := c.store.SaveChainStatus(block, view); err != nil {
+func (c *Chain) setState(node *state.BlockNode, view *state.UtxoViewpoint) error {
+       if err := c.store.SaveChainStatus(node, view); err != nil {
                return err
        }
 
-       c.state.cond.L.Lock()
-       defer c.state.cond.L.Unlock()
+       c.cond.L.Lock()
+       defer c.cond.L.Unlock()
 
-       blockHash := block.Hash()
-       node := c.index.GetNode(&blockHash)
        c.index.SetMainChain(node)
-       c.state.hash = &blockHash
-       c.state.height = node.height
-       c.state.workSum = node.workSum
-
-       log.WithFields(log.Fields{
-               "height":  c.state.height,
-               "hash":    c.state.hash.String(),
-               "workSum": c.state.workSum,
-       }).Debug("Chain best status has been changed")
-       c.state.cond.Broadcast()
-       return nil
-}
-
-// BlockSoonWaiter returns a channel that
-// waits for the block at the given height,
-// but it is an error to wait for a block far in the future.
-// WaitForBlockSoon will timeout if the context times out.
-// To wait unconditionally, the caller should use WaitForBlock.
-func (c *Chain) BlockSoonWaiter(ctx context.Context, height uint64) <-chan error {
-       ch := make(chan error, 1)
+       c.bestNode = node
 
-       go func() {
-               const slop = 3
-               if height > c.Height()+slop {
-                       ch <- ErrTheDistantFuture
-                       return
-               }
-
-               select {
-               case <-c.BlockWaiter(height):
-                       ch <- nil
-               case <-ctx.Done():
-                       ch <- ctx.Err()
-               }
-       }()
-
-       return ch
+       log.WithFields(log.Fields{"height": c.bestNode.Height, "hash": c.bestNode.Hash}).Debug("chain best status has been update")
+       c.cond.Broadcast()
+       return nil
 }
 
-// BlockWaiter returns a channel that
-// waits for the block at the given height.
+// BlockWaiter returns a channel that waits for the block at the given height.
 func (c *Chain) BlockWaiter(height uint64) <-chan struct{} {
        ch := make(chan struct{}, 1)
        go func() {
-               c.state.cond.L.Lock()
-               defer c.state.cond.L.Unlock()
-               for c.state.height < height {
-                       c.state.cond.Wait()
+               c.cond.L.Lock()
+               defer c.cond.L.Unlock()
+               for c.bestNode.Height < height {
+                       c.cond.Wait()
                }
                ch <- struct{}{}
        }()
similarity index 65%
rename from protocol/blockindex.go
rename to protocol/state/blockindex.go
index f95cb17..9278f13 100644 (file)
@@ -1,4 +1,4 @@
-package protocol
+package state
 
 import (
        "errors"
@@ -20,18 +20,18 @@ const approxNodesPerDay = 24 * 24
 // BlockNode represents a block within the block chain and is primarily used to
 // aid in selecting the best chain to be the main chain.
 type BlockNode struct {
-       parent  *BlockNode // parent is the parent block for this node.
+       Parent  *BlockNode // parent is the parent block for this node.
        Hash    bc.Hash    // hash of the block.
-       seed    *bc.Hash   // seed hash of the block
-       workSum *big.Int   // total amount of work in the chain up to
+       Seed    *bc.Hash   // seed hash of the block
+       WorkSum *big.Int   // total amount of work in the chain up to
 
-       version                uint64
-       height                 uint64
-       timestamp              uint64
-       nonce                  uint64
-       bits                   uint64
-       transactionsMerkleRoot bc.Hash
-       transactionStatusHash  bc.Hash
+       Version                uint64
+       Height                 uint64
+       Timestamp              uint64
+       Nonce                  uint64
+       Bits                   uint64
+       TransactionsMerkleRoot bc.Hash
+       TransactionStatusHash  bc.Hash
 }
 
 func NewBlockNode(bh *types.BlockHeader, parent *BlockNode) (*BlockNode, error) {
@@ -40,43 +40,43 @@ func NewBlockNode(bh *types.BlockHeader, parent *BlockNode) (*BlockNode, error)
        }
 
        node := &BlockNode{
-               parent:    parent,
+               Parent:    parent,
                Hash:      bh.Hash(),
-               workSum:   difficulty.CalcWork(bh.Bits),
-               version:   bh.Version,
-               height:    bh.Height,
-               timestamp: bh.Timestamp,
-               nonce:     bh.Nonce,
-               bits:      bh.Bits,
-               transactionsMerkleRoot: bh.TransactionsMerkleRoot,
-               transactionStatusHash:  bh.TransactionStatusHash,
+               WorkSum:   difficulty.CalcWork(bh.Bits),
+               Version:   bh.Version,
+               Height:    bh.Height,
+               Timestamp: bh.Timestamp,
+               Nonce:     bh.Nonce,
+               Bits:      bh.Bits,
+               TransactionsMerkleRoot: bh.TransactionsMerkleRoot,
+               TransactionStatusHash:  bh.TransactionStatusHash,
        }
 
        if bh.Height == 0 {
-               node.seed = consensus.InitialSeed
+               node.Seed = consensus.InitialSeed
        } else {
-               node.seed = parent.CalcNextSeed()
-               node.workSum = node.workSum.Add(parent.workSum, node.workSum)
+               node.Seed = parent.CalcNextSeed()
+               node.WorkSum = node.WorkSum.Add(parent.WorkSum, node.WorkSum)
        }
        return node, nil
 }
 
 // blockHeader convert a node to the header struct
-func (node *BlockNode) blockHeader() *types.BlockHeader {
+func (node *BlockNode) BlockHeader() *types.BlockHeader {
        previousBlockHash := bc.Hash{}
-       if node.parent != nil {
-               previousBlockHash = node.parent.Hash
+       if node.Parent != nil {
+               previousBlockHash = node.Parent.Hash
        }
        return &types.BlockHeader{
-               Version:           node.version,
-               Height:            node.height,
+               Version:           node.Version,
+               Height:            node.Height,
                PreviousBlockHash: previousBlockHash,
-               Timestamp:         node.timestamp,
-               Nonce:             node.nonce,
-               Bits:              node.bits,
+               Timestamp:         node.Timestamp,
+               Nonce:             node.Nonce,
+               Bits:              node.Bits,
                BlockCommitment: types.BlockCommitment{
-                       TransactionsMerkleRoot: node.transactionsMerkleRoot,
-                       TransactionStatusHash:  node.transactionStatusHash,
+                       TransactionsMerkleRoot: node.TransactionsMerkleRoot,
+                       TransactionStatusHash:  node.TransactionStatusHash,
                },
        }
 }
@@ -85,8 +85,8 @@ func (node *BlockNode) CalcPastMedianTime() uint64 {
        timestamps := []uint64{}
        iterNode := node
        for i := 0; i < consensus.MedianTimeBlocks && iterNode != nil; i++ {
-               timestamps = append(timestamps, iterNode.timestamp)
-               iterNode = iterNode.parent
+               timestamps = append(timestamps, iterNode.Timestamp)
+               iterNode = iterNode.Parent
        }
 
        sort.Sort(common.TimeSorter(timestamps))
@@ -95,23 +95,23 @@ func (node *BlockNode) CalcPastMedianTime() uint64 {
 
 // CalcNextBits calculate the seed for next block
 func (node *BlockNode) CalcNextBits() uint64 {
-       if node.height%consensus.BlocksPerRetarget != 0 || node.height == 0 {
-               return node.bits
+       if node.Height%consensus.BlocksPerRetarget != 0 || node.Height == 0 {
+               return node.Bits
        }
 
-       compareNode := node.parent
-       for compareNode.height%consensus.BlocksPerRetarget != 0 {
-               compareNode = compareNode.parent
+       compareNode := node.Parent
+       for compareNode.Height%consensus.BlocksPerRetarget != 0 {
+               compareNode = compareNode.Parent
        }
-       return difficulty.CalcNextRequiredDifficulty(node.blockHeader(), compareNode.blockHeader())
+       return difficulty.CalcNextRequiredDifficulty(node.BlockHeader(), compareNode.BlockHeader())
 }
 
 // CalcNextSeed calculate the seed for next block
 func (node *BlockNode) CalcNextSeed() *bc.Hash {
-       if node.height%consensus.SeedPerRetarget == 0 {
+       if node.Height%consensus.SeedPerRetarget == 0 {
                return &node.Hash
        }
-       return node.seed
+       return node.Seed
 }
 
 // BlockIndex is the struct for help chain trace block chain as tree
@@ -167,7 +167,7 @@ func (bi *BlockIndex) InMainchain(hash bc.Hash) bool {
        if !ok {
                return false
        }
-       return bi.nodeByHeight(node.height) == node
+       return bi.nodeByHeight(node.Height) == node
 }
 
 func (bi *BlockIndex) nodeByHeight(height uint64) *BlockNode {
@@ -189,7 +189,7 @@ func (bi *BlockIndex) SetMainChain(node *BlockNode) {
        bi.Lock()
        defer bi.Unlock()
 
-       needed := node.height + 1
+       needed := node.Height + 1
        if uint64(cap(bi.mainChain)) < needed {
                nodes := make([]*BlockNode, needed, needed+approxNodesPerDay)
                copy(nodes, bi.mainChain)
@@ -202,8 +202,8 @@ func (bi *BlockIndex) SetMainChain(node *BlockNode) {
                }
        }
 
-       for node != nil && bi.mainChain[node.height] != node {
-               bi.mainChain[node.height] = node
-               node = node.parent
+       for node != nil && bi.mainChain[node.Height] != node {
+               bi.mainChain[node.Height] = node
+               node = node.Parent
        }
 }
index e77431f..f7bf017 100644 (file)
@@ -12,18 +12,18 @@ type Store interface {
        BlockExist(*bc.Hash) bool
 
        GetBlock(*bc.Hash) (*types.Block, error)
-       GetStoreStatus() BlockStoreStateJSON
+       GetStoreStatus() *BlockStoreState
        GetTransactionStatus(*bc.Hash) (*bc.TransactionStatus, error)
        GetTransactionsUtxo(*state.UtxoViewpoint, []*bc.Tx) error
        GetUtxo(*bc.Hash) (*storage.UtxoEntry, error)
 
-       LoadBlockIndex() (*BlockIndex, error)
+       LoadBlockIndex() (*state.BlockIndex, error)
        SaveBlock(*types.Block, *bc.TransactionStatus) error
-       SaveChainStatus(*types.Block, *state.UtxoViewpoint) error
+       SaveChainStatus(*state.BlockNode, *state.UtxoViewpoint) error
 }
 
-// BlockStoreStateJSON represents the core's db status
-type BlockStoreStateJSON struct {
+// BlockStoreState represents the core's db status
+type BlockStoreState struct {
        Height uint64
        Hash   *bc.Hash
 }
index 4ed5177..7322de1 100644 (file)
@@ -4,44 +4,47 @@ import (
        "github.com/bytom/errors"
        "github.com/bytom/protocol/bc"
        "github.com/bytom/protocol/bc/types"
+       "github.com/bytom/protocol/state"
        "github.com/bytom/protocol/validation"
 )
 
 // ErrBadTx is returned for transactions failing validation
 var ErrBadTx = errors.New("invalid transaction")
 
+// GetTransactionStatus return the transaction status of give block
+func (c *Chain) GetTransactionStatus(hash *bc.Hash) (*bc.TransactionStatus, error) {
+       return c.store.GetTransactionStatus(hash)
+}
+
+// GetTransactionsUtxo return all the utxos that related to the txs' inputs
+func (c *Chain) GetTransactionsUtxo(view *state.UtxoViewpoint, txs []*bc.Tx) error {
+       return c.store.GetTransactionsUtxo(view, txs)
+}
+
 // ValidateTx validates the given transaction. A cache holds
 // per-transaction validation results and is consulted before
 // performing full validation.
 func (c *Chain) ValidateTx(tx *types.Tx) (bool, error) {
-       newTx := tx.Tx
-       bh := c.BestBlockHeader()
-       block := types.MapBlock(&types.Block{BlockHeader: *bh})
-       if ok := c.txPool.HaveTransaction(&newTx.ID); ok {
-               return false, c.txPool.GetErrCache(&newTx.ID)
+       if ok := c.txPool.HaveTransaction(&tx.ID); ok {
+               return false, c.txPool.GetErrCache(&tx.ID)
        }
 
-       // validate the UTXO
        view := c.txPool.GetTransactionUTXO(tx.Tx)
-       if err := c.GetTransactionsUtxo(view, []*bc.Tx{newTx}); err != nil {
-               c.txPool.AddErrCache(&newTx.ID, err)
-               return false, err
+       if err := c.GetTransactionsUtxo(view, []*bc.Tx{tx.Tx}); err != nil {
+               return true, err
        }
-       if err := view.ApplyTransaction(block, newTx, false); err != nil {
+
+       bh := c.BestBlockHeader()
+       block := types.MapBlock(&types.Block{BlockHeader: *bh})
+       if err := view.ApplyTransaction(block, tx.Tx, false); err != nil {
                return true, err
        }
 
-       // validate the BVM contract
-       gasOnlyTx := false
-       gasStatus, err := validation.ValidateTx(newTx, block)
-       if err != nil {
-               if gasStatus == nil || !gasStatus.GasVaild {
-                       c.txPool.AddErrCache(&newTx.ID, err)
-                       return false, err
-               }
-               gasOnlyTx = true
+       gasStatus, err := validation.ValidateTx(tx.Tx, block)
+       if gasStatus.GasVaild == false {
+               c.txPool.AddErrCache(&tx.ID, err)
        }
 
-       _, err = c.txPool.AddTransaction(tx, gasOnlyTx, block.BlockHeader.Height, gasStatus.BTMValue)
+       _, err = c.txPool.AddTransaction(tx, err != nil, block.BlockHeader.Height, gasStatus.BTMValue)
        return false, err
 }
diff --git a/protocol/tx_test.go b/protocol/tx_test.go
deleted file mode 100644 (file)
index c7b060b..0000000
+++ /dev/null
@@ -1,81 +0,0 @@
-package protocol
-
-import (
-       "fmt"
-       "testing"
-
-       "golang.org/x/crypto/sha3"
-
-       "github.com/bytom/crypto/ed25519"
-       "github.com/bytom/protocol/bc"
-       "github.com/bytom/protocol/bc/types"
-       "github.com/bytom/protocol/vm"
-       "github.com/bytom/protocol/vm/vmutil"
-       "github.com/bytom/testutil"
-)
-
-type testDest struct {
-       privKey ed25519.PrivateKey
-}
-
-func newDest(t testing.TB) *testDest {
-       _, priv, err := ed25519.GenerateKey(nil)
-       if err != nil {
-               testutil.FatalErr(t, err)
-       }
-       return &testDest{
-               privKey: priv,
-       }
-}
-
-func (d *testDest) sign(t testing.TB, tx *types.Tx, index uint32) {
-       txsighash := tx.SigHash(index)
-       prog, _ := vm.Assemble(fmt.Sprintf("0x%x TXSIGHASH EQUAL", txsighash.Bytes()))
-       h := sha3.Sum256(prog)
-       sig := ed25519.Sign(d.privKey, h[:])
-       tx.Inputs[index].SetArguments([][]byte{vm.Int64Bytes(0), sig, prog})
-}
-
-func (d testDest) controlProgram() ([]byte, error) {
-       pub := d.privKey.Public().(ed25519.PublicKey)
-       return vmutil.P2SPMultiSigProgram([]ed25519.PublicKey{pub}, 1)
-}
-
-type testAsset struct {
-       bc.AssetID
-       testDest
-}
-
-func newAsset(t testing.TB) *testAsset {
-       dest := newDest(t)
-       cp, _ := dest.controlProgram()
-       assetID := bc.ComputeAssetID(cp, 1, &bc.EmptyStringHash)
-
-       return &testAsset{
-               AssetID:  assetID,
-               testDest: *dest,
-       }
-}
-
-func issue(t testing.TB, asset *testAsset, dest *testDest, amount uint64) (*types.Tx, *testAsset, *testDest) {
-       if asset == nil {
-               asset = newAsset(t)
-       }
-       if dest == nil {
-               dest = newDest(t)
-       }
-       assetCP, _ := asset.controlProgram()
-       destCP, _ := dest.controlProgram()
-       tx := types.NewTx(types.TxData{
-               Version: 1,
-               Inputs: []*types.TxInput{
-                       types.NewIssuanceInput([]byte{1}, amount, assetCP, nil, nil),
-               },
-               Outputs: []*types.TxOutput{
-                       types.NewTxOutput(asset.AssetID, amount, destCP),
-               },
-       })
-       asset.sign(t, tx, 0)
-
-       return tx, asset, dest
-}
diff --git a/protocol/validation.go b/protocol/validation.go
deleted file mode 100644 (file)
index 9e1f112..0000000
+++ /dev/null
@@ -1,144 +0,0 @@
-package protocol
-
-import (
-       "time"
-
-       "github.com/bytom/consensus"
-       "github.com/bytom/consensus/difficulty"
-       "github.com/bytom/errors"
-       "github.com/bytom/protocol/bc"
-       "github.com/bytom/protocol/validation"
-)
-
-var (
-       errBadTimestamp             = errors.New("block timestamp is not in the vaild range")
-       errBadBits                  = errors.New("block bits is invaild")
-       errMismatchedBlock          = errors.New("mismatched block")
-       errMismatchedMerkleRoot     = errors.New("mismatched merkle root")
-       errMismatchedTxStatus       = errors.New("mismatched transaction status")
-       errMismatchedValue          = errors.New("mismatched value")
-       errMisorderedBlockHeight    = errors.New("misordered block height")
-       errMisorderedBlockTime      = errors.New("misordered block time")
-       errNoPrevBlock              = errors.New("no previous block")
-       errOverflow                 = errors.New("arithmetic overflow/underflow")
-       errOverBlockLimit           = errors.New("block's gas is over the limit")
-       errWorkProof                = errors.New("invalid difficulty proof of work")
-       errVersionRegression        = errors.New("version regression")
-       errWrongBlockSize           = errors.New("block size is too big")
-       errWrongTransactionStatus   = errors.New("transaction status is wrong")
-       errWrongCoinbaseTransaction = errors.New("wrong coinbase transaction")
-       errNotStandardTx            = errors.New("gas transaction is not standard transaction")
-)
-
-// ValidateBlock validates a block and the transactions within.
-// It does not run the consensus program; for that, see ValidateBlockSig.
-func (c *Chain) validateBlock(b *bc.Block) error {
-       parent := c.index.GetNode(b.PreviousBlockId)
-       if parent == nil {
-               return errors.WithDetailf(errNoPrevBlock, "height %d", b.Height)
-       }
-       if err := validateBlockAgainstPrev(b, parent); err != nil {
-               return err
-       }
-
-       if !difficulty.CheckProofOfWork(&b.ID, parent.CalcNextSeed(), b.BlockHeader.Bits) {
-               return errWorkProof
-       }
-
-       b.TransactionStatus = bc.NewTransactionStatus()
-       coinbaseValue := consensus.BlockSubsidy(b.BlockHeader.Height)
-       gasUsed := uint64(0)
-       for i, tx := range b.Transactions {
-               gasStatus, err := validation.ValidateTx(tx, b)
-               gasOnlyTx := false
-               if err != nil {
-                       if gasStatus == nil || !gasStatus.GasVaild {
-                               return errors.Wrapf(err, "validity of transaction %d of %d", i, len(b.Transactions))
-                       }
-                       gasOnlyTx = true
-               }
-               b.TransactionStatus.SetStatus(i, gasOnlyTx)
-               coinbaseValue += gasStatus.BTMValue
-               gasUsed += uint64(gasStatus.GasUsed)
-       }
-
-       if gasUsed > consensus.MaxBlockGas {
-               return errOverBlockLimit
-       }
-
-       // check the coinbase output entry value
-       if err := validateCoinbase(b.Transactions[0], coinbaseValue); err != nil {
-               return err
-       }
-
-       txRoot, err := bc.TxMerkleRoot(b.Transactions)
-       if err != nil {
-               return errors.Wrap(err, "computing transaction merkle root")
-       }
-
-       if txRoot != *b.TransactionsRoot {
-               return errors.WithDetailf(errMismatchedMerkleRoot, "computed %x, current block wants %x", txRoot.Bytes(), b.TransactionsRoot.Bytes())
-       }
-
-       txStatusHash, err := bc.TxStatusMerkleRoot(b.TransactionStatus.VerifyStatus)
-       if err != nil {
-               return err
-       }
-
-       if txStatusHash != *b.TransactionStatusHash {
-               return errMismatchedTxStatus
-       }
-       return nil
-}
-
-func validateBlockTime(b *bc.Block, parent *BlockNode) error {
-       if b.Timestamp > uint64(time.Now().Unix())+consensus.MaxTimeOffsetSeconds {
-               return errBadTimestamp
-       }
-
-       if b.Timestamp <= parent.CalcPastMedianTime() {
-               return errBadTimestamp
-       }
-       return nil
-}
-
-func validateCoinbase(tx *bc.Tx, value uint64) error {
-       resultEntry := tx.Entries[*tx.TxHeader.ResultIds[0]]
-       output, ok := resultEntry.(*bc.Output)
-       if !ok {
-               return errors.Wrap(errWrongCoinbaseTransaction, "decode output")
-       }
-
-       if output.Source.Value.Amount != value {
-               return errors.Wrap(errWrongCoinbaseTransaction, "dismatch output value")
-       }
-
-       inputEntry := tx.Entries[tx.InputIDs[0]]
-       input, ok := inputEntry.(*bc.Coinbase)
-       if !ok {
-               return errors.Wrap(errWrongCoinbaseTransaction, "decode input")
-       }
-       if input.Arbitrary != nil && len(input.Arbitrary) > consensus.CoinbaseArbitrarySizeLimit {
-               return errors.Wrap(errWrongCoinbaseTransaction, "coinbase arbitrary is over size")
-       }
-       return nil
-}
-
-func validateBlockAgainstPrev(b *bc.Block, parent *BlockNode) error {
-       if b.Version < parent.version {
-               return errors.WithDetailf(errVersionRegression, "previous block verson %d, current block version %d", parent.version, b.Version)
-       }
-       if b.Height != parent.height+1 {
-               return errors.WithDetailf(errMisorderedBlockHeight, "previous block height %d, current block height %d", parent.height, b.Height)
-       }
-       if b.Bits != parent.CalcNextBits() {
-               return errBadBits
-       }
-       if parent.Hash != *b.PreviousBlockId {
-               return errors.WithDetailf(errMismatchedBlock, "previous block ID %x, current block wants %x", parent.Hash.Bytes(), b.PreviousBlockId.Bytes())
-       }
-       if err := validateBlockTime(b, parent); err != nil {
-               return err
-       }
-       return nil
-}
diff --git a/protocol/validation/block.go b/protocol/validation/block.go
new file mode 100644 (file)
index 0000000..e1f904e
--- /dev/null
@@ -0,0 +1,118 @@
+package validation
+
+import (
+       "time"
+
+       "github.com/bytom/consensus"
+       "github.com/bytom/consensus/difficulty"
+       "github.com/bytom/errors"
+       "github.com/bytom/protocol/bc"
+       "github.com/bytom/protocol/state"
+)
+
+var (
+       errBadTimestamp          = errors.New("block timestamp is not in the vaild range")
+       errBadBits               = errors.New("block bits is invaild")
+       errMismatchedBlock       = errors.New("mismatched block")
+       errMismatchedMerkleRoot  = errors.New("mismatched merkle root")
+       errMisorderedBlockHeight = errors.New("misordered block height")
+       errOverBlockLimit        = errors.New("block's gas is over the limit")
+       errWorkProof             = errors.New("invalid difficulty proof of work")
+       errVersionRegression     = errors.New("version regression")
+)
+
+func checkBlockTime(b *bc.Block, parent *state.BlockNode) error {
+       if b.Timestamp > uint64(time.Now().Unix())+consensus.MaxTimeOffsetSeconds {
+               return errBadTimestamp
+       }
+
+       if b.Timestamp <= parent.CalcPastMedianTime() {
+               return errBadTimestamp
+       }
+       return nil
+}
+
+func checkCoinbaseAmount(b *bc.Block, amount uint64) error {
+       if len(b.Transactions) == 0 {
+               return errors.Wrap(errWrongCoinbaseTransaction, "block is empty")
+       }
+
+       tx := b.Transactions[0]
+       output, err := tx.Output(*tx.TxHeader.ResultIds[0])
+       if err != nil {
+               return err
+       }
+
+       if output.Source.Value.Amount != amount {
+               return errors.Wrap(errWrongCoinbaseTransaction, "dismatch output amount")
+       }
+       return nil
+}
+
+// ValidateBlockHeader check the block's header
+func ValidateBlockHeader(b *bc.Block, parent *state.BlockNode) error {
+       if b.Version < parent.Version {
+               return errors.WithDetailf(errVersionRegression, "previous block verson %d, current block version %d", parent.Version, b.Version)
+       }
+       if b.Height != parent.Height+1 {
+               return errors.WithDetailf(errMisorderedBlockHeight, "previous block height %d, current block height %d", parent.Height, b.Height)
+       }
+       if b.Bits != parent.CalcNextBits() {
+               return errBadBits
+       }
+       if parent.Hash != *b.PreviousBlockId {
+               return errors.WithDetailf(errMismatchedBlock, "previous block ID %x, current block wants %x", parent.Hash.Bytes(), b.PreviousBlockId.Bytes())
+       }
+       if err := checkBlockTime(b, parent); err != nil {
+               return err
+       }
+       if !difficulty.CheckProofOfWork(&b.ID, parent.CalcNextSeed(), b.BlockHeader.Bits) {
+               return errWorkProof
+       }
+       return nil
+}
+
+// ValidateBlock validates a block and the transactions within.
+func ValidateBlock(b *bc.Block, parent *state.BlockNode) error {
+       if err := ValidateBlockHeader(b, parent); err != nil {
+               return err
+       }
+
+       blockGasSum := uint64(0)
+       coinbaseAmount := consensus.BlockSubsidy(b.BlockHeader.Height)
+       b.TransactionStatus = bc.NewTransactionStatus()
+
+       for i, tx := range b.Transactions {
+               gasStatus, err := ValidateTx(tx, b)
+               if !gasStatus.GasVaild {
+                       return errors.Wrapf(err, "validate of transaction %d of %d", i, len(b.Transactions))
+               }
+
+               b.TransactionStatus.SetStatus(i, err != nil)
+               coinbaseAmount += gasStatus.BTMValue
+               if blockGasSum += uint64(gasStatus.GasUsed); blockGasSum > consensus.MaxBlockGas {
+                       return errOverBlockLimit
+               }
+       }
+
+       if err := checkCoinbaseAmount(b, coinbaseAmount); err != nil {
+               return err
+       }
+
+       txMerkleRoot, err := bc.TxMerkleRoot(b.Transactions)
+       if err != nil {
+               return errors.Wrap(err, "computing transaction id merkle root")
+       }
+       if txMerkleRoot != *b.TransactionsRoot {
+               return errors.WithDetailf(errMismatchedMerkleRoot, "transaction id merkle root")
+       }
+
+       txStatusHash, err := bc.TxStatusMerkleRoot(b.TransactionStatus.VerifyStatus)
+       if err != nil {
+               return errors.Wrap(err, "computing transaction status merkle root")
+       }
+       if txStatusHash != *b.TransactionStatusHash {
+               return errors.WithDetailf(errMismatchedMerkleRoot, "transaction status merkle root")
+       }
+       return nil
+}
index b50ef96..d4f6de8 100644 (file)
@@ -3,30 +3,87 @@ package validation
 import (
        "testing"
 
+       "github.com/bytom/consensus"
        "github.com/bytom/protocol/bc"
        "github.com/bytom/protocol/bc/types"
+       "github.com/bytom/protocol/state"
 )
 
-func dummyValidateTx(*bc.Tx) error {
-       return nil
+func TestCheckBlockTime(t *testing.T) {
+       cases := []struct {
+               blockTime  uint64
+               parentTime uint64
+               err        error
+       }{
+               {
+                       blockTime:  1520000001,
+                       parentTime: 1520000000,
+                       err:        nil,
+               },
+               {
+                       blockTime:  1510000000,
+                       parentTime: 1520000000,
+                       err:        errBadTimestamp,
+               },
+               {
+                       blockTime:  9999999999,
+                       parentTime: 1520000000,
+                       err:        errBadTimestamp,
+               },
+       }
+
+       parent := &state.BlockNode{}
+       block := &bc.Block{
+               BlockHeader: &bc.BlockHeader{},
+       }
+
+       for i, c := range cases {
+               parent.Timestamp = c.parentTime
+               block.Timestamp = c.blockTime
+               if err := checkBlockTime(block, parent); rootErr(err) != c.err {
+                       t.Errorf("case %d got error %s, want %s", i, err, c.err)
+               }
+       }
 }
 
-func generate(tb testing.TB, prev *bc.Block) *bc.Block {
-       b := &types.Block{
-               BlockHeader: types.BlockHeader{
-                       Version:           1,
-                       Height:            prev.Height + 1,
-                       PreviousBlockHash: prev.ID,
-                       Timestamp:         prev.Timestamp + 1,
-                       BlockCommitment:   types.BlockCommitment{},
+func TestCheckCoinbaseAmount(t *testing.T) {
+       cases := []struct {
+               txs    []*types.Tx
+               amount uint64
+               err    error
+       }{
+               {
+                       txs: []*types.Tx{
+                               types.NewTx(types.TxData{
+                                       Inputs:  []*types.TxInput{types.NewCoinbaseInput(nil)},
+                                       Outputs: []*types.TxOutput{types.NewTxOutput(*consensus.BTMAssetID, 5000, nil)},
+                               }),
+                       },
+                       amount: 5000,
+                       err:    nil,
+               },
+               {
+                       txs: []*types.Tx{
+                               types.NewTx(types.TxData{
+                                       Inputs:  []*types.TxInput{types.NewCoinbaseInput(nil)},
+                                       Outputs: []*types.TxOutput{types.NewTxOutput(*consensus.BTMAssetID, 5000, nil)},
+                               }),
+                       },
+                       amount: 6000,
+                       err:    errWrongCoinbaseTransaction,
+               },
+               {
+                       txs:    []*types.Tx{},
+                       amount: 5000,
+                       err:    errWrongCoinbaseTransaction,
                },
        }
 
-       var err error
-       b.TransactionsMerkleRoot, err = bc.TxMerkleRoot(nil)
-       if err != nil {
-               tb.Fatal(err)
+       block := new(types.Block)
+       for i, c := range cases {
+               block.Transactions = c.txs
+               if err := checkCoinbaseAmount(types.MapBlock(block), c.amount); rootErr(err) != c.err {
+                       t.Errorf("case %d got error %s, want %s", i, err, c.err)
+               }
        }
-
-       return types.MapBlock(b)
 }
similarity index 65%
rename from protocol/validation/validation.go
rename to protocol/validation/tx.go
index 7dd1ab7..ce5f190 100644 (file)
@@ -11,11 +11,8 @@ import (
        "github.com/bytom/protocol/vm"
 )
 
-const (
-       muxGasCost = int64(10)
-       // timeRangeGash is the block height we will reach after 100 years
-       timeRangeGash = uint64(21024000)
-)
+// timeRangeGash is the block height we will reach after 100 years
+const timeRangeGash = uint64(21024000)
 
 // GasState record the gas usage status
 type GasState struct {
@@ -23,7 +20,7 @@ type GasState struct {
        GasLeft    int64
        GasUsed    int64
        GasVaild   bool
-       storageGas int64
+       StorageGas int64
 }
 
 func (g *GasState) setGas(BTMValue int64, txSize int64) error {
@@ -33,11 +30,6 @@ func (g *GasState) setGas(BTMValue int64, txSize int64) error {
 
        g.BTMValue = uint64(BTMValue)
 
-       if BTMValue == 0 {
-               g.GasLeft = muxGasCost
-               return nil
-       }
-
        var ok bool
        if g.GasLeft, ok = checked.DivInt64(BTMValue, consensus.VMGasRate); !ok {
                return errors.Wrap(errGasCalculate, "setGas calc gas amount")
@@ -47,7 +39,7 @@ func (g *GasState) setGas(BTMValue int64, txSize int64) error {
                g.GasLeft = consensus.MaxGasAmount
        }
 
-       if g.storageGas, ok = checked.MulInt64(txSize, consensus.StorageGasRate); !ok {
+       if g.StorageGas, ok = checked.MulInt64(txSize, consensus.StorageGasRate); !ok {
                return errors.Wrap(errGasCalculate, "setGas calc tx storage gas")
        }
        return nil
@@ -55,11 +47,11 @@ func (g *GasState) setGas(BTMValue int64, txSize int64) error {
 
 func (g *GasState) setGasVaild() error {
        var ok bool
-       if g.GasLeft, ok = checked.SubInt64(g.GasLeft, g.storageGas); !ok || g.GasLeft < 0 {
+       if g.GasLeft, ok = checked.SubInt64(g.GasLeft, g.StorageGas); !ok || g.GasLeft < 0 {
                return errors.Wrap(errGasCalculate, "setGasVaild calc gasLeft")
        }
 
-       if g.GasUsed, ok = checked.AddInt64(g.GasUsed, g.storageGas); !ok {
+       if g.GasUsed, ok = checked.AddInt64(g.GasUsed, g.StorageGas); !ok {
                return errors.Wrap(errGasCalculate, "setGasVaild calc gasUsed")
        }
 
@@ -79,7 +71,7 @@ func (g *GasState) updateUsage(gasLeft int64) error {
                return errors.Wrap(errGasCalculate, "updateUsage calc gas diff")
        }
 
-       if !g.GasVaild && (g.GasUsed > consensus.DefaultGasCredit || g.storageGas > g.GasLeft) {
+       if !g.GasVaild && (g.GasUsed > consensus.DefaultGasCredit || g.StorageGas > g.GasLeft) {
                return errOverGasCredit
        }
        return nil
@@ -88,64 +80,40 @@ func (g *GasState) updateUsage(gasLeft int64) error {
 // validationState contains the context that must propagate through
 // the transaction graph when validating entries.
 type validationState struct {
-       // The ID of the blockchain
-       block *bc.Block
-
-       // The enclosing transaction object
-       tx *bc.Tx
-
-       // The ID of the nearest enclosing entry
-       entryID bc.Hash
-
-       // The source position, for validating ValueSources
-       sourcePos uint64
-
-       // The destination position, for validating ValueDestinations
-       destPos uint64
-
-       // Memoized per-entry validation results
-       cache map[bc.Hash]error
-
+       block     *bc.Block
+       tx        *bc.Tx
        gasStatus *GasState
+       entryID   bc.Hash           // The ID of the nearest enclosing entry
+       sourcePos uint64            // The source position, for validate ValueSources
+       destPos   uint64            // The destination position, for validate ValueDestinations
+       cache     map[bc.Hash]error // Memoized per-entry validation results
 }
 
 var (
-       errGasCalculate             = errors.New("gas usage calculate got a math error")
-       errEmptyResults             = errors.New("transaction has no results")
-       errMismatchedAssetID        = errors.New("mismatched asset id")
-       errMismatchedBlock          = errors.New("mismatched block")
-       errMismatchedMerkleRoot     = errors.New("mismatched merkle root")
-       errMismatchedPosition       = errors.New("mismatched value source/dest positions")
-       errMismatchedReference      = errors.New("mismatched reference")
-       errMismatchedTxStatus       = errors.New("mismatched transaction status")
-       errMismatchedValue          = errors.New("mismatched value")
-       errMisorderedBlockHeight    = errors.New("misordered block height")
-       errMisorderedBlockTime      = errors.New("misordered block time")
-       errMissingField             = errors.New("missing required field")
-       errNoGas                    = errors.New("no gas input")
-       errNoPrevBlock              = errors.New("no previous block")
-       errNoSource                 = errors.New("no source for value")
-       errNonemptyExtHash          = errors.New("non-empty extension hash")
-       errOverflow                 = errors.New("arithmetic overflow/underflow")
-       errOverGasCredit            = errors.New("all gas credit has been spend")
-       errOverBlockLimit           = errors.New("block's gas is over the limit")
-       errPosition                 = errors.New("invalid source or destination position")
-       errWorkProof                = errors.New("invalid difficulty proof of work")
-       errTxVersion                = errors.New("invalid transaction version")
-       errUnbalanced               = errors.New("unbalanced")
-       errUntimelyTransaction      = errors.New("block timestamp outside transaction time range")
-       errVersionRegression        = errors.New("version regression")
-       errWrongBlockSize           = errors.New("block size is too big")
-       errWrongTransactionSize     = errors.New("transaction size is not in vaild range")
-       errWrongTransactionStatus   = errors.New("transaction status is wrong")
-       errWrongCoinbaseTransaction = errors.New("wrong coinbase transaction")
-       errWrongCoinbaseAsset       = errors.New("wrong coinbase asset id")
-       errNotStandardTx            = errors.New("gas transaction is not standard transaction")
+       errBadTimeRange              = errors.New("tx time range is invalid")
+       errCoinbaseArbitraryOversize = errors.New("coinbase arbitrary size is larger than limit")
+       errGasCalculate              = errors.New("gas usage calculate got a math error")
+       errEmptyResults              = errors.New("transaction has no results")
+       errMismatchedAssetID         = errors.New("mismatched asset id")
+       errMismatchedPosition        = errors.New("mismatched value source/dest positions")
+       errMismatchedReference       = errors.New("mismatched reference")
+       errMismatchedValue           = errors.New("mismatched value")
+       errMissingField              = errors.New("missing required field")
+       errNoSource                  = errors.New("no source for value")
+       errOverflow                  = errors.New("arithmetic overflow/underflow")
+       errOverGasCredit             = errors.New("all gas credit has been spend")
+       errPosition                  = errors.New("invalid source or destination position")
+       errTxVersion                 = errors.New("invalid transaction version")
+       errUnbalanced                = errors.New("unbalanced")
+       errWrongTransactionSize      = errors.New("transaction size is not in vaild range")
+       errWrongCoinbaseTransaction  = errors.New("wrong coinbase transaction")
+       errWrongCoinbaseAsset        = errors.New("wrong coinbase asset id")
+       errNotStandardTx             = errors.New("gas transaction is not standard transaction")
 )
 
 func checkValid(vs *validationState, e bc.Entry) (err error) {
-       entryID := bc.EntryID(e)
        var ok bool
+       entryID := bc.EntryID(e)
        if err, ok = vs.cache[entryID]; ok {
                return err
        }
@@ -156,37 +124,17 @@ func checkValid(vs *validationState, e bc.Entry) (err error) {
 
        switch e := e.(type) {
        case *bc.TxHeader:
-
                for i, resID := range e.ResultIds {
                        resultEntry := vs.tx.Entries[*resID]
                        vs2 := *vs
                        vs2.entryID = *resID
-                       err = checkValid(&vs2, resultEntry)
-                       if err != nil {
+                       if err = checkValid(&vs2, resultEntry); err != nil {
                                return errors.Wrapf(err, "checking result %d", i)
                        }
                }
 
-               if e.Version == 1 {
-                       if len(e.ResultIds) == 0 {
-                               return errEmptyResults
-                       }
-               }
-
-       case *bc.Coinbase:
-               if vs.block == nil || len(vs.block.Transactions) == 0 || vs.block.Transactions[0] != vs.tx {
-                       return errWrongCoinbaseTransaction
-               }
-
-               if *e.WitnessDestination.Value.AssetId != *consensus.BTMAssetID {
-                       return errWrongCoinbaseAsset
-               }
-
-               vs2 := *vs
-               vs2.destPos = 0
-               err = checkValidDest(&vs2, e.WitnessDestination)
-               if err != nil {
-                       return errors.Wrap(err, "checking coinbase destination")
+               if e.Version == 1 && len(e.ResultIds) == 0 {
+                       return errEmptyResults
                }
 
        case *bc.Mux:
@@ -222,14 +170,6 @@ func checkValid(vs *validationState, e bc.Entry) (err error) {
                        }
                }
 
-               gasLeft, err := vm.Verify(NewTxVMContext(vs, e, e.Program, e.WitnessArguments), vs.gasStatus.GasLeft)
-               if err != nil {
-                       return errors.Wrap(err, "checking mux program")
-               }
-               if err = vs.gasStatus.updateUsage(gasLeft); err != nil {
-                       return err
-               }
-
                for _, BTMInputID := range vs.tx.GasInputIDs {
                        e, ok := vs.tx.Entries[BTMInputID]
                        if !ok {
@@ -239,28 +179,28 @@ func checkValid(vs *validationState, e bc.Entry) (err error) {
                        vs2 := *vs
                        vs2.entryID = BTMInputID
                        if err := checkValid(&vs2, e); err != nil {
-                               return errors.Wrap(err, "checking value source")
+                               return errors.Wrap(err, "checking gas input")
                        }
                }
 
                for i, dest := range e.WitnessDestinations {
                        vs2 := *vs
                        vs2.destPos = uint64(i)
-                       err = checkValidDest(&vs2, dest)
-                       if err != nil {
+                       if err = checkValidDest(&vs2, dest); err != nil {
                                return errors.Wrapf(err, "checking mux destination %d", i)
                        }
                }
 
-               if err := vs.gasStatus.setGasVaild(); err != nil {
-                       return err
+               if len(vs.tx.GasInputIDs) > 0 {
+                       if err := vs.gasStatus.setGasVaild(); err != nil {
+                               return err
+                       }
                }
 
                for i, src := range e.Sources {
                        vs2 := *vs
                        vs2.sourcePos = uint64(i)
-                       err = checkValidSrc(&vs2, src)
-                       if err != nil {
+                       if err = checkValidSrc(&vs2, src); err != nil {
                                return errors.Wrapf(err, "checking mux source %d", i)
                        }
                }
@@ -268,16 +208,14 @@ func checkValid(vs *validationState, e bc.Entry) (err error) {
        case *bc.Output:
                vs2 := *vs
                vs2.sourcePos = 0
-               err = checkValidSrc(&vs2, e.Source)
-               if err != nil {
+               if err = checkValidSrc(&vs2, e.Source); err != nil {
                        return errors.Wrap(err, "checking output source")
                }
 
        case *bc.Retirement:
                vs2 := *vs
                vs2.sourcePos = 0
-               err = checkValidSrc(&vs2, e.Source)
-               if err != nil {
+               if err = checkValidSrc(&vs2, e.Source); err != nil {
                        return errors.Wrap(err, "checking retirement source")
                }
 
@@ -297,8 +235,7 @@ func checkValid(vs *validationState, e bc.Entry) (err error) {
 
                destVS := *vs
                destVS.destPos = 0
-               err = checkValidDest(&destVS, e.WitnessDestination)
-               if err != nil {
+               if err = checkValidDest(&destVS, e.WitnessDestination); err != nil {
                        return errors.Wrap(err, "checking issuance destination")
                }
 
@@ -310,6 +247,7 @@ func checkValid(vs *validationState, e bc.Entry) (err error) {
                if err != nil {
                        return errors.Wrap(err, "getting spend prevout")
                }
+
                gasLeft, err := vm.Verify(NewTxVMContext(vs, e, spentOutput.ControlProgram, e.WitnessArguments), vs.gasStatus.GasLeft)
                if err != nil {
                        return errors.Wrap(err, "checking control program")
@@ -335,11 +273,32 @@ func checkValid(vs *validationState, e bc.Entry) (err error) {
 
                vs2 := *vs
                vs2.destPos = 0
-               err = checkValidDest(&vs2, e.WitnessDestination)
-               if err != nil {
+               if err = checkValidDest(&vs2, e.WitnessDestination); err != nil {
                        return errors.Wrap(err, "checking spend destination")
                }
 
+       case *bc.Coinbase:
+               if vs.block == nil || len(vs.block.Transactions) == 0 || vs.block.Transactions[0] != vs.tx {
+                       return errWrongCoinbaseTransaction
+               }
+
+               if *e.WitnessDestination.Value.AssetId != *consensus.BTMAssetID {
+                       return errWrongCoinbaseAsset
+               }
+
+               if e.Arbitrary != nil && len(e.Arbitrary) > consensus.CoinbaseArbitrarySizeLimit {
+                       return errCoinbaseArbitraryOversize
+               }
+
+               vs2 := *vs
+               vs2.destPos = 0
+               if err = checkValidDest(&vs2, e.WitnessDestination); err != nil {
+                       return errors.Wrap(err, "checking coinbase destination")
+               }
+
+               // special case for coinbase transaction, it's valid unit all the verify has been passed
+               vs.gasStatus.GasVaild = true
+
        default:
                return fmt.Errorf("entry has unexpected type %T", e)
        }
@@ -362,10 +321,10 @@ func checkValidSrc(vstate *validationState, vs *bc.ValueSource) error {
        if !ok {
                return errors.Wrapf(bc.ErrMissingEntry, "entry for value source %x not found", vs.Ref.Bytes())
        }
+
        vstate2 := *vstate
        vstate2.entryID = *vs.Ref
-       err := checkValid(&vstate2, e)
-       if err != nil {
+       if err := checkValid(&vstate2, e); err != nil {
                return errors.Wrap(err, "checking value source")
        }
 
@@ -376,6 +335,7 @@ func checkValidSrc(vstate *validationState, vs *bc.ValueSource) error {
                        return errors.Wrapf(errPosition, "invalid position %d for coinbase source", vs.Position)
                }
                dest = ref.WitnessDestination
+
        case *bc.Issuance:
                if vs.Position != 0 {
                        return errors.Wrapf(errPosition, "invalid position %d for issuance source", vs.Position)
@@ -432,6 +392,7 @@ func checkValidDest(vs *validationState, vd *bc.ValueDestination) error {
        if !ok {
                return errors.Wrapf(bc.ErrMissingEntry, "entry for value destination %x not found", vd.Ref.Bytes())
        }
+
        var src *bc.ValueSource
        switch ref := e.(type) {
        case *bc.Output:
@@ -475,82 +436,78 @@ func checkValidDest(vs *validationState, vd *bc.ValueDestination) error {
        return nil
 }
 
-func validateStandardTx(tx *bc.Tx) error {
-       for _, id := range tx.InputIDs {
-               e, ok := tx.Entries[id]
-               if !ok {
-                       return errors.New("miss tx input entry")
+func checkStandardTx(tx *bc.Tx) error {
+       for _, id := range tx.GasInputIDs {
+               spend, err := tx.Spend(id)
+               if err != nil {
+                       return err
+               }
+               spentOutput, err := tx.Output(*spend.SpentOutputId)
+               if err != nil {
+                       return err
                }
-               if spend, ok := e.(*bc.Spend); ok {
-                       if *spend.WitnessDestination.Value.AssetId != *consensus.BTMAssetID {
-                               continue
-                       }
-                       spentOutput, err := tx.Output(*spend.SpentOutputId)
-                       if err != nil {
-                               return errors.Wrap(err, "getting spend prevout")
-                       }
 
-                       if !segwit.IsP2WScript(spentOutput.ControlProgram.Code) {
-                               return errNotStandardTx
-                       }
+               if !segwit.IsP2WScript(spentOutput.ControlProgram.Code) {
+                       return errNotStandardTx
                }
        }
 
        for _, id := range tx.ResultIds {
                e, ok := tx.Entries[*id]
                if !ok {
-                       return errors.New("miss tx output entry")
+                       return errors.Wrapf(bc.ErrMissingEntry, "id %x", id.Bytes())
                }
-               if output, ok := e.(*bc.Output); ok {
-                       if *output.Source.Value.AssetId != *consensus.BTMAssetID {
-                               continue
-                       }
-                       if !segwit.IsP2WScript(output.ControlProgram.Code) {
-                               return errNotStandardTx
-                       }
+
+               output, ok := e.(*bc.Output)
+               if !ok || *output.Source.Value.AssetId != *consensus.BTMAssetID {
+                       continue
+               }
+
+               if !segwit.IsP2WScript(output.ControlProgram.Code) {
+                       return errNotStandardTx
                }
        }
        return nil
 }
 
-// ValidateTx validates a transaction.
-func ValidateTx(tx *bc.Tx, block *bc.Block) (*GasState, error) {
-       if block.Version == 1 && tx.Version != 1 {
-               return nil, errors.WithDetailf(errTxVersion, "block version %d, transaction version %d", block.Version, tx.Version)
+func checkTimeRange(tx *bc.Tx, block *bc.Block) error {
+       if tx.TimeRange == 0 {
+               return nil
        }
 
-       if tx.TimeRange > timeRangeGash && tx.TimeRange < block.Timestamp {
-               return nil, errors.New("transaction max timestamp is lower than block's")
-       } else if tx.TimeRange != 0 && tx.TimeRange < block.Height {
-               return nil, errors.New("transaction max block height is lower than block's")
+       blockVal := block.Height
+       if tx.TimeRange > timeRangeGash {
+               blockVal = block.Timestamp
        }
 
-       if tx.TxHeader.SerializedSize > consensus.MaxTxSize || tx.TxHeader.SerializedSize == 0 {
-               return nil, errWrongTransactionSize
+       if tx.TimeRange < blockVal {
+               return errBadTimeRange
        }
+       return nil
+}
 
-       if len(tx.ResultIds) == 0 {
-               return nil, errors.New("tx didn't have any output")
+// ValidateTx validates a transaction.
+func ValidateTx(tx *bc.Tx, block *bc.Block) (*GasState, error) {
+       gasStatus := &GasState{GasVaild: false}
+       if block.Version == 1 && tx.Version != 1 {
+               return gasStatus, errors.WithDetailf(errTxVersion, "block version %d, transaction version %d", block.Version, tx.Version)
        }
-
-       if len(tx.GasInputIDs) == 0 && tx != block.Transactions[0] {
-               return nil, errors.New("tx didn't have gas input")
+       if tx.SerializedSize == 0 {
+               return gasStatus, errWrongTransactionSize
        }
-
-       if err := validateStandardTx(tx); err != nil {
-               return nil, err
+       if err := checkTimeRange(tx, block); err != nil {
+               return gasStatus, err
+       }
+       if err := checkStandardTx(tx); err != nil {
+               return gasStatus, err
        }
 
        vs := &validationState{
-               block:   block,
-               tx:      tx,
-               entryID: tx.ID,
-               gasStatus: &GasState{
-                       GasVaild: false,
-               },
-               cache: make(map[bc.Hash]error),
+               block:     block,
+               tx:        tx,
+               entryID:   tx.ID,
+               gasStatus: gasStatus,
+               cache:     make(map[bc.Hash]error),
        }
-
-       err := checkValid(vs, tx.TxHeader)
-       return vs.gasStatus, err
+       return vs.gasStatus, checkValid(vs, tx.TxHeader)
 }
similarity index 95%
rename from protocol/validation/validation_test.go
rename to protocol/validation/tx_test.go
index 8a5f80a..dd2efb5 100644 (file)
@@ -1,12 +1,10 @@
 package validation
 
 import (
-       "fmt"
        "math"
        "testing"
 
        "github.com/davecgh/go-spew/spew"
-       "github.com/golang/protobuf/proto"
 
        "github.com/bytom/consensus"
        "github.com/bytom/crypto/sha3pool"
@@ -141,13 +139,6 @@ func TestTxValidation(t *testing.T) {
                        desc: "base case",
                },
                {
-                       desc: "failing mux program",
-                       f: func() {
-                               mux.Program.Code = []byte{byte(vm.OP_FALSE)}
-                       },
-                       err: vm.ErrFalseVMResult,
-               },
-               {
                        desc: "unbalanced mux amounts",
                        f: func() {
                                mux.Sources[0].Value.Amount++
@@ -389,34 +380,6 @@ func TestTimeRange(t *testing.T) {
        }
 }
 
-func TestBlockHeaderValid(t *testing.T) {
-       base := bc.NewBlockHeader(1, 1, &bc.Hash{}, 1, &bc.Hash{}, &bc.Hash{}, 0, 0)
-       baseBytes, _ := proto.Marshal(base)
-
-       var bh bc.BlockHeader
-
-       cases := []struct {
-               f   func()
-               err error
-       }{
-               {},
-               {
-                       f: func() {
-                               bh.Version = 2
-                       },
-               },
-       }
-
-       for i, c := range cases {
-               t.Run(fmt.Sprintf("case %d", i), func(t *testing.T) {
-                       proto.Unmarshal(baseBytes, &bh)
-                       if c.f != nil {
-                               c.f()
-                       }
-               })
-       }
-}
-
 // A txFixture is returned by sample (below) to produce a sample
 // transaction, which takes a separate, optional _input_ txFixture to
 // affect the transaction that's built. The components of the
index c9e3996..3325297 100644 (file)
@@ -79,15 +79,11 @@ func TestCheckOutput(t *testing.T) {
                t.Run(fmt.Sprintf("case %d", i), func(t *testing.T) {
                        gotOk, err := txCtx.checkOutput(test.index, test.amount, test.assetID, test.vmVersion, test.code, false)
                        if g := errors.Root(err); g != test.wantErr {
-                               t.Errorf("checkOutput(%v, %v, %x, %v, %x) err = %v, want %v",
-                                       test.index, test.amount, test.assetID, test.vmVersion, test.code,
-                                       g, test.wantErr)
+                               t.Errorf("checkOutput(%v, %v, %x, %v, %x) err = %v, want %v", test.index, test.amount, test.assetID, test.vmVersion, test.code, g, test.wantErr)
                                return
                        }
                        if gotOk != test.wantOk {
-                               t.Errorf("checkOutput(%v, %v, %x, %v, %x) ok = %t, want %v",
-                                       test.index, test.amount, test.assetID, test.vmVersion, test.code,
-                                       gotOk, test.wantOk)
+                               t.Errorf("checkOutput(%v, %v, %x, %v, %x) ok = %t, want %v", test.index, test.amount, test.assetID, test.vmVersion, test.code, gotOk, test.wantOk)
                        }
 
                })
index 6c3a619..b58335a 100644 (file)
@@ -7,7 +7,6 @@ import (
        dbm "github.com/tendermint/tmlibs/db"
 
        "github.com/bytom/account"
-       "github.com/bytom/config"
        "github.com/bytom/mining"
        "github.com/bytom/test"
 )
@@ -24,9 +23,6 @@ func BenchmarkNewBlockTpl(b *testing.B) {
        accountManager := account.NewManager(testDB, chain)
 
        txPool := test.MockTxPool()
-       genesisBlock := config.GenerateGenesisBlock()
-       chain.SaveBlock(genesisBlock)
-       chain.ConnectBlock(genesisBlock)
 
        b.ResetTimer()
        for i := 0; i < b.N; i++ {