OSDN Git Service

add sm2 (#1193)
[bytom/bytom.git] / crypto / sm2 / pkcs8.go
1 /*
2 Copyright Suzhou Tongji Fintech Research Institute 2017 All Rights Reserved.
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7         http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 */
15
16 package sm2
17
18 import (
19         "crypto/aes"
20         "crypto/cipher"
21         "crypto/elliptic"
22         "crypto/hmac"
23         "crypto/md5"
24         "crypto/rand"
25         "crypto/sha1"
26         "crypto/sha256"
27         "crypto/sha512"
28         "crypto/x509/pkix"
29         "encoding/asn1"
30         "encoding/pem"
31         "errors"
32         "hash"
33         "io/ioutil"
34         "math/big"
35         "os"
36         "reflect"
37 )
38
39 /*
40  * reference to RFC5959 and RFC2898
41  */
42
43 var (
44         oidPBES1  = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 5, 3}  // pbeWithMD5AndDES-CBC(PBES1)
45         oidPBES2  = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 5, 13} // id-PBES2(PBES2)
46         oidPBKDF2 = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 5, 12} // id-PBKDF2
47
48         oidKEYMD5    = asn1.ObjectIdentifier{1, 2, 840, 113549, 2, 5}
49         oidKEYSHA1   = asn1.ObjectIdentifier{1, 2, 840, 113549, 2, 7}
50         oidKEYSHA256 = asn1.ObjectIdentifier{1, 2, 840, 113549, 2, 9}
51         oidKEYSHA512 = asn1.ObjectIdentifier{1, 2, 840, 113549, 2, 11}
52
53         oidAES128CBC = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 1, 2}
54         oidAES256CBC = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 1, 42}
55
56         oidSM2 = asn1.ObjectIdentifier{1, 2, 840, 10045, 2, 1}
57 )
58
59 // reference to https://www.rfc-editor.org/rfc/rfc5958.txt
60 type PrivateKeyInfo struct {
61         Version             int // v1 or v2
62         PrivateKeyAlgorithm []asn1.ObjectIdentifier
63         PrivateKey          []byte
64 }
65
66 // reference to https://www.rfc-editor.org/rfc/rfc5958.txt
67 type EncryptedPrivateKeyInfo struct {
68         EncryptionAlgorithm Pbes2Algorithms
69         EncryptedData       []byte
70 }
71
72 // reference to https://www.ietf.org/rfc/rfc2898.txt
73 type Pbes2Algorithms struct {
74         IdPBES2     asn1.ObjectIdentifier
75         Pbes2Params Pbes2Params
76 }
77
78 // reference to https://www.ietf.org/rfc/rfc2898.txt
79 type Pbes2Params struct {
80         KeyDerivationFunc Pbes2KDfs // PBES2-KDFs
81         EncryptionScheme  Pbes2Encs // PBES2-Encs
82 }
83
84 // reference to https://www.ietf.org/rfc/rfc2898.txt
85 type Pbes2KDfs struct {
86         IdPBKDF2    asn1.ObjectIdentifier
87         Pkdf2Params Pkdf2Params
88 }
89
90 type Pbes2Encs struct {
91         EncryAlgo asn1.ObjectIdentifier
92         IV        []byte
93 }
94
95 // reference to https://www.ietf.org/rfc/rfc2898.txt
96 type Pkdf2Params struct {
97         Salt           []byte
98         IterationCount int
99         Prf            pkix.AlgorithmIdentifier
100 }
101
102 type sm2PrivateKey struct {
103         Version       int
104         PrivateKey    []byte
105         NamedCurveOID asn1.ObjectIdentifier `asn1:"optional,explicit,tag:0"`
106         PublicKey     asn1.BitString        `asn1:"optional,explicit,tag:1"`
107 }
108
109 type pkcs8 struct {
110         Version    int
111         Algo       pkix.AlgorithmIdentifier
112         PrivateKey []byte
113 }
114
115 // copy from crypto/pbkdf2.go
116 func pbkdf(password, salt []byte, iter, keyLen int, h func() hash.Hash) []byte {
117         prf := hmac.New(h, password)
118         hashLen := prf.Size()
119         numBlocks := (keyLen + hashLen - 1) / hashLen
120
121         var buf [4]byte
122         dk := make([]byte, 0, numBlocks*hashLen)
123         U := make([]byte, hashLen)
124         for block := 1; block <= numBlocks; block++ {
125                 // N.B.: || means concatenation, ^ means XOR
126                 // for each block T_i = U_1 ^ U_2 ^ ... ^ U_iter
127                 // U_1 = PRF(password, salt || uint(i))
128                 prf.Reset()
129                 prf.Write(salt)
130                 buf[0] = byte(block >> 24)
131                 buf[1] = byte(block >> 16)
132                 buf[2] = byte(block >> 8)
133                 buf[3] = byte(block)
134                 prf.Write(buf[:4])
135                 dk = prf.Sum(dk)
136                 T := dk[len(dk)-hashLen:]
137                 copy(U, T)
138
139                 // U_n = PRF(password, U_(n-1))
140                 for n := 2; n <= iter; n++ {
141                         prf.Reset()
142                         prf.Write(U)
143                         U = U[:0]
144                         U = prf.Sum(U)
145                         for x := range U {
146                                 T[x] ^= U[x]
147                         }
148                 }
149         }
150         return dk[:keyLen]
151 }
152
153 func ParseSm2PublicKey(der []byte) (*PublicKey, error) {
154         var pubkey pkixPublicKey
155
156         if _, err := asn1.Unmarshal(der, &pubkey); err != nil {
157                 return nil, err
158         }
159         if !reflect.DeepEqual(pubkey.Algo.Algorithm, oidSM2) {
160                 return nil, errors.New("x509: not sm2 elliptic curve")
161         }
162         curve := P256Sm2()
163         x, y := elliptic.Unmarshal(curve, pubkey.BitString.Bytes)
164         pub := PublicKey{
165                 Curve: curve,
166                 X:     x,
167                 Y:     y,
168         }
169         return &pub, nil
170 }
171
172 func MarshalSm2PublicKey(key *PublicKey) ([]byte, error) {
173         var r pkixPublicKey
174         var algo pkix.AlgorithmIdentifier
175
176         algo.Algorithm = oidSM2
177         algo.Parameters.Class = 0
178         algo.Parameters.Tag = 6
179         algo.Parameters.IsCompound = false
180         algo.Parameters.FullBytes = []byte{6, 8, 42, 129, 28, 207, 85, 1, 130, 45} // asn1.Marshal(asn1.ObjectIdentifier{1, 2, 156, 10197, 1, 301})
181         r.Algo = algo
182         r.BitString = asn1.BitString{Bytes: elliptic.Marshal(key.Curve, key.X, key.Y)}
183         return asn1.Marshal(r)
184 }
185
186 func ParseSm2PrivateKey(der []byte) (*PrivateKey, error) {
187         var privKey sm2PrivateKey
188
189         if _, err := asn1.Unmarshal(der, &privKey); err != nil {
190                 return nil, errors.New("x509: failed to parse SM2 private key: " + err.Error())
191         }
192         curve := P256Sm2()
193         k := new(big.Int).SetBytes(privKey.PrivateKey)
194         curveOrder := curve.Params().N
195         if k.Cmp(curveOrder) >= 0 {
196                 return nil, errors.New("x509: invalid elliptic curve private key value")
197         }
198         priv := new(PrivateKey)
199         priv.Curve = curve
200         priv.D = k
201         privateKey := make([]byte, (curveOrder.BitLen()+7)/8)
202         for len(privKey.PrivateKey) > len(privateKey) {
203                 if privKey.PrivateKey[0] != 0 {
204                         return nil, errors.New("x509: invalid private key length")
205                 }
206                 privKey.PrivateKey = privKey.PrivateKey[1:]
207         }
208         copy(privateKey[len(privateKey)-len(privKey.PrivateKey):], privKey.PrivateKey)
209         priv.X, priv.Y = curve.ScalarBaseMult(privateKey)
210         return priv, nil
211 }
212
213 func ParsePKCS8UnecryptedPrivateKey(der []byte) (*PrivateKey, error) {
214         var privKey pkcs8
215
216         if _, err := asn1.Unmarshal(der, &privKey); err != nil {
217                 return nil, err
218         }
219         if !reflect.DeepEqual(privKey.Algo.Algorithm, oidSM2) {
220                 return nil, errors.New("x509: not sm2 elliptic curve")
221         }
222         return ParseSm2PrivateKey(privKey.PrivateKey)
223 }
224
225 func ParsePKCS8EcryptedPrivateKey(der, pwd []byte) (*PrivateKey, error) {
226         var keyInfo EncryptedPrivateKeyInfo
227
228         _, err := asn1.Unmarshal(der, &keyInfo)
229         if err != nil {
230                 return nil, errors.New("x509: unknown format")
231         }
232         if !reflect.DeepEqual(keyInfo.EncryptionAlgorithm.IdPBES2, oidPBES2) {
233                 return nil, errors.New("x509: only support PBES2")
234         }
235         encryptionScheme := keyInfo.EncryptionAlgorithm.Pbes2Params.EncryptionScheme
236         keyDerivationFunc := keyInfo.EncryptionAlgorithm.Pbes2Params.KeyDerivationFunc
237         if !reflect.DeepEqual(keyDerivationFunc.IdPBKDF2, oidPBKDF2) {
238                 return nil, errors.New("x509: only support PBKDF2")
239         }
240         pkdf2Params := keyDerivationFunc.Pkdf2Params
241         if !reflect.DeepEqual(encryptionScheme.EncryAlgo, oidAES128CBC) &&
242                 !reflect.DeepEqual(encryptionScheme.EncryAlgo, oidAES256CBC) {
243                 return nil, errors.New("x509: unknow encryption algorithm")
244         }
245         iv := encryptionScheme.IV
246         salt := pkdf2Params.Salt
247         iter := pkdf2Params.IterationCount
248         encryptedKey := keyInfo.EncryptedData
249         var key []byte
250         switch {
251         case pkdf2Params.Prf.Algorithm.Equal(oidKEYMD5):
252                 key = pbkdf(pwd, salt, iter, 32, md5.New)
253                 break
254         case pkdf2Params.Prf.Algorithm.Equal(oidKEYSHA1):
255                 key = pbkdf(pwd, salt, iter, 32, sha1.New)
256                 break
257         case pkdf2Params.Prf.Algorithm.Equal(oidKEYSHA256):
258                 key = pbkdf(pwd, salt, iter, 32, sha256.New)
259                 break
260         case pkdf2Params.Prf.Algorithm.Equal(oidKEYSHA512):
261                 key = pbkdf(pwd, salt, iter, 32, sha512.New)
262                 break
263         default:
264                 return nil, errors.New("x509: unknown hash algorithm")
265         }
266         block, err := aes.NewCipher(key)
267         if err != nil {
268                 return nil, err
269         }
270         mode := cipher.NewCBCDecrypter(block, iv)
271         mode.CryptBlocks(encryptedKey, encryptedKey)
272         rKey, err := ParsePKCS8UnecryptedPrivateKey(encryptedKey)
273         if err != nil {
274                 return nil, errors.New("pkcs8: incorrect password")
275         }
276         return rKey, nil
277 }
278
279 func ParsePKCS8PrivateKey(der, pwd []byte) (*PrivateKey, error) {
280         if pwd == nil {
281                 return ParsePKCS8UnecryptedPrivateKey(der)
282         }
283         return ParsePKCS8EcryptedPrivateKey(der, pwd)
284 }
285
286 func MarshalSm2UnecryptedPrivateKey(key *PrivateKey) ([]byte, error) {
287         var r pkcs8
288         var priv sm2PrivateKey
289         var algo pkix.AlgorithmIdentifier
290
291         algo.Algorithm = oidSM2
292         algo.Parameters.Class = 0
293         algo.Parameters.Tag = 6
294         algo.Parameters.IsCompound = false
295         algo.Parameters.FullBytes = []byte{6, 8, 42, 129, 28, 207, 85, 1, 130, 45} // asn1.Marshal(asn1.ObjectIdentifier{1, 2, 156, 10197, 1, 301})
296         priv.Version = 1
297         priv.NamedCurveOID = oidNamedCurveP256SM2
298         priv.PublicKey = asn1.BitString{Bytes: elliptic.Marshal(key.Curve, key.X, key.Y)}
299         priv.PrivateKey = key.D.Bytes()
300         r.Version = 0
301         r.Algo = algo
302         r.PrivateKey, _ = asn1.Marshal(priv)
303         return asn1.Marshal(r)
304 }
305
306 func MarshalSm2EcryptedPrivateKey(PrivKey *PrivateKey, pwd []byte) ([]byte, error) {
307         der, err := MarshalSm2UnecryptedPrivateKey(PrivKey)
308         if err != nil {
309                 return nil, err
310         }
311         iter := 2048
312         salt := make([]byte, 8)
313         iv := make([]byte, 16)
314         rand.Reader.Read(salt)
315         rand.Reader.Read(iv)
316         key := pbkdf(pwd, salt, iter, 32, sha1.New) // 默认是SHA1
317         padding := aes.BlockSize - len(der)%aes.BlockSize
318         if padding > 0 {
319                 n := len(der)
320                 der = append(der, make([]byte, padding)...)
321                 for i := 0; i < padding; i++ {
322                         der[n+i] = byte(padding)
323                 }
324         }
325         encryptedKey := make([]byte, len(der))
326         block, err := aes.NewCipher(key)
327         if err != nil {
328                 return nil, err
329         }
330         mode := cipher.NewCBCEncrypter(block, iv)
331         mode.CryptBlocks(encryptedKey, der)
332         var algorithmIdentifier pkix.AlgorithmIdentifier
333         algorithmIdentifier.Algorithm = oidKEYSHA1
334         algorithmIdentifier.Parameters.Tag = 5
335         algorithmIdentifier.Parameters.IsCompound = false
336         algorithmIdentifier.Parameters.FullBytes = []byte{5, 0}
337         keyDerivationFunc := Pbes2KDfs{
338                 oidPBKDF2,
339                 Pkdf2Params{
340                         salt,
341                         iter,
342                         algorithmIdentifier,
343                 },
344         }
345         encryptionScheme := Pbes2Encs{
346                 oidAES256CBC,
347                 iv,
348         }
349         pbes2Algorithms := Pbes2Algorithms{
350                 oidPBES2,
351                 Pbes2Params{
352                         keyDerivationFunc,
353                         encryptionScheme,
354                 },
355         }
356         encryptedPkey := EncryptedPrivateKeyInfo{
357                 pbes2Algorithms,
358                 encryptedKey,
359         }
360         return asn1.Marshal(encryptedPkey)
361 }
362
363 func MarshalSm2PrivateKey(key *PrivateKey, pwd []byte) ([]byte, error) {
364         if pwd == nil {
365                 return MarshalSm2UnecryptedPrivateKey(key)
366         }
367         return MarshalSm2EcryptedPrivateKey(key, pwd)
368 }
369
370 func ReadPrivateKeyFromMem(data []byte, pwd []byte) (*PrivateKey, error) {
371         var block *pem.Block
372
373         block, _ = pem.Decode(data)
374         if block == nil {
375                 return nil, errors.New("failed to decode private key")
376         }
377         priv, err := ParsePKCS8PrivateKey(block.Bytes, pwd)
378         return priv, err
379 }
380
381 func ReadPrivateKeyFromPem(FileName string, pwd []byte) (*PrivateKey, error) {
382         data, err := ioutil.ReadFile(FileName)
383         if err != nil {
384                 return nil, err
385         }
386         return ReadPrivateKeyFromMem(data, pwd)
387 }
388
389 func WritePrivateKeytoMem(key *PrivateKey, pwd []byte) ([]byte, error) {
390         var block *pem.Block
391
392         der, err := MarshalSm2PrivateKey(key, pwd)
393         if err != nil {
394                 return nil, err
395         }
396         if pwd != nil {
397                 block = &pem.Block{
398                         Type:  "ENCRYPTED PRIVATE KEY",
399                         Bytes: der,
400                 }
401         } else {
402                 block = &pem.Block{
403                         Type:  "PRIVATE KEY",
404                         Bytes: der,
405                 }
406         }
407         return pem.EncodeToMemory(block), nil
408 }
409
410 func WritePrivateKeytoPem(FileName string, key *PrivateKey, pwd []byte) (bool, error) {
411         var block *pem.Block
412
413         der, err := MarshalSm2PrivateKey(key, pwd)
414         if err != nil {
415                 return false, err
416         }
417         if pwd != nil {
418                 block = &pem.Block{
419                         Type:  "ENCRYPTED PRIVATE KEY",
420                         Bytes: der,
421                 }
422         } else {
423                 block = &pem.Block{
424                         Type:  "PRIVATE KEY",
425                         Bytes: der,
426                 }
427         }
428         file, err := os.Create(FileName)
429         if err != nil {
430                 return false, err
431         }
432         defer file.Close()
433         err = pem.Encode(file, block)
434         if err != nil {
435                 return false, err
436         }
437         return true, nil
438 }
439
440 func ReadPublicKeyFromMem(data []byte, _ []byte) (*PublicKey, error) {
441         block, _ := pem.Decode(data)
442         if block == nil || block.Type != "PUBLIC KEY" {
443                 return nil, errors.New("failed to decode public key")
444         }
445         pub, err := ParseSm2PublicKey(block.Bytes)
446         return pub, err
447 }
448
449 func ReadPublicKeyFromPem(FileName string, pwd []byte) (*PublicKey, error) {
450         data, err := ioutil.ReadFile(FileName)
451         if err != nil {
452                 return nil, err
453         }
454         return ReadPublicKeyFromMem(data, pwd)
455 }
456
457 func WritePublicKeytoMem(key *PublicKey, _ []byte) ([]byte, error) {
458         der, err := MarshalSm2PublicKey(key)
459         if err != nil {
460                 return nil, err
461         }
462         block := &pem.Block{
463                 Type:  "PUBLIC KEY",
464                 Bytes: der,
465         }
466         return pem.EncodeToMemory(block), nil
467 }
468
469 func WritePublicKeytoPem(FileName string, key *PublicKey, _ []byte) (bool, error) {
470         der, err := MarshalSm2PublicKey(key)
471         if err != nil {
472                 return false, err
473         }
474         block := &pem.Block{
475                 Type:  "PUBLIC KEY",
476                 Bytes: der,
477         }
478         file, err := os.Create(FileName)
479         defer file.Close()
480         if err != nil {
481                 return false, err
482         }
483         err = pem.Encode(file, block)
484         if err != nil {
485                 return false, err
486         }
487         return true, nil
488 }