OSDN Git Service

Merge pull request #201 from Bytom/v0.1
[bytom/vapor.git] / vendor / github.com / bytom / crypto / sm2 / sm2.go
diff --git a/vendor/github.com/bytom/crypto/sm2/sm2.go b/vendor/github.com/bytom/crypto/sm2/sm2.go
new file mode 100644 (file)
index 0000000..8339c9b
--- /dev/null
@@ -0,0 +1,509 @@
+package sm2
+
+// reference to ecdsa
+import (
+       "bytes"
+       "crypto"
+       "crypto/aes"
+       "crypto/cipher"
+       "crypto/elliptic"
+       "crypto/rand"
+       "crypto/sha512"
+       "encoding/asn1"
+       "encoding/binary"
+       "errors"
+       "io"
+       "math/big"
+
+       "github.com/bytom/crypto/sm3"
+)
+
+const (
+       aesIV = "IV for <SM2> CTR"
+)
+
+type PublicKey struct {
+       elliptic.Curve
+       X, Y *big.Int
+}
+
+type PrivateKey struct {
+       PublicKey
+       D *big.Int
+}
+
+type sm2Signature struct {
+       R, S *big.Int
+}
+
+// The SM2's private key contains the public key
+func (priv *PrivateKey) Public() crypto.PublicKey {
+       return &priv.PublicKey
+}
+
+func SignDigitToSignData(r, s *big.Int) ([]byte, error) {
+       return asn1.Marshal(sm2Signature{r, s})
+}
+
+func SignDataToSignDigit(sign []byte) (*big.Int, *big.Int, error) {
+       var sm2Sign sm2Signature
+
+       _, err := asn1.Unmarshal(sign, &sm2Sign)
+       if err != nil {
+               return nil, nil, err
+       }
+       return sm2Sign.R, sm2Sign.S, nil
+}
+
+// sign format = 30 + len(z) + 02 + len(r) + r + 02 + len(s) + s, z being what follows its size, ie 02+len(r)+r+02+len(s)+s
+func (priv *PrivateKey) Sign(rand io.Reader, msg []byte, opts crypto.SignerOpts) ([]byte, error) {
+       r, s, err := Sign(priv, msg)
+       if err != nil {
+               return nil, err
+       }
+       return asn1.Marshal(sm2Signature{r, s})
+}
+
+func (priv *PrivateKey) Decrypt(data []byte) ([]byte, error) {
+       return Decrypt(priv, data)
+}
+
+func (pub *PublicKey) Verify(msg []byte, sign []byte) bool {
+       var sm2Sign sm2Signature
+
+       _, err := asn1.Unmarshal(sign, &sm2Sign)
+       if err != nil {
+               return false
+       }
+       return Verify(pub, msg, sm2Sign.R, sm2Sign.S)
+}
+
+func (pub *PublicKey) Encrypt(data []byte) ([]byte, error) {
+       return Encrypt(pub, data)
+}
+
+var one = new(big.Int).SetInt64(1)
+
+func intToBytes(x int) []byte {
+       var buf = make([]byte, 4)
+
+       binary.BigEndian.PutUint32(buf, uint32(x))
+       return buf
+}
+
+func kdf(x, y []byte, length int) ([]byte, bool) {
+       var c []byte
+
+       ct := 1
+       h := sm3.New()
+       x = append(x, y...)
+       for i, j := 0, (length+31)/32; i < j; i++ {
+               h.Reset()
+               h.Write(x)
+               h.Write(intToBytes(ct))
+               hash := h.Sum(nil)
+               if i+1 == j && length%32 != 0 {
+                       c = append(c, hash[:length%32]...)
+               } else {
+                       c = append(c, hash...)
+               }
+               ct++
+       }
+       for i := 0; i < length; i++ {
+               if c[i] != 0 {
+                       return c, true
+               }
+       }
+       return c, false
+}
+
+func randFieldElement(c elliptic.Curve, rand io.Reader) (k *big.Int, err error) {
+       params := c.Params()
+       b := make([]byte, params.BitSize/8+8)
+       _, err = io.ReadFull(rand, b)
+       if err != nil {
+               return
+       }
+       k = new(big.Int).SetBytes(b)
+       n := new(big.Int).Sub(params.N, one)
+       k.Mod(k, n)
+       k.Add(k, one)
+       return
+}
+
+func GenerateKey() (*PrivateKey, error) {
+       c := P256Sm2()
+       k, err := randFieldElement(c, rand.Reader)
+       if err != nil {
+               return nil, err
+       }
+       priv := new(PrivateKey)
+       priv.PublicKey.Curve = c
+       priv.D = k
+       priv.PublicKey.X, priv.PublicKey.Y = c.ScalarBaseMult(k.Bytes())
+       return priv, nil
+}
+
+var errZeroParam = errors.New("zero parameter")
+
+func Sign(priv *PrivateKey, hash []byte) (r, s *big.Int, err error) {
+       entropylen := (priv.Curve.Params().BitSize + 7) / 16
+       if entropylen > 32 {
+               entropylen = 32
+       }
+       entropy := make([]byte, entropylen)
+       _, err = io.ReadFull(rand.Reader, entropy)
+       if err != nil {
+               return
+       }
+
+       // Initialize an SHA-512 hash context; digest ...
+       md := sha512.New()
+       md.Write(priv.D.Bytes()) // the private key,
+       md.Write(entropy)        // the entropy,
+       md.Write(hash)           // and the input hash;
+       key := md.Sum(nil)[:32]  // and compute ChopMD-256(SHA-512),
+       // which is an indifferentiable MAC.
+
+       // Create an AES-CTR instance to use as a CSPRNG.
+       block, err := aes.NewCipher(key)
+       if err != nil {
+               return nil, nil, err
+       }
+
+       // Create a CSPRNG that xors a stream of zeros with
+       // the output of the AES-CTR instance.
+       csprng := cipher.StreamReader{
+               R: zeroReader,
+               S: cipher.NewCTR(block, []byte(aesIV)),
+       }
+
+       // See [NSA] 3.4.1
+       c := priv.PublicKey.Curve
+       N := c.Params().N
+       if N.Sign() == 0 {
+               return nil, nil, errZeroParam
+       }
+       var k *big.Int
+       e := new(big.Int).SetBytes(hash)
+       for { // 调整算法细节以实现SM2
+               for {
+                       k, err = randFieldElement(c, csprng)
+                       if err != nil {
+                               r = nil
+                               return
+                       }
+                       r, _ = priv.Curve.ScalarBaseMult(k.Bytes())
+                       r.Add(r, e)
+                       r.Mod(r, N)
+                       if r.Sign() != 0 {
+                               break
+                       }
+                       if t := new(big.Int).Add(r, k); t.Cmp(N) == 0 {
+                               break
+                       }
+               }
+               rD := new(big.Int).Mul(priv.D, r)
+               s = new(big.Int).Sub(k, rD)
+               d1 := new(big.Int).Add(priv.D, one)
+               d1Inv := new(big.Int).ModInverse(d1, N)
+               s.Mul(s, d1Inv)
+               s.Mod(s, N)
+               if s.Sign() != 0 {
+                       break
+               }
+       }
+       return
+}
+
+func Verify(pub *PublicKey, hash []byte, r, s *big.Int) bool {
+       c := pub.Curve
+       N := c.Params().N
+
+       if r.Sign() <= 0 || s.Sign() <= 0 {
+               return false
+       }
+       if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 {
+               return false
+       }
+
+       // 调整算法细节以实现SM2
+       t := new(big.Int).Add(r, s)
+       t.Mod(t, N)
+       if t.Sign() == 0 {
+               return false
+       }
+
+       var x *big.Int
+       x1, y1 := c.ScalarBaseMult(s.Bytes())
+       x2, y2 := c.ScalarMult(pub.X, pub.Y, t.Bytes())
+       x, _ = c.Add(x1, y1, x2, y2)
+
+       e := new(big.Int).SetBytes(hash)
+       x.Add(x, e)
+       x.Mod(x, N)
+       return x.Cmp(r) == 0
+}
+
+func Sm2Sign(priv *PrivateKey, msg, uid []byte) (r, s *big.Int, err error) {
+       za, err := ZA(&priv.PublicKey, uid)
+       if err != nil {
+               return nil, nil, err
+       }
+       e, err := msgHash(za, msg)
+       if err != nil {
+               return nil, nil, err
+       }
+       c := priv.PublicKey.Curve
+       N := c.Params().N
+       if N.Sign() == 0 {
+               return nil, nil, errZeroParam
+       }
+       var k *big.Int
+       for { // 调整算法细节以实现SM2
+               for {
+                       k, err = randFieldElement(c, rand.Reader)
+                       if err != nil {
+                               r = nil
+                               return
+                       }
+                       r, _ = priv.Curve.ScalarBaseMult(k.Bytes())
+                       r.Add(r, e)
+                       r.Mod(r, N)
+                       if r.Sign() != 0 {
+                               break
+                       }
+                       if t := new(big.Int).Add(r, k); t.Cmp(N) == 0 {
+                               break
+                       }
+               }
+               rD := new(big.Int).Mul(priv.D, r)
+               s = new(big.Int).Sub(k, rD)
+               d1 := new(big.Int).Add(priv.D, one)
+               d1Inv := new(big.Int).ModInverse(d1, N)
+               s.Mul(s, d1Inv)
+               s.Mod(s, N)
+               if s.Sign() != 0 {
+                       break
+               }
+       }
+       return
+}
+
+func Sm2Verify(pub *PublicKey, msg, uid []byte, r, s *big.Int) bool {
+       c := pub.Curve
+       N := c.Params().N
+       one := new(big.Int).SetInt64(1)
+       if r.Cmp(one) < 0 || s.Cmp(one) < 0 {
+               return false
+       }
+       if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 {
+               return false
+       }
+       za, err := ZA(pub, uid)
+       if err != nil {
+               return false
+       }
+       e, err := msgHash(za, msg)
+       if err != nil {
+               return false
+       }
+       t := new(big.Int).Add(r, s)
+       t.Mod(t, N)
+       if t.Sign() == 0 {
+               return false
+       }
+       var x *big.Int
+       x1, y1 := c.ScalarBaseMult(s.Bytes())
+       x2, y2 := c.ScalarMult(pub.X, pub.Y, t.Bytes())
+       x, _ = c.Add(x1, y1, x2, y2)
+
+       x.Add(x, e)
+       x.Mod(x, N)
+       return x.Cmp(r) == 0
+}
+
+func msgHash(za, msg []byte) (*big.Int, error) {
+       e := sm3.New()
+       e.Write(za)
+       e.Write(msg)
+       return new(big.Int).SetBytes(e.Sum(nil)[:32]), nil
+}
+
+// ZA = H256(ENTLA || IDA || a || b || xG || yG || xA || yA)
+func ZA(pub *PublicKey, uid []byte) ([]byte, error) {
+       za := sm3.New()
+       uidLen := len(uid)
+       if uidLen >= 8192 {
+               return []byte{}, errors.New("SM2: uid too large")
+       }
+       Entla := uint16(8 * uidLen)
+       za.Write([]byte{byte((Entla >> 8) & 0xFF)})
+       za.Write([]byte{byte(Entla & 0xFF)})
+       za.Write(uid)
+       za.Write(sm2P256ToBig(&sm2P256.a).Bytes())
+       za.Write(sm2P256.B.Bytes())
+       za.Write(sm2P256.Gx.Bytes())
+       za.Write(sm2P256.Gy.Bytes())
+
+       xBuf := pub.X.Bytes()
+       yBuf := pub.Y.Bytes()
+       if n := len(xBuf); n < 32 {
+               xBuf = append(zeroByteSlice[:32-n], xBuf...)
+       }
+       za.Write(xBuf)
+       za.Write(yBuf)
+       return za.Sum(nil)[:32], nil
+}
+
+// 32byte
+var zeroByteSlice = []byte{
+       0, 0, 0, 0,
+       0, 0, 0, 0,
+       0, 0, 0, 0,
+       0, 0, 0, 0,
+       0, 0, 0, 0,
+       0, 0, 0, 0,
+       0, 0, 0, 0,
+       0, 0, 0, 0,
+}
+
+/*
+ * sm2密文结构如下:
+ *  x
+ *  y
+ *  hash
+ *  CipherText
+ */
+func Encrypt(pub *PublicKey, data []byte) ([]byte, error) {
+       length := len(data)
+       for {
+               c := []byte{}
+               curve := pub.Curve
+               k, err := randFieldElement(curve, rand.Reader)
+               if err != nil {
+                       return nil, err
+               }
+               x1, y1 := curve.ScalarBaseMult(k.Bytes())
+               x2, y2 := curve.ScalarMult(pub.X, pub.Y, k.Bytes())
+               x1Buf := x1.Bytes()
+               y1Buf := y1.Bytes()
+               x2Buf := x2.Bytes()
+               y2Buf := y2.Bytes()
+               if n := len(x1Buf); n < 32 {
+                       x1Buf = append(zeroByteSlice[:32-n], x1Buf...)
+               }
+               if n := len(y1Buf); n < 32 {
+                       y1Buf = append(zeroByteSlice[:32-n], y1Buf...)
+               }
+               if n := len(x2Buf); n < 32 {
+                       x2Buf = append(zeroByteSlice[:32-n], x2Buf...)
+               }
+               if n := len(y2Buf); n < 32 {
+                       y2Buf = append(zeroByteSlice[:32-n], y2Buf...)
+               }
+               c = append(c, x1Buf...) // x分量
+               c = append(c, y1Buf...) // y分量
+               tm := []byte{}
+               tm = append(tm, x2Buf...)
+               tm = append(tm, data...)
+               tm = append(tm, y2Buf...)
+               h := sm3.Sm3Sum(tm)
+               c = append(c, h...)
+               ct, ok := kdf(x2Buf, y2Buf, length) // 密文
+               if !ok {
+                       continue
+               }
+               c = append(c, ct...)
+               for i := 0; i < length; i++ {
+                       c[96+i] ^= data[i]
+               }
+               return append([]byte{0x04}, c...), nil
+       }
+}
+
+func Decrypt(priv *PrivateKey, data []byte) ([]byte, error) {
+       data = data[1:]
+       length := len(data) - 96
+       curve := priv.Curve
+       x := new(big.Int).SetBytes(data[:32])
+       y := new(big.Int).SetBytes(data[32:64])
+       x2, y2 := curve.ScalarMult(x, y, priv.D.Bytes())
+       x2Buf := x2.Bytes()
+       y2Buf := y2.Bytes()
+       if n := len(x2Buf); n < 32 {
+               x2Buf = append(zeroByteSlice[:32-n], x2Buf...)
+       }
+       if n := len(y2Buf); n < 32 {
+               y2Buf = append(zeroByteSlice[:32-n], y2Buf...)
+       }
+       c, ok := kdf(x2Buf, y2Buf, length)
+       if !ok {
+               return nil, errors.New("Decrypt: failed to decrypt")
+       }
+       for i := 0; i < length; i++ {
+               c[i] ^= data[i+96]
+       }
+       tm := []byte{}
+       tm = append(tm, x2Buf...)
+       tm = append(tm, c...)
+       tm = append(tm, y2Buf...)
+       h := sm3.Sm3Sum(tm)
+       if bytes.Compare(h, data[64:96]) != 0 {
+               return c, errors.New("Decrypt: failed to decrypt")
+       }
+       return c, nil
+}
+
+type zr struct {
+       io.Reader
+}
+
+func (z *zr) Read(dst []byte) (n int, err error) {
+       for i := range dst {
+               dst[i] = 0
+       }
+       return len(dst), nil
+}
+
+var zeroReader = &zr{}
+
+func getLastBit(a *big.Int) uint {
+       return a.Bit(0)
+}
+
+func Compress(a *PublicKey) []byte {
+       buf := []byte{}
+       yp := getLastBit(a.Y)
+       buf = append(buf, a.X.Bytes()...)
+       if n := len(a.X.Bytes()); n < 32 {
+               buf = append(zeroByteSlice[:(32-n)], buf...)
+       }
+       buf = append([]byte{byte(yp)}, buf...)
+       return buf
+}
+
+func Decompress(a []byte) *PublicKey {
+       var aa, xx, xx3 sm2P256FieldElement
+
+       P256Sm2()
+       x := new(big.Int).SetBytes(a[1:])
+       curve := sm2P256
+       sm2P256FromBig(&xx, x)
+       sm2P256Square(&xx3, &xx)       // x3 = x ^ 2
+       sm2P256Mul(&xx3, &xx3, &xx)    // x3 = x ^ 2 * x
+       sm2P256Mul(&aa, &curve.a, &xx) // a = a * x
+       sm2P256Add(&xx3, &xx3, &aa)
+       sm2P256Add(&xx3, &xx3, &curve.b)
+
+       y2 := sm2P256ToBig(&xx3)
+       y := new(big.Int).ModSqrt(y2, sm2P256.P)
+       if getLastBit(y) != uint(a[0]) {
+               y.Sub(sm2P256.P, y)
+       }
+       return &PublicKey{
+               Curve: P256Sm2(),
+               X:     x,
+               Y:     y,
+       }
+}