10 "github.com/tjfoc/gmsm/sm3"
11 "gopkg.in/fatih/set.v0"
13 "github.com/bytom/bytom/protocol/bc"
16 // merkleFlag represent the type of merkle tree node, it's used to generate the structure of merkle tree
17 // Bitcoin has only two flags, which zero means the hash of assist node. And one means the hash of the related
18 // transaction node or it's parents, which distinguish them according to the height of the tree. But in the bytom,
19 // the height of transaction node is not fixed, so we need three flags to distinguish these nodes.
21 // FlagAssist represent assist node
23 // FlagTxParent represent the parent of transaction of node
25 // FlagTxLeaf represent transaction of node
30 leafPrefix = []byte{0x00}
31 interiorPrefix = []byte{0x01}
34 type merkleNode interface {
35 WriteTo(io.Writer) (int64, error)
38 func merkleRoot(nodes []merkleNode) (root bc.Hash, err error) {
41 return bc.EmptyStringHash, nil
44 root = leafMerkleHash(nodes[0])
48 k := prevPowerOfTwo(len(nodes))
49 left, err := merkleRoot(nodes[:k])
54 right, err := merkleRoot(nodes[k:])
59 root = interiorMerkleHash(&left, &right)
64 func interiorMerkleHash(left merkleNode, right merkleNode) (hash bc.Hash) {
66 hasher.Write(interiorPrefix)
70 copy(b32[:], hasher.Sum(nil))
71 hash = bc.NewHash(b32)
75 func leafMerkleHash(node merkleNode) (hash bc.Hash) {
77 hasher.Write(leafPrefix)
80 copy(b32[:], hasher.Sum(nil))
81 hash = bc.NewHash(b32)
85 type merkleTreeNode struct {
91 // buildMerkleTree construct a merkle tree based on the provide node data
92 func buildMerkleTree(rawDatas []merkleNode) *merkleTreeNode {
93 switch len(rawDatas) {
97 rawData := rawDatas[0]
98 merkleHash := leafMerkleHash(rawData)
99 node := newMerkleTreeNode(merkleHash, nil, nil)
102 k := prevPowerOfTwo(len(rawDatas))
103 left := buildMerkleTree(rawDatas[:k])
104 right := buildMerkleTree(rawDatas[k:])
105 merkleHash := interiorMerkleHash(&left.hash, &right.hash)
106 node := newMerkleTreeNode(merkleHash, left, right)
111 func (node *merkleTreeNode) getMerkleTreeProof(merkleHashSet *set.Set) ([]*bc.Hash, []uint8) {
112 var hashes []*bc.Hash
115 if node.left == nil && node.right == nil {
116 if key := node.hash.String(); merkleHashSet.Has(key) {
117 hashes = append(hashes, &node.hash)
118 flags = append(flags, FlagTxLeaf)
123 var leftHashes, rightHashes []*bc.Hash
124 var leftFlags, rightFlags []uint8
125 if node.left != nil {
126 leftHashes, leftFlags = node.left.getMerkleTreeProof(merkleHashSet)
128 if node.right != nil {
129 rightHashes, rightFlags = node.right.getMerkleTreeProof(merkleHashSet)
131 leftFind, rightFind := len(leftHashes) > 0, len(rightHashes) > 0
133 if leftFind || rightFind {
134 flags = append(flags, FlagTxParent)
140 hashes = append(hashes, leftHashes...)
141 flags = append(flags, leftFlags...)
143 hashes = append(hashes, &node.left.hash)
144 flags = append(flags, FlagAssist)
148 hashes = append(hashes, rightHashes...)
149 flags = append(flags, rightFlags...)
151 hashes = append(hashes, &node.right.hash)
152 flags = append(flags, FlagAssist)
157 func getMerkleTreeProof(rawDatas []merkleNode, relatedRawDatas []merkleNode) ([]*bc.Hash, []uint8) {
158 merkleTree := buildMerkleTree(rawDatas)
159 if merkleTree == nil {
160 return []*bc.Hash{}, []uint8{}
162 merkleHashSet := set.New()
163 for _, data := range relatedRawDatas {
164 merkleHash := leafMerkleHash(data)
165 merkleHashSet.Add(merkleHash.String())
167 if merkleHashSet.Size() == 0 {
168 return []*bc.Hash{&merkleTree.hash}, []uint8{FlagAssist}
170 return merkleTree.getMerkleTreeProof(merkleHashSet)
173 func (node *merkleTreeNode) getMerkleTreeProofByFlags(flagList *list.List) []*bc.Hash {
174 var hashes []*bc.Hash
176 if flagList.Len() == 0 {
179 flagEle := flagList.Front()
180 flag := flagEle.Value.(uint8)
181 flagList.Remove(flagEle)
183 if flag == FlagTxLeaf || flag == FlagAssist {
184 hashes = append(hashes, &node.hash)
187 if node.left != nil {
188 leftHashes := node.left.getMerkleTreeProofByFlags(flagList)
189 hashes = append(hashes, leftHashes...)
191 if node.right != nil {
192 rightHashes := node.right.getMerkleTreeProofByFlags(flagList)
193 hashes = append(hashes, rightHashes...)
198 func getMerkleTreeProofByFlags(rawDatas []merkleNode, flagList *list.List) []*bc.Hash {
199 tree := buildMerkleTree(rawDatas)
200 return tree.getMerkleTreeProofByFlags(flagList)
203 // GetTxMerkleTreeProof return a proof of merkle tree, which used to proof the transaction does
204 // exist in the merkle tree
205 func GetTxMerkleTreeProof(txs []*Tx, relatedTxs []*Tx) ([]*bc.Hash, []uint8) {
206 var rawDatas []merkleNode
207 var relatedRawDatas []merkleNode
208 for _, tx := range txs {
209 rawDatas = append(rawDatas, &tx.ID)
211 for _, relatedTx := range relatedTxs {
212 relatedRawDatas = append(relatedRawDatas, &relatedTx.ID)
214 return getMerkleTreeProof(rawDatas, relatedRawDatas)
217 // getMerkleRootByProof caculate the merkle root hash according to the proof
218 func getMerkleRootByProof(hashList *list.List, flagList *list.List, merkleHashes *list.List) bc.Hash {
219 if flagList.Len() == 0 || hashList.Len() == 0 {
220 return bc.EmptyStringHash
222 flagEle := flagList.Front()
223 flag := flagEle.Value.(uint8)
224 flagList.Remove(flagEle)
228 hash := hashList.Front()
229 hashList.Remove(hash)
230 return hash.Value.(bc.Hash)
234 if merkleHashes.Len() == 0 {
235 return bc.EmptyStringHash
237 hashEle := hashList.Front()
238 hash := hashEle.Value.(bc.Hash)
239 relatedHashEle := merkleHashes.Front()
240 relatedHash := relatedHashEle.Value.(bc.Hash)
241 if hash == relatedHash {
242 hashList.Remove(hashEle)
243 merkleHashes.Remove(relatedHashEle)
249 leftHash := getMerkleRootByProof(hashList, flagList, merkleHashes)
250 rightHash := getMerkleRootByProof(hashList, flagList, merkleHashes)
251 hash := interiorMerkleHash(&leftHash, &rightHash)
255 return bc.EmptyStringHash
258 func newMerkleTreeNode(merkleHash bc.Hash, left *merkleTreeNode, right *merkleTreeNode) *merkleTreeNode {
259 return &merkleTreeNode{
266 // ValidateMerkleTreeProof caculate the merkle root according to the hash of node and the flags
267 // only if the merkle root by caculated equals to the specify merkle root, and the merkle tree
268 // contains all of the related raw datas, the validate result will be true.
269 func validateMerkleTreeProof(hashes []*bc.Hash, flags []uint8, relatedNodes []merkleNode, merkleRoot bc.Hash) bool {
270 merkleHashes := list.New()
271 for _, relatedNode := range relatedNodes {
272 merkleHashes.PushBack(leafMerkleHash(relatedNode))
274 hashList := list.New()
275 for _, hash := range hashes {
276 hashList.PushBack(*hash)
278 flagList := list.New()
279 for _, flag := range flags {
280 flagList.PushBack(flag)
282 root := getMerkleRootByProof(hashList, flagList, merkleHashes)
283 return root == merkleRoot && merkleHashes.Len() == 0
286 // ValidateTxMerkleTreeProof validate the merkle tree of transactions
287 func ValidateTxMerkleTreeProof(hashes []*bc.Hash, flags []uint8, relatedHashes []*bc.Hash, merkleRoot bc.Hash) bool {
288 var relatedNodes []merkleNode
289 for _, hash := range relatedHashes {
290 relatedNodes = append(relatedNodes, hash)
292 return validateMerkleTreeProof(hashes, flags, relatedNodes, merkleRoot)
295 // TxMerkleRoot creates a merkle tree from a slice of transactions
296 // and returns the root hash of the tree.
297 func TxMerkleRoot(transactions []*bc.Tx) (root bc.Hash, err error) {
298 nodes := []merkleNode{}
299 for _, tx := range transactions {
300 nodes = append(nodes, &tx.ID)
302 return merkleRoot(nodes)
305 // prevPowerOfTwo returns the largest power of two that is smaller than a given number.
306 // In other words, for some input n, the prevPowerOfTwo k is a power of two such that
307 // k < n <= 2k. This is a helper function used during the calculation of a merkle tree.
308 func prevPowerOfTwo(n int) int {
309 // If the number is a power of two, divide it by 2 and return.
314 // Otherwise, find the previous PoT.
315 exponent := uint(math.Log2(float64(n)))
316 return 1 << exponent // 2^exponent