OSDN Git Service

spv merkle tree proof (#1262)
authormuscle_boy <shenao.78@163.com>
Thu, 23 Aug 2018 02:50:08 +0000 (10:50 +0800)
committerPaladz <yzhu101@uottawa.ca>
Thu, 23 Aug 2018 02:50:08 +0000 (10:50 +0800)
* the transaction output amout prohibit set zero

* add network access control api

* format import code style

* refactor

* code refactor

* bug fix

* the struct node_info add json field

* estimate gas support multi-sign

* add testcase of estimate gas

* add testcase

* bug fix

* add test case

* test case refactor

* list-tx,list-address,list-utxo support partition

* list-addresses list-tx list-utxo support pagging

* refactor pagging

* fix save asset

* fix save external assets

* remove blank

* remove useless context

* remove redudant web address config

* fix bug

* remove useless ctx

* add spv message struct

* remove redundant

* refactor message struct

* refactor message struct

* add filter load message handler

* add debug log

* bug fix spv

* bug fix

* bug fix

* refactor

* refactor

* add merkle proof

* add merkle flags test case

* add multiset

* bug fix and refactor

* bug fix

* remove redundant code

* bug fix

* bug fix

* format code

* refactor merkle tree

* refactor

* refactor

* fix bug for make test

* bug fix

* move merkle tree to bc level

* NewMinedBlockMessage not broadcast to the spv node

* refactor

* refactor

* refactor

* merkle tree bug fix

* merkle tree bug fix

14 files changed:
config/genesis.go
consensus/server_flag.go
mining/mining.go
netsync/handle.go
netsync/message.go
netsync/peer.go
protocol/bc/merkle.go [deleted file]
protocol/bc/merkle_test.go [deleted file]
protocol/bc/types/merkle.go [new file with mode: 0644]
protocol/bc/types/merkle_test.go [new file with mode: 0644]
protocol/validation/block.go
test/block_test_util.go
test/chain_test_util.go
test/mock/chain.go

index 38ee157..4ca222b 100644 (file)
@@ -32,12 +32,12 @@ func mainNetGenesisBlock() *types.Block {
        tx := genesisTx()
        txStatus := bc.NewTransactionStatus()
        txStatus.SetStatus(0, false)
-       txStatusHash, err := bc.TxStatusMerkleRoot(txStatus.VerifyStatus)
+       txStatusHash, err := types.TxStatusMerkleRoot(txStatus.VerifyStatus)
        if err != nil {
                log.Panicf("fail on calc genesis tx status merkle root")
        }
 
-       merkleRoot, err := bc.TxMerkleRoot([]*bc.Tx{tx.Tx})
+       merkleRoot, err := types.TxMerkleRoot([]*bc.Tx{tx.Tx})
        if err != nil {
                log.Panicf("fail on calc genesis tx merkel root")
        }
@@ -63,12 +63,12 @@ func testNetGenesisBlock() *types.Block {
        tx := genesisTx()
        txStatus := bc.NewTransactionStatus()
        txStatus.SetStatus(0, false)
-       txStatusHash, err := bc.TxStatusMerkleRoot(txStatus.VerifyStatus)
+       txStatusHash, err := types.TxStatusMerkleRoot(txStatus.VerifyStatus)
        if err != nil {
                log.Panicf("fail on calc genesis tx status merkle root")
        }
 
-       merkleRoot, err := bc.TxMerkleRoot([]*bc.Tx{tx.Tx})
+       merkleRoot, err := types.TxMerkleRoot([]*bc.Tx{tx.Tx})
        if err != nil {
                log.Panicf("fail on calc genesis tx merkel root")
        }
@@ -94,12 +94,12 @@ func soloNetGenesisBlock() *types.Block {
        tx := genesisTx()
        txStatus := bc.NewTransactionStatus()
        txStatus.SetStatus(0, false)
-       txStatusHash, err := bc.TxStatusMerkleRoot(txStatus.VerifyStatus)
+       txStatusHash, err := types.TxStatusMerkleRoot(txStatus.VerifyStatus)
        if err != nil {
                log.Panicf("fail on calc genesis tx status merkle root")
        }
 
-       merkleRoot, err := bc.TxMerkleRoot([]*bc.Tx{tx.Tx})
+       merkleRoot, err := types.TxMerkleRoot([]*bc.Tx{tx.Tx})
        if err != nil {
                log.Panicf("fail on calc genesis tx merkel root")
        }
index 5346026..1679908 100644 (file)
@@ -9,8 +9,10 @@ const (
        SFFullNode ServiceFlag = 1 << iota
        // SFFastSync indicate peer support header first mode
        SFFastSync
+       // SFSPV indicate peer support spv mode
+       SFSPV
        // DefaultServices is the server that this node support
-       DefaultServices = SFFullNode | SFFastSync
+       DefaultServices = SFFullNode | SFFastSync | SFSPV
 )
 
 // IsEnable check does the flag support the input flag function
index bb18122..b4058b4 100644 (file)
@@ -147,11 +147,11 @@ func NewBlockTemplate(c *protocol.Chain, txPool *protocol.TxPool, accountManager
        }
        txEntries[0] = b.Transactions[0].Tx
 
-       b.BlockHeader.BlockCommitment.TransactionsMerkleRoot, err = bc.TxMerkleRoot(txEntries)
+       b.BlockHeader.BlockCommitment.TransactionsMerkleRoot, err = types.TxMerkleRoot(txEntries)
        if err != nil {
                return nil, err
        }
 
-       b.BlockHeader.BlockCommitment.TransactionStatusHash, err = bc.TxStatusMerkleRoot(txStatus.VerifyStatus)
+       b.BlockHeader.BlockCommitment.TransactionStatusHash, err = types.TxStatusMerkleRoot(txStatus.VerifyStatus)
        return b, err
 }
index 04eadda..40e0085 100644 (file)
@@ -36,6 +36,7 @@ type Chain interface {
        GetBlockByHeight(uint64) (*types.Block, error)
        GetHeaderByHash(*bc.Hash) (*types.BlockHeader, error)
        GetHeaderByHeight(uint64) (*types.BlockHeader, error)
+       GetTransactionStatus(*bc.Hash) (*bc.TransactionStatus, error)
        InMainChain(bc.Hash) bool
        ProcessBlock(*types.Block) (bool, error)
        ValidateTx(*types.Tx) (bool, error)
@@ -165,6 +166,10 @@ func (sm *SyncManager) handleBlocksMsg(peer *peer, msg *BlocksMessage) {
        sm.blockKeeper.processBlocks(peer.ID(), blocks)
 }
 
+func (sm *SyncManager) handleFilterAddMsg(peer *peer, msg *FilterAddMessage) {
+       peer.filterAdds.Add(hex.EncodeToString(msg.Address))
+}
+
 func (sm *SyncManager) handleFilterClearMsg(peer *peer) {
        peer.filterAdds.Clear()
 }
@@ -246,7 +251,36 @@ func (sm *SyncManager) handleGetHeadersMsg(peer *peer, msg *GetHeadersMessage) {
        }
 }
 
-func (sm *SyncManager) handleGetMerkleBlockMsg(peer *peer, msg *GetMerkleBlockMessage) {}
+func (sm *SyncManager) handleGetMerkleBlockMsg(peer *peer, msg *GetMerkleBlockMessage) {
+       var err error
+       var block *types.Block 
+       if msg.Height != 0 {
+               block, err = sm.chain.GetBlockByHeight(msg.Height)
+       } else {
+               block, err = sm.chain.GetBlockByHash(msg.GetHash())
+       }
+       if err != nil {
+               log.WithField("err", err).Warning("fail on handleGetMerkleBlockMsg get block from chain")
+               return
+       }
+
+       blockHash := block.Hash()
+       txStatus, err := sm.chain.GetTransactionStatus(&blockHash)
+       if err != nil {
+               log.WithField("err", err).Warning("fail on handleGetMerkleBlockMsg get transaction status")
+               return
+       }
+
+       ok, err := peer.sendMerkleBlock(block, txStatus)
+       if err != nil {
+               log.WithField("err", err).Error("fail on handleGetMerkleBlockMsg sentMerkleBlock")
+               return
+       }
+
+       if !ok {
+               sm.peers.removePeer(peer.ID())
+       }
+}
 
 func (sm *SyncManager) handleHeadersMsg(peer *peer, msg *HeadersMessage) {
        headers, err := msg.GetHeaders()
@@ -354,6 +388,9 @@ func (sm *SyncManager) processMsg(basePeer BasePeer, msgType byte, msg Blockchai
        case *FilterLoadMessage:
                sm.handleFilterLoadMsg(peer, msg)
 
+       case *FilterAddMessage:
+               sm.handleFilterAddMsg(peer, msg)
+
        case *FilterClearMessage:
                sm.handleFilterClearMsg(peer)
 
index 3549955..67fe2ed 100644 (file)
@@ -27,7 +27,8 @@ const (
        NewTransactionByte  = byte(0x30)
        NewMineBlockByte    = byte(0x40)
        FilterLoadByte      = byte(0x50)
-       FilterClearByte     = byte(0x51)
+       FilterAddByte       = byte(0x51)
+       FilterClearByte     = byte(0x52)
        MerkleRequestByte   = byte(0x60)
        MerkleResponseByte  = byte(0x61)
 
@@ -50,6 +51,7 @@ var _ = wire.RegisterInterface(
        wire.ConcreteType{&TransactionMessage{}, NewTransactionByte},
        wire.ConcreteType{&MineBlockMessage{}, NewMineBlockByte},
        wire.ConcreteType{&FilterLoadMessage{}, FilterLoadByte},
+       wire.ConcreteType{&FilterAddMessage{}, FilterAddByte},
        wire.ConcreteType{&FilterClearMessage{}, FilterClearByte},
        wire.ConcreteType{&GetMerkleBlockMessage{}, MerkleRequestByte},
        wire.ConcreteType{&MerkleBlockMessage{}, MerkleResponseByte},
@@ -353,6 +355,11 @@ type FilterLoadMessage struct {
        Addresses [][]byte
 }
 
+// FilterAddMessage tells the receiving peer to add address to the filter.
+type FilterAddMessage struct {
+       Address []byte
+}
+
 //FilterClearMessage tells the receiving peer to remove a previously-set filter.
 type FilterClearMessage struct{}
 
@@ -362,14 +369,65 @@ type GetMerkleBlockMessage struct {
        RawHash [32]byte
 }
 
+//GetHash reutrn the hash of the request
+func (m *GetMerkleBlockMessage) GetHash() *bc.Hash {
+       hash := bc.NewHash(m.RawHash)
+       return &hash
+}
+
 //MerkleBlockMessage return the merkle block to client
 type MerkleBlockMessage struct {
-       RawBlockHeader   []byte
-       TransactionCount uint64
-       TxHashes         [][32]byte
-       TxFlags          []byte
-       RawTxDatas       [][]byte
-       StatusHashes     [][32]byte
-       StatusFlags      []byte
-       RawTxStatuses    [][]byte
+       RawBlockHeader []byte
+       TxHashes       [][32]byte
+       RawTxDatas     [][]byte
+       StatusHashes   [][32]byte
+       RawTxStatuses  [][]byte
+       Flags          []byte
+}
+
+func (msg *MerkleBlockMessage) setRawBlockHeader(bh types.BlockHeader) error {
+       rawHeader, err := bh.MarshalText()
+       if err != nil {
+               return err
+       }
+
+       msg.RawBlockHeader = rawHeader
+       return nil
+}
+
+func (msg *MerkleBlockMessage) setTxInfo(txHashes []*bc.Hash, txFlags []uint8, relatedTxs []*types.Tx) error {
+       for _, txHash := range txHashes {
+               msg.TxHashes = append(msg.TxHashes, txHash.Byte32())
+       }
+       for _, tx := range relatedTxs {
+               rawTxData, err := tx.MarshalText()
+               if err != nil {
+                       return err
+               }
+
+               msg.RawTxDatas = append(msg.RawTxDatas, rawTxData)
+       }
+       msg.Flags = txFlags
+       return nil
+}
+
+func (msg *MerkleBlockMessage) setStatusInfo(statusHashes []*bc.Hash, relatedStatuses []*bc.TxVerifyResult) error {
+       for _, statusHash := range statusHashes {
+               msg.StatusHashes = append(msg.StatusHashes, statusHash.Byte32())
+       }
+
+       for _, status := range relatedStatuses {
+               rawStatusData, err := json.Marshal(status)
+               if err != nil {
+                       return err
+               }
+
+               msg.RawTxStatuses = append(msg.RawTxStatuses, rawStatusData)
+       }
+       return nil
+}
+
+//NewMerkleBlockMessage construct merkle block message
+func NewMerkleBlockMessage() *MerkleBlockMessage {
+       return &MerkleBlockMessage{}
 }
index 240eacc..8a7a1c2 100644 (file)
@@ -1,9 +1,9 @@
 package netsync
 
 import (
+       "encoding/hex"
        "net"
        "sync"
-       "encoding/hex"
 
        log "github.com/sirupsen/logrus"
        "gopkg.in/fatih/set.v0"
@@ -90,8 +90,8 @@ func (p *peer) addBanScore(persistent, transient uint64, reason string) bool {
 func (p *peer) addFilterAddresses(addresses [][]byte) {
        p.mtx.Lock()
        defer p.mtx.Unlock()
-       
-       if (!p.filterAdds.IsEmpty()) {
+
+       if !p.filterAdds.IsEmpty() {
                p.filterAdds.Clear()
        }
        for _, address := range addresses {
@@ -124,8 +124,20 @@ func (p *peer) getPeerInfo() *PeerInfo {
        }
 }
 
+func (p *peer) getRelatedTxAndStatus(txs []*types.Tx, txStatuses *bc.TransactionStatus) ([]*types.Tx, []*bc.TxVerifyResult) {
+       var relatedTxs []*types.Tx
+       var relatedStatuses []*bc.TxVerifyResult
+       for i, tx := range txs {
+               if p.isRelatedTx(tx) {
+                       relatedTxs = append(relatedTxs, tx)
+                       relatedStatuses = append(relatedStatuses, txStatuses.VerifyStatus[i])
+               }
+       }
+       return relatedTxs, relatedStatuses
+}
+
 func (p *peer) isRelatedTx(tx *types.Tx) bool {
-       for _, input := range(tx.Inputs) {
+       for _, input := range tx.Inputs {
                switch inp := input.TypedInput.(type) {
                case *types.SpendInput:
                        if p.filterAdds.Has(hex.EncodeToString(inp.ControlProgram)) {
@@ -133,7 +145,7 @@ func (p *peer) isRelatedTx(tx *types.Tx) bool {
                        }
                }
        }
-       for _, output := range(tx.Outputs) {
+       for _, output := range tx.Outputs {
                if p.filterAdds.Has(hex.EncodeToString(output.ControlProgram)) {
                        return true
                }
@@ -206,6 +218,28 @@ func (p *peer) sendHeaders(headers []*types.BlockHeader) (bool, error) {
        return ok, nil
 }
 
+func (p *peer) sendMerkleBlock(block *types.Block, txStatuses *bc.TransactionStatus) (bool, error) {
+       msg := NewMerkleBlockMessage()
+       if err := msg.setRawBlockHeader(block.BlockHeader); err != nil {
+               return false, err
+       }
+
+       relatedTxs, relatedStatuses := p.getRelatedTxAndStatus(block.Transactions, txStatuses)
+
+       txHashes, txFlags := types.GetTxMerkleTreeProof(block.Transactions, relatedTxs)
+       if err := msg.setTxInfo(txHashes, txFlags, relatedTxs); err != nil {
+               return false, nil
+       }
+       
+       statusHashes := types.GetStatusMerkleTreeProof(txStatuses.VerifyStatus, txFlags)
+       if err := msg.setStatusInfo(statusHashes, relatedStatuses); err != nil {
+               return false, nil
+       }
+
+       ok := p.TrySend(BlockchainChannel, struct{ BlockchainMessage }{msg})
+       return ok, nil
+}
+
 func (p *peer) sendTransactions(txs []*types.Tx) (bool, error) {
        for _, tx := range txs {
                if p.isSPVNode() && !p.isRelatedTx(tx) {
@@ -301,6 +335,9 @@ func (ps *peerSet) broadcastMinedBlock(block *types.Block) error {
        hash := block.Hash()
        peers := ps.peersWithoutBlock(&hash)
        for _, peer := range peers {
+               if peer.isSPVNode() {
+                       continue
+               }
                if ok := peer.TrySend(BlockchainChannel, struct{ BlockchainMessage }{msg}); !ok {
                        ps.removePeer(peer.ID())
                        continue
diff --git a/protocol/bc/merkle.go b/protocol/bc/merkle.go
deleted file mode 100644 (file)
index 054d443..0000000
+++ /dev/null
@@ -1,86 +0,0 @@
-package bc
-
-import (
-       "io"
-       "math"
-
-       "github.com/bytom/crypto/sha3pool"
-)
-
-var (
-       leafPrefix     = []byte{0x00}
-       interiorPrefix = []byte{0x01}
-)
-
-type merkleNode interface {
-       WriteTo(io.Writer) (int64, error)
-}
-
-func merkleRoot(nodes []merkleNode) (root Hash, err error) {
-       switch {
-       case len(nodes) == 0:
-               return EmptyStringHash, nil
-
-       case len(nodes) == 1:
-               h := sha3pool.Get256()
-               defer sha3pool.Put256(h)
-
-               h.Write(leafPrefix)
-               nodes[0].WriteTo(h)
-               root.ReadFrom(h)
-               return root, nil
-
-       default:
-               k := prevPowerOfTwo(len(nodes))
-               left, err := merkleRoot(nodes[:k])
-               if err != nil {
-                       return root, err
-               }
-
-               right, err := merkleRoot(nodes[k:])
-               if err != nil {
-                       return root, err
-               }
-
-               h := sha3pool.Get256()
-               defer sha3pool.Put256(h)
-               h.Write(interiorPrefix)
-               left.WriteTo(h)
-               right.WriteTo(h)
-               root.ReadFrom(h)
-               return root, nil
-       }
-}
-
-// TxStatusMerkleRoot creates a merkle tree from a slice of TxVerifyResult
-func TxStatusMerkleRoot(tvrs []*TxVerifyResult) (root Hash, err error) {
-       nodes := []merkleNode{}
-       for _, tvr := range tvrs {
-               nodes = append(nodes, tvr)
-       }
-       return merkleRoot(nodes)
-}
-
-// TxMerkleRoot creates a merkle tree from a slice of transactions
-// and returns the root hash of the tree.
-func TxMerkleRoot(transactions []*Tx) (root Hash, err error) {
-       nodes := []merkleNode{}
-       for _, tx := range transactions {
-               nodes = append(nodes, tx.ID)
-       }
-       return merkleRoot(nodes)
-}
-
-// prevPowerOfTwo returns the largest power of two that is smaller than a given number.
-// In other words, for some input n, the prevPowerOfTwo k is a power of two such that
-// k < n <= 2k. This is a helper function used during the calculation of a merkle tree.
-func prevPowerOfTwo(n int) int {
-       // If the number is a power of two, divide it by 2 and return.
-       if n&(n-1) == 0 {
-               return n / 2
-       }
-
-       // Otherwise, find the previous PoT.
-       exponent := uint(math.Log2(float64(n)))
-       return 1 << exponent // 2^exponent
-}
diff --git a/protocol/bc/merkle_test.go b/protocol/bc/merkle_test.go
deleted file mode 100644 (file)
index 1452b02..0000000
+++ /dev/null
@@ -1,144 +0,0 @@
-package bc_test
-
-import (
-       "testing"
-       "time"
-
-       . "github.com/bytom/protocol/bc"
-       "github.com/bytom/protocol/bc/types"
-       "github.com/bytom/protocol/vm"
-       "github.com/bytom/testutil"
-)
-
-func TestMerkleRoot(t *testing.T) {
-       cases := []struct {
-               witnesses [][][]byte
-               want      Hash
-       }{{
-               witnesses: [][][]byte{
-                       {
-                               {1},
-                               []byte("00000"),
-                       },
-               },
-               want: testutil.MustDecodeHash("fe34dbd5da0ce3656f423fd7aad7fc7e879353174d33a6446c2ed0e3f3512101"),
-       }, {
-               witnesses: [][][]byte{
-                       {
-                               {1},
-                               []byte("000000"),
-                       },
-                       {
-                               {1},
-                               []byte("111111"),
-                       },
-               },
-               want: testutil.MustDecodeHash("0e4b4c1af18b8f59997804d69f8f66879ad5e30027346ee003ff7c7a512e5554"),
-       }, {
-               witnesses: [][][]byte{
-                       {
-                               {1},
-                               []byte("000000"),
-                       },
-                       {
-                               {2},
-                               []byte("111111"),
-                               []byte("222222"),
-                       },
-               },
-               want: testutil.MustDecodeHash("0e4b4c1af18b8f59997804d69f8f66879ad5e30027346ee003ff7c7a512e5554"),
-       }}
-
-       for _, c := range cases {
-               var txs []*Tx
-               for _, wit := range c.witnesses {
-                       txs = append(txs, types.NewTx(types.TxData{
-                               Inputs: []*types.TxInput{
-                                       &types.TxInput{
-                                               AssetVersion: 1,
-                                               TypedInput: &types.SpendInput{
-                                                       Arguments: wit,
-                                                       SpendCommitment: types.SpendCommitment{
-                                                               AssetAmount: AssetAmount{
-                                                                       AssetId: &AssetID{V0: 0},
-                                                               },
-                                                       },
-                                               },
-                                       },
-                               },
-                       }).Tx)
-               }
-               got, err := TxMerkleRoot(txs)
-               if err != nil {
-                       t.Fatalf("unexpected error %s", err)
-               }
-               if got != c.want {
-                       t.Log("witnesses", c.witnesses)
-                       t.Errorf("got merkle root = %x want %x", got.Bytes(), c.want.Bytes())
-               }
-       }
-}
-
-func TestDuplicateLeaves(t *testing.T) {
-       trueProg := []byte{byte(vm.OP_TRUE)}
-       assetID := ComputeAssetID(trueProg, 1, &EmptyStringHash)
-       txs := make([]*Tx, 6)
-       for i := uint64(0); i < 6; i++ {
-               now := []byte(time.Now().String())
-               txs[i] = types.NewTx(types.TxData{
-                       Version: 1,
-                       Inputs:  []*types.TxInput{types.NewIssuanceInput(now, i, trueProg, nil, nil)},
-                       Outputs: []*types.TxOutput{types.NewTxOutput(assetID, i, trueProg)},
-               }).Tx
-       }
-
-       // first, get the root of an unbalanced tree
-       txns := []*Tx{txs[5], txs[4], txs[3], txs[2], txs[1], txs[0]}
-       root1, err := TxMerkleRoot(txns)
-       if err != nil {
-               t.Fatalf("unexpected error %s", err)
-       }
-
-       // now, get the root of a balanced tree that repeats leaves 0 and 1
-       txns = []*Tx{txs[5], txs[4], txs[3], txs[2], txs[1], txs[0], txs[1], txs[0]}
-       root2, err := TxMerkleRoot(txns)
-       if err != nil {
-               t.Fatalf("unexpected error %s", err)
-       }
-
-       if root1 == root2 {
-               t.Error("forged merkle tree by duplicating some leaves")
-       }
-}
-
-func TestAllDuplicateLeaves(t *testing.T) {
-       trueProg := []byte{byte(vm.OP_TRUE)}
-       assetID := ComputeAssetID(trueProg, 1, &EmptyStringHash)
-       now := []byte(time.Now().String())
-       issuanceInp := types.NewIssuanceInput(now, 1, trueProg, nil, nil)
-
-       tx := types.NewTx(types.TxData{
-               Version: 1,
-               Inputs:  []*types.TxInput{issuanceInp},
-               Outputs: []*types.TxOutput{types.NewTxOutput(assetID, 1, trueProg)},
-       }).Tx
-       tx1, tx2, tx3, tx4, tx5, tx6 := tx, tx, tx, tx, tx, tx
-
-       // first, get the root of an unbalanced tree
-       txs := []*Tx{tx6, tx5, tx4, tx3, tx2, tx1}
-       root1, err := TxMerkleRoot(txs)
-       if err != nil {
-               t.Fatalf("unexpected error %s", err)
-       }
-
-       // now, get the root of a balanced tree that repeats leaves 5 and 6
-       txs = []*Tx{tx6, tx5, tx6, tx5, tx4, tx3, tx2, tx1}
-       root2, err := TxMerkleRoot(txs)
-       if err != nil {
-               t.Fatalf("unexpected error %s", err)
-       }
-
-       if root1 == root2 {
-               t.Error("forged merkle tree with all duplicate leaves")
-       }
-}
diff --git a/protocol/bc/types/merkle.go b/protocol/bc/types/merkle.go
new file mode 100644 (file)
index 0000000..dca07d4
--- /dev/null
@@ -0,0 +1,334 @@
+package types
+
+import (
+       "container/list"
+       "io"
+       "math"
+
+       "gopkg.in/fatih/set.v0"
+
+       "github.com/bytom/crypto/sha3pool"
+       "github.com/bytom/protocol/bc"
+)
+
+// merkleFlag represent the type of merkle tree node, it's used to generate the structure of merkle tree
+// Bitcoin has only two flags, which zero means the hash of assist node. And one means the hash of the related
+// transaction node or it's parents, which distinguish them according to the height of the tree. But in the bytom,
+// the height of transaction node is not fixed, so we need three flags to distinguish these nodes.
+const (
+       // FlagAssist represent assist node
+       FlagAssist = iota
+       // FlagTxParent represent the parent of transaction of node
+       FlagTxParent
+       // FlagTxLeaf represent transaction of node
+       FlagTxLeaf
+)
+
+var (
+       leafPrefix     = []byte{0x00}
+       interiorPrefix = []byte{0x01}
+)
+
+type merkleNode interface {
+       WriteTo(io.Writer) (int64, error)
+}
+
+func merkleRoot(nodes []merkleNode) (root bc.Hash, err error) {
+       switch {
+       case len(nodes) == 0:
+               return bc.EmptyStringHash, nil
+
+       case len(nodes) == 1:
+               root = leafMerkleHash(nodes[0])
+               return root, nil
+
+       default:
+               k := prevPowerOfTwo(len(nodes))
+               left, err := merkleRoot(nodes[:k])
+               if err != nil {
+                       return root, err
+               }
+
+               right, err := merkleRoot(nodes[k:])
+               if err != nil {
+                       return root, err
+               }
+
+               root = interiorMerkleHash(&left, &right)
+               return root, nil
+       }
+}
+
+func interiorMerkleHash(left merkleNode, right merkleNode) (hash bc.Hash) {
+       h := sha3pool.Get256()
+       defer sha3pool.Put256(h)
+       h.Write(interiorPrefix)
+       left.WriteTo(h)
+       right.WriteTo(h)
+       hash.ReadFrom(h)
+       return hash
+}
+
+func leafMerkleHash(node merkleNode) (hash bc.Hash) {
+       h := sha3pool.Get256()
+       defer sha3pool.Put256(h)
+       h.Write(leafPrefix)
+       node.WriteTo(h)
+       hash.ReadFrom(h)
+       return hash
+}
+
+type merkleTreeNode struct {
+       hash  bc.Hash
+       left  *merkleTreeNode
+       right *merkleTreeNode
+}
+
+// buildMerkleTree construct a merkle tree based on the provide node data
+func buildMerkleTree(rawDatas []merkleNode) *merkleTreeNode {
+       switch len(rawDatas) {
+       case 0:
+               return nil
+       case 1:
+               rawData := rawDatas[0]
+               merkleHash := leafMerkleHash(rawData)
+               node := newMerkleTreeNode(merkleHash, nil, nil)
+               return node
+       default:
+               k := prevPowerOfTwo(len(rawDatas))
+               left := buildMerkleTree(rawDatas[:k])
+               right := buildMerkleTree(rawDatas[k:])
+               merkleHash := interiorMerkleHash(&left.hash, &right.hash)
+               node := newMerkleTreeNode(merkleHash, left, right)
+               return node
+       }
+}
+
+func (node *merkleTreeNode) getMerkleTreeProof(merkleHashSet *set.Set) ([]*bc.Hash, []uint8) {
+       var hashes []*bc.Hash
+       var flags []uint8
+
+       if node.left == nil && node.right == nil {
+               if key := node.hash.String(); merkleHashSet.Has(key) {
+                       hashes = append(hashes, &node.hash)
+                       flags = append(flags, FlagTxLeaf)
+                       return hashes, flags
+               }
+               return hashes, flags
+       }
+       var leftHashes, rightHashes []*bc.Hash
+       var leftFlags, rightFlags []uint8
+       if node.left != nil {
+               leftHashes, leftFlags = node.left.getMerkleTreeProof(merkleHashSet)
+       }
+       if node.right != nil {
+               rightHashes, rightFlags = node.right.getMerkleTreeProof(merkleHashSet)
+       }
+       leftFind, rightFind := len(leftHashes) > 0, len(rightHashes) > 0
+
+       if leftFind || rightFind {
+               flags = append(flags, FlagTxParent)
+       } else {
+               return hashes, flags
+       }
+
+       if leftFind {
+               hashes = append(hashes, leftHashes...)
+               flags = append(flags, leftFlags...)
+       } else {
+               hashes = append(hashes, &node.left.hash)
+               flags = append(flags, FlagAssist)
+       }
+
+       if rightFind {
+               hashes = append(hashes, rightHashes...)
+               flags = append(flags, rightFlags...)
+       } else {
+               hashes = append(hashes, &node.right.hash)
+               flags = append(flags, FlagAssist)
+       }
+       return hashes, flags
+}
+
+func getMerkleTreeProof(rawDatas []merkleNode, relatedRawDatas []merkleNode) ([]*bc.Hash, []uint8) {
+       merkleTree := buildMerkleTree(rawDatas)
+       if merkleTree == nil {
+               return []*bc.Hash{}, []uint8{}
+       }
+       merkleHashSet := set.New()
+       for _, data := range relatedRawDatas {
+               merkleHash := leafMerkleHash(data)
+               merkleHashSet.Add(merkleHash.String())
+       }
+       return merkleTree.getMerkleTreeProof(merkleHashSet)
+}
+
+func (node *merkleTreeNode) getMerkleTreeProofByFlags(flagList *list.List) []*bc.Hash {
+       var hashes []*bc.Hash
+
+       if flagList.Len() == 0 {
+               return hashes
+       }
+       flagEle := flagList.Front()
+       flag := flagEle.Value.(uint8)
+       flagList.Remove(flagEle)
+
+       if flag == FlagTxLeaf || flag == FlagAssist {
+               hashes = append(hashes, &node.hash)
+               return hashes
+       }
+       if node.left != nil {
+               leftHashes := node.left.getMerkleTreeProofByFlags(flagList)
+               hashes = append(hashes, leftHashes...)
+       }
+       if node.right != nil {
+               rightHashes := node.right.getMerkleTreeProofByFlags(flagList)
+               hashes = append(hashes, rightHashes...)
+       }
+       return hashes
+}
+
+func getMerkleTreeProofByFlags(rawDatas []merkleNode, flagList *list.List) []*bc.Hash {
+       tree := buildMerkleTree(rawDatas)
+       return tree.getMerkleTreeProofByFlags(flagList)
+}
+
+// GetTxMerkleTreeProof return a proof of merkle tree, which used to proof the transaction does
+// exist in the merkle tree
+func GetTxMerkleTreeProof(txs []*Tx, relatedTxs []*Tx) ([]*bc.Hash, []uint8) {
+       var rawDatas []merkleNode
+       var relatedRawDatas []merkleNode
+       for _, tx := range txs {
+               rawDatas = append(rawDatas, &tx.ID)
+       }
+       for _, relatedTx := range relatedTxs {
+               relatedRawDatas = append(relatedRawDatas, &relatedTx.ID)
+       }
+       return getMerkleTreeProof(rawDatas, relatedRawDatas)
+}
+
+// GetStatusMerkleTreeProof return a proof of merkle tree, which used to proof the status of transaction is valid
+func GetStatusMerkleTreeProof(statuses []*bc.TxVerifyResult, flags []uint8) []*bc.Hash {
+       var rawDatas []merkleNode
+       for _, status := range statuses {
+               rawDatas = append(rawDatas, status)
+       }
+       flagList := list.New()
+       for _, flag := range flags {
+               flagList.PushBack(flag)
+       }
+       return getMerkleTreeProofByFlags(rawDatas, flagList)
+}
+
+// getMerkleRootByProof caculate the merkle root hash according to the proof
+func getMerkleRootByProof(hashList *list.List, flagList *list.List, merkleHashes *list.List) bc.Hash {
+       if flagList.Len() == 0 {
+               return bc.EmptyStringHash
+       }
+       flagEle := flagList.Front()
+       flag := flagEle.Value.(uint8)
+       flagList.Remove(flagEle)
+       if flag == FlagAssist {
+               hash := hashList.Front()
+               hashList.Remove(hash)
+               return hash.Value.(bc.Hash)
+       }
+       if flag == FlagTxLeaf {
+               if hashList.Len() == 0 || merkleHashes.Len() == 0 {
+                       return bc.EmptyStringHash
+               }
+               hashEle := hashList.Front()
+               hash := hashEle.Value.(bc.Hash)
+               relatedHashEle := merkleHashes.Front()
+               relatedHash := relatedHashEle.Value.(bc.Hash)
+               if hash == relatedHash {
+                       hashList.Remove(hashEle)
+                       merkleHashes.Remove(relatedHashEle)
+                       return hash
+               }
+               return bc.EmptyStringHash
+       }
+       leftHash := getMerkleRootByProof(hashList, flagList, merkleHashes)
+       rightHash := getMerkleRootByProof(hashList, flagList, merkleHashes)
+       hash := interiorMerkleHash(&leftHash, &rightHash)
+       return hash
+}
+
+func newMerkleTreeNode(merkleHash bc.Hash, left *merkleTreeNode, right *merkleTreeNode) *merkleTreeNode {
+       return &merkleTreeNode{
+               hash:  merkleHash,
+               left:  left,
+               right: right,
+       }
+}
+
+// ValidateMerkleTreeProof caculate the merkle root according to the hash of node and the flags
+// only if the merkle root by caculated equals to the specify merkle root, and the merkle tree
+// contains all of the related raw datas, the validate result will be true.
+func validateMerkleTreeProof(hashes []*bc.Hash, flags []uint8, relatedNodes []merkleNode, merkleRoot bc.Hash) bool {
+       merkleHashes := list.New()
+       for _, relatedNode := range relatedNodes {
+               merkleHashes.PushBack(leafMerkleHash(relatedNode))
+       }
+       hashList := list.New()
+       for _, hash := range hashes {
+               hashList.PushBack(*hash)
+       }
+       flagList := list.New()
+       for _, flag := range flags {
+               flagList.PushBack(flag)
+       }
+       root := getMerkleRootByProof(hashList, flagList, merkleHashes)
+       return root == merkleRoot && merkleHashes.Len() == 0
+}
+
+// ValidateTxMerkleTreeProof validate the merkle tree of transactions
+func ValidateTxMerkleTreeProof(hashes []*bc.Hash, flags []uint8, relatedHashes []*bc.Hash, merkleRoot bc.Hash) bool {
+       var relatedNodes []merkleNode
+       for _, hash := range relatedHashes {
+               relatedNodes = append(relatedNodes, hash)
+       }
+       return validateMerkleTreeProof(hashes, flags, relatedNodes, merkleRoot)
+}
+
+// ValidateStatusMerkleTreeProof validate the merkle tree of transaction status
+func ValidateStatusMerkleTreeProof(hashes []*bc.Hash, flags []uint8, relatedStatus []*bc.TxVerifyResult, merkleRoot bc.Hash) bool {
+       var relatedNodes []merkleNode
+       for _, result := range relatedStatus {
+               relatedNodes = append(relatedNodes, result)
+       }
+       return validateMerkleTreeProof(hashes, flags, relatedNodes, merkleRoot)
+}
+
+// TxStatusMerkleRoot creates a merkle tree from a slice of bc.TxVerifyResult
+func TxStatusMerkleRoot(tvrs []*bc.TxVerifyResult) (root bc.Hash, err error) {
+       nodes := []merkleNode{}
+       for _, tvr := range tvrs {
+               nodes = append(nodes, tvr)
+       }
+       return merkleRoot(nodes)
+}
+
+// TxMerkleRoot creates a merkle tree from a slice of transactions
+// and returns the root hash of the tree.
+func TxMerkleRoot(transactions []*bc.Tx) (root bc.Hash, err error) {
+       nodes := []merkleNode{}
+       for _, tx := range transactions {
+               nodes = append(nodes, &tx.ID)
+       }
+       return merkleRoot(nodes)
+}
+
+// prevPowerOfTwo returns the largest power of two that is smaller than a given number.
+// In other words, for some input n, the prevPowerOfTwo k is a power of two such that
+// k < n <= 2k. This is a helper function used during the calculation of a merkle tree.
+func prevPowerOfTwo(n int) int {
+       // If the number is a power of two, divide it by 2 and return.
+       if n&(n-1) == 0 {
+               return n / 2
+       }
+
+       // Otherwise, find the previous PoT.
+       exponent := uint(math.Log2(float64(n)))
+       return 1 << exponent // 2^exponent
+}
diff --git a/protocol/bc/types/merkle_test.go b/protocol/bc/types/merkle_test.go
new file mode 100644 (file)
index 0000000..77022e0
--- /dev/null
@@ -0,0 +1,208 @@
+package types
+
+import (
+       "math/rand"
+       "reflect"
+       "testing"
+       "time"
+
+       "github.com/bytom/protocol/bc"
+       "github.com/bytom/protocol/vm"
+       "github.com/bytom/testutil"
+)
+
+func TestMerkleRoot(t *testing.T) {
+       cases := []struct {
+               witnesses [][][]byte
+               want      bc.Hash
+       }{{
+               witnesses: [][][]byte{
+                       {
+                               {1},
+                               []byte("00000"),
+                       },
+               },
+               want: testutil.MustDecodeHash("fe34dbd5da0ce3656f423fd7aad7fc7e879353174d33a6446c2ed0e3f3512101"),
+       }, {
+               witnesses: [][][]byte{
+                       {
+                               {1},
+                               []byte("000000"),
+                       },
+                       {
+                               {1},
+                               []byte("111111"),
+                       },
+               },
+               want: testutil.MustDecodeHash("0e4b4c1af18b8f59997804d69f8f66879ad5e30027346ee003ff7c7a512e5554"),
+       }, {
+               witnesses: [][][]byte{
+                       {
+                               {1},
+                               []byte("000000"),
+                       },
+                       {
+                               {2},
+                               []byte("111111"),
+                               []byte("222222"),
+                       },
+               },
+               want: testutil.MustDecodeHash("0e4b4c1af18b8f59997804d69f8f66879ad5e30027346ee003ff7c7a512e5554"),
+       }}
+
+       for _, c := range cases {
+               var txs []*bc.Tx
+               for _, wit := range c.witnesses {
+                       txs = append(txs, NewTx(TxData{
+                               Inputs: []*TxInput{
+                                       &TxInput{
+                                               AssetVersion: 1,
+                                               TypedInput: &SpendInput{
+                                                       Arguments: wit,
+                                                       SpendCommitment: SpendCommitment{
+                                                               AssetAmount: bc.AssetAmount{
+                                                                       AssetId: &bc.AssetID{V0: 0},
+                                                               },
+                                                       },
+                                               },
+                                       },
+                               },
+                       }).Tx)
+               }
+               got, err := TxMerkleRoot(txs)
+               if err != nil {
+                       t.Fatalf("unexpected error %s", err)
+               }
+               if got != c.want {
+                       t.Log("witnesses", c.witnesses)
+                       t.Errorf("got merkle root = %x want %x", got.Bytes(), c.want.Bytes())
+               }
+       }
+}
+
+func TestDuplicateLeaves(t *testing.T) {
+       trueProg := []byte{byte(vm.OP_TRUE)}
+       assetID := bc.ComputeAssetID(trueProg, 1, &bc.EmptyStringHash)
+       txs := make([]*bc.Tx, 6)
+       for i := uint64(0); i < 6; i++ {
+               now := []byte(time.Now().String())
+               txs[i] = NewTx(TxData{
+                       Version: 1,
+                       Inputs:  []*TxInput{NewIssuanceInput(now, i, trueProg, nil, nil)},
+                       Outputs: []*TxOutput{NewTxOutput(assetID, i, trueProg)},
+               }).Tx
+       }
+
+       // first, get the root of an unbalanced tree
+       txns := []*bc.Tx{txs[5], txs[4], txs[3], txs[2], txs[1], txs[0]}
+       root1, err := TxMerkleRoot(txns)
+       if err != nil {
+               t.Fatalf("unexpected error %s", err)
+       }
+
+       // now, get the root of a balanced tree that repeats leaves 0 and 1
+       txns = []*bc.Tx{txs[5], txs[4], txs[3], txs[2], txs[1], txs[0], txs[1], txs[0]}
+       root2, err := TxMerkleRoot(txns)
+       if err != nil {
+               t.Fatalf("unexpected error %s", err)
+       }
+
+       if root1 == root2 {
+               t.Error("forged merkle tree by duplicating some leaves")
+       }
+}
+
+func TestAllDuplicateLeaves(t *testing.T) {
+       trueProg := []byte{byte(vm.OP_TRUE)}
+       assetID := bc.ComputeAssetID(trueProg, 1, &bc.EmptyStringHash)
+       now := []byte(time.Now().String())
+       issuanceInp := NewIssuanceInput(now, 1, trueProg, nil, nil)
+
+       tx := NewTx(TxData{
+               Version: 1,
+               Inputs:  []*TxInput{issuanceInp},
+               Outputs: []*TxOutput{NewTxOutput(assetID, 1, trueProg)},
+       }).Tx
+       tx1, tx2, tx3, tx4, tx5, tx6 := tx, tx, tx, tx, tx, tx
+
+       // first, get the root of an unbalanced tree
+       txs := []*bc.Tx{tx6, tx5, tx4, tx3, tx2, tx1}
+       root1, err := TxMerkleRoot(txs)
+       if err != nil {
+               t.Fatalf("unexpected error %s", err)
+       }
+
+       // now, get the root of a balanced tree that repeats leaves 5 and 6
+       txs = []*bc.Tx{tx6, tx5, tx6, tx5, tx4, tx3, tx2, tx1}
+       root2, err := TxMerkleRoot(txs)
+       if err != nil {
+               t.Fatalf("unexpected error %s", err)
+       }
+
+       if root1 == root2 {
+               t.Error("forged merkle tree with all duplicate leaves")
+       }
+}
+
+func TestTxMerkleProof(t *testing.T) {
+       var txs []*Tx
+       var bcTxs []*bc.Tx
+       trueProg := []byte{byte(vm.OP_TRUE)}
+       assetID := bc.ComputeAssetID(trueProg, 1, &bc.EmptyStringHash)
+       for i := 0; i < 10; i++ {
+               now := []byte(time.Now().String())
+               issuanceInp := NewIssuanceInput(now, 1, trueProg, nil, nil)
+               tx := NewTx(TxData{
+                       Version: 1,
+                       Inputs:  []*TxInput{issuanceInp},
+                       Outputs: []*TxOutput{NewTxOutput(assetID, 1, trueProg)},
+               })
+               txs = append(txs, tx)
+               bcTxs = append(bcTxs, tx.Tx)
+       }
+       root, err := TxMerkleRoot(bcTxs)
+       if err != nil {
+               t.Fatalf("unexpected error %s", err)
+       }
+
+       relatedTx := []*Tx{txs[0], txs[3], txs[7], txs[8]}
+       proofHashes, flags := GetTxMerkleTreeProof(txs, relatedTx)
+       if len(proofHashes) <= 0 {
+               t.Error("Can not find any tx id in the merkle tree")
+       }
+       expectFlags := []uint8{1, 1, 1, 1, 2, 0, 1, 0, 2, 1, 0, 1, 0, 2, 1, 2, 0}
+       if !reflect.DeepEqual(flags, expectFlags) {
+               t.Error("The flags is not equals expect flags", flags)
+       }
+       if len(proofHashes) != 9 {
+               t.Error("The length proof hashes is not equals expect length")
+       }
+       ids := []*bc.Hash{&txs[0].ID, &txs[3].ID, &txs[7].ID, &txs[8].ID}
+       if !ValidateTxMerkleTreeProof(proofHashes, flags, ids, root) {
+               t.Error("Merkle tree validate fail")
+       }
+}
+
+func TestStatusMerkleProof(t *testing.T) {
+       var statuses []*bc.TxVerifyResult
+       for i := 0; i < 10; i++ {
+               status := &bc.TxVerifyResult{}
+               fail := rand.Intn(2)
+               if fail == 0 {
+                       status.StatusFail = true
+               } else {
+                       status.StatusFail = false
+               }
+               statuses = append(statuses, status)
+       }
+       relatedStatuses := []*bc.TxVerifyResult{statuses[0], statuses[3], statuses[7], statuses[8]}
+       flags := []uint8{1, 1, 1, 1, 2, 0, 1, 0, 2, 1, 0, 1, 0, 2, 1, 2, 0}
+       hashes := GetStatusMerkleTreeProof(statuses, flags)
+       if len(hashes) != 9 {
+               t.Error("The length proof hashes is not equals expect length")
+       }
+       root, _ := TxStatusMerkleRoot(statuses)
+       if !ValidateStatusMerkleTreeProof(hashes, flags, relatedStatuses, root) {
+               t.Error("Merkle tree validate fail")
+       }
+}
index 6df1e1e..d63dedf 100644 (file)
@@ -7,6 +7,7 @@ import (
        "github.com/bytom/consensus/difficulty"
        "github.com/bytom/errors"
        "github.com/bytom/protocol/bc"
+       "github.com/bytom/protocol/bc/types"
        "github.com/bytom/protocol/state"
 )
 
@@ -99,7 +100,7 @@ func ValidateBlock(b *bc.Block, parent *state.BlockNode) error {
                return err
        }
 
-       txMerkleRoot, err := bc.TxMerkleRoot(b.Transactions)
+       txMerkleRoot, err := types.TxMerkleRoot(b.Transactions)
        if err != nil {
                return errors.Wrap(err, "computing transaction id merkle root")
        }
@@ -107,7 +108,7 @@ func ValidateBlock(b *bc.Block, parent *state.BlockNode) error {
                return errors.WithDetailf(errMismatchedMerkleRoot, "transaction id merkle root")
        }
 
-       txStatusHash, err := bc.TxStatusMerkleRoot(b.TransactionStatus.VerifyStatus)
+       txStatusHash, err := types.TxStatusMerkleRoot(b.TransactionStatus.VerifyStatus)
        if err != nil {
                return errors.Wrap(err, "computing transaction status merkle root")
        }
index 94088f1..8795ef7 100644 (file)
@@ -61,12 +61,12 @@ func NewBlock(chain *protocol.Chain, txs []*types.Tx, controlProgram []byte) (*t
 
        b.Transactions[0] = coinbaseTx
        txEntries[0] = coinbaseTx.Tx
-       b.TransactionsMerkleRoot, err = bc.TxMerkleRoot(txEntries)
+       b.TransactionsMerkleRoot, err = types.TxMerkleRoot(txEntries)
        if err != nil {
                return nil, err
        }
 
-       b.TransactionStatusHash, err = bc.TxStatusMerkleRoot(txStatus.VerifyStatus)
+       b.TransactionStatusHash, err = types.TxStatusMerkleRoot(txStatus.VerifyStatus)
        return b, err
 }
 
@@ -78,7 +78,7 @@ func ReplaceCoinbase(block *types.Block, coinbaseTx *types.Tx) (err error) {
                txEntires = append(txEntires, block.Transactions[i].Tx)
        }
 
-       block.TransactionsMerkleRoot, err = bc.TxMerkleRoot(txEntires)
+       block.TransactionsMerkleRoot, err = types.TxMerkleRoot(txEntires)
        return
 }
 
index f9eed1f..d0c6edc 100644 (file)
@@ -50,7 +50,7 @@ func (ctx *chainTestContext) validateStatus(block *types.Block) error {
                return err
        }
 
-       txStatusMerkleRoot, err := bc.TxStatusMerkleRoot(txStatus.VerifyStatus)
+       txStatusMerkleRoot, err := types.TxStatusMerkleRoot(txStatus.VerifyStatus)
        if err != nil {
                return err
        }
index a3c1c64..a25dd41 100644 (file)
@@ -67,6 +67,10 @@ func (c *Chain) GetHeaderByHeight(height uint64) (*types.BlockHeader, error) {
        return &block.BlockHeader, nil
 }
 
+func (c *Chain) GetTransactionStatus(hash *bc.Hash) (*bc.TransactionStatus, error) {
+       return nil, nil
+}
+
 func (c *Chain) InMainChain(hash bc.Hash) bool {
        block, ok := c.blockMap[hash]
        if !ok {