OSDN Git Service

Hulk did something
[bytom/vapor.git] / crypto / sm2 / sm2.go
1 package sm2
2
3 // reference to ecdsa
4 import (
5         "bytes"
6         "crypto"
7         "crypto/aes"
8         "crypto/cipher"
9         "crypto/elliptic"
10         "crypto/rand"
11         "crypto/sha512"
12         "encoding/asn1"
13         "encoding/binary"
14         "errors"
15         "io"
16         "math/big"
17
18         "github.com/vapor/crypto/sm3"
19 )
20
21 const (
22         aesIV = "IV for <SM2> CTR"
23 )
24
25 type PublicKey struct {
26         elliptic.Curve
27         X, Y *big.Int
28 }
29
30 type PrivateKey struct {
31         PublicKey
32         D *big.Int
33 }
34
35 type sm2Signature struct {
36         R, S *big.Int
37 }
38
39 // The SM2's private key contains the public key
40 func (priv *PrivateKey) Public() crypto.PublicKey {
41         return &priv.PublicKey
42 }
43
44 func SignDigitToSignData(r, s *big.Int) ([]byte, error) {
45         return asn1.Marshal(sm2Signature{r, s})
46 }
47
48 func SignDataToSignDigit(sign []byte) (*big.Int, *big.Int, error) {
49         var sm2Sign sm2Signature
50
51         _, err := asn1.Unmarshal(sign, &sm2Sign)
52         if err != nil {
53                 return nil, nil, err
54         }
55         return sm2Sign.R, sm2Sign.S, nil
56 }
57
58 // sign format = 30 + len(z) + 02 + len(r) + r + 02 + len(s) + s, z being what follows its size, ie 02+len(r)+r+02+len(s)+s
59 func (priv *PrivateKey) Sign(rand io.Reader, msg []byte, opts crypto.SignerOpts) ([]byte, error) {
60         r, s, err := Sign(priv, msg)
61         if err != nil {
62                 return nil, err
63         }
64         return asn1.Marshal(sm2Signature{r, s})
65 }
66
67 func (priv *PrivateKey) Decrypt(data []byte) ([]byte, error) {
68         return Decrypt(priv, data)
69 }
70
71 func (pub *PublicKey) Verify(msg []byte, sign []byte) bool {
72         var sm2Sign sm2Signature
73
74         _, err := asn1.Unmarshal(sign, &sm2Sign)
75         if err != nil {
76                 return false
77         }
78         return Verify(pub, msg, sm2Sign.R, sm2Sign.S)
79 }
80
81 func (pub *PublicKey) Encrypt(data []byte) ([]byte, error) {
82         return Encrypt(pub, data)
83 }
84
85 var one = new(big.Int).SetInt64(1)
86
87 func intToBytes(x int) []byte {
88         var buf = make([]byte, 4)
89
90         binary.BigEndian.PutUint32(buf, uint32(x))
91         return buf
92 }
93
94 func kdf(x, y []byte, length int) ([]byte, bool) {
95         var c []byte
96
97         ct := 1
98         h := sm3.New()
99         x = append(x, y...)
100         for i, j := 0, (length+31)/32; i < j; i++ {
101                 h.Reset()
102                 h.Write(x)
103                 h.Write(intToBytes(ct))
104                 hash := h.Sum(nil)
105                 if i+1 == j && length%32 != 0 {
106                         c = append(c, hash[:length%32]...)
107                 } else {
108                         c = append(c, hash...)
109                 }
110                 ct++
111         }
112         for i := 0; i < length; i++ {
113                 if c[i] != 0 {
114                         return c, true
115                 }
116         }
117         return c, false
118 }
119
120 func randFieldElement(c elliptic.Curve, rand io.Reader) (k *big.Int, err error) {
121         params := c.Params()
122         b := make([]byte, params.BitSize/8+8)
123         _, err = io.ReadFull(rand, b)
124         if err != nil {
125                 return
126         }
127         k = new(big.Int).SetBytes(b)
128         n := new(big.Int).Sub(params.N, one)
129         k.Mod(k, n)
130         k.Add(k, one)
131         return
132 }
133
134 func GenerateKey() (*PrivateKey, error) {
135         c := P256Sm2()
136         k, err := randFieldElement(c, rand.Reader)
137         if err != nil {
138                 return nil, err
139         }
140         priv := new(PrivateKey)
141         priv.PublicKey.Curve = c
142         priv.D = k
143         priv.PublicKey.X, priv.PublicKey.Y = c.ScalarBaseMult(k.Bytes())
144         return priv, nil
145 }
146
147 var errZeroParam = errors.New("zero parameter")
148
149 func Sign(priv *PrivateKey, hash []byte) (r, s *big.Int, err error) {
150         entropylen := (priv.Curve.Params().BitSize + 7) / 16
151         if entropylen > 32 {
152                 entropylen = 32
153         }
154         entropy := make([]byte, entropylen)
155         _, err = io.ReadFull(rand.Reader, entropy)
156         if err != nil {
157                 return
158         }
159
160         // Initialize an SHA-512 hash context; digest ...
161         md := sha512.New()
162         md.Write(priv.D.Bytes()) // the private key,
163         md.Write(entropy)        // the entropy,
164         md.Write(hash)           // and the input hash;
165         key := md.Sum(nil)[:32]  // and compute ChopMD-256(SHA-512),
166         // which is an indifferentiable MAC.
167
168         // Create an AES-CTR instance to use as a CSPRNG.
169         block, err := aes.NewCipher(key)
170         if err != nil {
171                 return nil, nil, err
172         }
173
174         // Create a CSPRNG that xors a stream of zeros with
175         // the output of the AES-CTR instance.
176         csprng := cipher.StreamReader{
177                 R: zeroReader,
178                 S: cipher.NewCTR(block, []byte(aesIV)),
179         }
180
181         // See [NSA] 3.4.1
182         c := priv.PublicKey.Curve
183         N := c.Params().N
184         if N.Sign() == 0 {
185                 return nil, nil, errZeroParam
186         }
187         var k *big.Int
188         e := new(big.Int).SetBytes(hash)
189         for { // 调整算法细节以实现SM2
190                 for {
191                         k, err = randFieldElement(c, csprng)
192                         if err != nil {
193                                 r = nil
194                                 return
195                         }
196                         r, _ = priv.Curve.ScalarBaseMult(k.Bytes())
197                         r.Add(r, e)
198                         r.Mod(r, N)
199                         if r.Sign() != 0 {
200                                 break
201                         }
202                         if t := new(big.Int).Add(r, k); t.Cmp(N) == 0 {
203                                 break
204                         }
205                 }
206                 rD := new(big.Int).Mul(priv.D, r)
207                 s = new(big.Int).Sub(k, rD)
208                 d1 := new(big.Int).Add(priv.D, one)
209                 d1Inv := new(big.Int).ModInverse(d1, N)
210                 s.Mul(s, d1Inv)
211                 s.Mod(s, N)
212                 if s.Sign() != 0 {
213                         break
214                 }
215         }
216         return
217 }
218
219 func Verify(pub *PublicKey, hash []byte, r, s *big.Int) bool {
220         c := pub.Curve
221         N := c.Params().N
222
223         if r.Sign() <= 0 || s.Sign() <= 0 {
224                 return false
225         }
226         if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 {
227                 return false
228         }
229
230         // 调整算法细节以实现SM2
231         t := new(big.Int).Add(r, s)
232         t.Mod(t, N)
233         if t.Sign() == 0 {
234                 return false
235         }
236
237         var x *big.Int
238         x1, y1 := c.ScalarBaseMult(s.Bytes())
239         x2, y2 := c.ScalarMult(pub.X, pub.Y, t.Bytes())
240         x, _ = c.Add(x1, y1, x2, y2)
241
242         e := new(big.Int).SetBytes(hash)
243         x.Add(x, e)
244         x.Mod(x, N)
245         return x.Cmp(r) == 0
246 }
247
248 func Sm2Sign(priv *PrivateKey, msg, uid []byte) (r, s *big.Int, err error) {
249         za, err := ZA(&priv.PublicKey, uid)
250         if err != nil {
251                 return nil, nil, err
252         }
253         e, err := msgHash(za, msg)
254         if err != nil {
255                 return nil, nil, err
256         }
257         c := priv.PublicKey.Curve
258         N := c.Params().N
259         if N.Sign() == 0 {
260                 return nil, nil, errZeroParam
261         }
262         var k *big.Int
263         for { // 调整算法细节以实现SM2
264                 for {
265                         k, err = randFieldElement(c, rand.Reader)
266                         if err != nil {
267                                 r = nil
268                                 return
269                         }
270                         r, _ = priv.Curve.ScalarBaseMult(k.Bytes())
271                         r.Add(r, e)
272                         r.Mod(r, N)
273                         if r.Sign() != 0 {
274                                 break
275                         }
276                         if t := new(big.Int).Add(r, k); t.Cmp(N) == 0 {
277                                 break
278                         }
279                 }
280                 rD := new(big.Int).Mul(priv.D, r)
281                 s = new(big.Int).Sub(k, rD)
282                 d1 := new(big.Int).Add(priv.D, one)
283                 d1Inv := new(big.Int).ModInverse(d1, N)
284                 s.Mul(s, d1Inv)
285                 s.Mod(s, N)
286                 if s.Sign() != 0 {
287                         break
288                 }
289         }
290         return
291 }
292
293 func Sm2Verify(pub *PublicKey, msg, uid []byte, r, s *big.Int) bool {
294         c := pub.Curve
295         N := c.Params().N
296         one := new(big.Int).SetInt64(1)
297         if r.Cmp(one) < 0 || s.Cmp(one) < 0 {
298                 return false
299         }
300         if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 {
301                 return false
302         }
303         za, err := ZA(pub, uid)
304         if err != nil {
305                 return false
306         }
307         e, err := msgHash(za, msg)
308         if err != nil {
309                 return false
310         }
311         t := new(big.Int).Add(r, s)
312         t.Mod(t, N)
313         if t.Sign() == 0 {
314                 return false
315         }
316         var x *big.Int
317         x1, y1 := c.ScalarBaseMult(s.Bytes())
318         x2, y2 := c.ScalarMult(pub.X, pub.Y, t.Bytes())
319         x, _ = c.Add(x1, y1, x2, y2)
320
321         x.Add(x, e)
322         x.Mod(x, N)
323         return x.Cmp(r) == 0
324 }
325
326 func msgHash(za, msg []byte) (*big.Int, error) {
327         e := sm3.New()
328         e.Write(za)
329         e.Write(msg)
330         return new(big.Int).SetBytes(e.Sum(nil)[:32]), nil
331 }
332
333 // ZA = H256(ENTLA || IDA || a || b || xG || yG || xA || yA)
334 func ZA(pub *PublicKey, uid []byte) ([]byte, error) {
335         za := sm3.New()
336         uidLen := len(uid)
337         if uidLen >= 8192 {
338                 return []byte{}, errors.New("SM2: uid too large")
339         }
340         Entla := uint16(8 * uidLen)
341         za.Write([]byte{byte((Entla >> 8) & 0xFF)})
342         za.Write([]byte{byte(Entla & 0xFF)})
343         za.Write(uid)
344         za.Write(sm2P256ToBig(&sm2P256.a).Bytes())
345         za.Write(sm2P256.B.Bytes())
346         za.Write(sm2P256.Gx.Bytes())
347         za.Write(sm2P256.Gy.Bytes())
348
349         xBuf := pub.X.Bytes()
350         yBuf := pub.Y.Bytes()
351         if n := len(xBuf); n < 32 {
352                 xBuf = append(zeroByteSlice[:32-n], xBuf...)
353         }
354         za.Write(xBuf)
355         za.Write(yBuf)
356         return za.Sum(nil)[:32], nil
357 }
358
359 // 32byte
360 var zeroByteSlice = []byte{
361         0, 0, 0, 0,
362         0, 0, 0, 0,
363         0, 0, 0, 0,
364         0, 0, 0, 0,
365         0, 0, 0, 0,
366         0, 0, 0, 0,
367         0, 0, 0, 0,
368         0, 0, 0, 0,
369 }
370
371 /*
372  * sm2密文结构如下:
373  *  x
374  *  y
375  *  hash
376  *  CipherText
377  */
378 func Encrypt(pub *PublicKey, data []byte) ([]byte, error) {
379         length := len(data)
380         for {
381                 c := []byte{}
382                 curve := pub.Curve
383                 k, err := randFieldElement(curve, rand.Reader)
384                 if err != nil {
385                         return nil, err
386                 }
387                 x1, y1 := curve.ScalarBaseMult(k.Bytes())
388                 x2, y2 := curve.ScalarMult(pub.X, pub.Y, k.Bytes())
389                 x1Buf := x1.Bytes()
390                 y1Buf := y1.Bytes()
391                 x2Buf := x2.Bytes()
392                 y2Buf := y2.Bytes()
393                 if n := len(x1Buf); n < 32 {
394                         x1Buf = append(zeroByteSlice[:32-n], x1Buf...)
395                 }
396                 if n := len(y1Buf); n < 32 {
397                         y1Buf = append(zeroByteSlice[:32-n], y1Buf...)
398                 }
399                 if n := len(x2Buf); n < 32 {
400                         x2Buf = append(zeroByteSlice[:32-n], x2Buf...)
401                 }
402                 if n := len(y2Buf); n < 32 {
403                         y2Buf = append(zeroByteSlice[:32-n], y2Buf...)
404                 }
405                 c = append(c, x1Buf...) // x分量
406                 c = append(c, y1Buf...) // y分量
407                 tm := []byte{}
408                 tm = append(tm, x2Buf...)
409                 tm = append(tm, data...)
410                 tm = append(tm, y2Buf...)
411                 h := sm3.Sm3Sum(tm)
412                 c = append(c, h...)
413                 ct, ok := kdf(x2Buf, y2Buf, length) // 密文
414                 if !ok {
415                         continue
416                 }
417                 c = append(c, ct...)
418                 for i := 0; i < length; i++ {
419                         c[96+i] ^= data[i]
420                 }
421                 return append([]byte{0x04}, c...), nil
422         }
423 }
424
425 func Decrypt(priv *PrivateKey, data []byte) ([]byte, error) {
426         data = data[1:]
427         length := len(data) - 96
428         curve := priv.Curve
429         x := new(big.Int).SetBytes(data[:32])
430         y := new(big.Int).SetBytes(data[32:64])
431         x2, y2 := curve.ScalarMult(x, y, priv.D.Bytes())
432         x2Buf := x2.Bytes()
433         y2Buf := y2.Bytes()
434         if n := len(x2Buf); n < 32 {
435                 x2Buf = append(zeroByteSlice[:32-n], x2Buf...)
436         }
437         if n := len(y2Buf); n < 32 {
438                 y2Buf = append(zeroByteSlice[:32-n], y2Buf...)
439         }
440         c, ok := kdf(x2Buf, y2Buf, length)
441         if !ok {
442                 return nil, errors.New("Decrypt: failed to decrypt")
443         }
444         for i := 0; i < length; i++ {
445                 c[i] ^= data[i+96]
446         }
447         tm := []byte{}
448         tm = append(tm, x2Buf...)
449         tm = append(tm, c...)
450         tm = append(tm, y2Buf...)
451         h := sm3.Sm3Sum(tm)
452         if bytes.Compare(h, data[64:96]) != 0 {
453                 return c, errors.New("Decrypt: failed to decrypt")
454         }
455         return c, nil
456 }
457
458 type zr struct {
459         io.Reader
460 }
461
462 func (z *zr) Read(dst []byte) (n int, err error) {
463         for i := range dst {
464                 dst[i] = 0
465         }
466         return len(dst), nil
467 }
468
469 var zeroReader = &zr{}
470
471 func getLastBit(a *big.Int) uint {
472         return a.Bit(0)
473 }
474
475 func Compress(a *PublicKey) []byte {
476         buf := []byte{}
477         yp := getLastBit(a.Y)
478         buf = append(buf, a.X.Bytes()...)
479         if n := len(a.X.Bytes()); n < 32 {
480                 buf = append(zeroByteSlice[:(32-n)], buf...)
481         }
482         buf = append([]byte{byte(yp)}, buf...)
483         return buf
484 }
485
486 func Decompress(a []byte) *PublicKey {
487         var aa, xx, xx3 sm2P256FieldElement
488
489         P256Sm2()
490         x := new(big.Int).SetBytes(a[1:])
491         curve := sm2P256
492         sm2P256FromBig(&xx, x)
493         sm2P256Square(&xx3, &xx)       // x3 = x ^ 2
494         sm2P256Mul(&xx3, &xx3, &xx)    // x3 = x ^ 2 * x
495         sm2P256Mul(&aa, &curve.a, &xx) // a = a * x
496         sm2P256Add(&xx3, &xx3, &aa)
497         sm2P256Add(&xx3, &xx3, &curve.b)
498
499         y2 := sm2P256ToBig(&xx3)
500         y := new(big.Int).ModSqrt(y2, sm2P256.P)
501         if getLastBit(y) != uint(a[0]) {
502                 y.Sub(sm2P256.P, y)
503         }
504         return &PublicKey{
505                 Curve: P256Sm2(),
506                 X:     x,
507                 Y:     y,
508         }
509 }