OSDN Git Service

new repo
[bytom/vapor.git] / vendor / golang.org / x / net / http2 / transport_test.go
1 // Copyright 2015 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package http2
6
7 import (
8         "bufio"
9         "bytes"
10         "crypto/tls"
11         "errors"
12         "flag"
13         "fmt"
14         "io"
15         "io/ioutil"
16         "math/rand"
17         "net"
18         "net/http"
19         "net/http/httptest"
20         "net/url"
21         "os"
22         "reflect"
23         "runtime"
24         "sort"
25         "strconv"
26         "strings"
27         "sync"
28         "sync/atomic"
29         "testing"
30         "time"
31
32         "golang.org/x/net/http2/hpack"
33 )
34
35 var (
36         extNet        = flag.Bool("extnet", false, "do external network tests")
37         transportHost = flag.String("transporthost", "http2.golang.org", "hostname to use for TestTransport")
38         insecure      = flag.Bool("insecure", false, "insecure TLS dials") // TODO: dead code. remove?
39 )
40
41 var tlsConfigInsecure = &tls.Config{InsecureSkipVerify: true}
42
43 type testContext struct{}
44
45 func (testContext) Done() <-chan struct{}                   { return make(chan struct{}) }
46 func (testContext) Err() error                              { panic("should not be called") }
47 func (testContext) Deadline() (deadline time.Time, ok bool) { return time.Time{}, false }
48 func (testContext) Value(key interface{}) interface{}       { return nil }
49
50 func TestTransportExternal(t *testing.T) {
51         if !*extNet {
52                 t.Skip("skipping external network test")
53         }
54         req, _ := http.NewRequest("GET", "https://"+*transportHost+"/", nil)
55         rt := &Transport{TLSClientConfig: tlsConfigInsecure}
56         res, err := rt.RoundTrip(req)
57         if err != nil {
58                 t.Fatalf("%v", err)
59         }
60         res.Write(os.Stdout)
61 }
62
63 type fakeTLSConn struct {
64         net.Conn
65 }
66
67 func (c *fakeTLSConn) ConnectionState() tls.ConnectionState {
68         return tls.ConnectionState{
69                 Version:     tls.VersionTLS12,
70                 CipherSuite: cipher_TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
71         }
72 }
73
74 func startH2cServer(t *testing.T) net.Listener {
75         h2Server := &Server{}
76         l := newLocalListener(t)
77         go func() {
78                 conn, err := l.Accept()
79                 if err != nil {
80                         t.Error(err)
81                         return
82                 }
83                 h2Server.ServeConn(&fakeTLSConn{conn}, &ServeConnOpts{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
84                         fmt.Fprintf(w, "Hello, %v, http: %v", r.URL.Path, r.TLS == nil)
85                 })})
86         }()
87         return l
88 }
89
90 func TestTransportH2c(t *testing.T) {
91         l := startH2cServer(t)
92         defer l.Close()
93         req, err := http.NewRequest("GET", "http://"+l.Addr().String()+"/foobar", nil)
94         if err != nil {
95                 t.Fatal(err)
96         }
97         tr := &Transport{
98                 AllowHTTP: true,
99                 DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
100                         return net.Dial(network, addr)
101                 },
102         }
103         res, err := tr.RoundTrip(req)
104         if err != nil {
105                 t.Fatal(err)
106         }
107         if res.ProtoMajor != 2 {
108                 t.Fatal("proto not h2c")
109         }
110         body, err := ioutil.ReadAll(res.Body)
111         if err != nil {
112                 t.Fatal(err)
113         }
114         if got, want := string(body), "Hello, /foobar, http: true"; got != want {
115                 t.Fatalf("response got %v, want %v", got, want)
116         }
117 }
118
119 func TestTransport(t *testing.T) {
120         const body = "sup"
121         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
122                 io.WriteString(w, body)
123         }, optOnlyServer)
124         defer st.Close()
125
126         tr := &Transport{TLSClientConfig: tlsConfigInsecure}
127         defer tr.CloseIdleConnections()
128
129         req, err := http.NewRequest("GET", st.ts.URL, nil)
130         if err != nil {
131                 t.Fatal(err)
132         }
133         res, err := tr.RoundTrip(req)
134         if err != nil {
135                 t.Fatal(err)
136         }
137         defer res.Body.Close()
138
139         t.Logf("Got res: %+v", res)
140         if g, w := res.StatusCode, 200; g != w {
141                 t.Errorf("StatusCode = %v; want %v", g, w)
142         }
143         if g, w := res.Status, "200 OK"; g != w {
144                 t.Errorf("Status = %q; want %q", g, w)
145         }
146         wantHeader := http.Header{
147                 "Content-Length": []string{"3"},
148                 "Content-Type":   []string{"text/plain; charset=utf-8"},
149                 "Date":           []string{"XXX"}, // see cleanDate
150         }
151         cleanDate(res)
152         if !reflect.DeepEqual(res.Header, wantHeader) {
153                 t.Errorf("res Header = %v; want %v", res.Header, wantHeader)
154         }
155         if res.Request != req {
156                 t.Errorf("Response.Request = %p; want %p", res.Request, req)
157         }
158         if res.TLS == nil {
159                 t.Error("Response.TLS = nil; want non-nil")
160         }
161         slurp, err := ioutil.ReadAll(res.Body)
162         if err != nil {
163                 t.Errorf("Body read: %v", err)
164         } else if string(slurp) != body {
165                 t.Errorf("Body = %q; want %q", slurp, body)
166         }
167 }
168
169 func onSameConn(t *testing.T, modReq func(*http.Request)) bool {
170         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
171                 io.WriteString(w, r.RemoteAddr)
172         }, optOnlyServer, func(c net.Conn, st http.ConnState) {
173                 t.Logf("conn %v is now state %v", c.RemoteAddr(), st)
174         })
175         defer st.Close()
176         tr := &Transport{TLSClientConfig: tlsConfigInsecure}
177         defer tr.CloseIdleConnections()
178         get := func() string {
179                 req, err := http.NewRequest("GET", st.ts.URL, nil)
180                 if err != nil {
181                         t.Fatal(err)
182                 }
183                 modReq(req)
184                 res, err := tr.RoundTrip(req)
185                 if err != nil {
186                         t.Fatal(err)
187                 }
188                 defer res.Body.Close()
189                 slurp, err := ioutil.ReadAll(res.Body)
190                 if err != nil {
191                         t.Fatalf("Body read: %v", err)
192                 }
193                 addr := strings.TrimSpace(string(slurp))
194                 if addr == "" {
195                         t.Fatalf("didn't get an addr in response")
196                 }
197                 return addr
198         }
199         first := get()
200         second := get()
201         return first == second
202 }
203
204 func TestTransportReusesConns(t *testing.T) {
205         if !onSameConn(t, func(*http.Request) {}) {
206                 t.Errorf("first and second responses were on different connections")
207         }
208 }
209
210 func TestTransportReusesConn_RequestClose(t *testing.T) {
211         if onSameConn(t, func(r *http.Request) { r.Close = true }) {
212                 t.Errorf("first and second responses were not on different connections")
213         }
214 }
215
216 func TestTransportReusesConn_ConnClose(t *testing.T) {
217         if onSameConn(t, func(r *http.Request) { r.Header.Set("Connection", "close") }) {
218                 t.Errorf("first and second responses were not on different connections")
219         }
220 }
221
222 // Tests that the Transport only keeps one pending dial open per destination address.
223 // https://golang.org/issue/13397
224 func TestTransportGroupsPendingDials(t *testing.T) {
225         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
226                 io.WriteString(w, r.RemoteAddr)
227         }, optOnlyServer)
228         defer st.Close()
229         tr := &Transport{
230                 TLSClientConfig: tlsConfigInsecure,
231         }
232         defer tr.CloseIdleConnections()
233         var (
234                 mu    sync.Mutex
235                 dials = map[string]int{}
236         )
237         var wg sync.WaitGroup
238         for i := 0; i < 10; i++ {
239                 wg.Add(1)
240                 go func() {
241                         defer wg.Done()
242                         req, err := http.NewRequest("GET", st.ts.URL, nil)
243                         if err != nil {
244                                 t.Error(err)
245                                 return
246                         }
247                         res, err := tr.RoundTrip(req)
248                         if err != nil {
249                                 t.Error(err)
250                                 return
251                         }
252                         defer res.Body.Close()
253                         slurp, err := ioutil.ReadAll(res.Body)
254                         if err != nil {
255                                 t.Errorf("Body read: %v", err)
256                         }
257                         addr := strings.TrimSpace(string(slurp))
258                         if addr == "" {
259                                 t.Errorf("didn't get an addr in response")
260                         }
261                         mu.Lock()
262                         dials[addr]++
263                         mu.Unlock()
264                 }()
265         }
266         wg.Wait()
267         if len(dials) != 1 {
268                 t.Errorf("saw %d dials; want 1: %v", len(dials), dials)
269         }
270         tr.CloseIdleConnections()
271         if err := retry(50, 10*time.Millisecond, func() error {
272                 cp, ok := tr.connPool().(*clientConnPool)
273                 if !ok {
274                         return fmt.Errorf("Conn pool is %T; want *clientConnPool", tr.connPool())
275                 }
276                 cp.mu.Lock()
277                 defer cp.mu.Unlock()
278                 if len(cp.dialing) != 0 {
279                         return fmt.Errorf("dialing map = %v; want empty", cp.dialing)
280                 }
281                 if len(cp.conns) != 0 {
282                         return fmt.Errorf("conns = %v; want empty", cp.conns)
283                 }
284                 if len(cp.keys) != 0 {
285                         return fmt.Errorf("keys = %v; want empty", cp.keys)
286                 }
287                 return nil
288         }); err != nil {
289                 t.Errorf("State of pool after CloseIdleConnections: %v", err)
290         }
291 }
292
293 func retry(tries int, delay time.Duration, fn func() error) error {
294         var err error
295         for i := 0; i < tries; i++ {
296                 err = fn()
297                 if err == nil {
298                         return nil
299                 }
300                 time.Sleep(delay)
301         }
302         return err
303 }
304
305 func TestTransportAbortClosesPipes(t *testing.T) {
306         shutdown := make(chan struct{})
307         st := newServerTester(t,
308                 func(w http.ResponseWriter, r *http.Request) {
309                         w.(http.Flusher).Flush()
310                         <-shutdown
311                 },
312                 optOnlyServer,
313         )
314         defer st.Close()
315         defer close(shutdown) // we must shutdown before st.Close() to avoid hanging
316
317         done := make(chan struct{})
318         requestMade := make(chan struct{})
319         go func() {
320                 defer close(done)
321                 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
322                 req, err := http.NewRequest("GET", st.ts.URL, nil)
323                 if err != nil {
324                         t.Fatal(err)
325                 }
326                 res, err := tr.RoundTrip(req)
327                 if err != nil {
328                         t.Fatal(err)
329                 }
330                 defer res.Body.Close()
331                 close(requestMade)
332                 _, err = ioutil.ReadAll(res.Body)
333                 if err == nil {
334                         t.Error("expected error from res.Body.Read")
335                 }
336         }()
337
338         <-requestMade
339         // Now force the serve loop to end, via closing the connection.
340         st.closeConn()
341         // deadlock? that's a bug.
342         select {
343         case <-done:
344         case <-time.After(3 * time.Second):
345                 t.Fatal("timeout")
346         }
347 }
348
349 // TODO: merge this with TestTransportBody to make TestTransportRequest? This
350 // could be a table-driven test with extra goodies.
351 func TestTransportPath(t *testing.T) {
352         gotc := make(chan *url.URL, 1)
353         st := newServerTester(t,
354                 func(w http.ResponseWriter, r *http.Request) {
355                         gotc <- r.URL
356                 },
357                 optOnlyServer,
358         )
359         defer st.Close()
360
361         tr := &Transport{TLSClientConfig: tlsConfigInsecure}
362         defer tr.CloseIdleConnections()
363         const (
364                 path  = "/testpath"
365                 query = "q=1"
366         )
367         surl := st.ts.URL + path + "?" + query
368         req, err := http.NewRequest("POST", surl, nil)
369         if err != nil {
370                 t.Fatal(err)
371         }
372         c := &http.Client{Transport: tr}
373         res, err := c.Do(req)
374         if err != nil {
375                 t.Fatal(err)
376         }
377         defer res.Body.Close()
378         got := <-gotc
379         if got.Path != path {
380                 t.Errorf("Read Path = %q; want %q", got.Path, path)
381         }
382         if got.RawQuery != query {
383                 t.Errorf("Read RawQuery = %q; want %q", got.RawQuery, query)
384         }
385 }
386
387 func randString(n int) string {
388         rnd := rand.New(rand.NewSource(int64(n)))
389         b := make([]byte, n)
390         for i := range b {
391                 b[i] = byte(rnd.Intn(256))
392         }
393         return string(b)
394 }
395
396 type panicReader struct{}
397
398 func (panicReader) Read([]byte) (int, error) { panic("unexpected Read") }
399 func (panicReader) Close() error             { panic("unexpected Close") }
400
401 func TestActualContentLength(t *testing.T) {
402         tests := []struct {
403                 req  *http.Request
404                 want int64
405         }{
406                 // Verify we don't read from Body:
407                 0: {
408                         req:  &http.Request{Body: panicReader{}},
409                         want: -1,
410                 },
411                 // nil Body means 0, regardless of ContentLength:
412                 1: {
413                         req:  &http.Request{Body: nil, ContentLength: 5},
414                         want: 0,
415                 },
416                 // ContentLength is used if set.
417                 2: {
418                         req:  &http.Request{Body: panicReader{}, ContentLength: 5},
419                         want: 5,
420                 },
421                 // http.NoBody means 0, not -1.
422                 3: {
423                         req:  &http.Request{Body: go18httpNoBody()},
424                         want: 0,
425                 },
426         }
427         for i, tt := range tests {
428                 got := actualContentLength(tt.req)
429                 if got != tt.want {
430                         t.Errorf("test[%d]: got %d; want %d", i, got, tt.want)
431                 }
432         }
433 }
434
435 func TestTransportBody(t *testing.T) {
436         bodyTests := []struct {
437                 body         string
438                 noContentLen bool
439         }{
440                 {body: "some message"},
441                 {body: "some message", noContentLen: true},
442                 {body: strings.Repeat("a", 1<<20), noContentLen: true},
443                 {body: strings.Repeat("a", 1<<20)},
444                 {body: randString(16<<10 - 1)},
445                 {body: randString(16 << 10)},
446                 {body: randString(16<<10 + 1)},
447                 {body: randString(512<<10 - 1)},
448                 {body: randString(512 << 10)},
449                 {body: randString(512<<10 + 1)},
450                 {body: randString(1<<20 - 1)},
451                 {body: randString(1 << 20)},
452                 {body: randString(1<<20 + 2)},
453         }
454
455         type reqInfo struct {
456                 req   *http.Request
457                 slurp []byte
458                 err   error
459         }
460         gotc := make(chan reqInfo, 1)
461         st := newServerTester(t,
462                 func(w http.ResponseWriter, r *http.Request) {
463                         slurp, err := ioutil.ReadAll(r.Body)
464                         if err != nil {
465                                 gotc <- reqInfo{err: err}
466                         } else {
467                                 gotc <- reqInfo{req: r, slurp: slurp}
468                         }
469                 },
470                 optOnlyServer,
471         )
472         defer st.Close()
473
474         for i, tt := range bodyTests {
475                 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
476                 defer tr.CloseIdleConnections()
477
478                 var body io.Reader = strings.NewReader(tt.body)
479                 if tt.noContentLen {
480                         body = struct{ io.Reader }{body} // just a Reader, hiding concrete type and other methods
481                 }
482                 req, err := http.NewRequest("POST", st.ts.URL, body)
483                 if err != nil {
484                         t.Fatalf("#%d: %v", i, err)
485                 }
486                 c := &http.Client{Transport: tr}
487                 res, err := c.Do(req)
488                 if err != nil {
489                         t.Fatalf("#%d: %v", i, err)
490                 }
491                 defer res.Body.Close()
492                 ri := <-gotc
493                 if ri.err != nil {
494                         t.Errorf("#%d: read error: %v", i, ri.err)
495                         continue
496                 }
497                 if got := string(ri.slurp); got != tt.body {
498                         t.Errorf("#%d: Read body mismatch.\n got: %q (len %d)\nwant: %q (len %d)", i, shortString(got), len(got), shortString(tt.body), len(tt.body))
499                 }
500                 wantLen := int64(len(tt.body))
501                 if tt.noContentLen && tt.body != "" {
502                         wantLen = -1
503                 }
504                 if ri.req.ContentLength != wantLen {
505                         t.Errorf("#%d. handler got ContentLength = %v; want %v", i, ri.req.ContentLength, wantLen)
506                 }
507         }
508 }
509
510 func shortString(v string) string {
511         const maxLen = 100
512         if len(v) <= maxLen {
513                 return v
514         }
515         return fmt.Sprintf("%v[...%d bytes omitted...]%v", v[:maxLen/2], len(v)-maxLen, v[len(v)-maxLen/2:])
516 }
517
518 func TestTransportDialTLS(t *testing.T) {
519         var mu sync.Mutex // guards following
520         var gotReq, didDial bool
521
522         ts := newServerTester(t,
523                 func(w http.ResponseWriter, r *http.Request) {
524                         mu.Lock()
525                         gotReq = true
526                         mu.Unlock()
527                 },
528                 optOnlyServer,
529         )
530         defer ts.Close()
531         tr := &Transport{
532                 DialTLS: func(netw, addr string, cfg *tls.Config) (net.Conn, error) {
533                         mu.Lock()
534                         didDial = true
535                         mu.Unlock()
536                         cfg.InsecureSkipVerify = true
537                         c, err := tls.Dial(netw, addr, cfg)
538                         if err != nil {
539                                 return nil, err
540                         }
541                         return c, c.Handshake()
542                 },
543         }
544         defer tr.CloseIdleConnections()
545         client := &http.Client{Transport: tr}
546         res, err := client.Get(ts.ts.URL)
547         if err != nil {
548                 t.Fatal(err)
549         }
550         res.Body.Close()
551         mu.Lock()
552         if !gotReq {
553                 t.Error("didn't get request")
554         }
555         if !didDial {
556                 t.Error("didn't use dial hook")
557         }
558 }
559
560 func TestConfigureTransport(t *testing.T) {
561         t1 := &http.Transport{}
562         err := ConfigureTransport(t1)
563         if err == errTransportVersion {
564                 t.Skip(err)
565         }
566         if err != nil {
567                 t.Fatal(err)
568         }
569         if got := fmt.Sprintf("%#v", t1); !strings.Contains(got, `"h2"`) {
570                 // Laziness, to avoid buildtags.
571                 t.Errorf("stringification of HTTP/1 transport didn't contain \"h2\": %v", got)
572         }
573         wantNextProtos := []string{"h2", "http/1.1"}
574         if t1.TLSClientConfig == nil {
575                 t.Errorf("nil t1.TLSClientConfig")
576         } else if !reflect.DeepEqual(t1.TLSClientConfig.NextProtos, wantNextProtos) {
577                 t.Errorf("TLSClientConfig.NextProtos = %q; want %q", t1.TLSClientConfig.NextProtos, wantNextProtos)
578         }
579         if err := ConfigureTransport(t1); err == nil {
580                 t.Error("unexpected success on second call to ConfigureTransport")
581         }
582
583         // And does it work?
584         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
585                 io.WriteString(w, r.Proto)
586         }, optOnlyServer)
587         defer st.Close()
588
589         t1.TLSClientConfig.InsecureSkipVerify = true
590         c := &http.Client{Transport: t1}
591         res, err := c.Get(st.ts.URL)
592         if err != nil {
593                 t.Fatal(err)
594         }
595         slurp, err := ioutil.ReadAll(res.Body)
596         if err != nil {
597                 t.Fatal(err)
598         }
599         if got, want := string(slurp), "HTTP/2.0"; got != want {
600                 t.Errorf("body = %q; want %q", got, want)
601         }
602 }
603
604 type capitalizeReader struct {
605         r io.Reader
606 }
607
608 func (cr capitalizeReader) Read(p []byte) (n int, err error) {
609         n, err = cr.r.Read(p)
610         for i, b := range p[:n] {
611                 if b >= 'a' && b <= 'z' {
612                         p[i] = b - ('a' - 'A')
613                 }
614         }
615         return
616 }
617
618 type flushWriter struct {
619         w io.Writer
620 }
621
622 func (fw flushWriter) Write(p []byte) (n int, err error) {
623         n, err = fw.w.Write(p)
624         if f, ok := fw.w.(http.Flusher); ok {
625                 f.Flush()
626         }
627         return
628 }
629
630 type clientTester struct {
631         t      *testing.T
632         tr     *Transport
633         sc, cc net.Conn // server and client conn
634         fr     *Framer  // server's framer
635         client func() error
636         server func() error
637 }
638
639 func newClientTester(t *testing.T) *clientTester {
640         var dialOnce struct {
641                 sync.Mutex
642                 dialed bool
643         }
644         ct := &clientTester{
645                 t: t,
646         }
647         ct.tr = &Transport{
648                 TLSClientConfig: tlsConfigInsecure,
649                 DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
650                         dialOnce.Lock()
651                         defer dialOnce.Unlock()
652                         if dialOnce.dialed {
653                                 return nil, errors.New("only one dial allowed in test mode")
654                         }
655                         dialOnce.dialed = true
656                         return ct.cc, nil
657                 },
658         }
659
660         ln := newLocalListener(t)
661         cc, err := net.Dial("tcp", ln.Addr().String())
662         if err != nil {
663                 t.Fatal(err)
664
665         }
666         sc, err := ln.Accept()
667         if err != nil {
668                 t.Fatal(err)
669         }
670         ln.Close()
671         ct.cc = cc
672         ct.sc = sc
673         ct.fr = NewFramer(sc, sc)
674         return ct
675 }
676
677 func newLocalListener(t *testing.T) net.Listener {
678         ln, err := net.Listen("tcp4", "127.0.0.1:0")
679         if err == nil {
680                 return ln
681         }
682         ln, err = net.Listen("tcp6", "[::1]:0")
683         if err != nil {
684                 t.Fatal(err)
685         }
686         return ln
687 }
688
689 func (ct *clientTester) greet(settings ...Setting) {
690         buf := make([]byte, len(ClientPreface))
691         _, err := io.ReadFull(ct.sc, buf)
692         if err != nil {
693                 ct.t.Fatalf("reading client preface: %v", err)
694         }
695         f, err := ct.fr.ReadFrame()
696         if err != nil {
697                 ct.t.Fatalf("Reading client settings frame: %v", err)
698         }
699         if sf, ok := f.(*SettingsFrame); !ok {
700                 ct.t.Fatalf("Wanted client settings frame; got %v", f)
701                 _ = sf // stash it away?
702         }
703         if err := ct.fr.WriteSettings(settings...); err != nil {
704                 ct.t.Fatal(err)
705         }
706         if err := ct.fr.WriteSettingsAck(); err != nil {
707                 ct.t.Fatal(err)
708         }
709 }
710
711 func (ct *clientTester) readNonSettingsFrame() (Frame, error) {
712         for {
713                 f, err := ct.fr.ReadFrame()
714                 if err != nil {
715                         return nil, err
716                 }
717                 if _, ok := f.(*SettingsFrame); ok {
718                         continue
719                 }
720                 return f, nil
721         }
722 }
723
724 func (ct *clientTester) cleanup() {
725         ct.tr.CloseIdleConnections()
726 }
727
728 func (ct *clientTester) run() {
729         errc := make(chan error, 2)
730         ct.start("client", errc, ct.client)
731         ct.start("server", errc, ct.server)
732         defer ct.cleanup()
733         for i := 0; i < 2; i++ {
734                 if err := <-errc; err != nil {
735                         ct.t.Error(err)
736                         return
737                 }
738         }
739 }
740
741 func (ct *clientTester) start(which string, errc chan<- error, fn func() error) {
742         go func() {
743                 finished := false
744                 var err error
745                 defer func() {
746                         if !finished {
747                                 err = fmt.Errorf("%s goroutine didn't finish.", which)
748                         } else if err != nil {
749                                 err = fmt.Errorf("%s: %v", which, err)
750                         }
751                         errc <- err
752                 }()
753                 err = fn()
754                 finished = true
755         }()
756 }
757
758 func (ct *clientTester) readFrame() (Frame, error) {
759         return readFrameTimeout(ct.fr, 2*time.Second)
760 }
761
762 func (ct *clientTester) firstHeaders() (*HeadersFrame, error) {
763         for {
764                 f, err := ct.readFrame()
765                 if err != nil {
766                         return nil, fmt.Errorf("ReadFrame while waiting for Headers: %v", err)
767                 }
768                 switch f.(type) {
769                 case *WindowUpdateFrame, *SettingsFrame:
770                         continue
771                 }
772                 hf, ok := f.(*HeadersFrame)
773                 if !ok {
774                         return nil, fmt.Errorf("Got %T; want HeadersFrame", f)
775                 }
776                 return hf, nil
777         }
778 }
779
780 type countingReader struct {
781         n *int64
782 }
783
784 func (r countingReader) Read(p []byte) (n int, err error) {
785         for i := range p {
786                 p[i] = byte(i)
787         }
788         atomic.AddInt64(r.n, int64(len(p)))
789         return len(p), err
790 }
791
792 func TestTransportReqBodyAfterResponse_200(t *testing.T) { testTransportReqBodyAfterResponse(t, 200) }
793 func TestTransportReqBodyAfterResponse_403(t *testing.T) { testTransportReqBodyAfterResponse(t, 403) }
794
795 func testTransportReqBodyAfterResponse(t *testing.T, status int) {
796         const bodySize = 10 << 20
797         clientDone := make(chan struct{})
798         ct := newClientTester(t)
799         ct.client = func() error {
800                 defer ct.cc.(*net.TCPConn).CloseWrite()
801                 defer close(clientDone)
802
803                 var n int64 // atomic
804                 req, err := http.NewRequest("PUT", "https://dummy.tld/", io.LimitReader(countingReader{&n}, bodySize))
805                 if err != nil {
806                         return err
807                 }
808                 res, err := ct.tr.RoundTrip(req)
809                 if err != nil {
810                         return fmt.Errorf("RoundTrip: %v", err)
811                 }
812                 defer res.Body.Close()
813                 if res.StatusCode != status {
814                         return fmt.Errorf("status code = %v; want %v", res.StatusCode, status)
815                 }
816                 slurp, err := ioutil.ReadAll(res.Body)
817                 if err != nil {
818                         return fmt.Errorf("Slurp: %v", err)
819                 }
820                 if len(slurp) > 0 {
821                         return fmt.Errorf("unexpected body: %q", slurp)
822                 }
823                 if status == 200 {
824                         if got := atomic.LoadInt64(&n); got != bodySize {
825                                 return fmt.Errorf("For 200 response, Transport wrote %d bytes; want %d", got, bodySize)
826                         }
827                 } else {
828                         if got := atomic.LoadInt64(&n); got == 0 || got >= bodySize {
829                                 return fmt.Errorf("For %d response, Transport wrote %d bytes; want (0,%d) exclusive", status, got, bodySize)
830                         }
831                 }
832                 return nil
833         }
834         ct.server = func() error {
835                 ct.greet()
836                 var buf bytes.Buffer
837                 enc := hpack.NewEncoder(&buf)
838                 var dataRecv int64
839                 var closed bool
840                 for {
841                         f, err := ct.fr.ReadFrame()
842                         if err != nil {
843                                 select {
844                                 case <-clientDone:
845                                         // If the client's done, it
846                                         // will have reported any
847                                         // errors on its side.
848                                         return nil
849                                 default:
850                                         return err
851                                 }
852                         }
853                         //println(fmt.Sprintf("server got frame: %v", f))
854                         switch f := f.(type) {
855                         case *WindowUpdateFrame, *SettingsFrame:
856                         case *HeadersFrame:
857                                 if !f.HeadersEnded() {
858                                         return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
859                                 }
860                                 if f.StreamEnded() {
861                                         return fmt.Errorf("headers contains END_STREAM unexpectedly: %v", f)
862                                 }
863                         case *DataFrame:
864                                 dataLen := len(f.Data())
865                                 if dataLen > 0 {
866                                         if dataRecv == 0 {
867                                                 enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)})
868                                                 ct.fr.WriteHeaders(HeadersFrameParam{
869                                                         StreamID:      f.StreamID,
870                                                         EndHeaders:    true,
871                                                         EndStream:     false,
872                                                         BlockFragment: buf.Bytes(),
873                                                 })
874                                         }
875                                         if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil {
876                                                 return err
877                                         }
878                                         if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil {
879                                                 return err
880                                         }
881                                 }
882                                 dataRecv += int64(dataLen)
883
884                                 if !closed && ((status != 200 && dataRecv > 0) ||
885                                         (status == 200 && dataRecv == bodySize)) {
886                                         closed = true
887                                         if err := ct.fr.WriteData(f.StreamID, true, nil); err != nil {
888                                                 return err
889                                         }
890                                 }
891                         default:
892                                 return fmt.Errorf("Unexpected client frame %v", f)
893                         }
894                 }
895         }
896         ct.run()
897 }
898
899 // See golang.org/issue/13444
900 func TestTransportFullDuplex(t *testing.T) {
901         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
902                 w.WriteHeader(200) // redundant but for clarity
903                 w.(http.Flusher).Flush()
904                 io.Copy(flushWriter{w}, capitalizeReader{r.Body})
905                 fmt.Fprintf(w, "bye.\n")
906         }, optOnlyServer)
907         defer st.Close()
908
909         tr := &Transport{TLSClientConfig: tlsConfigInsecure}
910         defer tr.CloseIdleConnections()
911         c := &http.Client{Transport: tr}
912
913         pr, pw := io.Pipe()
914         req, err := http.NewRequest("PUT", st.ts.URL, ioutil.NopCloser(pr))
915         if err != nil {
916                 t.Fatal(err)
917         }
918         req.ContentLength = -1
919         res, err := c.Do(req)
920         if err != nil {
921                 t.Fatal(err)
922         }
923         defer res.Body.Close()
924         if res.StatusCode != 200 {
925                 t.Fatalf("StatusCode = %v; want %v", res.StatusCode, 200)
926         }
927         bs := bufio.NewScanner(res.Body)
928         want := func(v string) {
929                 if !bs.Scan() {
930                         t.Fatalf("wanted to read %q but Scan() = false, err = %v", v, bs.Err())
931                 }
932         }
933         write := func(v string) {
934                 _, err := io.WriteString(pw, v)
935                 if err != nil {
936                         t.Fatalf("pipe write: %v", err)
937                 }
938         }
939         write("foo\n")
940         want("FOO")
941         write("bar\n")
942         want("BAR")
943         pw.Close()
944         want("bye.")
945         if err := bs.Err(); err != nil {
946                 t.Fatal(err)
947         }
948 }
949
950 func TestTransportConnectRequest(t *testing.T) {
951         gotc := make(chan *http.Request, 1)
952         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
953                 gotc <- r
954         }, optOnlyServer)
955         defer st.Close()
956
957         u, err := url.Parse(st.ts.URL)
958         if err != nil {
959                 t.Fatal(err)
960         }
961
962         tr := &Transport{TLSClientConfig: tlsConfigInsecure}
963         defer tr.CloseIdleConnections()
964         c := &http.Client{Transport: tr}
965
966         tests := []struct {
967                 req  *http.Request
968                 want string
969         }{
970                 {
971                         req: &http.Request{
972                                 Method: "CONNECT",
973                                 Header: http.Header{},
974                                 URL:    u,
975                         },
976                         want: u.Host,
977                 },
978                 {
979                         req: &http.Request{
980                                 Method: "CONNECT",
981                                 Header: http.Header{},
982                                 URL:    u,
983                                 Host:   "example.com:123",
984                         },
985                         want: "example.com:123",
986                 },
987         }
988
989         for i, tt := range tests {
990                 res, err := c.Do(tt.req)
991                 if err != nil {
992                         t.Errorf("%d. RoundTrip = %v", i, err)
993                         continue
994                 }
995                 res.Body.Close()
996                 req := <-gotc
997                 if req.Method != "CONNECT" {
998                         t.Errorf("method = %q; want CONNECT", req.Method)
999                 }
1000                 if req.Host != tt.want {
1001                         t.Errorf("Host = %q; want %q", req.Host, tt.want)
1002                 }
1003                 if req.URL.Host != tt.want {
1004                         t.Errorf("URL.Host = %q; want %q", req.URL.Host, tt.want)
1005                 }
1006         }
1007 }
1008
1009 type headerType int
1010
1011 const (
1012         noHeader headerType = iota // omitted
1013         oneHeader
1014         splitHeader // broken into continuation on purpose
1015 )
1016
1017 const (
1018         f0 = noHeader
1019         f1 = oneHeader
1020         f2 = splitHeader
1021         d0 = false
1022         d1 = true
1023 )
1024
1025 // Test all 36 combinations of response frame orders:
1026 //    (3 ways of 100-continue) * (2 ways of headers) * (2 ways of data) * (3 ways of trailers):func TestTransportResponsePattern_00f0(t *testing.T) { testTransportResponsePattern(h0, h1, false, h0) }
1027 // Generated by http://play.golang.org/p/SScqYKJYXd
1028 func TestTransportResPattern_c0h1d0t0(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f0) }
1029 func TestTransportResPattern_c0h1d0t1(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f1) }
1030 func TestTransportResPattern_c0h1d0t2(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f2) }
1031 func TestTransportResPattern_c0h1d1t0(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f0) }
1032 func TestTransportResPattern_c0h1d1t1(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f1) }
1033 func TestTransportResPattern_c0h1d1t2(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f2) }
1034 func TestTransportResPattern_c0h2d0t0(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f0) }
1035 func TestTransportResPattern_c0h2d0t1(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f1) }
1036 func TestTransportResPattern_c0h2d0t2(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f2) }
1037 func TestTransportResPattern_c0h2d1t0(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f0) }
1038 func TestTransportResPattern_c0h2d1t1(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f1) }
1039 func TestTransportResPattern_c0h2d1t2(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f2) }
1040 func TestTransportResPattern_c1h1d0t0(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f0) }
1041 func TestTransportResPattern_c1h1d0t1(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f1) }
1042 func TestTransportResPattern_c1h1d0t2(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f2) }
1043 func TestTransportResPattern_c1h1d1t0(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f0) }
1044 func TestTransportResPattern_c1h1d1t1(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f1) }
1045 func TestTransportResPattern_c1h1d1t2(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f2) }
1046 func TestTransportResPattern_c1h2d0t0(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f0) }
1047 func TestTransportResPattern_c1h2d0t1(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f1) }
1048 func TestTransportResPattern_c1h2d0t2(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f2) }
1049 func TestTransportResPattern_c1h2d1t0(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f0) }
1050 func TestTransportResPattern_c1h2d1t1(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f1) }
1051 func TestTransportResPattern_c1h2d1t2(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f2) }
1052 func TestTransportResPattern_c2h1d0t0(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f0) }
1053 func TestTransportResPattern_c2h1d0t1(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f1) }
1054 func TestTransportResPattern_c2h1d0t2(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f2) }
1055 func TestTransportResPattern_c2h1d1t0(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f0) }
1056 func TestTransportResPattern_c2h1d1t1(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f1) }
1057 func TestTransportResPattern_c2h1d1t2(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f2) }
1058 func TestTransportResPattern_c2h2d0t0(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f0) }
1059 func TestTransportResPattern_c2h2d0t1(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f1) }
1060 func TestTransportResPattern_c2h2d0t2(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f2) }
1061 func TestTransportResPattern_c2h2d1t0(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f0) }
1062 func TestTransportResPattern_c2h2d1t1(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f1) }
1063 func TestTransportResPattern_c2h2d1t2(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f2) }
1064
1065 func testTransportResPattern(t *testing.T, expect100Continue, resHeader headerType, withData bool, trailers headerType) {
1066         const reqBody = "some request body"
1067         const resBody = "some response body"
1068
1069         if resHeader == noHeader {
1070                 // TODO: test 100-continue followed by immediate
1071                 // server stream reset, without headers in the middle?
1072                 panic("invalid combination")
1073         }
1074
1075         ct := newClientTester(t)
1076         ct.client = func() error {
1077                 req, _ := http.NewRequest("POST", "https://dummy.tld/", strings.NewReader(reqBody))
1078                 if expect100Continue != noHeader {
1079                         req.Header.Set("Expect", "100-continue")
1080                 }
1081                 res, err := ct.tr.RoundTrip(req)
1082                 if err != nil {
1083                         return fmt.Errorf("RoundTrip: %v", err)
1084                 }
1085                 defer res.Body.Close()
1086                 if res.StatusCode != 200 {
1087                         return fmt.Errorf("status code = %v; want 200", res.StatusCode)
1088                 }
1089                 slurp, err := ioutil.ReadAll(res.Body)
1090                 if err != nil {
1091                         return fmt.Errorf("Slurp: %v", err)
1092                 }
1093                 wantBody := resBody
1094                 if !withData {
1095                         wantBody = ""
1096                 }
1097                 if string(slurp) != wantBody {
1098                         return fmt.Errorf("body = %q; want %q", slurp, wantBody)
1099                 }
1100                 if trailers == noHeader {
1101                         if len(res.Trailer) > 0 {
1102                                 t.Errorf("Trailer = %v; want none", res.Trailer)
1103                         }
1104                 } else {
1105                         want := http.Header{"Some-Trailer": {"some-value"}}
1106                         if !reflect.DeepEqual(res.Trailer, want) {
1107                                 t.Errorf("Trailer = %v; want %v", res.Trailer, want)
1108                         }
1109                 }
1110                 return nil
1111         }
1112         ct.server = func() error {
1113                 ct.greet()
1114                 var buf bytes.Buffer
1115                 enc := hpack.NewEncoder(&buf)
1116
1117                 for {
1118                         f, err := ct.fr.ReadFrame()
1119                         if err != nil {
1120                                 return err
1121                         }
1122                         endStream := false
1123                         send := func(mode headerType) {
1124                                 hbf := buf.Bytes()
1125                                 switch mode {
1126                                 case oneHeader:
1127                                         ct.fr.WriteHeaders(HeadersFrameParam{
1128                                                 StreamID:      f.Header().StreamID,
1129                                                 EndHeaders:    true,
1130                                                 EndStream:     endStream,
1131                                                 BlockFragment: hbf,
1132                                         })
1133                                 case splitHeader:
1134                                         if len(hbf) < 2 {
1135                                                 panic("too small")
1136                                         }
1137                                         ct.fr.WriteHeaders(HeadersFrameParam{
1138                                                 StreamID:      f.Header().StreamID,
1139                                                 EndHeaders:    false,
1140                                                 EndStream:     endStream,
1141                                                 BlockFragment: hbf[:1],
1142                                         })
1143                                         ct.fr.WriteContinuation(f.Header().StreamID, true, hbf[1:])
1144                                 default:
1145                                         panic("bogus mode")
1146                                 }
1147                         }
1148                         switch f := f.(type) {
1149                         case *WindowUpdateFrame, *SettingsFrame:
1150                         case *DataFrame:
1151                                 if !f.StreamEnded() {
1152                                         // No need to send flow control tokens. The test request body is tiny.
1153                                         continue
1154                                 }
1155                                 // Response headers (1+ frames; 1 or 2 in this test, but never 0)
1156                                 {
1157                                         buf.Reset()
1158                                         enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
1159                                         enc.WriteField(hpack.HeaderField{Name: "x-foo", Value: "blah"})
1160                                         enc.WriteField(hpack.HeaderField{Name: "x-bar", Value: "more"})
1161                                         if trailers != noHeader {
1162                                                 enc.WriteField(hpack.HeaderField{Name: "trailer", Value: "some-trailer"})
1163                                         }
1164                                         endStream = withData == false && trailers == noHeader
1165                                         send(resHeader)
1166                                 }
1167                                 if withData {
1168                                         endStream = trailers == noHeader
1169                                         ct.fr.WriteData(f.StreamID, endStream, []byte(resBody))
1170                                 }
1171                                 if trailers != noHeader {
1172                                         endStream = true
1173                                         buf.Reset()
1174                                         enc.WriteField(hpack.HeaderField{Name: "some-trailer", Value: "some-value"})
1175                                         send(trailers)
1176                                 }
1177                                 if endStream {
1178                                         return nil
1179                                 }
1180                         case *HeadersFrame:
1181                                 if expect100Continue != noHeader {
1182                                         buf.Reset()
1183                                         enc.WriteField(hpack.HeaderField{Name: ":status", Value: "100"})
1184                                         send(expect100Continue)
1185                                 }
1186                         }
1187                 }
1188         }
1189         ct.run()
1190 }
1191
1192 func TestTransportReceiveUndeclaredTrailer(t *testing.T) {
1193         ct := newClientTester(t)
1194         ct.client = func() error {
1195                 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
1196                 res, err := ct.tr.RoundTrip(req)
1197                 if err != nil {
1198                         return fmt.Errorf("RoundTrip: %v", err)
1199                 }
1200                 defer res.Body.Close()
1201                 if res.StatusCode != 200 {
1202                         return fmt.Errorf("status code = %v; want 200", res.StatusCode)
1203                 }
1204                 slurp, err := ioutil.ReadAll(res.Body)
1205                 if err != nil {
1206                         return fmt.Errorf("res.Body ReadAll error = %q, %v; want %v", slurp, err, nil)
1207                 }
1208                 if len(slurp) > 0 {
1209                         return fmt.Errorf("body = %q; want nothing", slurp)
1210                 }
1211                 if _, ok := res.Trailer["Some-Trailer"]; !ok {
1212                         return fmt.Errorf("expected Some-Trailer")
1213                 }
1214                 return nil
1215         }
1216         ct.server = func() error {
1217                 ct.greet()
1218
1219                 var n int
1220                 var hf *HeadersFrame
1221                 for hf == nil && n < 10 {
1222                         f, err := ct.fr.ReadFrame()
1223                         if err != nil {
1224                                 return err
1225                         }
1226                         hf, _ = f.(*HeadersFrame)
1227                         n++
1228                 }
1229
1230                 var buf bytes.Buffer
1231                 enc := hpack.NewEncoder(&buf)
1232
1233                 // send headers without Trailer header
1234                 enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
1235                 ct.fr.WriteHeaders(HeadersFrameParam{
1236                         StreamID:      hf.StreamID,
1237                         EndHeaders:    true,
1238                         EndStream:     false,
1239                         BlockFragment: buf.Bytes(),
1240                 })
1241
1242                 // send trailers
1243                 buf.Reset()
1244                 enc.WriteField(hpack.HeaderField{Name: "some-trailer", Value: "I'm an undeclared Trailer!"})
1245                 ct.fr.WriteHeaders(HeadersFrameParam{
1246                         StreamID:      hf.StreamID,
1247                         EndHeaders:    true,
1248                         EndStream:     true,
1249                         BlockFragment: buf.Bytes(),
1250                 })
1251                 return nil
1252         }
1253         ct.run()
1254 }
1255
1256 func TestTransportInvalidTrailer_Pseudo1(t *testing.T) {
1257         testTransportInvalidTrailer_Pseudo(t, oneHeader)
1258 }
1259 func TestTransportInvalidTrailer_Pseudo2(t *testing.T) {
1260         testTransportInvalidTrailer_Pseudo(t, splitHeader)
1261 }
1262 func testTransportInvalidTrailer_Pseudo(t *testing.T, trailers headerType) {
1263         testInvalidTrailer(t, trailers, pseudoHeaderError(":colon"), func(enc *hpack.Encoder) {
1264                 enc.WriteField(hpack.HeaderField{Name: ":colon", Value: "foo"})
1265                 enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"})
1266         })
1267 }
1268
1269 func TestTransportInvalidTrailer_Capital1(t *testing.T) {
1270         testTransportInvalidTrailer_Capital(t, oneHeader)
1271 }
1272 func TestTransportInvalidTrailer_Capital2(t *testing.T) {
1273         testTransportInvalidTrailer_Capital(t, splitHeader)
1274 }
1275 func testTransportInvalidTrailer_Capital(t *testing.T, trailers headerType) {
1276         testInvalidTrailer(t, trailers, headerFieldNameError("Capital"), func(enc *hpack.Encoder) {
1277                 enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"})
1278                 enc.WriteField(hpack.HeaderField{Name: "Capital", Value: "bad"})
1279         })
1280 }
1281 func TestTransportInvalidTrailer_EmptyFieldName(t *testing.T) {
1282         testInvalidTrailer(t, oneHeader, headerFieldNameError(""), func(enc *hpack.Encoder) {
1283                 enc.WriteField(hpack.HeaderField{Name: "", Value: "bad"})
1284         })
1285 }
1286 func TestTransportInvalidTrailer_BinaryFieldValue(t *testing.T) {
1287         testInvalidTrailer(t, oneHeader, headerFieldValueError("has\nnewline"), func(enc *hpack.Encoder) {
1288                 enc.WriteField(hpack.HeaderField{Name: "x", Value: "has\nnewline"})
1289         })
1290 }
1291
1292 func testInvalidTrailer(t *testing.T, trailers headerType, wantErr error, writeTrailer func(*hpack.Encoder)) {
1293         ct := newClientTester(t)
1294         ct.client = func() error {
1295                 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
1296                 res, err := ct.tr.RoundTrip(req)
1297                 if err != nil {
1298                         return fmt.Errorf("RoundTrip: %v", err)
1299                 }
1300                 defer res.Body.Close()
1301                 if res.StatusCode != 200 {
1302                         return fmt.Errorf("status code = %v; want 200", res.StatusCode)
1303                 }
1304                 slurp, err := ioutil.ReadAll(res.Body)
1305                 se, ok := err.(StreamError)
1306                 if !ok || se.Cause != wantErr {
1307                         return fmt.Errorf("res.Body ReadAll error = %q, %#v; want StreamError with cause %T, %#v", slurp, err, wantErr, wantErr)
1308                 }
1309                 if len(slurp) > 0 {
1310                         return fmt.Errorf("body = %q; want nothing", slurp)
1311                 }
1312                 return nil
1313         }
1314         ct.server = func() error {
1315                 ct.greet()
1316                 var buf bytes.Buffer
1317                 enc := hpack.NewEncoder(&buf)
1318
1319                 for {
1320                         f, err := ct.fr.ReadFrame()
1321                         if err != nil {
1322                                 return err
1323                         }
1324                         switch f := f.(type) {
1325                         case *HeadersFrame:
1326                                 var endStream bool
1327                                 send := func(mode headerType) {
1328                                         hbf := buf.Bytes()
1329                                         switch mode {
1330                                         case oneHeader:
1331                                                 ct.fr.WriteHeaders(HeadersFrameParam{
1332                                                         StreamID:      f.StreamID,
1333                                                         EndHeaders:    true,
1334                                                         EndStream:     endStream,
1335                                                         BlockFragment: hbf,
1336                                                 })
1337                                         case splitHeader:
1338                                                 if len(hbf) < 2 {
1339                                                         panic("too small")
1340                                                 }
1341                                                 ct.fr.WriteHeaders(HeadersFrameParam{
1342                                                         StreamID:      f.StreamID,
1343                                                         EndHeaders:    false,
1344                                                         EndStream:     endStream,
1345                                                         BlockFragment: hbf[:1],
1346                                                 })
1347                                                 ct.fr.WriteContinuation(f.StreamID, true, hbf[1:])
1348                                         default:
1349                                                 panic("bogus mode")
1350                                         }
1351                                 }
1352                                 // Response headers (1+ frames; 1 or 2 in this test, but never 0)
1353                                 {
1354                                         buf.Reset()
1355                                         enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
1356                                         enc.WriteField(hpack.HeaderField{Name: "trailer", Value: "declared"})
1357                                         endStream = false
1358                                         send(oneHeader)
1359                                 }
1360                                 // Trailers:
1361                                 {
1362                                         endStream = true
1363                                         buf.Reset()
1364                                         writeTrailer(enc)
1365                                         send(trailers)
1366                                 }
1367                                 return nil
1368                         }
1369                 }
1370         }
1371         ct.run()
1372 }
1373
1374 // headerListSize returns the HTTP2 header list size of h.
1375 //   http://httpwg.org/specs/rfc7540.html#SETTINGS_MAX_HEADER_LIST_SIZE
1376 //   http://httpwg.org/specs/rfc7540.html#MaxHeaderBlock
1377 func headerListSize(h http.Header) (size uint32) {
1378         for k, vv := range h {
1379                 for _, v := range vv {
1380                         hf := hpack.HeaderField{Name: k, Value: v}
1381                         size += hf.Size()
1382                 }
1383         }
1384         return size
1385 }
1386
1387 // padHeaders adds data to an http.Header until headerListSize(h) ==
1388 // limit. Due to the way header list sizes are calculated, padHeaders
1389 // cannot add fewer than len("Pad-Headers") + 32 bytes to h, and will
1390 // call t.Fatal if asked to do so. PadHeaders first reserves enough
1391 // space for an empty "Pad-Headers" key, then adds as many copies of
1392 // filler as possible. Any remaining bytes necessary to push the
1393 // header list size up to limit are added to h["Pad-Headers"].
1394 func padHeaders(t *testing.T, h http.Header, limit uint64, filler string) {
1395         if limit > 0xffffffff {
1396                 t.Fatalf("padHeaders: refusing to pad to more than 2^32-1 bytes. limit = %v", limit)
1397         }
1398         hf := hpack.HeaderField{Name: "Pad-Headers", Value: ""}
1399         minPadding := uint64(hf.Size())
1400         size := uint64(headerListSize(h))
1401
1402         minlimit := size + minPadding
1403         if limit < minlimit {
1404                 t.Fatalf("padHeaders: limit %v < %v", limit, minlimit)
1405         }
1406
1407         // Use a fixed-width format for name so that fieldSize
1408         // remains constant.
1409         nameFmt := "Pad-Headers-%06d"
1410         hf = hpack.HeaderField{Name: fmt.Sprintf(nameFmt, 1), Value: filler}
1411         fieldSize := uint64(hf.Size())
1412
1413         // Add as many complete filler values as possible, leaving
1414         // room for at least one empty "Pad-Headers" key.
1415         limit = limit - minPadding
1416         for i := 0; size+fieldSize < limit; i++ {
1417                 name := fmt.Sprintf(nameFmt, i)
1418                 h.Add(name, filler)
1419                 size += fieldSize
1420         }
1421
1422         // Add enough bytes to reach limit.
1423         remain := limit - size
1424         lastValue := strings.Repeat("*", int(remain))
1425         h.Add("Pad-Headers", lastValue)
1426 }
1427
1428 func TestPadHeaders(t *testing.T) {
1429         check := func(h http.Header, limit uint32, fillerLen int) {
1430                 if h == nil {
1431                         h = make(http.Header)
1432                 }
1433                 filler := strings.Repeat("f", fillerLen)
1434                 padHeaders(t, h, uint64(limit), filler)
1435                 gotSize := headerListSize(h)
1436                 if gotSize != limit {
1437                         t.Errorf("Got size = %v; want %v", gotSize, limit)
1438                 }
1439         }
1440         // Try all possible combinations for small fillerLen and limit.
1441         hf := hpack.HeaderField{Name: "Pad-Headers", Value: ""}
1442         minLimit := hf.Size()
1443         for limit := minLimit; limit <= 128; limit++ {
1444                 for fillerLen := 0; uint32(fillerLen) <= limit; fillerLen++ {
1445                         check(nil, limit, fillerLen)
1446                 }
1447         }
1448
1449         // Try a few tests with larger limits, plus cumulative
1450         // tests. Since these tests are cumulative, tests[i+1].limit
1451         // must be >= tests[i].limit + minLimit. See the comment on
1452         // padHeaders for more info on why the limit arg has this
1453         // restriction.
1454         tests := []struct {
1455                 fillerLen int
1456                 limit     uint32
1457         }{
1458                 {
1459                         fillerLen: 64,
1460                         limit:     1024,
1461                 },
1462                 {
1463                         fillerLen: 1024,
1464                         limit:     1286,
1465                 },
1466                 {
1467                         fillerLen: 256,
1468                         limit:     2048,
1469                 },
1470                 {
1471                         fillerLen: 1024,
1472                         limit:     10 * 1024,
1473                 },
1474                 {
1475                         fillerLen: 1023,
1476                         limit:     11 * 1024,
1477                 },
1478         }
1479         h := make(http.Header)
1480         for _, tc := range tests {
1481                 check(nil, tc.limit, tc.fillerLen)
1482                 check(h, tc.limit, tc.fillerLen)
1483         }
1484 }
1485
1486 func TestTransportChecksRequestHeaderListSize(t *testing.T) {
1487         st := newServerTester(t,
1488                 func(w http.ResponseWriter, r *http.Request) {
1489                         // Consume body & force client to send
1490                         // trailers before writing response.
1491                         // ioutil.ReadAll returns non-nil err for
1492                         // requests that attempt to send greater than
1493                         // maxHeaderListSize bytes of trailers, since
1494                         // those requests generate a stream reset.
1495                         ioutil.ReadAll(r.Body)
1496                         r.Body.Close()
1497                 },
1498                 func(ts *httptest.Server) {
1499                         ts.Config.MaxHeaderBytes = 16 << 10
1500                 },
1501                 optOnlyServer,
1502                 optQuiet,
1503         )
1504         defer st.Close()
1505
1506         tr := &Transport{TLSClientConfig: tlsConfigInsecure}
1507         defer tr.CloseIdleConnections()
1508
1509         checkRoundTrip := func(req *http.Request, wantErr error, desc string) {
1510                 res, err := tr.RoundTrip(req)
1511                 if err != wantErr {
1512                         if res != nil {
1513                                 res.Body.Close()
1514                         }
1515                         t.Errorf("%v: RoundTrip err = %v; want %v", desc, err, wantErr)
1516                         return
1517                 }
1518                 if err == nil {
1519                         if res == nil {
1520                                 t.Errorf("%v: response nil; want non-nil.", desc)
1521                                 return
1522                         }
1523                         defer res.Body.Close()
1524                         if res.StatusCode != http.StatusOK {
1525                                 t.Errorf("%v: response status = %v; want %v", desc, res.StatusCode, http.StatusOK)
1526                         }
1527                         return
1528                 }
1529                 if res != nil {
1530                         t.Errorf("%v: RoundTrip err = %v but response non-nil", desc, err)
1531                 }
1532         }
1533         headerListSizeForRequest := func(req *http.Request) (size uint64) {
1534                 contentLen := actualContentLength(req)
1535                 trailers, err := commaSeparatedTrailers(req)
1536                 if err != nil {
1537                         t.Fatalf("headerListSizeForRequest: %v", err)
1538                 }
1539                 cc := &ClientConn{peerMaxHeaderListSize: 0xffffffffffffffff}
1540                 cc.henc = hpack.NewEncoder(&cc.hbuf)
1541                 cc.mu.Lock()
1542                 hdrs, err := cc.encodeHeaders(req, true, trailers, contentLen)
1543                 cc.mu.Unlock()
1544                 if err != nil {
1545                         t.Fatalf("headerListSizeForRequest: %v", err)
1546                 }
1547                 hpackDec := hpack.NewDecoder(initialHeaderTableSize, func(hf hpack.HeaderField) {
1548                         size += uint64(hf.Size())
1549                 })
1550                 if len(hdrs) > 0 {
1551                         if _, err := hpackDec.Write(hdrs); err != nil {
1552                                 t.Fatalf("headerListSizeForRequest: %v", err)
1553                         }
1554                 }
1555                 return size
1556         }
1557         // Create a new Request for each test, rather than reusing the
1558         // same Request, to avoid a race when modifying req.Headers.
1559         // See https://github.com/golang/go/issues/21316
1560         newRequest := func() *http.Request {
1561                 // Body must be non-nil to enable writing trailers.
1562                 body := strings.NewReader("hello")
1563                 req, err := http.NewRequest("POST", st.ts.URL, body)
1564                 if err != nil {
1565                         t.Fatalf("newRequest: NewRequest: %v", err)
1566                 }
1567                 return req
1568         }
1569
1570         // Make an arbitrary request to ensure we get the server's
1571         // settings frame and initialize peerMaxHeaderListSize.
1572         req := newRequest()
1573         checkRoundTrip(req, nil, "Initial request")
1574
1575         // Get the ClientConn associated with the request and validate
1576         // peerMaxHeaderListSize.
1577         addr := authorityAddr(req.URL.Scheme, req.URL.Host)
1578         cc, err := tr.connPool().GetClientConn(req, addr)
1579         if err != nil {
1580                 t.Fatalf("GetClientConn: %v", err)
1581         }
1582         cc.mu.Lock()
1583         peerSize := cc.peerMaxHeaderListSize
1584         cc.mu.Unlock()
1585         st.scMu.Lock()
1586         wantSize := uint64(st.sc.maxHeaderListSize())
1587         st.scMu.Unlock()
1588         if peerSize != wantSize {
1589                 t.Errorf("peerMaxHeaderListSize = %v; want %v", peerSize, wantSize)
1590         }
1591
1592         // Sanity check peerSize. (*serverConn) maxHeaderListSize adds
1593         // 320 bytes of padding.
1594         wantHeaderBytes := uint64(st.ts.Config.MaxHeaderBytes) + 320
1595         if peerSize != wantHeaderBytes {
1596                 t.Errorf("peerMaxHeaderListSize = %v; want %v.", peerSize, wantHeaderBytes)
1597         }
1598
1599         // Pad headers & trailers, but stay under peerSize.
1600         req = newRequest()
1601         req.Header = make(http.Header)
1602         req.Trailer = make(http.Header)
1603         filler := strings.Repeat("*", 1024)
1604         padHeaders(t, req.Trailer, peerSize, filler)
1605         // cc.encodeHeaders adds some default headers to the request,
1606         // so we need to leave room for those.
1607         defaultBytes := headerListSizeForRequest(req)
1608         padHeaders(t, req.Header, peerSize-defaultBytes, filler)
1609         checkRoundTrip(req, nil, "Headers & Trailers under limit")
1610
1611         // Add enough header bytes to push us over peerSize.
1612         req = newRequest()
1613         req.Header = make(http.Header)
1614         padHeaders(t, req.Header, peerSize, filler)
1615         checkRoundTrip(req, errRequestHeaderListSize, "Headers over limit")
1616
1617         // Push trailers over the limit.
1618         req = newRequest()
1619         req.Trailer = make(http.Header)
1620         padHeaders(t, req.Trailer, peerSize+1, filler)
1621         checkRoundTrip(req, errRequestHeaderListSize, "Trailers over limit")
1622
1623         // Send headers with a single large value.
1624         req = newRequest()
1625         filler = strings.Repeat("*", int(peerSize))
1626         req.Header = make(http.Header)
1627         req.Header.Set("Big", filler)
1628         checkRoundTrip(req, errRequestHeaderListSize, "Single large header")
1629
1630         // Send trailers with a single large value.
1631         req = newRequest()
1632         req.Trailer = make(http.Header)
1633         req.Trailer.Set("Big", filler)
1634         checkRoundTrip(req, errRequestHeaderListSize, "Single large trailer")
1635 }
1636
1637 func TestTransportChecksResponseHeaderListSize(t *testing.T) {
1638         ct := newClientTester(t)
1639         ct.client = func() error {
1640                 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
1641                 res, err := ct.tr.RoundTrip(req)
1642                 if err != errResponseHeaderListSize {
1643                         if res != nil {
1644                                 res.Body.Close()
1645                         }
1646                         size := int64(0)
1647                         for k, vv := range res.Header {
1648                                 for _, v := range vv {
1649                                         size += int64(len(k)) + int64(len(v)) + 32
1650                                 }
1651                         }
1652                         return fmt.Errorf("RoundTrip Error = %v (and %d bytes of response headers); want errResponseHeaderListSize", err, size)
1653                 }
1654                 return nil
1655         }
1656         ct.server = func() error {
1657                 ct.greet()
1658                 var buf bytes.Buffer
1659                 enc := hpack.NewEncoder(&buf)
1660
1661                 for {
1662                         f, err := ct.fr.ReadFrame()
1663                         if err != nil {
1664                                 return err
1665                         }
1666                         switch f := f.(type) {
1667                         case *HeadersFrame:
1668                                 enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
1669                                 large := strings.Repeat("a", 1<<10)
1670                                 for i := 0; i < 5042; i++ {
1671                                         enc.WriteField(hpack.HeaderField{Name: large, Value: large})
1672                                 }
1673                                 if size, want := buf.Len(), 6329; size != want {
1674                                         // Note: this number might change if
1675                                         // our hpack implementation
1676                                         // changes. That's fine. This is
1677                                         // just a sanity check that our
1678                                         // response can fit in a single
1679                                         // header block fragment frame.
1680                                         return fmt.Errorf("encoding over 10MB of duplicate keypairs took %d bytes; expected %d", size, want)
1681                                 }
1682                                 ct.fr.WriteHeaders(HeadersFrameParam{
1683                                         StreamID:      f.StreamID,
1684                                         EndHeaders:    true,
1685                                         EndStream:     true,
1686                                         BlockFragment: buf.Bytes(),
1687                                 })
1688                                 return nil
1689                         }
1690                 }
1691         }
1692         ct.run()
1693 }
1694
1695 // Test that the the Transport returns a typed error from Response.Body.Read calls
1696 // when the server sends an error. (here we use a panic, since that should generate
1697 // a stream error, but others like cancel should be similar)
1698 func TestTransportBodyReadErrorType(t *testing.T) {
1699         doPanic := make(chan bool, 1)
1700         st := newServerTester(t,
1701                 func(w http.ResponseWriter, r *http.Request) {
1702                         w.(http.Flusher).Flush() // force headers out
1703                         <-doPanic
1704                         panic("boom")
1705                 },
1706                 optOnlyServer,
1707                 optQuiet,
1708         )
1709         defer st.Close()
1710
1711         tr := &Transport{TLSClientConfig: tlsConfigInsecure}
1712         defer tr.CloseIdleConnections()
1713         c := &http.Client{Transport: tr}
1714
1715         res, err := c.Get(st.ts.URL)
1716         if err != nil {
1717                 t.Fatal(err)
1718         }
1719         defer res.Body.Close()
1720         doPanic <- true
1721         buf := make([]byte, 100)
1722         n, err := res.Body.Read(buf)
1723         want := StreamError{StreamID: 0x1, Code: 0x2}
1724         if !reflect.DeepEqual(want, err) {
1725                 t.Errorf("Read = %v, %#v; want error %#v", n, err, want)
1726         }
1727 }
1728
1729 // golang.org/issue/13924
1730 // This used to fail after many iterations, especially with -race:
1731 // go test -v -run=TestTransportDoubleCloseOnWriteError -count=500 -race
1732 func TestTransportDoubleCloseOnWriteError(t *testing.T) {
1733         var (
1734                 mu   sync.Mutex
1735                 conn net.Conn // to close if set
1736         )
1737
1738         st := newServerTester(t,
1739                 func(w http.ResponseWriter, r *http.Request) {
1740                         mu.Lock()
1741                         defer mu.Unlock()
1742                         if conn != nil {
1743                                 conn.Close()
1744                         }
1745                 },
1746                 optOnlyServer,
1747         )
1748         defer st.Close()
1749
1750         tr := &Transport{
1751                 TLSClientConfig: tlsConfigInsecure,
1752                 DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
1753                         tc, err := tls.Dial(network, addr, cfg)
1754                         if err != nil {
1755                                 return nil, err
1756                         }
1757                         mu.Lock()
1758                         defer mu.Unlock()
1759                         conn = tc
1760                         return tc, nil
1761                 },
1762         }
1763         defer tr.CloseIdleConnections()
1764         c := &http.Client{Transport: tr}
1765         c.Get(st.ts.URL)
1766 }
1767
1768 // Test that the http1 Transport.DisableKeepAlives option is respected
1769 // and connections are closed as soon as idle.
1770 // See golang.org/issue/14008
1771 func TestTransportDisableKeepAlives(t *testing.T) {
1772         st := newServerTester(t,
1773                 func(w http.ResponseWriter, r *http.Request) {
1774                         io.WriteString(w, "hi")
1775                 },
1776                 optOnlyServer,
1777         )
1778         defer st.Close()
1779
1780         connClosed := make(chan struct{}) // closed on tls.Conn.Close
1781         tr := &Transport{
1782                 t1: &http.Transport{
1783                         DisableKeepAlives: true,
1784                 },
1785                 TLSClientConfig: tlsConfigInsecure,
1786                 DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
1787                         tc, err := tls.Dial(network, addr, cfg)
1788                         if err != nil {
1789                                 return nil, err
1790                         }
1791                         return &noteCloseConn{Conn: tc, closefn: func() { close(connClosed) }}, nil
1792                 },
1793         }
1794         c := &http.Client{Transport: tr}
1795         res, err := c.Get(st.ts.URL)
1796         if err != nil {
1797                 t.Fatal(err)
1798         }
1799         if _, err := ioutil.ReadAll(res.Body); err != nil {
1800                 t.Fatal(err)
1801         }
1802         defer res.Body.Close()
1803
1804         select {
1805         case <-connClosed:
1806         case <-time.After(1 * time.Second):
1807                 t.Errorf("timeout")
1808         }
1809
1810 }
1811
1812 // Test concurrent requests with Transport.DisableKeepAlives. We can share connections,
1813 // but when things are totally idle, it still needs to close.
1814 func TestTransportDisableKeepAlives_Concurrency(t *testing.T) {
1815         const D = 25 * time.Millisecond
1816         st := newServerTester(t,
1817                 func(w http.ResponseWriter, r *http.Request) {
1818                         time.Sleep(D)
1819                         io.WriteString(w, "hi")
1820                 },
1821                 optOnlyServer,
1822         )
1823         defer st.Close()
1824
1825         var dials int32
1826         var conns sync.WaitGroup
1827         tr := &Transport{
1828                 t1: &http.Transport{
1829                         DisableKeepAlives: true,
1830                 },
1831                 TLSClientConfig: tlsConfigInsecure,
1832                 DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
1833                         tc, err := tls.Dial(network, addr, cfg)
1834                         if err != nil {
1835                                 return nil, err
1836                         }
1837                         atomic.AddInt32(&dials, 1)
1838                         conns.Add(1)
1839                         return &noteCloseConn{Conn: tc, closefn: func() { conns.Done() }}, nil
1840                 },
1841         }
1842         c := &http.Client{Transport: tr}
1843         var reqs sync.WaitGroup
1844         const N = 20
1845         for i := 0; i < N; i++ {
1846                 reqs.Add(1)
1847                 if i == N-1 {
1848                         // For the final request, try to make all the
1849                         // others close. This isn't verified in the
1850                         // count, other than the Log statement, since
1851                         // it's so timing dependent. This test is
1852                         // really to make sure we don't interrupt a
1853                         // valid request.
1854                         time.Sleep(D * 2)
1855                 }
1856                 go func() {
1857                         defer reqs.Done()
1858                         res, err := c.Get(st.ts.URL)
1859                         if err != nil {
1860                                 t.Error(err)
1861                                 return
1862                         }
1863                         if _, err := ioutil.ReadAll(res.Body); err != nil {
1864                                 t.Error(err)
1865                                 return
1866                         }
1867                         res.Body.Close()
1868                 }()
1869         }
1870         reqs.Wait()
1871         conns.Wait()
1872         t.Logf("did %d dials, %d requests", atomic.LoadInt32(&dials), N)
1873 }
1874
1875 type noteCloseConn struct {
1876         net.Conn
1877         onceClose sync.Once
1878         closefn   func()
1879 }
1880
1881 func (c *noteCloseConn) Close() error {
1882         c.onceClose.Do(c.closefn)
1883         return c.Conn.Close()
1884 }
1885
1886 func isTimeout(err error) bool {
1887         switch err := err.(type) {
1888         case nil:
1889                 return false
1890         case *url.Error:
1891                 return isTimeout(err.Err)
1892         case net.Error:
1893                 return err.Timeout()
1894         }
1895         return false
1896 }
1897
1898 // Test that the http1 Transport.ResponseHeaderTimeout option and cancel is sent.
1899 func TestTransportResponseHeaderTimeout_NoBody(t *testing.T) {
1900         testTransportResponseHeaderTimeout(t, false)
1901 }
1902 func TestTransportResponseHeaderTimeout_Body(t *testing.T) {
1903         testTransportResponseHeaderTimeout(t, true)
1904 }
1905
1906 func testTransportResponseHeaderTimeout(t *testing.T, body bool) {
1907         ct := newClientTester(t)
1908         ct.tr.t1 = &http.Transport{
1909                 ResponseHeaderTimeout: 5 * time.Millisecond,
1910         }
1911         ct.client = func() error {
1912                 c := &http.Client{Transport: ct.tr}
1913                 var err error
1914                 var n int64
1915                 const bodySize = 4 << 20
1916                 if body {
1917                         _, err = c.Post("https://dummy.tld/", "text/foo", io.LimitReader(countingReader{&n}, bodySize))
1918                 } else {
1919                         _, err = c.Get("https://dummy.tld/")
1920                 }
1921                 if !isTimeout(err) {
1922                         t.Errorf("client expected timeout error; got %#v", err)
1923                 }
1924                 if body && n != bodySize {
1925                         t.Errorf("only read %d bytes of body; want %d", n, bodySize)
1926                 }
1927                 return nil
1928         }
1929         ct.server = func() error {
1930                 ct.greet()
1931                 for {
1932                         f, err := ct.fr.ReadFrame()
1933                         if err != nil {
1934                                 t.Logf("ReadFrame: %v", err)
1935                                 return nil
1936                         }
1937                         switch f := f.(type) {
1938                         case *DataFrame:
1939                                 dataLen := len(f.Data())
1940                                 if dataLen > 0 {
1941                                         if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil {
1942                                                 return err
1943                                         }
1944                                         if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil {
1945                                                 return err
1946                                         }
1947                                 }
1948                         case *RSTStreamFrame:
1949                                 if f.StreamID == 1 && f.ErrCode == ErrCodeCancel {
1950                                         return nil
1951                                 }
1952                         }
1953                 }
1954         }
1955         ct.run()
1956 }
1957
1958 func TestTransportDisableCompression(t *testing.T) {
1959         const body = "sup"
1960         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1961                 want := http.Header{
1962                         "User-Agent": []string{"Go-http-client/2.0"},
1963                 }
1964                 if !reflect.DeepEqual(r.Header, want) {
1965                         t.Errorf("request headers = %v; want %v", r.Header, want)
1966                 }
1967         }, optOnlyServer)
1968         defer st.Close()
1969
1970         tr := &Transport{
1971                 TLSClientConfig: tlsConfigInsecure,
1972                 t1: &http.Transport{
1973                         DisableCompression: true,
1974                 },
1975         }
1976         defer tr.CloseIdleConnections()
1977
1978         req, err := http.NewRequest("GET", st.ts.URL, nil)
1979         if err != nil {
1980                 t.Fatal(err)
1981         }
1982         res, err := tr.RoundTrip(req)
1983         if err != nil {
1984                 t.Fatal(err)
1985         }
1986         defer res.Body.Close()
1987 }
1988
1989 // RFC 7540 section 8.1.2.2
1990 func TestTransportRejectsConnHeaders(t *testing.T) {
1991         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1992                 var got []string
1993                 for k := range r.Header {
1994                         got = append(got, k)
1995                 }
1996                 sort.Strings(got)
1997                 w.Header().Set("Got-Header", strings.Join(got, ","))
1998         }, optOnlyServer)
1999         defer st.Close()
2000
2001         tr := &Transport{TLSClientConfig: tlsConfigInsecure}
2002         defer tr.CloseIdleConnections()
2003
2004         tests := []struct {
2005                 key   string
2006                 value []string
2007                 want  string
2008         }{
2009                 {
2010                         key:   "Upgrade",
2011                         value: []string{"anything"},
2012                         want:  "ERROR: http2: invalid Upgrade request header: [\"anything\"]",
2013                 },
2014                 {
2015                         key:   "Connection",
2016                         value: []string{"foo"},
2017                         want:  "ERROR: http2: invalid Connection request header: [\"foo\"]",
2018                 },
2019                 {
2020                         key:   "Connection",
2021                         value: []string{"close"},
2022                         want:  "Accept-Encoding,User-Agent",
2023                 },
2024                 {
2025                         key:   "Connection",
2026                         value: []string{"close", "something-else"},
2027                         want:  "ERROR: http2: invalid Connection request header: [\"close\" \"something-else\"]",
2028                 },
2029                 {
2030                         key:   "Connection",
2031                         value: []string{"keep-alive"},
2032                         want:  "Accept-Encoding,User-Agent",
2033                 },
2034                 {
2035                         key:   "Proxy-Connection", // just deleted and ignored
2036                         value: []string{"keep-alive"},
2037                         want:  "Accept-Encoding,User-Agent",
2038                 },
2039                 {
2040                         key:   "Transfer-Encoding",
2041                         value: []string{""},
2042                         want:  "Accept-Encoding,User-Agent",
2043                 },
2044                 {
2045                         key:   "Transfer-Encoding",
2046                         value: []string{"foo"},
2047                         want:  "ERROR: http2: invalid Transfer-Encoding request header: [\"foo\"]",
2048                 },
2049                 {
2050                         key:   "Transfer-Encoding",
2051                         value: []string{"chunked"},
2052                         want:  "Accept-Encoding,User-Agent",
2053                 },
2054                 {
2055                         key:   "Transfer-Encoding",
2056                         value: []string{"chunked", "other"},
2057                         want:  "ERROR: http2: invalid Transfer-Encoding request header: [\"chunked\" \"other\"]",
2058                 },
2059                 {
2060                         key:   "Content-Length",
2061                         value: []string{"123"},
2062                         want:  "Accept-Encoding,User-Agent",
2063                 },
2064                 {
2065                         key:   "Keep-Alive",
2066                         value: []string{"doop"},
2067                         want:  "Accept-Encoding,User-Agent",
2068                 },
2069         }
2070
2071         for _, tt := range tests {
2072                 req, _ := http.NewRequest("GET", st.ts.URL, nil)
2073                 req.Header[tt.key] = tt.value
2074                 res, err := tr.RoundTrip(req)
2075                 var got string
2076                 if err != nil {
2077                         got = fmt.Sprintf("ERROR: %v", err)
2078                 } else {
2079                         got = res.Header.Get("Got-Header")
2080                         res.Body.Close()
2081                 }
2082                 if got != tt.want {
2083                         t.Errorf("For key %q, value %q, got = %q; want %q", tt.key, tt.value, got, tt.want)
2084                 }
2085         }
2086 }
2087
2088 // golang.org/issue/14048
2089 func TestTransportFailsOnInvalidHeaders(t *testing.T) {
2090         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
2091                 var got []string
2092                 for k := range r.Header {
2093                         got = append(got, k)
2094                 }
2095                 sort.Strings(got)
2096                 w.Header().Set("Got-Header", strings.Join(got, ","))
2097         }, optOnlyServer)
2098         defer st.Close()
2099
2100         tests := [...]struct {
2101                 h       http.Header
2102                 wantErr string
2103         }{
2104                 0: {
2105                         h:       http.Header{"with space": {"foo"}},
2106                         wantErr: `invalid HTTP header name "with space"`,
2107                 },
2108                 1: {
2109                         h:       http.Header{"name": {"Брэд"}},
2110                         wantErr: "", // okay
2111                 },
2112                 2: {
2113                         h:       http.Header{"имя": {"Brad"}},
2114                         wantErr: `invalid HTTP header name "имя"`,
2115                 },
2116                 3: {
2117                         h:       http.Header{"foo": {"foo\x01bar"}},
2118                         wantErr: `invalid HTTP header value "foo\x01bar" for header "foo"`,
2119                 },
2120         }
2121
2122         tr := &Transport{TLSClientConfig: tlsConfigInsecure}
2123         defer tr.CloseIdleConnections()
2124
2125         for i, tt := range tests {
2126                 req, _ := http.NewRequest("GET", st.ts.URL, nil)
2127                 req.Header = tt.h
2128                 res, err := tr.RoundTrip(req)
2129                 var bad bool
2130                 if tt.wantErr == "" {
2131                         if err != nil {
2132                                 bad = true
2133                                 t.Errorf("case %d: error = %v; want no error", i, err)
2134                         }
2135                 } else {
2136                         if !strings.Contains(fmt.Sprint(err), tt.wantErr) {
2137                                 bad = true
2138                                 t.Errorf("case %d: error = %v; want error %q", i, err, tt.wantErr)
2139                         }
2140                 }
2141                 if err == nil {
2142                         if bad {
2143                                 t.Logf("case %d: server got headers %q", i, res.Header.Get("Got-Header"))
2144                         }
2145                         res.Body.Close()
2146                 }
2147         }
2148 }
2149
2150 // Tests that gzipReader doesn't crash on a second Read call following
2151 // the first Read call's gzip.NewReader returning an error.
2152 func TestGzipReader_DoubleReadCrash(t *testing.T) {
2153         gz := &gzipReader{
2154                 body: ioutil.NopCloser(strings.NewReader("0123456789")),
2155         }
2156         var buf [1]byte
2157         n, err1 := gz.Read(buf[:])
2158         if n != 0 || !strings.Contains(fmt.Sprint(err1), "invalid header") {
2159                 t.Fatalf("Read = %v, %v; want 0, invalid header", n, err1)
2160         }
2161         n, err2 := gz.Read(buf[:])
2162         if n != 0 || err2 != err1 {
2163                 t.Fatalf("second Read = %v, %v; want 0, %v", n, err2, err1)
2164         }
2165 }
2166
2167 func TestTransportNewTLSConfig(t *testing.T) {
2168         tests := [...]struct {
2169                 conf *tls.Config
2170                 host string
2171                 want *tls.Config
2172         }{
2173                 // Normal case.
2174                 0: {
2175                         conf: nil,
2176                         host: "foo.com",
2177                         want: &tls.Config{
2178                                 ServerName: "foo.com",
2179                                 NextProtos: []string{NextProtoTLS},
2180                         },
2181                 },
2182
2183                 // User-provided name (bar.com) takes precedence:
2184                 1: {
2185                         conf: &tls.Config{
2186                                 ServerName: "bar.com",
2187                         },
2188                         host: "foo.com",
2189                         want: &tls.Config{
2190                                 ServerName: "bar.com",
2191                                 NextProtos: []string{NextProtoTLS},
2192                         },
2193                 },
2194
2195                 // NextProto is prepended:
2196                 2: {
2197                         conf: &tls.Config{
2198                                 NextProtos: []string{"foo", "bar"},
2199                         },
2200                         host: "example.com",
2201                         want: &tls.Config{
2202                                 ServerName: "example.com",
2203                                 NextProtos: []string{NextProtoTLS, "foo", "bar"},
2204                         },
2205                 },
2206
2207                 // NextProto is not duplicated:
2208                 3: {
2209                         conf: &tls.Config{
2210                                 NextProtos: []string{"foo", "bar", NextProtoTLS},
2211                         },
2212                         host: "example.com",
2213                         want: &tls.Config{
2214                                 ServerName: "example.com",
2215                                 NextProtos: []string{"foo", "bar", NextProtoTLS},
2216                         },
2217                 },
2218         }
2219         for i, tt := range tests {
2220                 // Ignore the session ticket keys part, which ends up populating
2221                 // unexported fields in the Config:
2222                 if tt.conf != nil {
2223                         tt.conf.SessionTicketsDisabled = true
2224                 }
2225
2226                 tr := &Transport{TLSClientConfig: tt.conf}
2227                 got := tr.newTLSConfig(tt.host)
2228
2229                 got.SessionTicketsDisabled = false
2230
2231                 if !reflect.DeepEqual(got, tt.want) {
2232                         t.Errorf("%d. got %#v; want %#v", i, got, tt.want)
2233                 }
2234         }
2235 }
2236
2237 // The Google GFE responds to HEAD requests with a HEADERS frame
2238 // without END_STREAM, followed by a 0-length DATA frame with
2239 // END_STREAM. Make sure we don't get confused by that. (We did.)
2240 func TestTransportReadHeadResponse(t *testing.T) {
2241         ct := newClientTester(t)
2242         clientDone := make(chan struct{})
2243         ct.client = func() error {
2244                 defer close(clientDone)
2245                 req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil)
2246                 res, err := ct.tr.RoundTrip(req)
2247                 if err != nil {
2248                         return err
2249                 }
2250                 if res.ContentLength != 123 {
2251                         return fmt.Errorf("Content-Length = %d; want 123", res.ContentLength)
2252                 }
2253                 slurp, err := ioutil.ReadAll(res.Body)
2254                 if err != nil {
2255                         return fmt.Errorf("ReadAll: %v", err)
2256                 }
2257                 if len(slurp) > 0 {
2258                         return fmt.Errorf("Unexpected non-empty ReadAll body: %q", slurp)
2259                 }
2260                 return nil
2261         }
2262         ct.server = func() error {
2263                 ct.greet()
2264                 for {
2265                         f, err := ct.fr.ReadFrame()
2266                         if err != nil {
2267                                 t.Logf("ReadFrame: %v", err)
2268                                 return nil
2269                         }
2270                         hf, ok := f.(*HeadersFrame)
2271                         if !ok {
2272                                 continue
2273                         }
2274                         var buf bytes.Buffer
2275                         enc := hpack.NewEncoder(&buf)
2276                         enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
2277                         enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "123"})
2278                         ct.fr.WriteHeaders(HeadersFrameParam{
2279                                 StreamID:      hf.StreamID,
2280                                 EndHeaders:    true,
2281                                 EndStream:     false, // as the GFE does
2282                                 BlockFragment: buf.Bytes(),
2283                         })
2284                         ct.fr.WriteData(hf.StreamID, true, nil)
2285
2286                         <-clientDone
2287                         return nil
2288                 }
2289         }
2290         ct.run()
2291 }
2292
2293 type neverEnding byte
2294
2295 func (b neverEnding) Read(p []byte) (int, error) {
2296         for i := range p {
2297                 p[i] = byte(b)
2298         }
2299         return len(p), nil
2300 }
2301
2302 // golang.org/issue/15425: test that a handler closing the request
2303 // body doesn't terminate the stream to the peer. (It just stops
2304 // readability from the handler's side, and eventually the client
2305 // runs out of flow control tokens)
2306 func TestTransportHandlerBodyClose(t *testing.T) {
2307         const bodySize = 10 << 20
2308         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
2309                 r.Body.Close()
2310                 io.Copy(w, io.LimitReader(neverEnding('A'), bodySize))
2311         }, optOnlyServer)
2312         defer st.Close()
2313
2314         tr := &Transport{TLSClientConfig: tlsConfigInsecure}
2315         defer tr.CloseIdleConnections()
2316
2317         g0 := runtime.NumGoroutine()
2318
2319         const numReq = 10
2320         for i := 0; i < numReq; i++ {
2321                 req, err := http.NewRequest("POST", st.ts.URL, struct{ io.Reader }{io.LimitReader(neverEnding('A'), bodySize)})
2322                 if err != nil {
2323                         t.Fatal(err)
2324                 }
2325                 res, err := tr.RoundTrip(req)
2326                 if err != nil {
2327                         t.Fatal(err)
2328                 }
2329                 n, err := io.Copy(ioutil.Discard, res.Body)
2330                 res.Body.Close()
2331                 if n != bodySize || err != nil {
2332                         t.Fatalf("req#%d: Copy = %d, %v; want %d, nil", i, n, err, bodySize)
2333                 }
2334         }
2335         tr.CloseIdleConnections()
2336
2337         gd := runtime.NumGoroutine() - g0
2338         if gd > numReq/2 {
2339                 t.Errorf("appeared to leak goroutines")
2340         }
2341
2342 }
2343
2344 // https://golang.org/issue/15930
2345 func TestTransportFlowControl(t *testing.T) {
2346         const bufLen = 64 << 10
2347         var total int64 = 100 << 20 // 100MB
2348         if testing.Short() {
2349                 total = 10 << 20
2350         }
2351
2352         var wrote int64 // updated atomically
2353         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
2354                 b := make([]byte, bufLen)
2355                 for wrote < total {
2356                         n, err := w.Write(b)
2357                         atomic.AddInt64(&wrote, int64(n))
2358                         if err != nil {
2359                                 t.Errorf("ResponseWriter.Write error: %v", err)
2360                                 break
2361                         }
2362                         w.(http.Flusher).Flush()
2363                 }
2364         }, optOnlyServer)
2365
2366         tr := &Transport{TLSClientConfig: tlsConfigInsecure}
2367         defer tr.CloseIdleConnections()
2368         req, err := http.NewRequest("GET", st.ts.URL, nil)
2369         if err != nil {
2370                 t.Fatal("NewRequest error:", err)
2371         }
2372         resp, err := tr.RoundTrip(req)
2373         if err != nil {
2374                 t.Fatal("RoundTrip error:", err)
2375         }
2376         defer resp.Body.Close()
2377
2378         var read int64
2379         b := make([]byte, bufLen)
2380         for {
2381                 n, err := resp.Body.Read(b)
2382                 if err == io.EOF {
2383                         break
2384                 }
2385                 if err != nil {
2386                         t.Fatal("Read error:", err)
2387                 }
2388                 read += int64(n)
2389
2390                 const max = transportDefaultStreamFlow
2391                 if w := atomic.LoadInt64(&wrote); -max > read-w || read-w > max {
2392                         t.Fatalf("Too much data inflight: server wrote %v bytes but client only received %v", w, read)
2393                 }
2394
2395                 // Let the server get ahead of the client.
2396                 time.Sleep(1 * time.Millisecond)
2397         }
2398 }
2399
2400 // golang.org/issue/14627 -- if the server sends a GOAWAY frame, make
2401 // the Transport remember it and return it back to users (via
2402 // RoundTrip or request body reads) if needed (e.g. if the server
2403 // proceeds to close the TCP connection before the client gets its
2404 // response)
2405 func TestTransportUsesGoAwayDebugError_RoundTrip(t *testing.T) {
2406         testTransportUsesGoAwayDebugError(t, false)
2407 }
2408
2409 func TestTransportUsesGoAwayDebugError_Body(t *testing.T) {
2410         testTransportUsesGoAwayDebugError(t, true)
2411 }
2412
2413 func testTransportUsesGoAwayDebugError(t *testing.T, failMidBody bool) {
2414         ct := newClientTester(t)
2415         clientDone := make(chan struct{})
2416
2417         const goAwayErrCode = ErrCodeHTTP11Required // arbitrary
2418         const goAwayDebugData = "some debug data"
2419
2420         ct.client = func() error {
2421                 defer close(clientDone)
2422                 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
2423                 res, err := ct.tr.RoundTrip(req)
2424                 if failMidBody {
2425                         if err != nil {
2426                                 return fmt.Errorf("unexpected client RoundTrip error: %v", err)
2427                         }
2428                         _, err = io.Copy(ioutil.Discard, res.Body)
2429                         res.Body.Close()
2430                 }
2431                 want := GoAwayError{
2432                         LastStreamID: 5,
2433                         ErrCode:      goAwayErrCode,
2434                         DebugData:    goAwayDebugData,
2435                 }
2436                 if !reflect.DeepEqual(err, want) {
2437                         t.Errorf("RoundTrip error = %T: %#v, want %T (%#v)", err, err, want, want)
2438                 }
2439                 return nil
2440         }
2441         ct.server = func() error {
2442                 ct.greet()
2443                 for {
2444                         f, err := ct.fr.ReadFrame()
2445                         if err != nil {
2446                                 t.Logf("ReadFrame: %v", err)
2447                                 return nil
2448                         }
2449                         hf, ok := f.(*HeadersFrame)
2450                         if !ok {
2451                                 continue
2452                         }
2453                         if failMidBody {
2454                                 var buf bytes.Buffer
2455                                 enc := hpack.NewEncoder(&buf)
2456                                 enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
2457                                 enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "123"})
2458                                 ct.fr.WriteHeaders(HeadersFrameParam{
2459                                         StreamID:      hf.StreamID,
2460                                         EndHeaders:    true,
2461                                         EndStream:     false,
2462                                         BlockFragment: buf.Bytes(),
2463                                 })
2464                         }
2465                         // Write two GOAWAY frames, to test that the Transport takes
2466                         // the interesting parts of both.
2467                         ct.fr.WriteGoAway(5, ErrCodeNo, []byte(goAwayDebugData))
2468                         ct.fr.WriteGoAway(5, goAwayErrCode, nil)
2469                         ct.sc.(*net.TCPConn).CloseWrite()
2470                         <-clientDone
2471                         return nil
2472                 }
2473         }
2474         ct.run()
2475 }
2476
2477 func testTransportReturnsUnusedFlowControl(t *testing.T, oneDataFrame bool) {
2478         ct := newClientTester(t)
2479
2480         clientClosed := make(chan struct{})
2481         serverWroteFirstByte := make(chan struct{})
2482
2483         ct.client = func() error {
2484                 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
2485                 res, err := ct.tr.RoundTrip(req)
2486                 if err != nil {
2487                         return err
2488                 }
2489                 <-serverWroteFirstByte
2490
2491                 if n, err := res.Body.Read(make([]byte, 1)); err != nil || n != 1 {
2492                         return fmt.Errorf("body read = %v, %v; want 1, nil", n, err)
2493                 }
2494                 res.Body.Close() // leaving 4999 bytes unread
2495                 close(clientClosed)
2496
2497                 return nil
2498         }
2499         ct.server = func() error {
2500                 ct.greet()
2501
2502                 var hf *HeadersFrame
2503                 for {
2504                         f, err := ct.fr.ReadFrame()
2505                         if err != nil {
2506                                 return fmt.Errorf("ReadFrame while waiting for Headers: %v", err)
2507                         }
2508                         switch f.(type) {
2509                         case *WindowUpdateFrame, *SettingsFrame:
2510                                 continue
2511                         }
2512                         var ok bool
2513                         hf, ok = f.(*HeadersFrame)
2514                         if !ok {
2515                                 return fmt.Errorf("Got %T; want HeadersFrame", f)
2516                         }
2517                         break
2518                 }
2519
2520                 var buf bytes.Buffer
2521                 enc := hpack.NewEncoder(&buf)
2522                 enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
2523                 enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "5000"})
2524                 ct.fr.WriteHeaders(HeadersFrameParam{
2525                         StreamID:      hf.StreamID,
2526                         EndHeaders:    true,
2527                         EndStream:     false,
2528                         BlockFragment: buf.Bytes(),
2529                 })
2530
2531                 // Two cases:
2532                 // - Send one DATA frame with 5000 bytes.
2533                 // - Send two DATA frames with 1 and 4999 bytes each.
2534                 //
2535                 // In both cases, the client should consume one byte of data,
2536                 // refund that byte, then refund the following 4999 bytes.
2537                 //
2538                 // In the second case, the server waits for the client connection to
2539                 // close before seconding the second DATA frame. This tests the case
2540                 // where the client receives a DATA frame after it has reset the stream.
2541                 if oneDataFrame {
2542                         ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 5000))
2543                         close(serverWroteFirstByte)
2544                         <-clientClosed
2545                 } else {
2546                         ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 1))
2547                         close(serverWroteFirstByte)
2548                         <-clientClosed
2549                         ct.fr.WriteData(hf.StreamID, false /* don't end stream */, make([]byte, 4999))
2550                 }
2551
2552                 waitingFor := "RSTStreamFrame"
2553                 for {
2554                         f, err := ct.fr.ReadFrame()
2555                         if err != nil {
2556                                 return fmt.Errorf("ReadFrame while waiting for %s: %v", waitingFor, err)
2557                         }
2558                         if _, ok := f.(*SettingsFrame); ok {
2559                                 continue
2560                         }
2561                         switch waitingFor {
2562                         case "RSTStreamFrame":
2563                                 if rf, ok := f.(*RSTStreamFrame); !ok || rf.ErrCode != ErrCodeCancel {
2564                                         return fmt.Errorf("Expected a RSTStreamFrame with code cancel; got %v", summarizeFrame(f))
2565                                 }
2566                                 waitingFor = "WindowUpdateFrame"
2567                         case "WindowUpdateFrame":
2568                                 if wuf, ok := f.(*WindowUpdateFrame); !ok || wuf.Increment != 4999 {
2569                                         return fmt.Errorf("Expected WindowUpdateFrame for 4999 bytes; got %v", summarizeFrame(f))
2570                                 }
2571                                 return nil
2572                         }
2573                 }
2574         }
2575         ct.run()
2576 }
2577
2578 // See golang.org/issue/16481
2579 func TestTransportReturnsUnusedFlowControlSingleWrite(t *testing.T) {
2580         testTransportReturnsUnusedFlowControl(t, true)
2581 }
2582
2583 // See golang.org/issue/20469
2584 func TestTransportReturnsUnusedFlowControlMultipleWrites(t *testing.T) {
2585         testTransportReturnsUnusedFlowControl(t, false)
2586 }
2587
2588 // Issue 16612: adjust flow control on open streams when transport
2589 // receives SETTINGS with INITIAL_WINDOW_SIZE from server.
2590 func TestTransportAdjustsFlowControl(t *testing.T) {
2591         ct := newClientTester(t)
2592         clientDone := make(chan struct{})
2593
2594         const bodySize = 1 << 20
2595
2596         ct.client = func() error {
2597                 defer ct.cc.(*net.TCPConn).CloseWrite()
2598                 defer close(clientDone)
2599
2600                 req, _ := http.NewRequest("POST", "https://dummy.tld/", struct{ io.Reader }{io.LimitReader(neverEnding('A'), bodySize)})
2601                 res, err := ct.tr.RoundTrip(req)
2602                 if err != nil {
2603                         return err
2604                 }
2605                 res.Body.Close()
2606                 return nil
2607         }
2608         ct.server = func() error {
2609                 _, err := io.ReadFull(ct.sc, make([]byte, len(ClientPreface)))
2610                 if err != nil {
2611                         return fmt.Errorf("reading client preface: %v", err)
2612                 }
2613
2614                 var gotBytes int64
2615                 var sentSettings bool
2616                 for {
2617                         f, err := ct.fr.ReadFrame()
2618                         if err != nil {
2619                                 select {
2620                                 case <-clientDone:
2621                                         return nil
2622                                 default:
2623                                         return fmt.Errorf("ReadFrame while waiting for Headers: %v", err)
2624                                 }
2625                         }
2626                         switch f := f.(type) {
2627                         case *DataFrame:
2628                                 gotBytes += int64(len(f.Data()))
2629                                 // After we've got half the client's
2630                                 // initial flow control window's worth
2631                                 // of request body data, give it just
2632                                 // enough flow control to finish.
2633                                 if gotBytes >= initialWindowSize/2 && !sentSettings {
2634                                         sentSettings = true
2635
2636                                         ct.fr.WriteSettings(Setting{ID: SettingInitialWindowSize, Val: bodySize})
2637                                         ct.fr.WriteWindowUpdate(0, bodySize)
2638                                         ct.fr.WriteSettingsAck()
2639                                 }
2640
2641                                 if f.StreamEnded() {
2642                                         var buf bytes.Buffer
2643                                         enc := hpack.NewEncoder(&buf)
2644                                         enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
2645                                         ct.fr.WriteHeaders(HeadersFrameParam{
2646                                                 StreamID:      f.StreamID,
2647                                                 EndHeaders:    true,
2648                                                 EndStream:     true,
2649                                                 BlockFragment: buf.Bytes(),
2650                                         })
2651                                 }
2652                         }
2653                 }
2654         }
2655         ct.run()
2656 }
2657
2658 // See golang.org/issue/16556
2659 func TestTransportReturnsDataPaddingFlowControl(t *testing.T) {
2660         ct := newClientTester(t)
2661
2662         unblockClient := make(chan bool, 1)
2663
2664         ct.client = func() error {
2665                 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
2666                 res, err := ct.tr.RoundTrip(req)
2667                 if err != nil {
2668                         return err
2669                 }
2670                 defer res.Body.Close()
2671                 <-unblockClient
2672                 return nil
2673         }
2674         ct.server = func() error {
2675                 ct.greet()
2676
2677                 var hf *HeadersFrame
2678                 for {
2679                         f, err := ct.fr.ReadFrame()
2680                         if err != nil {
2681                                 return fmt.Errorf("ReadFrame while waiting for Headers: %v", err)
2682                         }
2683                         switch f.(type) {
2684                         case *WindowUpdateFrame, *SettingsFrame:
2685                                 continue
2686                         }
2687                         var ok bool
2688                         hf, ok = f.(*HeadersFrame)
2689                         if !ok {
2690                                 return fmt.Errorf("Got %T; want HeadersFrame", f)
2691                         }
2692                         break
2693                 }
2694
2695                 var buf bytes.Buffer
2696                 enc := hpack.NewEncoder(&buf)
2697                 enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
2698                 enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "5000"})
2699                 ct.fr.WriteHeaders(HeadersFrameParam{
2700                         StreamID:      hf.StreamID,
2701                         EndHeaders:    true,
2702                         EndStream:     false,
2703                         BlockFragment: buf.Bytes(),
2704                 })
2705                 pad := make([]byte, 5)
2706                 ct.fr.WriteDataPadded(hf.StreamID, false, make([]byte, 5000), pad) // without ending stream
2707
2708                 f, err := ct.readNonSettingsFrame()
2709                 if err != nil {
2710                         return fmt.Errorf("ReadFrame while waiting for first WindowUpdateFrame: %v", err)
2711                 }
2712                 wantBack := uint32(len(pad)) + 1 // one byte for the length of the padding
2713                 if wuf, ok := f.(*WindowUpdateFrame); !ok || wuf.Increment != wantBack || wuf.StreamID != 0 {
2714                         return fmt.Errorf("Expected conn WindowUpdateFrame for %d bytes; got %v", wantBack, summarizeFrame(f))
2715                 }
2716
2717                 f, err = ct.readNonSettingsFrame()
2718                 if err != nil {
2719                         return fmt.Errorf("ReadFrame while waiting for second WindowUpdateFrame: %v", err)
2720                 }
2721                 if wuf, ok := f.(*WindowUpdateFrame); !ok || wuf.Increment != wantBack || wuf.StreamID == 0 {
2722                         return fmt.Errorf("Expected stream WindowUpdateFrame for %d bytes; got %v", wantBack, summarizeFrame(f))
2723                 }
2724                 unblockClient <- true
2725                 return nil
2726         }
2727         ct.run()
2728 }
2729
2730 // golang.org/issue/16572 -- RoundTrip shouldn't hang when it gets a
2731 // StreamError as a result of the response HEADERS
2732 func TestTransportReturnsErrorOnBadResponseHeaders(t *testing.T) {
2733         ct := newClientTester(t)
2734
2735         ct.client = func() error {
2736                 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
2737                 res, err := ct.tr.RoundTrip(req)
2738                 if err == nil {
2739                         res.Body.Close()
2740                         return errors.New("unexpected successful GET")
2741                 }
2742                 want := StreamError{1, ErrCodeProtocol, headerFieldNameError("  content-type")}
2743                 if !reflect.DeepEqual(want, err) {
2744                         t.Errorf("RoundTrip error = %#v; want %#v", err, want)
2745                 }
2746                 return nil
2747         }
2748         ct.server = func() error {
2749                 ct.greet()
2750
2751                 hf, err := ct.firstHeaders()
2752                 if err != nil {
2753                         return err
2754                 }
2755
2756                 var buf bytes.Buffer
2757                 enc := hpack.NewEncoder(&buf)
2758                 enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
2759                 enc.WriteField(hpack.HeaderField{Name: "  content-type", Value: "bogus"}) // bogus spaces
2760                 ct.fr.WriteHeaders(HeadersFrameParam{
2761                         StreamID:      hf.StreamID,
2762                         EndHeaders:    true,
2763                         EndStream:     false,
2764                         BlockFragment: buf.Bytes(),
2765                 })
2766
2767                 for {
2768                         fr, err := ct.readFrame()
2769                         if err != nil {
2770                                 return fmt.Errorf("error waiting for RST_STREAM from client: %v", err)
2771                         }
2772                         if _, ok := fr.(*SettingsFrame); ok {
2773                                 continue
2774                         }
2775                         if rst, ok := fr.(*RSTStreamFrame); !ok || rst.StreamID != 1 || rst.ErrCode != ErrCodeProtocol {
2776                                 t.Errorf("Frame = %v; want RST_STREAM for stream 1 with ErrCodeProtocol", summarizeFrame(fr))
2777                         }
2778                         break
2779                 }
2780
2781                 return nil
2782         }
2783         ct.run()
2784 }
2785
2786 // byteAndEOFReader returns is in an io.Reader which reads one byte
2787 // (the underlying byte) and io.EOF at once in its Read call.
2788 type byteAndEOFReader byte
2789
2790 func (b byteAndEOFReader) Read(p []byte) (n int, err error) {
2791         if len(p) == 0 {
2792                 panic("unexpected useless call")
2793         }
2794         p[0] = byte(b)
2795         return 1, io.EOF
2796 }
2797
2798 // Issue 16788: the Transport had a regression where it started
2799 // sending a spurious DATA frame with a duplicate END_STREAM bit after
2800 // the request body writer goroutine had already read an EOF from the
2801 // Request.Body and included the END_STREAM on a data-carrying DATA
2802 // frame.
2803 //
2804 // Notably, to trigger this, the requests need to use a Request.Body
2805 // which returns (non-0, io.EOF) and also needs to set the ContentLength
2806 // explicitly.
2807 func TestTransportBodyDoubleEndStream(t *testing.T) {
2808         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
2809                 // Nothing.
2810         }, optOnlyServer)
2811         defer st.Close()
2812
2813         tr := &Transport{TLSClientConfig: tlsConfigInsecure}
2814         defer tr.CloseIdleConnections()
2815
2816         for i := 0; i < 2; i++ {
2817                 req, _ := http.NewRequest("POST", st.ts.URL, byteAndEOFReader('a'))
2818                 req.ContentLength = 1
2819                 res, err := tr.RoundTrip(req)
2820                 if err != nil {
2821                         t.Fatalf("failure on req %d: %v", i+1, err)
2822                 }
2823                 defer res.Body.Close()
2824         }
2825 }
2826
2827 // golang.org/issue/16847, golang.org/issue/19103
2828 func TestTransportRequestPathPseudo(t *testing.T) {
2829         type result struct {
2830                 path string
2831                 err  string
2832         }
2833         tests := []struct {
2834                 req  *http.Request
2835                 want result
2836         }{
2837                 0: {
2838                         req: &http.Request{
2839                                 Method: "GET",
2840                                 URL: &url.URL{
2841                                         Host: "foo.com",
2842                                         Path: "/foo",
2843                                 },
2844                         },
2845                         want: result{path: "/foo"},
2846                 },
2847                 // In Go 1.7, we accepted paths of "//foo".
2848                 // In Go 1.8, we rejected it (issue 16847).
2849                 // In Go 1.9, we accepted it again (issue 19103).
2850                 1: {
2851                         req: &http.Request{
2852                                 Method: "GET",
2853                                 URL: &url.URL{
2854                                         Host: "foo.com",
2855                                         Path: "//foo",
2856                                 },
2857                         },
2858                         want: result{path: "//foo"},
2859                 },
2860
2861                 // Opaque with //$Matching_Hostname/path
2862                 2: {
2863                         req: &http.Request{
2864                                 Method: "GET",
2865                                 URL: &url.URL{
2866                                         Scheme: "https",
2867                                         Opaque: "//foo.com/path",
2868                                         Host:   "foo.com",
2869                                         Path:   "/ignored",
2870                                 },
2871                         },
2872                         want: result{path: "/path"},
2873                 },
2874
2875                 // Opaque with some other Request.Host instead:
2876                 3: {
2877                         req: &http.Request{
2878                                 Method: "GET",
2879                                 Host:   "bar.com",
2880                                 URL: &url.URL{
2881                                         Scheme: "https",
2882                                         Opaque: "//bar.com/path",
2883                                         Host:   "foo.com",
2884                                         Path:   "/ignored",
2885                                 },
2886                         },
2887                         want: result{path: "/path"},
2888                 },
2889
2890                 // Opaque without the leading "//":
2891                 4: {
2892                         req: &http.Request{
2893                                 Method: "GET",
2894                                 URL: &url.URL{
2895                                         Opaque: "/path",
2896                                         Host:   "foo.com",
2897                                         Path:   "/ignored",
2898                                 },
2899                         },
2900                         want: result{path: "/path"},
2901                 },
2902
2903                 // Opaque we can't handle:
2904                 5: {
2905                         req: &http.Request{
2906                                 Method: "GET",
2907                                 URL: &url.URL{
2908                                         Scheme: "https",
2909                                         Opaque: "//unknown_host/path",
2910                                         Host:   "foo.com",
2911                                         Path:   "/ignored",
2912                                 },
2913                         },
2914                         want: result{err: `invalid request :path "https://unknown_host/path" from URL.Opaque = "//unknown_host/path"`},
2915                 },
2916
2917                 // A CONNECT request:
2918                 6: {
2919                         req: &http.Request{
2920                                 Method: "CONNECT",
2921                                 URL: &url.URL{
2922                                         Host: "foo.com",
2923                                 },
2924                         },
2925                         want: result{},
2926                 },
2927         }
2928         for i, tt := range tests {
2929                 cc := &ClientConn{peerMaxHeaderListSize: 0xffffffffffffffff}
2930                 cc.henc = hpack.NewEncoder(&cc.hbuf)
2931                 cc.mu.Lock()
2932                 hdrs, err := cc.encodeHeaders(tt.req, false, "", -1)
2933                 cc.mu.Unlock()
2934                 var got result
2935                 hpackDec := hpack.NewDecoder(initialHeaderTableSize, func(f hpack.HeaderField) {
2936                         if f.Name == ":path" {
2937                                 got.path = f.Value
2938                         }
2939                 })
2940                 if err != nil {
2941                         got.err = err.Error()
2942                 } else if len(hdrs) > 0 {
2943                         if _, err := hpackDec.Write(hdrs); err != nil {
2944                                 t.Errorf("%d. bogus hpack: %v", i, err)
2945                                 continue
2946                         }
2947                 }
2948                 if got != tt.want {
2949                         t.Errorf("%d. got %+v; want %+v", i, got, tt.want)
2950                 }
2951
2952         }
2953
2954 }
2955
2956 // golang.org/issue/17071 -- don't sniff the first byte of the request body
2957 // before we've determined that the ClientConn is usable.
2958 func TestRoundTripDoesntConsumeRequestBodyEarly(t *testing.T) {
2959         const body = "foo"
2960         req, _ := http.NewRequest("POST", "http://foo.com/", ioutil.NopCloser(strings.NewReader(body)))
2961         cc := &ClientConn{
2962                 closed: true,
2963         }
2964         _, err := cc.RoundTrip(req)
2965         if err != errClientConnUnusable {
2966                 t.Fatalf("RoundTrip = %v; want errClientConnUnusable", err)
2967         }
2968         slurp, err := ioutil.ReadAll(req.Body)
2969         if err != nil {
2970                 t.Errorf("ReadAll = %v", err)
2971         }
2972         if string(slurp) != body {
2973                 t.Errorf("Body = %q; want %q", slurp, body)
2974         }
2975 }
2976
2977 func TestClientConnPing(t *testing.T) {
2978         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, optOnlyServer)
2979         defer st.Close()
2980         tr := &Transport{TLSClientConfig: tlsConfigInsecure}
2981         defer tr.CloseIdleConnections()
2982         cc, err := tr.dialClientConn(st.ts.Listener.Addr().String(), false)
2983         if err != nil {
2984                 t.Fatal(err)
2985         }
2986         if err = cc.Ping(testContext{}); err != nil {
2987                 t.Fatal(err)
2988         }
2989 }
2990
2991 // Issue 16974: if the server sent a DATA frame after the user
2992 // canceled the Transport's Request, the Transport previously wrote to a
2993 // closed pipe, got an error, and ended up closing the whole TCP
2994 // connection.
2995 func TestTransportCancelDataResponseRace(t *testing.T) {
2996         cancel := make(chan struct{})
2997         clientGotError := make(chan bool, 1)
2998
2999         const msg = "Hello."
3000         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3001                 if strings.Contains(r.URL.Path, "/hello") {
3002                         time.Sleep(50 * time.Millisecond)
3003                         io.WriteString(w, msg)
3004                         return
3005                 }
3006                 for i := 0; i < 50; i++ {
3007                         io.WriteString(w, "Some data.")
3008                         w.(http.Flusher).Flush()
3009                         if i == 2 {
3010                                 close(cancel)
3011                                 <-clientGotError
3012                         }
3013                         time.Sleep(10 * time.Millisecond)
3014                 }
3015         }, optOnlyServer)
3016         defer st.Close()
3017
3018         tr := &Transport{TLSClientConfig: tlsConfigInsecure}
3019         defer tr.CloseIdleConnections()
3020
3021         c := &http.Client{Transport: tr}
3022         req, _ := http.NewRequest("GET", st.ts.URL, nil)
3023         req.Cancel = cancel
3024         res, err := c.Do(req)
3025         if err != nil {
3026                 t.Fatal(err)
3027         }
3028         if _, err = io.Copy(ioutil.Discard, res.Body); err == nil {
3029                 t.Fatal("unexpected success")
3030         }
3031         clientGotError <- true
3032
3033         res, err = c.Get(st.ts.URL + "/hello")
3034         if err != nil {
3035                 t.Fatal(err)
3036         }
3037         slurp, err := ioutil.ReadAll(res.Body)
3038         if err != nil {
3039                 t.Fatal(err)
3040         }
3041         if string(slurp) != msg {
3042                 t.Errorf("Got = %q; want %q", slurp, msg)
3043         }
3044 }
3045
3046 func TestTransportRetryAfterGOAWAY(t *testing.T) {
3047         var dialer struct {
3048                 sync.Mutex
3049                 count int
3050         }
3051         ct1 := make(chan *clientTester)
3052         ct2 := make(chan *clientTester)
3053
3054         ln := newLocalListener(t)
3055         defer ln.Close()
3056
3057         tr := &Transport{
3058                 TLSClientConfig: tlsConfigInsecure,
3059         }
3060         tr.DialTLS = func(network, addr string, cfg *tls.Config) (net.Conn, error) {
3061                 dialer.Lock()
3062                 defer dialer.Unlock()
3063                 dialer.count++
3064                 if dialer.count == 3 {
3065                         return nil, errors.New("unexpected number of dials")
3066                 }
3067                 cc, err := net.Dial("tcp", ln.Addr().String())
3068                 if err != nil {
3069                         return nil, fmt.Errorf("dial error: %v", err)
3070                 }
3071                 sc, err := ln.Accept()
3072                 if err != nil {
3073                         return nil, fmt.Errorf("accept error: %v", err)
3074                 }
3075                 ct := &clientTester{
3076                         t:  t,
3077                         tr: tr,
3078                         cc: cc,
3079                         sc: sc,
3080                         fr: NewFramer(sc, sc),
3081                 }
3082                 switch dialer.count {
3083                 case 1:
3084                         ct1 <- ct
3085                 case 2:
3086                         ct2 <- ct
3087                 }
3088                 return cc, nil
3089         }
3090
3091         errs := make(chan error, 3)
3092         done := make(chan struct{})
3093         defer close(done)
3094
3095         // Client.
3096         go func() {
3097                 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
3098                 res, err := tr.RoundTrip(req)
3099                 if res != nil {
3100                         res.Body.Close()
3101                         if got := res.Header.Get("Foo"); got != "bar" {
3102                                 err = fmt.Errorf("foo header = %q; want bar", got)
3103                         }
3104                 }
3105                 if err != nil {
3106                         err = fmt.Errorf("RoundTrip: %v", err)
3107                 }
3108                 errs <- err
3109         }()
3110
3111         connToClose := make(chan io.Closer, 2)
3112
3113         // Server for the first request.
3114         go func() {
3115                 var ct *clientTester
3116                 select {
3117                 case ct = <-ct1:
3118                 case <-done:
3119                         return
3120                 }
3121
3122                 connToClose <- ct.cc
3123                 ct.greet()
3124                 hf, err := ct.firstHeaders()
3125                 if err != nil {
3126                         errs <- fmt.Errorf("server1 failed reading HEADERS: %v", err)
3127                         return
3128                 }
3129                 t.Logf("server1 got %v", hf)
3130                 if err := ct.fr.WriteGoAway(0 /*max id*/, ErrCodeNo, nil); err != nil {
3131                         errs <- fmt.Errorf("server1 failed writing GOAWAY: %v", err)
3132                         return
3133                 }
3134                 errs <- nil
3135         }()
3136
3137         // Server for the second request.
3138         go func() {
3139                 var ct *clientTester
3140                 select {
3141                 case ct = <-ct2:
3142                 case <-done:
3143                         return
3144                 }
3145
3146                 connToClose <- ct.cc
3147                 ct.greet()
3148                 hf, err := ct.firstHeaders()
3149                 if err != nil {
3150                         errs <- fmt.Errorf("server2 failed reading HEADERS: %v", err)
3151                         return
3152                 }
3153                 t.Logf("server2 got %v", hf)
3154
3155                 var buf bytes.Buffer
3156                 enc := hpack.NewEncoder(&buf)
3157                 enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
3158                 enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"})
3159                 err = ct.fr.WriteHeaders(HeadersFrameParam{
3160                         StreamID:      hf.StreamID,
3161                         EndHeaders:    true,
3162                         EndStream:     false,
3163                         BlockFragment: buf.Bytes(),
3164                 })
3165                 if err != nil {
3166                         errs <- fmt.Errorf("server2 failed writing response HEADERS: %v", err)
3167                 } else {
3168                         errs <- nil
3169                 }
3170         }()
3171
3172         for k := 0; k < 3; k++ {
3173                 select {
3174                 case err := <-errs:
3175                         if err != nil {
3176                                 t.Error(err)
3177                         }
3178                 case <-time.After(1 * time.Second):
3179                         t.Errorf("timed out")
3180                 }
3181         }
3182
3183         for {
3184                 select {
3185                 case c := <-connToClose:
3186                         c.Close()
3187                 default:
3188                         return
3189                 }
3190         }
3191 }
3192
3193 func TestTransportRetryAfterRefusedStream(t *testing.T) {
3194         clientDone := make(chan struct{})
3195         ct := newClientTester(t)
3196         ct.client = func() error {
3197                 defer ct.cc.(*net.TCPConn).CloseWrite()
3198                 defer close(clientDone)
3199                 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
3200                 resp, err := ct.tr.RoundTrip(req)
3201                 if err != nil {
3202                         return fmt.Errorf("RoundTrip: %v", err)
3203                 }
3204                 resp.Body.Close()
3205                 if resp.StatusCode != 204 {
3206                         return fmt.Errorf("Status = %v; want 204", resp.StatusCode)
3207                 }
3208                 return nil
3209         }
3210         ct.server = func() error {
3211                 ct.greet()
3212                 var buf bytes.Buffer
3213                 enc := hpack.NewEncoder(&buf)
3214                 nreq := 0
3215
3216                 for {
3217                         f, err := ct.fr.ReadFrame()
3218                         if err != nil {
3219                                 select {
3220                                 case <-clientDone:
3221                                         // If the client's done, it
3222                                         // will have reported any
3223                                         // errors on its side.
3224                                         return nil
3225                                 default:
3226                                         return err
3227                                 }
3228                         }
3229                         switch f := f.(type) {
3230                         case *WindowUpdateFrame, *SettingsFrame:
3231                         case *HeadersFrame:
3232                                 if !f.HeadersEnded() {
3233                                         return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
3234                                 }
3235                                 nreq++
3236                                 if nreq == 1 {
3237                                         ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream)
3238                                 } else {
3239                                         enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"})
3240                                         ct.fr.WriteHeaders(HeadersFrameParam{
3241                                                 StreamID:      f.StreamID,
3242                                                 EndHeaders:    true,
3243                                                 EndStream:     true,
3244                                                 BlockFragment: buf.Bytes(),
3245                                         })
3246                                 }
3247                         default:
3248                                 return fmt.Errorf("Unexpected client frame %v", f)
3249                         }
3250                 }
3251         }
3252         ct.run()
3253 }
3254
3255 func TestTransportRetryHasLimit(t *testing.T) {
3256         // Skip in short mode because the total expected delay is 1s+2s+4s+8s+16s=29s.
3257         if testing.Short() {
3258                 t.Skip("skipping long test in short mode")
3259         }
3260         clientDone := make(chan struct{})
3261         ct := newClientTester(t)
3262         ct.client = func() error {
3263                 defer ct.cc.(*net.TCPConn).CloseWrite()
3264                 defer close(clientDone)
3265                 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
3266                 resp, err := ct.tr.RoundTrip(req)
3267                 if err == nil {
3268                         return fmt.Errorf("RoundTrip expected error, got response: %+v", resp)
3269                 }
3270                 t.Logf("expected error, got: %v", err)
3271                 return nil
3272         }
3273         ct.server = func() error {
3274                 ct.greet()
3275                 for {
3276                         f, err := ct.fr.ReadFrame()
3277                         if err != nil {
3278                                 select {
3279                                 case <-clientDone:
3280                                         // If the client's done, it
3281                                         // will have reported any
3282                                         // errors on its side.
3283                                         return nil
3284                                 default:
3285                                         return err
3286                                 }
3287                         }
3288                         switch f := f.(type) {
3289                         case *WindowUpdateFrame, *SettingsFrame:
3290                         case *HeadersFrame:
3291                                 if !f.HeadersEnded() {
3292                                         return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
3293                                 }
3294                                 ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream)
3295                         default:
3296                                 return fmt.Errorf("Unexpected client frame %v", f)
3297                         }
3298                 }
3299         }
3300         ct.run()
3301 }
3302
3303 func TestTransportResponseDataBeforeHeaders(t *testing.T) {
3304         ct := newClientTester(t)
3305         ct.client = func() error {
3306                 defer ct.cc.(*net.TCPConn).CloseWrite()
3307                 req := httptest.NewRequest("GET", "https://dummy.tld/", nil)
3308                 // First request is normal to ensure the check is per stream and not per connection.
3309                 _, err := ct.tr.RoundTrip(req)
3310                 if err != nil {
3311                         return fmt.Errorf("RoundTrip expected no error, got: %v", err)
3312                 }
3313                 // Second request returns a DATA frame with no HEADERS.
3314                 resp, err := ct.tr.RoundTrip(req)
3315                 if err == nil {
3316                         return fmt.Errorf("RoundTrip expected error, got response: %+v", resp)
3317                 }
3318                 if err, ok := err.(StreamError); !ok || err.Code != ErrCodeProtocol {
3319                         return fmt.Errorf("expected stream PROTOCOL_ERROR, got: %v", err)
3320                 }
3321                 return nil
3322         }
3323         ct.server = func() error {
3324                 ct.greet()
3325                 for {
3326                         f, err := ct.fr.ReadFrame()
3327                         if err == io.EOF {
3328                                 return nil
3329                         } else if err != nil {
3330                                 return err
3331                         }
3332                         switch f := f.(type) {
3333                         case *WindowUpdateFrame, *SettingsFrame:
3334                         case *HeadersFrame:
3335                                 switch f.StreamID {
3336                                 case 1:
3337                                         // Send a valid response to first request.
3338                                         var buf bytes.Buffer
3339                                         enc := hpack.NewEncoder(&buf)
3340                                         enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
3341                                         ct.fr.WriteHeaders(HeadersFrameParam{
3342                                                 StreamID:      f.StreamID,
3343                                                 EndHeaders:    true,
3344                                                 EndStream:     true,
3345                                                 BlockFragment: buf.Bytes(),
3346                                         })
3347                                 case 3:
3348                                         ct.fr.WriteData(f.StreamID, true, []byte("payload"))
3349                                 }
3350                         default:
3351                                 return fmt.Errorf("Unexpected client frame %v", f)
3352                         }
3353                 }
3354         }
3355         ct.run()
3356 }
3357 func TestTransportRequestsStallAtServerLimit(t *testing.T) {
3358         const maxConcurrent = 2
3359
3360         greet := make(chan struct{})      // server sends initial SETTINGS frame
3361         gotRequest := make(chan struct{}) // server received a request
3362         clientDone := make(chan struct{})
3363
3364         // Collect errors from goroutines.
3365         var wg sync.WaitGroup
3366         errs := make(chan error, 100)
3367         defer func() {
3368                 wg.Wait()
3369                 close(errs)
3370                 for err := range errs {
3371                         t.Error(err)
3372                 }
3373         }()
3374
3375         // We will send maxConcurrent+2 requests. This checker goroutine waits for the
3376         // following stages:
3377         //   1. The first maxConcurrent requests are received by the server.
3378         //   2. The client will cancel the next request
3379         //   3. The server is unblocked so it can service the first maxConcurrent requests
3380         //   4. The client will send the final request
3381         wg.Add(1)
3382         unblockClient := make(chan struct{})
3383         clientRequestCancelled := make(chan struct{})
3384         unblockServer := make(chan struct{})
3385         go func() {
3386                 defer wg.Done()
3387                 // Stage 1.
3388                 for k := 0; k < maxConcurrent; k++ {
3389                         <-gotRequest
3390                 }
3391                 // Stage 2.
3392                 close(unblockClient)
3393                 <-clientRequestCancelled
3394                 // Stage 3: give some time for the final RoundTrip call to be scheduled and
3395                 // verify that the final request is not sent.
3396                 time.Sleep(50 * time.Millisecond)
3397                 select {
3398                 case <-gotRequest:
3399                         errs <- errors.New("last request did not stall")
3400                         close(unblockServer)
3401                         return
3402                 default:
3403                 }
3404                 close(unblockServer)
3405                 // Stage 4.
3406                 <-gotRequest
3407         }()
3408
3409         ct := newClientTester(t)
3410         ct.client = func() error {
3411                 var wg sync.WaitGroup
3412                 defer func() {
3413                         wg.Wait()
3414                         close(clientDone)
3415                         ct.cc.(*net.TCPConn).CloseWrite()
3416                 }()
3417                 for k := 0; k < maxConcurrent+2; k++ {
3418                         wg.Add(1)
3419                         go func(k int) {
3420                                 defer wg.Done()
3421                                 // Don't send the second request until after receiving SETTINGS from the server
3422                                 // to avoid a race where we use the default SettingMaxConcurrentStreams, which
3423                                 // is much larger than maxConcurrent. We have to send the first request before
3424                                 // waiting because the first request triggers the dial and greet.
3425                                 if k > 0 {
3426                                         <-greet
3427                                 }
3428                                 // Block until maxConcurrent requests are sent before sending any more.
3429                                 if k >= maxConcurrent {
3430                                         <-unblockClient
3431                                 }
3432                                 req, _ := http.NewRequest("GET", fmt.Sprintf("https://dummy.tld/%d", k), nil)
3433                                 if k == maxConcurrent {
3434                                         // This request will be canceled.
3435                                         cancel := make(chan struct{})
3436                                         req.Cancel = cancel
3437                                         close(cancel)
3438                                         _, err := ct.tr.RoundTrip(req)
3439                                         close(clientRequestCancelled)
3440                                         if err == nil {
3441                                                 errs <- fmt.Errorf("RoundTrip(%d) should have failed due to cancel", k)
3442                                                 return
3443                                         }
3444                                 } else {
3445                                         resp, err := ct.tr.RoundTrip(req)
3446                                         if err != nil {
3447                                                 errs <- fmt.Errorf("RoundTrip(%d): %v", k, err)
3448                                                 return
3449                                         }
3450                                         ioutil.ReadAll(resp.Body)
3451                                         resp.Body.Close()
3452                                         if resp.StatusCode != 204 {
3453                                                 errs <- fmt.Errorf("Status = %v; want 204", resp.StatusCode)
3454                                                 return
3455                                         }
3456                                 }
3457                         }(k)
3458                 }
3459                 return nil
3460         }
3461
3462         ct.server = func() error {
3463                 var wg sync.WaitGroup
3464                 defer wg.Wait()
3465
3466                 ct.greet(Setting{SettingMaxConcurrentStreams, maxConcurrent})
3467
3468                 // Server write loop.
3469                 var buf bytes.Buffer
3470                 enc := hpack.NewEncoder(&buf)
3471                 writeResp := make(chan uint32, maxConcurrent+1)
3472
3473                 wg.Add(1)
3474                 go func() {
3475                         defer wg.Done()
3476                         <-unblockServer
3477                         for id := range writeResp {
3478                                 buf.Reset()
3479                                 enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"})
3480                                 ct.fr.WriteHeaders(HeadersFrameParam{
3481                                         StreamID:      id,
3482                                         EndHeaders:    true,
3483                                         EndStream:     true,
3484                                         BlockFragment: buf.Bytes(),
3485                                 })
3486                         }
3487                 }()
3488
3489                 // Server read loop.
3490                 var nreq int
3491                 for {
3492                         f, err := ct.fr.ReadFrame()
3493                         if err != nil {
3494                                 select {
3495                                 case <-clientDone:
3496                                         // If the client's done, it will have reported any errors on its side.
3497                                         return nil
3498                                 default:
3499                                         return err
3500                                 }
3501                         }
3502                         switch f := f.(type) {
3503                         case *WindowUpdateFrame:
3504                         case *SettingsFrame:
3505                                 // Wait for the client SETTINGS ack until ending the greet.
3506                                 close(greet)
3507                         case *HeadersFrame:
3508                                 if !f.HeadersEnded() {
3509                                         return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
3510                                 }
3511                                 gotRequest <- struct{}{}
3512                                 nreq++
3513                                 writeResp <- f.StreamID
3514                                 if nreq == maxConcurrent+1 {
3515                                         close(writeResp)
3516                                 }
3517                         default:
3518                                 return fmt.Errorf("Unexpected client frame %v", f)
3519                         }
3520                 }
3521         }
3522
3523         ct.run()
3524 }
3525
3526 func TestAuthorityAddr(t *testing.T) {
3527         tests := []struct {
3528                 scheme, authority string
3529                 want              string
3530         }{
3531                 {"http", "foo.com", "foo.com:80"},
3532                 {"https", "foo.com", "foo.com:443"},
3533                 {"https", "foo.com:1234", "foo.com:1234"},
3534                 {"https", "1.2.3.4:1234", "1.2.3.4:1234"},
3535                 {"https", "1.2.3.4", "1.2.3.4:443"},
3536                 {"https", "[::1]:1234", "[::1]:1234"},
3537                 {"https", "[::1]", "[::1]:443"},
3538         }
3539         for _, tt := range tests {
3540                 got := authorityAddr(tt.scheme, tt.authority)
3541                 if got != tt.want {
3542                         t.Errorf("authorityAddr(%q, %q) = %q; want %q", tt.scheme, tt.authority, got, tt.want)
3543                 }
3544         }
3545 }
3546
3547 // Issue 20448: stop allocating for DATA frames' payload after
3548 // Response.Body.Close is called.
3549 func TestTransportAllocationsAfterResponseBodyClose(t *testing.T) {
3550         megabyteZero := make([]byte, 1<<20)
3551
3552         writeErr := make(chan error, 1)
3553
3554         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3555                 w.(http.Flusher).Flush()
3556                 var sum int64
3557                 for i := 0; i < 100; i++ {
3558                         n, err := w.Write(megabyteZero)
3559                         sum += int64(n)
3560                         if err != nil {
3561                                 writeErr <- err
3562                                 return
3563                         }
3564                 }
3565                 t.Logf("wrote all %d bytes", sum)
3566                 writeErr <- nil
3567         }, optOnlyServer)
3568         defer st.Close()
3569
3570         tr := &Transport{TLSClientConfig: tlsConfigInsecure}
3571         defer tr.CloseIdleConnections()
3572         c := &http.Client{Transport: tr}
3573         res, err := c.Get(st.ts.URL)
3574         if err != nil {
3575                 t.Fatal(err)
3576         }
3577         var buf [1]byte
3578         if _, err := res.Body.Read(buf[:]); err != nil {
3579                 t.Error(err)
3580         }
3581         if err := res.Body.Close(); err != nil {
3582                 t.Error(err)
3583         }
3584
3585         trb, ok := res.Body.(transportResponseBody)
3586         if !ok {
3587                 t.Fatalf("res.Body = %T; want transportResponseBody", res.Body)
3588         }
3589         if trb.cs.bufPipe.b != nil {
3590                 t.Errorf("response body pipe is still open")
3591         }
3592
3593         gotErr := <-writeErr
3594         if gotErr == nil {
3595                 t.Errorf("Handler unexpectedly managed to write its entire response without getting an error")
3596         } else if gotErr != errStreamClosed {
3597                 t.Errorf("Handler Write err = %v; want errStreamClosed", gotErr)
3598         }
3599 }
3600
3601 // Issue 18891: make sure Request.Body == NoBody means no DATA frame
3602 // is ever sent, even if empty.
3603 func TestTransportNoBodyMeansNoDATA(t *testing.T) {
3604         ct := newClientTester(t)
3605
3606         unblockClient := make(chan bool)
3607
3608         ct.client = func() error {
3609                 req, _ := http.NewRequest("GET", "https://dummy.tld/", go18httpNoBody())
3610                 ct.tr.RoundTrip(req)
3611                 <-unblockClient
3612                 return nil
3613         }
3614         ct.server = func() error {
3615                 defer close(unblockClient)
3616                 defer ct.cc.(*net.TCPConn).Close()
3617                 ct.greet()
3618
3619                 for {
3620                         f, err := ct.fr.ReadFrame()
3621                         if err != nil {
3622                                 return fmt.Errorf("ReadFrame while waiting for Headers: %v", err)
3623                         }
3624                         switch f := f.(type) {
3625                         default:
3626                                 return fmt.Errorf("Got %T; want HeadersFrame", f)
3627                         case *WindowUpdateFrame, *SettingsFrame:
3628                                 continue
3629                         case *HeadersFrame:
3630                                 if !f.StreamEnded() {
3631                                         return fmt.Errorf("got headers frame without END_STREAM")
3632                                 }
3633                                 return nil
3634                         }
3635                 }
3636         }
3637         ct.run()
3638 }
3639
3640 func benchSimpleRoundTrip(b *testing.B, nHeaders int) {
3641         defer disableGoroutineTracking()()
3642         b.ReportAllocs()
3643         st := newServerTester(b,
3644                 func(w http.ResponseWriter, r *http.Request) {
3645                 },
3646                 optOnlyServer,
3647                 optQuiet,
3648         )
3649         defer st.Close()
3650
3651         tr := &Transport{TLSClientConfig: tlsConfigInsecure}
3652         defer tr.CloseIdleConnections()
3653
3654         req, err := http.NewRequest("GET", st.ts.URL, nil)
3655         if err != nil {
3656                 b.Fatal(err)
3657         }
3658
3659         for i := 0; i < nHeaders; i++ {
3660                 name := fmt.Sprint("A-", i)
3661                 req.Header.Set(name, "*")
3662         }
3663
3664         b.ResetTimer()
3665
3666         for i := 0; i < b.N; i++ {
3667                 res, err := tr.RoundTrip(req)
3668                 if err != nil {
3669                         if res != nil {
3670                                 res.Body.Close()
3671                         }
3672                         b.Fatalf("RoundTrip err = %v; want nil", err)
3673                 }
3674                 res.Body.Close()
3675                 if res.StatusCode != http.StatusOK {
3676                         b.Fatalf("Response code = %v; want %v", res.StatusCode, http.StatusOK)
3677                 }
3678         }
3679 }
3680
3681 func BenchmarkClientRequestHeaders(b *testing.B) {
3682         b.Run("   0 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0) })
3683         b.Run("  10 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 10) })
3684         b.Run(" 100 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 100) })
3685         b.Run("1000 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 1000) })
3686 }