--- /dev/null
+package sm2
+
+// reference to ecdsa
+import (
+ "bytes"
+ "crypto"
+ "crypto/aes"
+ "crypto/cipher"
+ "crypto/elliptic"
+ "crypto/rand"
+ "crypto/sha512"
+ "encoding/asn1"
+ "encoding/binary"
+ "errors"
+ "io"
+ "math/big"
+
+ "github.com/bytom/crypto/sm3"
+)
+
+const (
+ aesIV = "IV for <SM2> CTR"
+)
+
+type PublicKey struct {
+ elliptic.Curve
+ X, Y *big.Int
+}
+
+type PrivateKey struct {
+ PublicKey
+ D *big.Int
+}
+
+type sm2Signature struct {
+ R, S *big.Int
+}
+
+// The SM2's private key contains the public key
+func (priv *PrivateKey) Public() crypto.PublicKey {
+ return &priv.PublicKey
+}
+
+func SignDigitToSignData(r, s *big.Int) ([]byte, error) {
+ return asn1.Marshal(sm2Signature{r, s})
+}
+
+func SignDataToSignDigit(sign []byte) (*big.Int, *big.Int, error) {
+ var sm2Sign sm2Signature
+
+ _, err := asn1.Unmarshal(sign, &sm2Sign)
+ if err != nil {
+ return nil, nil, err
+ }
+ return sm2Sign.R, sm2Sign.S, nil
+}
+
+// sign format = 30 + len(z) + 02 + len(r) + r + 02 + len(s) + s, z being what follows its size, ie 02+len(r)+r+02+len(s)+s
+func (priv *PrivateKey) Sign(rand io.Reader, msg []byte, opts crypto.SignerOpts) ([]byte, error) {
+ r, s, err := Sign(priv, msg)
+ if err != nil {
+ return nil, err
+ }
+ return asn1.Marshal(sm2Signature{r, s})
+}
+
+func (priv *PrivateKey) Decrypt(data []byte) ([]byte, error) {
+ return Decrypt(priv, data)
+}
+
+func (pub *PublicKey) Verify(msg []byte, sign []byte) bool {
+ var sm2Sign sm2Signature
+
+ _, err := asn1.Unmarshal(sign, &sm2Sign)
+ if err != nil {
+ return false
+ }
+ return Verify(pub, msg, sm2Sign.R, sm2Sign.S)
+}
+
+func (pub *PublicKey) Encrypt(data []byte) ([]byte, error) {
+ return Encrypt(pub, data)
+}
+
+var one = new(big.Int).SetInt64(1)
+
+func intToBytes(x int) []byte {
+ var buf = make([]byte, 4)
+
+ binary.BigEndian.PutUint32(buf, uint32(x))
+ return buf
+}
+
+func kdf(x, y []byte, length int) ([]byte, bool) {
+ var c []byte
+
+ ct := 1
+ h := sm3.New()
+ x = append(x, y...)
+ for i, j := 0, (length+31)/32; i < j; i++ {
+ h.Reset()
+ h.Write(x)
+ h.Write(intToBytes(ct))
+ hash := h.Sum(nil)
+ if i+1 == j && length%32 != 0 {
+ c = append(c, hash[:length%32]...)
+ } else {
+ c = append(c, hash...)
+ }
+ ct++
+ }
+ for i := 0; i < length; i++ {
+ if c[i] != 0 {
+ return c, true
+ }
+ }
+ return c, false
+}
+
+func randFieldElement(c elliptic.Curve, rand io.Reader) (k *big.Int, err error) {
+ params := c.Params()
+ b := make([]byte, params.BitSize/8+8)
+ _, err = io.ReadFull(rand, b)
+ if err != nil {
+ return
+ }
+ k = new(big.Int).SetBytes(b)
+ n := new(big.Int).Sub(params.N, one)
+ k.Mod(k, n)
+ k.Add(k, one)
+ return
+}
+
+func GenerateKey() (*PrivateKey, error) {
+ c := P256Sm2()
+ k, err := randFieldElement(c, rand.Reader)
+ if err != nil {
+ return nil, err
+ }
+ priv := new(PrivateKey)
+ priv.PublicKey.Curve = c
+ priv.D = k
+ priv.PublicKey.X, priv.PublicKey.Y = c.ScalarBaseMult(k.Bytes())
+ return priv, nil
+}
+
+var errZeroParam = errors.New("zero parameter")
+
+func Sign(priv *PrivateKey, hash []byte) (r, s *big.Int, err error) {
+ entropylen := (priv.Curve.Params().BitSize + 7) / 16
+ if entropylen > 32 {
+ entropylen = 32
+ }
+ entropy := make([]byte, entropylen)
+ _, err = io.ReadFull(rand.Reader, entropy)
+ if err != nil {
+ return
+ }
+
+ // Initialize an SHA-512 hash context; digest ...
+ md := sha512.New()
+ md.Write(priv.D.Bytes()) // the private key,
+ md.Write(entropy) // the entropy,
+ md.Write(hash) // and the input hash;
+ key := md.Sum(nil)[:32] // and compute ChopMD-256(SHA-512),
+ // which is an indifferentiable MAC.
+
+ // Create an AES-CTR instance to use as a CSPRNG.
+ block, err := aes.NewCipher(key)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ // Create a CSPRNG that xors a stream of zeros with
+ // the output of the AES-CTR instance.
+ csprng := cipher.StreamReader{
+ R: zeroReader,
+ S: cipher.NewCTR(block, []byte(aesIV)),
+ }
+
+ // See [NSA] 3.4.1
+ c := priv.PublicKey.Curve
+ N := c.Params().N
+ if N.Sign() == 0 {
+ return nil, nil, errZeroParam
+ }
+ var k *big.Int
+ e := new(big.Int).SetBytes(hash)
+ for { // 调整算法细节以实现SM2
+ for {
+ k, err = randFieldElement(c, csprng)
+ if err != nil {
+ r = nil
+ return
+ }
+ r, _ = priv.Curve.ScalarBaseMult(k.Bytes())
+ r.Add(r, e)
+ r.Mod(r, N)
+ if r.Sign() != 0 {
+ break
+ }
+ if t := new(big.Int).Add(r, k); t.Cmp(N) == 0 {
+ break
+ }
+ }
+ rD := new(big.Int).Mul(priv.D, r)
+ s = new(big.Int).Sub(k, rD)
+ d1 := new(big.Int).Add(priv.D, one)
+ d1Inv := new(big.Int).ModInverse(d1, N)
+ s.Mul(s, d1Inv)
+ s.Mod(s, N)
+ if s.Sign() != 0 {
+ break
+ }
+ }
+ return
+}
+
+func Verify(pub *PublicKey, hash []byte, r, s *big.Int) bool {
+ c := pub.Curve
+ N := c.Params().N
+
+ if r.Sign() <= 0 || s.Sign() <= 0 {
+ return false
+ }
+ if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 {
+ return false
+ }
+
+ // 调整算法细节以实现SM2
+ t := new(big.Int).Add(r, s)
+ t.Mod(t, N)
+ if t.Sign() == 0 {
+ return false
+ }
+
+ var x *big.Int
+ x1, y1 := c.ScalarBaseMult(s.Bytes())
+ x2, y2 := c.ScalarMult(pub.X, pub.Y, t.Bytes())
+ x, _ = c.Add(x1, y1, x2, y2)
+
+ e := new(big.Int).SetBytes(hash)
+ x.Add(x, e)
+ x.Mod(x, N)
+ return x.Cmp(r) == 0
+}
+
+func Sm2Sign(priv *PrivateKey, msg, uid []byte) (r, s *big.Int, err error) {
+ za, err := ZA(&priv.PublicKey, uid)
+ if err != nil {
+ return nil, nil, err
+ }
+ e, err := msgHash(za, msg)
+ if err != nil {
+ return nil, nil, err
+ }
+ c := priv.PublicKey.Curve
+ N := c.Params().N
+ if N.Sign() == 0 {
+ return nil, nil, errZeroParam
+ }
+ var k *big.Int
+ for { // 调整算法细节以实现SM2
+ for {
+ k, err = randFieldElement(c, rand.Reader)
+ if err != nil {
+ r = nil
+ return
+ }
+ r, _ = priv.Curve.ScalarBaseMult(k.Bytes())
+ r.Add(r, e)
+ r.Mod(r, N)
+ if r.Sign() != 0 {
+ break
+ }
+ if t := new(big.Int).Add(r, k); t.Cmp(N) == 0 {
+ break
+ }
+ }
+ rD := new(big.Int).Mul(priv.D, r)
+ s = new(big.Int).Sub(k, rD)
+ d1 := new(big.Int).Add(priv.D, one)
+ d1Inv := new(big.Int).ModInverse(d1, N)
+ s.Mul(s, d1Inv)
+ s.Mod(s, N)
+ if s.Sign() != 0 {
+ break
+ }
+ }
+ return
+}
+
+func Sm2Verify(pub *PublicKey, msg, uid []byte, r, s *big.Int) bool {
+ c := pub.Curve
+ N := c.Params().N
+ one := new(big.Int).SetInt64(1)
+ if r.Cmp(one) < 0 || s.Cmp(one) < 0 {
+ return false
+ }
+ if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 {
+ return false
+ }
+ za, err := ZA(pub, uid)
+ if err != nil {
+ return false
+ }
+ e, err := msgHash(za, msg)
+ if err != nil {
+ return false
+ }
+ t := new(big.Int).Add(r, s)
+ t.Mod(t, N)
+ if t.Sign() == 0 {
+ return false
+ }
+ var x *big.Int
+ x1, y1 := c.ScalarBaseMult(s.Bytes())
+ x2, y2 := c.ScalarMult(pub.X, pub.Y, t.Bytes())
+ x, _ = c.Add(x1, y1, x2, y2)
+
+ x.Add(x, e)
+ x.Mod(x, N)
+ return x.Cmp(r) == 0
+}
+
+func msgHash(za, msg []byte) (*big.Int, error) {
+ e := sm3.New()
+ e.Write(za)
+ e.Write(msg)
+ return new(big.Int).SetBytes(e.Sum(nil)[:32]), nil
+}
+
+// ZA = H256(ENTLA || IDA || a || b || xG || yG || xA || yA)
+func ZA(pub *PublicKey, uid []byte) ([]byte, error) {
+ za := sm3.New()
+ uidLen := len(uid)
+ if uidLen >= 8192 {
+ return []byte{}, errors.New("SM2: uid too large")
+ }
+ Entla := uint16(8 * uidLen)
+ za.Write([]byte{byte((Entla >> 8) & 0xFF)})
+ za.Write([]byte{byte(Entla & 0xFF)})
+ za.Write(uid)
+ za.Write(sm2P256ToBig(&sm2P256.a).Bytes())
+ za.Write(sm2P256.B.Bytes())
+ za.Write(sm2P256.Gx.Bytes())
+ za.Write(sm2P256.Gy.Bytes())
+
+ xBuf := pub.X.Bytes()
+ yBuf := pub.Y.Bytes()
+ if n := len(xBuf); n < 32 {
+ xBuf = append(zeroByteSlice[:32-n], xBuf...)
+ }
+ za.Write(xBuf)
+ za.Write(yBuf)
+ return za.Sum(nil)[:32], nil
+}
+
+// 32byte
+var zeroByteSlice = []byte{
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+}
+
+/*
+ * sm2密文结构如下:
+ * x
+ * y
+ * hash
+ * CipherText
+ */
+func Encrypt(pub *PublicKey, data []byte) ([]byte, error) {
+ length := len(data)
+ for {
+ c := []byte{}
+ curve := pub.Curve
+ k, err := randFieldElement(curve, rand.Reader)
+ if err != nil {
+ return nil, err
+ }
+ x1, y1 := curve.ScalarBaseMult(k.Bytes())
+ x2, y2 := curve.ScalarMult(pub.X, pub.Y, k.Bytes())
+ x1Buf := x1.Bytes()
+ y1Buf := y1.Bytes()
+ x2Buf := x2.Bytes()
+ y2Buf := y2.Bytes()
+ if n := len(x1Buf); n < 32 {
+ x1Buf = append(zeroByteSlice[:32-n], x1Buf...)
+ }
+ if n := len(y1Buf); n < 32 {
+ y1Buf = append(zeroByteSlice[:32-n], y1Buf...)
+ }
+ if n := len(x2Buf); n < 32 {
+ x2Buf = append(zeroByteSlice[:32-n], x2Buf...)
+ }
+ if n := len(y2Buf); n < 32 {
+ y2Buf = append(zeroByteSlice[:32-n], y2Buf...)
+ }
+ c = append(c, x1Buf...) // x分量
+ c = append(c, y1Buf...) // y分量
+ tm := []byte{}
+ tm = append(tm, x2Buf...)
+ tm = append(tm, data...)
+ tm = append(tm, y2Buf...)
+ h := sm3.Sm3Sum(tm)
+ c = append(c, h...)
+ ct, ok := kdf(x2Buf, y2Buf, length) // 密文
+ if !ok {
+ continue
+ }
+ c = append(c, ct...)
+ for i := 0; i < length; i++ {
+ c[96+i] ^= data[i]
+ }
+ return append([]byte{0x04}, c...), nil
+ }
+}
+
+func Decrypt(priv *PrivateKey, data []byte) ([]byte, error) {
+ data = data[1:]
+ length := len(data) - 96
+ curve := priv.Curve
+ x := new(big.Int).SetBytes(data[:32])
+ y := new(big.Int).SetBytes(data[32:64])
+ x2, y2 := curve.ScalarMult(x, y, priv.D.Bytes())
+ x2Buf := x2.Bytes()
+ y2Buf := y2.Bytes()
+ if n := len(x2Buf); n < 32 {
+ x2Buf = append(zeroByteSlice[:32-n], x2Buf...)
+ }
+ if n := len(y2Buf); n < 32 {
+ y2Buf = append(zeroByteSlice[:32-n], y2Buf...)
+ }
+ c, ok := kdf(x2Buf, y2Buf, length)
+ if !ok {
+ return nil, errors.New("Decrypt: failed to decrypt")
+ }
+ for i := 0; i < length; i++ {
+ c[i] ^= data[i+96]
+ }
+ tm := []byte{}
+ tm = append(tm, x2Buf...)
+ tm = append(tm, c...)
+ tm = append(tm, y2Buf...)
+ h := sm3.Sm3Sum(tm)
+ if bytes.Compare(h, data[64:96]) != 0 {
+ return c, errors.New("Decrypt: failed to decrypt")
+ }
+ return c, nil
+}
+
+type zr struct {
+ io.Reader
+}
+
+func (z *zr) Read(dst []byte) (n int, err error) {
+ for i := range dst {
+ dst[i] = 0
+ }
+ return len(dst), nil
+}
+
+var zeroReader = &zr{}
+
+func getLastBit(a *big.Int) uint {
+ return a.Bit(0)
+}
+
+func Compress(a *PublicKey) []byte {
+ buf := []byte{}
+ yp := getLastBit(a.Y)
+ buf = append(buf, a.X.Bytes()...)
+ if n := len(a.X.Bytes()); n < 32 {
+ buf = append(zeroByteSlice[:(32-n)], buf...)
+ }
+ buf = append([]byte{byte(yp)}, buf...)
+ return buf
+}
+
+func Decompress(a []byte) *PublicKey {
+ var aa, xx, xx3 sm2P256FieldElement
+
+ P256Sm2()
+ x := new(big.Int).SetBytes(a[1:])
+ curve := sm2P256
+ sm2P256FromBig(&xx, x)
+ sm2P256Square(&xx3, &xx) // x3 = x ^ 2
+ sm2P256Mul(&xx3, &xx3, &xx) // x3 = x ^ 2 * x
+ sm2P256Mul(&aa, &curve.a, &xx) // a = a * x
+ sm2P256Add(&xx3, &xx3, &aa)
+ sm2P256Add(&xx3, &xx3, &curve.b)
+
+ y2 := sm2P256ToBig(&xx3)
+ y := new(big.Int).ModSqrt(y2, sm2P256.P)
+ if getLastBit(y) != uint(a[0]) {
+ y.Sub(sm2P256.P, y)
+ }
+ return &PublicKey{
+ Curve: P256Sm2(),
+ X: x,
+ Y: y,
+ }
+}