OSDN Git Service

new repo
[bytom/vapor.git] / vendor / golang.org / x / crypto / acme / autocert / autocert_test.go
1 // Copyright 2016 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package autocert
6
7 import (
8         "context"
9         "crypto"
10         "crypto/ecdsa"
11         "crypto/elliptic"
12         "crypto/rand"
13         "crypto/rsa"
14         "crypto/tls"
15         "crypto/x509"
16         "crypto/x509/pkix"
17         "encoding/base64"
18         "encoding/json"
19         "fmt"
20         "html/template"
21         "io"
22         "math/big"
23         "net/http"
24         "net/http/httptest"
25         "reflect"
26         "sync"
27         "testing"
28         "time"
29
30         "golang.org/x/crypto/acme"
31 )
32
33 var discoTmpl = template.Must(template.New("disco").Parse(`{
34         "new-reg": "{{.}}/new-reg",
35         "new-authz": "{{.}}/new-authz",
36         "new-cert": "{{.}}/new-cert"
37 }`))
38
39 var authzTmpl = template.Must(template.New("authz").Parse(`{
40         "status": "pending",
41         "challenges": [
42                 {
43                         "uri": "{{.}}/challenge/1",
44                         "type": "tls-sni-01",
45                         "token": "token-01"
46                 },
47                 {
48                         "uri": "{{.}}/challenge/2",
49                         "type": "tls-sni-02",
50                         "token": "token-02"
51                 }
52         ]
53 }`))
54
55 type memCache struct {
56         mu      sync.Mutex
57         keyData map[string][]byte
58 }
59
60 func (m *memCache) Get(ctx context.Context, key string) ([]byte, error) {
61         m.mu.Lock()
62         defer m.mu.Unlock()
63
64         v, ok := m.keyData[key]
65         if !ok {
66                 return nil, ErrCacheMiss
67         }
68         return v, nil
69 }
70
71 func (m *memCache) Put(ctx context.Context, key string, data []byte) error {
72         m.mu.Lock()
73         defer m.mu.Unlock()
74
75         m.keyData[key] = data
76         return nil
77 }
78
79 func (m *memCache) Delete(ctx context.Context, key string) error {
80         m.mu.Lock()
81         defer m.mu.Unlock()
82
83         delete(m.keyData, key)
84         return nil
85 }
86
87 func newMemCache() *memCache {
88         return &memCache{
89                 keyData: make(map[string][]byte),
90         }
91 }
92
93 func dummyCert(pub interface{}, san ...string) ([]byte, error) {
94         return dateDummyCert(pub, time.Now(), time.Now().Add(90*24*time.Hour), san...)
95 }
96
97 func dateDummyCert(pub interface{}, start, end time.Time, san ...string) ([]byte, error) {
98         // use EC key to run faster on 386
99         key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
100         if err != nil {
101                 return nil, err
102         }
103         t := &x509.Certificate{
104                 SerialNumber:          big.NewInt(1),
105                 NotBefore:             start,
106                 NotAfter:              end,
107                 BasicConstraintsValid: true,
108                 KeyUsage:              x509.KeyUsageKeyEncipherment,
109                 DNSNames:              san,
110         }
111         if pub == nil {
112                 pub = &key.PublicKey
113         }
114         return x509.CreateCertificate(rand.Reader, t, t, pub, key)
115 }
116
117 func decodePayload(v interface{}, r io.Reader) error {
118         var req struct{ Payload string }
119         if err := json.NewDecoder(r).Decode(&req); err != nil {
120                 return err
121         }
122         payload, err := base64.RawURLEncoding.DecodeString(req.Payload)
123         if err != nil {
124                 return err
125         }
126         return json.Unmarshal(payload, v)
127 }
128
129 func TestGetCertificate(t *testing.T) {
130         man := &Manager{Prompt: AcceptTOS}
131         defer man.stopRenew()
132         hello := &tls.ClientHelloInfo{ServerName: "example.org"}
133         testGetCertificate(t, man, "example.org", hello)
134 }
135
136 func TestGetCertificate_trailingDot(t *testing.T) {
137         man := &Manager{Prompt: AcceptTOS}
138         defer man.stopRenew()
139         hello := &tls.ClientHelloInfo{ServerName: "example.org."}
140         testGetCertificate(t, man, "example.org", hello)
141 }
142
143 func TestGetCertificate_ForceRSA(t *testing.T) {
144         man := &Manager{
145                 Prompt:   AcceptTOS,
146                 Cache:    newMemCache(),
147                 ForceRSA: true,
148         }
149         defer man.stopRenew()
150         hello := &tls.ClientHelloInfo{ServerName: "example.org"}
151         testGetCertificate(t, man, "example.org", hello)
152
153         cert, err := man.cacheGet(context.Background(), "example.org")
154         if err != nil {
155                 t.Fatalf("man.cacheGet: %v", err)
156         }
157         if _, ok := cert.PrivateKey.(*rsa.PrivateKey); !ok {
158                 t.Errorf("cert.PrivateKey is %T; want *rsa.PrivateKey", cert.PrivateKey)
159         }
160 }
161
162 func TestGetCertificate_nilPrompt(t *testing.T) {
163         man := &Manager{}
164         defer man.stopRenew()
165         url, finish := startACMEServerStub(t, man, "example.org")
166         defer finish()
167         key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
168         if err != nil {
169                 t.Fatal(err)
170         }
171         man.Client = &acme.Client{
172                 Key:          key,
173                 DirectoryURL: url,
174         }
175         hello := &tls.ClientHelloInfo{ServerName: "example.org"}
176         if _, err := man.GetCertificate(hello); err == nil {
177                 t.Error("got certificate for example.org; wanted error")
178         }
179 }
180
181 func TestGetCertificate_expiredCache(t *testing.T) {
182         // Make an expired cert and cache it.
183         pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
184         if err != nil {
185                 t.Fatal(err)
186         }
187         tmpl := &x509.Certificate{
188                 SerialNumber: big.NewInt(1),
189                 Subject:      pkix.Name{CommonName: "example.org"},
190                 NotAfter:     time.Now(),
191         }
192         pub, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &pk.PublicKey, pk)
193         if err != nil {
194                 t.Fatal(err)
195         }
196         tlscert := &tls.Certificate{
197                 Certificate: [][]byte{pub},
198                 PrivateKey:  pk,
199         }
200
201         man := &Manager{Prompt: AcceptTOS, Cache: newMemCache()}
202         defer man.stopRenew()
203         if err := man.cachePut(context.Background(), "example.org", tlscert); err != nil {
204                 t.Fatalf("man.cachePut: %v", err)
205         }
206
207         // The expired cached cert should trigger a new cert issuance
208         // and return without an error.
209         hello := &tls.ClientHelloInfo{ServerName: "example.org"}
210         testGetCertificate(t, man, "example.org", hello)
211 }
212
213 func TestGetCertificate_failedAttempt(t *testing.T) {
214         ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
215                 w.WriteHeader(http.StatusBadRequest)
216         }))
217         defer ts.Close()
218
219         const example = "example.org"
220         d := createCertRetryAfter
221         f := testDidRemoveState
222         defer func() {
223                 createCertRetryAfter = d
224                 testDidRemoveState = f
225         }()
226         createCertRetryAfter = 0
227         done := make(chan struct{})
228         testDidRemoveState = func(domain string) {
229                 if domain != example {
230                         t.Errorf("testDidRemoveState: domain = %q; want %q", domain, example)
231                 }
232                 close(done)
233         }
234
235         key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
236         if err != nil {
237                 t.Fatal(err)
238         }
239         man := &Manager{
240                 Prompt: AcceptTOS,
241                 Client: &acme.Client{
242                         Key:          key,
243                         DirectoryURL: ts.URL,
244                 },
245         }
246         defer man.stopRenew()
247         hello := &tls.ClientHelloInfo{ServerName: example}
248         if _, err := man.GetCertificate(hello); err == nil {
249                 t.Error("GetCertificate: err is nil")
250         }
251         select {
252         case <-time.After(5 * time.Second):
253                 t.Errorf("took too long to remove the %q state", example)
254         case <-done:
255                 man.stateMu.Lock()
256                 defer man.stateMu.Unlock()
257                 if v, exist := man.state[example]; exist {
258                         t.Errorf("state exists for %q: %+v", example, v)
259                 }
260         }
261 }
262
263 // startACMEServerStub runs an ACME server
264 // The domain argument is the expected domain name of a certificate request.
265 func startACMEServerStub(t *testing.T, man *Manager, domain string) (url string, finish func()) {
266         // echo token-02 | shasum -a 256
267         // then divide result in 2 parts separated by dot
268         tokenCertName := "4e8eb87631187e9ff2153b56b13a4dec.13a35d002e485d60ff37354b32f665d9.token.acme.invalid"
269         verifyTokenCert := func() {
270                 hello := &tls.ClientHelloInfo{ServerName: tokenCertName}
271                 _, err := man.GetCertificate(hello)
272                 if err != nil {
273                         t.Errorf("verifyTokenCert: GetCertificate(%q): %v", tokenCertName, err)
274                         return
275                 }
276         }
277
278         // ACME CA server stub
279         var ca *httptest.Server
280         ca = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
281                 w.Header().Set("Replay-Nonce", "nonce")
282                 if r.Method == "HEAD" {
283                         // a nonce request
284                         return
285                 }
286
287                 switch r.URL.Path {
288                 // discovery
289                 case "/":
290                         if err := discoTmpl.Execute(w, ca.URL); err != nil {
291                                 t.Errorf("discoTmpl: %v", err)
292                         }
293                 // client key registration
294                 case "/new-reg":
295                         w.Write([]byte("{}"))
296                 // domain authorization
297                 case "/new-authz":
298                         w.Header().Set("Location", ca.URL+"/authz/1")
299                         w.WriteHeader(http.StatusCreated)
300                         if err := authzTmpl.Execute(w, ca.URL); err != nil {
301                                 t.Errorf("authzTmpl: %v", err)
302                         }
303                 // accept tls-sni-02 challenge
304                 case "/challenge/2":
305                         verifyTokenCert()
306                         w.Write([]byte("{}"))
307                 // authorization status
308                 case "/authz/1":
309                         w.Write([]byte(`{"status": "valid"}`))
310                 // cert request
311                 case "/new-cert":
312                         var req struct {
313                                 CSR string `json:"csr"`
314                         }
315                         decodePayload(&req, r.Body)
316                         b, _ := base64.RawURLEncoding.DecodeString(req.CSR)
317                         csr, err := x509.ParseCertificateRequest(b)
318                         if err != nil {
319                                 t.Errorf("new-cert: CSR: %v", err)
320                         }
321                         if csr.Subject.CommonName != domain {
322                                 t.Errorf("CommonName in CSR = %q; want %q", csr.Subject.CommonName, domain)
323                         }
324                         der, err := dummyCert(csr.PublicKey, domain)
325                         if err != nil {
326                                 t.Errorf("new-cert: dummyCert: %v", err)
327                         }
328                         chainUp := fmt.Sprintf("<%s/ca-cert>; rel=up", ca.URL)
329                         w.Header().Set("Link", chainUp)
330                         w.WriteHeader(http.StatusCreated)
331                         w.Write(der)
332                 // CA chain cert
333                 case "/ca-cert":
334                         der, err := dummyCert(nil, "ca")
335                         if err != nil {
336                                 t.Errorf("ca-cert: dummyCert: %v", err)
337                         }
338                         w.Write(der)
339                 default:
340                         t.Errorf("unrecognized r.URL.Path: %s", r.URL.Path)
341                 }
342         }))
343         finish = func() {
344                 ca.Close()
345
346                 // make sure token cert was removed
347                 cancel := make(chan struct{})
348                 done := make(chan struct{})
349                 go func() {
350                         defer close(done)
351                         tick := time.NewTicker(100 * time.Millisecond)
352                         defer tick.Stop()
353                         for {
354                                 hello := &tls.ClientHelloInfo{ServerName: tokenCertName}
355                                 if _, err := man.GetCertificate(hello); err != nil {
356                                         return
357                                 }
358                                 select {
359                                 case <-tick.C:
360                                 case <-cancel:
361                                         return
362                                 }
363                         }
364                 }()
365                 select {
366                 case <-done:
367                 case <-time.After(5 * time.Second):
368                         close(cancel)
369                         t.Error("token cert was not removed")
370                         <-done
371                 }
372         }
373         return ca.URL, finish
374 }
375
376 // tests man.GetCertificate flow using the provided hello argument.
377 // The domain argument is the expected domain name of a certificate request.
378 func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.ClientHelloInfo) {
379         url, finish := startACMEServerStub(t, man, domain)
380         defer finish()
381
382         // use EC key to run faster on 386
383         key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
384         if err != nil {
385                 t.Fatal(err)
386         }
387         man.Client = &acme.Client{
388                 Key:          key,
389                 DirectoryURL: url,
390         }
391
392         // simulate tls.Config.GetCertificate
393         var tlscert *tls.Certificate
394         done := make(chan struct{})
395         go func() {
396                 tlscert, err = man.GetCertificate(hello)
397                 close(done)
398         }()
399         select {
400         case <-time.After(time.Minute):
401                 t.Fatal("man.GetCertificate took too long to return")
402         case <-done:
403         }
404         if err != nil {
405                 t.Fatalf("man.GetCertificate: %v", err)
406         }
407
408         // verify the tlscert is the same we responded with from the CA stub
409         if len(tlscert.Certificate) == 0 {
410                 t.Fatal("len(tlscert.Certificate) is 0")
411         }
412         cert, err := x509.ParseCertificate(tlscert.Certificate[0])
413         if err != nil {
414                 t.Fatalf("x509.ParseCertificate: %v", err)
415         }
416         if len(cert.DNSNames) == 0 || cert.DNSNames[0] != domain {
417                 t.Errorf("cert.DNSNames = %v; want %q", cert.DNSNames, domain)
418         }
419
420 }
421
422 func TestAccountKeyCache(t *testing.T) {
423         m := Manager{Cache: newMemCache()}
424         ctx := context.Background()
425         k1, err := m.accountKey(ctx)
426         if err != nil {
427                 t.Fatal(err)
428         }
429         k2, err := m.accountKey(ctx)
430         if err != nil {
431                 t.Fatal(err)
432         }
433         if !reflect.DeepEqual(k1, k2) {
434                 t.Errorf("account keys don't match: k1 = %#v; k2 = %#v", k1, k2)
435         }
436 }
437
438 func TestCache(t *testing.T) {
439         privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
440         if err != nil {
441                 t.Fatal(err)
442         }
443         tmpl := &x509.Certificate{
444                 SerialNumber: big.NewInt(1),
445                 Subject:      pkix.Name{CommonName: "example.org"},
446                 NotAfter:     time.Now().Add(time.Hour),
447         }
448         pub, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &privKey.PublicKey, privKey)
449         if err != nil {
450                 t.Fatal(err)
451         }
452         tlscert := &tls.Certificate{
453                 Certificate: [][]byte{pub},
454                 PrivateKey:  privKey,
455         }
456
457         man := &Manager{Cache: newMemCache()}
458         defer man.stopRenew()
459         ctx := context.Background()
460         if err := man.cachePut(ctx, "example.org", tlscert); err != nil {
461                 t.Fatalf("man.cachePut: %v", err)
462         }
463         res, err := man.cacheGet(ctx, "example.org")
464         if err != nil {
465                 t.Fatalf("man.cacheGet: %v", err)
466         }
467         if res == nil {
468                 t.Fatal("res is nil")
469         }
470 }
471
472 func TestHostWhitelist(t *testing.T) {
473         policy := HostWhitelist("example.com", "example.org", "*.example.net")
474         tt := []struct {
475                 host  string
476                 allow bool
477         }{
478                 {"example.com", true},
479                 {"example.org", true},
480                 {"one.example.com", false},
481                 {"two.example.org", false},
482                 {"three.example.net", false},
483                 {"dummy", false},
484         }
485         for i, test := range tt {
486                 err := policy(nil, test.host)
487                 if err != nil && test.allow {
488                         t.Errorf("%d: policy(%q): %v; want nil", i, test.host, err)
489                 }
490                 if err == nil && !test.allow {
491                         t.Errorf("%d: policy(%q): nil; want an error", i, test.host)
492                 }
493         }
494 }
495
496 func TestValidCert(t *testing.T) {
497         key1, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
498         if err != nil {
499                 t.Fatal(err)
500         }
501         key2, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
502         if err != nil {
503                 t.Fatal(err)
504         }
505         key3, err := rsa.GenerateKey(rand.Reader, 512)
506         if err != nil {
507                 t.Fatal(err)
508         }
509         cert1, err := dummyCert(key1.Public(), "example.org")
510         if err != nil {
511                 t.Fatal(err)
512         }
513         cert2, err := dummyCert(key2.Public(), "example.org")
514         if err != nil {
515                 t.Fatal(err)
516         }
517         cert3, err := dummyCert(key3.Public(), "example.org")
518         if err != nil {
519                 t.Fatal(err)
520         }
521         now := time.Now()
522         early, err := dateDummyCert(key1.Public(), now.Add(time.Hour), now.Add(2*time.Hour), "example.org")
523         if err != nil {
524                 t.Fatal(err)
525         }
526         expired, err := dateDummyCert(key1.Public(), now.Add(-2*time.Hour), now.Add(-time.Hour), "example.org")
527         if err != nil {
528                 t.Fatal(err)
529         }
530
531         tt := []struct {
532                 domain string
533                 key    crypto.Signer
534                 cert   [][]byte
535                 ok     bool
536         }{
537                 {"example.org", key1, [][]byte{cert1}, true},
538                 {"example.org", key3, [][]byte{cert3}, true},
539                 {"example.org", key1, [][]byte{cert1, cert2, cert3}, true},
540                 {"example.org", key1, [][]byte{cert1, {1}}, false},
541                 {"example.org", key1, [][]byte{{1}}, false},
542                 {"example.org", key1, [][]byte{cert2}, false},
543                 {"example.org", key2, [][]byte{cert1}, false},
544                 {"example.org", key1, [][]byte{cert3}, false},
545                 {"example.org", key3, [][]byte{cert1}, false},
546                 {"example.net", key1, [][]byte{cert1}, false},
547                 {"example.org", key1, [][]byte{early}, false},
548                 {"example.org", key1, [][]byte{expired}, false},
549         }
550         for i, test := range tt {
551                 leaf, err := validCert(test.domain, test.cert, test.key)
552                 if err != nil && test.ok {
553                         t.Errorf("%d: err = %v", i, err)
554                 }
555                 if err == nil && !test.ok {
556                         t.Errorf("%d: err is nil", i)
557                 }
558                 if err == nil && test.ok && leaf == nil {
559                         t.Errorf("%d: leaf is nil", i)
560                 }
561         }
562 }
563
564 type cacheGetFunc func(ctx context.Context, key string) ([]byte, error)
565
566 func (f cacheGetFunc) Get(ctx context.Context, key string) ([]byte, error) {
567         return f(ctx, key)
568 }
569
570 func (f cacheGetFunc) Put(ctx context.Context, key string, data []byte) error {
571         return fmt.Errorf("unsupported Put of %q = %q", key, data)
572 }
573
574 func (f cacheGetFunc) Delete(ctx context.Context, key string) error {
575         return fmt.Errorf("unsupported Delete of %q", key)
576 }
577
578 func TestManagerGetCertificateBogusSNI(t *testing.T) {
579         m := Manager{
580                 Prompt: AcceptTOS,
581                 Cache: cacheGetFunc(func(ctx context.Context, key string) ([]byte, error) {
582                         return nil, fmt.Errorf("cache.Get of %s", key)
583                 }),
584         }
585         tests := []struct {
586                 name    string
587                 wantErr string
588         }{
589                 {"foo.com", "cache.Get of foo.com"},
590                 {"foo.com.", "cache.Get of foo.com"},
591                 {`a\b.com`, "acme/autocert: server name contains invalid character"},
592                 {`a/b.com`, "acme/autocert: server name contains invalid character"},
593                 {"", "acme/autocert: missing server name"},
594                 {"foo", "acme/autocert: server name component count invalid"},
595                 {".foo", "acme/autocert: server name component count invalid"},
596                 {"foo.", "acme/autocert: server name component count invalid"},
597                 {"fo.o", "cache.Get of fo.o"},
598         }
599         for _, tt := range tests {
600                 _, err := m.GetCertificate(&tls.ClientHelloInfo{ServerName: tt.name})
601                 got := fmt.Sprint(err)
602                 if got != tt.wantErr {
603                         t.Errorf("GetCertificate(SNI = %q) = %q; want %q", tt.name, got, tt.wantErr)
604                 }
605         }
606 }