OSDN Git Service

netsync add test case (#365)
[bytom/vapor.git] / test / mock / chain.go
index b1601b1..69ef7c4 100644 (file)
@@ -4,23 +4,34 @@ import (
        "errors"
        "math/rand"
 
+       "github.com/vapor/protocol"
        "github.com/vapor/protocol/bc"
        "github.com/vapor/protocol/bc/types"
 )
 
+var (
+       ErrFoundHeaderByHash   = errors.New("can't find header by hash")
+       ErrFoundHeaderByHeight = errors.New("can't find header by height")
+)
+
+type mempool interface {
+       AddTx(tx *types.Tx)
+}
+
 type Chain struct {
        bestBlockHeader *types.BlockHeader
        heightMap       map[uint64]*types.Block
        blockMap        map[bc.Hash]*types.Block
-
-       prevOrphans map[bc.Hash]*types.Block
+       prevOrphans     map[bc.Hash]*types.Block
+       mempool         mempool
 }
 
-func NewChain() *Chain {
+func NewChain(mempool *Mempool) *Chain {
        return &Chain{
                heightMap:   map[uint64]*types.Block{},
                blockMap:    map[bc.Hash]*types.Block{},
                prevOrphans: make(map[bc.Hash]*types.Block),
+               mempool:     mempool,
        }
 }
 
@@ -32,6 +43,10 @@ func (c *Chain) BestBlockHeight() uint64 {
        return c.bestBlockHeader.Height
 }
 
+func (c *Chain) LastIrreversibleHeader() *types.BlockHeader {
+       return c.bestBlockHeader
+}
+
 func (c *Chain) CalcNextSeed(hash *bc.Hash) (*bc.Hash, error) {
        return &bc.Hash{V0: hash.V1, V1: hash.V2, V2: hash.V3, V3: hash.V0}, nil
 }
@@ -55,7 +70,7 @@ func (c *Chain) GetBlockByHeight(height uint64) (*types.Block, error) {
 func (c *Chain) GetHeaderByHash(hash *bc.Hash) (*types.BlockHeader, error) {
        block, ok := c.blockMap[*hash]
        if !ok {
-               return nil, errors.New("can't find block")
+               return nil, ErrFoundHeaderByHash
        }
        return &block.BlockHeader, nil
 }
@@ -63,7 +78,7 @@ func (c *Chain) GetHeaderByHash(hash *bc.Hash) (*types.BlockHeader, error) {
 func (c *Chain) GetHeaderByHeight(height uint64) (*types.BlockHeader, error) {
        block, ok := c.heightMap[height]
        if !ok {
-               return nil, errors.New("can't find block")
+               return nil, ErrFoundHeaderByHeight
        }
        return &block.BlockHeader, nil
 }
@@ -98,6 +113,10 @@ func (c *Chain) InMainChain(hash bc.Hash) bool {
 }
 
 func (c *Chain) ProcessBlock(block *types.Block) (bool, error) {
+       if block.TransactionsMerkleRoot == bc.NewHash([32]byte{0x1}) {
+               return false, protocol.ErrBadStateRoot
+       }
+
        if c.bestBlockHeader.Hash() == block.PreviousBlockHash {
                c.heightMap[block.Height] = block
                c.blockMap[block.Hash()] = block
@@ -137,6 +156,7 @@ func (c *Chain) SetBlockByHeight(height uint64, block *types.Block) {
        c.blockMap[block.Hash()] = block
 }
 
-func (c *Chain) ValidateTx(*types.Tx) (bool, error) {
+func (c *Chain) ValidateTx(tx *types.Tx) (bool, error) {
+       c.mempool.AddTx(tx)
        return false, nil
 }