OSDN Git Service

update gm
[bytom/bytom.git] / protocol / bc / types / merkle_gm.go
1 // +build gm
2
3 package types
4
5 import (
6         "container/list"
7         "io"
8         "math"
9
10         "github.com/tjfoc/gmsm/sm3"
11         "gopkg.in/fatih/set.v0"
12
13         "github.com/bytom/bytom/protocol/bc"
14 )
15
16 // merkleFlag represent the type of merkle tree node, it's used to generate the structure of merkle tree
17 // Bitcoin has only two flags, which zero means the hash of assist node. And one means the hash of the related
18 // transaction node or it's parents, which distinguish them according to the height of the tree. But in the bytom,
19 // the height of transaction node is not fixed, so we need three flags to distinguish these nodes.
20 const (
21         // FlagAssist represent assist node
22         FlagAssist = iota
23         // FlagTxParent represent the parent of transaction of node
24         FlagTxParent
25         // FlagTxLeaf represent transaction of node
26         FlagTxLeaf
27 )
28
29 var (
30         leafPrefix     = []byte{0x00}
31         interiorPrefix = []byte{0x01}
32 )
33
34 type merkleNode interface {
35         WriteTo(io.Writer) (int64, error)
36 }
37
38 func merkleRoot(nodes []merkleNode) (root bc.Hash, err error) {
39         switch {
40         case len(nodes) == 0:
41                 return bc.EmptyStringHash, nil
42
43         case len(nodes) == 1:
44                 root = leafMerkleHash(nodes[0])
45                 return root, nil
46
47         default:
48                 k := prevPowerOfTwo(len(nodes))
49                 left, err := merkleRoot(nodes[:k])
50                 if err != nil {
51                         return root, err
52                 }
53
54                 right, err := merkleRoot(nodes[k:])
55                 if err != nil {
56                         return root, err
57                 }
58
59                 root = interiorMerkleHash(&left, &right)
60                 return root, nil
61         }
62 }
63
64 func interiorMerkleHash(left merkleNode, right merkleNode) (hash bc.Hash) {
65         hasher := sm3.New()
66         hasher.Write(interiorPrefix)
67         left.WriteTo(hasher)
68         right.WriteTo(hasher)
69         var b32 [32]byte
70         copy(b32[:], hasher.Sum(nil))
71         hash = bc.NewHash(b32)
72         return hash
73 }
74
75 func leafMerkleHash(node merkleNode) (hash bc.Hash) {
76         hasher := sm3.New()
77         hasher.Write(leafPrefix)
78         node.WriteTo(hasher)
79         var b32 [32]byte
80         copy(b32[:], hasher.Sum(nil))
81         hash = bc.NewHash(b32)
82         return hash
83 }
84
85 type merkleTreeNode struct {
86         hash  bc.Hash
87         left  *merkleTreeNode
88         right *merkleTreeNode
89 }
90
91 // buildMerkleTree construct a merkle tree based on the provide node data
92 func buildMerkleTree(rawDatas []merkleNode) *merkleTreeNode {
93         switch len(rawDatas) {
94         case 0:
95                 return nil
96         case 1:
97                 rawData := rawDatas[0]
98                 merkleHash := leafMerkleHash(rawData)
99                 node := newMerkleTreeNode(merkleHash, nil, nil)
100                 return node
101         default:
102                 k := prevPowerOfTwo(len(rawDatas))
103                 left := buildMerkleTree(rawDatas[:k])
104                 right := buildMerkleTree(rawDatas[k:])
105                 merkleHash := interiorMerkleHash(&left.hash, &right.hash)
106                 node := newMerkleTreeNode(merkleHash, left, right)
107                 return node
108         }
109 }
110
111 func (node *merkleTreeNode) getMerkleTreeProof(merkleHashSet *set.Set) ([]*bc.Hash, []uint8) {
112         var hashes []*bc.Hash
113         var flags []uint8
114
115         if node.left == nil && node.right == nil {
116                 if key := node.hash.String(); merkleHashSet.Has(key) {
117                         hashes = append(hashes, &node.hash)
118                         flags = append(flags, FlagTxLeaf)
119                         return hashes, flags
120                 }
121                 return hashes, flags
122         }
123         var leftHashes, rightHashes []*bc.Hash
124         var leftFlags, rightFlags []uint8
125         if node.left != nil {
126                 leftHashes, leftFlags = node.left.getMerkleTreeProof(merkleHashSet)
127         }
128         if node.right != nil {
129                 rightHashes, rightFlags = node.right.getMerkleTreeProof(merkleHashSet)
130         }
131         leftFind, rightFind := len(leftHashes) > 0, len(rightHashes) > 0
132
133         if leftFind || rightFind {
134                 flags = append(flags, FlagTxParent)
135         } else {
136                 return hashes, flags
137         }
138
139         if leftFind {
140                 hashes = append(hashes, leftHashes...)
141                 flags = append(flags, leftFlags...)
142         } else {
143                 hashes = append(hashes, &node.left.hash)
144                 flags = append(flags, FlagAssist)
145         }
146
147         if rightFind {
148                 hashes = append(hashes, rightHashes...)
149                 flags = append(flags, rightFlags...)
150         } else {
151                 hashes = append(hashes, &node.right.hash)
152                 flags = append(flags, FlagAssist)
153         }
154         return hashes, flags
155 }
156
157 func getMerkleTreeProof(rawDatas []merkleNode, relatedRawDatas []merkleNode) ([]*bc.Hash, []uint8) {
158         merkleTree := buildMerkleTree(rawDatas)
159         if merkleTree == nil {
160                 return []*bc.Hash{}, []uint8{}
161         }
162         merkleHashSet := set.New()
163         for _, data := range relatedRawDatas {
164                 merkleHash := leafMerkleHash(data)
165                 merkleHashSet.Add(merkleHash.String())
166         }
167         if merkleHashSet.Size() == 0 {
168                 return []*bc.Hash{&merkleTree.hash}, []uint8{FlagAssist}
169         }
170         return merkleTree.getMerkleTreeProof(merkleHashSet)
171 }
172
173 func (node *merkleTreeNode) getMerkleTreeProofByFlags(flagList *list.List) []*bc.Hash {
174         var hashes []*bc.Hash
175
176         if flagList.Len() == 0 {
177                 return hashes
178         }
179         flagEle := flagList.Front()
180         flag := flagEle.Value.(uint8)
181         flagList.Remove(flagEle)
182
183         if flag == FlagTxLeaf || flag == FlagAssist {
184                 hashes = append(hashes, &node.hash)
185                 return hashes
186         }
187         if node.left != nil {
188                 leftHashes := node.left.getMerkleTreeProofByFlags(flagList)
189                 hashes = append(hashes, leftHashes...)
190         }
191         if node.right != nil {
192                 rightHashes := node.right.getMerkleTreeProofByFlags(flagList)
193                 hashes = append(hashes, rightHashes...)
194         }
195         return hashes
196 }
197
198 func getMerkleTreeProofByFlags(rawDatas []merkleNode, flagList *list.List) []*bc.Hash {
199         tree := buildMerkleTree(rawDatas)
200         return tree.getMerkleTreeProofByFlags(flagList)
201 }
202
203 // GetTxMerkleTreeProof return a proof of merkle tree, which used to proof the transaction does
204 // exist in the merkle tree
205 func GetTxMerkleTreeProof(txs []*Tx, relatedTxs []*Tx) ([]*bc.Hash, []uint8) {
206         var rawDatas []merkleNode
207         var relatedRawDatas []merkleNode
208         for _, tx := range txs {
209                 rawDatas = append(rawDatas, &tx.ID)
210         }
211         for _, relatedTx := range relatedTxs {
212                 relatedRawDatas = append(relatedRawDatas, &relatedTx.ID)
213         }
214         return getMerkleTreeProof(rawDatas, relatedRawDatas)
215 }
216
217 // getMerkleRootByProof caculate the merkle root hash according to the proof
218 func getMerkleRootByProof(hashList *list.List, flagList *list.List, merkleHashes *list.List) bc.Hash {
219         if flagList.Len() == 0 || hashList.Len() == 0 {
220                 return bc.EmptyStringHash
221         }
222         flagEle := flagList.Front()
223         flag := flagEle.Value.(uint8)
224         flagList.Remove(flagEle)
225         switch flag {
226         case FlagAssist:
227                 {
228                         hash := hashList.Front()
229                         hashList.Remove(hash)
230                         return hash.Value.(bc.Hash)
231                 }
232         case FlagTxLeaf:
233                 {
234                         if merkleHashes.Len() == 0 {
235                                 return bc.EmptyStringHash
236                         }
237                         hashEle := hashList.Front()
238                         hash := hashEle.Value.(bc.Hash)
239                         relatedHashEle := merkleHashes.Front()
240                         relatedHash := relatedHashEle.Value.(bc.Hash)
241                         if hash == relatedHash {
242                                 hashList.Remove(hashEle)
243                                 merkleHashes.Remove(relatedHashEle)
244                                 return hash
245                         }
246                 }
247         case FlagTxParent:
248                 {
249                         leftHash := getMerkleRootByProof(hashList, flagList, merkleHashes)
250                         rightHash := getMerkleRootByProof(hashList, flagList, merkleHashes)
251                         hash := interiorMerkleHash(&leftHash, &rightHash)
252                         return hash
253                 }
254         }
255         return bc.EmptyStringHash
256 }
257
258 func newMerkleTreeNode(merkleHash bc.Hash, left *merkleTreeNode, right *merkleTreeNode) *merkleTreeNode {
259         return &merkleTreeNode{
260                 hash:  merkleHash,
261                 left:  left,
262                 right: right,
263         }
264 }
265
266 // ValidateMerkleTreeProof caculate the merkle root according to the hash of node and the flags
267 // only if the merkle root by caculated equals to the specify merkle root, and the merkle tree
268 // contains all of the related raw datas, the validate result will be true.
269 func validateMerkleTreeProof(hashes []*bc.Hash, flags []uint8, relatedNodes []merkleNode, merkleRoot bc.Hash) bool {
270         merkleHashes := list.New()
271         for _, relatedNode := range relatedNodes {
272                 merkleHashes.PushBack(leafMerkleHash(relatedNode))
273         }
274         hashList := list.New()
275         for _, hash := range hashes {
276                 hashList.PushBack(*hash)
277         }
278         flagList := list.New()
279         for _, flag := range flags {
280                 flagList.PushBack(flag)
281         }
282         root := getMerkleRootByProof(hashList, flagList, merkleHashes)
283         return root == merkleRoot && merkleHashes.Len() == 0
284 }
285
286 // ValidateTxMerkleTreeProof validate the merkle tree of transactions
287 func ValidateTxMerkleTreeProof(hashes []*bc.Hash, flags []uint8, relatedHashes []*bc.Hash, merkleRoot bc.Hash) bool {
288         var relatedNodes []merkleNode
289         for _, hash := range relatedHashes {
290                 relatedNodes = append(relatedNodes, hash)
291         }
292         return validateMerkleTreeProof(hashes, flags, relatedNodes, merkleRoot)
293 }
294
295 // TxMerkleRoot creates a merkle tree from a slice of transactions
296 // and returns the root hash of the tree.
297 func TxMerkleRoot(transactions []*bc.Tx) (root bc.Hash, err error) {
298         nodes := []merkleNode{}
299         for _, tx := range transactions {
300                 nodes = append(nodes, &tx.ID)
301         }
302         return merkleRoot(nodes)
303 }
304
305 // prevPowerOfTwo returns the largest power of two that is smaller than a given number.
306 // In other words, for some input n, the prevPowerOfTwo k is a power of two such that
307 // k < n <= 2k. This is a helper function used during the calculation of a merkle tree.
308 func prevPowerOfTwo(n int) int {
309         // If the number is a power of two, divide it by 2 and return.
310         if n&(n-1) == 0 {
311                 return n / 2
312         }
313
314         // Otherwise, find the previous PoT.
315         exponent := uint(math.Log2(float64(n)))
316         return 1 << exponent // 2^exponent
317 }