OSDN Git Service

add sm2 (#1193)
[bytom/bytom.git] / crypto / sm2 / sm2.go
1 /*
2 Copyright Suzhou Tongji Fintech Research Institute 2017 All Rights Reserved.
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7         http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 */
15
16 package sm2
17
18 // reference to ecdsa
19 import (
20         "bytes"
21         "crypto"
22         "crypto/aes"
23         "crypto/cipher"
24         "crypto/elliptic"
25         "crypto/rand"
26         "crypto/sha512"
27         "encoding/asn1"
28         "encoding/binary"
29         "errors"
30         "io"
31         "math/big"
32
33         "github.com/bytom/crypto/sm3"
34 )
35
36 const (
37         aesIV = "IV for <SM2> CTR"
38 )
39
40 type PublicKey struct {
41         elliptic.Curve
42         X, Y *big.Int
43 }
44
45 type PrivateKey struct {
46         PublicKey
47         D *big.Int
48 }
49
50 type sm2Signature struct {
51         R, S *big.Int
52 }
53
54 // The SM2's private key contains the public key
55 func (priv *PrivateKey) Public() crypto.PublicKey {
56         return &priv.PublicKey
57 }
58
59 func SignDigitToSignData(r, s *big.Int) ([]byte, error) {
60         return asn1.Marshal(sm2Signature{r, s})
61 }
62
63 func SignDataToSignDigit(sign []byte) (*big.Int, *big.Int, error) {
64         var sm2Sign sm2Signature
65
66         _, err := asn1.Unmarshal(sign, &sm2Sign)
67         if err != nil {
68                 return nil, nil, err
69         }
70         return sm2Sign.R, sm2Sign.S, nil
71 }
72
73 // sign format = 30 + len(z) + 02 + len(r) + r + 02 + len(s) + s, z being what follows its size, ie 02+len(r)+r+02+len(s)+s
74 func (priv *PrivateKey) Sign(rand io.Reader, msg []byte, opts crypto.SignerOpts) ([]byte, error) {
75         r, s, err := Sign(priv, msg)
76         if err != nil {
77                 return nil, err
78         }
79         return asn1.Marshal(sm2Signature{r, s})
80 }
81
82 func (priv *PrivateKey) Decrypt(data []byte) ([]byte, error) {
83         return Decrypt(priv, data)
84 }
85
86 func (pub *PublicKey) Verify(msg []byte, sign []byte) bool {
87         var sm2Sign sm2Signature
88
89         _, err := asn1.Unmarshal(sign, &sm2Sign)
90         if err != nil {
91                 return false
92         }
93         return Verify(pub, msg, sm2Sign.R, sm2Sign.S)
94 }
95
96 func (pub *PublicKey) Encrypt(data []byte) ([]byte, error) {
97         return Encrypt(pub, data)
98 }
99
100 var one = new(big.Int).SetInt64(1)
101
102 func intToBytes(x int) []byte {
103         var buf = make([]byte, 4)
104
105         binary.BigEndian.PutUint32(buf, uint32(x))
106         return buf
107 }
108
109 func kdf(x, y []byte, length int) ([]byte, bool) {
110         var c []byte
111
112         ct := 1
113         h := sm3.New()
114         x = append(x, y...)
115         for i, j := 0, (length+31)/32; i < j; i++ {
116                 h.Reset()
117                 h.Write(x)
118                 h.Write(intToBytes(ct))
119                 hash := h.Sum(nil)
120                 if i+1 == j && length%32 != 0 {
121                         c = append(c, hash[:length%32]...)
122                 } else {
123                         c = append(c, hash...)
124                 }
125                 ct++
126         }
127         for i := 0; i < length; i++ {
128                 if c[i] != 0 {
129                         return c, true
130                 }
131         }
132         return c, false
133 }
134
135 func randFieldElement(c elliptic.Curve, rand io.Reader) (k *big.Int, err error) {
136         params := c.Params()
137         b := make([]byte, params.BitSize/8+8)
138         _, err = io.ReadFull(rand, b)
139         if err != nil {
140                 return
141         }
142         k = new(big.Int).SetBytes(b)
143         n := new(big.Int).Sub(params.N, one)
144         k.Mod(k, n)
145         k.Add(k, one)
146         return
147 }
148
149 func GenerateKey() (*PrivateKey, error) {
150         c := P256Sm2()
151         k, err := randFieldElement(c, rand.Reader)
152         if err != nil {
153                 return nil, err
154         }
155         priv := new(PrivateKey)
156         priv.PublicKey.Curve = c
157         priv.D = k
158         priv.PublicKey.X, priv.PublicKey.Y = c.ScalarBaseMult(k.Bytes())
159         return priv, nil
160 }
161
162 var errZeroParam = errors.New("zero parameter")
163
164 func Sign(priv *PrivateKey, hash []byte) (r, s *big.Int, err error) {
165         entropylen := (priv.Curve.Params().BitSize + 7) / 16
166         if entropylen > 32 {
167                 entropylen = 32
168         }
169         entropy := make([]byte, entropylen)
170         _, err = io.ReadFull(rand.Reader, entropy)
171         if err != nil {
172                 return
173         }
174
175         // Initialize an SHA-512 hash context; digest ...
176         md := sha512.New()
177         md.Write(priv.D.Bytes()) // the private key,
178         md.Write(entropy)        // the entropy,
179         md.Write(hash)           // and the input hash;
180         key := md.Sum(nil)[:32]  // and compute ChopMD-256(SHA-512),
181         // which is an indifferentiable MAC.
182
183         // Create an AES-CTR instance to use as a CSPRNG.
184         block, err := aes.NewCipher(key)
185         if err != nil {
186                 return nil, nil, err
187         }
188
189         // Create a CSPRNG that xors a stream of zeros with
190         // the output of the AES-CTR instance.
191         csprng := cipher.StreamReader{
192                 R: zeroReader,
193                 S: cipher.NewCTR(block, []byte(aesIV)),
194         }
195
196         // See [NSA] 3.4.1
197         c := priv.PublicKey.Curve
198         N := c.Params().N
199         if N.Sign() == 0 {
200                 return nil, nil, errZeroParam
201         }
202         var k *big.Int
203         e := new(big.Int).SetBytes(hash)
204         for { // 调整算法细节以实现SM2
205                 for {
206                         k, err = randFieldElement(c, csprng)
207                         if err != nil {
208                                 r = nil
209                                 return
210                         }
211                         r, _ = priv.Curve.ScalarBaseMult(k.Bytes())
212                         r.Add(r, e)
213                         r.Mod(r, N)
214                         if r.Sign() != 0 {
215                                 break
216                         }
217                         if t := new(big.Int).Add(r, k); t.Cmp(N) == 0 {
218                                 break
219                         }
220                 }
221                 rD := new(big.Int).Mul(priv.D, r)
222                 s = new(big.Int).Sub(k, rD)
223                 d1 := new(big.Int).Add(priv.D, one)
224                 d1Inv := new(big.Int).ModInverse(d1, N)
225                 s.Mul(s, d1Inv)
226                 s.Mod(s, N)
227                 if s.Sign() != 0 {
228                         break
229                 }
230         }
231         return
232 }
233
234 func Verify(pub *PublicKey, hash []byte, r, s *big.Int) bool {
235         c := pub.Curve
236         N := c.Params().N
237
238         if r.Sign() <= 0 || s.Sign() <= 0 {
239                 return false
240         }
241         if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 {
242                 return false
243         }
244
245         // 调整算法细节以实现SM2
246         t := new(big.Int).Add(r, s)
247         t.Mod(t, N)
248         if t.Sign() == 0 {
249                 return false
250         }
251
252         var x *big.Int
253         x1, y1 := c.ScalarBaseMult(s.Bytes())
254         x2, y2 := c.ScalarMult(pub.X, pub.Y, t.Bytes())
255         x, _ = c.Add(x1, y1, x2, y2)
256
257         e := new(big.Int).SetBytes(hash)
258         x.Add(x, e)
259         x.Mod(x, N)
260         return x.Cmp(r) == 0
261 }
262
263 func Sm2Sign(priv *PrivateKey, msg, uid []byte) (r, s *big.Int, err error) {
264         za, err := ZA(&priv.PublicKey, uid)
265         if err != nil {
266                 return nil, nil, err
267         }
268         e, err := msgHash(za, msg)
269         if err != nil {
270                 return nil, nil, err
271         }
272         c := priv.PublicKey.Curve
273         N := c.Params().N
274         if N.Sign() == 0 {
275                 return nil, nil, errZeroParam
276         }
277         var k *big.Int
278         for { // 调整算法细节以实现SM2
279                 for {
280                         k, err = randFieldElement(c, rand.Reader)
281                         if err != nil {
282                                 r = nil
283                                 return
284                         }
285                         r, _ = priv.Curve.ScalarBaseMult(k.Bytes())
286                         r.Add(r, e)
287                         r.Mod(r, N)
288                         if r.Sign() != 0 {
289                                 break
290                         }
291                         if t := new(big.Int).Add(r, k); t.Cmp(N) == 0 {
292                                 break
293                         }
294                 }
295                 rD := new(big.Int).Mul(priv.D, r)
296                 s = new(big.Int).Sub(k, rD)
297                 d1 := new(big.Int).Add(priv.D, one)
298                 d1Inv := new(big.Int).ModInverse(d1, N)
299                 s.Mul(s, d1Inv)
300                 s.Mod(s, N)
301                 if s.Sign() != 0 {
302                         break
303                 }
304         }
305         return
306 }
307
308 func Sm2Verify(pub *PublicKey, msg, uid []byte, r, s *big.Int) bool {
309         c := pub.Curve
310         N := c.Params().N
311         one := new(big.Int).SetInt64(1)
312         if r.Cmp(one) < 0 || s.Cmp(one) < 0 {
313                 return false
314         }
315         if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 {
316                 return false
317         }
318         za, err := ZA(pub, uid)
319         if err != nil {
320                 return false
321         }
322         e, err := msgHash(za, msg)
323         if err != nil {
324                 return false
325         }
326         t := new(big.Int).Add(r, s)
327         t.Mod(t, N)
328         if t.Sign() == 0 {
329                 return false
330         }
331         var x *big.Int
332         x1, y1 := c.ScalarBaseMult(s.Bytes())
333         x2, y2 := c.ScalarMult(pub.X, pub.Y, t.Bytes())
334         x, _ = c.Add(x1, y1, x2, y2)
335
336         x.Add(x, e)
337         x.Mod(x, N)
338         return x.Cmp(r) == 0
339 }
340
341 func msgHash(za, msg []byte) (*big.Int, error) {
342         e := sm3.New()
343         e.Write(za)
344         e.Write(msg)
345         return new(big.Int).SetBytes(e.Sum(nil)[:32]), nil
346 }
347
348 // ZA = H256(ENTLA || IDA || a || b || xG || yG || xA || yA)
349 func ZA(pub *PublicKey, uid []byte) ([]byte, error) {
350         za := sm3.New()
351         uidLen := len(uid)
352         if uidLen >= 8192 {
353                 return []byte{}, errors.New("SM2: uid too large")
354         }
355         Entla := uint16(8 * uidLen)
356         za.Write([]byte{byte((Entla >> 8) & 0xFF)})
357         za.Write([]byte{byte(Entla & 0xFF)})
358         za.Write(uid)
359         za.Write(sm2P256ToBig(&sm2P256.a).Bytes())
360         za.Write(sm2P256.B.Bytes())
361         za.Write(sm2P256.Gx.Bytes())
362         za.Write(sm2P256.Gy.Bytes())
363
364         xBuf := pub.X.Bytes()
365         yBuf := pub.Y.Bytes()
366         if n := len(xBuf); n < 32 {
367                 xBuf = append(zeroByteSlice[:32-n], xBuf...)
368         }
369         za.Write(xBuf)
370         za.Write(yBuf)
371         return za.Sum(nil)[:32], nil
372 }
373
374 // 32byte
375 var zeroByteSlice = []byte{
376         0, 0, 0, 0,
377         0, 0, 0, 0,
378         0, 0, 0, 0,
379         0, 0, 0, 0,
380         0, 0, 0, 0,
381         0, 0, 0, 0,
382         0, 0, 0, 0,
383         0, 0, 0, 0,
384 }
385
386 /*
387  * sm2密文结构如下:
388  *  x
389  *  y
390  *  hash
391  *  CipherText
392  */
393 func Encrypt(pub *PublicKey, data []byte) ([]byte, error) {
394         length := len(data)
395         for {
396                 c := []byte{}
397                 curve := pub.Curve
398                 k, err := randFieldElement(curve, rand.Reader)
399                 if err != nil {
400                         return nil, err
401                 }
402                 x1, y1 := curve.ScalarBaseMult(k.Bytes())
403                 x2, y2 := curve.ScalarMult(pub.X, pub.Y, k.Bytes())
404                 x1Buf := x1.Bytes()
405                 y1Buf := y1.Bytes()
406                 x2Buf := x2.Bytes()
407                 y2Buf := y2.Bytes()
408                 if n := len(x1Buf); n < 32 {
409                         x1Buf = append(zeroByteSlice[:32-n], x1Buf...)
410                 }
411                 if n := len(y1Buf); n < 32 {
412                         y1Buf = append(zeroByteSlice[:32-n], y1Buf...)
413                 }
414                 if n := len(x2Buf); n < 32 {
415                         x2Buf = append(zeroByteSlice[:32-n], x2Buf...)
416                 }
417                 if n := len(y2Buf); n < 32 {
418                         y2Buf = append(zeroByteSlice[:32-n], y2Buf...)
419                 }
420                 c = append(c, x1Buf...) // x分量
421                 c = append(c, y1Buf...) // y分量
422                 tm := []byte{}
423                 tm = append(tm, x2Buf...)
424                 tm = append(tm, data...)
425                 tm = append(tm, y2Buf...)
426                 h := sm3.Sm3Sum(tm)
427                 c = append(c, h...)
428                 ct, ok := kdf(x2Buf, y2Buf, length) // 密文
429                 if !ok {
430                         continue
431                 }
432                 c = append(c, ct...)
433                 for i := 0; i < length; i++ {
434                         c[96+i] ^= data[i]
435                 }
436                 return append([]byte{0x04}, c...), nil
437         }
438 }
439
440 func Decrypt(priv *PrivateKey, data []byte) ([]byte, error) {
441         data = data[1:]
442         length := len(data) - 96
443         curve := priv.Curve
444         x := new(big.Int).SetBytes(data[:32])
445         y := new(big.Int).SetBytes(data[32:64])
446         x2, y2 := curve.ScalarMult(x, y, priv.D.Bytes())
447         x2Buf := x2.Bytes()
448         y2Buf := y2.Bytes()
449         if n := len(x2Buf); n < 32 {
450                 x2Buf = append(zeroByteSlice[:32-n], x2Buf...)
451         }
452         if n := len(y2Buf); n < 32 {
453                 y2Buf = append(zeroByteSlice[:32-n], y2Buf...)
454         }
455         c, ok := kdf(x2Buf, y2Buf, length)
456         if !ok {
457                 return nil, errors.New("Decrypt: failed to decrypt")
458         }
459         for i := 0; i < length; i++ {
460                 c[i] ^= data[i+96]
461         }
462         tm := []byte{}
463         tm = append(tm, x2Buf...)
464         tm = append(tm, c...)
465         tm = append(tm, y2Buf...)
466         h := sm3.Sm3Sum(tm)
467         if bytes.Compare(h, data[64:96]) != 0 {
468                 return c, errors.New("Decrypt: failed to decrypt")
469         }
470         return c, nil
471 }
472
473 type zr struct {
474         io.Reader
475 }
476
477 func (z *zr) Read(dst []byte) (n int, err error) {
478         for i := range dst {
479                 dst[i] = 0
480         }
481         return len(dst), nil
482 }
483
484 var zeroReader = &zr{}
485
486 func getLastBit(a *big.Int) uint {
487         return a.Bit(0)
488 }
489
490 func Compress(a *PublicKey) []byte {
491         buf := []byte{}
492         yp := getLastBit(a.Y)
493         buf = append(buf, a.X.Bytes()...)
494         if n := len(a.X.Bytes()); n < 32 {
495                 buf = append(zeroByteSlice[:(32-n)], buf...)
496         }
497         buf = append([]byte{byte(yp)}, buf...)
498         return buf
499 }
500
501 func Decompress(a []byte) *PublicKey {
502         var aa, xx, xx3 sm2P256FieldElement
503
504         P256Sm2()
505         x := new(big.Int).SetBytes(a[1:])
506         curve := sm2P256
507         sm2P256FromBig(&xx, x)
508         sm2P256Square(&xx3, &xx)       // x3 = x ^ 2
509         sm2P256Mul(&xx3, &xx3, &xx)    // x3 = x ^ 2 * x
510         sm2P256Mul(&aa, &curve.a, &xx) // a = a * x
511         sm2P256Add(&xx3, &xx3, &aa)
512         sm2P256Add(&xx3, &xx3, &curve.b)
513
514         y2 := sm2P256ToBig(&xx3)
515         y := new(big.Int).ModSqrt(y2, sm2P256.P)
516         if getLastBit(y) != uint(a[0]) {
517                 y.Sub(sm2P256.P, y)
518         }
519         return &PublicKey{
520                 Curve: P256Sm2(),
521                 X:     x,
522                 Y:     y,
523         }
524 }