From dae1b15e92f95c6750c8014d8c90a6e169980bdd Mon Sep 17 00:00:00 2001 From: Chengcheng Zhang <943420582@qq.com> Date: Fri, 27 Jul 2018 14:44:50 +0800 Subject: [PATCH] add sm2 (#1193) * add sm2 * fix import --- crypto/sm2/cert_pool.go | 229 +++++ crypto/sm2/p256.go | 1056 ++++++++++++++++++++ crypto/sm2/pkcs1.go | 132 +++ crypto/sm2/pkcs8.go | 488 +++++++++ crypto/sm2/sm2.go | 524 ++++++++++ crypto/sm2/sm2_test.go | 219 ++++ crypto/sm2/verify.go | 568 +++++++++++ crypto/sm2/x509.go | 2527 +++++++++++++++++++++++++++++++++++++++++++++++ crypto/sm3/sm3_test.go | 1 + 9 files changed, 5744 insertions(+) create mode 100644 crypto/sm2/cert_pool.go create mode 100644 crypto/sm2/p256.go create mode 100644 crypto/sm2/pkcs1.go create mode 100644 crypto/sm2/pkcs8.go create mode 100644 crypto/sm2/sm2.go create mode 100644 crypto/sm2/sm2_test.go create mode 100644 crypto/sm2/verify.go create mode 100644 crypto/sm2/x509.go diff --git a/crypto/sm2/cert_pool.go b/crypto/sm2/cert_pool.go new file mode 100644 index 00000000..8230c4d1 --- /dev/null +++ b/crypto/sm2/cert_pool.go @@ -0,0 +1,229 @@ +/* +Copyright Suzhou Tongji Fintech Research Institute 2017 All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sm2 + +import ( + "encoding/pem" + "errors" + "io/ioutil" + "os" + "runtime" + "sync" +) + +// Possible certificate files; stop after finding one. +var certFiles = []string{ + "/etc/ssl/certs/ca-certificates.crt", // Debian/Ubuntu/Gentoo etc. + "/etc/pki/tls/certs/ca-bundle.crt", // Fedora/RHEL 6 + "/etc/ssl/ca-bundle.pem", // OpenSUSE + "/etc/pki/tls/cacert.pem", // OpenELEC + "/etc/pki/ca-trust/extracted/pem/tls-ca-bundle.pem", // CentOS/RHEL 7 +} + +// CertPool is a set of certificates. +type CertPool struct { + bySubjectKeyId map[string][]int + byName map[string][]int + certs []*Certificate +} + +// NewCertPool returns a new, empty CertPool. +func NewCertPool() *CertPool { + return &CertPool{ + bySubjectKeyId: make(map[string][]int), + byName: make(map[string][]int), + } +} + +// Possible directories with certificate files; stop after successfully +// reading at least one file from a directory. +var certDirectories = []string{ + "/etc/ssl/certs", // SLES10/SLES11, https://golang.org/issue/12139 + "/system/etc/security/cacerts", // Android +} + +var ( + once sync.Once + systemRoots *CertPool + systemRootsErr error +) + +func systemRootsPool() *CertPool { + once.Do(initSystemRoots) + return systemRoots +} + +func initSystemRoots() { + systemRoots, systemRootsErr = loadSystemRoots() +} + +func (c *Certificate) systemVerify(opts *VerifyOptions) (chains [][]*Certificate, err error) { + return nil, nil +} + +func loadSystemRoots() (*CertPool, error) { + roots := NewCertPool() + var firstErr error + for _, file := range certFiles { + data, err := ioutil.ReadFile(file) + if err == nil { + roots.AppendCertsFromPEM(data) + return roots, nil + } + if firstErr == nil && !os.IsNotExist(err) { + firstErr = err + } + } + + for _, directory := range certDirectories { + fis, err := ioutil.ReadDir(directory) + if err != nil { + if firstErr == nil && !os.IsNotExist(err) { + firstErr = err + } + continue + } + rootsAdded := false + for _, fi := range fis { + data, err := ioutil.ReadFile(directory + "/" + fi.Name()) + if err == nil && roots.AppendCertsFromPEM(data) { + rootsAdded = true + } + } + if rootsAdded { + return roots, nil + } + } + + return nil, firstErr +} + +// SystemCertPool returns a copy of the system cert pool. +// +// Any mutations to the returned pool are not written to disk and do +// not affect any other pool. +func SystemCertPool() (*CertPool, error) { + if runtime.GOOS == "windows" { + // Issue 16736, 18609: + return nil, errors.New("crypto/x509: system root pool is not available on Windows") + } + + return loadSystemRoots() +} + +// findVerifiedParents attempts to find certificates in s which have signed the +// given certificate. If any candidates were rejected then errCert will be set +// to one of them, arbitrarily, and err will contain the reason that it was +// rejected. +func (s *CertPool) findVerifiedParents(cert *Certificate) (parents []int, errCert *Certificate, err error) { + if s == nil { + return + } + var candidates []int + + if len(cert.AuthorityKeyId) > 0 { + candidates = s.bySubjectKeyId[string(cert.AuthorityKeyId)] + } + if len(candidates) == 0 { + candidates = s.byName[string(cert.RawIssuer)] + } + + for _, c := range candidates { + if err = cert.CheckSignatureFrom(s.certs[c]); err == nil { + parents = append(parents, c) + } else { + errCert = s.certs[c] + } + } + + return +} + +func (s *CertPool) contains(cert *Certificate) bool { + if s == nil { + return false + } + + candidates := s.byName[string(cert.RawSubject)] + for _, c := range candidates { + if s.certs[c].Equal(cert) { + return true + } + } + + return false +} + +// AddCert adds a certificate to a pool. +func (s *CertPool) AddCert(cert *Certificate) { + if cert == nil { + panic("adding nil Certificate to CertPool") + } + + // Check that the certificate isn't being added twice. + if s.contains(cert) { + return + } + + n := len(s.certs) + s.certs = append(s.certs, cert) + + if len(cert.SubjectKeyId) > 0 { + keyId := string(cert.SubjectKeyId) + s.bySubjectKeyId[keyId] = append(s.bySubjectKeyId[keyId], n) + } + name := string(cert.RawSubject) + s.byName[name] = append(s.byName[name], n) +} + +// AppendCertsFromPEM attempts to parse a series of PEM encoded certificates. +// It appends any certificates found to s and reports whether any certificates +// were successfully parsed. +// +// On many Linux systems, /etc/ssl/cert.pem will contain the system wide set +// of root CAs in a format suitable for this function. +func (s *CertPool) AppendCertsFromPEM(pemCerts []byte) (ok bool) { + for len(pemCerts) > 0 { + var block *pem.Block + block, pemCerts = pem.Decode(pemCerts) + if block == nil { + break + } + if block.Type != "CERTIFICATE" || len(block.Headers) != 0 { + continue + } + + cert, err := ParseCertificate(block.Bytes) + if err != nil { + continue + } + + s.AddCert(cert) + ok = true + } + + return +} + +// Subjects returns a list of the DER-encoded subjects of +// all of the certificates in the pool. +func (s *CertPool) Subjects() [][]byte { + res := make([][]byte, len(s.certs)) + for i, c := range s.certs { + res[i] = c.RawSubject + } + return res +} diff --git a/crypto/sm2/p256.go b/crypto/sm2/p256.go new file mode 100644 index 00000000..a7293238 --- /dev/null +++ b/crypto/sm2/p256.go @@ -0,0 +1,1056 @@ +/* +Copyright Suzhou Tongji Fintech Research Institute 2017 All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sm2 + +import ( + "crypto/elliptic" + "math/big" + "sync" +) + +/** 学习标准库p256的优化方法实现sm2的快速版本 + * 标准库的p256的代码实现有些晦涩难懂,当然sm2的同样如此,有兴趣的大家可以研究研究,最后神兽压阵。。。 + * + * ━━━━━━animal━━━━━━ + *    ┏┓   ┏┓ + *   ┏┛┻━━━┛┻┓ + *   ┃       ┃ + *   ┃   ━   ┃ + *   ┃ ┳┛ ┗┳ ┃ + *   ┃       ┃ + *   ┃   ┻   ┃ + *   ┃       ┃ + *   ┗━┓   ┏━┛ + *    ┃   ┃ + *    ┃   ┃ + *    ┃   ┗━━━┓ + *  ┃     ┣┓ + *   ┃     ┏┛ + *    ┗┓┓┏━┳┓┏┛ + *    ┃┫┫ ┃┫┫ + *    ┗┻┛ ┗┻┛ + * + * ━━━━━Kawaii ━━━━━━ + */ + +type sm2P256Curve struct { + RInverse *big.Int + *elliptic.CurveParams + a, b, gx, gy sm2P256FieldElement +} + +var initonce sync.Once +var sm2P256 sm2P256Curve + +type sm2P256FieldElement [9]uint32 +type sm2P256LargeFieldElement [17]uint64 + +const ( + bottom28Bits = 0xFFFFFFF + bottom29Bits = 0x1FFFFFFF +) + +func initP256Sm2() { + sm2P256.CurveParams = &elliptic.CurveParams{Name: "SM2-P-256"} // sm2 + A, _ := new(big.Int).SetString("FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFC", 16) + //SM2椭 椭 圆 曲 线 公 钥 密 码 算 法 推 荐 曲 线 参 数 + sm2P256.P, _ = new(big.Int).SetString("FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFF", 16) + sm2P256.N, _ = new(big.Int).SetString("FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF7203DF6B21C6052B53BBF40939D54123", 16) + sm2P256.B, _ = new(big.Int).SetString("28E9FA9E9D9F5E344D5A9E4BCF6509A7F39789F515AB8F92DDBCBD414D940E93", 16) + sm2P256.Gx, _ = new(big.Int).SetString("32C4AE2C1F1981195F9904466A39C9948FE30BBFF2660BE1715A4589334C74C7", 16) + sm2P256.Gy, _ = new(big.Int).SetString("BC3736A2F4F6779C59BDCEE36B692153D0A9877CC62A474002DF32E52139F0A0", 16) + sm2P256.RInverse, _ = new(big.Int).SetString("7ffffffd80000002fffffffe000000017ffffffe800000037ffffffc80000002", 16) + sm2P256.BitSize = 256 + sm2P256FromBig(&sm2P256.a, A) + sm2P256FromBig(&sm2P256.gx, sm2P256.Gx) + sm2P256FromBig(&sm2P256.gy, sm2P256.Gy) + sm2P256FromBig(&sm2P256.b, sm2P256.B) +} + +func P256Sm2() elliptic.Curve { + initonce.Do(initP256Sm2) + return sm2P256 +} + +func (curve sm2P256Curve) Params() *elliptic.CurveParams { + return sm2P256.CurveParams +} + +// y^2 = x^3 + ax + b +func (curve sm2P256Curve) IsOnCurve(X, Y *big.Int) bool { + var a, x, y, y2, x3 sm2P256FieldElement + + sm2P256FromBig(&x, X) + sm2P256FromBig(&y, Y) + + sm2P256Square(&x3, &x) // x3 = x ^ 2 + sm2P256Mul(&x3, &x3, &x) // x3 = x ^ 2 * x + sm2P256Mul(&a, &curve.a, &x) // a = a * x + sm2P256Add(&x3, &x3, &a) + sm2P256Add(&x3, &x3, &curve.b) + + sm2P256Square(&y2, &y) // y2 = y ^ 2 + return sm2P256ToBig(&x3).Cmp(sm2P256ToBig(&y2)) == 0 +} + +func zForAffine(x, y *big.Int) *big.Int { + z := new(big.Int) + if x.Sign() != 0 || y.Sign() != 0 { + z.SetInt64(1) + } + return z +} + +func (curve sm2P256Curve) Add(x1, y1, x2, y2 *big.Int) (*big.Int, *big.Int) { + var X1, Y1, Z1, X2, Y2, Z2, X3, Y3, Z3 sm2P256FieldElement + + z1 := zForAffine(x1, y1) + z2 := zForAffine(x2, y2) + sm2P256FromBig(&X1, x1) + sm2P256FromBig(&Y1, y1) + sm2P256FromBig(&Z1, z1) + sm2P256FromBig(&X2, x2) + sm2P256FromBig(&Y2, y2) + sm2P256FromBig(&Z2, z2) + sm2P256PointAdd(&X1, &Y1, &Z1, &X2, &Y2, &Z2, &X3, &Y3, &Z3) + return sm2P256ToAffine(&X3, &Y3, &Z3) +} + +func (curve sm2P256Curve) Double(x1, y1 *big.Int) (*big.Int, *big.Int) { + var X1, Y1, Z1 sm2P256FieldElement + + z1 := zForAffine(x1, y1) + sm2P256FromBig(&X1, x1) + sm2P256FromBig(&Y1, y1) + sm2P256FromBig(&Z1, z1) + sm2P256PointDouble(&X1, &Y1, &Z1, &X1, &Y1, &Z1) + return sm2P256ToAffine(&X1, &Y1, &Z1) +} + +func (curve sm2P256Curve) ScalarMult(x1, y1 *big.Int, k []byte) (*big.Int, *big.Int) { + var scalarReversed [32]byte + var X, Y, Z, X1, Y1 sm2P256FieldElement + + sm2P256FromBig(&X1, x1) + sm2P256FromBig(&Y1, y1) + sm2P256GetScalar(&scalarReversed, k) + sm2P256ScalarMult(&X, &Y, &Z, &X1, &Y1, &scalarReversed) + return sm2P256ToAffine(&X, &Y, &Z) +} + +func (curve sm2P256Curve) ScalarBaseMult(k []byte) (*big.Int, *big.Int) { + var scalarReversed [32]byte + var X, Y, Z sm2P256FieldElement + + sm2P256GetScalar(&scalarReversed, k) + sm2P256ScalarBaseMult(&X, &Y, &Z, &scalarReversed) + return sm2P256ToAffine(&X, &Y, &Z) +} + +var sm2P256Precomputed = [9 * 2 * 15 * 2]uint32{ + 0x830053d, 0x328990f, 0x6c04fe1, 0xc0f72e5, 0x1e19f3c, 0x666b093, 0x175a87b, 0xec38276, 0x222cf4b, + 0x185a1bba, 0x354e593, 0x1295fac1, 0xf2bc469, 0x47c60fa, 0xc19b8a9, 0xf63533e, 0x903ae6b, 0xc79acba, + 0x15b061a4, 0x33e020b, 0xdffb34b, 0xfcf2c8, 0x16582e08, 0x262f203, 0xfb34381, 0xa55452, 0x604f0ff, + 0x41f1f90, 0xd64ced2, 0xee377bf, 0x75f05f0, 0x189467ae, 0xe2244e, 0x1e7700e8, 0x3fbc464, 0x9612d2e, + 0x1341b3b8, 0xee84e23, 0x1edfa5b4, 0x14e6030, 0x19e87be9, 0x92f533c, 0x1665d96c, 0x226653e, 0xa238d3e, + 0xf5c62c, 0x95bb7a, 0x1f0e5a41, 0x28789c3, 0x1f251d23, 0x8726609, 0xe918910, 0x8096848, 0xf63d028, + 0x152296a1, 0x9f561a8, 0x14d376fb, 0x898788a, 0x61a95fb, 0xa59466d, 0x159a003d, 0x1ad1698, 0x93cca08, + 0x1b314662, 0x706e006, 0x11ce1e30, 0x97b710, 0x172fbc0d, 0x8f50158, 0x11c7ffe7, 0xd182cce, 0xc6ad9e8, + 0x12ea31b2, 0xc4e4f38, 0x175b0d96, 0xec06337, 0x75a9c12, 0xb001fdf, 0x93e82f5, 0x34607de, 0xb8035ed, + 0x17f97924, 0x75cf9e6, 0xdceaedd, 0x2529924, 0x1a10c5ff, 0xb1a54dc, 0x19464d8, 0x2d1997, 0xde6a110, + 0x1e276ee5, 0x95c510c, 0x1aca7c7a, 0xfe48aca, 0x121ad4d9, 0xe4132c6, 0x8239b9d, 0x40ea9cd, 0x816c7b, + 0x632d7a4, 0xa679813, 0x5911fcf, 0x82b0f7c, 0x57b0ad5, 0xbef65, 0xd541365, 0x7f9921f, 0xc62e7a, + 0x3f4b32d, 0x58e50e1, 0x6427aed, 0xdcdda67, 0xe8c2d3e, 0x6aa54a4, 0x18df4c35, 0x49a6a8e, 0x3cd3d0c, + 0xd7adf2, 0xcbca97, 0x1bda5f2d, 0x3258579, 0x606b1e6, 0x6fc1b5b, 0x1ac27317, 0x503ca16, 0xa677435, + 0x57bc73, 0x3992a42, 0xbab987b, 0xfab25eb, 0x128912a4, 0x90a1dc4, 0x1402d591, 0x9ffbcfc, 0xaa48856, + 0x7a7c2dc, 0xcefd08a, 0x1b29bda6, 0xa785641, 0x16462d8c, 0x76241b7, 0x79b6c3b, 0x204ae18, 0xf41212b, + 0x1f567a4d, 0xd6ce6db, 0xedf1784, 0x111df34, 0x85d7955, 0x55fc189, 0x1b7ae265, 0xf9281ac, 0xded7740, + 0xf19468b, 0x83763bb, 0x8ff7234, 0x3da7df8, 0x9590ac3, 0xdc96f2a, 0x16e44896, 0x7931009, 0x99d5acc, + 0x10f7b842, 0xaef5e84, 0xc0310d7, 0xdebac2c, 0x2a7b137, 0x4342344, 0x19633649, 0x3a10624, 0x4b4cb56, + 0x1d809c59, 0xac007f, 0x1f0f4bcd, 0xa1ab06e, 0xc5042cf, 0x82c0c77, 0x76c7563, 0x22c30f3, 0x3bf1568, + 0x7a895be, 0xfcca554, 0x12e90e4c, 0x7b4ab5f, 0x13aeb76b, 0x5887e2c, 0x1d7fe1e3, 0x908c8e3, 0x95800ee, + 0xb36bd54, 0xf08905d, 0x4e73ae8, 0xf5a7e48, 0xa67cb0, 0x50e1067, 0x1b944a0a, 0xf29c83a, 0xb23cfb9, + 0xbe1db1, 0x54de6e8, 0xd4707f2, 0x8ebcc2d, 0x2c77056, 0x1568ce4, 0x15fcc849, 0x4069712, 0xe2ed85f, + 0x2c5ff09, 0x42a6929, 0x628e7ea, 0xbd5b355, 0xaf0bd79, 0xaa03699, 0xdb99816, 0x4379cef, 0x81d57b, + 0x11237f01, 0xe2a820b, 0xfd53b95, 0x6beb5ee, 0x1aeb790c, 0xe470d53, 0x2c2cfee, 0x1c1d8d8, 0xa520fc4, + 0x1518e034, 0xa584dd4, 0x29e572b, 0xd4594fc, 0x141a8f6f, 0x8dfccf3, 0x5d20ba3, 0x2eb60c3, 0x9f16eb0, + 0x11cec356, 0xf039f84, 0x1b0990c1, 0xc91e526, 0x10b65bae, 0xf0616e8, 0x173fa3ff, 0xec8ccf9, 0xbe32790, + 0x11da3e79, 0xe2f35c7, 0x908875c, 0xdacf7bd, 0x538c165, 0x8d1487f, 0x7c31aed, 0x21af228, 0x7e1689d, + 0xdfc23ca, 0x24f15dc, 0x25ef3c4, 0x35248cd, 0x99a0f43, 0xa4b6ecc, 0xd066b3, 0x2481152, 0x37a7688, + 0x15a444b6, 0xb62300c, 0x4b841b, 0xa655e79, 0xd53226d, 0xbeb348a, 0x127f3c2, 0xb989247, 0x71a277d, + 0x19e9dfcb, 0xb8f92d0, 0xe2d226c, 0x390a8b0, 0x183cc462, 0x7bd8167, 0x1f32a552, 0x5e02db4, 0xa146ee9, + 0x1a003957, 0x1c95f61, 0x1eeec155, 0x26f811f, 0xf9596ba, 0x3082bfb, 0x96df083, 0x3e3a289, 0x7e2d8be, + 0x157a63e0, 0x99b8941, 0x1da7d345, 0xcc6cd0, 0x10beed9a, 0x48e83c0, 0x13aa2e25, 0x7cad710, 0x4029988, + 0x13dfa9dd, 0xb94f884, 0x1f4adfef, 0xb88543, 0x16f5f8dc, 0xa6a67f4, 0x14e274e2, 0x5e56cf4, 0x2f24ef, + 0x1e9ef967, 0xfe09bad, 0xfe079b3, 0xcc0ae9e, 0xb3edf6d, 0x3e961bc, 0x130d7831, 0x31043d6, 0xba986f9, + 0x1d28055, 0x65240ca, 0x4971fa3, 0x81b17f8, 0x11ec34a5, 0x8366ddc, 0x1471809, 0xfa5f1c6, 0xc911e15, + 0x8849491, 0xcf4c2e2, 0x14471b91, 0x39f75be, 0x445c21e, 0xf1585e9, 0x72cc11f, 0x4c79f0c, 0xe5522e1, + 0x1874c1ee, 0x4444211, 0x7914884, 0x3d1b133, 0x25ba3c, 0x4194f65, 0x1c0457ef, 0xac4899d, 0xe1fa66c, + 0x130a7918, 0x9b8d312, 0x4b1c5c8, 0x61ccac3, 0x18c8aa6f, 0xe93cb0a, 0xdccb12c, 0xde10825, 0x969737d, + 0xf58c0c3, 0x7cee6a9, 0xc2c329a, 0xc7f9ed9, 0x107b3981, 0x696a40e, 0x152847ff, 0x4d88754, 0xb141f47, + 0x5a16ffe, 0x3a7870a, 0x18667659, 0x3b72b03, 0xb1c9435, 0x9285394, 0xa00005a, 0x37506c, 0x2edc0bb, + 0x19afe392, 0xeb39cac, 0x177ef286, 0xdf87197, 0x19f844ed, 0x31fe8, 0x15f9bfd, 0x80dbec, 0x342e96e, + 0x497aced, 0xe88e909, 0x1f5fa9ba, 0x530a6ee, 0x1ef4e3f1, 0x69ffd12, 0x583006d, 0x2ecc9b1, 0x362db70, + 0x18c7bdc5, 0xf4bb3c5, 0x1c90b957, 0xf067c09, 0x9768f2b, 0xf73566a, 0x1939a900, 0x198c38a, 0x202a2a1, + 0x4bbf5a6, 0x4e265bc, 0x1f44b6e7, 0x185ca49, 0xa39e81b, 0x24aff5b, 0x4acc9c2, 0x638bdd3, 0xb65b2a8, + 0x6def8be, 0xb94537a, 0x10b81dee, 0xe00ec55, 0x2f2cdf7, 0xc20622d, 0x2d20f36, 0xe03c8c9, 0x898ea76, + 0x8e3921b, 0x8905bff, 0x1e94b6c8, 0xee7ad86, 0x154797f2, 0xa620863, 0x3fbd0d9, 0x1f3caab, 0x30c24bd, + 0x19d3892f, 0x59c17a2, 0x1ab4b0ae, 0xf8714ee, 0x90c4098, 0xa9c800d, 0x1910236b, 0xea808d3, 0x9ae2f31, + 0x1a15ad64, 0xa48c8d1, 0x184635a4, 0xb725ef1, 0x11921dcc, 0x3f866df, 0x16c27568, 0xbdf580a, 0xb08f55c, + 0x186ee1c, 0xb1627fa, 0x34e82f6, 0x933837e, 0xf311be5, 0xfedb03b, 0x167f72cd, 0xa5469c0, 0x9c82531, + 0xb92a24b, 0x14fdc8b, 0x141980d1, 0xbdc3a49, 0x7e02bb1, 0xaf4e6dd, 0x106d99e1, 0xd4616fc, 0x93c2717, + 0x1c0a0507, 0xc6d5fed, 0x9a03d8b, 0xa1d22b0, 0x127853e3, 0xc4ac6b8, 0x1a048cf7, 0x9afb72c, 0x65d485d, + 0x72d5998, 0xe9fa744, 0xe49e82c, 0x253cf80, 0x5f777ce, 0xa3799a5, 0x17270cbb, 0xc1d1ef0, 0xdf74977, + 0x114cb859, 0xfa8e037, 0xb8f3fe5, 0xc734cc6, 0x70d3d61, 0xeadac62, 0x12093dd0, 0x9add67d, 0x87200d6, + 0x175bcbb, 0xb29b49f, 0x1806b79c, 0x12fb61f, 0x170b3a10, 0x3aaf1cf, 0xa224085, 0x79d26af, 0x97759e2, + 0x92e19f1, 0xb32714d, 0x1f00d9f1, 0xc728619, 0x9e6f627, 0xe745e24, 0x18ea4ace, 0xfc60a41, 0x125f5b2, + 0xc3cf512, 0x39ed486, 0xf4d15fa, 0xf9167fd, 0x1c1f5dd5, 0xc21a53e, 0x1897930, 0x957a112, 0x21059a0, + 0x1f9e3ddc, 0xa4dfced, 0x8427f6f, 0x726fbe7, 0x1ea658f8, 0x2fdcd4c, 0x17e9b66f, 0xb2e7c2e, 0x39923bf, + 0x1bae104, 0x3973ce5, 0xc6f264c, 0x3511b84, 0x124195d7, 0x11996bd, 0x20be23d, 0xdc437c4, 0x4b4f16b, + 0x11902a0, 0x6c29cc9, 0x1d5ffbe6, 0xdb0b4c7, 0x10144c14, 0x2f2b719, 0x301189, 0x2343336, 0xa0bf2ac, +} + +func sm2P256GetScalar(b *[32]byte, a []byte) { + var scalarBytes []byte + + n := new(big.Int).SetBytes(a) + if n.Cmp(sm2P256.N) >= 0 { + n.Mod(n, sm2P256.N) + scalarBytes = n.Bytes() + } else { + scalarBytes = a + } + for i, v := range scalarBytes { + b[len(scalarBytes)-(1+i)] = v + } +} + +func sm2P256PointAddMixed(xOut, yOut, zOut, x1, y1, z1, x2, y2 *sm2P256FieldElement) { + var z1z1, z1z1z1, s2, u2, h, i, j, r, rr, v, tmp sm2P256FieldElement + + sm2P256Square(&z1z1, z1) + sm2P256Add(&tmp, z1, z1) + + sm2P256Mul(&u2, x2, &z1z1) + sm2P256Mul(&z1z1z1, z1, &z1z1) + sm2P256Mul(&s2, y2, &z1z1z1) + sm2P256Sub(&h, &u2, x1) + sm2P256Add(&i, &h, &h) + sm2P256Square(&i, &i) + sm2P256Mul(&j, &h, &i) + sm2P256Sub(&r, &s2, y1) + sm2P256Add(&r, &r, &r) + sm2P256Mul(&v, x1, &i) + + sm2P256Mul(zOut, &tmp, &h) + sm2P256Square(&rr, &r) + sm2P256Sub(xOut, &rr, &j) + sm2P256Sub(xOut, xOut, &v) + sm2P256Sub(xOut, xOut, &v) + + sm2P256Sub(&tmp, &v, xOut) + sm2P256Mul(yOut, &tmp, &r) + sm2P256Mul(&tmp, y1, &j) + sm2P256Sub(yOut, yOut, &tmp) + sm2P256Sub(yOut, yOut, &tmp) +} + +// sm2P256CopyConditional sets out=in if mask = 0xffffffff in constant time. +// +// On entry: mask is either 0 or 0xffffffff. +func sm2P256CopyConditional(out, in *sm2P256FieldElement, mask uint32) { + for i := 0; i < 9; i++ { + tmp := mask & (in[i] ^ out[i]) + out[i] ^= tmp + } +} + +// sm2P256SelectAffinePoint sets {out_x,out_y} to the index'th entry of table. +// On entry: index < 16, table[0] must be zero. +func sm2P256SelectAffinePoint(xOut, yOut *sm2P256FieldElement, table []uint32, index uint32) { + for i := range xOut { + xOut[i] = 0 + } + for i := range yOut { + yOut[i] = 0 + } + + for i := uint32(1); i < 16; i++ { + mask := i ^ index + mask |= mask >> 2 + mask |= mask >> 1 + mask &= 1 + mask-- + for j := range xOut { + xOut[j] |= table[0] & mask + table = table[1:] + } + for j := range yOut { + yOut[j] |= table[0] & mask + table = table[1:] + } + } +} + +// sm2P256SelectJacobianPoint sets {out_x,out_y,out_z} to the index'th entry of +// table. +// On entry: index < 16, table[0] must be zero. +func sm2P256SelectJacobianPoint(xOut, yOut, zOut *sm2P256FieldElement, table *[16][3]sm2P256FieldElement, index uint32) { + for i := range xOut { + xOut[i] = 0 + } + for i := range yOut { + yOut[i] = 0 + } + for i := range zOut { + zOut[i] = 0 + } + + // The implicit value at index 0 is all zero. We don't need to perform that + // iteration of the loop because we already set out_* to zero. + for i := uint32(1); i < 16; i++ { + mask := i ^ index + mask |= mask >> 2 + mask |= mask >> 1 + mask &= 1 + mask-- + for j := range xOut { + xOut[j] |= table[i][0][j] & mask + } + for j := range yOut { + yOut[j] |= table[i][1][j] & mask + } + for j := range zOut { + zOut[j] |= table[i][2][j] & mask + } + } +} + +// sm2P256GetBit returns the bit'th bit of scalar. +func sm2P256GetBit(scalar *[32]uint8, bit uint) uint32 { + return uint32(((scalar[bit>>3]) >> (bit & 7)) & 1) +} + +// sm2P256ScalarBaseMult sets {xOut,yOut,zOut} = scalar*G where scalar is a +// little-endian number. Note that the value of scalar must be less than the +// order of the group. +func sm2P256ScalarBaseMult(xOut, yOut, zOut *sm2P256FieldElement, scalar *[32]uint8) { + nIsInfinityMask := ^uint32(0) + var px, py, tx, ty, tz sm2P256FieldElement + var pIsNoninfiniteMask, mask, tableOffset uint32 + + for i := range xOut { + xOut[i] = 0 + } + for i := range yOut { + yOut[i] = 0 + } + for i := range zOut { + zOut[i] = 0 + } + + // The loop adds bits at positions 0, 64, 128 and 192, followed by + // positions 32,96,160 and 224 and does this 32 times. + for i := uint(0); i < 32; i++ { + if i != 0 { + sm2P256PointDouble(xOut, yOut, zOut, xOut, yOut, zOut) + } + tableOffset = 0 + for j := uint(0); j <= 32; j += 32 { + bit0 := sm2P256GetBit(scalar, 31-i+j) + bit1 := sm2P256GetBit(scalar, 95-i+j) + bit2 := sm2P256GetBit(scalar, 159-i+j) + bit3 := sm2P256GetBit(scalar, 223-i+j) + index := bit0 | (bit1 << 1) | (bit2 << 2) | (bit3 << 3) + + sm2P256SelectAffinePoint(&px, &py, sm2P256Precomputed[tableOffset:], index) + tableOffset += 30 * 9 + + // Since scalar is less than the order of the group, we know that + // {xOut,yOut,zOut} != {px,py,1}, unless both are zero, which we handle + // below. + sm2P256PointAddMixed(&tx, &ty, &tz, xOut, yOut, zOut, &px, &py) + // The result of pointAddMixed is incorrect if {xOut,yOut,zOut} is zero + // (a.k.a. the point at infinity). We handle that situation by + // copying the point from the table. + sm2P256CopyConditional(xOut, &px, nIsInfinityMask) + sm2P256CopyConditional(yOut, &py, nIsInfinityMask) + sm2P256CopyConditional(zOut, &sm2P256Factor[1], nIsInfinityMask) + + // Equally, the result is also wrong if the point from the table is + // zero, which happens when the index is zero. We handle that by + // only copying from {tx,ty,tz} to {xOut,yOut,zOut} if index != 0. + pIsNoninfiniteMask = nonZeroToAllOnes(index) + mask = pIsNoninfiniteMask & ^nIsInfinityMask + sm2P256CopyConditional(xOut, &tx, mask) + sm2P256CopyConditional(yOut, &ty, mask) + sm2P256CopyConditional(zOut, &tz, mask) + // If p was not zero, then n is now non-zero. + nIsInfinityMask &^= pIsNoninfiniteMask + } + } +} + +func sm2P256ScalarMult(xOut, yOut, zOut, x, y *sm2P256FieldElement, scalar *[32]uint8) { + var precomp [16][3]sm2P256FieldElement + var px, py, pz, tx, ty, tz sm2P256FieldElement + var nIsInfinityMask, index, pIsNoninfiniteMask, mask uint32 + + // We precompute 0,1,2,... times {x,y}. + precomp[1][0] = *x + precomp[1][1] = *y + precomp[1][2] = sm2P256Factor[1] + + for i := 2; i < 16; i += 2 { + sm2P256PointDouble(&precomp[i][0], &precomp[i][1], &precomp[i][2], &precomp[i/2][0], &precomp[i/2][1], &precomp[i/2][2]) + sm2P256PointAddMixed(&precomp[i+1][0], &precomp[i+1][1], &precomp[i+1][2], &precomp[i][0], &precomp[i][1], &precomp[i][2], x, y) + } + + for i := range xOut { + xOut[i] = 0 + } + for i := range yOut { + yOut[i] = 0 + } + for i := range zOut { + zOut[i] = 0 + } + nIsInfinityMask = ^uint32(0) + + // We add in a window of four bits each iteration and do this 64 times. + for i := 0; i < 64; i++ { + if i != 0 { + sm2P256PointDouble(xOut, yOut, zOut, xOut, yOut, zOut) + sm2P256PointDouble(xOut, yOut, zOut, xOut, yOut, zOut) + sm2P256PointDouble(xOut, yOut, zOut, xOut, yOut, zOut) + sm2P256PointDouble(xOut, yOut, zOut, xOut, yOut, zOut) + } + + index = uint32(scalar[31-i/2]) + if (i & 1) == 1 { + index &= 15 + } else { + index >>= 4 + } + + // See the comments in scalarBaseMult about handling infinities. + sm2P256SelectJacobianPoint(&px, &py, &pz, &precomp, index) + sm2P256PointAdd(xOut, yOut, zOut, &px, &py, &pz, &tx, &ty, &tz) + sm2P256CopyConditional(xOut, &px, nIsInfinityMask) + sm2P256CopyConditional(yOut, &py, nIsInfinityMask) + sm2P256CopyConditional(zOut, &pz, nIsInfinityMask) + + pIsNoninfiniteMask = nonZeroToAllOnes(index) + mask = pIsNoninfiniteMask & ^nIsInfinityMask + sm2P256CopyConditional(xOut, &tx, mask) + sm2P256CopyConditional(yOut, &ty, mask) + sm2P256CopyConditional(zOut, &tz, mask) + nIsInfinityMask &^= pIsNoninfiniteMask + } +} + +func sm2P256PointToAffine(xOut, yOut, x, y, z *sm2P256FieldElement) { + var zInv, zInvSq sm2P256FieldElement + + zz := sm2P256ToBig(z) + zz.ModInverse(zz, sm2P256.P) + sm2P256FromBig(&zInv, zz) + + sm2P256Square(&zInvSq, &zInv) + sm2P256Mul(xOut, x, &zInvSq) + sm2P256Mul(&zInv, &zInv, &zInvSq) + sm2P256Mul(yOut, y, &zInv) +} + +func sm2P256ToAffine(x, y, z *sm2P256FieldElement) (xOut, yOut *big.Int) { + var xx, yy sm2P256FieldElement + + sm2P256PointToAffine(&xx, &yy, x, y, z) + return sm2P256ToBig(&xx), sm2P256ToBig(&yy) +} + +var sm2P256Factor = []sm2P256FieldElement{ + sm2P256FieldElement{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, + sm2P256FieldElement{0x2, 0x0, 0x1FFFFF00, 0x7FF, 0x0, 0x0, 0x0, 0x2000000, 0x0}, + sm2P256FieldElement{0x4, 0x0, 0x1FFFFE00, 0xFFF, 0x0, 0x0, 0x0, 0x4000000, 0x0}, + sm2P256FieldElement{0x6, 0x0, 0x1FFFFD00, 0x17FF, 0x0, 0x0, 0x0, 0x6000000, 0x0}, + sm2P256FieldElement{0x8, 0x0, 0x1FFFFC00, 0x1FFF, 0x0, 0x0, 0x0, 0x8000000, 0x0}, + sm2P256FieldElement{0xA, 0x0, 0x1FFFFB00, 0x27FF, 0x0, 0x0, 0x0, 0xA000000, 0x0}, + sm2P256FieldElement{0xC, 0x0, 0x1FFFFA00, 0x2FFF, 0x0, 0x0, 0x0, 0xC000000, 0x0}, + sm2P256FieldElement{0xE, 0x0, 0x1FFFF900, 0x37FF, 0x0, 0x0, 0x0, 0xE000000, 0x0}, + sm2P256FieldElement{0x10, 0x0, 0x1FFFF800, 0x3FFF, 0x0, 0x0, 0x0, 0x0, 0x01}, +} + +func sm2P256Scalar(b *sm2P256FieldElement, a int) { + sm2P256Mul(b, b, &sm2P256Factor[a]) +} + +// (x3, y3, z3) = (x1, y1, z1) + (x2, y2, z2) +func sm2P256PointAdd(x1, y1, z1, x2, y2, z2, x3, y3, z3 *sm2P256FieldElement) { + var u1, u2, z22, z12, z23, z13, s1, s2, h, h2, r, r2, tm sm2P256FieldElement + + if sm2P256ToBig(z1).Sign() == 0 { + sm2P256Dup(x3, x2) + sm2P256Dup(y3, y2) + sm2P256Dup(z3, z2) + return + } + + if sm2P256ToBig(z2).Sign() == 0 { + sm2P256Dup(x3, x1) + sm2P256Dup(y3, y1) + sm2P256Dup(z3, z1) + return + } + + sm2P256Square(&z12, z1) // z12 = z1 ^ 2 + sm2P256Square(&z22, z2) // z22 = z2 ^ 2 + + sm2P256Mul(&z13, &z12, z1) // z13 = z1 ^ 3 + sm2P256Mul(&z23, &z22, z2) // z23 = z2 ^ 3 + + sm2P256Mul(&u1, x1, &z22) // u1 = x1 * z2 ^ 2 + sm2P256Mul(&u2, x2, &z12) // u2 = x2 * z1 ^ 2 + + sm2P256Mul(&s1, y1, &z23) // s1 = y1 * z2 ^ 3 + sm2P256Mul(&s2, y2, &z13) // s2 = y2 * z1 ^ 3 + + if sm2P256ToBig(&u1).Cmp(sm2P256ToBig(&u2)) == 0 && + sm2P256ToBig(&s1).Cmp(sm2P256ToBig(&s2)) == 0 { + sm2P256PointDouble(x1, y1, z1, x1, y1, z1) + } + + sm2P256Sub(&h, &u2, &u1) // h = u2 - u1 + sm2P256Sub(&r, &s2, &s1) // r = s2 - s1 + + sm2P256Square(&r2, &r) // r2 = r ^ 2 + sm2P256Square(&h2, &h) // h2 = h ^ 2 + + sm2P256Mul(&tm, &h2, &h) // tm = h ^ 3 + sm2P256Sub(x3, &r2, &tm) + sm2P256Mul(&tm, &u1, &h2) + sm2P256Scalar(&tm, 2) // tm = 2 * (u1 * h ^ 2) + sm2P256Sub(x3, x3, &tm) // x3 = r ^ 2 - h ^ 3 - 2 * u1 * h ^ 2 + + sm2P256Mul(&tm, &u1, &h2) // tm = u1 * h ^ 2 + sm2P256Sub(&tm, &tm, x3) // tm = u1 * h ^ 2 - x3 + sm2P256Mul(y3, &r, &tm) + sm2P256Mul(&tm, &h2, &h) // tm = h ^ 3 + sm2P256Mul(&tm, &tm, &s1) // tm = s1 * h ^ 3 + sm2P256Sub(y3, y3, &tm) // y3 = r * (u1 * h ^ 2 - x3) - s1 * h ^ 3 + + sm2P256Mul(z3, z1, z2) + sm2P256Mul(z3, z3, &h) // z3 = z1 * z3 * h +} + +func sm2P256PointDouble(x3, y3, z3, x, y, z *sm2P256FieldElement) { + var s, m, m2, x2, y2, z2, z4, y4, az4 sm2P256FieldElement + + sm2P256Square(&x2, x) // x2 = x ^ 2 + sm2P256Square(&y2, y) // y2 = y ^ 2 + sm2P256Square(&z2, z) // z2 = z ^ 2 + + sm2P256Square(&z4, z) // z4 = z ^ 2 + sm2P256Mul(&z4, &z4, z) // z4 = z ^ 3 + sm2P256Mul(&z4, &z4, z) // z4 = z ^ 4 + + sm2P256Square(&y4, y) // y4 = y ^ 2 + sm2P256Mul(&y4, &y4, y) // y4 = y ^ 3 + sm2P256Mul(&y4, &y4, y) // y4 = y ^ 4 + sm2P256Scalar(&y4, 8) // y4 = 8 * y ^ 4 + + sm2P256Mul(&s, x, &y2) + sm2P256Scalar(&s, 4) // s = 4 * x * y ^ 2 + + sm2P256Dup(&m, &x2) + sm2P256Scalar(&m, 3) + sm2P256Mul(&az4, &sm2P256.a, &z4) + sm2P256Add(&m, &m, &az4) // m = 3 * x ^ 2 + a * z ^ 4 + + sm2P256Square(&m2, &m) // m2 = m ^ 2 + + sm2P256Add(z3, y, z) + sm2P256Square(z3, z3) + sm2P256Sub(z3, z3, &z2) + sm2P256Sub(z3, z3, &y2) // z' = (y + z) ^2 - z ^ 2 - y ^ 2 + + sm2P256Sub(x3, &m2, &s) + sm2P256Sub(x3, x3, &s) // x' = m2 - 2 * s + + sm2P256Sub(y3, &s, x3) + sm2P256Mul(y3, y3, &m) + sm2P256Sub(y3, y3, &y4) // y' = m * (s - x') - 8 * y ^ 4 +} + +// p256Zero31 is 0 mod p. +var sm2P256Zero31 = sm2P256FieldElement{0x7FFFFFF8, 0x3FFFFFFC, 0x800003FC, 0x3FFFDFFC, 0x7FFFFFFC, 0x3FFFFFFC, 0x7FFFFFFC, 0x37FFFFFC, 0x7FFFFFFC} + +// c = a + b +func sm2P256Add(c, a, b *sm2P256FieldElement) { + carry := uint32(0) + for i := 0; ; i++ { + c[i] = a[i] + b[i] + c[i] += carry + carry = c[i] >> 29 + c[i] &= bottom29Bits + i++ + if i == 9 { + break + } + c[i] = a[i] + b[i] + c[i] += carry + carry = c[i] >> 28 + c[i] &= bottom28Bits + } + sm2P256ReduceCarry(c, carry) +} + +// c = a - b +func sm2P256Sub(c, a, b *sm2P256FieldElement) { + var carry uint32 + + for i := 0; ; i++ { + c[i] = a[i] - b[i] + c[i] += sm2P256Zero31[i] + c[i] += carry + carry = c[i] >> 29 + c[i] &= bottom29Bits + i++ + if i == 9 { + break + } + c[i] = a[i] - b[i] + c[i] += sm2P256Zero31[i] + c[i] += carry + carry = c[i] >> 28 + c[i] &= bottom28Bits + } + sm2P256ReduceCarry(c, carry) +} + +// c = a * b +func sm2P256Mul(c, a, b *sm2P256FieldElement) { + var tmp sm2P256LargeFieldElement + + tmp[0] = uint64(a[0]) * uint64(b[0]) + tmp[1] = uint64(a[0])*(uint64(b[1])<<0) + + uint64(a[1])*(uint64(b[0])<<0) + tmp[2] = uint64(a[0])*(uint64(b[2])<<0) + + uint64(a[1])*(uint64(b[1])<<1) + + uint64(a[2])*(uint64(b[0])<<0) + tmp[3] = uint64(a[0])*(uint64(b[3])<<0) + + uint64(a[1])*(uint64(b[2])<<0) + + uint64(a[2])*(uint64(b[1])<<0) + + uint64(a[3])*(uint64(b[0])<<0) + tmp[4] = uint64(a[0])*(uint64(b[4])<<0) + + uint64(a[1])*(uint64(b[3])<<1) + + uint64(a[2])*(uint64(b[2])<<0) + + uint64(a[3])*(uint64(b[1])<<1) + + uint64(a[4])*(uint64(b[0])<<0) + tmp[5] = uint64(a[0])*(uint64(b[5])<<0) + + uint64(a[1])*(uint64(b[4])<<0) + + uint64(a[2])*(uint64(b[3])<<0) + + uint64(a[3])*(uint64(b[2])<<0) + + uint64(a[4])*(uint64(b[1])<<0) + + uint64(a[5])*(uint64(b[0])<<0) + tmp[6] = uint64(a[0])*(uint64(b[6])<<0) + + uint64(a[1])*(uint64(b[5])<<1) + + uint64(a[2])*(uint64(b[4])<<0) + + uint64(a[3])*(uint64(b[3])<<1) + + uint64(a[4])*(uint64(b[2])<<0) + + uint64(a[5])*(uint64(b[1])<<1) + + uint64(a[6])*(uint64(b[0])<<0) + tmp[7] = uint64(a[0])*(uint64(b[7])<<0) + + uint64(a[1])*(uint64(b[6])<<0) + + uint64(a[2])*(uint64(b[5])<<0) + + uint64(a[3])*(uint64(b[4])<<0) + + uint64(a[4])*(uint64(b[3])<<0) + + uint64(a[5])*(uint64(b[2])<<0) + + uint64(a[6])*(uint64(b[1])<<0) + + uint64(a[7])*(uint64(b[0])<<0) + // tmp[8] has the greatest value but doesn't overflow. See logic in + // p256Square. + tmp[8] = uint64(a[0])*(uint64(b[8])<<0) + + uint64(a[1])*(uint64(b[7])<<1) + + uint64(a[2])*(uint64(b[6])<<0) + + uint64(a[3])*(uint64(b[5])<<1) + + uint64(a[4])*(uint64(b[4])<<0) + + uint64(a[5])*(uint64(b[3])<<1) + + uint64(a[6])*(uint64(b[2])<<0) + + uint64(a[7])*(uint64(b[1])<<1) + + uint64(a[8])*(uint64(b[0])<<0) + tmp[9] = uint64(a[1])*(uint64(b[8])<<0) + + uint64(a[2])*(uint64(b[7])<<0) + + uint64(a[3])*(uint64(b[6])<<0) + + uint64(a[4])*(uint64(b[5])<<0) + + uint64(a[5])*(uint64(b[4])<<0) + + uint64(a[6])*(uint64(b[3])<<0) + + uint64(a[7])*(uint64(b[2])<<0) + + uint64(a[8])*(uint64(b[1])<<0) + tmp[10] = uint64(a[2])*(uint64(b[8])<<0) + + uint64(a[3])*(uint64(b[7])<<1) + + uint64(a[4])*(uint64(b[6])<<0) + + uint64(a[5])*(uint64(b[5])<<1) + + uint64(a[6])*(uint64(b[4])<<0) + + uint64(a[7])*(uint64(b[3])<<1) + + uint64(a[8])*(uint64(b[2])<<0) + tmp[11] = uint64(a[3])*(uint64(b[8])<<0) + + uint64(a[4])*(uint64(b[7])<<0) + + uint64(a[5])*(uint64(b[6])<<0) + + uint64(a[6])*(uint64(b[5])<<0) + + uint64(a[7])*(uint64(b[4])<<0) + + uint64(a[8])*(uint64(b[3])<<0) + tmp[12] = uint64(a[4])*(uint64(b[8])<<0) + + uint64(a[5])*(uint64(b[7])<<1) + + uint64(a[6])*(uint64(b[6])<<0) + + uint64(a[7])*(uint64(b[5])<<1) + + uint64(a[8])*(uint64(b[4])<<0) + tmp[13] = uint64(a[5])*(uint64(b[8])<<0) + + uint64(a[6])*(uint64(b[7])<<0) + + uint64(a[7])*(uint64(b[6])<<0) + + uint64(a[8])*(uint64(b[5])<<0) + tmp[14] = uint64(a[6])*(uint64(b[8])<<0) + + uint64(a[7])*(uint64(b[7])<<1) + + uint64(a[8])*(uint64(b[6])<<0) + tmp[15] = uint64(a[7])*(uint64(b[8])<<0) + + uint64(a[8])*(uint64(b[7])<<0) + tmp[16] = uint64(a[8]) * (uint64(b[8]) << 0) + sm2P256ReduceDegree(c, &tmp) +} + +// b = a * a +func sm2P256Square(b, a *sm2P256FieldElement) { + var tmp sm2P256LargeFieldElement + + tmp[0] = uint64(a[0]) * uint64(a[0]) + tmp[1] = uint64(a[0]) * (uint64(a[1]) << 1) + tmp[2] = uint64(a[0])*(uint64(a[2])<<1) + + uint64(a[1])*(uint64(a[1])<<1) + tmp[3] = uint64(a[0])*(uint64(a[3])<<1) + + uint64(a[1])*(uint64(a[2])<<1) + tmp[4] = uint64(a[0])*(uint64(a[4])<<1) + + uint64(a[1])*(uint64(a[3])<<2) + + uint64(a[2])*uint64(a[2]) + tmp[5] = uint64(a[0])*(uint64(a[5])<<1) + + uint64(a[1])*(uint64(a[4])<<1) + + uint64(a[2])*(uint64(a[3])<<1) + tmp[6] = uint64(a[0])*(uint64(a[6])<<1) + + uint64(a[1])*(uint64(a[5])<<2) + + uint64(a[2])*(uint64(a[4])<<1) + + uint64(a[3])*(uint64(a[3])<<1) + tmp[7] = uint64(a[0])*(uint64(a[7])<<1) + + uint64(a[1])*(uint64(a[6])<<1) + + uint64(a[2])*(uint64(a[5])<<1) + + uint64(a[3])*(uint64(a[4])<<1) + // tmp[8] has the greatest value of 2**61 + 2**60 + 2**61 + 2**60 + 2**60, + // which is < 2**64 as required. + tmp[8] = uint64(a[0])*(uint64(a[8])<<1) + + uint64(a[1])*(uint64(a[7])<<2) + + uint64(a[2])*(uint64(a[6])<<1) + + uint64(a[3])*(uint64(a[5])<<2) + + uint64(a[4])*uint64(a[4]) + tmp[9] = uint64(a[1])*(uint64(a[8])<<1) + + uint64(a[2])*(uint64(a[7])<<1) + + uint64(a[3])*(uint64(a[6])<<1) + + uint64(a[4])*(uint64(a[5])<<1) + tmp[10] = uint64(a[2])*(uint64(a[8])<<1) + + uint64(a[3])*(uint64(a[7])<<2) + + uint64(a[4])*(uint64(a[6])<<1) + + uint64(a[5])*(uint64(a[5])<<1) + tmp[11] = uint64(a[3])*(uint64(a[8])<<1) + + uint64(a[4])*(uint64(a[7])<<1) + + uint64(a[5])*(uint64(a[6])<<1) + tmp[12] = uint64(a[4])*(uint64(a[8])<<1) + + uint64(a[5])*(uint64(a[7])<<2) + + uint64(a[6])*uint64(a[6]) + tmp[13] = uint64(a[5])*(uint64(a[8])<<1) + + uint64(a[6])*(uint64(a[7])<<1) + tmp[14] = uint64(a[6])*(uint64(a[8])<<1) + + uint64(a[7])*(uint64(a[7])<<1) + tmp[15] = uint64(a[7]) * (uint64(a[8]) << 1) + tmp[16] = uint64(a[8]) * uint64(a[8]) + sm2P256ReduceDegree(b, &tmp) +} + +// nonZeroToAllOnes returns: +// 0xffffffff for 0 < x <= 2**31 +// 0 for x == 0 or x > 2**31. +func nonZeroToAllOnes(x uint32) uint32 { + return ((x - 1) >> 31) - 1 +} + +var sm2P256Carry = [8 * 9]uint32{ + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x2, 0x0, 0x1FFFFF00, 0x7FF, 0x0, 0x0, 0x0, 0x2000000, 0x0, + 0x4, 0x0, 0x1FFFFE00, 0xFFF, 0x0, 0x0, 0x0, 0x4000000, 0x0, + 0x6, 0x0, 0x1FFFFD00, 0x17FF, 0x0, 0x0, 0x0, 0x6000000, 0x0, + 0x8, 0x0, 0x1FFFFC00, 0x1FFF, 0x0, 0x0, 0x0, 0x8000000, 0x0, + 0xA, 0x0, 0x1FFFFB00, 0x27FF, 0x0, 0x0, 0x0, 0xA000000, 0x0, + 0xC, 0x0, 0x1FFFFA00, 0x2FFF, 0x0, 0x0, 0x0, 0xC000000, 0x0, + 0xE, 0x0, 0x1FFFF900, 0x37FF, 0x0, 0x0, 0x0, 0xE000000, 0x0, +} + +// carry < 2 ^ 3 +func sm2P256ReduceCarry(a *sm2P256FieldElement, carry uint32) { + a[0] += sm2P256Carry[carry*9+0] + a[2] += sm2P256Carry[carry*9+2] + a[3] += sm2P256Carry[carry*9+3] + a[7] += sm2P256Carry[carry*9+7] +} + +// 这代码真是丑比了,我也是对自己醉了。。。 +// 你最好别改这个代码,不然你会死的很惨。。 +func sm2P256ReduceDegree(a *sm2P256FieldElement, b *sm2P256LargeFieldElement) { + var tmp [18]uint32 + var carry, x, xMask uint32 + + // tmp + // 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 ... + // 29 | 28 | 29 | 28 | 29 | 28 | 29 | 28 | 29 | 28 | 29 ... + tmp[0] = uint32(b[0]) & bottom29Bits + tmp[1] = uint32(b[0]) >> 29 + tmp[1] |= (uint32(b[0]>>32) << 3) & bottom28Bits + tmp[1] += uint32(b[1]) & bottom28Bits + carry = tmp[1] >> 28 + tmp[1] &= bottom28Bits + for i := 2; i < 17; i++ { + tmp[i] = (uint32(b[i-2] >> 32)) >> 25 + tmp[i] += (uint32(b[i-1])) >> 28 + tmp[i] += (uint32(b[i-1]>>32) << 4) & bottom29Bits + tmp[i] += uint32(b[i]) & bottom29Bits + tmp[i] += carry + carry = tmp[i] >> 29 + tmp[i] &= bottom29Bits + + i++ + if i == 17 { + break + } + tmp[i] = uint32(b[i-2]>>32) >> 25 + tmp[i] += uint32(b[i-1]) >> 29 + tmp[i] += ((uint32(b[i-1] >> 32)) << 3) & bottom28Bits + tmp[i] += uint32(b[i]) & bottom28Bits + tmp[i] += carry + carry = tmp[i] >> 28 + tmp[i] &= bottom28Bits + } + tmp[17] = uint32(b[15]>>32) >> 25 + tmp[17] += uint32(b[16]) >> 29 + tmp[17] += uint32(b[16]>>32) << 3 + tmp[17] += carry + + for i := 0; ; i += 2 { + + tmp[i+1] += tmp[i] >> 29 + x = tmp[i] & bottom29Bits + tmp[i] = 0 + if x > 0 { + set4 := uint32(0) + set7 := uint32(0) + xMask = nonZeroToAllOnes(x) + tmp[i+2] += (x << 7) & bottom29Bits + tmp[i+3] += x >> 22 + if tmp[i+3] < 0x10000000 { + set4 = 1 + tmp[i+3] += 0x10000000 & xMask + tmp[i+3] -= (x << 10) & bottom28Bits + } else { + tmp[i+3] -= (x << 10) & bottom28Bits + } + if tmp[i+4] < 0x20000000 { + tmp[i+4] += 0x20000000 & xMask + tmp[i+4] -= set4 // 借位 + tmp[i+4] -= x >> 18 + if tmp[i+5] < 0x10000000 { + tmp[i+5] += 0x10000000 & xMask + tmp[i+5] -= 1 // 借位 + if tmp[i+6] < 0x20000000 { + set7 = 1 + tmp[i+6] += 0x20000000 & xMask + tmp[i+6] -= 1 // 借位 + } else { + tmp[i+6] -= 1 // 借位 + } + } else { + tmp[i+5] -= 1 + } + } else { + tmp[i+4] -= set4 // 借位 + tmp[i+4] -= x >> 18 + } + if tmp[i+7] < 0x10000000 { + tmp[i+7] += 0x10000000 & xMask + tmp[i+7] -= set7 + tmp[i+7] -= (x << 24) & bottom28Bits + tmp[i+8] += (x << 28) & bottom29Bits + if tmp[i+8] < 0x20000000 { + tmp[i+8] += 0x20000000 & xMask + tmp[i+8] -= 1 + tmp[i+8] -= x >> 4 + tmp[i+9] += ((x >> 1) - 1) & xMask + } else { + tmp[i+8] -= 1 + tmp[i+8] -= x >> 4 + tmp[i+9] += (x >> 1) & xMask + } + } else { + tmp[i+7] -= set7 // 借位 + tmp[i+7] -= (x << 24) & bottom28Bits + tmp[i+8] += (x << 28) & bottom29Bits + if tmp[i+8] < 0x20000000 { + tmp[i+8] += 0x20000000 & xMask + tmp[i+8] -= x >> 4 + tmp[i+9] += ((x >> 1) - 1) & xMask + } else { + tmp[i+8] -= x >> 4 + tmp[i+9] += (x >> 1) & xMask + } + } + + } + + if i+1 == 9 { + break + } + + tmp[i+2] += tmp[i+1] >> 28 + x = tmp[i+1] & bottom28Bits + tmp[i+1] = 0 + if x > 0 { + set5 := uint32(0) + set8 := uint32(0) + set9 := uint32(0) + xMask = nonZeroToAllOnes(x) + tmp[i+3] += (x << 7) & bottom28Bits + tmp[i+4] += x >> 21 + if tmp[i+4] < 0x20000000 { + set5 = 1 + tmp[i+4] += 0x20000000 & xMask + tmp[i+4] -= (x << 11) & bottom29Bits + } else { + tmp[i+4] -= (x << 11) & bottom29Bits + } + if tmp[i+5] < 0x10000000 { + tmp[i+5] += 0x10000000 & xMask + tmp[i+5] -= set5 // 借位 + tmp[i+5] -= x >> 18 + if tmp[i+6] < 0x20000000 { + tmp[i+6] += 0x20000000 & xMask + tmp[i+6] -= 1 // 借位 + if tmp[i+7] < 0x10000000 { + set8 = 1 + tmp[i+7] += 0x10000000 & xMask + tmp[i+7] -= 1 // 借位 + } else { + tmp[i+7] -= 1 // 借位 + } + } else { + tmp[i+6] -= 1 // 借位 + } + } else { + tmp[i+5] -= set5 // 借位 + tmp[i+5] -= x >> 18 + } + if tmp[i+8] < 0x20000000 { + set9 = 1 + tmp[i+8] += 0x20000000 & xMask + tmp[i+8] -= set8 + tmp[i+8] -= (x << 25) & bottom29Bits + } else { + tmp[i+8] -= set8 + tmp[i+8] -= (x << 25) & bottom29Bits + } + if tmp[i+9] < 0x10000000 { + tmp[i+9] += 0x10000000 & xMask + tmp[i+9] -= set9 // 借位 + tmp[i+9] -= x >> 4 + tmp[i+10] += (x - 1) & xMask + } else { + tmp[i+9] -= set9 // 借位 + tmp[i+9] -= x >> 4 + tmp[i+10] += x & xMask + } + } + } + + carry = uint32(0) + for i := 0; i < 8; i++ { + a[i] = tmp[i+9] + a[i] += carry + a[i] += (tmp[i+10] << 28) & bottom29Bits + carry = a[i] >> 29 + a[i] &= bottom29Bits + + i++ + a[i] = tmp[i+9] >> 1 + a[i] += carry + carry = a[i] >> 28 + a[i] &= bottom28Bits + } + a[8] = tmp[17] + a[8] += carry + carry = a[8] >> 29 + a[8] &= bottom29Bits + sm2P256ReduceCarry(a, carry) +} + +// b = a +func sm2P256Dup(b, a *sm2P256FieldElement) { + *b = *a +} + +// X = a * R mod P +func sm2P256FromBig(X *sm2P256FieldElement, a *big.Int) { + x := new(big.Int).Lsh(a, 257) + x.Mod(x, sm2P256.P) + for i := 0; i < 9; i++ { + if bits := x.Bits(); len(bits) > 0 { + X[i] = uint32(bits[0]) & bottom29Bits + } else { + X[i] = 0 + } + x.Rsh(x, 29) + i++ + if i == 9 { + break + } + if bits := x.Bits(); len(bits) > 0 { + X[i] = uint32(bits[0]) & bottom28Bits + } else { + X[i] = 0 + } + x.Rsh(x, 28) + } +} + +// X = r * R mod P +// r = X * R' mod P +func sm2P256ToBig(X *sm2P256FieldElement) *big.Int { + r, tm := new(big.Int), new(big.Int) + r.SetInt64(int64(X[8])) + for i := 7; i >= 0; i-- { + if (i & 1) == 0 { + r.Lsh(r, 29) + } else { + r.Lsh(r, 28) + } + tm.SetInt64(int64(X[i])) + r.Add(r, tm) + } + r.Mul(r, sm2P256.RInverse) + r.Mod(r, sm2P256.P) + return r +} diff --git a/crypto/sm2/pkcs1.go b/crypto/sm2/pkcs1.go new file mode 100644 index 00000000..8ef29b5f --- /dev/null +++ b/crypto/sm2/pkcs1.go @@ -0,0 +1,132 @@ +/* +Copyright Suzhou Tongji Fintech Research Institute 2017 All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sm2 + +import ( + "crypto/rsa" + "encoding/asn1" + "errors" + "math/big" +) + +// pkcs1PrivateKey is a structure which mirrors the PKCS#1 ASN.1 for an RSA private key. +type pkcs1PrivateKey struct { + Version int + N *big.Int + E int + D *big.Int + P *big.Int + Q *big.Int + // We ignore these values, if present, because rsa will calculate them. + Dp *big.Int `asn1:"optional"` + Dq *big.Int `asn1:"optional"` + Qinv *big.Int `asn1:"optional"` + + AdditionalPrimes []pkcs1AdditionalRSAPrime `asn1:"optional,omitempty"` +} + +type pkcs1AdditionalRSAPrime struct { + Prime *big.Int + + // We ignore these values because rsa will calculate them. + Exp *big.Int + Coeff *big.Int +} + +// ParsePKCS1PrivateKey returns an RSA private key from its ASN.1 PKCS#1 DER encoded form. +func ParsePKCS1PrivateKey(der []byte) (*rsa.PrivateKey, error) { + var priv pkcs1PrivateKey + rest, err := asn1.Unmarshal(der, &priv) + if len(rest) > 0 { + return nil, asn1.SyntaxError{Msg: "trailing data"} + } + if err != nil { + return nil, err + } + + if priv.Version > 1 { + return nil, errors.New("x509: unsupported private key version") + } + + if priv.N.Sign() <= 0 || priv.D.Sign() <= 0 || priv.P.Sign() <= 0 || priv.Q.Sign() <= 0 { + return nil, errors.New("x509: private key contains zero or negative value") + } + + key := new(rsa.PrivateKey) + key.PublicKey = rsa.PublicKey{ + E: priv.E, + N: priv.N, + } + + key.D = priv.D + key.Primes = make([]*big.Int, 2+len(priv.AdditionalPrimes)) + key.Primes[0] = priv.P + key.Primes[1] = priv.Q + for i, a := range priv.AdditionalPrimes { + if a.Prime.Sign() <= 0 { + return nil, errors.New("x509: private key contains zero or negative prime") + } + key.Primes[i+2] = a.Prime + // We ignore the other two values because rsa will calculate + // them as needed. + } + + err = key.Validate() + if err != nil { + return nil, err + } + key.Precompute() + + return key, nil +} + +// MarshalPKCS1PrivateKey converts a private key to ASN.1 DER encoded form. +func MarshalPKCS1PrivateKey(key *rsa.PrivateKey) []byte { + key.Precompute() + + version := 0 + if len(key.Primes) > 2 { + version = 1 + } + + priv := pkcs1PrivateKey{ + Version: version, + N: key.N, + E: key.PublicKey.E, + D: key.D, + P: key.Primes[0], + Q: key.Primes[1], + Dp: key.Precomputed.Dp, + Dq: key.Precomputed.Dq, + Qinv: key.Precomputed.Qinv, + } + + priv.AdditionalPrimes = make([]pkcs1AdditionalRSAPrime, len(key.Precomputed.CRTValues)) + for i, values := range key.Precomputed.CRTValues { + priv.AdditionalPrimes[i].Prime = key.Primes[2+i] + priv.AdditionalPrimes[i].Exp = values.Exp + priv.AdditionalPrimes[i].Coeff = values.Coeff + } + + b, _ := asn1.Marshal(priv) + return b +} + +// rsaPublicKey reflects the ASN.1 structure of a PKCS#1 public key. +type rsaPublicKey struct { + N *big.Int + E int +} diff --git a/crypto/sm2/pkcs8.go b/crypto/sm2/pkcs8.go new file mode 100644 index 00000000..192a664b --- /dev/null +++ b/crypto/sm2/pkcs8.go @@ -0,0 +1,488 @@ +/* +Copyright Suzhou Tongji Fintech Research Institute 2017 All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sm2 + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/elliptic" + "crypto/hmac" + "crypto/md5" + "crypto/rand" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "crypto/x509/pkix" + "encoding/asn1" + "encoding/pem" + "errors" + "hash" + "io/ioutil" + "math/big" + "os" + "reflect" +) + +/* + * reference to RFC5959 and RFC2898 + */ + +var ( + oidPBES1 = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 5, 3} // pbeWithMD5AndDES-CBC(PBES1) + oidPBES2 = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 5, 13} // id-PBES2(PBES2) + oidPBKDF2 = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 5, 12} // id-PBKDF2 + + oidKEYMD5 = asn1.ObjectIdentifier{1, 2, 840, 113549, 2, 5} + oidKEYSHA1 = asn1.ObjectIdentifier{1, 2, 840, 113549, 2, 7} + oidKEYSHA256 = asn1.ObjectIdentifier{1, 2, 840, 113549, 2, 9} + oidKEYSHA512 = asn1.ObjectIdentifier{1, 2, 840, 113549, 2, 11} + + oidAES128CBC = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 1, 2} + oidAES256CBC = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 1, 42} + + oidSM2 = asn1.ObjectIdentifier{1, 2, 840, 10045, 2, 1} +) + +// reference to https://www.rfc-editor.org/rfc/rfc5958.txt +type PrivateKeyInfo struct { + Version int // v1 or v2 + PrivateKeyAlgorithm []asn1.ObjectIdentifier + PrivateKey []byte +} + +// reference to https://www.rfc-editor.org/rfc/rfc5958.txt +type EncryptedPrivateKeyInfo struct { + EncryptionAlgorithm Pbes2Algorithms + EncryptedData []byte +} + +// reference to https://www.ietf.org/rfc/rfc2898.txt +type Pbes2Algorithms struct { + IdPBES2 asn1.ObjectIdentifier + Pbes2Params Pbes2Params +} + +// reference to https://www.ietf.org/rfc/rfc2898.txt +type Pbes2Params struct { + KeyDerivationFunc Pbes2KDfs // PBES2-KDFs + EncryptionScheme Pbes2Encs // PBES2-Encs +} + +// reference to https://www.ietf.org/rfc/rfc2898.txt +type Pbes2KDfs struct { + IdPBKDF2 asn1.ObjectIdentifier + Pkdf2Params Pkdf2Params +} + +type Pbes2Encs struct { + EncryAlgo asn1.ObjectIdentifier + IV []byte +} + +// reference to https://www.ietf.org/rfc/rfc2898.txt +type Pkdf2Params struct { + Salt []byte + IterationCount int + Prf pkix.AlgorithmIdentifier +} + +type sm2PrivateKey struct { + Version int + PrivateKey []byte + NamedCurveOID asn1.ObjectIdentifier `asn1:"optional,explicit,tag:0"` + PublicKey asn1.BitString `asn1:"optional,explicit,tag:1"` +} + +type pkcs8 struct { + Version int + Algo pkix.AlgorithmIdentifier + PrivateKey []byte +} + +// copy from crypto/pbkdf2.go +func pbkdf(password, salt []byte, iter, keyLen int, h func() hash.Hash) []byte { + prf := hmac.New(h, password) + hashLen := prf.Size() + numBlocks := (keyLen + hashLen - 1) / hashLen + + var buf [4]byte + dk := make([]byte, 0, numBlocks*hashLen) + U := make([]byte, hashLen) + for block := 1; block <= numBlocks; block++ { + // N.B.: || means concatenation, ^ means XOR + // for each block T_i = U_1 ^ U_2 ^ ... ^ U_iter + // U_1 = PRF(password, salt || uint(i)) + prf.Reset() + prf.Write(salt) + buf[0] = byte(block >> 24) + buf[1] = byte(block >> 16) + buf[2] = byte(block >> 8) + buf[3] = byte(block) + prf.Write(buf[:4]) + dk = prf.Sum(dk) + T := dk[len(dk)-hashLen:] + copy(U, T) + + // U_n = PRF(password, U_(n-1)) + for n := 2; n <= iter; n++ { + prf.Reset() + prf.Write(U) + U = U[:0] + U = prf.Sum(U) + for x := range U { + T[x] ^= U[x] + } + } + } + return dk[:keyLen] +} + +func ParseSm2PublicKey(der []byte) (*PublicKey, error) { + var pubkey pkixPublicKey + + if _, err := asn1.Unmarshal(der, &pubkey); err != nil { + return nil, err + } + if !reflect.DeepEqual(pubkey.Algo.Algorithm, oidSM2) { + return nil, errors.New("x509: not sm2 elliptic curve") + } + curve := P256Sm2() + x, y := elliptic.Unmarshal(curve, pubkey.BitString.Bytes) + pub := PublicKey{ + Curve: curve, + X: x, + Y: y, + } + return &pub, nil +} + +func MarshalSm2PublicKey(key *PublicKey) ([]byte, error) { + var r pkixPublicKey + var algo pkix.AlgorithmIdentifier + + algo.Algorithm = oidSM2 + algo.Parameters.Class = 0 + algo.Parameters.Tag = 6 + algo.Parameters.IsCompound = false + algo.Parameters.FullBytes = []byte{6, 8, 42, 129, 28, 207, 85, 1, 130, 45} // asn1.Marshal(asn1.ObjectIdentifier{1, 2, 156, 10197, 1, 301}) + r.Algo = algo + r.BitString = asn1.BitString{Bytes: elliptic.Marshal(key.Curve, key.X, key.Y)} + return asn1.Marshal(r) +} + +func ParseSm2PrivateKey(der []byte) (*PrivateKey, error) { + var privKey sm2PrivateKey + + if _, err := asn1.Unmarshal(der, &privKey); err != nil { + return nil, errors.New("x509: failed to parse SM2 private key: " + err.Error()) + } + curve := P256Sm2() + k := new(big.Int).SetBytes(privKey.PrivateKey) + curveOrder := curve.Params().N + if k.Cmp(curveOrder) >= 0 { + return nil, errors.New("x509: invalid elliptic curve private key value") + } + priv := new(PrivateKey) + priv.Curve = curve + priv.D = k + privateKey := make([]byte, (curveOrder.BitLen()+7)/8) + for len(privKey.PrivateKey) > len(privateKey) { + if privKey.PrivateKey[0] != 0 { + return nil, errors.New("x509: invalid private key length") + } + privKey.PrivateKey = privKey.PrivateKey[1:] + } + copy(privateKey[len(privateKey)-len(privKey.PrivateKey):], privKey.PrivateKey) + priv.X, priv.Y = curve.ScalarBaseMult(privateKey) + return priv, nil +} + +func ParsePKCS8UnecryptedPrivateKey(der []byte) (*PrivateKey, error) { + var privKey pkcs8 + + if _, err := asn1.Unmarshal(der, &privKey); err != nil { + return nil, err + } + if !reflect.DeepEqual(privKey.Algo.Algorithm, oidSM2) { + return nil, errors.New("x509: not sm2 elliptic curve") + } + return ParseSm2PrivateKey(privKey.PrivateKey) +} + +func ParsePKCS8EcryptedPrivateKey(der, pwd []byte) (*PrivateKey, error) { + var keyInfo EncryptedPrivateKeyInfo + + _, err := asn1.Unmarshal(der, &keyInfo) + if err != nil { + return nil, errors.New("x509: unknown format") + } + if !reflect.DeepEqual(keyInfo.EncryptionAlgorithm.IdPBES2, oidPBES2) { + return nil, errors.New("x509: only support PBES2") + } + encryptionScheme := keyInfo.EncryptionAlgorithm.Pbes2Params.EncryptionScheme + keyDerivationFunc := keyInfo.EncryptionAlgorithm.Pbes2Params.KeyDerivationFunc + if !reflect.DeepEqual(keyDerivationFunc.IdPBKDF2, oidPBKDF2) { + return nil, errors.New("x509: only support PBKDF2") + } + pkdf2Params := keyDerivationFunc.Pkdf2Params + if !reflect.DeepEqual(encryptionScheme.EncryAlgo, oidAES128CBC) && + !reflect.DeepEqual(encryptionScheme.EncryAlgo, oidAES256CBC) { + return nil, errors.New("x509: unknow encryption algorithm") + } + iv := encryptionScheme.IV + salt := pkdf2Params.Salt + iter := pkdf2Params.IterationCount + encryptedKey := keyInfo.EncryptedData + var key []byte + switch { + case pkdf2Params.Prf.Algorithm.Equal(oidKEYMD5): + key = pbkdf(pwd, salt, iter, 32, md5.New) + break + case pkdf2Params.Prf.Algorithm.Equal(oidKEYSHA1): + key = pbkdf(pwd, salt, iter, 32, sha1.New) + break + case pkdf2Params.Prf.Algorithm.Equal(oidKEYSHA256): + key = pbkdf(pwd, salt, iter, 32, sha256.New) + break + case pkdf2Params.Prf.Algorithm.Equal(oidKEYSHA512): + key = pbkdf(pwd, salt, iter, 32, sha512.New) + break + default: + return nil, errors.New("x509: unknown hash algorithm") + } + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + mode := cipher.NewCBCDecrypter(block, iv) + mode.CryptBlocks(encryptedKey, encryptedKey) + rKey, err := ParsePKCS8UnecryptedPrivateKey(encryptedKey) + if err != nil { + return nil, errors.New("pkcs8: incorrect password") + } + return rKey, nil +} + +func ParsePKCS8PrivateKey(der, pwd []byte) (*PrivateKey, error) { + if pwd == nil { + return ParsePKCS8UnecryptedPrivateKey(der) + } + return ParsePKCS8EcryptedPrivateKey(der, pwd) +} + +func MarshalSm2UnecryptedPrivateKey(key *PrivateKey) ([]byte, error) { + var r pkcs8 + var priv sm2PrivateKey + var algo pkix.AlgorithmIdentifier + + algo.Algorithm = oidSM2 + algo.Parameters.Class = 0 + algo.Parameters.Tag = 6 + algo.Parameters.IsCompound = false + algo.Parameters.FullBytes = []byte{6, 8, 42, 129, 28, 207, 85, 1, 130, 45} // asn1.Marshal(asn1.ObjectIdentifier{1, 2, 156, 10197, 1, 301}) + priv.Version = 1 + priv.NamedCurveOID = oidNamedCurveP256SM2 + priv.PublicKey = asn1.BitString{Bytes: elliptic.Marshal(key.Curve, key.X, key.Y)} + priv.PrivateKey = key.D.Bytes() + r.Version = 0 + r.Algo = algo + r.PrivateKey, _ = asn1.Marshal(priv) + return asn1.Marshal(r) +} + +func MarshalSm2EcryptedPrivateKey(PrivKey *PrivateKey, pwd []byte) ([]byte, error) { + der, err := MarshalSm2UnecryptedPrivateKey(PrivKey) + if err != nil { + return nil, err + } + iter := 2048 + salt := make([]byte, 8) + iv := make([]byte, 16) + rand.Reader.Read(salt) + rand.Reader.Read(iv) + key := pbkdf(pwd, salt, iter, 32, sha1.New) // 默认是SHA1 + padding := aes.BlockSize - len(der)%aes.BlockSize + if padding > 0 { + n := len(der) + der = append(der, make([]byte, padding)...) + for i := 0; i < padding; i++ { + der[n+i] = byte(padding) + } + } + encryptedKey := make([]byte, len(der)) + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + mode := cipher.NewCBCEncrypter(block, iv) + mode.CryptBlocks(encryptedKey, der) + var algorithmIdentifier pkix.AlgorithmIdentifier + algorithmIdentifier.Algorithm = oidKEYSHA1 + algorithmIdentifier.Parameters.Tag = 5 + algorithmIdentifier.Parameters.IsCompound = false + algorithmIdentifier.Parameters.FullBytes = []byte{5, 0} + keyDerivationFunc := Pbes2KDfs{ + oidPBKDF2, + Pkdf2Params{ + salt, + iter, + algorithmIdentifier, + }, + } + encryptionScheme := Pbes2Encs{ + oidAES256CBC, + iv, + } + pbes2Algorithms := Pbes2Algorithms{ + oidPBES2, + Pbes2Params{ + keyDerivationFunc, + encryptionScheme, + }, + } + encryptedPkey := EncryptedPrivateKeyInfo{ + pbes2Algorithms, + encryptedKey, + } + return asn1.Marshal(encryptedPkey) +} + +func MarshalSm2PrivateKey(key *PrivateKey, pwd []byte) ([]byte, error) { + if pwd == nil { + return MarshalSm2UnecryptedPrivateKey(key) + } + return MarshalSm2EcryptedPrivateKey(key, pwd) +} + +func ReadPrivateKeyFromMem(data []byte, pwd []byte) (*PrivateKey, error) { + var block *pem.Block + + block, _ = pem.Decode(data) + if block == nil { + return nil, errors.New("failed to decode private key") + } + priv, err := ParsePKCS8PrivateKey(block.Bytes, pwd) + return priv, err +} + +func ReadPrivateKeyFromPem(FileName string, pwd []byte) (*PrivateKey, error) { + data, err := ioutil.ReadFile(FileName) + if err != nil { + return nil, err + } + return ReadPrivateKeyFromMem(data, pwd) +} + +func WritePrivateKeytoMem(key *PrivateKey, pwd []byte) ([]byte, error) { + var block *pem.Block + + der, err := MarshalSm2PrivateKey(key, pwd) + if err != nil { + return nil, err + } + if pwd != nil { + block = &pem.Block{ + Type: "ENCRYPTED PRIVATE KEY", + Bytes: der, + } + } else { + block = &pem.Block{ + Type: "PRIVATE KEY", + Bytes: der, + } + } + return pem.EncodeToMemory(block), nil +} + +func WritePrivateKeytoPem(FileName string, key *PrivateKey, pwd []byte) (bool, error) { + var block *pem.Block + + der, err := MarshalSm2PrivateKey(key, pwd) + if err != nil { + return false, err + } + if pwd != nil { + block = &pem.Block{ + Type: "ENCRYPTED PRIVATE KEY", + Bytes: der, + } + } else { + block = &pem.Block{ + Type: "PRIVATE KEY", + Bytes: der, + } + } + file, err := os.Create(FileName) + if err != nil { + return false, err + } + defer file.Close() + err = pem.Encode(file, block) + if err != nil { + return false, err + } + return true, nil +} + +func ReadPublicKeyFromMem(data []byte, _ []byte) (*PublicKey, error) { + block, _ := pem.Decode(data) + if block == nil || block.Type != "PUBLIC KEY" { + return nil, errors.New("failed to decode public key") + } + pub, err := ParseSm2PublicKey(block.Bytes) + return pub, err +} + +func ReadPublicKeyFromPem(FileName string, pwd []byte) (*PublicKey, error) { + data, err := ioutil.ReadFile(FileName) + if err != nil { + return nil, err + } + return ReadPublicKeyFromMem(data, pwd) +} + +func WritePublicKeytoMem(key *PublicKey, _ []byte) ([]byte, error) { + der, err := MarshalSm2PublicKey(key) + if err != nil { + return nil, err + } + block := &pem.Block{ + Type: "PUBLIC KEY", + Bytes: der, + } + return pem.EncodeToMemory(block), nil +} + +func WritePublicKeytoPem(FileName string, key *PublicKey, _ []byte) (bool, error) { + der, err := MarshalSm2PublicKey(key) + if err != nil { + return false, err + } + block := &pem.Block{ + Type: "PUBLIC KEY", + Bytes: der, + } + file, err := os.Create(FileName) + defer file.Close() + if err != nil { + return false, err + } + err = pem.Encode(file, block) + if err != nil { + return false, err + } + return true, nil +} diff --git a/crypto/sm2/sm2.go b/crypto/sm2/sm2.go new file mode 100644 index 00000000..897c1091 --- /dev/null +++ b/crypto/sm2/sm2.go @@ -0,0 +1,524 @@ +/* +Copyright Suzhou Tongji Fintech Research Institute 2017 All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sm2 + +// reference to ecdsa +import ( + "bytes" + "crypto" + "crypto/aes" + "crypto/cipher" + "crypto/elliptic" + "crypto/rand" + "crypto/sha512" + "encoding/asn1" + "encoding/binary" + "errors" + "io" + "math/big" + + "github.com/bytom/crypto/sm3" +) + +const ( + aesIV = "IV for CTR" +) + +type PublicKey struct { + elliptic.Curve + X, Y *big.Int +} + +type PrivateKey struct { + PublicKey + D *big.Int +} + +type sm2Signature struct { + R, S *big.Int +} + +// The SM2's private key contains the public key +func (priv *PrivateKey) Public() crypto.PublicKey { + return &priv.PublicKey +} + +func SignDigitToSignData(r, s *big.Int) ([]byte, error) { + return asn1.Marshal(sm2Signature{r, s}) +} + +func SignDataToSignDigit(sign []byte) (*big.Int, *big.Int, error) { + var sm2Sign sm2Signature + + _, err := asn1.Unmarshal(sign, &sm2Sign) + if err != nil { + return nil, nil, err + } + return sm2Sign.R, sm2Sign.S, nil +} + +// sign format = 30 + len(z) + 02 + len(r) + r + 02 + len(s) + s, z being what follows its size, ie 02+len(r)+r+02+len(s)+s +func (priv *PrivateKey) Sign(rand io.Reader, msg []byte, opts crypto.SignerOpts) ([]byte, error) { + r, s, err := Sign(priv, msg) + if err != nil { + return nil, err + } + return asn1.Marshal(sm2Signature{r, s}) +} + +func (priv *PrivateKey) Decrypt(data []byte) ([]byte, error) { + return Decrypt(priv, data) +} + +func (pub *PublicKey) Verify(msg []byte, sign []byte) bool { + var sm2Sign sm2Signature + + _, err := asn1.Unmarshal(sign, &sm2Sign) + if err != nil { + return false + } + return Verify(pub, msg, sm2Sign.R, sm2Sign.S) +} + +func (pub *PublicKey) Encrypt(data []byte) ([]byte, error) { + return Encrypt(pub, data) +} + +var one = new(big.Int).SetInt64(1) + +func intToBytes(x int) []byte { + var buf = make([]byte, 4) + + binary.BigEndian.PutUint32(buf, uint32(x)) + return buf +} + +func kdf(x, y []byte, length int) ([]byte, bool) { + var c []byte + + ct := 1 + h := sm3.New() + x = append(x, y...) + for i, j := 0, (length+31)/32; i < j; i++ { + h.Reset() + h.Write(x) + h.Write(intToBytes(ct)) + hash := h.Sum(nil) + if i+1 == j && length%32 != 0 { + c = append(c, hash[:length%32]...) + } else { + c = append(c, hash...) + } + ct++ + } + for i := 0; i < length; i++ { + if c[i] != 0 { + return c, true + } + } + return c, false +} + +func randFieldElement(c elliptic.Curve, rand io.Reader) (k *big.Int, err error) { + params := c.Params() + b := make([]byte, params.BitSize/8+8) + _, err = io.ReadFull(rand, b) + if err != nil { + return + } + k = new(big.Int).SetBytes(b) + n := new(big.Int).Sub(params.N, one) + k.Mod(k, n) + k.Add(k, one) + return +} + +func GenerateKey() (*PrivateKey, error) { + c := P256Sm2() + k, err := randFieldElement(c, rand.Reader) + if err != nil { + return nil, err + } + priv := new(PrivateKey) + priv.PublicKey.Curve = c + priv.D = k + priv.PublicKey.X, priv.PublicKey.Y = c.ScalarBaseMult(k.Bytes()) + return priv, nil +} + +var errZeroParam = errors.New("zero parameter") + +func Sign(priv *PrivateKey, hash []byte) (r, s *big.Int, err error) { + entropylen := (priv.Curve.Params().BitSize + 7) / 16 + if entropylen > 32 { + entropylen = 32 + } + entropy := make([]byte, entropylen) + _, err = io.ReadFull(rand.Reader, entropy) + if err != nil { + return + } + + // Initialize an SHA-512 hash context; digest ... + md := sha512.New() + md.Write(priv.D.Bytes()) // the private key, + md.Write(entropy) // the entropy, + md.Write(hash) // and the input hash; + key := md.Sum(nil)[:32] // and compute ChopMD-256(SHA-512), + // which is an indifferentiable MAC. + + // Create an AES-CTR instance to use as a CSPRNG. + block, err := aes.NewCipher(key) + if err != nil { + return nil, nil, err + } + + // Create a CSPRNG that xors a stream of zeros with + // the output of the AES-CTR instance. + csprng := cipher.StreamReader{ + R: zeroReader, + S: cipher.NewCTR(block, []byte(aesIV)), + } + + // See [NSA] 3.4.1 + c := priv.PublicKey.Curve + N := c.Params().N + if N.Sign() == 0 { + return nil, nil, errZeroParam + } + var k *big.Int + e := new(big.Int).SetBytes(hash) + for { // 调整算法细节以实现SM2 + for { + k, err = randFieldElement(c, csprng) + if err != nil { + r = nil + return + } + r, _ = priv.Curve.ScalarBaseMult(k.Bytes()) + r.Add(r, e) + r.Mod(r, N) + if r.Sign() != 0 { + break + } + if t := new(big.Int).Add(r, k); t.Cmp(N) == 0 { + break + } + } + rD := new(big.Int).Mul(priv.D, r) + s = new(big.Int).Sub(k, rD) + d1 := new(big.Int).Add(priv.D, one) + d1Inv := new(big.Int).ModInverse(d1, N) + s.Mul(s, d1Inv) + s.Mod(s, N) + if s.Sign() != 0 { + break + } + } + return +} + +func Verify(pub *PublicKey, hash []byte, r, s *big.Int) bool { + c := pub.Curve + N := c.Params().N + + if r.Sign() <= 0 || s.Sign() <= 0 { + return false + } + if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 { + return false + } + + // 调整算法细节以实现SM2 + t := new(big.Int).Add(r, s) + t.Mod(t, N) + if t.Sign() == 0 { + return false + } + + var x *big.Int + x1, y1 := c.ScalarBaseMult(s.Bytes()) + x2, y2 := c.ScalarMult(pub.X, pub.Y, t.Bytes()) + x, _ = c.Add(x1, y1, x2, y2) + + e := new(big.Int).SetBytes(hash) + x.Add(x, e) + x.Mod(x, N) + return x.Cmp(r) == 0 +} + +func Sm2Sign(priv *PrivateKey, msg, uid []byte) (r, s *big.Int, err error) { + za, err := ZA(&priv.PublicKey, uid) + if err != nil { + return nil, nil, err + } + e, err := msgHash(za, msg) + if err != nil { + return nil, nil, err + } + c := priv.PublicKey.Curve + N := c.Params().N + if N.Sign() == 0 { + return nil, nil, errZeroParam + } + var k *big.Int + for { // 调整算法细节以实现SM2 + for { + k, err = randFieldElement(c, rand.Reader) + if err != nil { + r = nil + return + } + r, _ = priv.Curve.ScalarBaseMult(k.Bytes()) + r.Add(r, e) + r.Mod(r, N) + if r.Sign() != 0 { + break + } + if t := new(big.Int).Add(r, k); t.Cmp(N) == 0 { + break + } + } + rD := new(big.Int).Mul(priv.D, r) + s = new(big.Int).Sub(k, rD) + d1 := new(big.Int).Add(priv.D, one) + d1Inv := new(big.Int).ModInverse(d1, N) + s.Mul(s, d1Inv) + s.Mod(s, N) + if s.Sign() != 0 { + break + } + } + return +} + +func Sm2Verify(pub *PublicKey, msg, uid []byte, r, s *big.Int) bool { + c := pub.Curve + N := c.Params().N + one := new(big.Int).SetInt64(1) + if r.Cmp(one) < 0 || s.Cmp(one) < 0 { + return false + } + if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 { + return false + } + za, err := ZA(pub, uid) + if err != nil { + return false + } + e, err := msgHash(za, msg) + if err != nil { + return false + } + t := new(big.Int).Add(r, s) + t.Mod(t, N) + if t.Sign() == 0 { + return false + } + var x *big.Int + x1, y1 := c.ScalarBaseMult(s.Bytes()) + x2, y2 := c.ScalarMult(pub.X, pub.Y, t.Bytes()) + x, _ = c.Add(x1, y1, x2, y2) + + x.Add(x, e) + x.Mod(x, N) + return x.Cmp(r) == 0 +} + +func msgHash(za, msg []byte) (*big.Int, error) { + e := sm3.New() + e.Write(za) + e.Write(msg) + return new(big.Int).SetBytes(e.Sum(nil)[:32]), nil +} + +// ZA = H256(ENTLA || IDA || a || b || xG || yG || xA || yA) +func ZA(pub *PublicKey, uid []byte) ([]byte, error) { + za := sm3.New() + uidLen := len(uid) + if uidLen >= 8192 { + return []byte{}, errors.New("SM2: uid too large") + } + Entla := uint16(8 * uidLen) + za.Write([]byte{byte((Entla >> 8) & 0xFF)}) + za.Write([]byte{byte(Entla & 0xFF)}) + za.Write(uid) + za.Write(sm2P256ToBig(&sm2P256.a).Bytes()) + za.Write(sm2P256.B.Bytes()) + za.Write(sm2P256.Gx.Bytes()) + za.Write(sm2P256.Gy.Bytes()) + + xBuf := pub.X.Bytes() + yBuf := pub.Y.Bytes() + if n := len(xBuf); n < 32 { + xBuf = append(zeroByteSlice[:32-n], xBuf...) + } + za.Write(xBuf) + za.Write(yBuf) + return za.Sum(nil)[:32], nil +} + +// 32byte +var zeroByteSlice = []byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, +} + +/* + * sm2密文结构如下: + * x + * y + * hash + * CipherText + */ +func Encrypt(pub *PublicKey, data []byte) ([]byte, error) { + length := len(data) + for { + c := []byte{} + curve := pub.Curve + k, err := randFieldElement(curve, rand.Reader) + if err != nil { + return nil, err + } + x1, y1 := curve.ScalarBaseMult(k.Bytes()) + x2, y2 := curve.ScalarMult(pub.X, pub.Y, k.Bytes()) + x1Buf := x1.Bytes() + y1Buf := y1.Bytes() + x2Buf := x2.Bytes() + y2Buf := y2.Bytes() + if n := len(x1Buf); n < 32 { + x1Buf = append(zeroByteSlice[:32-n], x1Buf...) + } + if n := len(y1Buf); n < 32 { + y1Buf = append(zeroByteSlice[:32-n], y1Buf...) + } + if n := len(x2Buf); n < 32 { + x2Buf = append(zeroByteSlice[:32-n], x2Buf...) + } + if n := len(y2Buf); n < 32 { + y2Buf = append(zeroByteSlice[:32-n], y2Buf...) + } + c = append(c, x1Buf...) // x分量 + c = append(c, y1Buf...) // y分量 + tm := []byte{} + tm = append(tm, x2Buf...) + tm = append(tm, data...) + tm = append(tm, y2Buf...) + h := sm3.Sm3Sum(tm) + c = append(c, h...) + ct, ok := kdf(x2Buf, y2Buf, length) // 密文 + if !ok { + continue + } + c = append(c, ct...) + for i := 0; i < length; i++ { + c[96+i] ^= data[i] + } + return append([]byte{0x04}, c...), nil + } +} + +func Decrypt(priv *PrivateKey, data []byte) ([]byte, error) { + data = data[1:] + length := len(data) - 96 + curve := priv.Curve + x := new(big.Int).SetBytes(data[:32]) + y := new(big.Int).SetBytes(data[32:64]) + x2, y2 := curve.ScalarMult(x, y, priv.D.Bytes()) + x2Buf := x2.Bytes() + y2Buf := y2.Bytes() + if n := len(x2Buf); n < 32 { + x2Buf = append(zeroByteSlice[:32-n], x2Buf...) + } + if n := len(y2Buf); n < 32 { + y2Buf = append(zeroByteSlice[:32-n], y2Buf...) + } + c, ok := kdf(x2Buf, y2Buf, length) + if !ok { + return nil, errors.New("Decrypt: failed to decrypt") + } + for i := 0; i < length; i++ { + c[i] ^= data[i+96] + } + tm := []byte{} + tm = append(tm, x2Buf...) + tm = append(tm, c...) + tm = append(tm, y2Buf...) + h := sm3.Sm3Sum(tm) + if bytes.Compare(h, data[64:96]) != 0 { + return c, errors.New("Decrypt: failed to decrypt") + } + return c, nil +} + +type zr struct { + io.Reader +} + +func (z *zr) Read(dst []byte) (n int, err error) { + for i := range dst { + dst[i] = 0 + } + return len(dst), nil +} + +var zeroReader = &zr{} + +func getLastBit(a *big.Int) uint { + return a.Bit(0) +} + +func Compress(a *PublicKey) []byte { + buf := []byte{} + yp := getLastBit(a.Y) + buf = append(buf, a.X.Bytes()...) + if n := len(a.X.Bytes()); n < 32 { + buf = append(zeroByteSlice[:(32-n)], buf...) + } + buf = append([]byte{byte(yp)}, buf...) + return buf +} + +func Decompress(a []byte) *PublicKey { + var aa, xx, xx3 sm2P256FieldElement + + P256Sm2() + x := new(big.Int).SetBytes(a[1:]) + curve := sm2P256 + sm2P256FromBig(&xx, x) + sm2P256Square(&xx3, &xx) // x3 = x ^ 2 + sm2P256Mul(&xx3, &xx3, &xx) // x3 = x ^ 2 * x + sm2P256Mul(&aa, &curve.a, &xx) // a = a * x + sm2P256Add(&xx3, &xx3, &aa) + sm2P256Add(&xx3, &xx3, &curve.b) + + y2 := sm2P256ToBig(&xx3) + y := new(big.Int).ModSqrt(y2, sm2P256.P) + if getLastBit(y) != uint(a[0]) { + y.Sub(sm2P256.P, y) + } + return &PublicKey{ + Curve: P256Sm2(), + X: x, + Y: y, + } +} diff --git a/crypto/sm2/sm2_test.go b/crypto/sm2/sm2_test.go new file mode 100644 index 00000000..0820118a --- /dev/null +++ b/crypto/sm2/sm2_test.go @@ -0,0 +1,219 @@ +/* +Copyright Suzhou Tongji Fintech Research Institute 2017 All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sm2 + +import ( + "crypto/rand" + "crypto/x509/pkix" + "encoding/asn1" + "fmt" + "io/ioutil" + "log" + "math/big" + "net" + "os" + "testing" + "time" +) + +func TestSm2(t *testing.T) { + priv, err := GenerateKey() // 生成密钥对 + if err != nil { + log.Fatal(err) + } + fmt.Printf("%v\n", priv.Curve.IsOnCurve(priv.X, priv.Y)) // 验证是否为sm2的曲线 + pub := &priv.PublicKey + msg := []byte("123456") + d0, err := pub.Encrypt(msg) + if err != nil { + fmt.Printf("Error: failed to encrypt %s: %v\n", msg, err) + return + } + fmt.Printf("Cipher text = %v\n", d0) + d1, err := priv.Decrypt(d0) + if err != nil { + fmt.Printf("Error: failed to decrypt: %v\n", err) + } + fmt.Printf("clear text = %s\n", d1) + ok, err := WritePrivateKeytoPem("priv.pem", priv, nil) // 生成密钥文件 + if ok != true { + log.Fatal(err) + } + pubKey, _ := priv.Public().(*PublicKey) + ok, err = WritePublicKeytoPem("pub.pem", pubKey, nil) // 生成公钥文件 + if ok != true { + log.Fatal(err) + } + msg = []byte("test") + err = ioutil.WriteFile("ifile", msg, os.FileMode(0644)) // 生成测试文件 + if err != nil { + log.Fatal(err) + } + privKey, err := ReadPrivateKeyFromPem("priv.pem", nil) // 读取密钥 + if err != nil { + log.Fatal(err) + } + pubKey, err = ReadPublicKeyFromPem("pub.pem", nil) // 读取公钥 + if err != nil { + log.Fatal(err) + } + msg, _ = ioutil.ReadFile("ifile") // 从文件读取数据 + sign, err := privKey.Sign(rand.Reader, msg, nil) // 签名 + if err != nil { + log.Fatal(err) + } + err = ioutil.WriteFile("ofile", sign, os.FileMode(0644)) + if err != nil { + log.Fatal(err) + } + signdata, _ := ioutil.ReadFile("ofile") + ok = privKey.Verify(msg, signdata) // 密钥验证 + if ok != true { + fmt.Printf("Verify error\n") + } else { + fmt.Printf("Verify ok\n") + } + ok = pubKey.Verify(msg, signdata) // 公钥验证 + if ok != true { + fmt.Printf("Verify error\n") + } else { + fmt.Printf("Verify ok\n") + } + templateReq := CertificateRequest{ + Subject: pkix.Name{ + CommonName: "test.example.com", + Organization: []string{"Test"}, + }, + // SignatureAlgorithm: ECDSAWithSHA256, + SignatureAlgorithm: SM2WithSM3, + } + _, err = CreateCertificateRequestToPem("req.pem", &templateReq, privKey) + if err != nil { + log.Fatal(err) + } + req, err := ReadCertificateRequestFromPem("req.pem") + if err != nil { + log.Fatal(err) + } + err = req.CheckSignature() + if err != nil { + log.Fatal(err) + } else { + fmt.Printf("CheckSignature ok\n") + } + testExtKeyUsage := []ExtKeyUsage{ExtKeyUsageClientAuth, ExtKeyUsageServerAuth} + testUnknownExtKeyUsage := []asn1.ObjectIdentifier{[]int{1, 2, 3}, []int{2, 59, 1}} + extraExtensionData := []byte("extra extension") + commonName := "test.example.com" + template := Certificate{ + // SerialNumber is negative to ensure that negative + // values are parsed. This is due to the prevalence of + // buggy code that produces certificates with negative + // serial numbers. + SerialNumber: big.NewInt(-1), + Subject: pkix.Name{ + CommonName: commonName, + Organization: []string{"TEST"}, + Country: []string{"China"}, + ExtraNames: []pkix.AttributeTypeAndValue{ + { + Type: []int{2, 5, 4, 42}, + Value: "Gopher", + }, + // This should override the Country, above. + { + Type: []int{2, 5, 4, 6}, + Value: "NL", + }, + }, + }, + NotBefore: time.Unix(1000, 0), + NotAfter: time.Unix(100000, 0), + + // SignatureAlgorithm: ECDSAWithSHA256, + SignatureAlgorithm: SM2WithSM3, + + SubjectKeyId: []byte{1, 2, 3, 4}, + KeyUsage: KeyUsageCertSign, + + ExtKeyUsage: testExtKeyUsage, + UnknownExtKeyUsage: testUnknownExtKeyUsage, + + BasicConstraintsValid: true, + IsCA: true, + + OCSPServer: []string{"http://ocsp.example.com"}, + IssuingCertificateURL: []string{"http://crt.example.com/ca1.crt"}, + + DNSNames: []string{"test.example.com"}, + EmailAddresses: []string{"gopher@golang.org"}, + IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1).To4(), net.ParseIP("2001:4860:0:2001::68")}, + + PolicyIdentifiers: []asn1.ObjectIdentifier{[]int{1, 2, 3}}, + PermittedDNSDomains: []string{".example.com", "example.com"}, + + CRLDistributionPoints: []string{"http://crl1.example.com/ca1.crl", "http://crl2.example.com/ca1.crl"}, + + ExtraExtensions: []pkix.Extension{ + { + Id: []int{1, 2, 3, 4}, + Value: extraExtensionData, + }, + // This extension should override the SubjectKeyId, above. + { + Id: oidExtensionSubjectKeyId, + Critical: false, + Value: []byte{0x04, 0x04, 4, 3, 2, 1}, + }, + }, + } + pubKey, _ = priv.Public().(*PublicKey) + ok, _ = CreateCertificateToPem("cert.pem", &template, &template, pubKey, privKey) + if ok != true { + fmt.Printf("failed to create cert file\n") + } + cert, err := ReadCertificateFromPem("cert.pem") + if err != nil { + fmt.Printf("failed to read cert file") + } + err = cert.CheckSignature(cert.SignatureAlgorithm, cert.RawTBSCertificate, cert.Signature) + if err != nil { + log.Fatal(err) + } else { + fmt.Printf("CheckSignature ok\n") + } +} + +func BenchmarkSM2(t *testing.B) { + t.ReportAllocs() + for i := 0; i < t.N; i++ { + priv, err := GenerateKey() // 生成密钥对 + if err != nil { + log.Fatal(err) + } + msg := []byte("test") + sign, err := priv.Sign(rand.Reader, msg, nil) // 签名 + if err != nil { + log.Fatal(err) + } + ok := priv.Verify(msg, sign) // 密钥验证 + if ok != true { + fmt.Printf("Verify error\n") + } else { + fmt.Printf("Verify ok\n") + } + } +} diff --git a/crypto/sm2/verify.go b/crypto/sm2/verify.go new file mode 100644 index 00000000..e822b22b --- /dev/null +++ b/crypto/sm2/verify.go @@ -0,0 +1,568 @@ +/* +Copyright Suzhou Tongji Fintech Research Institute 2017 All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sm2 + +import ( + "bytes" + "errors" + "fmt" + "net" + "runtime" + "strings" + "time" + "unicode/utf8" +) + +type InvalidReason int + +const ( + // NotAuthorizedToSign results when a certificate is signed by another + // which isn't marked as a CA certificate. + NotAuthorizedToSign InvalidReason = iota + // Expired results when a certificate has expired, based on the time + // given in the VerifyOptions. + Expired + // CANotAuthorizedForThisName results when an intermediate or root + // certificate has a name constraint which doesn't include the name + // being checked. + CANotAuthorizedForThisName + // TooManyIntermediates results when a path length constraint is + // violated. + TooManyIntermediates + // IncompatibleUsage results when the certificate's key usage indicates + // that it may only be used for a different purpose. + IncompatibleUsage + // NameMismatch results when the subject name of a parent certificate + // does not match the issuer name in the child. + NameMismatch +) + +// CertificateInvalidError results when an odd error occurs. Users of this +// library probably want to handle all these errors uniformly. +type CertificateInvalidError struct { + Cert *Certificate + Reason InvalidReason +} + +func (e CertificateInvalidError) Error() string { + switch e.Reason { + case NotAuthorizedToSign: + return "x509: certificate is not authorized to sign other certificates" + case Expired: + return "x509: certificate has expired or is not yet valid" + case CANotAuthorizedForThisName: + return "x509: a root or intermediate certificate is not authorized to sign in this domain" + case TooManyIntermediates: + return "x509: too many intermediates for path length constraint" + case IncompatibleUsage: + return "x509: certificate specifies an incompatible key usage" + case NameMismatch: + return "x509: issuer name does not match subject from issuing certificate" + } + return "x509: unknown error" +} + +// HostnameError results when the set of authorized names doesn't match the +// requested name. +type HostnameError struct { + Certificate *Certificate + Host string +} + +func (h HostnameError) Error() string { + c := h.Certificate + + var valid string + if ip := net.ParseIP(h.Host); ip != nil { + // Trying to validate an IP + if len(c.IPAddresses) == 0 { + return "x509: cannot validate certificate for " + h.Host + " because it doesn't contain any IP SANs" + } + for _, san := range c.IPAddresses { + if len(valid) > 0 { + valid += ", " + } + valid += san.String() + } + } else { + if len(c.DNSNames) > 0 { + valid = strings.Join(c.DNSNames, ", ") + } else { + valid = c.Subject.CommonName + } + } + + if len(valid) == 0 { + return "x509: certificate is not valid for any names, but wanted to match " + h.Host + } + return "x509: certificate is valid for " + valid + ", not " + h.Host +} + +// UnknownAuthorityError results when the certificate issuer is unknown +type UnknownAuthorityError struct { + Cert *Certificate + // hintErr contains an error that may be helpful in determining why an + // authority wasn't found. + hintErr error + // hintCert contains a possible authority certificate that was rejected + // because of the error in hintErr. + hintCert *Certificate +} + +func (e UnknownAuthorityError) Error() string { + s := "x509: certificate signed by unknown authority" + if e.hintErr != nil { + certName := e.hintCert.Subject.CommonName + if len(certName) == 0 { + if len(e.hintCert.Subject.Organization) > 0 { + certName = e.hintCert.Subject.Organization[0] + } else { + certName = "serial:" + e.hintCert.SerialNumber.String() + } + } + s += fmt.Sprintf(" (possibly because of %q while trying to verify candidate authority certificate %q)", e.hintErr, certName) + } + return s +} + +// SystemRootsError results when we fail to load the system root certificates. +type SystemRootsError struct { + Err error +} + +func (se SystemRootsError) Error() string { + msg := "x509: failed to load system roots and no roots provided" + if se.Err != nil { + return msg + "; " + se.Err.Error() + } + return msg +} + +// errNotParsed is returned when a certificate without ASN.1 contents is +// verified. Platform-specific verification needs the ASN.1 contents. +var errNotParsed = errors.New("x509: missing ASN.1 contents; use ParseCertificate") + +// VerifyOptions contains parameters for Certificate.Verify. It's a structure +// because other PKIX verification APIs have ended up needing many options. +type VerifyOptions struct { + DNSName string + Intermediates *CertPool + Roots *CertPool // if nil, the system roots are used + CurrentTime time.Time // if zero, the current time is used + // KeyUsage specifies which Extended Key Usage values are acceptable. + // An empty list means ExtKeyUsageServerAuth. Key usage is considered a + // constraint down the chain which mirrors Windows CryptoAPI behavior, + // but not the spec. To accept any key usage, include ExtKeyUsageAny. + KeyUsages []ExtKeyUsage +} + +const ( + leafCertificate = iota + intermediateCertificate + rootCertificate +) + +func matchNameConstraint(domain, constraint string) bool { + // The meaning of zero length constraints is not specified, but this + // code follows NSS and accepts them as valid for everything. + if len(constraint) == 0 { + return true + } + + if len(domain) < len(constraint) { + return false + } + + prefixLen := len(domain) - len(constraint) + if !strings.EqualFold(domain[prefixLen:], constraint) { + return false + } + + if prefixLen == 0 { + return true + } + + isSubdomain := domain[prefixLen-1] == '.' + constraintHasLeadingDot := constraint[0] == '.' + return isSubdomain != constraintHasLeadingDot +} + +// isValid performs validity checks on the c. +func (c *Certificate) isValid(certType int, currentChain []*Certificate, opts *VerifyOptions) error { + if len(currentChain) > 0 { + child := currentChain[len(currentChain)-1] + if !bytes.Equal(child.RawIssuer, c.RawSubject) { + return CertificateInvalidError{c, NameMismatch} + } + } + now := opts.CurrentTime + if now.IsZero() { + now = time.Now() + } + if now.Before(c.NotBefore) || now.After(c.NotAfter) { + return CertificateInvalidError{c, Expired} + } + if len(c.PermittedDNSDomains) > 0 { + ok := false + for _, constraint := range c.PermittedDNSDomains { + ok = matchNameConstraint(opts.DNSName, constraint) + if ok { + break + } + } + + if !ok { + return CertificateInvalidError{c, CANotAuthorizedForThisName} + } + } + + // KeyUsage status flags are ignored. From Engineering Security, Peter + // Gutmann: A European government CA marked its signing certificates as + // being valid for encryption only, but no-one noticed. Another + // European CA marked its signature keys as not being valid for + // signatures. A different CA marked its own trusted root certificate + // as being invalid for certificate signing. Another national CA + // distributed a certificate to be used to encrypt data for the + // country’s tax authority that was marked as only being usable for + // digital signatures but not for encryption. Yet another CA reversed + // the order of the bit flags in the keyUsage due to confusion over + // encoding endianness, essentially setting a random keyUsage in + // certificates that it issued. Another CA created a self-invalidating + // certificate by adding a certificate policy statement stipulating + // that the certificate had to be used strictly as specified in the + // keyUsage, and a keyUsage containing a flag indicating that the RSA + // encryption key could only be used for Diffie-Hellman key agreement. + + if certType == intermediateCertificate && (!c.BasicConstraintsValid || !c.IsCA) { + return CertificateInvalidError{c, NotAuthorizedToSign} + } + + if c.BasicConstraintsValid && c.MaxPathLen >= 0 { + numIntermediates := len(currentChain) - 1 + if numIntermediates > c.MaxPathLen { + return CertificateInvalidError{c, TooManyIntermediates} + } + } + + return nil +} + +// Verify attempts to verify c by building one or more chains from c to a +// certificate in opts.Roots, using certificates in opts.Intermediates if +// needed. If successful, it returns one or more chains where the first +// element of the chain is c and the last element is from opts.Roots. +// +// If opts.Roots is nil and system roots are unavailable the returned error +// will be of type SystemRootsError. +// +// WARNING: this doesn't do any revocation checking. +func (c *Certificate) Verify(opts VerifyOptions) (chains [][]*Certificate, err error) { + // Platform-specific verification needs the ASN.1 contents so + // this makes the behavior consistent across platforms. + if len(c.Raw) == 0 { + return nil, errNotParsed + } + if opts.Intermediates != nil { + for _, intermediate := range opts.Intermediates.certs { + if len(intermediate.Raw) == 0 { + return nil, errNotParsed + } + } + } + + // Use Windows's own verification and chain building. + if opts.Roots == nil && runtime.GOOS == "windows" { + return c.systemVerify(&opts) + } + + if len(c.UnhandledCriticalExtensions) > 0 { + return nil, UnhandledCriticalExtension{} + } + + if opts.Roots == nil { + opts.Roots = systemRootsPool() + if opts.Roots == nil { + return nil, SystemRootsError{systemRootsErr} + } + } + + err = c.isValid(leafCertificate, nil, &opts) + if err != nil { + return + } + + if len(opts.DNSName) > 0 { + err = c.VerifyHostname(opts.DNSName) + if err != nil { + return + } + } + + var candidateChains [][]*Certificate + if opts.Roots.contains(c) { + candidateChains = append(candidateChains, []*Certificate{c}) + } else { + if candidateChains, err = c.buildChains(make(map[int][][]*Certificate), []*Certificate{c}, &opts); err != nil { + return nil, err + } + } + + keyUsages := opts.KeyUsages + if len(keyUsages) == 0 { + keyUsages = []ExtKeyUsage{ExtKeyUsageServerAuth} + } + + // If any key usage is acceptable then we're done. + for _, usage := range keyUsages { + if usage == ExtKeyUsageAny { + chains = candidateChains + return + } + } + + for _, candidate := range candidateChains { + if checkChainForKeyUsage(candidate, keyUsages) { + chains = append(chains, candidate) + } + } + + if len(chains) == 0 { + err = CertificateInvalidError{c, IncompatibleUsage} + } + + return +} + +func appendToFreshChain(chain []*Certificate, cert *Certificate) []*Certificate { + n := make([]*Certificate, len(chain)+1) + copy(n, chain) + n[len(chain)] = cert + return n +} + +func (c *Certificate) buildChains(cache map[int][][]*Certificate, currentChain []*Certificate, opts *VerifyOptions) (chains [][]*Certificate, err error) { + possibleRoots, failedRoot, rootErr := opts.Roots.findVerifiedParents(c) +nextRoot: + for _, rootNum := range possibleRoots { + root := opts.Roots.certs[rootNum] + + for _, cert := range currentChain { + if cert.Equal(root) { + continue nextRoot + } + } + + err = root.isValid(rootCertificate, currentChain, opts) + if err != nil { + continue + } + chains = append(chains, appendToFreshChain(currentChain, root)) + } + + possibleIntermediates, failedIntermediate, intermediateErr := opts.Intermediates.findVerifiedParents(c) +nextIntermediate: + for _, intermediateNum := range possibleIntermediates { + intermediate := opts.Intermediates.certs[intermediateNum] + for _, cert := range currentChain { + if cert.Equal(intermediate) { + continue nextIntermediate + } + } + err = intermediate.isValid(intermediateCertificate, currentChain, opts) + if err != nil { + continue + } + var childChains [][]*Certificate + childChains, ok := cache[intermediateNum] + if !ok { + childChains, err = intermediate.buildChains(cache, appendToFreshChain(currentChain, intermediate), opts) + cache[intermediateNum] = childChains + } + chains = append(chains, childChains...) + } + + if len(chains) > 0 { + err = nil + } + + if len(chains) == 0 && err == nil { + hintErr := rootErr + hintCert := failedRoot + if hintErr == nil { + hintErr = intermediateErr + hintCert = failedIntermediate + } + err = UnknownAuthorityError{c, hintErr, hintCert} + } + + return +} + +func matchHostnames(pattern, host string) bool { + host = strings.TrimSuffix(host, ".") + pattern = strings.TrimSuffix(pattern, ".") + + if len(pattern) == 0 || len(host) == 0 { + return false + } + + patternParts := strings.Split(pattern, ".") + hostParts := strings.Split(host, ".") + + if len(patternParts) != len(hostParts) { + return false + } + + for i, patternPart := range patternParts { + if i == 0 && patternPart == "*" { + continue + } + if patternPart != hostParts[i] { + return false + } + } + + return true +} + +// toLowerCaseASCII returns a lower-case version of in. See RFC 6125 6.4.1. We use +// an explicitly ASCII function to avoid any sharp corners resulting from +// performing Unicode operations on DNS labels. +func toLowerCaseASCII(in string) string { + // If the string is already lower-case then there's nothing to do. + isAlreadyLowerCase := true + for _, c := range in { + if c == utf8.RuneError { + // If we get a UTF-8 error then there might be + // upper-case ASCII bytes in the invalid sequence. + isAlreadyLowerCase = false + break + } + if 'A' <= c && c <= 'Z' { + isAlreadyLowerCase = false + break + } + } + + if isAlreadyLowerCase { + return in + } + + out := []byte(in) + for i, c := range out { + if 'A' <= c && c <= 'Z' { + out[i] += 'a' - 'A' + } + } + return string(out) +} + +// VerifyHostname returns nil if c is a valid certificate for the named host. +// Otherwise it returns an error describing the mismatch. +func (c *Certificate) VerifyHostname(h string) error { + // IP addresses may be written in [ ]. + candidateIP := h + if len(h) >= 3 && h[0] == '[' && h[len(h)-1] == ']' { + candidateIP = h[1 : len(h)-1] + } + if ip := net.ParseIP(candidateIP); ip != nil { + // We only match IP addresses against IP SANs. + // https://tools.ietf.org/html/rfc6125#appendix-B.2 + for _, candidate := range c.IPAddresses { + if ip.Equal(candidate) { + return nil + } + } + return HostnameError{c, candidateIP} + } + + lowered := toLowerCaseASCII(h) + + if len(c.DNSNames) > 0 { + for _, match := range c.DNSNames { + if matchHostnames(toLowerCaseASCII(match), lowered) { + return nil + } + } + // If Subject Alt Name is given, we ignore the common name. + } else if matchHostnames(toLowerCaseASCII(c.Subject.CommonName), lowered) { + return nil + } + + return HostnameError{c, h} +} + +func checkChainForKeyUsage(chain []*Certificate, keyUsages []ExtKeyUsage) bool { + usages := make([]ExtKeyUsage, len(keyUsages)) + copy(usages, keyUsages) + + if len(chain) == 0 { + return false + } + + usagesRemaining := len(usages) + + // We walk down the list and cross out any usages that aren't supported + // by each certificate. If we cross out all the usages, then the chain + // is unacceptable. + +NextCert: + for i := len(chain) - 1; i >= 0; i-- { + cert := chain[i] + if len(cert.ExtKeyUsage) == 0 && len(cert.UnknownExtKeyUsage) == 0 { + // The certificate doesn't have any extended key usage specified. + continue + } + + for _, usage := range cert.ExtKeyUsage { + if usage == ExtKeyUsageAny { + // The certificate is explicitly good for any usage. + continue NextCert + } + } + + const invalidUsage ExtKeyUsage = -1 + + NextRequestedUsage: + for i, requestedUsage := range usages { + if requestedUsage == invalidUsage { + continue + } + + for _, usage := range cert.ExtKeyUsage { + if requestedUsage == usage { + continue NextRequestedUsage + } else if requestedUsage == ExtKeyUsageServerAuth && + (usage == ExtKeyUsageNetscapeServerGatedCrypto || + usage == ExtKeyUsageMicrosoftServerGatedCrypto) { + // In order to support COMODO + // certificate chains, we have to + // accept Netscape or Microsoft SGC + // usages as equal to ServerAuth. + continue NextRequestedUsage + } + } + + usages[i] = invalidUsage + usagesRemaining-- + if usagesRemaining == 0 { + return false + } + } + } + + return true +} diff --git a/crypto/sm2/x509.go b/crypto/sm2/x509.go new file mode 100644 index 00000000..855e3b13 --- /dev/null +++ b/crypto/sm2/x509.go @@ -0,0 +1,2527 @@ +/* +Copyright Suzhou Tongji Fintech Research Institute 2017 All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// crypto/x509 add sm2 support +package sm2 + +import ( + "bytes" + "crypto" + "crypto/dsa" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/md5" + "crypto/rand" + "crypto/rsa" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "crypto/x509/pkix" + "encoding/asn1" + "encoding/pem" + "errors" + "fmt" + "hash" + "io" + "io/ioutil" + "math/big" + "net" + "os" + "strconv" + "time" + + "golang.org/x/crypto/ripemd160" + "golang.org/x/crypto/sha3" + + "github.com/bytom/crypto/sm3" +) + +// pkixPublicKey reflects a PKIX public key structure. See SubjectPublicKeyInfo +// in RFC 3280. +type pkixPublicKey struct { + Algo pkix.AlgorithmIdentifier + BitString asn1.BitString +} + +// ParsePKIXPublicKey parses a DER encoded public key. These values are +// typically found in PEM blocks with "BEGIN PUBLIC KEY". +// +// Supported key types include RSA, DSA, and ECDSA. Unknown key +// types result in an error. +// +// On success, pub will be of type *rsa.PublicKey, *dsa.PublicKey, +// or *ecdsa.PublicKey. +func ParsePKIXPublicKey(derBytes []byte) (pub interface{}, err error) { + var pki publicKeyInfo + + if rest, err := asn1.Unmarshal(derBytes, &pki); err != nil { + return nil, err + } else if len(rest) != 0 { + return nil, errors.New("x509: trailing data after ASN.1 of public-key") + } + algo := getPublicKeyAlgorithmFromOID(pki.Algorithm.Algorithm) + if algo == UnknownPublicKeyAlgorithm { + return nil, errors.New("x509: unknown public key algorithm") + } + return parsePublicKey(algo, &pki) +} + +func marshalPublicKey(pub interface{}) (publicKeyBytes []byte, publicKeyAlgorithm pkix.AlgorithmIdentifier, err error) { + switch pub := pub.(type) { + case *rsa.PublicKey: + publicKeyBytes, err = asn1.Marshal(rsaPublicKey{ + N: pub.N, + E: pub.E, + }) + if err != nil { + return nil, pkix.AlgorithmIdentifier{}, err + } + publicKeyAlgorithm.Algorithm = oidPublicKeyRSA + // This is a NULL parameters value which is required by + // https://tools.ietf.org/html/rfc3279#section-2.3.1. + publicKeyAlgorithm.Parameters = asn1.RawValue{ + Tag: 5, + } + case *ecdsa.PublicKey: + publicKeyBytes = elliptic.Marshal(pub.Curve, pub.X, pub.Y) + oid, ok := oidFromNamedCurve(pub.Curve) + if !ok { + return nil, pkix.AlgorithmIdentifier{}, errors.New("x509: unsupported elliptic curve") + } + publicKeyAlgorithm.Algorithm = oidPublicKeyECDSA + var paramBytes []byte + paramBytes, err = asn1.Marshal(oid) + if err != nil { + return + } + publicKeyAlgorithm.Parameters.FullBytes = paramBytes + case *PublicKey: + publicKeyBytes = elliptic.Marshal(pub.Curve, pub.X, pub.Y) + oid, ok := oidFromNamedCurve(pub.Curve) + if !ok { + return nil, pkix.AlgorithmIdentifier{}, errors.New("x509: unsupported SM2 curve") + } + publicKeyAlgorithm.Algorithm = oidPublicKeyECDSA + var paramBytes []byte + paramBytes, err = asn1.Marshal(oid) + if err != nil { + return + } + publicKeyAlgorithm.Parameters.FullBytes = paramBytes + default: + return nil, pkix.AlgorithmIdentifier{}, errors.New("x509: only RSA and ECDSA(SM2) public keys supported") + } + + return publicKeyBytes, publicKeyAlgorithm, nil +} + +// MarshalPKIXPublicKey serialises a public key to DER-encoded PKIX format. +func MarshalPKIXPublicKey(pub interface{}) ([]byte, error) { + var publicKeyBytes []byte + var publicKeyAlgorithm pkix.AlgorithmIdentifier + var err error + + if publicKeyBytes, publicKeyAlgorithm, err = marshalPublicKey(pub); err != nil { + return nil, err + } + + pkix := pkixPublicKey{ + Algo: publicKeyAlgorithm, + BitString: asn1.BitString{ + Bytes: publicKeyBytes, + BitLength: 8 * len(publicKeyBytes), + }, + } + + ret, _ := asn1.Marshal(pkix) + return ret, nil +} + +// These structures reflect the ASN.1 structure of X.509 certificates.: + +type certificate struct { + Raw asn1.RawContent + TBSCertificate tbsCertificate + SignatureAlgorithm pkix.AlgorithmIdentifier + SignatureValue asn1.BitString +} + +type tbsCertificate struct { + Raw asn1.RawContent + Version int `asn1:"optional,explicit,default:0,tag:0"` + SerialNumber *big.Int + SignatureAlgorithm pkix.AlgorithmIdentifier + Issuer asn1.RawValue + Validity validity + Subject asn1.RawValue + PublicKey publicKeyInfo + UniqueId asn1.BitString `asn1:"optional,tag:1"` + SubjectUniqueId asn1.BitString `asn1:"optional,tag:2"` + Extensions []pkix.Extension `asn1:"optional,explicit,tag:3"` +} + +type dsaAlgorithmParameters struct { + P, Q, G *big.Int +} + +type dsaSignature struct { + R, S *big.Int +} + +type ecdsaSignature dsaSignature + +type validity struct { + NotBefore, NotAfter time.Time +} + +type publicKeyInfo struct { + Raw asn1.RawContent + Algorithm pkix.AlgorithmIdentifier + PublicKey asn1.BitString +} + +// RFC 5280, 4.2.1.1 +type authKeyId struct { + Id []byte `asn1:"optional,tag:0"` +} + +type SignatureAlgorithm int + +type Hash uint + +func init() { + RegisterHash(MD4, nil) + RegisterHash(MD5, md5.New) + RegisterHash(SHA1, sha1.New) + RegisterHash(SHA224, sha256.New224) + RegisterHash(SHA256, sha256.New) + RegisterHash(SHA384, sha512.New384) + RegisterHash(SHA512, sha512.New) + RegisterHash(MD5SHA1, nil) + RegisterHash(RIPEMD160, ripemd160.New) + RegisterHash(SHA3_224, sha3.New224) + RegisterHash(SHA3_256, sha3.New256) + RegisterHash(SHA3_384, sha3.New384) + RegisterHash(SHA3_512, sha3.New512) + RegisterHash(SHA512_224, sha512.New512_224) + RegisterHash(SHA512_256, sha512.New512_256) + RegisterHash(SM3, sm3.New) +} + +// HashFunc simply returns the value of h so that Hash implements SignerOpts. +func (h Hash) HashFunc() crypto.Hash { + return crypto.Hash(h) +} + +const ( + MD4 Hash = 1 + iota // import golang.org/x/crypto/md4 + MD5 // import crypto/md5 + SHA1 // import crypto/sha1 + SHA224 // import crypto/sha256 + SHA256 // import crypto/sha256 + SHA384 // import crypto/sha512 + SHA512 // import crypto/sha512 + MD5SHA1 // no implementation; MD5+SHA1 used for TLS RSA + RIPEMD160 // import golang.org/x/crypto/ripemd160 + SHA3_224 // import golang.org/x/crypto/sha3 + SHA3_256 // import golang.org/x/crypto/sha3 + SHA3_384 // import golang.org/x/crypto/sha3 + SHA3_512 // import golang.org/x/crypto/sha3 + SHA512_224 // import crypto/sha512 + SHA512_256 // import crypto/sha512 + SM3 + maxHash +) + +var digestSizes = []uint8{ + MD4: 16, + MD5: 16, + SHA1: 20, + SHA224: 28, + SHA256: 32, + SHA384: 48, + SHA512: 64, + SHA512_224: 28, + SHA512_256: 32, + SHA3_224: 28, + SHA3_256: 32, + SHA3_384: 48, + SHA3_512: 64, + MD5SHA1: 36, + RIPEMD160: 20, + SM3: 32, +} + +// Size returns the length, in bytes, of a digest resulting from the given hash +// function. It doesn't require that the hash function in question be linked +// into the program. +func (h Hash) Size() int { + if h > 0 && h < maxHash { + return int(digestSizes[h]) + } + panic("crypto: Size of unknown hash function") +} + +var hashes = make([]func() hash.Hash, maxHash) + +// New returns a new hash.Hash calculating the given hash function. New panics +// if the hash function is not linked into the binary. +func (h Hash) New() hash.Hash { + if h > 0 && h < maxHash { + f := hashes[h] + if f != nil { + return f() + } + } + panic("crypto: requested hash function #" + strconv.Itoa(int(h)) + " is unavailable") +} + +// Available reports whether the given hash function is linked into the binary. +func (h Hash) Available() bool { + return h < maxHash && hashes[h] != nil +} + +// RegisterHash registers a function that returns a new instance of the given +// hash function. This is intended to be called from the init function in +// packages that implement hash functions. +func RegisterHash(h Hash, f func() hash.Hash) { + if h >= maxHash { + panic("crypto: RegisterHash of unknown hash function") + } + hashes[h] = f +} + +const ( + UnknownSignatureAlgorithm SignatureAlgorithm = iota + MD2WithRSA + MD5WithRSA + // SM3WithRSA reserve + SHA1WithRSA + SHA256WithRSA + SHA384WithRSA + SHA512WithRSA + DSAWithSHA1 + DSAWithSHA256 + ECDSAWithSHA1 + ECDSAWithSHA256 + ECDSAWithSHA384 + ECDSAWithSHA512 + SHA256WithRSAPSS + SHA384WithRSAPSS + SHA512WithRSAPSS + SM2WithSM3 + SM2WithSHA1 + SM2WithSHA256 +) + +func (algo SignatureAlgorithm) isRSAPSS() bool { + switch algo { + case SHA256WithRSAPSS, SHA384WithRSAPSS, SHA512WithRSAPSS: + return true + default: + return false + } +} + +var algoName = [...]string{ + MD2WithRSA: "MD2-RSA", + MD5WithRSA: "MD5-RSA", + SHA1WithRSA: "SHA1-RSA", + // SM3WithRSA: "SM3-RSA", reserve + SHA256WithRSA: "SHA256-RSA", + SHA384WithRSA: "SHA384-RSA", + SHA512WithRSA: "SHA512-RSA", + SHA256WithRSAPSS: "SHA256-RSAPSS", + SHA384WithRSAPSS: "SHA384-RSAPSS", + SHA512WithRSAPSS: "SHA512-RSAPSS", + DSAWithSHA1: "DSA-SHA1", + DSAWithSHA256: "DSA-SHA256", + ECDSAWithSHA1: "ECDSA-SHA1", + ECDSAWithSHA256: "ECDSA-SHA256", + ECDSAWithSHA384: "ECDSA-SHA384", + ECDSAWithSHA512: "ECDSA-SHA512", + SM2WithSM3: "SM2-SM3", + SM2WithSHA1: "SM2-SHA1", + SM2WithSHA256: "SM2-SHA256", +} + +func (algo SignatureAlgorithm) String() string { + if 0 < algo && int(algo) < len(algoName) { + return algoName[algo] + } + return strconv.Itoa(int(algo)) +} + +type PublicKeyAlgorithm int + +const ( + UnknownPublicKeyAlgorithm PublicKeyAlgorithm = iota + RSA + DSA + ECDSA +) + +// OIDs for signature algorithms +// +// pkcs-1 OBJECT IDENTIFIER ::= { +// iso(1) member-body(2) us(840) rsadsi(113549) pkcs(1) 1 } +// +// +// RFC 3279 2.2.1 RSA Signature Algorithms +// +// md2WithRSAEncryption OBJECT IDENTIFIER ::= { pkcs-1 2 } +// +// md5WithRSAEncryption OBJECT IDENTIFIER ::= { pkcs-1 4 } +// +// sha-1WithRSAEncryption OBJECT IDENTIFIER ::= { pkcs-1 5 } +// +// dsaWithSha1 OBJECT IDENTIFIER ::= { +// iso(1) member-body(2) us(840) x9-57(10040) x9cm(4) 3 } +// +// RFC 3279 2.2.3 ECDSA Signature Algorithm +// +// ecdsa-with-SHA1 OBJECT IDENTIFIER ::= { +// iso(1) member-body(2) us(840) ansi-x962(10045) +// signatures(4) ecdsa-with-SHA1(1)} +// +// +// RFC 4055 5 PKCS #1 Version 1.5 +// +// sha256WithRSAEncryption OBJECT IDENTIFIER ::= { pkcs-1 11 } +// +// sha384WithRSAEncryption OBJECT IDENTIFIER ::= { pkcs-1 12 } +// +// sha512WithRSAEncryption OBJECT IDENTIFIER ::= { pkcs-1 13 } +// +// +// RFC 5758 3.1 DSA Signature Algorithms +// +// dsaWithSha256 OBJECT IDENTIFIER ::= { +// joint-iso-ccitt(2) country(16) us(840) organization(1) gov(101) +// csor(3) algorithms(4) id-dsa-with-sha2(3) 2} +// +// RFC 5758 3.2 ECDSA Signature Algorithm +// +// ecdsa-with-SHA256 OBJECT IDENTIFIER ::= { iso(1) member-body(2) +// us(840) ansi-X9-62(10045) signatures(4) ecdsa-with-SHA2(3) 2 } +// +// ecdsa-with-SHA384 OBJECT IDENTIFIER ::= { iso(1) member-body(2) +// us(840) ansi-X9-62(10045) signatures(4) ecdsa-with-SHA2(3) 3 } +// +// ecdsa-with-SHA512 OBJECT IDENTIFIER ::= { iso(1) member-body(2) +// us(840) ansi-X9-62(10045) signatures(4) ecdsa-with-SHA2(3) 4 } + +var ( + oidSignatureMD2WithRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 2} + oidSignatureMD5WithRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 4} + oidSignatureSHA1WithRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 5} + oidSignatureSHA256WithRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 11} + oidSignatureSHA384WithRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 12} + oidSignatureSHA512WithRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 13} + oidSignatureRSAPSS = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 10} + oidSignatureDSAWithSHA1 = asn1.ObjectIdentifier{1, 2, 840, 10040, 4, 3} + oidSignatureDSAWithSHA256 = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 3, 2} + oidSignatureECDSAWithSHA1 = asn1.ObjectIdentifier{1, 2, 840, 10045, 4, 1} + oidSignatureECDSAWithSHA256 = asn1.ObjectIdentifier{1, 2, 840, 10045, 4, 3, 2} + oidSignatureECDSAWithSHA384 = asn1.ObjectIdentifier{1, 2, 840, 10045, 4, 3, 3} + oidSignatureECDSAWithSHA512 = asn1.ObjectIdentifier{1, 2, 840, 10045, 4, 3, 4} + oidSignatureSM2WithSM3 = asn1.ObjectIdentifier{1, 2, 156, 10197, 1, 501} + oidSignatureSM2WithSHA1 = asn1.ObjectIdentifier{1, 2, 156, 10197, 1, 502} + oidSignatureSM2WithSHA256 = asn1.ObjectIdentifier{1, 2, 156, 10197, 1, 503} + // oidSignatureSM3WithRSA = asn1.ObjectIdentifier{1, 2, 156, 10197, 1, 504} + + oidSM3 = asn1.ObjectIdentifier{1, 2, 156, 10197, 1, 401, 1} + oidSHA256 = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 2, 1} + oidSHA384 = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 2, 2} + oidSHA512 = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 2, 3} + + oidMGF1 = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 8} + + // oidISOSignatureSHA1WithRSA means the same as oidSignatureSHA1WithRSA + // but it's specified by ISO. Microsoft's makecert.exe has been known + // to produce certificates with this OID. + oidISOSignatureSHA1WithRSA = asn1.ObjectIdentifier{1, 3, 14, 3, 2, 29} +) + +var signatureAlgorithmDetails = []struct { + algo SignatureAlgorithm + oid asn1.ObjectIdentifier + pubKeyAlgo PublicKeyAlgorithm + hash Hash +}{ + {MD2WithRSA, oidSignatureMD2WithRSA, RSA, Hash(0) /* no value for MD2 */}, + {MD5WithRSA, oidSignatureMD5WithRSA, RSA, MD5}, + {SHA1WithRSA, oidSignatureSHA1WithRSA, RSA, SHA1}, + {SHA1WithRSA, oidISOSignatureSHA1WithRSA, RSA, SHA1}, + {SHA256WithRSA, oidSignatureSHA256WithRSA, RSA, SHA256}, + {SHA384WithRSA, oidSignatureSHA384WithRSA, RSA, SHA384}, + {SHA512WithRSA, oidSignatureSHA512WithRSA, RSA, SHA512}, + {SHA256WithRSAPSS, oidSignatureRSAPSS, RSA, SHA256}, + {SHA384WithRSAPSS, oidSignatureRSAPSS, RSA, SHA384}, + {SHA512WithRSAPSS, oidSignatureRSAPSS, RSA, SHA512}, + {DSAWithSHA1, oidSignatureDSAWithSHA1, DSA, SHA1}, + {DSAWithSHA256, oidSignatureDSAWithSHA256, DSA, SHA256}, + {ECDSAWithSHA1, oidSignatureECDSAWithSHA1, ECDSA, SHA1}, + {ECDSAWithSHA256, oidSignatureECDSAWithSHA256, ECDSA, SHA256}, + {ECDSAWithSHA384, oidSignatureECDSAWithSHA384, ECDSA, SHA384}, + {ECDSAWithSHA512, oidSignatureECDSAWithSHA512, ECDSA, SHA512}, + {SM2WithSM3, oidSignatureSM2WithSM3, ECDSA, SM3}, + {SM2WithSHA1, oidSignatureSM2WithSHA1, ECDSA, SHA1}, + {SM2WithSHA256, oidSignatureSM2WithSHA256, ECDSA, SHA256}, + // {SM3WithRSA, oidSignatureSM3WithRSA, RSA, SM3}, +} + +// pssParameters reflects the parameters in an AlgorithmIdentifier that +// specifies RSA PSS. See https://tools.ietf.org/html/rfc3447#appendix-A.2.3 +type pssParameters struct { + // The following three fields are not marked as + // optional because the default values specify SHA-1, + // which is no longer suitable for use in signatures. + Hash pkix.AlgorithmIdentifier `asn1:"explicit,tag:0"` + MGF pkix.AlgorithmIdentifier `asn1:"explicit,tag:1"` + SaltLength int `asn1:"explicit,tag:2"` + TrailerField int `asn1:"optional,explicit,tag:3,default:1"` +} + +// rsaPSSParameters returns an asn1.RawValue suitable for use as the Parameters +// in an AlgorithmIdentifier that specifies RSA PSS. +func rsaPSSParameters(hashFunc Hash) asn1.RawValue { + var hashOID asn1.ObjectIdentifier + + switch hashFunc { + case SHA256: + hashOID = oidSHA256 + case SHA384: + hashOID = oidSHA384 + case SHA512: + hashOID = oidSHA512 + } + + params := pssParameters{ + Hash: pkix.AlgorithmIdentifier{ + Algorithm: hashOID, + Parameters: asn1.RawValue{ + Tag: 5, /* ASN.1 NULL */ + }, + }, + MGF: pkix.AlgorithmIdentifier{ + Algorithm: oidMGF1, + }, + SaltLength: hashFunc.Size(), + TrailerField: 1, + } + + mgf1Params := pkix.AlgorithmIdentifier{ + Algorithm: hashOID, + Parameters: asn1.RawValue{ + Tag: 5, /* ASN.1 NULL */ + }, + } + + var err error + params.MGF.Parameters.FullBytes, err = asn1.Marshal(mgf1Params) + if err != nil { + panic(err) + } + + serialized, err := asn1.Marshal(params) + if err != nil { + panic(err) + } + + return asn1.RawValue{FullBytes: serialized} +} + +func getSignatureAlgorithmFromAI(ai pkix.AlgorithmIdentifier) SignatureAlgorithm { + if !ai.Algorithm.Equal(oidSignatureRSAPSS) { + for _, details := range signatureAlgorithmDetails { + if ai.Algorithm.Equal(details.oid) { + return details.algo + } + } + return UnknownSignatureAlgorithm + } + + // RSA PSS is special because it encodes important parameters + // in the Parameters. + + var params pssParameters + if _, err := asn1.Unmarshal(ai.Parameters.FullBytes, ¶ms); err != nil { + return UnknownSignatureAlgorithm + } + + var mgf1HashFunc pkix.AlgorithmIdentifier + if _, err := asn1.Unmarshal(params.MGF.Parameters.FullBytes, &mgf1HashFunc); err != nil { + return UnknownSignatureAlgorithm + } + + // PSS is greatly overburdened with options. This code forces + // them into three buckets by requiring that the MGF1 hash + // function always match the message hash function (as + // recommended in + // https://tools.ietf.org/html/rfc3447#section-8.1), that the + // salt length matches the hash length, and that the trailer + // field has the default value. + asn1NULL := []byte{0x05, 0x00} + if !bytes.Equal(params.Hash.Parameters.FullBytes, asn1NULL) || + !params.MGF.Algorithm.Equal(oidMGF1) || + !mgf1HashFunc.Algorithm.Equal(params.Hash.Algorithm) || + !bytes.Equal(mgf1HashFunc.Parameters.FullBytes, asn1NULL) || + params.TrailerField != 1 { + return UnknownSignatureAlgorithm + } + + switch { + case params.Hash.Algorithm.Equal(oidSHA256) && params.SaltLength == 32: + return SHA256WithRSAPSS + case params.Hash.Algorithm.Equal(oidSHA384) && params.SaltLength == 48: + return SHA384WithRSAPSS + case params.Hash.Algorithm.Equal(oidSHA512) && params.SaltLength == 64: + return SHA512WithRSAPSS + } + + return UnknownSignatureAlgorithm +} + +// RFC 3279, 2.3 Public Key Algorithms +// +// pkcs-1 OBJECT IDENTIFIER ::== { iso(1) member-body(2) us(840) +// rsadsi(113549) pkcs(1) 1 } +// +// rsaEncryption OBJECT IDENTIFIER ::== { pkcs1-1 1 } +// +// id-dsa OBJECT IDENTIFIER ::== { iso(1) member-body(2) us(840) +// x9-57(10040) x9cm(4) 1 } +// +// RFC 5480, 2.1.1 Unrestricted Algorithm Identifier and Parameters +// +// id-ecPublicKey OBJECT IDENTIFIER ::= { +// iso(1) member-body(2) us(840) ansi-X9-62(10045) keyType(2) 1 } +var ( + oidPublicKeyRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 1} + oidPublicKeyDSA = asn1.ObjectIdentifier{1, 2, 840, 10040, 4, 1} + oidPublicKeyECDSA = asn1.ObjectIdentifier{1, 2, 840, 10045, 2, 1} +) + +func getPublicKeyAlgorithmFromOID(oid asn1.ObjectIdentifier) PublicKeyAlgorithm { + switch { + case oid.Equal(oidPublicKeyRSA): + return RSA + case oid.Equal(oidPublicKeyDSA): + return DSA + case oid.Equal(oidPublicKeyECDSA): + return ECDSA + } + return UnknownPublicKeyAlgorithm +} + +// RFC 5480, 2.1.1.1. Named Curve +// +// secp224r1 OBJECT IDENTIFIER ::= { +// iso(1) identified-organization(3) certicom(132) curve(0) 33 } +// +// secp256r1 OBJECT IDENTIFIER ::= { +// iso(1) member-body(2) us(840) ansi-X9-62(10045) curves(3) +// prime(1) 7 } +// +// secp384r1 OBJECT IDENTIFIER ::= { +// iso(1) identified-organization(3) certicom(132) curve(0) 34 } +// +// secp521r1 OBJECT IDENTIFIER ::= { +// iso(1) identified-organization(3) certicom(132) curve(0) 35 } +// +// NB: secp256r1 is equivalent to prime256v1 +var ( + oidNamedCurveP224 = asn1.ObjectIdentifier{1, 3, 132, 0, 33} + oidNamedCurveP256 = asn1.ObjectIdentifier{1, 2, 840, 10045, 3, 1, 7} + oidNamedCurveP384 = asn1.ObjectIdentifier{1, 3, 132, 0, 34} + oidNamedCurveP521 = asn1.ObjectIdentifier{1, 3, 132, 0, 35} + oidNamedCurveP256SM2 = asn1.ObjectIdentifier{1, 2, 156, 10197, 1, 301} // I get the SM2 ID through parsing the pem file generated by gmssl +) + +func namedCurveFromOID(oid asn1.ObjectIdentifier) elliptic.Curve { + switch { + case oid.Equal(oidNamedCurveP224): + return elliptic.P224() + case oid.Equal(oidNamedCurveP256): + return elliptic.P256() + case oid.Equal(oidNamedCurveP384): + return elliptic.P384() + case oid.Equal(oidNamedCurveP521): + return elliptic.P521() + case oid.Equal(oidNamedCurveP256SM2): + return P256Sm2() + } + return nil +} + +func oidFromNamedCurve(curve elliptic.Curve) (asn1.ObjectIdentifier, bool) { + switch curve { + case elliptic.P224(): + return oidNamedCurveP224, true + case elliptic.P256(): + return oidNamedCurveP256, true + case elliptic.P384(): + return oidNamedCurveP384, true + case elliptic.P521(): + return oidNamedCurveP521, true + case P256Sm2(): + return oidNamedCurveP256SM2, true + } + return nil, false +} + +// KeyUsage represents the set of actions that are valid for a given key. It's +// a bitmap of the KeyUsage* constants. +type KeyUsage int + +const ( + KeyUsageDigitalSignature KeyUsage = 1 << iota + KeyUsageContentCommitment + KeyUsageKeyEncipherment + KeyUsageDataEncipherment + KeyUsageKeyAgreement + KeyUsageCertSign + KeyUsageCRLSign + KeyUsageEncipherOnly + KeyUsageDecipherOnly +) + +// RFC 5280, 4.2.1.12 Extended Key Usage +// +// anyExtendedKeyUsage OBJECT IDENTIFIER ::= { id-ce-extKeyUsage 0 } +// +// id-kp OBJECT IDENTIFIER ::= { id-pkix 3 } +// +// id-kp-serverAuth OBJECT IDENTIFIER ::= { id-kp 1 } +// id-kp-clientAuth OBJECT IDENTIFIER ::= { id-kp 2 } +// id-kp-codeSigning OBJECT IDENTIFIER ::= { id-kp 3 } +// id-kp-emailProtection OBJECT IDENTIFIER ::= { id-kp 4 } +// id-kp-timeStamping OBJECT IDENTIFIER ::= { id-kp 8 } +// id-kp-OCSPSigning OBJECT IDENTIFIER ::= { id-kp 9 } +var ( + oidExtKeyUsageAny = asn1.ObjectIdentifier{2, 5, 29, 37, 0} + oidExtKeyUsageServerAuth = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 1} + oidExtKeyUsageClientAuth = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 2} + oidExtKeyUsageCodeSigning = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 3} + oidExtKeyUsageEmailProtection = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 4} + oidExtKeyUsageIPSECEndSystem = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 5} + oidExtKeyUsageIPSECTunnel = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 6} + oidExtKeyUsageIPSECUser = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 7} + oidExtKeyUsageTimeStamping = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 8} + oidExtKeyUsageOCSPSigning = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 9} + oidExtKeyUsageMicrosoftServerGatedCrypto = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 311, 10, 3, 3} + oidExtKeyUsageNetscapeServerGatedCrypto = asn1.ObjectIdentifier{2, 16, 840, 1, 113730, 4, 1} +) + +// ExtKeyUsage represents an extended set of actions that are valid for a given key. +// Each of the ExtKeyUsage* constants define a unique action. +type ExtKeyUsage int + +const ( + ExtKeyUsageAny ExtKeyUsage = iota + ExtKeyUsageServerAuth + ExtKeyUsageClientAuth + ExtKeyUsageCodeSigning + ExtKeyUsageEmailProtection + ExtKeyUsageIPSECEndSystem + ExtKeyUsageIPSECTunnel + ExtKeyUsageIPSECUser + ExtKeyUsageTimeStamping + ExtKeyUsageOCSPSigning + ExtKeyUsageMicrosoftServerGatedCrypto + ExtKeyUsageNetscapeServerGatedCrypto +) + +// extKeyUsageOIDs contains the mapping between an ExtKeyUsage and its OID. +var extKeyUsageOIDs = []struct { + extKeyUsage ExtKeyUsage + oid asn1.ObjectIdentifier +}{ + {ExtKeyUsageAny, oidExtKeyUsageAny}, + {ExtKeyUsageServerAuth, oidExtKeyUsageServerAuth}, + {ExtKeyUsageClientAuth, oidExtKeyUsageClientAuth}, + {ExtKeyUsageCodeSigning, oidExtKeyUsageCodeSigning}, + {ExtKeyUsageEmailProtection, oidExtKeyUsageEmailProtection}, + {ExtKeyUsageIPSECEndSystem, oidExtKeyUsageIPSECEndSystem}, + {ExtKeyUsageIPSECTunnel, oidExtKeyUsageIPSECTunnel}, + {ExtKeyUsageIPSECUser, oidExtKeyUsageIPSECUser}, + {ExtKeyUsageTimeStamping, oidExtKeyUsageTimeStamping}, + {ExtKeyUsageOCSPSigning, oidExtKeyUsageOCSPSigning}, + {ExtKeyUsageMicrosoftServerGatedCrypto, oidExtKeyUsageMicrosoftServerGatedCrypto}, + {ExtKeyUsageNetscapeServerGatedCrypto, oidExtKeyUsageNetscapeServerGatedCrypto}, +} + +func extKeyUsageFromOID(oid asn1.ObjectIdentifier) (eku ExtKeyUsage, ok bool) { + for _, pair := range extKeyUsageOIDs { + if oid.Equal(pair.oid) { + return pair.extKeyUsage, true + } + } + return +} + +func oidFromExtKeyUsage(eku ExtKeyUsage) (oid asn1.ObjectIdentifier, ok bool) { + for _, pair := range extKeyUsageOIDs { + if eku == pair.extKeyUsage { + return pair.oid, true + } + } + return +} + +// A Certificate represents an X.509 certificate. +type Certificate struct { + Raw []byte // Complete ASN.1 DER content (certificate, signature algorithm and signature). + RawTBSCertificate []byte // Certificate part of raw ASN.1 DER content. + RawSubjectPublicKeyInfo []byte // DER encoded SubjectPublicKeyInfo. + RawSubject []byte // DER encoded Subject + RawIssuer []byte // DER encoded Issuer + + Signature []byte + SignatureAlgorithm SignatureAlgorithm + + PublicKeyAlgorithm PublicKeyAlgorithm + PublicKey interface{} + + Version int + SerialNumber *big.Int + Issuer pkix.Name + Subject pkix.Name + NotBefore, NotAfter time.Time // Validity bounds. + KeyUsage KeyUsage + + // Extensions contains raw X.509 extensions. When parsing certificates, + // this can be used to extract non-critical extensions that are not + // parsed by this package. When marshaling certificates, the Extensions + // field is ignored, see ExtraExtensions. + Extensions []pkix.Extension + + // ExtraExtensions contains extensions to be copied, raw, into any + // marshaled certificates. Values override any extensions that would + // otherwise be produced based on the other fields. The ExtraExtensions + // field is not populated when parsing certificates, see Extensions. + ExtraExtensions []pkix.Extension + + // UnhandledCriticalExtensions contains a list of extension IDs that + // were not (fully) processed when parsing. Verify will fail if this + // slice is non-empty, unless verification is delegated to an OS + // library which understands all the critical extensions. + // + // Users can access these extensions using Extensions and can remove + // elements from this slice if they believe that they have been + // handled. + UnhandledCriticalExtensions []asn1.ObjectIdentifier + + ExtKeyUsage []ExtKeyUsage // Sequence of extended key usages. + UnknownExtKeyUsage []asn1.ObjectIdentifier // Encountered extended key usages unknown to this package. + + BasicConstraintsValid bool // if true then the next two fields are valid. + IsCA bool + MaxPathLen int + // MaxPathLenZero indicates that BasicConstraintsValid==true and + // MaxPathLen==0 should be interpreted as an actual maximum path length + // of zero. Otherwise, that combination is interpreted as MaxPathLen + // not being set. + MaxPathLenZero bool + + SubjectKeyId []byte + AuthorityKeyId []byte + + // RFC 5280, 4.2.2.1 (Authority Information Access) + OCSPServer []string + IssuingCertificateURL []string + + // Subject Alternate Name values + DNSNames []string + EmailAddresses []string + IPAddresses []net.IP + + // Name constraints + PermittedDNSDomainsCritical bool // if true then the name constraints are marked critical. + PermittedDNSDomains []string + + // CRL Distribution Points + CRLDistributionPoints []string + + PolicyIdentifiers []asn1.ObjectIdentifier +} + +// ErrUnsupportedAlgorithm results from attempting to perform an operation that +// involves algorithms that are not currently implemented. +var ErrUnsupportedAlgorithm = errors.New("x509: cannot verify signature: algorithm unimplemented") + +// An InsecureAlgorithmError +type InsecureAlgorithmError SignatureAlgorithm + +func (e InsecureAlgorithmError) Error() string { + return fmt.Sprintf("x509: cannot verify signature: insecure algorithm %v", SignatureAlgorithm(e)) +} + +// ConstraintViolationError results when a requested usage is not permitted by +// a certificate. For example: checking a signature when the public key isn't a +// certificate signing key. +type ConstraintViolationError struct{} + +func (ConstraintViolationError) Error() string { + return "x509: invalid signature: parent certificate cannot sign this kind of certificate" +} + +func (c *Certificate) Equal(other *Certificate) bool { + return bytes.Equal(c.Raw, other.Raw) +} + +// Entrust have a broken root certificate (CN=Entrust.net Certification +// Authority (2048)) which isn't marked as a CA certificate and is thus invalid +// according to PKIX. +// We recognise this certificate by its SubjectPublicKeyInfo and exempt it +// from the Basic Constraints requirement. +// See http://www.entrust.net/knowledge-base/technote.cfm?tn=7869 +// +// TODO(agl): remove this hack once their reissued root is sufficiently +// widespread. +var entrustBrokenSPKI = []byte{ + 0x30, 0x82, 0x01, 0x22, 0x30, 0x0d, 0x06, 0x09, + 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, + 0x01, 0x05, 0x00, 0x03, 0x82, 0x01, 0x0f, 0x00, + 0x30, 0x82, 0x01, 0x0a, 0x02, 0x82, 0x01, 0x01, + 0x00, 0x97, 0xa3, 0x2d, 0x3c, 0x9e, 0xde, 0x05, + 0xda, 0x13, 0xc2, 0x11, 0x8d, 0x9d, 0x8e, 0xe3, + 0x7f, 0xc7, 0x4b, 0x7e, 0x5a, 0x9f, 0xb3, 0xff, + 0x62, 0xab, 0x73, 0xc8, 0x28, 0x6b, 0xba, 0x10, + 0x64, 0x82, 0x87, 0x13, 0xcd, 0x57, 0x18, 0xff, + 0x28, 0xce, 0xc0, 0xe6, 0x0e, 0x06, 0x91, 0x50, + 0x29, 0x83, 0xd1, 0xf2, 0xc3, 0x2a, 0xdb, 0xd8, + 0xdb, 0x4e, 0x04, 0xcc, 0x00, 0xeb, 0x8b, 0xb6, + 0x96, 0xdc, 0xbc, 0xaa, 0xfa, 0x52, 0x77, 0x04, + 0xc1, 0xdb, 0x19, 0xe4, 0xae, 0x9c, 0xfd, 0x3c, + 0x8b, 0x03, 0xef, 0x4d, 0xbc, 0x1a, 0x03, 0x65, + 0xf9, 0xc1, 0xb1, 0x3f, 0x72, 0x86, 0xf2, 0x38, + 0xaa, 0x19, 0xae, 0x10, 0x88, 0x78, 0x28, 0xda, + 0x75, 0xc3, 0x3d, 0x02, 0x82, 0x02, 0x9c, 0xb9, + 0xc1, 0x65, 0x77, 0x76, 0x24, 0x4c, 0x98, 0xf7, + 0x6d, 0x31, 0x38, 0xfb, 0xdb, 0xfe, 0xdb, 0x37, + 0x02, 0x76, 0xa1, 0x18, 0x97, 0xa6, 0xcc, 0xde, + 0x20, 0x09, 0x49, 0x36, 0x24, 0x69, 0x42, 0xf6, + 0xe4, 0x37, 0x62, 0xf1, 0x59, 0x6d, 0xa9, 0x3c, + 0xed, 0x34, 0x9c, 0xa3, 0x8e, 0xdb, 0xdc, 0x3a, + 0xd7, 0xf7, 0x0a, 0x6f, 0xef, 0x2e, 0xd8, 0xd5, + 0x93, 0x5a, 0x7a, 0xed, 0x08, 0x49, 0x68, 0xe2, + 0x41, 0xe3, 0x5a, 0x90, 0xc1, 0x86, 0x55, 0xfc, + 0x51, 0x43, 0x9d, 0xe0, 0xb2, 0xc4, 0x67, 0xb4, + 0xcb, 0x32, 0x31, 0x25, 0xf0, 0x54, 0x9f, 0x4b, + 0xd1, 0x6f, 0xdb, 0xd4, 0xdd, 0xfc, 0xaf, 0x5e, + 0x6c, 0x78, 0x90, 0x95, 0xde, 0xca, 0x3a, 0x48, + 0xb9, 0x79, 0x3c, 0x9b, 0x19, 0xd6, 0x75, 0x05, + 0xa0, 0xf9, 0x88, 0xd7, 0xc1, 0xe8, 0xa5, 0x09, + 0xe4, 0x1a, 0x15, 0xdc, 0x87, 0x23, 0xaa, 0xb2, + 0x75, 0x8c, 0x63, 0x25, 0x87, 0xd8, 0xf8, 0x3d, + 0xa6, 0xc2, 0xcc, 0x66, 0xff, 0xa5, 0x66, 0x68, + 0x55, 0x02, 0x03, 0x01, 0x00, 0x01, +} + +// CheckSignatureFrom verifies that the signature on c is a valid signature +// from parent. +func (c *Certificate) CheckSignatureFrom(parent *Certificate) error { + // RFC 5280, 4.2.1.9: + // "If the basic constraints extension is not present in a version 3 + // certificate, or the extension is present but the cA boolean is not + // asserted, then the certified public key MUST NOT be used to verify + // certificate signatures." + // (except for Entrust, see comment above entrustBrokenSPKI) + if (parent.Version == 3 && !parent.BasicConstraintsValid || + parent.BasicConstraintsValid && !parent.IsCA) && + !bytes.Equal(c.RawSubjectPublicKeyInfo, entrustBrokenSPKI) { + return ConstraintViolationError{} + } + + if parent.KeyUsage != 0 && parent.KeyUsage&KeyUsageCertSign == 0 { + return ConstraintViolationError{} + } + + if parent.PublicKeyAlgorithm == UnknownPublicKeyAlgorithm { + return ErrUnsupportedAlgorithm + } + + // TODO(agl): don't ignore the path length constraint. + + return parent.CheckSignature(c.SignatureAlgorithm, c.RawTBSCertificate, c.Signature) +} + +// CheckSignature verifies that signature is a valid signature over signed from +// c's public key. +func (c *Certificate) CheckSignature(algo SignatureAlgorithm, signed, signature []byte) error { + return checkSignature(algo, signed, signature, c.PublicKey) +} + +// CheckSignature verifies that signature is a valid signature over signed from +// a crypto.PublicKey. +func checkSignature(algo SignatureAlgorithm, signed, signature []byte, publicKey crypto.PublicKey) (err error) { + var hashType Hash + + switch algo { + case SHA1WithRSA, DSAWithSHA1, ECDSAWithSHA1, SM2WithSHA1: + hashType = SHA1 + case SHA256WithRSA, SHA256WithRSAPSS, DSAWithSHA256, ECDSAWithSHA256, SM2WithSHA256: + hashType = SHA256 + case SHA384WithRSA, SHA384WithRSAPSS, ECDSAWithSHA384: + hashType = SHA384 + case SHA512WithRSA, SHA512WithRSAPSS, ECDSAWithSHA512: + hashType = SHA512 + case MD2WithRSA, MD5WithRSA: + return InsecureAlgorithmError(algo) + case SM2WithSM3: // SM3WithRSA reserve + hashType = SM3 + default: + return ErrUnsupportedAlgorithm + } + + if !hashType.Available() { + return ErrUnsupportedAlgorithm + } + h := hashType.New() + + h.Write(signed) + digest := h.Sum(nil) + + switch pub := publicKey.(type) { + case *rsa.PublicKey: + if algo.isRSAPSS() { + return rsa.VerifyPSS(pub, crypto.Hash(hashType), digest, signature, &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash}) + } else { + return rsa.VerifyPKCS1v15(pub, crypto.Hash(hashType), digest, signature) + } + case *dsa.PublicKey: + dsaSig := new(dsaSignature) + if rest, err := asn1.Unmarshal(signature, dsaSig); err != nil { + return err + } else if len(rest) != 0 { + return errors.New("x509: trailing data after DSA signature") + } + if dsaSig.R.Sign() <= 0 || dsaSig.S.Sign() <= 0 { + return errors.New("x509: DSA signature contained zero or negative values") + } + if !dsa.Verify(pub, digest, dsaSig.R, dsaSig.S) { + return errors.New("x509: DSA verification failure") + } + return + case *ecdsa.PublicKey: + ecdsaSig := new(ecdsaSignature) + if rest, err := asn1.Unmarshal(signature, ecdsaSig); err != nil { + return err + } else if len(rest) != 0 { + return errors.New("x509: trailing data after ECDSA signature") + } + if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 { + return errors.New("x509: ECDSA signature contained zero or negative values") + } + switch pub.Curve { + case P256Sm2(): + if !Verify(&PublicKey{ + Curve: pub.Curve, + X: pub.X, + Y: pub.Y, + }, digest, ecdsaSig.R, ecdsaSig.S) { + return errors.New("x509: SM2 verification failure") + } + default: + if !ecdsa.Verify(pub, digest, ecdsaSig.R, ecdsaSig.S) { + return errors.New("x509: ECDSA verification failure") + } + } + return + } + return ErrUnsupportedAlgorithm +} + +// CheckCRLSignature checks that the signature in crl is from c. +func (c *Certificate) CheckCRLSignature(crl *pkix.CertificateList) error { + algo := getSignatureAlgorithmFromAI(crl.SignatureAlgorithm) + return c.CheckSignature(algo, crl.TBSCertList.Raw, crl.SignatureValue.RightAlign()) +} + +type UnhandledCriticalExtension struct{} + +func (h UnhandledCriticalExtension) Error() string { + return "x509: unhandled critical extension" +} + +type basicConstraints struct { + IsCA bool `asn1:"optional"` + MaxPathLen int `asn1:"optional,default:-1"` +} + +// RFC 5280 4.2.1.4 +type policyInformation struct { + Policy asn1.ObjectIdentifier + // policyQualifiers omitted +} + +// RFC 5280, 4.2.1.10 +type nameConstraints struct { + Permitted []generalSubtree `asn1:"optional,tag:0"` + Excluded []generalSubtree `asn1:"optional,tag:1"` +} + +type generalSubtree struct { + Name string `asn1:"tag:2,optional,ia5"` +} + +// RFC 5280, 4.2.2.1 +type authorityInfoAccess struct { + Method asn1.ObjectIdentifier + Location asn1.RawValue +} + +// RFC 5280, 4.2.1.14 +type distributionPoint struct { + DistributionPoint distributionPointName `asn1:"optional,tag:0"` + Reason asn1.BitString `asn1:"optional,tag:1"` + CRLIssuer asn1.RawValue `asn1:"optional,tag:2"` +} + +type distributionPointName struct { + FullName asn1.RawValue `asn1:"optional,tag:0"` + RelativeName pkix.RDNSequence `asn1:"optional,tag:1"` +} + +// asn1Null is the ASN.1 encoding of a NULL value. +var asn1Null = []byte{5, 0} + +func parsePublicKey(algo PublicKeyAlgorithm, keyData *publicKeyInfo) (interface{}, error) { + asn1Data := keyData.PublicKey.RightAlign() + switch algo { + case RSA: + // RSA public keys must have a NULL in the parameters + // (https://tools.ietf.org/html/rfc3279#section-2.3.1). + if !bytes.Equal(keyData.Algorithm.Parameters.FullBytes, asn1Null) { + return nil, errors.New("x509: RSA key missing NULL parameters") + } + + p := new(rsaPublicKey) + rest, err := asn1.Unmarshal(asn1Data, p) + if err != nil { + return nil, err + } + if len(rest) != 0 { + return nil, errors.New("x509: trailing data after RSA public key") + } + + if p.N.Sign() <= 0 { + return nil, errors.New("x509: RSA modulus is not a positive number") + } + if p.E <= 0 { + return nil, errors.New("x509: RSA public exponent is not a positive number") + } + + pub := &rsa.PublicKey{ + E: p.E, + N: p.N, + } + return pub, nil + case DSA: + var p *big.Int + rest, err := asn1.Unmarshal(asn1Data, &p) + if err != nil { + return nil, err + } + if len(rest) != 0 { + return nil, errors.New("x509: trailing data after DSA public key") + } + paramsData := keyData.Algorithm.Parameters.FullBytes + params := new(dsaAlgorithmParameters) + rest, err = asn1.Unmarshal(paramsData, params) + if err != nil { + return nil, err + } + if len(rest) != 0 { + return nil, errors.New("x509: trailing data after DSA parameters") + } + if p.Sign() <= 0 || params.P.Sign() <= 0 || params.Q.Sign() <= 0 || params.G.Sign() <= 0 { + return nil, errors.New("x509: zero or negative DSA parameter") + } + pub := &dsa.PublicKey{ + Parameters: dsa.Parameters{ + P: params.P, + Q: params.Q, + G: params.G, + }, + Y: p, + } + return pub, nil + case ECDSA: + paramsData := keyData.Algorithm.Parameters.FullBytes + namedCurveOID := new(asn1.ObjectIdentifier) + rest, err := asn1.Unmarshal(paramsData, namedCurveOID) + if err != nil { + return nil, err + } + if len(rest) != 0 { + return nil, errors.New("x509: trailing data after ECDSA parameters") + } + namedCurve := namedCurveFromOID(*namedCurveOID) + if namedCurve == nil { + return nil, errors.New("x509: unsupported elliptic curve") + } + x, y := elliptic.Unmarshal(namedCurve, asn1Data) + if x == nil { + return nil, errors.New("x509: failed to unmarshal elliptic curve point") + } + pub := &ecdsa.PublicKey{ + Curve: namedCurve, + X: x, + Y: y, + } + return pub, nil + default: + return nil, nil + } +} + +func parseSANExtension(value []byte) (dnsNames, emailAddresses []string, ipAddresses []net.IP, err error) { + // RFC 5280, 4.2.1.6 + + // SubjectAltName ::= GeneralNames + // + // GeneralNames ::= SEQUENCE SIZE (1..MAX) OF GeneralName + // + // GeneralName ::= CHOICE { + // otherName [0] OtherName, + // rfc822Name [1] IA5String, + // dNSName [2] IA5String, + // x400Address [3] ORAddress, + // directoryName [4] Name, + // ediPartyName [5] EDIPartyName, + // uniformResourceIdentifier [6] IA5String, + // iPAddress [7] OCTET STRING, + // registeredID [8] OBJECT IDENTIFIER } + var seq asn1.RawValue + var rest []byte + if rest, err = asn1.Unmarshal(value, &seq); err != nil { + return + } else if len(rest) != 0 { + err = errors.New("x509: trailing data after X.509 extension") + return + } + if !seq.IsCompound || seq.Tag != 16 || seq.Class != 0 { + err = asn1.StructuralError{Msg: "bad SAN sequence"} + return + } + + rest = seq.Bytes + for len(rest) > 0 { + var v asn1.RawValue + rest, err = asn1.Unmarshal(rest, &v) + if err != nil { + return + } + switch v.Tag { + case 1: + emailAddresses = append(emailAddresses, string(v.Bytes)) + case 2: + dnsNames = append(dnsNames, string(v.Bytes)) + case 7: + switch len(v.Bytes) { + case net.IPv4len, net.IPv6len: + ipAddresses = append(ipAddresses, v.Bytes) + default: + err = errors.New("x509: certificate contained IP address of length " + strconv.Itoa(len(v.Bytes))) + return + } + } + } + + return +} + +func parseCertificate(in *certificate) (*Certificate, error) { + out := new(Certificate) + out.Raw = in.Raw + out.RawTBSCertificate = in.TBSCertificate.Raw + out.RawSubjectPublicKeyInfo = in.TBSCertificate.PublicKey.Raw + out.RawSubject = in.TBSCertificate.Subject.FullBytes + out.RawIssuer = in.TBSCertificate.Issuer.FullBytes + + out.Signature = in.SignatureValue.RightAlign() + out.SignatureAlgorithm = + getSignatureAlgorithmFromAI(in.TBSCertificate.SignatureAlgorithm) + + out.PublicKeyAlgorithm = + getPublicKeyAlgorithmFromOID(in.TBSCertificate.PublicKey.Algorithm.Algorithm) + var err error + out.PublicKey, err = parsePublicKey(out.PublicKeyAlgorithm, &in.TBSCertificate.PublicKey) + if err != nil { + return nil, err + } + + out.Version = in.TBSCertificate.Version + 1 + out.SerialNumber = in.TBSCertificate.SerialNumber + + var issuer, subject pkix.RDNSequence + if rest, err := asn1.Unmarshal(in.TBSCertificate.Subject.FullBytes, &subject); err != nil { + return nil, err + } else if len(rest) != 0 { + return nil, errors.New("x509: trailing data after X.509 subject") + } + if rest, err := asn1.Unmarshal(in.TBSCertificate.Issuer.FullBytes, &issuer); err != nil { + return nil, err + } else if len(rest) != 0 { + return nil, errors.New("x509: trailing data after X.509 subject") + } + + out.Issuer.FillFromRDNSequence(&issuer) + out.Subject.FillFromRDNSequence(&subject) + + out.NotBefore = in.TBSCertificate.Validity.NotBefore + out.NotAfter = in.TBSCertificate.Validity.NotAfter + + for _, e := range in.TBSCertificate.Extensions { + out.Extensions = append(out.Extensions, e) + unhandled := false + + if len(e.Id) == 4 && e.Id[0] == 2 && e.Id[1] == 5 && e.Id[2] == 29 { + switch e.Id[3] { + case 15: + // RFC 5280, 4.2.1.3 + var usageBits asn1.BitString + if rest, err := asn1.Unmarshal(e.Value, &usageBits); err != nil { + return nil, err + } else if len(rest) != 0 { + return nil, errors.New("x509: trailing data after X.509 KeyUsage") + } + + var usage int + for i := 0; i < 9; i++ { + if usageBits.At(i) != 0 { + usage |= 1 << uint(i) + } + } + out.KeyUsage = KeyUsage(usage) + + case 19: + // RFC 5280, 4.2.1.9 + var constraints basicConstraints + if rest, err := asn1.Unmarshal(e.Value, &constraints); err != nil { + return nil, err + } else if len(rest) != 0 { + return nil, errors.New("x509: trailing data after X.509 BasicConstraints") + } + + out.BasicConstraintsValid = true + out.IsCA = constraints.IsCA + out.MaxPathLen = constraints.MaxPathLen + out.MaxPathLenZero = out.MaxPathLen == 0 + + case 17: + out.DNSNames, out.EmailAddresses, out.IPAddresses, err = parseSANExtension(e.Value) + if err != nil { + return nil, err + } + + if len(out.DNSNames) == 0 && len(out.EmailAddresses) == 0 && len(out.IPAddresses) == 0 { + // If we didn't parse anything then we do the critical check, below. + unhandled = true + } + + case 30: + // RFC 5280, 4.2.1.10 + + // NameConstraints ::= SEQUENCE { + // permittedSubtrees [0] GeneralSubtrees OPTIONAL, + // excludedSubtrees [1] GeneralSubtrees OPTIONAL } + // + // GeneralSubtrees ::= SEQUENCE SIZE (1..MAX) OF GeneralSubtree + // + // GeneralSubtree ::= SEQUENCE { + // base GeneralName, + // minimum [0] BaseDistance DEFAULT 0, + // maximum [1] BaseDistance OPTIONAL } + // + // BaseDistance ::= INTEGER (0..MAX) + + var constraints nameConstraints + if rest, err := asn1.Unmarshal(e.Value, &constraints); err != nil { + return nil, err + } else if len(rest) != 0 { + return nil, errors.New("x509: trailing data after X.509 NameConstraints") + } + + if len(constraints.Excluded) > 0 && e.Critical { + return out, UnhandledCriticalExtension{} + } + + for _, subtree := range constraints.Permitted { + if len(subtree.Name) == 0 { + if e.Critical { + return out, UnhandledCriticalExtension{} + } + continue + } + out.PermittedDNSDomains = append(out.PermittedDNSDomains, subtree.Name) + } + + case 31: + // RFC 5280, 4.2.1.13 + + // CRLDistributionPoints ::= SEQUENCE SIZE (1..MAX) OF DistributionPoint + // + // DistributionPoint ::= SEQUENCE { + // distributionPoint [0] DistributionPointName OPTIONAL, + // reasons [1] ReasonFlags OPTIONAL, + // cRLIssuer [2] GeneralNames OPTIONAL } + // + // DistributionPointName ::= CHOICE { + // fullName [0] GeneralNames, + // nameRelativeToCRLIssuer [1] RelativeDistinguishedName } + + var cdp []distributionPoint + if rest, err := asn1.Unmarshal(e.Value, &cdp); err != nil { + return nil, err + } else if len(rest) != 0 { + return nil, errors.New("x509: trailing data after X.509 CRL distribution point") + } + + for _, dp := range cdp { + // Per RFC 5280, 4.2.1.13, one of distributionPoint or cRLIssuer may be empty. + if len(dp.DistributionPoint.FullName.Bytes) == 0 { + continue + } + + var n asn1.RawValue + if _, err := asn1.Unmarshal(dp.DistributionPoint.FullName.Bytes, &n); err != nil { + return nil, err + } + // Trailing data after the fullName is + // allowed because other elements of + // the SEQUENCE can appear. + + if n.Tag == 6 { + out.CRLDistributionPoints = append(out.CRLDistributionPoints, string(n.Bytes)) + } + } + + case 35: + // RFC 5280, 4.2.1.1 + var a authKeyId + if rest, err := asn1.Unmarshal(e.Value, &a); err != nil { + return nil, err + } else if len(rest) != 0 { + return nil, errors.New("x509: trailing data after X.509 authority key-id") + } + out.AuthorityKeyId = a.Id + + case 37: + // RFC 5280, 4.2.1.12. Extended Key Usage + + // id-ce-extKeyUsage OBJECT IDENTIFIER ::= { id-ce 37 } + // + // ExtKeyUsageSyntax ::= SEQUENCE SIZE (1..MAX) OF KeyPurposeId + // + // KeyPurposeId ::= OBJECT IDENTIFIER + + var keyUsage []asn1.ObjectIdentifier + if rest, err := asn1.Unmarshal(e.Value, &keyUsage); err != nil { + return nil, err + } else if len(rest) != 0 { + return nil, errors.New("x509: trailing data after X.509 ExtendedKeyUsage") + } + + for _, u := range keyUsage { + if extKeyUsage, ok := extKeyUsageFromOID(u); ok { + out.ExtKeyUsage = append(out.ExtKeyUsage, extKeyUsage) + } else { + out.UnknownExtKeyUsage = append(out.UnknownExtKeyUsage, u) + } + } + + case 14: + // RFC 5280, 4.2.1.2 + var keyid []byte + if rest, err := asn1.Unmarshal(e.Value, &keyid); err != nil { + return nil, err + } else if len(rest) != 0 { + return nil, errors.New("x509: trailing data after X.509 key-id") + } + out.SubjectKeyId = keyid + + case 32: + // RFC 5280 4.2.1.4: Certificate Policies + var policies []policyInformation + if rest, err := asn1.Unmarshal(e.Value, &policies); err != nil { + return nil, err + } else if len(rest) != 0 { + return nil, errors.New("x509: trailing data after X.509 certificate policies") + } + out.PolicyIdentifiers = make([]asn1.ObjectIdentifier, len(policies)) + for i, policy := range policies { + out.PolicyIdentifiers[i] = policy.Policy + } + + default: + // Unknown extensions are recorded if critical. + unhandled = true + } + } else if e.Id.Equal(oidExtensionAuthorityInfoAccess) { + // RFC 5280 4.2.2.1: Authority Information Access + var aia []authorityInfoAccess + if rest, err := asn1.Unmarshal(e.Value, &aia); err != nil { + return nil, err + } else if len(rest) != 0 { + return nil, errors.New("x509: trailing data after X.509 authority information") + } + + for _, v := range aia { + // GeneralName: uniformResourceIdentifier [6] IA5String + if v.Location.Tag != 6 { + continue + } + if v.Method.Equal(oidAuthorityInfoAccessOcsp) { + out.OCSPServer = append(out.OCSPServer, string(v.Location.Bytes)) + } else if v.Method.Equal(oidAuthorityInfoAccessIssuers) { + out.IssuingCertificateURL = append(out.IssuingCertificateURL, string(v.Location.Bytes)) + } + } + } else { + // Unknown extensions are recorded if critical. + unhandled = true + } + + if e.Critical && unhandled { + out.UnhandledCriticalExtensions = append(out.UnhandledCriticalExtensions, e.Id) + } + } + + return out, nil +} + +// ParseCertificate parses a single certificate from the given ASN.1 DER data. +func ParseCertificate(asn1Data []byte) (*Certificate, error) { + var cert certificate + rest, err := asn1.Unmarshal(asn1Data, &cert) + if err != nil { + return nil, err + } + if len(rest) > 0 { + return nil, asn1.SyntaxError{Msg: "trailing data"} + } + + return parseCertificate(&cert) +} + +// ParseCertificates parses one or more certificates from the given ASN.1 DER +// data. The certificates must be concatenated with no intermediate padding. +func ParseCertificates(asn1Data []byte) ([]*Certificate, error) { + var v []*certificate + + for len(asn1Data) > 0 { + cert := new(certificate) + var err error + asn1Data, err = asn1.Unmarshal(asn1Data, cert) + if err != nil { + return nil, err + } + v = append(v, cert) + } + + ret := make([]*Certificate, len(v)) + for i, ci := range v { + cert, err := parseCertificate(ci) + if err != nil { + return nil, err + } + ret[i] = cert + } + + return ret, nil +} + +func reverseBitsInAByte(in byte) byte { + b1 := in>>4 | in<<4 + b2 := b1>>2&0x33 | b1<<2&0xcc + b3 := b2>>1&0x55 | b2<<1&0xaa + return b3 +} + +// asn1BitLength returns the bit-length of bitString by considering the +// most-significant bit in a byte to be the "first" bit. This convention +// matches ASN.1, but differs from almost everything else. +func asn1BitLength(bitString []byte) int { + bitLen := len(bitString) * 8 + + for i := range bitString { + b := bitString[len(bitString)-i-1] + + for bit := uint(0); bit < 8; bit++ { + if (b>>bit)&1 == 1 { + return bitLen + } + bitLen-- + } + } + + return 0 +} + +var ( + oidExtensionSubjectKeyId = []int{2, 5, 29, 14} + oidExtensionKeyUsage = []int{2, 5, 29, 15} + oidExtensionExtendedKeyUsage = []int{2, 5, 29, 37} + oidExtensionAuthorityKeyId = []int{2, 5, 29, 35} + oidExtensionBasicConstraints = []int{2, 5, 29, 19} + oidExtensionSubjectAltName = []int{2, 5, 29, 17} + oidExtensionCertificatePolicies = []int{2, 5, 29, 32} + oidExtensionNameConstraints = []int{2, 5, 29, 30} + oidExtensionCRLDistributionPoints = []int{2, 5, 29, 31} + oidExtensionAuthorityInfoAccess = []int{1, 3, 6, 1, 5, 5, 7, 1, 1} +) + +var ( + oidAuthorityInfoAccessOcsp = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 48, 1} + oidAuthorityInfoAccessIssuers = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 48, 2} +) + +// oidNotInExtensions returns whether an extension with the given oid exists in +// extensions. +func oidInExtensions(oid asn1.ObjectIdentifier, extensions []pkix.Extension) bool { + for _, e := range extensions { + if e.Id.Equal(oid) { + return true + } + } + return false +} + +// marshalSANs marshals a list of addresses into a the contents of an X.509 +// SubjectAlternativeName extension. +func marshalSANs(dnsNames, emailAddresses []string, ipAddresses []net.IP) (derBytes []byte, err error) { + var rawValues []asn1.RawValue + for _, name := range dnsNames { + rawValues = append(rawValues, asn1.RawValue{Tag: 2, Class: 2, Bytes: []byte(name)}) + } + for _, email := range emailAddresses { + rawValues = append(rawValues, asn1.RawValue{Tag: 1, Class: 2, Bytes: []byte(email)}) + } + for _, rawIP := range ipAddresses { + // If possible, we always want to encode IPv4 addresses in 4 bytes. + ip := rawIP.To4() + if ip == nil { + ip = rawIP + } + rawValues = append(rawValues, asn1.RawValue{Tag: 7, Class: 2, Bytes: ip}) + } + return asn1.Marshal(rawValues) +} + +func buildExtensions(template *Certificate) (ret []pkix.Extension, err error) { + ret = make([]pkix.Extension, 10 /* maximum number of elements. */) + n := 0 + + if template.KeyUsage != 0 && + !oidInExtensions(oidExtensionKeyUsage, template.ExtraExtensions) { + ret[n].Id = oidExtensionKeyUsage + ret[n].Critical = true + + var a [2]byte + a[0] = reverseBitsInAByte(byte(template.KeyUsage)) + a[1] = reverseBitsInAByte(byte(template.KeyUsage >> 8)) + + l := 1 + if a[1] != 0 { + l = 2 + } + + bitString := a[:l] + ret[n].Value, err = asn1.Marshal(asn1.BitString{Bytes: bitString, BitLength: asn1BitLength(bitString)}) + if err != nil { + return + } + n++ + } + + if (len(template.ExtKeyUsage) > 0 || len(template.UnknownExtKeyUsage) > 0) && + !oidInExtensions(oidExtensionExtendedKeyUsage, template.ExtraExtensions) { + ret[n].Id = oidExtensionExtendedKeyUsage + + var oids []asn1.ObjectIdentifier + for _, u := range template.ExtKeyUsage { + if oid, ok := oidFromExtKeyUsage(u); ok { + oids = append(oids, oid) + } else { + panic("internal error") + } + } + + oids = append(oids, template.UnknownExtKeyUsage...) + + ret[n].Value, err = asn1.Marshal(oids) + if err != nil { + return + } + n++ + } + + if template.BasicConstraintsValid && !oidInExtensions(oidExtensionBasicConstraints, template.ExtraExtensions) { + // Leaving MaxPathLen as zero indicates that no maximum path + // length is desired, unless MaxPathLenZero is set. A value of + // -1 causes encoding/asn1 to omit the value as desired. + maxPathLen := template.MaxPathLen + if maxPathLen == 0 && !template.MaxPathLenZero { + maxPathLen = -1 + } + ret[n].Id = oidExtensionBasicConstraints + ret[n].Value, err = asn1.Marshal(basicConstraints{template.IsCA, maxPathLen}) + ret[n].Critical = true + if err != nil { + return + } + n++ + } + + if len(template.SubjectKeyId) > 0 && !oidInExtensions(oidExtensionSubjectKeyId, template.ExtraExtensions) { + ret[n].Id = oidExtensionSubjectKeyId + ret[n].Value, err = asn1.Marshal(template.SubjectKeyId) + if err != nil { + return + } + n++ + } + + if len(template.AuthorityKeyId) > 0 && !oidInExtensions(oidExtensionAuthorityKeyId, template.ExtraExtensions) { + ret[n].Id = oidExtensionAuthorityKeyId + ret[n].Value, err = asn1.Marshal(authKeyId{template.AuthorityKeyId}) + if err != nil { + return + } + n++ + } + + if (len(template.OCSPServer) > 0 || len(template.IssuingCertificateURL) > 0) && + !oidInExtensions(oidExtensionAuthorityInfoAccess, template.ExtraExtensions) { + ret[n].Id = oidExtensionAuthorityInfoAccess + var aiaValues []authorityInfoAccess + for _, name := range template.OCSPServer { + aiaValues = append(aiaValues, authorityInfoAccess{ + Method: oidAuthorityInfoAccessOcsp, + Location: asn1.RawValue{Tag: 6, Class: 2, Bytes: []byte(name)}, + }) + } + for _, name := range template.IssuingCertificateURL { + aiaValues = append(aiaValues, authorityInfoAccess{ + Method: oidAuthorityInfoAccessIssuers, + Location: asn1.RawValue{Tag: 6, Class: 2, Bytes: []byte(name)}, + }) + } + ret[n].Value, err = asn1.Marshal(aiaValues) + if err != nil { + return + } + n++ + } + + if (len(template.DNSNames) > 0 || len(template.EmailAddresses) > 0 || len(template.IPAddresses) > 0) && + !oidInExtensions(oidExtensionSubjectAltName, template.ExtraExtensions) { + ret[n].Id = oidExtensionSubjectAltName + ret[n].Value, err = marshalSANs(template.DNSNames, template.EmailAddresses, template.IPAddresses) + if err != nil { + return + } + n++ + } + + if len(template.PolicyIdentifiers) > 0 && + !oidInExtensions(oidExtensionCertificatePolicies, template.ExtraExtensions) { + ret[n].Id = oidExtensionCertificatePolicies + policies := make([]policyInformation, len(template.PolicyIdentifiers)) + for i, policy := range template.PolicyIdentifiers { + policies[i].Policy = policy + } + ret[n].Value, err = asn1.Marshal(policies) + if err != nil { + return + } + n++ + } + + if len(template.PermittedDNSDomains) > 0 && + !oidInExtensions(oidExtensionNameConstraints, template.ExtraExtensions) { + ret[n].Id = oidExtensionNameConstraints + ret[n].Critical = template.PermittedDNSDomainsCritical + + var out nameConstraints + out.Permitted = make([]generalSubtree, len(template.PermittedDNSDomains)) + for i, permitted := range template.PermittedDNSDomains { + out.Permitted[i] = generalSubtree{Name: permitted} + } + ret[n].Value, err = asn1.Marshal(out) + if err != nil { + return + } + n++ + } + + if len(template.CRLDistributionPoints) > 0 && + !oidInExtensions(oidExtensionCRLDistributionPoints, template.ExtraExtensions) { + ret[n].Id = oidExtensionCRLDistributionPoints + + var crlDp []distributionPoint + for _, name := range template.CRLDistributionPoints { + rawFullName, _ := asn1.Marshal(asn1.RawValue{Tag: 6, Class: 2, Bytes: []byte(name)}) + + dp := distributionPoint{ + DistributionPoint: distributionPointName{ + FullName: asn1.RawValue{Tag: 0, Class: 2, IsCompound: true, Bytes: rawFullName}, + }, + } + crlDp = append(crlDp, dp) + } + + ret[n].Value, err = asn1.Marshal(crlDp) + if err != nil { + return + } + n++ + } + + // Adding another extension here? Remember to update the maximum number + // of elements in the make() at the top of the function. + + return append(ret[:n], template.ExtraExtensions...), nil +} + +func subjectBytes(cert *Certificate) ([]byte, error) { + if len(cert.RawSubject) > 0 { + return cert.RawSubject, nil + } + + return asn1.Marshal(cert.Subject.ToRDNSequence()) +} + +// signingParamsForPublicKey returns the parameters to use for signing with +// priv. If requestedSigAlgo is not zero then it overrides the default +// signature algorithm. +func signingParamsForPublicKey(pub interface{}, requestedSigAlgo SignatureAlgorithm) (hashFunc Hash, sigAlgo pkix.AlgorithmIdentifier, err error) { + var pubType PublicKeyAlgorithm + + switch pub := pub.(type) { + case *rsa.PublicKey: + pubType = RSA + hashFunc = SHA256 + sigAlgo.Algorithm = oidSignatureSHA256WithRSA + sigAlgo.Parameters = asn1.RawValue{ + Tag: 5, + } + + case *ecdsa.PublicKey: + pubType = ECDSA + switch pub.Curve { + case elliptic.P224(), elliptic.P256(): + hashFunc = SHA256 + sigAlgo.Algorithm = oidSignatureECDSAWithSHA256 + case elliptic.P384(): + hashFunc = SHA384 + sigAlgo.Algorithm = oidSignatureECDSAWithSHA384 + case elliptic.P521(): + hashFunc = SHA512 + sigAlgo.Algorithm = oidSignatureECDSAWithSHA512 + default: + err = errors.New("x509: unknown elliptic curve") + } + case *PublicKey: + pubType = ECDSA + switch pub.Curve { + case P256Sm2(): + hashFunc = SM3 + sigAlgo.Algorithm = oidSignatureSM2WithSM3 + default: + err = errors.New("x509: unknown SM2 curve") + } + default: + err = errors.New("x509: only RSA and ECDSA keys supported") + } + + if err != nil { + return + } + + if requestedSigAlgo == 0 { + return + } + + found := false + for _, details := range signatureAlgorithmDetails { + if details.algo == requestedSigAlgo { + if details.pubKeyAlgo != pubType { + err = errors.New("x509: requested SignatureAlgorithm does not match private key type") + return + } + sigAlgo.Algorithm, hashFunc = details.oid, details.hash + if hashFunc == 0 { + err = errors.New("x509: cannot sign with hash function requested") + return + } + if requestedSigAlgo.isRSAPSS() { + sigAlgo.Parameters = rsaPSSParameters(hashFunc) + } + found = true + break + } + } + + if !found { + err = errors.New("x509: unknown SignatureAlgorithm") + } + + return +} + +// CreateCertificate creates a new certificate based on a template. The +// following members of template are used: SerialNumber, Subject, NotBefore, +// NotAfter, KeyUsage, ExtKeyUsage, UnknownExtKeyUsage, BasicConstraintsValid, +// IsCA, MaxPathLen, SubjectKeyId, DNSNames, PermittedDNSDomainsCritical, +// PermittedDNSDomains, SignatureAlgorithm. +// +// The certificate is signed by parent. If parent is equal to template then the +// certificate is self-signed. The parameter pub is the public key of the +// signee and priv is the private key of the signer. +// +// The returned slice is the certificate in DER encoding. +// +// All keys types that are implemented via crypto.Signer are supported (This +// includes *rsa.PublicKey and *ecdsa.PublicKey.) +func CreateCertificate(rand io.Reader, template, parent *Certificate, pub, priv interface{}) (cert []byte, err error) { + key, ok := priv.(crypto.Signer) + if !ok { + return nil, errors.New("x509: certificate private key does not implement crypto.Signer") + } + + if template.SerialNumber == nil { + return nil, errors.New("x509: no SerialNumber given") + } + + hashFunc, signatureAlgorithm, err := signingParamsForPublicKey(key.Public(), template.SignatureAlgorithm) + if err != nil { + return nil, err + } + + publicKeyBytes, publicKeyAlgorithm, err := marshalPublicKey(pub) + if err != nil { + return nil, err + } + + asn1Issuer, err := subjectBytes(parent) + if err != nil { + return + } + + asn1Subject, err := subjectBytes(template) + if err != nil { + return + } + + if !bytes.Equal(asn1Issuer, asn1Subject) && len(parent.SubjectKeyId) > 0 { + template.AuthorityKeyId = parent.SubjectKeyId + } + + extensions, err := buildExtensions(template) + if err != nil { + return + } + encodedPublicKey := asn1.BitString{BitLength: len(publicKeyBytes) * 8, Bytes: publicKeyBytes} + c := tbsCertificate{ + Version: 2, + SerialNumber: template.SerialNumber, + SignatureAlgorithm: signatureAlgorithm, + Issuer: asn1.RawValue{FullBytes: asn1Issuer}, + Validity: validity{template.NotBefore.UTC(), template.NotAfter.UTC()}, + Subject: asn1.RawValue{FullBytes: asn1Subject}, + PublicKey: publicKeyInfo{nil, publicKeyAlgorithm, encodedPublicKey}, + Extensions: extensions, + } + + tbsCertContents, err := asn1.Marshal(c) + if err != nil { + return + } + + c.Raw = tbsCertContents + + h := hashFunc.New() + h.Write(tbsCertContents) + digest := h.Sum(nil) + + var signerOpts crypto.SignerOpts + signerOpts = hashFunc + if template.SignatureAlgorithm != 0 && template.SignatureAlgorithm.isRSAPSS() { + signerOpts = &rsa.PSSOptions{ + SaltLength: rsa.PSSSaltLengthEqualsHash, + Hash: crypto.Hash(hashFunc), + } + } + + var signature []byte + signature, err = key.Sign(rand, digest, signerOpts) + if err != nil { + return + } + + return asn1.Marshal(certificate{ + nil, + c, + signatureAlgorithm, + asn1.BitString{Bytes: signature, BitLength: len(signature) * 8}, + }) +} + +// pemCRLPrefix is the magic string that indicates that we have a PEM encoded +// CRL. +var pemCRLPrefix = []byte("-----BEGIN X509 CRL") + +// pemType is the type of a PEM encoded CRL. +var pemType = "X509 CRL" + +// ParseCRL parses a CRL from the given bytes. It's often the case that PEM +// encoded CRLs will appear where they should be DER encoded, so this function +// will transparently handle PEM encoding as long as there isn't any leading +// garbage. +func ParseCRL(crlBytes []byte) (*pkix.CertificateList, error) { + if bytes.HasPrefix(crlBytes, pemCRLPrefix) { + block, _ := pem.Decode(crlBytes) + if block != nil && block.Type == pemType { + crlBytes = block.Bytes + } + } + return ParseDERCRL(crlBytes) +} + +// ParseDERCRL parses a DER encoded CRL from the given bytes. +func ParseDERCRL(derBytes []byte) (*pkix.CertificateList, error) { + certList := new(pkix.CertificateList) + if rest, err := asn1.Unmarshal(derBytes, certList); err != nil { + return nil, err + } else if len(rest) != 0 { + return nil, errors.New("x509: trailing data after CRL") + } + return certList, nil +} + +// CreateCRL returns a DER encoded CRL, signed by this Certificate, that +// contains the given list of revoked certificates. +func (c *Certificate) CreateCRL(rand io.Reader, priv interface{}, revokedCerts []pkix.RevokedCertificate, now, expiry time.Time) (crlBytes []byte, err error) { + key, ok := priv.(crypto.Signer) + if !ok { + return nil, errors.New("x509: certificate private key does not implement crypto.Signer") + } + + hashFunc, signatureAlgorithm, err := signingParamsForPublicKey(key.Public(), 0) + if err != nil { + return nil, err + } + + // Force revocation times to UTC per RFC 5280. + revokedCertsUTC := make([]pkix.RevokedCertificate, len(revokedCerts)) + for i, rc := range revokedCerts { + rc.RevocationTime = rc.RevocationTime.UTC() + revokedCertsUTC[i] = rc + } + + tbsCertList := pkix.TBSCertificateList{ + Version: 1, + Signature: signatureAlgorithm, + Issuer: c.Subject.ToRDNSequence(), + ThisUpdate: now.UTC(), + NextUpdate: expiry.UTC(), + RevokedCertificates: revokedCertsUTC, + } + + // Authority Key Id + if len(c.SubjectKeyId) > 0 { + var aki pkix.Extension + aki.Id = oidExtensionAuthorityKeyId + aki.Value, err = asn1.Marshal(authKeyId{Id: c.SubjectKeyId}) + if err != nil { + return + } + tbsCertList.Extensions = append(tbsCertList.Extensions, aki) + } + + tbsCertListContents, err := asn1.Marshal(tbsCertList) + if err != nil { + return + } + + h := hashFunc.New() + h.Write(tbsCertListContents) + digest := h.Sum(nil) + + var signature []byte + signature, err = key.Sign(rand, digest, hashFunc) + if err != nil { + return + } + + return asn1.Marshal(pkix.CertificateList{ + TBSCertList: tbsCertList, + SignatureAlgorithm: signatureAlgorithm, + SignatureValue: asn1.BitString{Bytes: signature, BitLength: len(signature) * 8}, + }) +} + +// CertificateRequest represents a PKCS #10, certificate signature request. +type CertificateRequest struct { + Raw []byte // Complete ASN.1 DER content (CSR, signature algorithm and signature). + RawTBSCertificateRequest []byte // Certificate request info part of raw ASN.1 DER content. + RawSubjectPublicKeyInfo []byte // DER encoded SubjectPublicKeyInfo. + RawSubject []byte // DER encoded Subject. + + Version int + Signature []byte + SignatureAlgorithm SignatureAlgorithm + + PublicKeyAlgorithm PublicKeyAlgorithm + PublicKey interface{} + + Subject pkix.Name + + // Attributes is the dried husk of a bug and shouldn't be used. + Attributes []pkix.AttributeTypeAndValueSET + + // Extensions contains raw X.509 extensions. When parsing CSRs, this + // can be used to extract extensions that are not parsed by this + // package. + Extensions []pkix.Extension + + // ExtraExtensions contains extensions to be copied, raw, into any + // marshaled CSR. Values override any extensions that would otherwise + // be produced based on the other fields but are overridden by any + // extensions specified in Attributes. + // + // The ExtraExtensions field is not populated when parsing CSRs, see + // Extensions. + ExtraExtensions []pkix.Extension + + // Subject Alternate Name values. + DNSNames []string + EmailAddresses []string + IPAddresses []net.IP +} + +// These structures reflect the ASN.1 structure of X.509 certificate +// signature requests (see RFC 2986): + +type tbsCertificateRequest struct { + Raw asn1.RawContent + Version int + Subject asn1.RawValue + PublicKey publicKeyInfo + RawAttributes []asn1.RawValue `asn1:"tag:0"` +} + +type certificateRequest struct { + Raw asn1.RawContent + TBSCSR tbsCertificateRequest + SignatureAlgorithm pkix.AlgorithmIdentifier + SignatureValue asn1.BitString +} + +// oidExtensionRequest is a PKCS#9 OBJECT IDENTIFIER that indicates requested +// extensions in a CSR. +var oidExtensionRequest = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 9, 14} + +// newRawAttributes converts AttributeTypeAndValueSETs from a template +// CertificateRequest's Attributes into tbsCertificateRequest RawAttributes. +func newRawAttributes(attributes []pkix.AttributeTypeAndValueSET) ([]asn1.RawValue, error) { + var rawAttributes []asn1.RawValue + b, err := asn1.Marshal(attributes) + if err != nil { + return nil, err + } + rest, err := asn1.Unmarshal(b, &rawAttributes) + if err != nil { + return nil, err + } + if len(rest) != 0 { + return nil, errors.New("x509: failed to unmarshal raw CSR Attributes") + } + return rawAttributes, nil +} + +// parseRawAttributes Unmarshals RawAttributes intos AttributeTypeAndValueSETs. +func parseRawAttributes(rawAttributes []asn1.RawValue) []pkix.AttributeTypeAndValueSET { + var attributes []pkix.AttributeTypeAndValueSET + for _, rawAttr := range rawAttributes { + var attr pkix.AttributeTypeAndValueSET + rest, err := asn1.Unmarshal(rawAttr.FullBytes, &attr) + // Ignore attributes that don't parse into pkix.AttributeTypeAndValueSET + // (i.e.: challengePassword or unstructuredName). + if err == nil && len(rest) == 0 { + attributes = append(attributes, attr) + } + } + return attributes +} + +// parseCSRExtensions parses the attributes from a CSR and extracts any +// requested extensions. +func parseCSRExtensions(rawAttributes []asn1.RawValue) ([]pkix.Extension, error) { + // pkcs10Attribute reflects the Attribute structure from section 4.1 of + // https://tools.ietf.org/html/rfc2986. + type pkcs10Attribute struct { + Id asn1.ObjectIdentifier + Values []asn1.RawValue `asn1:"set"` + } + + var ret []pkix.Extension + for _, rawAttr := range rawAttributes { + var attr pkcs10Attribute + if rest, err := asn1.Unmarshal(rawAttr.FullBytes, &attr); err != nil || len(rest) != 0 || len(attr.Values) == 0 { + // Ignore attributes that don't parse. + continue + } + + if !attr.Id.Equal(oidExtensionRequest) { + continue + } + + var extensions []pkix.Extension + if _, err := asn1.Unmarshal(attr.Values[0].FullBytes, &extensions); err != nil { + return nil, err + } + ret = append(ret, extensions...) + } + + return ret, nil +} + +// CreateCertificateRequest creates a new certificate request based on a template. +// The following members of template are used: Subject, Attributes, +// SignatureAlgorithm, Extensions, DNSNames, EmailAddresses, and IPAddresses. +// The private key is the private key of the signer. +// +// The returned slice is the certificate request in DER encoding. +// +// All keys types that are implemented via crypto.Signer are supported (This +// includes *rsa.PublicKey and *ecdsa.PublicKey.) +func CreateCertificateRequest(rand io.Reader, template *CertificateRequest, priv interface{}) (csr []byte, err error) { + key, ok := priv.(crypto.Signer) + if !ok { + return nil, errors.New("x509: certificate private key does not implement crypto.Signer") + } + + var hashFunc Hash + var sigAlgo pkix.AlgorithmIdentifier + hashFunc, sigAlgo, err = signingParamsForPublicKey(key.Public(), template.SignatureAlgorithm) + if err != nil { + return nil, err + } + + var publicKeyBytes []byte + var publicKeyAlgorithm pkix.AlgorithmIdentifier + publicKeyBytes, publicKeyAlgorithm, err = marshalPublicKey(key.Public()) + if err != nil { + return nil, err + } + + var extensions []pkix.Extension + + if (len(template.DNSNames) > 0 || len(template.EmailAddresses) > 0 || len(template.IPAddresses) > 0) && + !oidInExtensions(oidExtensionSubjectAltName, template.ExtraExtensions) { + sanBytes, err := marshalSANs(template.DNSNames, template.EmailAddresses, template.IPAddresses) + if err != nil { + return nil, err + } + + extensions = append(extensions, pkix.Extension{ + Id: oidExtensionSubjectAltName, + Value: sanBytes, + }) + } + + extensions = append(extensions, template.ExtraExtensions...) + + var attributes []pkix.AttributeTypeAndValueSET + attributes = append(attributes, template.Attributes...) + + if len(extensions) > 0 { + // specifiedExtensions contains all the extensions that we + // found specified via template.Attributes. + specifiedExtensions := make(map[string]bool) + + for _, atvSet := range template.Attributes { + if !atvSet.Type.Equal(oidExtensionRequest) { + continue + } + + for _, atvs := range atvSet.Value { + for _, atv := range atvs { + specifiedExtensions[atv.Type.String()] = true + } + } + } + + atvs := make([]pkix.AttributeTypeAndValue, 0, len(extensions)) + for _, e := range extensions { + if specifiedExtensions[e.Id.String()] { + // Attributes already contained a value for + // this extension and it takes priority. + continue + } + + atvs = append(atvs, pkix.AttributeTypeAndValue{ + // There is no place for the critical flag in a CSR. + Type: e.Id, + Value: e.Value, + }) + } + + // Append the extensions to an existing attribute if possible. + appended := false + for _, atvSet := range attributes { + if !atvSet.Type.Equal(oidExtensionRequest) || len(atvSet.Value) == 0 { + continue + } + + atvSet.Value[0] = append(atvSet.Value[0], atvs...) + appended = true + break + } + + // Otherwise, add a new attribute for the extensions. + if !appended { + attributes = append(attributes, pkix.AttributeTypeAndValueSET{ + Type: oidExtensionRequest, + Value: [][]pkix.AttributeTypeAndValue{ + atvs, + }, + }) + } + } + + asn1Subject := template.RawSubject + if len(asn1Subject) == 0 { + asn1Subject, err = asn1.Marshal(template.Subject.ToRDNSequence()) + if err != nil { + return + } + } + + rawAttributes, err := newRawAttributes(attributes) + if err != nil { + return + } + + tbsCSR := tbsCertificateRequest{ + Version: 0, // PKCS #10, RFC 2986 + Subject: asn1.RawValue{FullBytes: asn1Subject}, + PublicKey: publicKeyInfo{ + Algorithm: publicKeyAlgorithm, + PublicKey: asn1.BitString{ + Bytes: publicKeyBytes, + BitLength: len(publicKeyBytes) * 8, + }, + }, + RawAttributes: rawAttributes, + } + + tbsCSRContents, err := asn1.Marshal(tbsCSR) + if err != nil { + return + } + tbsCSR.Raw = tbsCSRContents + + h := hashFunc.New() + h.Write(tbsCSRContents) + digest := h.Sum(nil) + + var signature []byte + signature, err = key.Sign(rand, digest, hashFunc) + if err != nil { + return + } + + return asn1.Marshal(certificateRequest{ + TBSCSR: tbsCSR, + SignatureAlgorithm: sigAlgo, + SignatureValue: asn1.BitString{ + Bytes: signature, + BitLength: len(signature) * 8, + }, + }) +} + +// ParseCertificateRequest parses a single certificate request from the +// given ASN.1 DER data. +func ParseCertificateRequest(asn1Data []byte) (*CertificateRequest, error) { + var csr certificateRequest + + rest, err := asn1.Unmarshal(asn1Data, &csr) + if err != nil { + return nil, err + } else if len(rest) != 0 { + return nil, asn1.SyntaxError{Msg: "trailing data"} + } + + return parseCertificateRequest(&csr) +} + +func parseCertificateRequest(in *certificateRequest) (*CertificateRequest, error) { + out := &CertificateRequest{ + Raw: in.Raw, + RawTBSCertificateRequest: in.TBSCSR.Raw, + RawSubjectPublicKeyInfo: in.TBSCSR.PublicKey.Raw, + RawSubject: in.TBSCSR.Subject.FullBytes, + + Signature: in.SignatureValue.RightAlign(), + SignatureAlgorithm: getSignatureAlgorithmFromAI(in.SignatureAlgorithm), + + PublicKeyAlgorithm: getPublicKeyAlgorithmFromOID(in.TBSCSR.PublicKey.Algorithm.Algorithm), + + Version: in.TBSCSR.Version, + Attributes: parseRawAttributes(in.TBSCSR.RawAttributes), + } + + var err error + out.PublicKey, err = parsePublicKey(out.PublicKeyAlgorithm, &in.TBSCSR.PublicKey) + if err != nil { + return nil, err + } + + var subject pkix.RDNSequence + if rest, err := asn1.Unmarshal(in.TBSCSR.Subject.FullBytes, &subject); err != nil { + return nil, err + } else if len(rest) != 0 { + return nil, errors.New("x509: trailing data after X.509 Subject") + } + + out.Subject.FillFromRDNSequence(&subject) + + if out.Extensions, err = parseCSRExtensions(in.TBSCSR.RawAttributes); err != nil { + return nil, err + } + + for _, extension := range out.Extensions { + if extension.Id.Equal(oidExtensionSubjectAltName) { + out.DNSNames, out.EmailAddresses, out.IPAddresses, err = parseSANExtension(extension.Value) + if err != nil { + return nil, err + } + } + } + + return out, nil +} + +// CheckSignature reports whether the signature on c is valid. +func (c *CertificateRequest) CheckSignature() error { + return checkSignature(c.SignatureAlgorithm, c.RawTBSCertificateRequest, c.Signature, c.PublicKey) +} + +func ReadCertificateRequestFromMem(data []byte) (*CertificateRequest, error) { + block, _ := pem.Decode(data) + if block == nil { + return nil, errors.New("failed to decode certificate request") + } + return ParseCertificateRequest(block.Bytes) +} + +func ReadCertificateRequestFromPem(FileName string) (*CertificateRequest, error) { + data, err := ioutil.ReadFile(FileName) + if err != nil { + return nil, err + } + return ReadCertificateRequestFromMem(data) +} + +func CreateCertificateRequestToMem(template *CertificateRequest, privKey *PrivateKey) ([]byte, error) { + der, err := CreateCertificateRequest(rand.Reader, template, privKey) + if err != nil { + return nil, err + } + block := &pem.Block{ + Type: "CERTIFICATE REQUEST", + Bytes: der, + } + return pem.EncodeToMemory(block), nil +} + +func CreateCertificateRequestToPem(FileName string, template *CertificateRequest, + privKey *PrivateKey) (bool, error) { + der, err := CreateCertificateRequest(rand.Reader, template, privKey) + if err != nil { + return false, err + } + block := &pem.Block{ + Type: "CERTIFICATE REQUEST", + Bytes: der, + } + file, err := os.Create(FileName) + if err != nil { + return false, err + } + defer file.Close() + err = pem.Encode(file, block) + if err != nil { + return false, err + } + return true, nil +} + +func ReadCertificateFromMem(data []byte) (*Certificate, error) { + block, _ := pem.Decode(data) + if block == nil { + return nil, errors.New("failed to decode certificate request") + } + return ParseCertificate(block.Bytes) +} + +func ReadCertificateFromPem(FileName string) (*Certificate, error) { + data, err := ioutil.ReadFile(FileName) + if err != nil { + return nil, err + } + return ReadCertificateFromMem(data) +} + +func CreateCertificateToMem(template, parent *Certificate, pubKey *PublicKey, privKey *PrivateKey) ([]byte, error) { + der, err := CreateCertificate(rand.Reader, template, parent, pubKey, privKey) + if err != nil { + return nil, err + } + block := &pem.Block{ + Type: "CERTIFICATE", + Bytes: der, + } + return pem.EncodeToMemory(block), nil +} + +func CreateCertificateToPem(FileName string, template, parent *Certificate, pubKey *PublicKey, privKey *PrivateKey) (bool, error) { + der, err := CreateCertificate(rand.Reader, template, parent, pubKey, privKey) + if err != nil { + return false, err + } + block := &pem.Block{ + Type: "CERTIFICATE", + Bytes: der, + } + file, err := os.Create(FileName) + if err != nil { + return false, err + } + defer file.Close() + err = pem.Encode(file, block) + if err != nil { + return false, err + } + return true, nil +} diff --git a/crypto/sm3/sm3_test.go b/crypto/sm3/sm3_test.go index fe1d1c60..428bcc24 100644 --- a/crypto/sm3/sm3_test.go +++ b/crypto/sm3/sm3_test.go @@ -45,6 +45,7 @@ func TestSm3(t *testing.T) { hw.Write(msg) hash := hw.Sum(nil) fmt.Println(hash) + fmt.Printf("hash = %d\n", len(hash)) fmt.Printf("%s\n", byteToString(hash)) hash1 := Sm3Sum(msg) fmt.Println(hash1) -- 2.11.0