+++ /dev/null
-package sm2
-
-// reference to ecdsa
-import (
- "bytes"
- "crypto"
- "crypto/aes"
- "crypto/cipher"
- "crypto/elliptic"
- "crypto/rand"
- "crypto/sha512"
- "encoding/asn1"
- "encoding/binary"
- "errors"
- "io"
- "math/big"
-
- "github.com/vapor/crypto/sm3"
-)
-
-const (
- aesIV = "IV for <SM2> CTR"
-)
-
-type PublicKey struct {
- elliptic.Curve
- X, Y *big.Int
-}
-
-type PrivateKey struct {
- PublicKey
- D *big.Int
-}
-
-type sm2Signature struct {
- R, S *big.Int
-}
-
-// The SM2's private key contains the public key
-func (priv *PrivateKey) Public() crypto.PublicKey {
- return &priv.PublicKey
-}
-
-func SignDigitToSignData(r, s *big.Int) ([]byte, error) {
- return asn1.Marshal(sm2Signature{r, s})
-}
-
-func SignDataToSignDigit(sign []byte) (*big.Int, *big.Int, error) {
- var sm2Sign sm2Signature
-
- _, err := asn1.Unmarshal(sign, &sm2Sign)
- if err != nil {
- return nil, nil, err
- }
- return sm2Sign.R, sm2Sign.S, nil
-}
-
-// sign format = 30 + len(z) + 02 + len(r) + r + 02 + len(s) + s, z being what follows its size, ie 02+len(r)+r+02+len(s)+s
-func (priv *PrivateKey) Sign(rand io.Reader, msg []byte, opts crypto.SignerOpts) ([]byte, error) {
- r, s, err := Sign(priv, msg)
- if err != nil {
- return nil, err
- }
- return asn1.Marshal(sm2Signature{r, s})
-}
-
-func (priv *PrivateKey) Decrypt(data []byte) ([]byte, error) {
- return Decrypt(priv, data)
-}
-
-func (pub *PublicKey) Verify(msg []byte, sign []byte) bool {
- var sm2Sign sm2Signature
-
- _, err := asn1.Unmarshal(sign, &sm2Sign)
- if err != nil {
- return false
- }
- return Verify(pub, msg, sm2Sign.R, sm2Sign.S)
-}
-
-func (pub *PublicKey) Encrypt(data []byte) ([]byte, error) {
- return Encrypt(pub, data)
-}
-
-var one = new(big.Int).SetInt64(1)
-
-func intToBytes(x int) []byte {
- var buf = make([]byte, 4)
-
- binary.BigEndian.PutUint32(buf, uint32(x))
- return buf
-}
-
-func kdf(x, y []byte, length int) ([]byte, bool) {
- var c []byte
-
- ct := 1
- h := sm3.New()
- x = append(x, y...)
- for i, j := 0, (length+31)/32; i < j; i++ {
- h.Reset()
- h.Write(x)
- h.Write(intToBytes(ct))
- hash := h.Sum(nil)
- if i+1 == j && length%32 != 0 {
- c = append(c, hash[:length%32]...)
- } else {
- c = append(c, hash...)
- }
- ct++
- }
- for i := 0; i < length; i++ {
- if c[i] != 0 {
- return c, true
- }
- }
- return c, false
-}
-
-func randFieldElement(c elliptic.Curve, rand io.Reader) (k *big.Int, err error) {
- params := c.Params()
- b := make([]byte, params.BitSize/8+8)
- _, err = io.ReadFull(rand, b)
- if err != nil {
- return
- }
- k = new(big.Int).SetBytes(b)
- n := new(big.Int).Sub(params.N, one)
- k.Mod(k, n)
- k.Add(k, one)
- return
-}
-
-func GenerateKey() (*PrivateKey, error) {
- c := P256Sm2()
- k, err := randFieldElement(c, rand.Reader)
- if err != nil {
- return nil, err
- }
- priv := new(PrivateKey)
- priv.PublicKey.Curve = c
- priv.D = k
- priv.PublicKey.X, priv.PublicKey.Y = c.ScalarBaseMult(k.Bytes())
- return priv, nil
-}
-
-var errZeroParam = errors.New("zero parameter")
-
-func Sign(priv *PrivateKey, hash []byte) (r, s *big.Int, err error) {
- entropylen := (priv.Curve.Params().BitSize + 7) / 16
- if entropylen > 32 {
- entropylen = 32
- }
- entropy := make([]byte, entropylen)
- _, err = io.ReadFull(rand.Reader, entropy)
- if err != nil {
- return
- }
-
- // Initialize an SHA-512 hash context; digest ...
- md := sha512.New()
- md.Write(priv.D.Bytes()) // the private key,
- md.Write(entropy) // the entropy,
- md.Write(hash) // and the input hash;
- key := md.Sum(nil)[:32] // and compute ChopMD-256(SHA-512),
- // which is an indifferentiable MAC.
-
- // Create an AES-CTR instance to use as a CSPRNG.
- block, err := aes.NewCipher(key)
- if err != nil {
- return nil, nil, err
- }
-
- // Create a CSPRNG that xors a stream of zeros with
- // the output of the AES-CTR instance.
- csprng := cipher.StreamReader{
- R: zeroReader,
- S: cipher.NewCTR(block, []byte(aesIV)),
- }
-
- // See [NSA] 3.4.1
- c := priv.PublicKey.Curve
- N := c.Params().N
- if N.Sign() == 0 {
- return nil, nil, errZeroParam
- }
- var k *big.Int
- e := new(big.Int).SetBytes(hash)
- for { // 调整算法细节以实现SM2
- for {
- k, err = randFieldElement(c, csprng)
- if err != nil {
- r = nil
- return
- }
- r, _ = priv.Curve.ScalarBaseMult(k.Bytes())
- r.Add(r, e)
- r.Mod(r, N)
- if r.Sign() != 0 {
- break
- }
- if t := new(big.Int).Add(r, k); t.Cmp(N) == 0 {
- break
- }
- }
- rD := new(big.Int).Mul(priv.D, r)
- s = new(big.Int).Sub(k, rD)
- d1 := new(big.Int).Add(priv.D, one)
- d1Inv := new(big.Int).ModInverse(d1, N)
- s.Mul(s, d1Inv)
- s.Mod(s, N)
- if s.Sign() != 0 {
- break
- }
- }
- return
-}
-
-func Verify(pub *PublicKey, hash []byte, r, s *big.Int) bool {
- c := pub.Curve
- N := c.Params().N
-
- if r.Sign() <= 0 || s.Sign() <= 0 {
- return false
- }
- if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 {
- return false
- }
-
- // 调整算法细节以实现SM2
- t := new(big.Int).Add(r, s)
- t.Mod(t, N)
- if t.Sign() == 0 {
- return false
- }
-
- var x *big.Int
- x1, y1 := c.ScalarBaseMult(s.Bytes())
- x2, y2 := c.ScalarMult(pub.X, pub.Y, t.Bytes())
- x, _ = c.Add(x1, y1, x2, y2)
-
- e := new(big.Int).SetBytes(hash)
- x.Add(x, e)
- x.Mod(x, N)
- return x.Cmp(r) == 0
-}
-
-func Sm2Sign(priv *PrivateKey, msg, uid []byte) (r, s *big.Int, err error) {
- za, err := ZA(&priv.PublicKey, uid)
- if err != nil {
- return nil, nil, err
- }
- e, err := msgHash(za, msg)
- if err != nil {
- return nil, nil, err
- }
- c := priv.PublicKey.Curve
- N := c.Params().N
- if N.Sign() == 0 {
- return nil, nil, errZeroParam
- }
- var k *big.Int
- for { // 调整算法细节以实现SM2
- for {
- k, err = randFieldElement(c, rand.Reader)
- if err != nil {
- r = nil
- return
- }
- r, _ = priv.Curve.ScalarBaseMult(k.Bytes())
- r.Add(r, e)
- r.Mod(r, N)
- if r.Sign() != 0 {
- break
- }
- if t := new(big.Int).Add(r, k); t.Cmp(N) == 0 {
- break
- }
- }
- rD := new(big.Int).Mul(priv.D, r)
- s = new(big.Int).Sub(k, rD)
- d1 := new(big.Int).Add(priv.D, one)
- d1Inv := new(big.Int).ModInverse(d1, N)
- s.Mul(s, d1Inv)
- s.Mod(s, N)
- if s.Sign() != 0 {
- break
- }
- }
- return
-}
-
-func Sm2Verify(pub *PublicKey, msg, uid []byte, r, s *big.Int) bool {
- c := pub.Curve
- N := c.Params().N
- one := new(big.Int).SetInt64(1)
- if r.Cmp(one) < 0 || s.Cmp(one) < 0 {
- return false
- }
- if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 {
- return false
- }
- za, err := ZA(pub, uid)
- if err != nil {
- return false
- }
- e, err := msgHash(za, msg)
- if err != nil {
- return false
- }
- t := new(big.Int).Add(r, s)
- t.Mod(t, N)
- if t.Sign() == 0 {
- return false
- }
- var x *big.Int
- x1, y1 := c.ScalarBaseMult(s.Bytes())
- x2, y2 := c.ScalarMult(pub.X, pub.Y, t.Bytes())
- x, _ = c.Add(x1, y1, x2, y2)
-
- x.Add(x, e)
- x.Mod(x, N)
- return x.Cmp(r) == 0
-}
-
-func msgHash(za, msg []byte) (*big.Int, error) {
- e := sm3.New()
- e.Write(za)
- e.Write(msg)
- return new(big.Int).SetBytes(e.Sum(nil)[:32]), nil
-}
-
-// ZA = H256(ENTLA || IDA || a || b || xG || yG || xA || yA)
-func ZA(pub *PublicKey, uid []byte) ([]byte, error) {
- za := sm3.New()
- uidLen := len(uid)
- if uidLen >= 8192 {
- return []byte{}, errors.New("SM2: uid too large")
- }
- Entla := uint16(8 * uidLen)
- za.Write([]byte{byte((Entla >> 8) & 0xFF)})
- za.Write([]byte{byte(Entla & 0xFF)})
- za.Write(uid)
- za.Write(sm2P256ToBig(&sm2P256.a).Bytes())
- za.Write(sm2P256.B.Bytes())
- za.Write(sm2P256.Gx.Bytes())
- za.Write(sm2P256.Gy.Bytes())
-
- xBuf := pub.X.Bytes()
- yBuf := pub.Y.Bytes()
- if n := len(xBuf); n < 32 {
- xBuf = append(zeroByteSlice[:32-n], xBuf...)
- }
- za.Write(xBuf)
- za.Write(yBuf)
- return za.Sum(nil)[:32], nil
-}
-
-// 32byte
-var zeroByteSlice = []byte{
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- 0, 0, 0, 0,
-}
-
-/*
- * sm2密文结构如下:
- * x
- * y
- * hash
- * CipherText
- */
-func Encrypt(pub *PublicKey, data []byte) ([]byte, error) {
- length := len(data)
- for {
- c := []byte{}
- curve := pub.Curve
- k, err := randFieldElement(curve, rand.Reader)
- if err != nil {
- return nil, err
- }
- x1, y1 := curve.ScalarBaseMult(k.Bytes())
- x2, y2 := curve.ScalarMult(pub.X, pub.Y, k.Bytes())
- x1Buf := x1.Bytes()
- y1Buf := y1.Bytes()
- x2Buf := x2.Bytes()
- y2Buf := y2.Bytes()
- if n := len(x1Buf); n < 32 {
- x1Buf = append(zeroByteSlice[:32-n], x1Buf...)
- }
- if n := len(y1Buf); n < 32 {
- y1Buf = append(zeroByteSlice[:32-n], y1Buf...)
- }
- if n := len(x2Buf); n < 32 {
- x2Buf = append(zeroByteSlice[:32-n], x2Buf...)
- }
- if n := len(y2Buf); n < 32 {
- y2Buf = append(zeroByteSlice[:32-n], y2Buf...)
- }
- c = append(c, x1Buf...) // x分量
- c = append(c, y1Buf...) // y分量
- tm := []byte{}
- tm = append(tm, x2Buf...)
- tm = append(tm, data...)
- tm = append(tm, y2Buf...)
- h := sm3.Sm3Sum(tm)
- c = append(c, h...)
- ct, ok := kdf(x2Buf, y2Buf, length) // 密文
- if !ok {
- continue
- }
- c = append(c, ct...)
- for i := 0; i < length; i++ {
- c[96+i] ^= data[i]
- }
- return append([]byte{0x04}, c...), nil
- }
-}
-
-func Decrypt(priv *PrivateKey, data []byte) ([]byte, error) {
- data = data[1:]
- length := len(data) - 96
- curve := priv.Curve
- x := new(big.Int).SetBytes(data[:32])
- y := new(big.Int).SetBytes(data[32:64])
- x2, y2 := curve.ScalarMult(x, y, priv.D.Bytes())
- x2Buf := x2.Bytes()
- y2Buf := y2.Bytes()
- if n := len(x2Buf); n < 32 {
- x2Buf = append(zeroByteSlice[:32-n], x2Buf...)
- }
- if n := len(y2Buf); n < 32 {
- y2Buf = append(zeroByteSlice[:32-n], y2Buf...)
- }
- c, ok := kdf(x2Buf, y2Buf, length)
- if !ok {
- return nil, errors.New("Decrypt: failed to decrypt")
- }
- for i := 0; i < length; i++ {
- c[i] ^= data[i+96]
- }
- tm := []byte{}
- tm = append(tm, x2Buf...)
- tm = append(tm, c...)
- tm = append(tm, y2Buf...)
- h := sm3.Sm3Sum(tm)
- if bytes.Compare(h, data[64:96]) != 0 {
- return c, errors.New("Decrypt: failed to decrypt")
- }
- return c, nil
-}
-
-type zr struct {
- io.Reader
-}
-
-func (z *zr) Read(dst []byte) (n int, err error) {
- for i := range dst {
- dst[i] = 0
- }
- return len(dst), nil
-}
-
-var zeroReader = &zr{}
-
-func getLastBit(a *big.Int) uint {
- return a.Bit(0)
-}
-
-func Compress(a *PublicKey) []byte {
- buf := []byte{}
- yp := getLastBit(a.Y)
- buf = append(buf, a.X.Bytes()...)
- if n := len(a.X.Bytes()); n < 32 {
- buf = append(zeroByteSlice[:(32-n)], buf...)
- }
- buf = append([]byte{byte(yp)}, buf...)
- return buf
-}
-
-func Decompress(a []byte) *PublicKey {
- var aa, xx, xx3 sm2P256FieldElement
-
- P256Sm2()
- x := new(big.Int).SetBytes(a[1:])
- curve := sm2P256
- sm2P256FromBig(&xx, x)
- sm2P256Square(&xx3, &xx) // x3 = x ^ 2
- sm2P256Mul(&xx3, &xx3, &xx) // x3 = x ^ 2 * x
- sm2P256Mul(&aa, &curve.a, &xx) // a = a * x
- sm2P256Add(&xx3, &xx3, &aa)
- sm2P256Add(&xx3, &xx3, &curve.b)
-
- y2 := sm2P256ToBig(&xx3)
- y := new(big.Int).ModSqrt(y2, sm2P256.P)
- if getLastBit(y) != uint(a[0]) {
- y.Sub(sm2P256.P, y)
- }
- return &PublicKey{
- Curve: P256Sm2(),
- X: x,
- Y: y,
- }
-}