OSDN Git Service

Merge pull request #41 from Bytom/dev
[bytom/vapor.git] / protocol / bc / types / merkle.go
1 package types
2
3 import (
4         "container/list"
5         "io"
6         "math"
7
8         "gopkg.in/fatih/set.v0"
9
10         "github.com/vapor/crypto/sha3pool"
11         "github.com/vapor/protocol/bc"
12 )
13
14 // merkleFlag represent the type of merkle tree node, it's used to generate the structure of merkle tree
15 // Bitcoin has only two flags, which zero means the hash of assist node. And one means the hash of the related
16 // transaction node or it's parents, which distinguish them according to the height of the tree. But in the bytom,
17 // the height of transaction node is not fixed, so we need three flags to distinguish these nodes.
18 const (
19         // FlagAssist represent assist node
20         FlagAssist = iota
21         // FlagTxParent represent the parent of transaction of node
22         FlagTxParent
23         // FlagTxLeaf represent transaction of node
24         FlagTxLeaf
25 )
26
27 var (
28         leafPrefix     = []byte{0x00}
29         interiorPrefix = []byte{0x01}
30 )
31
32 type merkleNode interface {
33         WriteTo(io.Writer) (int64, error)
34 }
35
36 func merkleRoot(nodes []merkleNode) (root bc.Hash, err error) {
37         switch {
38         case len(nodes) == 0:
39                 return bc.EmptyStringHash, nil
40
41         case len(nodes) == 1:
42                 root = leafMerkleHash(nodes[0])
43                 return root, nil
44
45         default:
46                 k := prevPowerOfTwo(len(nodes))
47                 left, err := merkleRoot(nodes[:k])
48                 if err != nil {
49                         return root, err
50                 }
51
52                 right, err := merkleRoot(nodes[k:])
53                 if err != nil {
54                         return root, err
55                 }
56
57                 root = interiorMerkleHash(&left, &right)
58                 return root, nil
59         }
60 }
61
62 func interiorMerkleHash(left merkleNode, right merkleNode) (hash bc.Hash) {
63         h := sha3pool.Get256()
64         defer sha3pool.Put256(h)
65         h.Write(interiorPrefix)
66         left.WriteTo(h)
67         right.WriteTo(h)
68         hash.ReadFrom(h)
69         return hash
70 }
71
72 func leafMerkleHash(node merkleNode) (hash bc.Hash) {
73         h := sha3pool.Get256()
74         defer sha3pool.Put256(h)
75         h.Write(leafPrefix)
76         node.WriteTo(h)
77         hash.ReadFrom(h)
78         return hash
79 }
80
81 type merkleTreeNode struct {
82         hash  bc.Hash
83         left  *merkleTreeNode
84         right *merkleTreeNode
85 }
86
87 // buildMerkleTree construct a merkle tree based on the provide node data
88 func buildMerkleTree(rawDatas []merkleNode) *merkleTreeNode {
89         switch len(rawDatas) {
90         case 0:
91                 return nil
92         case 1:
93                 rawData := rawDatas[0]
94                 merkleHash := leafMerkleHash(rawData)
95                 node := newMerkleTreeNode(merkleHash, nil, nil)
96                 return node
97         default:
98                 k := prevPowerOfTwo(len(rawDatas))
99                 left := buildMerkleTree(rawDatas[:k])
100                 right := buildMerkleTree(rawDatas[k:])
101                 merkleHash := interiorMerkleHash(&left.hash, &right.hash)
102                 node := newMerkleTreeNode(merkleHash, left, right)
103                 return node
104         }
105 }
106
107 func (node *merkleTreeNode) getMerkleTreeProof(merkleHashSet *set.Set) ([]*bc.Hash, []uint8) {
108         var hashes []*bc.Hash
109         var flags []uint8
110
111         if node.left == nil && node.right == nil {
112                 if key := node.hash.String(); merkleHashSet.Has(key) {
113                         hashes = append(hashes, &node.hash)
114                         flags = append(flags, FlagTxLeaf)
115                         return hashes, flags
116                 }
117                 return hashes, flags
118         }
119         var leftHashes, rightHashes []*bc.Hash
120         var leftFlags, rightFlags []uint8
121         if node.left != nil {
122                 leftHashes, leftFlags = node.left.getMerkleTreeProof(merkleHashSet)
123         }
124         if node.right != nil {
125                 rightHashes, rightFlags = node.right.getMerkleTreeProof(merkleHashSet)
126         }
127         leftFind, rightFind := len(leftHashes) > 0, len(rightHashes) > 0
128
129         if leftFind || rightFind {
130                 flags = append(flags, FlagTxParent)
131         } else {
132                 return hashes, flags
133         }
134
135         if leftFind {
136                 hashes = append(hashes, leftHashes...)
137                 flags = append(flags, leftFlags...)
138         } else {
139                 hashes = append(hashes, &node.left.hash)
140                 flags = append(flags, FlagAssist)
141         }
142
143         if rightFind {
144                 hashes = append(hashes, rightHashes...)
145                 flags = append(flags, rightFlags...)
146         } else {
147                 hashes = append(hashes, &node.right.hash)
148                 flags = append(flags, FlagAssist)
149         }
150         return hashes, flags
151 }
152
153 func getMerkleTreeProof(rawDatas []merkleNode, relatedRawDatas []merkleNode) ([]*bc.Hash, []uint8) {
154         merkleTree := buildMerkleTree(rawDatas)
155         if merkleTree == nil {
156                 return []*bc.Hash{}, []uint8{}
157         }
158         merkleHashSet := set.New()
159         for _, data := range relatedRawDatas {
160                 merkleHash := leafMerkleHash(data)
161                 merkleHashSet.Add(merkleHash.String())
162         }
163         if merkleHashSet.Size() == 0 {
164                 return []*bc.Hash{&merkleTree.hash}, []uint8{FlagAssist}
165         }
166         return merkleTree.getMerkleTreeProof(merkleHashSet)
167 }
168
169 func (node *merkleTreeNode) getMerkleTreeProofByFlags(flagList *list.List) []*bc.Hash {
170         var hashes []*bc.Hash
171
172         if flagList.Len() == 0 {
173                 return hashes
174         }
175         flagEle := flagList.Front()
176         flag := flagEle.Value.(uint8)
177         flagList.Remove(flagEle)
178
179         if flag == FlagTxLeaf || flag == FlagAssist {
180                 hashes = append(hashes, &node.hash)
181                 return hashes
182         }
183         if node.left != nil {
184                 leftHashes := node.left.getMerkleTreeProofByFlags(flagList)
185                 hashes = append(hashes, leftHashes...)
186         }
187         if node.right != nil {
188                 rightHashes := node.right.getMerkleTreeProofByFlags(flagList)
189                 hashes = append(hashes, rightHashes...)
190         }
191         return hashes
192 }
193
194 func getMerkleTreeProofByFlags(rawDatas []merkleNode, flagList *list.List) []*bc.Hash {
195         tree := buildMerkleTree(rawDatas)
196         return tree.getMerkleTreeProofByFlags(flagList)
197 }
198
199 // GetTxMerkleTreeProof return a proof of merkle tree, which used to proof the transaction does
200 // exist in the merkle tree
201 func GetTxMerkleTreeProof(txs []*Tx, relatedTxs []*Tx) ([]*bc.Hash, []uint8) {
202         var rawDatas []merkleNode
203         var relatedRawDatas []merkleNode
204         for _, tx := range txs {
205                 rawDatas = append(rawDatas, &tx.ID)
206         }
207         for _, relatedTx := range relatedTxs {
208                 relatedRawDatas = append(relatedRawDatas, &relatedTx.ID)
209         }
210         return getMerkleTreeProof(rawDatas, relatedRawDatas)
211 }
212
213 // GetStatusMerkleTreeProof return a proof of merkle tree, which used to proof the status of transaction is valid
214 func GetStatusMerkleTreeProof(statuses []*bc.TxVerifyResult, flags []uint8) []*bc.Hash {
215         var rawDatas []merkleNode
216         for _, status := range statuses {
217                 rawDatas = append(rawDatas, status)
218         }
219         flagList := list.New()
220         for _, flag := range flags {
221                 flagList.PushBack(flag)
222         }
223         return getMerkleTreeProofByFlags(rawDatas, flagList)
224 }
225
226 // getMerkleRootByProof caculate the merkle root hash according to the proof
227 func getMerkleRootByProof(hashList *list.List, flagList *list.List, merkleHashes *list.List) bc.Hash {
228         if flagList.Len() == 0 || hashList.Len() == 0 {
229                 return bc.EmptyStringHash
230         }
231         flagEle := flagList.Front()
232         flag := flagEle.Value.(uint8)
233         flagList.Remove(flagEle)
234         switch flag {
235         case FlagAssist:
236                 {
237                         hash := hashList.Front()
238                         hashList.Remove(hash)
239                         return hash.Value.(bc.Hash)
240                 }
241         case FlagTxLeaf:
242                 {
243                         if merkleHashes.Len() == 0 {
244                                 return bc.EmptyStringHash
245                         }
246                         hashEle := hashList.Front()
247                         hash := hashEle.Value.(bc.Hash)
248                         relatedHashEle := merkleHashes.Front()
249                         relatedHash := relatedHashEle.Value.(bc.Hash)
250                         if hash == relatedHash {
251                                 hashList.Remove(hashEle)
252                                 merkleHashes.Remove(relatedHashEle)
253                                 return hash
254                         }
255                 }
256         case FlagTxParent:
257                 {
258                         leftHash := getMerkleRootByProof(hashList, flagList, merkleHashes)
259                         rightHash := getMerkleRootByProof(hashList, flagList, merkleHashes)
260                         hash := interiorMerkleHash(&leftHash, &rightHash)
261                         return hash
262                 }
263         }
264         return bc.EmptyStringHash
265 }
266
267 func newMerkleTreeNode(merkleHash bc.Hash, left *merkleTreeNode, right *merkleTreeNode) *merkleTreeNode {
268         return &merkleTreeNode{
269                 hash:  merkleHash,
270                 left:  left,
271                 right: right,
272         }
273 }
274
275 // ValidateMerkleTreeProof caculate the merkle root according to the hash of node and the flags
276 // only if the merkle root by caculated equals to the specify merkle root, and the merkle tree
277 // contains all of the related raw datas, the validate result will be true.
278 func validateMerkleTreeProof(hashes []*bc.Hash, flags []uint8, relatedNodes []merkleNode, merkleRoot bc.Hash) bool {
279         merkleHashes := list.New()
280         for _, relatedNode := range relatedNodes {
281                 merkleHashes.PushBack(leafMerkleHash(relatedNode))
282         }
283         hashList := list.New()
284         for _, hash := range hashes {
285                 hashList.PushBack(*hash)
286         }
287         flagList := list.New()
288         for _, flag := range flags {
289                 flagList.PushBack(flag)
290         }
291         root := getMerkleRootByProof(hashList, flagList, merkleHashes)
292         return root == merkleRoot && merkleHashes.Len() == 0
293 }
294
295 // ValidateTxMerkleTreeProof validate the merkle tree of transactions
296 func ValidateTxMerkleTreeProof(hashes []*bc.Hash, flags []uint8, relatedHashes []*bc.Hash, merkleRoot bc.Hash) bool {
297         var relatedNodes []merkleNode
298         for _, hash := range relatedHashes {
299                 relatedNodes = append(relatedNodes, hash)
300         }
301         return validateMerkleTreeProof(hashes, flags, relatedNodes, merkleRoot)
302 }
303
304 // ValidateStatusMerkleTreeProof validate the merkle tree of transaction status
305 func ValidateStatusMerkleTreeProof(hashes []*bc.Hash, flags []uint8, relatedStatus []*bc.TxVerifyResult, merkleRoot bc.Hash) bool {
306         var relatedNodes []merkleNode
307         for _, result := range relatedStatus {
308                 relatedNodes = append(relatedNodes, result)
309         }
310         return validateMerkleTreeProof(hashes, flags, relatedNodes, merkleRoot)
311 }
312
313 // TxStatusMerkleRoot creates a merkle tree from a slice of bc.TxVerifyResult
314 func TxStatusMerkleRoot(tvrs []*bc.TxVerifyResult) (root bc.Hash, err error) {
315         nodes := []merkleNode{}
316         for _, tvr := range tvrs {
317                 nodes = append(nodes, tvr)
318         }
319         return merkleRoot(nodes)
320 }
321
322 // TxMerkleRoot creates a merkle tree from a slice of transactions
323 // and returns the root hash of the tree.
324 func TxMerkleRoot(transactions []*bc.Tx) (root bc.Hash, err error) {
325         nodes := []merkleNode{}
326         for _, tx := range transactions {
327                 nodes = append(nodes, &tx.ID)
328         }
329         return merkleRoot(nodes)
330 }
331
332 // prevPowerOfTwo returns the largest power of two that is smaller than a given number.
333 // In other words, for some input n, the prevPowerOfTwo k is a power of two such that
334 // k < n <= 2k. This is a helper function used during the calculation of a merkle tree.
335 func prevPowerOfTwo(n int) int {
336         // If the number is a power of two, divide it by 2 and return.
337         if n&(n-1) == 0 {
338                 return n / 2
339         }
340
341         // Otherwise, find the previous PoT.
342         exponent := uint(math.Log2(float64(n)))
343         return 1 << exponent // 2^exponent
344 }