OSDN Git Service

Thanos did someting
[bytom/vapor.git] / crypto / sm2 / sm2.go
diff --git a/crypto/sm2/sm2.go b/crypto/sm2/sm2.go
deleted file mode 100644 (file)
index 872bc48..0000000
+++ /dev/null
@@ -1,509 +0,0 @@
-package sm2
-
-// reference to ecdsa
-import (
-       "bytes"
-       "crypto"
-       "crypto/aes"
-       "crypto/cipher"
-       "crypto/elliptic"
-       "crypto/rand"
-       "crypto/sha512"
-       "encoding/asn1"
-       "encoding/binary"
-       "errors"
-       "io"
-       "math/big"
-
-       "github.com/vapor/crypto/sm3"
-)
-
-const (
-       aesIV = "IV for <SM2> CTR"
-)
-
-type PublicKey struct {
-       elliptic.Curve
-       X, Y *big.Int
-}
-
-type PrivateKey struct {
-       PublicKey
-       D *big.Int
-}
-
-type sm2Signature struct {
-       R, S *big.Int
-}
-
-// The SM2's private key contains the public key
-func (priv *PrivateKey) Public() crypto.PublicKey {
-       return &priv.PublicKey
-}
-
-func SignDigitToSignData(r, s *big.Int) ([]byte, error) {
-       return asn1.Marshal(sm2Signature{r, s})
-}
-
-func SignDataToSignDigit(sign []byte) (*big.Int, *big.Int, error) {
-       var sm2Sign sm2Signature
-
-       _, err := asn1.Unmarshal(sign, &sm2Sign)
-       if err != nil {
-               return nil, nil, err
-       }
-       return sm2Sign.R, sm2Sign.S, nil
-}
-
-// sign format = 30 + len(z) + 02 + len(r) + r + 02 + len(s) + s, z being what follows its size, ie 02+len(r)+r+02+len(s)+s
-func (priv *PrivateKey) Sign(rand io.Reader, msg []byte, opts crypto.SignerOpts) ([]byte, error) {
-       r, s, err := Sign(priv, msg)
-       if err != nil {
-               return nil, err
-       }
-       return asn1.Marshal(sm2Signature{r, s})
-}
-
-func (priv *PrivateKey) Decrypt(data []byte) ([]byte, error) {
-       return Decrypt(priv, data)
-}
-
-func (pub *PublicKey) Verify(msg []byte, sign []byte) bool {
-       var sm2Sign sm2Signature
-
-       _, err := asn1.Unmarshal(sign, &sm2Sign)
-       if err != nil {
-               return false
-       }
-       return Verify(pub, msg, sm2Sign.R, sm2Sign.S)
-}
-
-func (pub *PublicKey) Encrypt(data []byte) ([]byte, error) {
-       return Encrypt(pub, data)
-}
-
-var one = new(big.Int).SetInt64(1)
-
-func intToBytes(x int) []byte {
-       var buf = make([]byte, 4)
-
-       binary.BigEndian.PutUint32(buf, uint32(x))
-       return buf
-}
-
-func kdf(x, y []byte, length int) ([]byte, bool) {
-       var c []byte
-
-       ct := 1
-       h := sm3.New()
-       x = append(x, y...)
-       for i, j := 0, (length+31)/32; i < j; i++ {
-               h.Reset()
-               h.Write(x)
-               h.Write(intToBytes(ct))
-               hash := h.Sum(nil)
-               if i+1 == j && length%32 != 0 {
-                       c = append(c, hash[:length%32]...)
-               } else {
-                       c = append(c, hash...)
-               }
-               ct++
-       }
-       for i := 0; i < length; i++ {
-               if c[i] != 0 {
-                       return c, true
-               }
-       }
-       return c, false
-}
-
-func randFieldElement(c elliptic.Curve, rand io.Reader) (k *big.Int, err error) {
-       params := c.Params()
-       b := make([]byte, params.BitSize/8+8)
-       _, err = io.ReadFull(rand, b)
-       if err != nil {
-               return
-       }
-       k = new(big.Int).SetBytes(b)
-       n := new(big.Int).Sub(params.N, one)
-       k.Mod(k, n)
-       k.Add(k, one)
-       return
-}
-
-func GenerateKey() (*PrivateKey, error) {
-       c := P256Sm2()
-       k, err := randFieldElement(c, rand.Reader)
-       if err != nil {
-               return nil, err
-       }
-       priv := new(PrivateKey)
-       priv.PublicKey.Curve = c
-       priv.D = k
-       priv.PublicKey.X, priv.PublicKey.Y = c.ScalarBaseMult(k.Bytes())
-       return priv, nil
-}
-
-var errZeroParam = errors.New("zero parameter")
-
-func Sign(priv *PrivateKey, hash []byte) (r, s *big.Int, err error) {
-       entropylen := (priv.Curve.Params().BitSize + 7) / 16
-       if entropylen > 32 {
-               entropylen = 32
-       }
-       entropy := make([]byte, entropylen)
-       _, err = io.ReadFull(rand.Reader, entropy)
-       if err != nil {
-               return
-       }
-
-       // Initialize an SHA-512 hash context; digest ...
-       md := sha512.New()
-       md.Write(priv.D.Bytes()) // the private key,
-       md.Write(entropy)        // the entropy,
-       md.Write(hash)           // and the input hash;
-       key := md.Sum(nil)[:32]  // and compute ChopMD-256(SHA-512),
-       // which is an indifferentiable MAC.
-
-       // Create an AES-CTR instance to use as a CSPRNG.
-       block, err := aes.NewCipher(key)
-       if err != nil {
-               return nil, nil, err
-       }
-
-       // Create a CSPRNG that xors a stream of zeros with
-       // the output of the AES-CTR instance.
-       csprng := cipher.StreamReader{
-               R: zeroReader,
-               S: cipher.NewCTR(block, []byte(aesIV)),
-       }
-
-       // See [NSA] 3.4.1
-       c := priv.PublicKey.Curve
-       N := c.Params().N
-       if N.Sign() == 0 {
-               return nil, nil, errZeroParam
-       }
-       var k *big.Int
-       e := new(big.Int).SetBytes(hash)
-       for { // 调整算法细节以实现SM2
-               for {
-                       k, err = randFieldElement(c, csprng)
-                       if err != nil {
-                               r = nil
-                               return
-                       }
-                       r, _ = priv.Curve.ScalarBaseMult(k.Bytes())
-                       r.Add(r, e)
-                       r.Mod(r, N)
-                       if r.Sign() != 0 {
-                               break
-                       }
-                       if t := new(big.Int).Add(r, k); t.Cmp(N) == 0 {
-                               break
-                       }
-               }
-               rD := new(big.Int).Mul(priv.D, r)
-               s = new(big.Int).Sub(k, rD)
-               d1 := new(big.Int).Add(priv.D, one)
-               d1Inv := new(big.Int).ModInverse(d1, N)
-               s.Mul(s, d1Inv)
-               s.Mod(s, N)
-               if s.Sign() != 0 {
-                       break
-               }
-       }
-       return
-}
-
-func Verify(pub *PublicKey, hash []byte, r, s *big.Int) bool {
-       c := pub.Curve
-       N := c.Params().N
-
-       if r.Sign() <= 0 || s.Sign() <= 0 {
-               return false
-       }
-       if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 {
-               return false
-       }
-
-       // 调整算法细节以实现SM2
-       t := new(big.Int).Add(r, s)
-       t.Mod(t, N)
-       if t.Sign() == 0 {
-               return false
-       }
-
-       var x *big.Int
-       x1, y1 := c.ScalarBaseMult(s.Bytes())
-       x2, y2 := c.ScalarMult(pub.X, pub.Y, t.Bytes())
-       x, _ = c.Add(x1, y1, x2, y2)
-
-       e := new(big.Int).SetBytes(hash)
-       x.Add(x, e)
-       x.Mod(x, N)
-       return x.Cmp(r) == 0
-}
-
-func Sm2Sign(priv *PrivateKey, msg, uid []byte) (r, s *big.Int, err error) {
-       za, err := ZA(&priv.PublicKey, uid)
-       if err != nil {
-               return nil, nil, err
-       }
-       e, err := msgHash(za, msg)
-       if err != nil {
-               return nil, nil, err
-       }
-       c := priv.PublicKey.Curve
-       N := c.Params().N
-       if N.Sign() == 0 {
-               return nil, nil, errZeroParam
-       }
-       var k *big.Int
-       for { // 调整算法细节以实现SM2
-               for {
-                       k, err = randFieldElement(c, rand.Reader)
-                       if err != nil {
-                               r = nil
-                               return
-                       }
-                       r, _ = priv.Curve.ScalarBaseMult(k.Bytes())
-                       r.Add(r, e)
-                       r.Mod(r, N)
-                       if r.Sign() != 0 {
-                               break
-                       }
-                       if t := new(big.Int).Add(r, k); t.Cmp(N) == 0 {
-                               break
-                       }
-               }
-               rD := new(big.Int).Mul(priv.D, r)
-               s = new(big.Int).Sub(k, rD)
-               d1 := new(big.Int).Add(priv.D, one)
-               d1Inv := new(big.Int).ModInverse(d1, N)
-               s.Mul(s, d1Inv)
-               s.Mod(s, N)
-               if s.Sign() != 0 {
-                       break
-               }
-       }
-       return
-}
-
-func Sm2Verify(pub *PublicKey, msg, uid []byte, r, s *big.Int) bool {
-       c := pub.Curve
-       N := c.Params().N
-       one := new(big.Int).SetInt64(1)
-       if r.Cmp(one) < 0 || s.Cmp(one) < 0 {
-               return false
-       }
-       if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 {
-               return false
-       }
-       za, err := ZA(pub, uid)
-       if err != nil {
-               return false
-       }
-       e, err := msgHash(za, msg)
-       if err != nil {
-               return false
-       }
-       t := new(big.Int).Add(r, s)
-       t.Mod(t, N)
-       if t.Sign() == 0 {
-               return false
-       }
-       var x *big.Int
-       x1, y1 := c.ScalarBaseMult(s.Bytes())
-       x2, y2 := c.ScalarMult(pub.X, pub.Y, t.Bytes())
-       x, _ = c.Add(x1, y1, x2, y2)
-
-       x.Add(x, e)
-       x.Mod(x, N)
-       return x.Cmp(r) == 0
-}
-
-func msgHash(za, msg []byte) (*big.Int, error) {
-       e := sm3.New()
-       e.Write(za)
-       e.Write(msg)
-       return new(big.Int).SetBytes(e.Sum(nil)[:32]), nil
-}
-
-// ZA = H256(ENTLA || IDA || a || b || xG || yG || xA || yA)
-func ZA(pub *PublicKey, uid []byte) ([]byte, error) {
-       za := sm3.New()
-       uidLen := len(uid)
-       if uidLen >= 8192 {
-               return []byte{}, errors.New("SM2: uid too large")
-       }
-       Entla := uint16(8 * uidLen)
-       za.Write([]byte{byte((Entla >> 8) & 0xFF)})
-       za.Write([]byte{byte(Entla & 0xFF)})
-       za.Write(uid)
-       za.Write(sm2P256ToBig(&sm2P256.a).Bytes())
-       za.Write(sm2P256.B.Bytes())
-       za.Write(sm2P256.Gx.Bytes())
-       za.Write(sm2P256.Gy.Bytes())
-
-       xBuf := pub.X.Bytes()
-       yBuf := pub.Y.Bytes()
-       if n := len(xBuf); n < 32 {
-               xBuf = append(zeroByteSlice[:32-n], xBuf...)
-       }
-       za.Write(xBuf)
-       za.Write(yBuf)
-       return za.Sum(nil)[:32], nil
-}
-
-// 32byte
-var zeroByteSlice = []byte{
-       0, 0, 0, 0,
-       0, 0, 0, 0,
-       0, 0, 0, 0,
-       0, 0, 0, 0,
-       0, 0, 0, 0,
-       0, 0, 0, 0,
-       0, 0, 0, 0,
-       0, 0, 0, 0,
-}
-
-/*
- * sm2密文结构如下:
- *  x
- *  y
- *  hash
- *  CipherText
- */
-func Encrypt(pub *PublicKey, data []byte) ([]byte, error) {
-       length := len(data)
-       for {
-               c := []byte{}
-               curve := pub.Curve
-               k, err := randFieldElement(curve, rand.Reader)
-               if err != nil {
-                       return nil, err
-               }
-               x1, y1 := curve.ScalarBaseMult(k.Bytes())
-               x2, y2 := curve.ScalarMult(pub.X, pub.Y, k.Bytes())
-               x1Buf := x1.Bytes()
-               y1Buf := y1.Bytes()
-               x2Buf := x2.Bytes()
-               y2Buf := y2.Bytes()
-               if n := len(x1Buf); n < 32 {
-                       x1Buf = append(zeroByteSlice[:32-n], x1Buf...)
-               }
-               if n := len(y1Buf); n < 32 {
-                       y1Buf = append(zeroByteSlice[:32-n], y1Buf...)
-               }
-               if n := len(x2Buf); n < 32 {
-                       x2Buf = append(zeroByteSlice[:32-n], x2Buf...)
-               }
-               if n := len(y2Buf); n < 32 {
-                       y2Buf = append(zeroByteSlice[:32-n], y2Buf...)
-               }
-               c = append(c, x1Buf...) // x分量
-               c = append(c, y1Buf...) // y分量
-               tm := []byte{}
-               tm = append(tm, x2Buf...)
-               tm = append(tm, data...)
-               tm = append(tm, y2Buf...)
-               h := sm3.Sm3Sum(tm)
-               c = append(c, h...)
-               ct, ok := kdf(x2Buf, y2Buf, length) // 密文
-               if !ok {
-                       continue
-               }
-               c = append(c, ct...)
-               for i := 0; i < length; i++ {
-                       c[96+i] ^= data[i]
-               }
-               return append([]byte{0x04}, c...), nil
-       }
-}
-
-func Decrypt(priv *PrivateKey, data []byte) ([]byte, error) {
-       data = data[1:]
-       length := len(data) - 96
-       curve := priv.Curve
-       x := new(big.Int).SetBytes(data[:32])
-       y := new(big.Int).SetBytes(data[32:64])
-       x2, y2 := curve.ScalarMult(x, y, priv.D.Bytes())
-       x2Buf := x2.Bytes()
-       y2Buf := y2.Bytes()
-       if n := len(x2Buf); n < 32 {
-               x2Buf = append(zeroByteSlice[:32-n], x2Buf...)
-       }
-       if n := len(y2Buf); n < 32 {
-               y2Buf = append(zeroByteSlice[:32-n], y2Buf...)
-       }
-       c, ok := kdf(x2Buf, y2Buf, length)
-       if !ok {
-               return nil, errors.New("Decrypt: failed to decrypt")
-       }
-       for i := 0; i < length; i++ {
-               c[i] ^= data[i+96]
-       }
-       tm := []byte{}
-       tm = append(tm, x2Buf...)
-       tm = append(tm, c...)
-       tm = append(tm, y2Buf...)
-       h := sm3.Sm3Sum(tm)
-       if bytes.Compare(h, data[64:96]) != 0 {
-               return c, errors.New("Decrypt: failed to decrypt")
-       }
-       return c, nil
-}
-
-type zr struct {
-       io.Reader
-}
-
-func (z *zr) Read(dst []byte) (n int, err error) {
-       for i := range dst {
-               dst[i] = 0
-       }
-       return len(dst), nil
-}
-
-var zeroReader = &zr{}
-
-func getLastBit(a *big.Int) uint {
-       return a.Bit(0)
-}
-
-func Compress(a *PublicKey) []byte {
-       buf := []byte{}
-       yp := getLastBit(a.Y)
-       buf = append(buf, a.X.Bytes()...)
-       if n := len(a.X.Bytes()); n < 32 {
-               buf = append(zeroByteSlice[:(32-n)], buf...)
-       }
-       buf = append([]byte{byte(yp)}, buf...)
-       return buf
-}
-
-func Decompress(a []byte) *PublicKey {
-       var aa, xx, xx3 sm2P256FieldElement
-
-       P256Sm2()
-       x := new(big.Int).SetBytes(a[1:])
-       curve := sm2P256
-       sm2P256FromBig(&xx, x)
-       sm2P256Square(&xx3, &xx)       // x3 = x ^ 2
-       sm2P256Mul(&xx3, &xx3, &xx)    // x3 = x ^ 2 * x
-       sm2P256Mul(&aa, &curve.a, &xx) // a = a * x
-       sm2P256Add(&xx3, &xx3, &aa)
-       sm2P256Add(&xx3, &xx3, &curve.b)
-
-       y2 := sm2P256ToBig(&xx3)
-       y := new(big.Int).ModSqrt(y2, sm2P256.P)
-       if getLastBit(y) != uint(a[0]) {
-               y.Sub(sm2P256.P, y)
-       }
-       return &PublicKey{
-               Curve: P256Sm2(),
-               X:     x,
-               Y:     y,
-       }
-}