OSDN Git Service

Add merkle proof check
authorYahtoo Ma <yahtoo.ma@gmail.com>
Thu, 23 Aug 2018 08:27:44 +0000 (16:27 +0800)
committerYahtoo Ma <yahtoo.ma@gmail.com>
Fri, 24 Aug 2018 04:16:08 +0000 (12:16 +0800)
netsync/block_keeper.go
netsync/message.go
protocol/bc/types/merkle.go [new file with mode: 0644]
protocol/bc/types/merkleBlock.go
protocol/bc/types/merkle_test.go [new file with mode: 0644]
protocol/block.go

index 7433ff5..7e611a6 100644 (file)
@@ -166,7 +166,12 @@ func (bk *blockKeeper) fastBlockSync(checkPoint *consensus.Checkpoint) error {
                if err != nil {
                        return err
                }
-               bk.VerifyMerkleBlock(fastHeader.Value.(*types.BlockHeader), merkleBlock)
+               if err := bk.VerifyBlockHeader(fastHeader.Value.(*types.BlockHeader), merkleBlock); err != nil {
+                       return err
+               }
+               if err := bk.VerifyMerkleBlock(merkleBlock); err != nil {
+                       return err
+               }
                blockHash := merkleBlock.Hash()
                if blockHash != fastHeader.Value.(*types.BlockHeader).Hash() {
                        return errPeerMisbehave
@@ -283,12 +288,14 @@ func (bk *blockKeeper) processHeaders(peerID string, headers []*types.BlockHeade
 func (bk *blockKeeper) regularBlockSync(wantHeight uint64) error {
        i := bk.chain.BestBlockHeight() + 1
        for i <= wantHeight {
-               block, err := bk.requireMerkleBlock(i, nil)
+               merkleBlock, err := bk.requireMerkleBlock(i, nil)
                if err != nil {
                        return err
                }
-
-               isOrphan, err := bk.ProcessMerkleBlock(block)
+               if err := bk.VerifyMerkleBlock(merkleBlock); err != nil {
+                       return err
+               }
+               isOrphan, err := bk.ProcessMerkleBlock(merkleBlock)
                if err != nil {
                        return err
                }
@@ -347,12 +354,35 @@ func (bk *blockKeeper) requireMerkleBlock(height uint64, hash *bc.Hash) (*types.
        }
 }
 
-func (bk *blockKeeper) VerifyMerkleBlock(header *types.BlockHeader, merkleBlock *types.MerkleBlock) bool {
+func (bk *blockKeeper) VerifyBlockHeader(header *types.BlockHeader, merkleBlock *types.MerkleBlock) error {
        if header.Hash() != merkleBlock.BlockHeader.Hash() {
-               return false
+               return errors.New("BlockHeader mismatch")
        }
+       return nil
+}
 
-       return true
+func (bk *blockKeeper) VerifyMerkleBlock(merkleBlock *types.MerkleBlock) error {
+       if len(merkleBlock.Transactions) == 0 {
+               return nil
+       }
+
+       var proofHashes []*bc.Hash
+       for _, v := range merkleBlock.TxHashes {
+               hash := bc.NewHash(v)
+               proofHashes = append(proofHashes, &hash)
+       }
+
+       flags := merkleBlock.Flags
+
+       var relatedTxHashes []*bc.Hash
+       for _, v := range merkleBlock.Transactions {
+               relatedTxHashes = append(relatedTxHashes, &v.ID)
+       }
+       if !types.ValidateTxMerkleTreeProof(proofHashes, flags, relatedTxHashes, merkleBlock.BlockHeader.TransactionsMerkleRoot) {
+               return errors.New("merkle proof check error")
+       }
+
+       return nil
 }
 
 func (bk *blockKeeper) requireBlocks(locator []*bc.Hash, stopHash *bc.Hash) ([]*types.Block, error) {
index 3c9b62d..d1ffa87 100644 (file)
@@ -399,12 +399,18 @@ func (m *MerkleBlockMessage) GetMerkleBlock() *types.MerkleBlock {
                BlockHeader:  types.BlockHeader{},
                Transactions: []*types.Tx{},
        }
-
        merkleBlock.BlockHeader.UnmarshalText(m.RawBlockHeader)
+
        for _, rawTx := range m.RawTxDatas {
                tx := &types.Tx{}
                tx.UnmarshalText(rawTx)
                merkleBlock.Transactions = append(merkleBlock.Transactions, tx)
        }
+
+       merkleBlock.TxHashes = m.TxHashes
+       merkleBlock.Flags = m.Flags
+       merkleBlock.StatusHashes = m.StatusHashes
+       merkleBlock.RawTxStatuses = m.RawTxStatuses
+
        return merkleBlock
 }
diff --git a/protocol/bc/types/merkle.go b/protocol/bc/types/merkle.go
new file mode 100644 (file)
index 0000000..8b05dc7
--- /dev/null
@@ -0,0 +1,336 @@
+package types
+
+import (
+       "container/list"
+       "io"
+       "math"
+
+       "gopkg.in/fatih/set.v0"
+
+       "github.com/bytom/crypto/sha3pool"
+       "github.com/bytom/protocol/bc"
+)
+
+// merkleFlag represent the type of merkle tree node, it's used to generate the structure of merkle tree
+// Bitcoin has only two flags, which zero means the hash of assist node. And one means the hash of the related
+// transaction node or it's parents, which distinguish them according to the height of the tree. But in the bytom,
+// the height of transaction node is not fixed, so we need three flags to distinguish these nodes.
+const (
+       // FlagAssist represent assist node
+       FlagAssist = iota
+       // FlagTxParent represent the parent of transaction of node
+       FlagTxParent
+       // FlagTxLeaf represent transaction of node
+       FlagTxLeaf
+)
+
+var (
+       leafPrefix     = []byte{0x00}
+       interiorPrefix = []byte{0x01}
+)
+
+type merkleNode interface {
+       WriteTo(io.Writer) (int64, error)
+}
+
+func merkleRoot(nodes []merkleNode) (root bc.Hash, err error) {
+       switch {
+       case len(nodes) == 0:
+               return bc.EmptyStringHash, nil
+
+       case len(nodes) == 1:
+               root = leafMerkleHash(nodes[0])
+               return root, nil
+
+       default:
+               k := prevPowerOfTwo(len(nodes))
+               left, err := merkleRoot(nodes[:k])
+               if err != nil {
+                       return root, err
+               }
+
+               right, err := merkleRoot(nodes[k:])
+               if err != nil {
+                       return root, err
+               }
+
+               root = interiorMerkleHash(&left, &right)
+               return root, nil
+       }
+}
+
+func interiorMerkleHash(left merkleNode, right merkleNode) (hash bc.Hash) {
+       h := sha3pool.Get256()
+       defer sha3pool.Put256(h)
+       h.Write(interiorPrefix)
+       left.WriteTo(h)
+       right.WriteTo(h)
+       hash.ReadFrom(h)
+       return hash
+}
+
+func leafMerkleHash(node merkleNode) (hash bc.Hash) {
+       h := sha3pool.Get256()
+       defer sha3pool.Put256(h)
+       h.Write(leafPrefix)
+       node.WriteTo(h)
+       hash.ReadFrom(h)
+       return hash
+}
+
+type merkleTreeNode struct {
+       hash  bc.Hash
+       left  *merkleTreeNode
+       right *merkleTreeNode
+}
+
+// buildMerkleTree construct a merkle tree based on the provide node data
+func buildMerkleTree(rawDatas []merkleNode) *merkleTreeNode {
+       switch len(rawDatas) {
+       case 0:
+               return nil
+       case 1:
+               rawData := rawDatas[0]
+               merkleHash := leafMerkleHash(rawData)
+               node := newMerkleTreeNode(merkleHash, nil, nil)
+               return node
+       default:
+               k := prevPowerOfTwo(len(rawDatas))
+               left := buildMerkleTree(rawDatas[:k])
+               right := buildMerkleTree(rawDatas[k:])
+               merkleHash := interiorMerkleHash(&left.hash, &right.hash)
+               node := newMerkleTreeNode(merkleHash, left, right)
+               return node
+       }
+}
+
+func (node *merkleTreeNode) getMerkleTreeProof(merkleHashSet *set.Set) ([]*bc.Hash, []uint8) {
+       var hashes []*bc.Hash
+       var flags []uint8
+
+       if node.left == nil && node.right == nil {
+               if key := node.hash.String(); merkleHashSet.Has(key) {
+                       hashes = append(hashes, &node.hash)
+                       flags = append(flags, FlagTxLeaf)
+                       return hashes, flags
+               }
+               return hashes, flags
+       }
+       var leftHashes, rightHashes []*bc.Hash
+       var leftFlags, rightFlags []uint8
+       if node.left != nil {
+               leftHashes, leftFlags = node.left.getMerkleTreeProof(merkleHashSet)
+       }
+       if node.right != nil {
+               rightHashes, rightFlags = node.right.getMerkleTreeProof(merkleHashSet)
+       }
+       leftFind, rightFind := len(leftHashes) > 0, len(rightHashes) > 0
+
+       if leftFind || rightFind {
+               flags = append(flags, FlagTxParent)
+       } else {
+               return hashes, flags
+       }
+
+       if leftFind {
+               hashes = append(hashes, leftHashes...)
+               flags = append(flags, leftFlags...)
+       } else {
+               hashes = append(hashes, &node.left.hash)
+               flags = append(flags, FlagAssist)
+       }
+
+       if rightFind {
+               hashes = append(hashes, rightHashes...)
+               flags = append(flags, rightFlags...)
+       } else {
+               hashes = append(hashes, &node.right.hash)
+               flags = append(flags, FlagAssist)
+       }
+       return hashes, flags
+}
+
+func getMerkleTreeProof(rawDatas []merkleNode, relatedRawDatas []merkleNode) ([]*bc.Hash, []uint8) {
+       merkleTree := buildMerkleTree(rawDatas)
+       if merkleTree == nil {
+               return []*bc.Hash{}, []uint8{}
+       }
+       merkleHashSet := set.New()
+       for _, data := range relatedRawDatas {
+               merkleHash := leafMerkleHash(data)
+               merkleHashSet.Add(merkleHash.String())
+       }
+       return merkleTree.getMerkleTreeProof(merkleHashSet)
+}
+
+func (node *merkleTreeNode) getMerkleTreeProofByFlags(flagList *list.List) []*bc.Hash {
+       var hashes []*bc.Hash
+
+       if flagList.Len() == 0 {
+               return hashes
+       }
+       flagEle := flagList.Front()
+       flag := flagEle.Value.(uint8)
+       flagList.Remove(flagEle)
+
+       if flag == FlagTxLeaf || flag == FlagAssist {
+               hashes = append(hashes, &node.hash)
+               return hashes
+       }
+       if node.left != nil {
+               leftHashes := node.left.getMerkleTreeProofByFlags(flagList)
+               hashes = append(hashes, leftHashes...)
+       }
+       if node.right != nil {
+               rightHashes := node.right.getMerkleTreeProofByFlags(flagList)
+               hashes = append(hashes, rightHashes...)
+       }
+       return hashes
+}
+
+func getMerkleTreeProofByFlags(rawDatas []merkleNode, flagList *list.List) []*bc.Hash {
+       tree := buildMerkleTree(rawDatas)
+       return tree.getMerkleTreeProofByFlags(flagList)
+}
+
+// GetTxMerkleTreeProof return a proof of merkle tree, which used to proof the transaction does
+// exist in the merkle tree
+func GetTxMerkleTreeProof(txs []*Tx, relatedTxs []*Tx) ([]*bc.Hash, []uint8) {
+       var rawDatas []merkleNode
+       var relatedRawDatas []merkleNode
+       for _, tx := range txs {
+               temp := tx.ID
+               rawDatas = append(rawDatas, &temp)
+       }
+       for _, relatedTx := range relatedTxs {
+               temp := relatedTx.ID
+               relatedRawDatas = append(relatedRawDatas, &temp)
+       }
+       return getMerkleTreeProof(rawDatas, relatedRawDatas)
+}
+
+// GetStatusMerkleTreeProof return a proof of merkle tree, which used to proof the status of transaction is valid
+func GetStatusMerkleTreeProof(statuses []*bc.TxVerifyResult, flags []uint8) []*bc.Hash {
+       var rawDatas []merkleNode
+       for _, status := range statuses {
+               rawDatas = append(rawDatas, status)
+       }
+       flagList := list.New()
+       for _, flag := range flags {
+               flagList.PushBack(flag)
+       }
+       return getMerkleTreeProofByFlags(rawDatas, flagList)
+}
+
+// getMerkleRootByProof caculate the merkle root hash according to the proof
+func getMerkleRootByProof(hashList *list.List, flagList *list.List, merkleHashes *list.List) bc.Hash {
+       if flagList.Len() == 0 {
+               return bc.EmptyStringHash
+       }
+       flagEle := flagList.Front()
+       flag := flagEle.Value.(uint8)
+       flagList.Remove(flagEle)
+       if flag == FlagAssist {
+               hash := hashList.Front()
+               hashList.Remove(hash)
+               return hash.Value.(bc.Hash)
+       }
+       if flag == FlagTxLeaf {
+               if hashList.Len() == 0 || merkleHashes.Len() == 0 {
+                       return bc.EmptyStringHash
+               }
+               hashEle := hashList.Front()
+               hash := hashEle.Value.(bc.Hash)
+               relatedHashEle := merkleHashes.Front()
+               relatedHash := relatedHashEle.Value.(bc.Hash)
+               if hash == relatedHash {
+                       hashList.Remove(hashEle)
+                       merkleHashes.Remove(relatedHashEle)
+                       return hash
+               }
+               return bc.EmptyStringHash
+       }
+       leftHash := getMerkleRootByProof(hashList, flagList, merkleHashes)
+       rightHash := getMerkleRootByProof(hashList, flagList, merkleHashes)
+       hash := interiorMerkleHash(&leftHash, &rightHash)
+       return hash
+}
+
+func newMerkleTreeNode(merkleHash bc.Hash, left *merkleTreeNode, right *merkleTreeNode) *merkleTreeNode {
+       return &merkleTreeNode{
+               hash:  merkleHash,
+               left:  left,
+               right: right,
+       }
+}
+
+// ValidateMerkleTreeProof caculate the merkle root according to the hash of node and the flags
+// only if the merkle root by caculated equals to the specify merkle root, and the merkle tree
+// contains all of the related raw datas, the validate result will be true.
+func validateMerkleTreeProof(hashes []*bc.Hash, flags []uint8, relatedNodes []merkleNode, merkleRoot bc.Hash) bool {
+       merkleHashes := list.New()
+       for _, relatedNode := range relatedNodes {
+               merkleHashes.PushBack(leafMerkleHash(relatedNode))
+       }
+       hashList := list.New()
+       for _, hash := range hashes {
+               hashList.PushBack(*hash)
+       }
+       flagList := list.New()
+       for _, flag := range flags {
+               flagList.PushBack(flag)
+       }
+       root := getMerkleRootByProof(hashList, flagList, merkleHashes)
+       return root == merkleRoot && merkleHashes.Len() == 0
+}
+
+// ValidateTxMerkleTreeProof validate the merkle tree of transactions
+func ValidateTxMerkleTreeProof(hashes []*bc.Hash, flags []uint8, relatedHashes []*bc.Hash, merkleRoot bc.Hash) bool {
+       var relatedNodes []merkleNode
+       for _, hash := range relatedHashes {
+               relatedNodes = append(relatedNodes, hash)
+       }
+       return validateMerkleTreeProof(hashes, flags, relatedNodes, merkleRoot)
+}
+
+// ValidateStatusMerkleTreeProof validate the merkle tree of transaction status
+func ValidateStatusMerkleTreeProof(hashes []*bc.Hash, flags []uint8, relatedStatus []*bc.TxVerifyResult, merkleRoot bc.Hash) bool {
+       var relatedNodes []merkleNode
+       for _, result := range relatedStatus {
+               relatedNodes = append(relatedNodes, result)
+       }
+       return validateMerkleTreeProof(hashes, flags, relatedNodes, merkleRoot)
+}
+
+// TxStatusMerkleRoot creates a merkle tree from a slice of bc.TxVerifyResult
+func TxStatusMerkleRoot(tvrs []*bc.TxVerifyResult) (root bc.Hash, err error) {
+       nodes := []merkleNode{}
+       for _, tvr := range tvrs {
+               nodes = append(nodes, tvr)
+       }
+       return merkleRoot(nodes)
+}
+
+// TxMerkleRoot creates a merkle tree from a slice of transactions
+// and returns the root hash of the tree.
+func TxMerkleRoot(transactions []*bc.Tx) (root bc.Hash, err error) {
+       nodes := []merkleNode{}
+       for _, tx := range transactions {
+               nodes = append(nodes, &tx.ID)
+       }
+       return merkleRoot(nodes)
+}
+
+// prevPowerOfTwo returns the largest power of two that is smaller than a given number.
+// In other words, for some input n, the prevPowerOfTwo k is a power of two such that
+// k < n <= 2k. This is a helper function used during the calculation of a merkle tree.
+func prevPowerOfTwo(n int) int {
+       // If the number is a power of two, divide it by 2 and return.
+       if n&(n-1) == 0 {
+               return n / 2
+       }
+
+       // Otherwise, find the previous PoT.
+       exponent := uint(math.Log2(float64(n)))
+       return 1 << exponent // 2^exponent
+}
index 3cf09d0..6ba5d51 100644 (file)
@@ -22,7 +22,11 @@ const (
 // it contains.
 type MerkleBlock struct {
        BlockHeader
-       Transactions []*Tx
+       Transactions  []*Tx
+       TxHashes      [][32]byte
+       StatusHashes  [][32]byte
+       RawTxStatuses [][]byte
+       Flags         []byte
 }
 
 // MarshalText fulfills the json.Marshaler interface. This guarantees that
diff --git a/protocol/bc/types/merkle_test.go b/protocol/bc/types/merkle_test.go
new file mode 100644 (file)
index 0000000..d0681de
--- /dev/null
@@ -0,0 +1,208 @@
+package types
+
+import (
+       "math/rand"
+       "reflect"
+       "testing"
+       "time"
+
+       "github.com/bytom/protocol/bc"
+       "github.com/bytom/protocol/vm"
+       "github.com/bytom/testutil"
+)
+
+func TestMerkleRoot(t *testing.T) {
+       cases := []struct {
+               witnesses [][][]byte
+               want      bc.Hash
+       }{{
+               witnesses: [][][]byte{
+                       {
+                               {1},
+                               []byte("00000"),
+                       },
+               },
+               want: testutil.MustDecodeHash("fe34dbd5da0ce3656f423fd7aad7fc7e879353174d33a6446c2ed0e3f3512101"),
+       }, {
+               witnesses: [][][]byte{
+                       {
+                               {1},
+                               []byte("000000"),
+                       },
+                       {
+                               {1},
+                               []byte("111111"),
+                       },
+               },
+               want: testutil.MustDecodeHash("0e4b4c1af18b8f59997804d69f8f66879ad5e30027346ee003ff7c7a512e5554"),
+       }, {
+               witnesses: [][][]byte{
+                       {
+                               {1},
+                               []byte("000000"),
+                       },
+                       {
+                               {2},
+                               []byte("111111"),
+                               []byte("222222"),
+                       },
+               },
+               want: testutil.MustDecodeHash("0e4b4c1af18b8f59997804d69f8f66879ad5e30027346ee003ff7c7a512e5554"),
+       }}
+
+       for _, c := range cases {
+               var txs []*bc.Tx
+               for _, wit := range c.witnesses {
+                       txs = append(txs, NewTx(TxData{
+                               Inputs: []*TxInput{
+                                       &TxInput{
+                                               AssetVersion: 1,
+                                               TypedInput: &SpendInput{
+                                                       Arguments: wit,
+                                                       SpendCommitment: SpendCommitment{
+                                                               AssetAmount: bc.AssetAmount{
+                                                                       AssetId: &bc.AssetID{V0: 0},
+                                                               },
+                                                       },
+                                               },
+                                       },
+                               },
+                       }).Tx)
+               }
+               got, err := TxMerkleRoot(txs)
+               if err != nil {
+                       t.Fatalf("unexpected error %s", err)
+               }
+               if got != c.want {
+                       t.Log("witnesses", c.witnesses)
+                       t.Errorf("got merkle root = %x want %x", got.Bytes(), c.want.Bytes())
+               }
+       }
+}
+
+func TestDuplicateLeaves(t *testing.T) {
+       trueProg := []byte{byte(vm.OP_TRUE)}
+       assetID := bc.ComputeAssetID(trueProg, 1, &bc.EmptyStringHash)
+       txs := make([]*bc.Tx, 6)
+       for i := uint64(0); i < 6; i++ {
+               now := []byte(time.Now().String())
+               txs[i] = NewTx(TxData{
+                       Version: 1,
+                       Inputs:  []*TxInput{NewIssuanceInput(now, i, trueProg, nil, nil)},
+                       Outputs: []*TxOutput{NewTxOutput(assetID, i, trueProg)},
+               }).Tx
+       }
+
+       // first, get the root of an unbalanced tree
+       txns := []*bc.Tx{txs[5], txs[4], txs[3], txs[2], txs[1], txs[0]}
+       root1, err := TxMerkleRoot(txns)
+       if err != nil {
+               t.Fatalf("unexpected error %s", err)
+       }
+
+       // now, get the root of a balanced tree that repeats leaves 0 and 1
+       txns = []*bc.Tx{txs[5], txs[4], txs[3], txs[2], txs[1], txs[0], txs[1], txs[0]}
+       root2, err := TxMerkleRoot(txns)
+       if err != nil {
+               t.Fatalf("unexpected error %s", err)
+       }
+
+       if root1 == root2 {
+               t.Error("forged merkle tree by duplicating some leaves")
+       }
+}
+
+func TestAllDuplicateLeaves(t *testing.T) {
+       trueProg := []byte{byte(vm.OP_TRUE)}
+       assetID := bc.ComputeAssetID(trueProg, 1, &bc.EmptyStringHash)
+       now := []byte(time.Now().String())
+       issuanceInp := NewIssuanceInput(now, 1, trueProg, nil, nil)
+
+       tx := NewTx(TxData{
+               Version: 1,
+               Inputs:  []*TxInput{issuanceInp},
+               Outputs: []*TxOutput{NewTxOutput(assetID, 1, trueProg)},
+       }).Tx
+       tx1, tx2, tx3, tx4, tx5, tx6 := tx, tx, tx, tx, tx, tx
+
+       // first, get the root of an unbalanced tree
+       txs := []*bc.Tx{tx6, tx5, tx4, tx3, tx2, tx1}
+       root1, err := TxMerkleRoot(txs)
+       if err != nil {
+               t.Fatalf("unexpected error %s", err)
+       }
+
+       // now, get the root of a balanced tree that repeats leaves 5 and 6
+       txs = []*bc.Tx{tx6, tx5, tx6, tx5, tx4, tx3, tx2, tx1}
+       root2, err := TxMerkleRoot(txs)
+       if err != nil {
+               t.Fatalf("unexpected error %s", err)
+       }
+
+       if root1 == root2 {
+               t.Error("forged merkle tree with all duplicate leaves")
+       }
+}
+
+func TestTxMerkleProof(t *testing.T) {
+       var txs []*Tx
+       var bcTxs []*bc.Tx
+       trueProg := []byte{byte(vm.OP_TRUE)}
+       assetID := bc.ComputeAssetID(trueProg, 1, &bc.EmptyStringHash)
+       for i := 0; i < 18; i++ {
+               now := []byte(time.Now().String())
+               issuanceInp := NewIssuanceInput(now, 1, trueProg, nil, nil)
+               tx := NewTx(TxData{
+                       Version: 1,
+                       Inputs:  []*TxInput{issuanceInp},
+                       Outputs: []*TxOutput{NewTxOutput(assetID, 1, trueProg)},
+               })
+               txs = append(txs, tx)
+               bcTxs = append(bcTxs, tx.Tx)
+       }
+       root, err := TxMerkleRoot(bcTxs)
+       if err != nil {
+               t.Fatalf("unexpected error %s", err)
+       }
+
+       relatedTx := []*Tx{txs[17]}
+       proofHashes, flags := GetTxMerkleTreeProof(txs, relatedTx)
+       if len(proofHashes) <= 0 {
+               t.Error("Can not find any tx id in the merkle tree")
+       }
+       expectFlags := []uint8{1, 1, 1, 1, 2, 0, 1, 0, 2, 1, 0, 1, 0, 2, 1, 2, 0}
+       if !reflect.DeepEqual(flags, expectFlags) {
+               t.Error("The flags is not equals expect flags", flags)
+       }
+       if len(proofHashes) != 9 {
+               t.Error("The length proof hashes is not equals expect length")
+       }
+       ids := []*bc.Hash{&txs[0].ID, &txs[3].ID, &txs[7].ID, &txs[8].ID}
+       if !ValidateTxMerkleTreeProof(proofHashes, flags, ids, root) {
+               t.Error("Merkle tree validate fail")
+       }
+}
+
+func TestStatusMerkleProof(t *testing.T) {
+       var statuses []*bc.TxVerifyResult
+       for i := 0; i < 10; i++ {
+               status := &bc.TxVerifyResult{}
+               fail := rand.Intn(2)
+               if fail == 0 {
+                       status.StatusFail = true
+               } else {
+                       status.StatusFail = false
+               }
+               statuses = append(statuses, status)
+       }
+       relatedStatuses := []*bc.TxVerifyResult{statuses[0], statuses[3], statuses[7], statuses[8]}
+       flags := []uint8{1, 1, 1, 1, 2, 0, 1, 0, 2, 1, 0, 1, 0, 2, 1, 2, 0}
+       hashes := GetStatusMerkleTreeProof(statuses, flags)
+       if len(hashes) != 9 {
+               t.Error("The length proof hashes is not equals expect length")
+       }
+       root, _ := TxStatusMerkleRoot(statuses)
+       if !ValidateStatusMerkleTreeProof(hashes, flags, relatedStatuses, root) {
+               t.Error("Merkle tree validate fail")
+       }
+}
index 72a2dd7..46da3d3 100644 (file)
@@ -75,18 +75,7 @@ func (c *Chain) calcReorganizeNodes(node *state.BlockNode) ([]*state.BlockNode,
 
 func (c *Chain) connectBlock(block *types.Block) (err error) {
        bcBlock := types.MapBlock(block)
-       if bcBlock.TransactionStatus, err = c.store.GetTransactionStatus(&bcBlock.ID); err != nil {
-               return err
-       }
-
        utxoView := state.NewUtxoViewpoint()
-       if err := c.store.GetTransactionsUtxo(utxoView, bcBlock.Transactions); err != nil {
-               return err
-       }
-       if err := utxoView.ApplyBlock(bcBlock, bcBlock.TransactionStatus); err != nil {
-               return err
-       }
-
        node := c.index.GetNode(&bcBlock.ID)
        if err := c.setState(node, utxoView); err != nil {
                return err
@@ -99,51 +88,7 @@ func (c *Chain) connectBlock(block *types.Block) (err error) {
 }
 
 func (c *Chain) reorganizeChain(node *state.BlockNode) error {
-       attachNodes, detachNodes := c.calcReorganizeNodes(node)
        utxoView := state.NewUtxoViewpoint()
-
-       for _, detachNode := range detachNodes {
-               b, err := c.store.GetBlock(&detachNode.Hash)
-               if err != nil {
-                       return err
-               }
-
-               detachBlock := types.MapBlock(b)
-               if err := c.store.GetTransactionsUtxo(utxoView, detachBlock.Transactions); err != nil {
-                       return err
-               }
-               txStatus, err := c.GetTransactionStatus(&detachBlock.ID)
-               if err != nil {
-                       return err
-               }
-               if err := utxoView.DetachBlock(detachBlock, txStatus); err != nil {
-                       return err
-               }
-
-               log.WithFields(log.Fields{"height": node.Height, "hash": node.Hash.String()}).Debug("detach from mainchain")
-       }
-
-       for _, attachNode := range attachNodes {
-               b, err := c.store.GetBlock(&attachNode.Hash)
-               if err != nil {
-                       return err
-               }
-
-               attachBlock := types.MapBlock(b)
-               if err := c.store.GetTransactionsUtxo(utxoView, attachBlock.Transactions); err != nil {
-                       return err
-               }
-               txStatus, err := c.GetTransactionStatus(&attachBlock.ID)
-               if err != nil {
-                       return err
-               }
-               if err := utxoView.ApplyBlock(attachBlock, txStatus); err != nil {
-                       return err
-               }
-
-               log.WithFields(log.Fields{"height": node.Height, "hash": node.Hash.String()}).Debug("attach from mainchain")
-       }
-
        return c.setState(node, utxoView)
 }