package types import ( "container/list" "io" "math" "gopkg.in/fatih/set.v0" "github.com/vapor/crypto/sha3pool" "github.com/vapor/protocol/bc" ) // merkleFlag represent the type of merkle tree node, it's used to generate the structure of merkle tree // Bitcoin has only two flags, which zero means the hash of assist node. And one means the hash of the related // transaction node or it's parents, which distinguish them according to the height of the tree. But in the bytom, // the height of transaction node is not fixed, so we need three flags to distinguish these nodes. const ( // FlagAssist represent assist node FlagAssist = iota // FlagTxParent represent the parent of transaction of node FlagTxParent // FlagTxLeaf represent transaction of node FlagTxLeaf ) var ( leafPrefix = []byte{0x00} interiorPrefix = []byte{0x01} ) type merkleNode interface { WriteTo(io.Writer) (int64, error) } func merkleRoot(nodes []merkleNode) (root bc.Hash, err error) { switch { case len(nodes) == 0: return bc.EmptyStringHash, nil case len(nodes) == 1: root = leafMerkleHash(nodes[0]) return root, nil default: k := prevPowerOfTwo(len(nodes)) left, err := merkleRoot(nodes[:k]) if err != nil { return root, err } right, err := merkleRoot(nodes[k:]) if err != nil { return root, err } root = interiorMerkleHash(&left, &right) return root, nil } } func interiorMerkleHash(left merkleNode, right merkleNode) (hash bc.Hash) { h := sha3pool.Get256() defer sha3pool.Put256(h) h.Write(interiorPrefix) left.WriteTo(h) right.WriteTo(h) hash.ReadFrom(h) return hash } func leafMerkleHash(node merkleNode) (hash bc.Hash) { h := sha3pool.Get256() defer sha3pool.Put256(h) h.Write(leafPrefix) node.WriteTo(h) hash.ReadFrom(h) return hash } type merkleTreeNode struct { hash bc.Hash left *merkleTreeNode right *merkleTreeNode } // buildMerkleTree construct a merkle tree based on the provide node data func buildMerkleTree(rawDatas []merkleNode) *merkleTreeNode { switch len(rawDatas) { case 0: return nil case 1: rawData := rawDatas[0] merkleHash := leafMerkleHash(rawData) node := newMerkleTreeNode(merkleHash, nil, nil) return node default: k := prevPowerOfTwo(len(rawDatas)) left := buildMerkleTree(rawDatas[:k]) right := buildMerkleTree(rawDatas[k:]) merkleHash := interiorMerkleHash(&left.hash, &right.hash) node := newMerkleTreeNode(merkleHash, left, right) return node } } func (node *merkleTreeNode) getMerkleTreeProof(merkleHashSet *set.Set) ([]*bc.Hash, []uint8) { var hashes []*bc.Hash var flags []uint8 if node.left == nil && node.right == nil { if key := node.hash.String(); merkleHashSet.Has(key) { hashes = append(hashes, &node.hash) flags = append(flags, FlagTxLeaf) return hashes, flags } return hashes, flags } var leftHashes, rightHashes []*bc.Hash var leftFlags, rightFlags []uint8 if node.left != nil { leftHashes, leftFlags = node.left.getMerkleTreeProof(merkleHashSet) } if node.right != nil { rightHashes, rightFlags = node.right.getMerkleTreeProof(merkleHashSet) } leftFind, rightFind := len(leftHashes) > 0, len(rightHashes) > 0 if leftFind || rightFind { flags = append(flags, FlagTxParent) } else { return hashes, flags } if leftFind { hashes = append(hashes, leftHashes...) flags = append(flags, leftFlags...) } else { hashes = append(hashes, &node.left.hash) flags = append(flags, FlagAssist) } if rightFind { hashes = append(hashes, rightHashes...) flags = append(flags, rightFlags...) } else { hashes = append(hashes, &node.right.hash) flags = append(flags, FlagAssist) } return hashes, flags } func getMerkleTreeProof(rawDatas []merkleNode, relatedRawDatas []merkleNode) ([]*bc.Hash, []uint8) { merkleTree := buildMerkleTree(rawDatas) if merkleTree == nil { return []*bc.Hash{}, []uint8{} } merkleHashSet := set.New() for _, data := range relatedRawDatas { merkleHash := leafMerkleHash(data) merkleHashSet.Add(merkleHash.String()) } if merkleHashSet.Size() == 0 { return []*bc.Hash{&merkleTree.hash}, []uint8{FlagAssist} } return merkleTree.getMerkleTreeProof(merkleHashSet) } func (node *merkleTreeNode) getMerkleTreeProofByFlags(flagList *list.List) []*bc.Hash { var hashes []*bc.Hash if flagList.Len() == 0 { return hashes } flagEle := flagList.Front() flag := flagEle.Value.(uint8) flagList.Remove(flagEle) if flag == FlagTxLeaf || flag == FlagAssist { hashes = append(hashes, &node.hash) return hashes } if node.left != nil { leftHashes := node.left.getMerkleTreeProofByFlags(flagList) hashes = append(hashes, leftHashes...) } if node.right != nil { rightHashes := node.right.getMerkleTreeProofByFlags(flagList) hashes = append(hashes, rightHashes...) } return hashes } func getMerkleTreeProofByFlags(rawDatas []merkleNode, flagList *list.List) []*bc.Hash { tree := buildMerkleTree(rawDatas) return tree.getMerkleTreeProofByFlags(flagList) } // GetTxMerkleTreeProof return a proof of merkle tree, which used to proof the transaction does // exist in the merkle tree func GetTxMerkleTreeProof(txs []*Tx, relatedTxs []*Tx) ([]*bc.Hash, []uint8) { var rawDatas []merkleNode var relatedRawDatas []merkleNode for _, tx := range txs { rawDatas = append(rawDatas, &tx.ID) } for _, relatedTx := range relatedTxs { relatedRawDatas = append(relatedRawDatas, &relatedTx.ID) } return getMerkleTreeProof(rawDatas, relatedRawDatas) } // GetStatusMerkleTreeProof return a proof of merkle tree, which used to proof the status of transaction is valid func GetStatusMerkleTreeProof(statuses []*bc.TxVerifyResult, flags []uint8) []*bc.Hash { var rawDatas []merkleNode for _, status := range statuses { rawDatas = append(rawDatas, status) } flagList := list.New() for _, flag := range flags { flagList.PushBack(flag) } return getMerkleTreeProofByFlags(rawDatas, flagList) } // getMerkleRootByProof caculate the merkle root hash according to the proof func getMerkleRootByProof(hashList *list.List, flagList *list.List, merkleHashes *list.List) bc.Hash { if flagList.Len() == 0 || hashList.Len() == 0 { return bc.EmptyStringHash } flagEle := flagList.Front() flag := flagEle.Value.(uint8) flagList.Remove(flagEle) switch flag { case FlagAssist: { hash := hashList.Front() hashList.Remove(hash) return hash.Value.(bc.Hash) } case FlagTxLeaf: { if merkleHashes.Len() == 0 { return bc.EmptyStringHash } hashEle := hashList.Front() hash := hashEle.Value.(bc.Hash) relatedHashEle := merkleHashes.Front() relatedHash := relatedHashEle.Value.(bc.Hash) if hash == relatedHash { hashList.Remove(hashEle) merkleHashes.Remove(relatedHashEle) return hash } } case FlagTxParent: { leftHash := getMerkleRootByProof(hashList, flagList, merkleHashes) rightHash := getMerkleRootByProof(hashList, flagList, merkleHashes) hash := interiorMerkleHash(&leftHash, &rightHash) return hash } } return bc.EmptyStringHash } func newMerkleTreeNode(merkleHash bc.Hash, left *merkleTreeNode, right *merkleTreeNode) *merkleTreeNode { return &merkleTreeNode{ hash: merkleHash, left: left, right: right, } } // ValidateMerkleTreeProof caculate the merkle root according to the hash of node and the flags // only if the merkle root by caculated equals to the specify merkle root, and the merkle tree // contains all of the related raw datas, the validate result will be true. func validateMerkleTreeProof(hashes []*bc.Hash, flags []uint8, relatedNodes []merkleNode, merkleRoot bc.Hash) bool { merkleHashes := list.New() for _, relatedNode := range relatedNodes { merkleHashes.PushBack(leafMerkleHash(relatedNode)) } hashList := list.New() for _, hash := range hashes { hashList.PushBack(*hash) } flagList := list.New() for _, flag := range flags { flagList.PushBack(flag) } root := getMerkleRootByProof(hashList, flagList, merkleHashes) return root == merkleRoot && merkleHashes.Len() == 0 } // ValidateTxMerkleTreeProof validate the merkle tree of transactions func ValidateTxMerkleTreeProof(hashes []*bc.Hash, flags []uint8, relatedHashes []*bc.Hash, merkleRoot bc.Hash) bool { var relatedNodes []merkleNode for _, hash := range relatedHashes { relatedNodes = append(relatedNodes, hash) } return validateMerkleTreeProof(hashes, flags, relatedNodes, merkleRoot) } // ValidateStatusMerkleTreeProof validate the merkle tree of transaction status func ValidateStatusMerkleTreeProof(hashes []*bc.Hash, flags []uint8, relatedStatus []*bc.TxVerifyResult, merkleRoot bc.Hash) bool { var relatedNodes []merkleNode for _, result := range relatedStatus { relatedNodes = append(relatedNodes, result) } return validateMerkleTreeProof(hashes, flags, relatedNodes, merkleRoot) } // TxStatusMerkleRoot creates a merkle tree from a slice of bc.TxVerifyResult func TxStatusMerkleRoot(tvrs []*bc.TxVerifyResult) (root bc.Hash, err error) { nodes := []merkleNode{} for _, tvr := range tvrs { nodes = append(nodes, tvr) } return merkleRoot(nodes) } // TxMerkleRoot creates a merkle tree from a slice of transactions // and returns the root hash of the tree. func TxMerkleRoot(transactions []*bc.Tx) (root bc.Hash, err error) { nodes := []merkleNode{} for _, tx := range transactions { nodes = append(nodes, &tx.ID) } return merkleRoot(nodes) } // prevPowerOfTwo returns the largest power of two that is smaller than a given number. // In other words, for some input n, the prevPowerOfTwo k is a power of two such that // k < n <= 2k. This is a helper function used during the calculation of a merkle tree. func prevPowerOfTwo(n int) int { // If the number is a power of two, divide it by 2 and return. if n&(n-1) == 0 { return n / 2 } // Otherwise, find the previous PoT. exponent := uint(math.Log2(float64(n))) return 1 << exponent // 2^exponent }