OSDN Git Service

update gm
[bytom/bytom.git] / protocol / bc / types / merkle.go
1 // +build !gm
2
3 package types
4
5 import (
6         "container/list"
7         "io"
8         "math"
9
10         "gopkg.in/fatih/set.v0"
11
12         "github.com/bytom/bytom/crypto/sha3pool"
13         "github.com/bytom/bytom/protocol/bc"
14 )
15
16 // merkleFlag represent the type of merkle tree node, it's used to generate the structure of merkle tree
17 // Bitcoin has only two flags, which zero means the hash of assist node. And one means the hash of the related
18 // transaction node or it's parents, which distinguish them according to the height of the tree. But in the bytom,
19 // the height of transaction node is not fixed, so we need three flags to distinguish these nodes.
20 const (
21         // FlagAssist represent assist node
22         FlagAssist = iota
23         // FlagTxParent represent the parent of transaction of node
24         FlagTxParent
25         // FlagTxLeaf represent transaction of node
26         FlagTxLeaf
27 )
28
29 var (
30         leafPrefix     = []byte{0x00}
31         interiorPrefix = []byte{0x01}
32 )
33
34 type merkleNode interface {
35         WriteTo(io.Writer) (int64, error)
36 }
37
38 func merkleRoot(nodes []merkleNode) (root bc.Hash, err error) {
39         switch {
40         case len(nodes) == 0:
41                 return bc.EmptyStringHash, nil
42
43         case len(nodes) == 1:
44                 root = leafMerkleHash(nodes[0])
45                 return root, nil
46
47         default:
48                 k := prevPowerOfTwo(len(nodes))
49                 left, err := merkleRoot(nodes[:k])
50                 if err != nil {
51                         return root, err
52                 }
53
54                 right, err := merkleRoot(nodes[k:])
55                 if err != nil {
56                         return root, err
57                 }
58
59                 root = interiorMerkleHash(&left, &right)
60                 return root, nil
61         }
62 }
63
64 func interiorMerkleHash(left merkleNode, right merkleNode) (hash bc.Hash) {
65         h := sha3pool.Get256()
66         defer sha3pool.Put256(h)
67         h.Write(interiorPrefix)
68         left.WriteTo(h)
69         right.WriteTo(h)
70         hash.ReadFrom(h)
71         return hash
72 }
73
74 func leafMerkleHash(node merkleNode) (hash bc.Hash) {
75         h := sha3pool.Get256()
76         defer sha3pool.Put256(h)
77         h.Write(leafPrefix)
78         node.WriteTo(h)
79         hash.ReadFrom(h)
80         return hash
81 }
82
83 type merkleTreeNode struct {
84         hash  bc.Hash
85         left  *merkleTreeNode
86         right *merkleTreeNode
87 }
88
89 // buildMerkleTree construct a merkle tree based on the provide node data
90 func buildMerkleTree(rawDatas []merkleNode) *merkleTreeNode {
91         switch len(rawDatas) {
92         case 0:
93                 return nil
94         case 1:
95                 rawData := rawDatas[0]
96                 merkleHash := leafMerkleHash(rawData)
97                 node := newMerkleTreeNode(merkleHash, nil, nil)
98                 return node
99         default:
100                 k := prevPowerOfTwo(len(rawDatas))
101                 left := buildMerkleTree(rawDatas[:k])
102                 right := buildMerkleTree(rawDatas[k:])
103                 merkleHash := interiorMerkleHash(&left.hash, &right.hash)
104                 node := newMerkleTreeNode(merkleHash, left, right)
105                 return node
106         }
107 }
108
109 func (node *merkleTreeNode) getMerkleTreeProof(merkleHashSet *set.Set) ([]*bc.Hash, []uint8) {
110         var hashes []*bc.Hash
111         var flags []uint8
112
113         if node.left == nil && node.right == nil {
114                 if key := node.hash.String(); merkleHashSet.Has(key) {
115                         hashes = append(hashes, &node.hash)
116                         flags = append(flags, FlagTxLeaf)
117                         return hashes, flags
118                 }
119                 return hashes, flags
120         }
121         var leftHashes, rightHashes []*bc.Hash
122         var leftFlags, rightFlags []uint8
123         if node.left != nil {
124                 leftHashes, leftFlags = node.left.getMerkleTreeProof(merkleHashSet)
125         }
126         if node.right != nil {
127                 rightHashes, rightFlags = node.right.getMerkleTreeProof(merkleHashSet)
128         }
129         leftFind, rightFind := len(leftHashes) > 0, len(rightHashes) > 0
130
131         if leftFind || rightFind {
132                 flags = append(flags, FlagTxParent)
133         } else {
134                 return hashes, flags
135         }
136
137         if leftFind {
138                 hashes = append(hashes, leftHashes...)
139                 flags = append(flags, leftFlags...)
140         } else {
141                 hashes = append(hashes, &node.left.hash)
142                 flags = append(flags, FlagAssist)
143         }
144
145         if rightFind {
146                 hashes = append(hashes, rightHashes...)
147                 flags = append(flags, rightFlags...)
148         } else {
149                 hashes = append(hashes, &node.right.hash)
150                 flags = append(flags, FlagAssist)
151         }
152         return hashes, flags
153 }
154
155 func getMerkleTreeProof(rawDatas []merkleNode, relatedRawDatas []merkleNode) ([]*bc.Hash, []uint8) {
156         merkleTree := buildMerkleTree(rawDatas)
157         if merkleTree == nil {
158                 return []*bc.Hash{}, []uint8{}
159         }
160         merkleHashSet := set.New()
161         for _, data := range relatedRawDatas {
162                 merkleHash := leafMerkleHash(data)
163                 merkleHashSet.Add(merkleHash.String())
164         }
165         if merkleHashSet.Size() == 0 {
166                 return []*bc.Hash{&merkleTree.hash}, []uint8{FlagAssist}
167         }
168         return merkleTree.getMerkleTreeProof(merkleHashSet)
169 }
170
171 func (node *merkleTreeNode) getMerkleTreeProofByFlags(flagList *list.List) []*bc.Hash {
172         var hashes []*bc.Hash
173
174         if flagList.Len() == 0 {
175                 return hashes
176         }
177         flagEle := flagList.Front()
178         flag := flagEle.Value.(uint8)
179         flagList.Remove(flagEle)
180
181         if flag == FlagTxLeaf || flag == FlagAssist {
182                 hashes = append(hashes, &node.hash)
183                 return hashes
184         }
185         if node.left != nil {
186                 leftHashes := node.left.getMerkleTreeProofByFlags(flagList)
187                 hashes = append(hashes, leftHashes...)
188         }
189         if node.right != nil {
190                 rightHashes := node.right.getMerkleTreeProofByFlags(flagList)
191                 hashes = append(hashes, rightHashes...)
192         }
193         return hashes
194 }
195
196 func getMerkleTreeProofByFlags(rawDatas []merkleNode, flagList *list.List) []*bc.Hash {
197         tree := buildMerkleTree(rawDatas)
198         return tree.getMerkleTreeProofByFlags(flagList)
199 }
200
201 // GetTxMerkleTreeProof return a proof of merkle tree, which used to proof the transaction does
202 // exist in the merkle tree
203 func GetTxMerkleTreeProof(txs []*Tx, relatedTxs []*Tx) ([]*bc.Hash, []uint8) {
204         var rawDatas []merkleNode
205         var relatedRawDatas []merkleNode
206         for _, tx := range txs {
207                 rawDatas = append(rawDatas, &tx.ID)
208         }
209         for _, relatedTx := range relatedTxs {
210                 relatedRawDatas = append(relatedRawDatas, &relatedTx.ID)
211         }
212         return getMerkleTreeProof(rawDatas, relatedRawDatas)
213 }
214
215 // getMerkleRootByProof caculate the merkle root hash according to the proof
216 func getMerkleRootByProof(hashList *list.List, flagList *list.List, merkleHashes *list.List) bc.Hash {
217         if flagList.Len() == 0 || hashList.Len() == 0 {
218                 return bc.EmptyStringHash
219         }
220         flagEle := flagList.Front()
221         flag := flagEle.Value.(uint8)
222         flagList.Remove(flagEle)
223         switch flag {
224         case FlagAssist:
225                 {
226                         hash := hashList.Front()
227                         hashList.Remove(hash)
228                         return hash.Value.(bc.Hash)
229                 }
230         case FlagTxLeaf:
231                 {
232                         if merkleHashes.Len() == 0 {
233                                 return bc.EmptyStringHash
234                         }
235                         hashEle := hashList.Front()
236                         hash := hashEle.Value.(bc.Hash)
237                         relatedHashEle := merkleHashes.Front()
238                         relatedHash := relatedHashEle.Value.(bc.Hash)
239                         if hash == relatedHash {
240                                 hashList.Remove(hashEle)
241                                 merkleHashes.Remove(relatedHashEle)
242                                 return hash
243                         }
244                 }
245         case FlagTxParent:
246                 {
247                         leftHash := getMerkleRootByProof(hashList, flagList, merkleHashes)
248                         rightHash := getMerkleRootByProof(hashList, flagList, merkleHashes)
249                         hash := interiorMerkleHash(&leftHash, &rightHash)
250                         return hash
251                 }
252         }
253         return bc.EmptyStringHash
254 }
255
256 func newMerkleTreeNode(merkleHash bc.Hash, left *merkleTreeNode, right *merkleTreeNode) *merkleTreeNode {
257         return &merkleTreeNode{
258                 hash:  merkleHash,
259                 left:  left,
260                 right: right,
261         }
262 }
263
264 // ValidateMerkleTreeProof caculate the merkle root according to the hash of node and the flags
265 // only if the merkle root by caculated equals to the specify merkle root, and the merkle tree
266 // contains all of the related raw datas, the validate result will be true.
267 func validateMerkleTreeProof(hashes []*bc.Hash, flags []uint8, relatedNodes []merkleNode, merkleRoot bc.Hash) bool {
268         merkleHashes := list.New()
269         for _, relatedNode := range relatedNodes {
270                 merkleHashes.PushBack(leafMerkleHash(relatedNode))
271         }
272         hashList := list.New()
273         for _, hash := range hashes {
274                 hashList.PushBack(*hash)
275         }
276         flagList := list.New()
277         for _, flag := range flags {
278                 flagList.PushBack(flag)
279         }
280         root := getMerkleRootByProof(hashList, flagList, merkleHashes)
281         return root == merkleRoot && merkleHashes.Len() == 0
282 }
283
284 // ValidateTxMerkleTreeProof validate the merkle tree of transactions
285 func ValidateTxMerkleTreeProof(hashes []*bc.Hash, flags []uint8, relatedHashes []*bc.Hash, merkleRoot bc.Hash) bool {
286         var relatedNodes []merkleNode
287         for _, hash := range relatedHashes {
288                 relatedNodes = append(relatedNodes, hash)
289         }
290         return validateMerkleTreeProof(hashes, flags, relatedNodes, merkleRoot)
291 }
292
293 // TxMerkleRoot creates a merkle tree from a slice of transactions
294 // and returns the root hash of the tree.
295 func TxMerkleRoot(transactions []*bc.Tx) (root bc.Hash, err error) {
296         nodes := []merkleNode{}
297         for _, tx := range transactions {
298                 nodes = append(nodes, &tx.ID)
299         }
300         return merkleRoot(nodes)
301 }
302
303 // prevPowerOfTwo returns the largest power of two that is smaller than a given number.
304 // In other words, for some input n, the prevPowerOfTwo k is a power of two such that
305 // k < n <= 2k. This is a helper function used during the calculation of a merkle tree.
306 func prevPowerOfTwo(n int) int {
307         // If the number is a power of two, divide it by 2 and return.
308         if n&(n-1) == 0 {
309                 return n / 2
310         }
311
312         // Otherwise, find the previous PoT.
313         exponent := uint(math.Log2(float64(n)))
314         return 1 << exponent // 2^exponent
315 }