2 Copyright Suzhou Tongji Fintech Research Institute 2017 All Rights Reserved.
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
7 http://www.apache.org/licenses/LICENSE-2.0
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
33 "github.com/bytom/crypto/sm3"
37 aesIV = "IV for <SM2> CTR"
40 type PublicKey struct {
45 type PrivateKey struct {
50 type sm2Signature struct {
54 // The SM2's private key contains the public key
55 func (priv *PrivateKey) Public() crypto.PublicKey {
56 return &priv.PublicKey
59 func SignDigitToSignData(r, s *big.Int) ([]byte, error) {
60 return asn1.Marshal(sm2Signature{r, s})
63 func SignDataToSignDigit(sign []byte) (*big.Int, *big.Int, error) {
64 var sm2Sign sm2Signature
66 _, err := asn1.Unmarshal(sign, &sm2Sign)
70 return sm2Sign.R, sm2Sign.S, nil
73 // sign format = 30 + len(z) + 02 + len(r) + r + 02 + len(s) + s, z being what follows its size, ie 02+len(r)+r+02+len(s)+s
74 func (priv *PrivateKey) Sign(rand io.Reader, msg []byte, opts crypto.SignerOpts) ([]byte, error) {
75 r, s, err := Sign(priv, msg)
79 return asn1.Marshal(sm2Signature{r, s})
82 func (priv *PrivateKey) Decrypt(data []byte) ([]byte, error) {
83 return Decrypt(priv, data)
86 func (pub *PublicKey) Verify(msg []byte, sign []byte) bool {
87 var sm2Sign sm2Signature
89 _, err := asn1.Unmarshal(sign, &sm2Sign)
93 return Verify(pub, msg, sm2Sign.R, sm2Sign.S)
96 func (pub *PublicKey) Encrypt(data []byte) ([]byte, error) {
97 return Encrypt(pub, data)
100 var one = new(big.Int).SetInt64(1)
102 func intToBytes(x int) []byte {
103 var buf = make([]byte, 4)
105 binary.BigEndian.PutUint32(buf, uint32(x))
109 func kdf(x, y []byte, length int) ([]byte, bool) {
115 for i, j := 0, (length+31)/32; i < j; i++ {
118 h.Write(intToBytes(ct))
120 if i+1 == j && length%32 != 0 {
121 c = append(c, hash[:length%32]...)
123 c = append(c, hash...)
127 for i := 0; i < length; i++ {
135 func randFieldElement(c elliptic.Curve, rand io.Reader) (k *big.Int, err error) {
137 b := make([]byte, params.BitSize/8+8)
138 _, err = io.ReadFull(rand, b)
142 k = new(big.Int).SetBytes(b)
143 n := new(big.Int).Sub(params.N, one)
149 func GenerateKey() (*PrivateKey, error) {
151 k, err := randFieldElement(c, rand.Reader)
155 priv := new(PrivateKey)
156 priv.PublicKey.Curve = c
158 priv.PublicKey.X, priv.PublicKey.Y = c.ScalarBaseMult(k.Bytes())
162 var errZeroParam = errors.New("zero parameter")
164 func Sign(priv *PrivateKey, hash []byte) (r, s *big.Int, err error) {
165 entropylen := (priv.Curve.Params().BitSize + 7) / 16
169 entropy := make([]byte, entropylen)
170 _, err = io.ReadFull(rand.Reader, entropy)
175 // Initialize an SHA-512 hash context; digest ...
177 md.Write(priv.D.Bytes()) // the private key,
178 md.Write(entropy) // the entropy,
179 md.Write(hash) // and the input hash;
180 key := md.Sum(nil)[:32] // and compute ChopMD-256(SHA-512),
181 // which is an indifferentiable MAC.
183 // Create an AES-CTR instance to use as a CSPRNG.
184 block, err := aes.NewCipher(key)
189 // Create a CSPRNG that xors a stream of zeros with
190 // the output of the AES-CTR instance.
191 csprng := cipher.StreamReader{
193 S: cipher.NewCTR(block, []byte(aesIV)),
197 c := priv.PublicKey.Curve
200 return nil, nil, errZeroParam
203 e := new(big.Int).SetBytes(hash)
204 for { // 调整算法细节以实现SM2
206 k, err = randFieldElement(c, csprng)
211 r, _ = priv.Curve.ScalarBaseMult(k.Bytes())
217 if t := new(big.Int).Add(r, k); t.Cmp(N) == 0 {
221 rD := new(big.Int).Mul(priv.D, r)
222 s = new(big.Int).Sub(k, rD)
223 d1 := new(big.Int).Add(priv.D, one)
224 d1Inv := new(big.Int).ModInverse(d1, N)
234 func Verify(pub *PublicKey, hash []byte, r, s *big.Int) bool {
238 if r.Sign() <= 0 || s.Sign() <= 0 {
241 if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 {
246 t := new(big.Int).Add(r, s)
253 x1, y1 := c.ScalarBaseMult(s.Bytes())
254 x2, y2 := c.ScalarMult(pub.X, pub.Y, t.Bytes())
255 x, _ = c.Add(x1, y1, x2, y2)
257 e := new(big.Int).SetBytes(hash)
263 func Sm2Sign(priv *PrivateKey, msg, uid []byte) (r, s *big.Int, err error) {
264 za, err := ZA(&priv.PublicKey, uid)
268 e, err := msgHash(za, msg)
272 c := priv.PublicKey.Curve
275 return nil, nil, errZeroParam
278 for { // 调整算法细节以实现SM2
280 k, err = randFieldElement(c, rand.Reader)
285 r, _ = priv.Curve.ScalarBaseMult(k.Bytes())
291 if t := new(big.Int).Add(r, k); t.Cmp(N) == 0 {
295 rD := new(big.Int).Mul(priv.D, r)
296 s = new(big.Int).Sub(k, rD)
297 d1 := new(big.Int).Add(priv.D, one)
298 d1Inv := new(big.Int).ModInverse(d1, N)
308 func Sm2Verify(pub *PublicKey, msg, uid []byte, r, s *big.Int) bool {
311 one := new(big.Int).SetInt64(1)
312 if r.Cmp(one) < 0 || s.Cmp(one) < 0 {
315 if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 {
318 za, err := ZA(pub, uid)
322 e, err := msgHash(za, msg)
326 t := new(big.Int).Add(r, s)
332 x1, y1 := c.ScalarBaseMult(s.Bytes())
333 x2, y2 := c.ScalarMult(pub.X, pub.Y, t.Bytes())
334 x, _ = c.Add(x1, y1, x2, y2)
341 func msgHash(za, msg []byte) (*big.Int, error) {
345 return new(big.Int).SetBytes(e.Sum(nil)[:32]), nil
348 // ZA = H256(ENTLA || IDA || a || b || xG || yG || xA || yA)
349 func ZA(pub *PublicKey, uid []byte) ([]byte, error) {
353 return []byte{}, errors.New("SM2: uid too large")
355 Entla := uint16(8 * uidLen)
356 za.Write([]byte{byte((Entla >> 8) & 0xFF)})
357 za.Write([]byte{byte(Entla & 0xFF)})
359 za.Write(sm2P256ToBig(&sm2P256.a).Bytes())
360 za.Write(sm2P256.B.Bytes())
361 za.Write(sm2P256.Gx.Bytes())
362 za.Write(sm2P256.Gy.Bytes())
364 xBuf := pub.X.Bytes()
365 yBuf := pub.Y.Bytes()
366 if n := len(xBuf); n < 32 {
367 xBuf = append(zeroByteSlice[:32-n], xBuf...)
371 return za.Sum(nil)[:32], nil
375 var zeroByteSlice = []byte{
393 func Encrypt(pub *PublicKey, data []byte) ([]byte, error) {
398 k, err := randFieldElement(curve, rand.Reader)
402 x1, y1 := curve.ScalarBaseMult(k.Bytes())
403 x2, y2 := curve.ScalarMult(pub.X, pub.Y, k.Bytes())
408 if n := len(x1Buf); n < 32 {
409 x1Buf = append(zeroByteSlice[:32-n], x1Buf...)
411 if n := len(y1Buf); n < 32 {
412 y1Buf = append(zeroByteSlice[:32-n], y1Buf...)
414 if n := len(x2Buf); n < 32 {
415 x2Buf = append(zeroByteSlice[:32-n], x2Buf...)
417 if n := len(y2Buf); n < 32 {
418 y2Buf = append(zeroByteSlice[:32-n], y2Buf...)
420 c = append(c, x1Buf...) // x分量
421 c = append(c, y1Buf...) // y分量
423 tm = append(tm, x2Buf...)
424 tm = append(tm, data...)
425 tm = append(tm, y2Buf...)
428 ct, ok := kdf(x2Buf, y2Buf, length) // 密文
433 for i := 0; i < length; i++ {
436 return append([]byte{0x04}, c...), nil
440 func Decrypt(priv *PrivateKey, data []byte) ([]byte, error) {
442 length := len(data) - 96
444 x := new(big.Int).SetBytes(data[:32])
445 y := new(big.Int).SetBytes(data[32:64])
446 x2, y2 := curve.ScalarMult(x, y, priv.D.Bytes())
449 if n := len(x2Buf); n < 32 {
450 x2Buf = append(zeroByteSlice[:32-n], x2Buf...)
452 if n := len(y2Buf); n < 32 {
453 y2Buf = append(zeroByteSlice[:32-n], y2Buf...)
455 c, ok := kdf(x2Buf, y2Buf, length)
457 return nil, errors.New("Decrypt: failed to decrypt")
459 for i := 0; i < length; i++ {
463 tm = append(tm, x2Buf...)
464 tm = append(tm, c...)
465 tm = append(tm, y2Buf...)
467 if bytes.Compare(h, data[64:96]) != 0 {
468 return c, errors.New("Decrypt: failed to decrypt")
477 func (z *zr) Read(dst []byte) (n int, err error) {
484 var zeroReader = &zr{}
486 func getLastBit(a *big.Int) uint {
490 func Compress(a *PublicKey) []byte {
492 yp := getLastBit(a.Y)
493 buf = append(buf, a.X.Bytes()...)
494 if n := len(a.X.Bytes()); n < 32 {
495 buf = append(zeroByteSlice[:(32-n)], buf...)
497 buf = append([]byte{byte(yp)}, buf...)
501 func Decompress(a []byte) *PublicKey {
502 var aa, xx, xx3 sm2P256FieldElement
505 x := new(big.Int).SetBytes(a[1:])
507 sm2P256FromBig(&xx, x)
508 sm2P256Square(&xx3, &xx) // x3 = x ^ 2
509 sm2P256Mul(&xx3, &xx3, &xx) // x3 = x ^ 2 * x
510 sm2P256Mul(&aa, &curve.a, &xx) // a = a * x
511 sm2P256Add(&xx3, &xx3, &aa)
512 sm2P256Add(&xx3, &xx3, &curve.b)
514 y2 := sm2P256ToBig(&xx3)
515 y := new(big.Int).ModSqrt(y2, sm2P256.P)
516 if getLastBit(y) != uint(a[0]) {