X-Git-Url: http://git.osdn.net/view?p=bytom%2Fvapor.git;a=blobdiff_plain;f=crypto%2Fsm2%2Fsm2.go;fp=crypto%2Fsm2%2Fsm2.go;h=0000000000000000000000000000000000000000;hp=872bc48fa7b386f526df4f6c0a0db0b20d71e1ee;hb=d09b7a78d44dc259725902b8141cdba0d716b121;hpb=ee01d543fdfe1fd0a4d548965c66f7923ea7b062 diff --git a/crypto/sm2/sm2.go b/crypto/sm2/sm2.go deleted file mode 100644 index 872bc48f..00000000 --- a/crypto/sm2/sm2.go +++ /dev/null @@ -1,509 +0,0 @@ -package sm2 - -// reference to ecdsa -import ( - "bytes" - "crypto" - "crypto/aes" - "crypto/cipher" - "crypto/elliptic" - "crypto/rand" - "crypto/sha512" - "encoding/asn1" - "encoding/binary" - "errors" - "io" - "math/big" - - "github.com/vapor/crypto/sm3" -) - -const ( - aesIV = "IV for CTR" -) - -type PublicKey struct { - elliptic.Curve - X, Y *big.Int -} - -type PrivateKey struct { - PublicKey - D *big.Int -} - -type sm2Signature struct { - R, S *big.Int -} - -// The SM2's private key contains the public key -func (priv *PrivateKey) Public() crypto.PublicKey { - return &priv.PublicKey -} - -func SignDigitToSignData(r, s *big.Int) ([]byte, error) { - return asn1.Marshal(sm2Signature{r, s}) -} - -func SignDataToSignDigit(sign []byte) (*big.Int, *big.Int, error) { - var sm2Sign sm2Signature - - _, err := asn1.Unmarshal(sign, &sm2Sign) - if err != nil { - return nil, nil, err - } - return sm2Sign.R, sm2Sign.S, nil -} - -// sign format = 30 + len(z) + 02 + len(r) + r + 02 + len(s) + s, z being what follows its size, ie 02+len(r)+r+02+len(s)+s -func (priv *PrivateKey) Sign(rand io.Reader, msg []byte, opts crypto.SignerOpts) ([]byte, error) { - r, s, err := Sign(priv, msg) - if err != nil { - return nil, err - } - return asn1.Marshal(sm2Signature{r, s}) -} - -func (priv *PrivateKey) Decrypt(data []byte) ([]byte, error) { - return Decrypt(priv, data) -} - -func (pub *PublicKey) Verify(msg []byte, sign []byte) bool { - var sm2Sign sm2Signature - - _, err := asn1.Unmarshal(sign, &sm2Sign) - if err != nil { - return false - } - return Verify(pub, msg, sm2Sign.R, sm2Sign.S) -} - -func (pub *PublicKey) Encrypt(data []byte) ([]byte, error) { - return Encrypt(pub, data) -} - -var one = new(big.Int).SetInt64(1) - -func intToBytes(x int) []byte { - var buf = make([]byte, 4) - - binary.BigEndian.PutUint32(buf, uint32(x)) - return buf -} - -func kdf(x, y []byte, length int) ([]byte, bool) { - var c []byte - - ct := 1 - h := sm3.New() - x = append(x, y...) - for i, j := 0, (length+31)/32; i < j; i++ { - h.Reset() - h.Write(x) - h.Write(intToBytes(ct)) - hash := h.Sum(nil) - if i+1 == j && length%32 != 0 { - c = append(c, hash[:length%32]...) - } else { - c = append(c, hash...) - } - ct++ - } - for i := 0; i < length; i++ { - if c[i] != 0 { - return c, true - } - } - return c, false -} - -func randFieldElement(c elliptic.Curve, rand io.Reader) (k *big.Int, err error) { - params := c.Params() - b := make([]byte, params.BitSize/8+8) - _, err = io.ReadFull(rand, b) - if err != nil { - return - } - k = new(big.Int).SetBytes(b) - n := new(big.Int).Sub(params.N, one) - k.Mod(k, n) - k.Add(k, one) - return -} - -func GenerateKey() (*PrivateKey, error) { - c := P256Sm2() - k, err := randFieldElement(c, rand.Reader) - if err != nil { - return nil, err - } - priv := new(PrivateKey) - priv.PublicKey.Curve = c - priv.D = k - priv.PublicKey.X, priv.PublicKey.Y = c.ScalarBaseMult(k.Bytes()) - return priv, nil -} - -var errZeroParam = errors.New("zero parameter") - -func Sign(priv *PrivateKey, hash []byte) (r, s *big.Int, err error) { - entropylen := (priv.Curve.Params().BitSize + 7) / 16 - if entropylen > 32 { - entropylen = 32 - } - entropy := make([]byte, entropylen) - _, err = io.ReadFull(rand.Reader, entropy) - if err != nil { - return - } - - // Initialize an SHA-512 hash context; digest ... - md := sha512.New() - md.Write(priv.D.Bytes()) // the private key, - md.Write(entropy) // the entropy, - md.Write(hash) // and the input hash; - key := md.Sum(nil)[:32] // and compute ChopMD-256(SHA-512), - // which is an indifferentiable MAC. - - // Create an AES-CTR instance to use as a CSPRNG. - block, err := aes.NewCipher(key) - if err != nil { - return nil, nil, err - } - - // Create a CSPRNG that xors a stream of zeros with - // the output of the AES-CTR instance. - csprng := cipher.StreamReader{ - R: zeroReader, - S: cipher.NewCTR(block, []byte(aesIV)), - } - - // See [NSA] 3.4.1 - c := priv.PublicKey.Curve - N := c.Params().N - if N.Sign() == 0 { - return nil, nil, errZeroParam - } - var k *big.Int - e := new(big.Int).SetBytes(hash) - for { // 调整算法细节以实现SM2 - for { - k, err = randFieldElement(c, csprng) - if err != nil { - r = nil - return - } - r, _ = priv.Curve.ScalarBaseMult(k.Bytes()) - r.Add(r, e) - r.Mod(r, N) - if r.Sign() != 0 { - break - } - if t := new(big.Int).Add(r, k); t.Cmp(N) == 0 { - break - } - } - rD := new(big.Int).Mul(priv.D, r) - s = new(big.Int).Sub(k, rD) - d1 := new(big.Int).Add(priv.D, one) - d1Inv := new(big.Int).ModInverse(d1, N) - s.Mul(s, d1Inv) - s.Mod(s, N) - if s.Sign() != 0 { - break - } - } - return -} - -func Verify(pub *PublicKey, hash []byte, r, s *big.Int) bool { - c := pub.Curve - N := c.Params().N - - if r.Sign() <= 0 || s.Sign() <= 0 { - return false - } - if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 { - return false - } - - // 调整算法细节以实现SM2 - t := new(big.Int).Add(r, s) - t.Mod(t, N) - if t.Sign() == 0 { - return false - } - - var x *big.Int - x1, y1 := c.ScalarBaseMult(s.Bytes()) - x2, y2 := c.ScalarMult(pub.X, pub.Y, t.Bytes()) - x, _ = c.Add(x1, y1, x2, y2) - - e := new(big.Int).SetBytes(hash) - x.Add(x, e) - x.Mod(x, N) - return x.Cmp(r) == 0 -} - -func Sm2Sign(priv *PrivateKey, msg, uid []byte) (r, s *big.Int, err error) { - za, err := ZA(&priv.PublicKey, uid) - if err != nil { - return nil, nil, err - } - e, err := msgHash(za, msg) - if err != nil { - return nil, nil, err - } - c := priv.PublicKey.Curve - N := c.Params().N - if N.Sign() == 0 { - return nil, nil, errZeroParam - } - var k *big.Int - for { // 调整算法细节以实现SM2 - for { - k, err = randFieldElement(c, rand.Reader) - if err != nil { - r = nil - return - } - r, _ = priv.Curve.ScalarBaseMult(k.Bytes()) - r.Add(r, e) - r.Mod(r, N) - if r.Sign() != 0 { - break - } - if t := new(big.Int).Add(r, k); t.Cmp(N) == 0 { - break - } - } - rD := new(big.Int).Mul(priv.D, r) - s = new(big.Int).Sub(k, rD) - d1 := new(big.Int).Add(priv.D, one) - d1Inv := new(big.Int).ModInverse(d1, N) - s.Mul(s, d1Inv) - s.Mod(s, N) - if s.Sign() != 0 { - break - } - } - return -} - -func Sm2Verify(pub *PublicKey, msg, uid []byte, r, s *big.Int) bool { - c := pub.Curve - N := c.Params().N - one := new(big.Int).SetInt64(1) - if r.Cmp(one) < 0 || s.Cmp(one) < 0 { - return false - } - if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 { - return false - } - za, err := ZA(pub, uid) - if err != nil { - return false - } - e, err := msgHash(za, msg) - if err != nil { - return false - } - t := new(big.Int).Add(r, s) - t.Mod(t, N) - if t.Sign() == 0 { - return false - } - var x *big.Int - x1, y1 := c.ScalarBaseMult(s.Bytes()) - x2, y2 := c.ScalarMult(pub.X, pub.Y, t.Bytes()) - x, _ = c.Add(x1, y1, x2, y2) - - x.Add(x, e) - x.Mod(x, N) - return x.Cmp(r) == 0 -} - -func msgHash(za, msg []byte) (*big.Int, error) { - e := sm3.New() - e.Write(za) - e.Write(msg) - return new(big.Int).SetBytes(e.Sum(nil)[:32]), nil -} - -// ZA = H256(ENTLA || IDA || a || b || xG || yG || xA || yA) -func ZA(pub *PublicKey, uid []byte) ([]byte, error) { - za := sm3.New() - uidLen := len(uid) - if uidLen >= 8192 { - return []byte{}, errors.New("SM2: uid too large") - } - Entla := uint16(8 * uidLen) - za.Write([]byte{byte((Entla >> 8) & 0xFF)}) - za.Write([]byte{byte(Entla & 0xFF)}) - za.Write(uid) - za.Write(sm2P256ToBig(&sm2P256.a).Bytes()) - za.Write(sm2P256.B.Bytes()) - za.Write(sm2P256.Gx.Bytes()) - za.Write(sm2P256.Gy.Bytes()) - - xBuf := pub.X.Bytes() - yBuf := pub.Y.Bytes() - if n := len(xBuf); n < 32 { - xBuf = append(zeroByteSlice[:32-n], xBuf...) - } - za.Write(xBuf) - za.Write(yBuf) - return za.Sum(nil)[:32], nil -} - -// 32byte -var zeroByteSlice = []byte{ - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0, 0, -} - -/* - * sm2密文结构如下: - * x - * y - * hash - * CipherText - */ -func Encrypt(pub *PublicKey, data []byte) ([]byte, error) { - length := len(data) - for { - c := []byte{} - curve := pub.Curve - k, err := randFieldElement(curve, rand.Reader) - if err != nil { - return nil, err - } - x1, y1 := curve.ScalarBaseMult(k.Bytes()) - x2, y2 := curve.ScalarMult(pub.X, pub.Y, k.Bytes()) - x1Buf := x1.Bytes() - y1Buf := y1.Bytes() - x2Buf := x2.Bytes() - y2Buf := y2.Bytes() - if n := len(x1Buf); n < 32 { - x1Buf = append(zeroByteSlice[:32-n], x1Buf...) - } - if n := len(y1Buf); n < 32 { - y1Buf = append(zeroByteSlice[:32-n], y1Buf...) - } - if n := len(x2Buf); n < 32 { - x2Buf = append(zeroByteSlice[:32-n], x2Buf...) - } - if n := len(y2Buf); n < 32 { - y2Buf = append(zeroByteSlice[:32-n], y2Buf...) - } - c = append(c, x1Buf...) // x分量 - c = append(c, y1Buf...) // y分量 - tm := []byte{} - tm = append(tm, x2Buf...) - tm = append(tm, data...) - tm = append(tm, y2Buf...) - h := sm3.Sm3Sum(tm) - c = append(c, h...) - ct, ok := kdf(x2Buf, y2Buf, length) // 密文 - if !ok { - continue - } - c = append(c, ct...) - for i := 0; i < length; i++ { - c[96+i] ^= data[i] - } - return append([]byte{0x04}, c...), nil - } -} - -func Decrypt(priv *PrivateKey, data []byte) ([]byte, error) { - data = data[1:] - length := len(data) - 96 - curve := priv.Curve - x := new(big.Int).SetBytes(data[:32]) - y := new(big.Int).SetBytes(data[32:64]) - x2, y2 := curve.ScalarMult(x, y, priv.D.Bytes()) - x2Buf := x2.Bytes() - y2Buf := y2.Bytes() - if n := len(x2Buf); n < 32 { - x2Buf = append(zeroByteSlice[:32-n], x2Buf...) - } - if n := len(y2Buf); n < 32 { - y2Buf = append(zeroByteSlice[:32-n], y2Buf...) - } - c, ok := kdf(x2Buf, y2Buf, length) - if !ok { - return nil, errors.New("Decrypt: failed to decrypt") - } - for i := 0; i < length; i++ { - c[i] ^= data[i+96] - } - tm := []byte{} - tm = append(tm, x2Buf...) - tm = append(tm, c...) - tm = append(tm, y2Buf...) - h := sm3.Sm3Sum(tm) - if bytes.Compare(h, data[64:96]) != 0 { - return c, errors.New("Decrypt: failed to decrypt") - } - return c, nil -} - -type zr struct { - io.Reader -} - -func (z *zr) Read(dst []byte) (n int, err error) { - for i := range dst { - dst[i] = 0 - } - return len(dst), nil -} - -var zeroReader = &zr{} - -func getLastBit(a *big.Int) uint { - return a.Bit(0) -} - -func Compress(a *PublicKey) []byte { - buf := []byte{} - yp := getLastBit(a.Y) - buf = append(buf, a.X.Bytes()...) - if n := len(a.X.Bytes()); n < 32 { - buf = append(zeroByteSlice[:(32-n)], buf...) - } - buf = append([]byte{byte(yp)}, buf...) - return buf -} - -func Decompress(a []byte) *PublicKey { - var aa, xx, xx3 sm2P256FieldElement - - P256Sm2() - x := new(big.Int).SetBytes(a[1:]) - curve := sm2P256 - sm2P256FromBig(&xx, x) - sm2P256Square(&xx3, &xx) // x3 = x ^ 2 - sm2P256Mul(&xx3, &xx3, &xx) // x3 = x ^ 2 * x - sm2P256Mul(&aa, &curve.a, &xx) // a = a * x - sm2P256Add(&xx3, &xx3, &aa) - sm2P256Add(&xx3, &xx3, &curve.b) - - y2 := sm2P256ToBig(&xx3) - y := new(big.Int).ModSqrt(y2, sm2P256.P) - if getLastBit(y) != uint(a[0]) { - y.Sub(sm2P256.P, y) - } - return &PublicKey{ - Curve: P256Sm2(), - X: x, - Y: y, - } -}