1 // Copyright 2016 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
30 "golang.org/x/crypto/acme"
33 var discoTmpl = template.Must(template.New("disco").Parse(`{
34 "new-reg": "{{.}}/new-reg",
35 "new-authz": "{{.}}/new-authz",
36 "new-cert": "{{.}}/new-cert"
39 var authzTmpl = template.Must(template.New("authz").Parse(`{
43 "uri": "{{.}}/challenge/1",
48 "uri": "{{.}}/challenge/2",
55 type memCache struct {
57 keyData map[string][]byte
60 func (m *memCache) Get(ctx context.Context, key string) ([]byte, error) {
64 v, ok := m.keyData[key]
66 return nil, ErrCacheMiss
71 func (m *memCache) Put(ctx context.Context, key string, data []byte) error {
79 func (m *memCache) Delete(ctx context.Context, key string) error {
83 delete(m.keyData, key)
87 func newMemCache() *memCache {
89 keyData: make(map[string][]byte),
93 func dummyCert(pub interface{}, san ...string) ([]byte, error) {
94 return dateDummyCert(pub, time.Now(), time.Now().Add(90*24*time.Hour), san...)
97 func dateDummyCert(pub interface{}, start, end time.Time, san ...string) ([]byte, error) {
98 // use EC key to run faster on 386
99 key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
103 t := &x509.Certificate{
104 SerialNumber: big.NewInt(1),
107 BasicConstraintsValid: true,
108 KeyUsage: x509.KeyUsageKeyEncipherment,
114 return x509.CreateCertificate(rand.Reader, t, t, pub, key)
117 func decodePayload(v interface{}, r io.Reader) error {
118 var req struct{ Payload string }
119 if err := json.NewDecoder(r).Decode(&req); err != nil {
122 payload, err := base64.RawURLEncoding.DecodeString(req.Payload)
126 return json.Unmarshal(payload, v)
129 func TestGetCertificate(t *testing.T) {
130 man := &Manager{Prompt: AcceptTOS}
131 defer man.stopRenew()
132 hello := &tls.ClientHelloInfo{ServerName: "example.org"}
133 testGetCertificate(t, man, "example.org", hello)
136 func TestGetCertificate_trailingDot(t *testing.T) {
137 man := &Manager{Prompt: AcceptTOS}
138 defer man.stopRenew()
139 hello := &tls.ClientHelloInfo{ServerName: "example.org."}
140 testGetCertificate(t, man, "example.org", hello)
143 func TestGetCertificate_ForceRSA(t *testing.T) {
146 Cache: newMemCache(),
149 defer man.stopRenew()
150 hello := &tls.ClientHelloInfo{ServerName: "example.org"}
151 testGetCertificate(t, man, "example.org", hello)
153 cert, err := man.cacheGet(context.Background(), "example.org")
155 t.Fatalf("man.cacheGet: %v", err)
157 if _, ok := cert.PrivateKey.(*rsa.PrivateKey); !ok {
158 t.Errorf("cert.PrivateKey is %T; want *rsa.PrivateKey", cert.PrivateKey)
162 func TestGetCertificate_nilPrompt(t *testing.T) {
164 defer man.stopRenew()
165 url, finish := startACMEServerStub(t, man, "example.org")
167 key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
171 man.Client = &acme.Client{
175 hello := &tls.ClientHelloInfo{ServerName: "example.org"}
176 if _, err := man.GetCertificate(hello); err == nil {
177 t.Error("got certificate for example.org; wanted error")
181 func TestGetCertificate_expiredCache(t *testing.T) {
182 // Make an expired cert and cache it.
183 pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
187 tmpl := &x509.Certificate{
188 SerialNumber: big.NewInt(1),
189 Subject: pkix.Name{CommonName: "example.org"},
190 NotAfter: time.Now(),
192 pub, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &pk.PublicKey, pk)
196 tlscert := &tls.Certificate{
197 Certificate: [][]byte{pub},
201 man := &Manager{Prompt: AcceptTOS, Cache: newMemCache()}
202 defer man.stopRenew()
203 if err := man.cachePut(context.Background(), "example.org", tlscert); err != nil {
204 t.Fatalf("man.cachePut: %v", err)
207 // The expired cached cert should trigger a new cert issuance
208 // and return without an error.
209 hello := &tls.ClientHelloInfo{ServerName: "example.org"}
210 testGetCertificate(t, man, "example.org", hello)
213 func TestGetCertificate_failedAttempt(t *testing.T) {
214 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
215 w.WriteHeader(http.StatusBadRequest)
219 const example = "example.org"
220 d := createCertRetryAfter
221 f := testDidRemoveState
223 createCertRetryAfter = d
224 testDidRemoveState = f
226 createCertRetryAfter = 0
227 done := make(chan struct{})
228 testDidRemoveState = func(domain string) {
229 if domain != example {
230 t.Errorf("testDidRemoveState: domain = %q; want %q", domain, example)
235 key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
241 Client: &acme.Client{
243 DirectoryURL: ts.URL,
246 defer man.stopRenew()
247 hello := &tls.ClientHelloInfo{ServerName: example}
248 if _, err := man.GetCertificate(hello); err == nil {
249 t.Error("GetCertificate: err is nil")
252 case <-time.After(5 * time.Second):
253 t.Errorf("took too long to remove the %q state", example)
256 defer man.stateMu.Unlock()
257 if v, exist := man.state[example]; exist {
258 t.Errorf("state exists for %q: %+v", example, v)
263 // startACMEServerStub runs an ACME server
264 // The domain argument is the expected domain name of a certificate request.
265 func startACMEServerStub(t *testing.T, man *Manager, domain string) (url string, finish func()) {
266 // echo token-02 | shasum -a 256
267 // then divide result in 2 parts separated by dot
268 tokenCertName := "4e8eb87631187e9ff2153b56b13a4dec.13a35d002e485d60ff37354b32f665d9.token.acme.invalid"
269 verifyTokenCert := func() {
270 hello := &tls.ClientHelloInfo{ServerName: tokenCertName}
271 _, err := man.GetCertificate(hello)
273 t.Errorf("verifyTokenCert: GetCertificate(%q): %v", tokenCertName, err)
278 // ACME CA server stub
279 var ca *httptest.Server
280 ca = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
281 w.Header().Set("Replay-Nonce", "nonce")
282 if r.Method == "HEAD" {
290 if err := discoTmpl.Execute(w, ca.URL); err != nil {
291 t.Errorf("discoTmpl: %v", err)
293 // client key registration
295 w.Write([]byte("{}"))
296 // domain authorization
298 w.Header().Set("Location", ca.URL+"/authz/1")
299 w.WriteHeader(http.StatusCreated)
300 if err := authzTmpl.Execute(w, ca.URL); err != nil {
301 t.Errorf("authzTmpl: %v", err)
303 // accept tls-sni-02 challenge
306 w.Write([]byte("{}"))
307 // authorization status
309 w.Write([]byte(`{"status": "valid"}`))
313 CSR string `json:"csr"`
315 decodePayload(&req, r.Body)
316 b, _ := base64.RawURLEncoding.DecodeString(req.CSR)
317 csr, err := x509.ParseCertificateRequest(b)
319 t.Errorf("new-cert: CSR: %v", err)
321 if csr.Subject.CommonName != domain {
322 t.Errorf("CommonName in CSR = %q; want %q", csr.Subject.CommonName, domain)
324 der, err := dummyCert(csr.PublicKey, domain)
326 t.Errorf("new-cert: dummyCert: %v", err)
328 chainUp := fmt.Sprintf("<%s/ca-cert>; rel=up", ca.URL)
329 w.Header().Set("Link", chainUp)
330 w.WriteHeader(http.StatusCreated)
334 der, err := dummyCert(nil, "ca")
336 t.Errorf("ca-cert: dummyCert: %v", err)
340 t.Errorf("unrecognized r.URL.Path: %s", r.URL.Path)
346 // make sure token cert was removed
347 cancel := make(chan struct{})
348 done := make(chan struct{})
351 tick := time.NewTicker(100 * time.Millisecond)
354 hello := &tls.ClientHelloInfo{ServerName: tokenCertName}
355 if _, err := man.GetCertificate(hello); err != nil {
367 case <-time.After(5 * time.Second):
369 t.Error("token cert was not removed")
373 return ca.URL, finish
376 // tests man.GetCertificate flow using the provided hello argument.
377 // The domain argument is the expected domain name of a certificate request.
378 func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.ClientHelloInfo) {
379 url, finish := startACMEServerStub(t, man, domain)
382 // use EC key to run faster on 386
383 key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
387 man.Client = &acme.Client{
392 // simulate tls.Config.GetCertificate
393 var tlscert *tls.Certificate
394 done := make(chan struct{})
396 tlscert, err = man.GetCertificate(hello)
400 case <-time.After(time.Minute):
401 t.Fatal("man.GetCertificate took too long to return")
405 t.Fatalf("man.GetCertificate: %v", err)
408 // verify the tlscert is the same we responded with from the CA stub
409 if len(tlscert.Certificate) == 0 {
410 t.Fatal("len(tlscert.Certificate) is 0")
412 cert, err := x509.ParseCertificate(tlscert.Certificate[0])
414 t.Fatalf("x509.ParseCertificate: %v", err)
416 if len(cert.DNSNames) == 0 || cert.DNSNames[0] != domain {
417 t.Errorf("cert.DNSNames = %v; want %q", cert.DNSNames, domain)
422 func TestAccountKeyCache(t *testing.T) {
423 m := Manager{Cache: newMemCache()}
424 ctx := context.Background()
425 k1, err := m.accountKey(ctx)
429 k2, err := m.accountKey(ctx)
433 if !reflect.DeepEqual(k1, k2) {
434 t.Errorf("account keys don't match: k1 = %#v; k2 = %#v", k1, k2)
438 func TestCache(t *testing.T) {
439 privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
443 tmpl := &x509.Certificate{
444 SerialNumber: big.NewInt(1),
445 Subject: pkix.Name{CommonName: "example.org"},
446 NotAfter: time.Now().Add(time.Hour),
448 pub, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &privKey.PublicKey, privKey)
452 tlscert := &tls.Certificate{
453 Certificate: [][]byte{pub},
457 man := &Manager{Cache: newMemCache()}
458 defer man.stopRenew()
459 ctx := context.Background()
460 if err := man.cachePut(ctx, "example.org", tlscert); err != nil {
461 t.Fatalf("man.cachePut: %v", err)
463 res, err := man.cacheGet(ctx, "example.org")
465 t.Fatalf("man.cacheGet: %v", err)
468 t.Fatal("res is nil")
472 func TestHostWhitelist(t *testing.T) {
473 policy := HostWhitelist("example.com", "example.org", "*.example.net")
478 {"example.com", true},
479 {"example.org", true},
480 {"one.example.com", false},
481 {"two.example.org", false},
482 {"three.example.net", false},
485 for i, test := range tt {
486 err := policy(nil, test.host)
487 if err != nil && test.allow {
488 t.Errorf("%d: policy(%q): %v; want nil", i, test.host, err)
490 if err == nil && !test.allow {
491 t.Errorf("%d: policy(%q): nil; want an error", i, test.host)
496 func TestValidCert(t *testing.T) {
497 key1, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
501 key2, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
505 key3, err := rsa.GenerateKey(rand.Reader, 512)
509 cert1, err := dummyCert(key1.Public(), "example.org")
513 cert2, err := dummyCert(key2.Public(), "example.org")
517 cert3, err := dummyCert(key3.Public(), "example.org")
522 early, err := dateDummyCert(key1.Public(), now.Add(time.Hour), now.Add(2*time.Hour), "example.org")
526 expired, err := dateDummyCert(key1.Public(), now.Add(-2*time.Hour), now.Add(-time.Hour), "example.org")
537 {"example.org", key1, [][]byte{cert1}, true},
538 {"example.org", key3, [][]byte{cert3}, true},
539 {"example.org", key1, [][]byte{cert1, cert2, cert3}, true},
540 {"example.org", key1, [][]byte{cert1, {1}}, false},
541 {"example.org", key1, [][]byte{{1}}, false},
542 {"example.org", key1, [][]byte{cert2}, false},
543 {"example.org", key2, [][]byte{cert1}, false},
544 {"example.org", key1, [][]byte{cert3}, false},
545 {"example.org", key3, [][]byte{cert1}, false},
546 {"example.net", key1, [][]byte{cert1}, false},
547 {"example.org", key1, [][]byte{early}, false},
548 {"example.org", key1, [][]byte{expired}, false},
550 for i, test := range tt {
551 leaf, err := validCert(test.domain, test.cert, test.key)
552 if err != nil && test.ok {
553 t.Errorf("%d: err = %v", i, err)
555 if err == nil && !test.ok {
556 t.Errorf("%d: err is nil", i)
558 if err == nil && test.ok && leaf == nil {
559 t.Errorf("%d: leaf is nil", i)
564 type cacheGetFunc func(ctx context.Context, key string) ([]byte, error)
566 func (f cacheGetFunc) Get(ctx context.Context, key string) ([]byte, error) {
570 func (f cacheGetFunc) Put(ctx context.Context, key string, data []byte) error {
571 return fmt.Errorf("unsupported Put of %q = %q", key, data)
574 func (f cacheGetFunc) Delete(ctx context.Context, key string) error {
575 return fmt.Errorf("unsupported Delete of %q", key)
578 func TestManagerGetCertificateBogusSNI(t *testing.T) {
581 Cache: cacheGetFunc(func(ctx context.Context, key string) ([]byte, error) {
582 return nil, fmt.Errorf("cache.Get of %s", key)
589 {"foo.com", "cache.Get of foo.com"},
590 {"foo.com.", "cache.Get of foo.com"},
591 {`a\b.com`, "acme/autocert: server name contains invalid character"},
592 {`a/b.com`, "acme/autocert: server name contains invalid character"},
593 {"", "acme/autocert: missing server name"},
594 {"foo", "acme/autocert: server name component count invalid"},
595 {".foo", "acme/autocert: server name component count invalid"},
596 {"foo.", "acme/autocert: server name component count invalid"},
597 {"fo.o", "cache.Get of fo.o"},
599 for _, tt := range tests {
600 _, err := m.GetCertificate(&tls.ClientHelloInfo{ServerName: tt.name})
601 got := fmt.Sprint(err)
602 if got != tt.wantErr {
603 t.Errorf("GetCertificate(SNI = %q) = %q; want %q", tt.name, got, tt.wantErr)