OSDN Git Service

new repo
[bytom/vapor.git] / vendor / golang.org / x / net / http2 / server_test.go
1 // Copyright 2014 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package http2
6
7 import (
8         "bytes"
9         "crypto/tls"
10         "errors"
11         "flag"
12         "fmt"
13         "io"
14         "io/ioutil"
15         "log"
16         "net"
17         "net/http"
18         "net/http/httptest"
19         "os"
20         "os/exec"
21         "reflect"
22         "runtime"
23         "strconv"
24         "strings"
25         "sync"
26         "sync/atomic"
27         "testing"
28         "time"
29
30         "golang.org/x/net/http2/hpack"
31 )
32
33 var stderrVerbose = flag.Bool("stderr_verbose", false, "Mirror verbosity to stderr, unbuffered")
34
35 func stderrv() io.Writer {
36         if *stderrVerbose {
37                 return os.Stderr
38         }
39
40         return ioutil.Discard
41 }
42
43 type serverTester struct {
44         cc             net.Conn // client conn
45         t              testing.TB
46         ts             *httptest.Server
47         fr             *Framer
48         serverLogBuf   bytes.Buffer // logger for httptest.Server
49         logFilter      []string     // substrings to filter out
50         scMu           sync.Mutex   // guards sc
51         sc             *serverConn
52         hpackDec       *hpack.Decoder
53         decodedHeaders [][2]string
54
55         // If http2debug!=2, then we capture Frame debug logs that will be written
56         // to t.Log after a test fails. The read and write logs use separate locks
57         // and buffers so we don't accidentally introduce synchronization between
58         // the read and write goroutines, which may hide data races.
59         frameReadLogMu   sync.Mutex
60         frameReadLogBuf  bytes.Buffer
61         frameWriteLogMu  sync.Mutex
62         frameWriteLogBuf bytes.Buffer
63
64         // writing headers:
65         headerBuf bytes.Buffer
66         hpackEnc  *hpack.Encoder
67 }
68
69 func init() {
70         testHookOnPanicMu = new(sync.Mutex)
71 }
72
73 func resetHooks() {
74         testHookOnPanicMu.Lock()
75         testHookOnPanic = nil
76         testHookOnPanicMu.Unlock()
77 }
78
79 type serverTesterOpt string
80
81 var optOnlyServer = serverTesterOpt("only_server")
82 var optQuiet = serverTesterOpt("quiet_logging")
83 var optFramerReuseFrames = serverTesterOpt("frame_reuse_frames")
84
85 func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{}) *serverTester {
86         resetHooks()
87
88         ts := httptest.NewUnstartedServer(handler)
89
90         tlsConfig := &tls.Config{
91                 InsecureSkipVerify: true,
92                 NextProtos:         []string{NextProtoTLS},
93         }
94
95         var onlyServer, quiet, framerReuseFrames bool
96         h2server := new(Server)
97         for _, opt := range opts {
98                 switch v := opt.(type) {
99                 case func(*tls.Config):
100                         v(tlsConfig)
101                 case func(*httptest.Server):
102                         v(ts)
103                 case func(*Server):
104                         v(h2server)
105                 case serverTesterOpt:
106                         switch v {
107                         case optOnlyServer:
108                                 onlyServer = true
109                         case optQuiet:
110                                 quiet = true
111                         case optFramerReuseFrames:
112                                 framerReuseFrames = true
113                         }
114                 case func(net.Conn, http.ConnState):
115                         ts.Config.ConnState = v
116                 default:
117                         t.Fatalf("unknown newServerTester option type %T", v)
118                 }
119         }
120
121         ConfigureServer(ts.Config, h2server)
122
123         st := &serverTester{
124                 t:  t,
125                 ts: ts,
126         }
127         st.hpackEnc = hpack.NewEncoder(&st.headerBuf)
128         st.hpackDec = hpack.NewDecoder(initialHeaderTableSize, st.onHeaderField)
129
130         ts.TLS = ts.Config.TLSConfig // the httptest.Server has its own copy of this TLS config
131         if quiet {
132                 ts.Config.ErrorLog = log.New(ioutil.Discard, "", 0)
133         } else {
134                 ts.Config.ErrorLog = log.New(io.MultiWriter(stderrv(), twriter{t: t, st: st}, &st.serverLogBuf), "", log.LstdFlags)
135         }
136         ts.StartTLS()
137
138         if VerboseLogs {
139                 t.Logf("Running test server at: %s", ts.URL)
140         }
141         testHookGetServerConn = func(v *serverConn) {
142                 st.scMu.Lock()
143                 defer st.scMu.Unlock()
144                 st.sc = v
145         }
146         log.SetOutput(io.MultiWriter(stderrv(), twriter{t: t, st: st}))
147         if !onlyServer {
148                 cc, err := tls.Dial("tcp", ts.Listener.Addr().String(), tlsConfig)
149                 if err != nil {
150                         t.Fatal(err)
151                 }
152                 st.cc = cc
153                 st.fr = NewFramer(cc, cc)
154                 if framerReuseFrames {
155                         st.fr.SetReuseFrames()
156                 }
157                 if !logFrameReads && !logFrameWrites {
158                         st.fr.debugReadLoggerf = func(m string, v ...interface{}) {
159                                 m = time.Now().Format("2006-01-02 15:04:05.999999999 ") + strings.TrimPrefix(m, "http2: ") + "\n"
160                                 st.frameReadLogMu.Lock()
161                                 fmt.Fprintf(&st.frameReadLogBuf, m, v...)
162                                 st.frameReadLogMu.Unlock()
163                         }
164                         st.fr.debugWriteLoggerf = func(m string, v ...interface{}) {
165                                 m = time.Now().Format("2006-01-02 15:04:05.999999999 ") + strings.TrimPrefix(m, "http2: ") + "\n"
166                                 st.frameWriteLogMu.Lock()
167                                 fmt.Fprintf(&st.frameWriteLogBuf, m, v...)
168                                 st.frameWriteLogMu.Unlock()
169                         }
170                         st.fr.logReads = true
171                         st.fr.logWrites = true
172                 }
173         }
174         return st
175 }
176
177 func (st *serverTester) closeConn() {
178         st.scMu.Lock()
179         defer st.scMu.Unlock()
180         st.sc.conn.Close()
181 }
182
183 func (st *serverTester) addLogFilter(phrase string) {
184         st.logFilter = append(st.logFilter, phrase)
185 }
186
187 func (st *serverTester) stream(id uint32) *stream {
188         ch := make(chan *stream, 1)
189         st.sc.serveMsgCh <- func(int) {
190                 ch <- st.sc.streams[id]
191         }
192         return <-ch
193 }
194
195 func (st *serverTester) streamState(id uint32) streamState {
196         ch := make(chan streamState, 1)
197         st.sc.serveMsgCh <- func(int) {
198                 state, _ := st.sc.state(id)
199                 ch <- state
200         }
201         return <-ch
202 }
203
204 // loopNum reports how many times this conn's select loop has gone around.
205 func (st *serverTester) loopNum() int {
206         lastc := make(chan int, 1)
207         st.sc.serveMsgCh <- func(loopNum int) {
208                 lastc <- loopNum
209         }
210         return <-lastc
211 }
212
213 // awaitIdle heuristically awaits for the server conn's select loop to be idle.
214 // The heuristic is that the server connection's serve loop must schedule
215 // 50 times in a row without any channel sends or receives occurring.
216 func (st *serverTester) awaitIdle() {
217         remain := 50
218         last := st.loopNum()
219         for remain > 0 {
220                 n := st.loopNum()
221                 if n == last+1 {
222                         remain--
223                 } else {
224                         remain = 50
225                 }
226                 last = n
227         }
228 }
229
230 func (st *serverTester) Close() {
231         if st.t.Failed() {
232                 st.frameReadLogMu.Lock()
233                 if st.frameReadLogBuf.Len() > 0 {
234                         st.t.Logf("Framer read log:\n%s", st.frameReadLogBuf.String())
235                 }
236                 st.frameReadLogMu.Unlock()
237
238                 st.frameWriteLogMu.Lock()
239                 if st.frameWriteLogBuf.Len() > 0 {
240                         st.t.Logf("Framer write log:\n%s", st.frameWriteLogBuf.String())
241                 }
242                 st.frameWriteLogMu.Unlock()
243
244                 // If we failed already (and are likely in a Fatal,
245                 // unwindowing), force close the connection, so the
246                 // httptest.Server doesn't wait forever for the conn
247                 // to close.
248                 if st.cc != nil {
249                         st.cc.Close()
250                 }
251         }
252         st.ts.Close()
253         if st.cc != nil {
254                 st.cc.Close()
255         }
256         log.SetOutput(os.Stderr)
257 }
258
259 // greet initiates the client's HTTP/2 connection into a state where
260 // frames may be sent.
261 func (st *serverTester) greet() {
262         st.greetAndCheckSettings(func(Setting) error { return nil })
263 }
264
265 func (st *serverTester) greetAndCheckSettings(checkSetting func(s Setting) error) {
266         st.writePreface()
267         st.writeInitialSettings()
268         st.wantSettings().ForeachSetting(checkSetting)
269         st.writeSettingsAck()
270
271         // The initial WINDOW_UPDATE and SETTINGS ACK can come in any order.
272         var gotSettingsAck bool
273         var gotWindowUpdate bool
274
275         for i := 0; i < 2; i++ {
276                 f, err := st.readFrame()
277                 if err != nil {
278                         st.t.Fatal(err)
279                 }
280                 switch f := f.(type) {
281                 case *SettingsFrame:
282                         if !f.Header().Flags.Has(FlagSettingsAck) {
283                                 st.t.Fatal("Settings Frame didn't have ACK set")
284                         }
285                         gotSettingsAck = true
286
287                 case *WindowUpdateFrame:
288                         if f.FrameHeader.StreamID != 0 {
289                                 st.t.Fatalf("WindowUpdate StreamID = %d; want 0", f.FrameHeader.StreamID)
290                         }
291                         incr := uint32((&Server{}).initialConnRecvWindowSize() - initialWindowSize)
292                         if f.Increment != incr {
293                                 st.t.Fatalf("WindowUpdate increment = %d; want %d", f.Increment, incr)
294                         }
295                         gotWindowUpdate = true
296
297                 default:
298                         st.t.Fatalf("Wanting a settings ACK or window update, received a %T", f)
299                 }
300         }
301
302         if !gotSettingsAck {
303                 st.t.Fatalf("Didn't get a settings ACK")
304         }
305         if !gotWindowUpdate {
306                 st.t.Fatalf("Didn't get a window update")
307         }
308 }
309
310 func (st *serverTester) writePreface() {
311         n, err := st.cc.Write(clientPreface)
312         if err != nil {
313                 st.t.Fatalf("Error writing client preface: %v", err)
314         }
315         if n != len(clientPreface) {
316                 st.t.Fatalf("Writing client preface, wrote %d bytes; want %d", n, len(clientPreface))
317         }
318 }
319
320 func (st *serverTester) writeInitialSettings() {
321         if err := st.fr.WriteSettings(); err != nil {
322                 st.t.Fatalf("Error writing initial SETTINGS frame from client to server: %v", err)
323         }
324 }
325
326 func (st *serverTester) writeSettingsAck() {
327         if err := st.fr.WriteSettingsAck(); err != nil {
328                 st.t.Fatalf("Error writing ACK of server's SETTINGS: %v", err)
329         }
330 }
331
332 func (st *serverTester) writeHeaders(p HeadersFrameParam) {
333         if err := st.fr.WriteHeaders(p); err != nil {
334                 st.t.Fatalf("Error writing HEADERS: %v", err)
335         }
336 }
337
338 func (st *serverTester) writePriority(id uint32, p PriorityParam) {
339         if err := st.fr.WritePriority(id, p); err != nil {
340                 st.t.Fatalf("Error writing PRIORITY: %v", err)
341         }
342 }
343
344 func (st *serverTester) encodeHeaderField(k, v string) {
345         err := st.hpackEnc.WriteField(hpack.HeaderField{Name: k, Value: v})
346         if err != nil {
347                 st.t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err)
348         }
349 }
350
351 // encodeHeaderRaw is the magic-free version of encodeHeader.
352 // It takes 0 or more (k, v) pairs and encodes them.
353 func (st *serverTester) encodeHeaderRaw(headers ...string) []byte {
354         if len(headers)%2 == 1 {
355                 panic("odd number of kv args")
356         }
357         st.headerBuf.Reset()
358         for len(headers) > 0 {
359                 k, v := headers[0], headers[1]
360                 st.encodeHeaderField(k, v)
361                 headers = headers[2:]
362         }
363         return st.headerBuf.Bytes()
364 }
365
366 // encodeHeader encodes headers and returns their HPACK bytes. headers
367 // must contain an even number of key/value pairs. There may be
368 // multiple pairs for keys (e.g. "cookie").  The :method, :path, and
369 // :scheme headers default to GET, / and https. The :authority header
370 // defaults to st.ts.Listener.Addr().
371 func (st *serverTester) encodeHeader(headers ...string) []byte {
372         if len(headers)%2 == 1 {
373                 panic("odd number of kv args")
374         }
375
376         st.headerBuf.Reset()
377         defaultAuthority := st.ts.Listener.Addr().String()
378
379         if len(headers) == 0 {
380                 // Fast path, mostly for benchmarks, so test code doesn't pollute
381                 // profiles when we're looking to improve server allocations.
382                 st.encodeHeaderField(":method", "GET")
383                 st.encodeHeaderField(":scheme", "https")
384                 st.encodeHeaderField(":authority", defaultAuthority)
385                 st.encodeHeaderField(":path", "/")
386                 return st.headerBuf.Bytes()
387         }
388
389         if len(headers) == 2 && headers[0] == ":method" {
390                 // Another fast path for benchmarks.
391                 st.encodeHeaderField(":method", headers[1])
392                 st.encodeHeaderField(":scheme", "https")
393                 st.encodeHeaderField(":authority", defaultAuthority)
394                 st.encodeHeaderField(":path", "/")
395                 return st.headerBuf.Bytes()
396         }
397
398         pseudoCount := map[string]int{}
399         keys := []string{":method", ":scheme", ":authority", ":path"}
400         vals := map[string][]string{
401                 ":method":    {"GET"},
402                 ":scheme":    {"https"},
403                 ":authority": {defaultAuthority},
404                 ":path":      {"/"},
405         }
406         for len(headers) > 0 {
407                 k, v := headers[0], headers[1]
408                 headers = headers[2:]
409                 if _, ok := vals[k]; !ok {
410                         keys = append(keys, k)
411                 }
412                 if strings.HasPrefix(k, ":") {
413                         pseudoCount[k]++
414                         if pseudoCount[k] == 1 {
415                                 vals[k] = []string{v}
416                         } else {
417                                 // Allows testing of invalid headers w/ dup pseudo fields.
418                                 vals[k] = append(vals[k], v)
419                         }
420                 } else {
421                         vals[k] = append(vals[k], v)
422                 }
423         }
424         for _, k := range keys {
425                 for _, v := range vals[k] {
426                         st.encodeHeaderField(k, v)
427                 }
428         }
429         return st.headerBuf.Bytes()
430 }
431
432 // bodylessReq1 writes a HEADERS frames with StreamID 1 and EndStream and EndHeaders set.
433 func (st *serverTester) bodylessReq1(headers ...string) {
434         st.writeHeaders(HeadersFrameParam{
435                 StreamID:      1, // clients send odd numbers
436                 BlockFragment: st.encodeHeader(headers...),
437                 EndStream:     true,
438                 EndHeaders:    true,
439         })
440 }
441
442 func (st *serverTester) writeData(streamID uint32, endStream bool, data []byte) {
443         if err := st.fr.WriteData(streamID, endStream, data); err != nil {
444                 st.t.Fatalf("Error writing DATA: %v", err)
445         }
446 }
447
448 func (st *serverTester) writeDataPadded(streamID uint32, endStream bool, data, pad []byte) {
449         if err := st.fr.WriteDataPadded(streamID, endStream, data, pad); err != nil {
450                 st.t.Fatalf("Error writing DATA: %v", err)
451         }
452 }
453
454 func readFrameTimeout(fr *Framer, wait time.Duration) (Frame, error) {
455         ch := make(chan interface{}, 1)
456         go func() {
457                 fr, err := fr.ReadFrame()
458                 if err != nil {
459                         ch <- err
460                 } else {
461                         ch <- fr
462                 }
463         }()
464         t := time.NewTimer(wait)
465         select {
466         case v := <-ch:
467                 t.Stop()
468                 if fr, ok := v.(Frame); ok {
469                         return fr, nil
470                 }
471                 return nil, v.(error)
472         case <-t.C:
473                 return nil, errors.New("timeout waiting for frame")
474         }
475 }
476
477 func (st *serverTester) readFrame() (Frame, error) {
478         return readFrameTimeout(st.fr, 2*time.Second)
479 }
480
481 func (st *serverTester) wantHeaders() *HeadersFrame {
482         f, err := st.readFrame()
483         if err != nil {
484                 st.t.Fatalf("Error while expecting a HEADERS frame: %v", err)
485         }
486         hf, ok := f.(*HeadersFrame)
487         if !ok {
488                 st.t.Fatalf("got a %T; want *HeadersFrame", f)
489         }
490         return hf
491 }
492
493 func (st *serverTester) wantContinuation() *ContinuationFrame {
494         f, err := st.readFrame()
495         if err != nil {
496                 st.t.Fatalf("Error while expecting a CONTINUATION frame: %v", err)
497         }
498         cf, ok := f.(*ContinuationFrame)
499         if !ok {
500                 st.t.Fatalf("got a %T; want *ContinuationFrame", f)
501         }
502         return cf
503 }
504
505 func (st *serverTester) wantData() *DataFrame {
506         f, err := st.readFrame()
507         if err != nil {
508                 st.t.Fatalf("Error while expecting a DATA frame: %v", err)
509         }
510         df, ok := f.(*DataFrame)
511         if !ok {
512                 st.t.Fatalf("got a %T; want *DataFrame", f)
513         }
514         return df
515 }
516
517 func (st *serverTester) wantSettings() *SettingsFrame {
518         f, err := st.readFrame()
519         if err != nil {
520                 st.t.Fatalf("Error while expecting a SETTINGS frame: %v", err)
521         }
522         sf, ok := f.(*SettingsFrame)
523         if !ok {
524                 st.t.Fatalf("got a %T; want *SettingsFrame", f)
525         }
526         return sf
527 }
528
529 func (st *serverTester) wantPing() *PingFrame {
530         f, err := st.readFrame()
531         if err != nil {
532                 st.t.Fatalf("Error while expecting a PING frame: %v", err)
533         }
534         pf, ok := f.(*PingFrame)
535         if !ok {
536                 st.t.Fatalf("got a %T; want *PingFrame", f)
537         }
538         return pf
539 }
540
541 func (st *serverTester) wantGoAway() *GoAwayFrame {
542         f, err := st.readFrame()
543         if err != nil {
544                 st.t.Fatalf("Error while expecting a GOAWAY frame: %v", err)
545         }
546         gf, ok := f.(*GoAwayFrame)
547         if !ok {
548                 st.t.Fatalf("got a %T; want *GoAwayFrame", f)
549         }
550         return gf
551 }
552
553 func (st *serverTester) wantRSTStream(streamID uint32, errCode ErrCode) {
554         f, err := st.readFrame()
555         if err != nil {
556                 st.t.Fatalf("Error while expecting an RSTStream frame: %v", err)
557         }
558         rs, ok := f.(*RSTStreamFrame)
559         if !ok {
560                 st.t.Fatalf("got a %T; want *RSTStreamFrame", f)
561         }
562         if rs.FrameHeader.StreamID != streamID {
563                 st.t.Fatalf("RSTStream StreamID = %d; want %d", rs.FrameHeader.StreamID, streamID)
564         }
565         if rs.ErrCode != errCode {
566                 st.t.Fatalf("RSTStream ErrCode = %d (%s); want %d (%s)", rs.ErrCode, rs.ErrCode, errCode, errCode)
567         }
568 }
569
570 func (st *serverTester) wantWindowUpdate(streamID, incr uint32) {
571         f, err := st.readFrame()
572         if err != nil {
573                 st.t.Fatalf("Error while expecting a WINDOW_UPDATE frame: %v", err)
574         }
575         wu, ok := f.(*WindowUpdateFrame)
576         if !ok {
577                 st.t.Fatalf("got a %T; want *WindowUpdateFrame", f)
578         }
579         if wu.FrameHeader.StreamID != streamID {
580                 st.t.Fatalf("WindowUpdate StreamID = %d; want %d", wu.FrameHeader.StreamID, streamID)
581         }
582         if wu.Increment != incr {
583                 st.t.Fatalf("WindowUpdate increment = %d; want %d", wu.Increment, incr)
584         }
585 }
586
587 func (st *serverTester) wantSettingsAck() {
588         f, err := st.readFrame()
589         if err != nil {
590                 st.t.Fatal(err)
591         }
592         sf, ok := f.(*SettingsFrame)
593         if !ok {
594                 st.t.Fatalf("Wanting a settings ACK, received a %T", f)
595         }
596         if !sf.Header().Flags.Has(FlagSettingsAck) {
597                 st.t.Fatal("Settings Frame didn't have ACK set")
598         }
599 }
600
601 func (st *serverTester) wantPushPromise() *PushPromiseFrame {
602         f, err := st.readFrame()
603         if err != nil {
604                 st.t.Fatal(err)
605         }
606         ppf, ok := f.(*PushPromiseFrame)
607         if !ok {
608                 st.t.Fatalf("Wanted PushPromise, received %T", ppf)
609         }
610         return ppf
611 }
612
613 func TestServer(t *testing.T) {
614         gotReq := make(chan bool, 1)
615         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
616                 w.Header().Set("Foo", "Bar")
617                 gotReq <- true
618         })
619         defer st.Close()
620
621         covers("3.5", `
622                 The server connection preface consists of a potentially empty
623                 SETTINGS frame ([SETTINGS]) that MUST be the first frame the
624                 server sends in the HTTP/2 connection.
625         `)
626
627         st.greet()
628         st.writeHeaders(HeadersFrameParam{
629                 StreamID:      1, // clients send odd numbers
630                 BlockFragment: st.encodeHeader(),
631                 EndStream:     true, // no DATA frames
632                 EndHeaders:    true,
633         })
634
635         select {
636         case <-gotReq:
637         case <-time.After(2 * time.Second):
638                 t.Error("timeout waiting for request")
639         }
640 }
641
642 func TestServer_Request_Get(t *testing.T) {
643         testServerRequest(t, func(st *serverTester) {
644                 st.writeHeaders(HeadersFrameParam{
645                         StreamID:      1, // clients send odd numbers
646                         BlockFragment: st.encodeHeader("foo-bar", "some-value"),
647                         EndStream:     true, // no DATA frames
648                         EndHeaders:    true,
649                 })
650         }, func(r *http.Request) {
651                 if r.Method != "GET" {
652                         t.Errorf("Method = %q; want GET", r.Method)
653                 }
654                 if r.URL.Path != "/" {
655                         t.Errorf("URL.Path = %q; want /", r.URL.Path)
656                 }
657                 if r.ContentLength != 0 {
658                         t.Errorf("ContentLength = %v; want 0", r.ContentLength)
659                 }
660                 if r.Close {
661                         t.Error("Close = true; want false")
662                 }
663                 if !strings.Contains(r.RemoteAddr, ":") {
664                         t.Errorf("RemoteAddr = %q; want something with a colon", r.RemoteAddr)
665                 }
666                 if r.Proto != "HTTP/2.0" || r.ProtoMajor != 2 || r.ProtoMinor != 0 {
667                         t.Errorf("Proto = %q Major=%v,Minor=%v; want HTTP/2.0", r.Proto, r.ProtoMajor, r.ProtoMinor)
668                 }
669                 wantHeader := http.Header{
670                         "Foo-Bar": []string{"some-value"},
671                 }
672                 if !reflect.DeepEqual(r.Header, wantHeader) {
673                         t.Errorf("Header = %#v; want %#v", r.Header, wantHeader)
674                 }
675                 if n, err := r.Body.Read([]byte(" ")); err != io.EOF || n != 0 {
676                         t.Errorf("Read = %d, %v; want 0, EOF", n, err)
677                 }
678         })
679 }
680
681 func TestServer_Request_Get_PathSlashes(t *testing.T) {
682         testServerRequest(t, func(st *serverTester) {
683                 st.writeHeaders(HeadersFrameParam{
684                         StreamID:      1, // clients send odd numbers
685                         BlockFragment: st.encodeHeader(":path", "/%2f/"),
686                         EndStream:     true, // no DATA frames
687                         EndHeaders:    true,
688                 })
689         }, func(r *http.Request) {
690                 if r.RequestURI != "/%2f/" {
691                         t.Errorf("RequestURI = %q; want /%%2f/", r.RequestURI)
692                 }
693                 if r.URL.Path != "///" {
694                         t.Errorf("URL.Path = %q; want ///", r.URL.Path)
695                 }
696         })
697 }
698
699 // TODO: add a test with EndStream=true on the HEADERS but setting a
700 // Content-Length anyway. Should we just omit it and force it to
701 // zero?
702
703 func TestServer_Request_Post_NoContentLength_EndStream(t *testing.T) {
704         testServerRequest(t, func(st *serverTester) {
705                 st.writeHeaders(HeadersFrameParam{
706                         StreamID:      1, // clients send odd numbers
707                         BlockFragment: st.encodeHeader(":method", "POST"),
708                         EndStream:     true,
709                         EndHeaders:    true,
710                 })
711         }, func(r *http.Request) {
712                 if r.Method != "POST" {
713                         t.Errorf("Method = %q; want POST", r.Method)
714                 }
715                 if r.ContentLength != 0 {
716                         t.Errorf("ContentLength = %v; want 0", r.ContentLength)
717                 }
718                 if n, err := r.Body.Read([]byte(" ")); err != io.EOF || n != 0 {
719                         t.Errorf("Read = %d, %v; want 0, EOF", n, err)
720                 }
721         })
722 }
723
724 func TestServer_Request_Post_Body_ImmediateEOF(t *testing.T) {
725         testBodyContents(t, -1, "", func(st *serverTester) {
726                 st.writeHeaders(HeadersFrameParam{
727                         StreamID:      1, // clients send odd numbers
728                         BlockFragment: st.encodeHeader(":method", "POST"),
729                         EndStream:     false, // to say DATA frames are coming
730                         EndHeaders:    true,
731                 })
732                 st.writeData(1, true, nil) // just kidding. empty body.
733         })
734 }
735
736 func TestServer_Request_Post_Body_OneData(t *testing.T) {
737         const content = "Some content"
738         testBodyContents(t, -1, content, func(st *serverTester) {
739                 st.writeHeaders(HeadersFrameParam{
740                         StreamID:      1, // clients send odd numbers
741                         BlockFragment: st.encodeHeader(":method", "POST"),
742                         EndStream:     false, // to say DATA frames are coming
743                         EndHeaders:    true,
744                 })
745                 st.writeData(1, true, []byte(content))
746         })
747 }
748
749 func TestServer_Request_Post_Body_TwoData(t *testing.T) {
750         const content = "Some content"
751         testBodyContents(t, -1, content, func(st *serverTester) {
752                 st.writeHeaders(HeadersFrameParam{
753                         StreamID:      1, // clients send odd numbers
754                         BlockFragment: st.encodeHeader(":method", "POST"),
755                         EndStream:     false, // to say DATA frames are coming
756                         EndHeaders:    true,
757                 })
758                 st.writeData(1, false, []byte(content[:5]))
759                 st.writeData(1, true, []byte(content[5:]))
760         })
761 }
762
763 func TestServer_Request_Post_Body_ContentLength_Correct(t *testing.T) {
764         const content = "Some content"
765         testBodyContents(t, int64(len(content)), content, func(st *serverTester) {
766                 st.writeHeaders(HeadersFrameParam{
767                         StreamID: 1, // clients send odd numbers
768                         BlockFragment: st.encodeHeader(
769                                 ":method", "POST",
770                                 "content-length", strconv.Itoa(len(content)),
771                         ),
772                         EndStream:  false, // to say DATA frames are coming
773                         EndHeaders: true,
774                 })
775                 st.writeData(1, true, []byte(content))
776         })
777 }
778
779 func TestServer_Request_Post_Body_ContentLength_TooLarge(t *testing.T) {
780         testBodyContentsFail(t, 3, "request declared a Content-Length of 3 but only wrote 2 bytes",
781                 func(st *serverTester) {
782                         st.writeHeaders(HeadersFrameParam{
783                                 StreamID: 1, // clients send odd numbers
784                                 BlockFragment: st.encodeHeader(
785                                         ":method", "POST",
786                                         "content-length", "3",
787                                 ),
788                                 EndStream:  false, // to say DATA frames are coming
789                                 EndHeaders: true,
790                         })
791                         st.writeData(1, true, []byte("12"))
792                 })
793 }
794
795 func TestServer_Request_Post_Body_ContentLength_TooSmall(t *testing.T) {
796         testBodyContentsFail(t, 4, "sender tried to send more than declared Content-Length of 4 bytes",
797                 func(st *serverTester) {
798                         st.writeHeaders(HeadersFrameParam{
799                                 StreamID: 1, // clients send odd numbers
800                                 BlockFragment: st.encodeHeader(
801                                         ":method", "POST",
802                                         "content-length", "4",
803                                 ),
804                                 EndStream:  false, // to say DATA frames are coming
805                                 EndHeaders: true,
806                         })
807                         st.writeData(1, true, []byte("12345"))
808                 })
809 }
810
811 func testBodyContents(t *testing.T, wantContentLength int64, wantBody string, write func(st *serverTester)) {
812         testServerRequest(t, write, func(r *http.Request) {
813                 if r.Method != "POST" {
814                         t.Errorf("Method = %q; want POST", r.Method)
815                 }
816                 if r.ContentLength != wantContentLength {
817                         t.Errorf("ContentLength = %v; want %d", r.ContentLength, wantContentLength)
818                 }
819                 all, err := ioutil.ReadAll(r.Body)
820                 if err != nil {
821                         t.Fatal(err)
822                 }
823                 if string(all) != wantBody {
824                         t.Errorf("Read = %q; want %q", all, wantBody)
825                 }
826                 if err := r.Body.Close(); err != nil {
827                         t.Fatalf("Close: %v", err)
828                 }
829         })
830 }
831
832 func testBodyContentsFail(t *testing.T, wantContentLength int64, wantReadError string, write func(st *serverTester)) {
833         testServerRequest(t, write, func(r *http.Request) {
834                 if r.Method != "POST" {
835                         t.Errorf("Method = %q; want POST", r.Method)
836                 }
837                 if r.ContentLength != wantContentLength {
838                         t.Errorf("ContentLength = %v; want %d", r.ContentLength, wantContentLength)
839                 }
840                 all, err := ioutil.ReadAll(r.Body)
841                 if err == nil {
842                         t.Fatalf("expected an error (%q) reading from the body. Successfully read %q instead.",
843                                 wantReadError, all)
844                 }
845                 if !strings.Contains(err.Error(), wantReadError) {
846                         t.Fatalf("Body.Read = %v; want substring %q", err, wantReadError)
847                 }
848                 if err := r.Body.Close(); err != nil {
849                         t.Fatalf("Close: %v", err)
850                 }
851         })
852 }
853
854 // Using a Host header, instead of :authority
855 func TestServer_Request_Get_Host(t *testing.T) {
856         const host = "example.com"
857         testServerRequest(t, func(st *serverTester) {
858                 st.writeHeaders(HeadersFrameParam{
859                         StreamID:      1, // clients send odd numbers
860                         BlockFragment: st.encodeHeader(":authority", "", "host", host),
861                         EndStream:     true,
862                         EndHeaders:    true,
863                 })
864         }, func(r *http.Request) {
865                 if r.Host != host {
866                         t.Errorf("Host = %q; want %q", r.Host, host)
867                 }
868         })
869 }
870
871 // Using an :authority pseudo-header, instead of Host
872 func TestServer_Request_Get_Authority(t *testing.T) {
873         const host = "example.com"
874         testServerRequest(t, func(st *serverTester) {
875                 st.writeHeaders(HeadersFrameParam{
876                         StreamID:      1, // clients send odd numbers
877                         BlockFragment: st.encodeHeader(":authority", host),
878                         EndStream:     true,
879                         EndHeaders:    true,
880                 })
881         }, func(r *http.Request) {
882                 if r.Host != host {
883                         t.Errorf("Host = %q; want %q", r.Host, host)
884                 }
885         })
886 }
887
888 func TestServer_Request_WithContinuation(t *testing.T) {
889         wantHeader := http.Header{
890                 "Foo-One":   []string{"value-one"},
891                 "Foo-Two":   []string{"value-two"},
892                 "Foo-Three": []string{"value-three"},
893         }
894         testServerRequest(t, func(st *serverTester) {
895                 fullHeaders := st.encodeHeader(
896                         "foo-one", "value-one",
897                         "foo-two", "value-two",
898                         "foo-three", "value-three",
899                 )
900                 remain := fullHeaders
901                 chunks := 0
902                 for len(remain) > 0 {
903                         const maxChunkSize = 5
904                         chunk := remain
905                         if len(chunk) > maxChunkSize {
906                                 chunk = chunk[:maxChunkSize]
907                         }
908                         remain = remain[len(chunk):]
909
910                         if chunks == 0 {
911                                 st.writeHeaders(HeadersFrameParam{
912                                         StreamID:      1, // clients send odd numbers
913                                         BlockFragment: chunk,
914                                         EndStream:     true,  // no DATA frames
915                                         EndHeaders:    false, // we'll have continuation frames
916                                 })
917                         } else {
918                                 err := st.fr.WriteContinuation(1, len(remain) == 0, chunk)
919                                 if err != nil {
920                                         t.Fatal(err)
921                                 }
922                         }
923                         chunks++
924                 }
925                 if chunks < 2 {
926                         t.Fatal("too few chunks")
927                 }
928         }, func(r *http.Request) {
929                 if !reflect.DeepEqual(r.Header, wantHeader) {
930                         t.Errorf("Header = %#v; want %#v", r.Header, wantHeader)
931                 }
932         })
933 }
934
935 // Concatenated cookie headers. ("8.1.2.5 Compressing the Cookie Header Field")
936 func TestServer_Request_CookieConcat(t *testing.T) {
937         const host = "example.com"
938         testServerRequest(t, func(st *serverTester) {
939                 st.bodylessReq1(
940                         ":authority", host,
941                         "cookie", "a=b",
942                         "cookie", "c=d",
943                         "cookie", "e=f",
944                 )
945         }, func(r *http.Request) {
946                 const want = "a=b; c=d; e=f"
947                 if got := r.Header.Get("Cookie"); got != want {
948                         t.Errorf("Cookie = %q; want %q", got, want)
949                 }
950         })
951 }
952
953 func TestServer_Request_Reject_CapitalHeader(t *testing.T) {
954         testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("UPPER", "v") })
955 }
956
957 func TestServer_Request_Reject_HeaderFieldNameColon(t *testing.T) {
958         testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("has:colon", "v") })
959 }
960
961 func TestServer_Request_Reject_HeaderFieldNameNULL(t *testing.T) {
962         testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("has\x00null", "v") })
963 }
964
965 func TestServer_Request_Reject_HeaderFieldNameEmpty(t *testing.T) {
966         testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("", "v") })
967 }
968
969 func TestServer_Request_Reject_HeaderFieldValueNewline(t *testing.T) {
970         testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("foo", "has\nnewline") })
971 }
972
973 func TestServer_Request_Reject_HeaderFieldValueCR(t *testing.T) {
974         testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("foo", "has\rcarriage") })
975 }
976
977 func TestServer_Request_Reject_HeaderFieldValueDEL(t *testing.T) {
978         testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("foo", "has\x7fdel") })
979 }
980
981 func TestServer_Request_Reject_Pseudo_Missing_method(t *testing.T) {
982         testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":method", "") })
983 }
984
985 func TestServer_Request_Reject_Pseudo_ExactlyOne(t *testing.T) {
986         // 8.1.2.3 Request Pseudo-Header Fields
987         // "All HTTP/2 requests MUST include exactly one valid value" ...
988         testRejectRequest(t, func(st *serverTester) {
989                 st.addLogFilter("duplicate pseudo-header")
990                 st.bodylessReq1(":method", "GET", ":method", "POST")
991         })
992 }
993
994 func TestServer_Request_Reject_Pseudo_AfterRegular(t *testing.T) {
995         // 8.1.2.3 Request Pseudo-Header Fields
996         // "All pseudo-header fields MUST appear in the header block
997         // before regular header fields. Any request or response that
998         // contains a pseudo-header field that appears in a header
999         // block after a regular header field MUST be treated as
1000         // malformed (Section 8.1.2.6)."
1001         testRejectRequest(t, func(st *serverTester) {
1002                 st.addLogFilter("pseudo-header after regular header")
1003                 var buf bytes.Buffer
1004                 enc := hpack.NewEncoder(&buf)
1005                 enc.WriteField(hpack.HeaderField{Name: ":method", Value: "GET"})
1006                 enc.WriteField(hpack.HeaderField{Name: "regular", Value: "foobar"})
1007                 enc.WriteField(hpack.HeaderField{Name: ":path", Value: "/"})
1008                 enc.WriteField(hpack.HeaderField{Name: ":scheme", Value: "https"})
1009                 st.writeHeaders(HeadersFrameParam{
1010                         StreamID:      1, // clients send odd numbers
1011                         BlockFragment: buf.Bytes(),
1012                         EndStream:     true,
1013                         EndHeaders:    true,
1014                 })
1015         })
1016 }
1017
1018 func TestServer_Request_Reject_Pseudo_Missing_path(t *testing.T) {
1019         testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":path", "") })
1020 }
1021
1022 func TestServer_Request_Reject_Pseudo_Missing_scheme(t *testing.T) {
1023         testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":scheme", "") })
1024 }
1025
1026 func TestServer_Request_Reject_Pseudo_scheme_invalid(t *testing.T) {
1027         testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":scheme", "bogus") })
1028 }
1029
1030 func TestServer_Request_Reject_Pseudo_Unknown(t *testing.T) {
1031         testRejectRequest(t, func(st *serverTester) {
1032                 st.addLogFilter(`invalid pseudo-header ":unknown_thing"`)
1033                 st.bodylessReq1(":unknown_thing", "")
1034         })
1035 }
1036
1037 func testRejectRequest(t *testing.T, send func(*serverTester)) {
1038         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1039                 t.Error("server request made it to handler; should've been rejected")
1040         })
1041         defer st.Close()
1042
1043         st.greet()
1044         send(st)
1045         st.wantRSTStream(1, ErrCodeProtocol)
1046 }
1047
1048 func testRejectRequestWithProtocolError(t *testing.T, send func(*serverTester)) {
1049         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1050                 t.Error("server request made it to handler; should've been rejected")
1051         }, optQuiet)
1052         defer st.Close()
1053
1054         st.greet()
1055         send(st)
1056         gf := st.wantGoAway()
1057         if gf.ErrCode != ErrCodeProtocol {
1058                 t.Errorf("err code = %v; want %v", gf.ErrCode, ErrCodeProtocol)
1059         }
1060 }
1061
1062 // Section 5.1, on idle connections: "Receiving any frame other than
1063 // HEADERS or PRIORITY on a stream in this state MUST be treated as a
1064 // connection error (Section 5.4.1) of type PROTOCOL_ERROR."
1065 func TestRejectFrameOnIdle_WindowUpdate(t *testing.T) {
1066         testRejectRequestWithProtocolError(t, func(st *serverTester) {
1067                 st.fr.WriteWindowUpdate(123, 456)
1068         })
1069 }
1070 func TestRejectFrameOnIdle_Data(t *testing.T) {
1071         testRejectRequestWithProtocolError(t, func(st *serverTester) {
1072                 st.fr.WriteData(123, true, nil)
1073         })
1074 }
1075 func TestRejectFrameOnIdle_RSTStream(t *testing.T) {
1076         testRejectRequestWithProtocolError(t, func(st *serverTester) {
1077                 st.fr.WriteRSTStream(123, ErrCodeCancel)
1078         })
1079 }
1080
1081 func TestServer_Request_Connect(t *testing.T) {
1082         testServerRequest(t, func(st *serverTester) {
1083                 st.writeHeaders(HeadersFrameParam{
1084                         StreamID: 1,
1085                         BlockFragment: st.encodeHeaderRaw(
1086                                 ":method", "CONNECT",
1087                                 ":authority", "example.com:123",
1088                         ),
1089                         EndStream:  true,
1090                         EndHeaders: true,
1091                 })
1092         }, func(r *http.Request) {
1093                 if g, w := r.Method, "CONNECT"; g != w {
1094                         t.Errorf("Method = %q; want %q", g, w)
1095                 }
1096                 if g, w := r.RequestURI, "example.com:123"; g != w {
1097                         t.Errorf("RequestURI = %q; want %q", g, w)
1098                 }
1099                 if g, w := r.URL.Host, "example.com:123"; g != w {
1100                         t.Errorf("URL.Host = %q; want %q", g, w)
1101                 }
1102         })
1103 }
1104
1105 func TestServer_Request_Connect_InvalidPath(t *testing.T) {
1106         testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) {
1107                 st.writeHeaders(HeadersFrameParam{
1108                         StreamID: 1,
1109                         BlockFragment: st.encodeHeaderRaw(
1110                                 ":method", "CONNECT",
1111                                 ":authority", "example.com:123",
1112                                 ":path", "/bogus",
1113                         ),
1114                         EndStream:  true,
1115                         EndHeaders: true,
1116                 })
1117         })
1118 }
1119
1120 func TestServer_Request_Connect_InvalidScheme(t *testing.T) {
1121         testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) {
1122                 st.writeHeaders(HeadersFrameParam{
1123                         StreamID: 1,
1124                         BlockFragment: st.encodeHeaderRaw(
1125                                 ":method", "CONNECT",
1126                                 ":authority", "example.com:123",
1127                                 ":scheme", "https",
1128                         ),
1129                         EndStream:  true,
1130                         EndHeaders: true,
1131                 })
1132         })
1133 }
1134
1135 func TestServer_Ping(t *testing.T) {
1136         st := newServerTester(t, nil)
1137         defer st.Close()
1138         st.greet()
1139
1140         // Server should ignore this one, since it has ACK set.
1141         ackPingData := [8]byte{1, 2, 4, 8, 16, 32, 64, 128}
1142         if err := st.fr.WritePing(true, ackPingData); err != nil {
1143                 t.Fatal(err)
1144         }
1145
1146         // But the server should reply to this one, since ACK is false.
1147         pingData := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
1148         if err := st.fr.WritePing(false, pingData); err != nil {
1149                 t.Fatal(err)
1150         }
1151
1152         pf := st.wantPing()
1153         if !pf.Flags.Has(FlagPingAck) {
1154                 t.Error("response ping doesn't have ACK set")
1155         }
1156         if pf.Data != pingData {
1157                 t.Errorf("response ping has data %q; want %q", pf.Data, pingData)
1158         }
1159 }
1160
1161 func TestServer_RejectsLargeFrames(t *testing.T) {
1162         if runtime.GOOS == "windows" {
1163                 t.Skip("see golang.org/issue/13434")
1164         }
1165
1166         st := newServerTester(t, nil)
1167         defer st.Close()
1168         st.greet()
1169
1170         // Write too large of a frame (too large by one byte)
1171         // We ignore the return value because it's expected that the server
1172         // will only read the first 9 bytes (the headre) and then disconnect.
1173         st.fr.WriteRawFrame(0xff, 0, 0, make([]byte, defaultMaxReadFrameSize+1))
1174
1175         gf := st.wantGoAway()
1176         if gf.ErrCode != ErrCodeFrameSize {
1177                 t.Errorf("GOAWAY err = %v; want %v", gf.ErrCode, ErrCodeFrameSize)
1178         }
1179         if st.serverLogBuf.Len() != 0 {
1180                 // Previously we spun here for a bit until the GOAWAY disconnect
1181                 // timer fired, logging while we fired.
1182                 t.Errorf("unexpected server output: %.500s\n", st.serverLogBuf.Bytes())
1183         }
1184 }
1185
1186 func TestServer_Handler_Sends_WindowUpdate(t *testing.T) {
1187         puppet := newHandlerPuppet()
1188         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1189                 puppet.act(w, r)
1190         })
1191         defer st.Close()
1192         defer puppet.done()
1193
1194         st.greet()
1195
1196         st.writeHeaders(HeadersFrameParam{
1197                 StreamID:      1, // clients send odd numbers
1198                 BlockFragment: st.encodeHeader(":method", "POST"),
1199                 EndStream:     false, // data coming
1200                 EndHeaders:    true,
1201         })
1202         st.writeData(1, false, []byte("abcdef"))
1203         puppet.do(readBodyHandler(t, "abc"))
1204         st.wantWindowUpdate(0, 3)
1205         st.wantWindowUpdate(1, 3)
1206
1207         puppet.do(readBodyHandler(t, "def"))
1208         st.wantWindowUpdate(0, 3)
1209         st.wantWindowUpdate(1, 3)
1210
1211         st.writeData(1, true, []byte("ghijkl")) // END_STREAM here
1212         puppet.do(readBodyHandler(t, "ghi"))
1213         puppet.do(readBodyHandler(t, "jkl"))
1214         st.wantWindowUpdate(0, 3)
1215         st.wantWindowUpdate(0, 3) // no more stream-level, since END_STREAM
1216 }
1217
1218 // the version of the TestServer_Handler_Sends_WindowUpdate with padding.
1219 // See golang.org/issue/16556
1220 func TestServer_Handler_Sends_WindowUpdate_Padding(t *testing.T) {
1221         puppet := newHandlerPuppet()
1222         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1223                 puppet.act(w, r)
1224         })
1225         defer st.Close()
1226         defer puppet.done()
1227
1228         st.greet()
1229
1230         st.writeHeaders(HeadersFrameParam{
1231                 StreamID:      1,
1232                 BlockFragment: st.encodeHeader(":method", "POST"),
1233                 EndStream:     false,
1234                 EndHeaders:    true,
1235         })
1236         st.writeDataPadded(1, false, []byte("abcdef"), []byte{0, 0, 0, 0})
1237
1238         // Expect to immediately get our 5 bytes of padding back for
1239         // both the connection and stream (4 bytes of padding + 1 byte of length)
1240         st.wantWindowUpdate(0, 5)
1241         st.wantWindowUpdate(1, 5)
1242
1243         puppet.do(readBodyHandler(t, "abc"))
1244         st.wantWindowUpdate(0, 3)
1245         st.wantWindowUpdate(1, 3)
1246
1247         puppet.do(readBodyHandler(t, "def"))
1248         st.wantWindowUpdate(0, 3)
1249         st.wantWindowUpdate(1, 3)
1250 }
1251
1252 func TestServer_Send_GoAway_After_Bogus_WindowUpdate(t *testing.T) {
1253         st := newServerTester(t, nil)
1254         defer st.Close()
1255         st.greet()
1256         if err := st.fr.WriteWindowUpdate(0, 1<<31-1); err != nil {
1257                 t.Fatal(err)
1258         }
1259         gf := st.wantGoAway()
1260         if gf.ErrCode != ErrCodeFlowControl {
1261                 t.Errorf("GOAWAY err = %v; want %v", gf.ErrCode, ErrCodeFlowControl)
1262         }
1263         if gf.LastStreamID != 0 {
1264                 t.Errorf("GOAWAY last stream ID = %v; want %v", gf.LastStreamID, 0)
1265         }
1266 }
1267
1268 func TestServer_Send_RstStream_After_Bogus_WindowUpdate(t *testing.T) {
1269         inHandler := make(chan bool)
1270         blockHandler := make(chan bool)
1271         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1272                 inHandler <- true
1273                 <-blockHandler
1274         })
1275         defer st.Close()
1276         defer close(blockHandler)
1277         st.greet()
1278         st.writeHeaders(HeadersFrameParam{
1279                 StreamID:      1,
1280                 BlockFragment: st.encodeHeader(":method", "POST"),
1281                 EndStream:     false, // keep it open
1282                 EndHeaders:    true,
1283         })
1284         <-inHandler
1285         // Send a bogus window update:
1286         if err := st.fr.WriteWindowUpdate(1, 1<<31-1); err != nil {
1287                 t.Fatal(err)
1288         }
1289         st.wantRSTStream(1, ErrCodeFlowControl)
1290 }
1291
1292 // testServerPostUnblock sends a hanging POST with unsent data to handler,
1293 // then runs fn once in the handler, and verifies that the error returned from
1294 // handler is acceptable. It fails if takes over 5 seconds for handler to exit.
1295 func testServerPostUnblock(t *testing.T,
1296         handler func(http.ResponseWriter, *http.Request) error,
1297         fn func(*serverTester),
1298         checkErr func(error),
1299         otherHeaders ...string) {
1300         inHandler := make(chan bool)
1301         errc := make(chan error, 1)
1302         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1303                 inHandler <- true
1304                 errc <- handler(w, r)
1305         })
1306         defer st.Close()
1307         st.greet()
1308         st.writeHeaders(HeadersFrameParam{
1309                 StreamID:      1,
1310                 BlockFragment: st.encodeHeader(append([]string{":method", "POST"}, otherHeaders...)...),
1311                 EndStream:     false, // keep it open
1312                 EndHeaders:    true,
1313         })
1314         <-inHandler
1315         fn(st)
1316         select {
1317         case err := <-errc:
1318                 if checkErr != nil {
1319                         checkErr(err)
1320                 }
1321         case <-time.After(5 * time.Second):
1322                 t.Fatal("timeout waiting for Handler to return")
1323         }
1324 }
1325
1326 func TestServer_RSTStream_Unblocks_Read(t *testing.T) {
1327         testServerPostUnblock(t,
1328                 func(w http.ResponseWriter, r *http.Request) (err error) {
1329                         _, err = r.Body.Read(make([]byte, 1))
1330                         return
1331                 },
1332                 func(st *serverTester) {
1333                         if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil {
1334                                 t.Fatal(err)
1335                         }
1336                 },
1337                 func(err error) {
1338                         want := StreamError{StreamID: 0x1, Code: 0x8}
1339                         if !reflect.DeepEqual(err, want) {
1340                                 t.Errorf("Read error = %v; want %v", err, want)
1341                         }
1342                 },
1343         )
1344 }
1345
1346 func TestServer_RSTStream_Unblocks_Header_Write(t *testing.T) {
1347         // Run this test a bunch, because it doesn't always
1348         // deadlock. But with a bunch, it did.
1349         n := 50
1350         if testing.Short() {
1351                 n = 5
1352         }
1353         for i := 0; i < n; i++ {
1354                 testServer_RSTStream_Unblocks_Header_Write(t)
1355         }
1356 }
1357
1358 func testServer_RSTStream_Unblocks_Header_Write(t *testing.T) {
1359         inHandler := make(chan bool, 1)
1360         unblockHandler := make(chan bool, 1)
1361         headerWritten := make(chan bool, 1)
1362         wroteRST := make(chan bool, 1)
1363
1364         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1365                 inHandler <- true
1366                 <-wroteRST
1367                 w.Header().Set("foo", "bar")
1368                 w.WriteHeader(200)
1369                 w.(http.Flusher).Flush()
1370                 headerWritten <- true
1371                 <-unblockHandler
1372         })
1373         defer st.Close()
1374
1375         st.greet()
1376         st.writeHeaders(HeadersFrameParam{
1377                 StreamID:      1,
1378                 BlockFragment: st.encodeHeader(":method", "POST"),
1379                 EndStream:     false, // keep it open
1380                 EndHeaders:    true,
1381         })
1382         <-inHandler
1383         if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil {
1384                 t.Fatal(err)
1385         }
1386         wroteRST <- true
1387         st.awaitIdle()
1388         select {
1389         case <-headerWritten:
1390         case <-time.After(2 * time.Second):
1391                 t.Error("timeout waiting for header write")
1392         }
1393         unblockHandler <- true
1394 }
1395
1396 func TestServer_DeadConn_Unblocks_Read(t *testing.T) {
1397         testServerPostUnblock(t,
1398                 func(w http.ResponseWriter, r *http.Request) (err error) {
1399                         _, err = r.Body.Read(make([]byte, 1))
1400                         return
1401                 },
1402                 func(st *serverTester) { st.cc.Close() },
1403                 func(err error) {
1404                         if err == nil {
1405                                 t.Error("unexpected nil error from Request.Body.Read")
1406                         }
1407                 },
1408         )
1409 }
1410
1411 var blockUntilClosed = func(w http.ResponseWriter, r *http.Request) error {
1412         <-w.(http.CloseNotifier).CloseNotify()
1413         return nil
1414 }
1415
1416 func TestServer_CloseNotify_After_RSTStream(t *testing.T) {
1417         testServerPostUnblock(t, blockUntilClosed, func(st *serverTester) {
1418                 if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil {
1419                         t.Fatal(err)
1420                 }
1421         }, nil)
1422 }
1423
1424 func TestServer_CloseNotify_After_ConnClose(t *testing.T) {
1425         testServerPostUnblock(t, blockUntilClosed, func(st *serverTester) { st.cc.Close() }, nil)
1426 }
1427
1428 // that CloseNotify unblocks after a stream error due to the client's
1429 // problem that's unrelated to them explicitly canceling it (which is
1430 // TestServer_CloseNotify_After_RSTStream above)
1431 func TestServer_CloseNotify_After_StreamError(t *testing.T) {
1432         testServerPostUnblock(t, blockUntilClosed, func(st *serverTester) {
1433                 // data longer than declared Content-Length => stream error
1434                 st.writeData(1, true, []byte("1234"))
1435         }, nil, "content-length", "3")
1436 }
1437
1438 func TestServer_StateTransitions(t *testing.T) {
1439         var st *serverTester
1440         inHandler := make(chan bool)
1441         writeData := make(chan bool)
1442         leaveHandler := make(chan bool)
1443         st = newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1444                 inHandler <- true
1445                 if st.stream(1) == nil {
1446                         t.Errorf("nil stream 1 in handler")
1447                 }
1448                 if got, want := st.streamState(1), stateOpen; got != want {
1449                         t.Errorf("in handler, state is %v; want %v", got, want)
1450                 }
1451                 writeData <- true
1452                 if n, err := r.Body.Read(make([]byte, 1)); n != 0 || err != io.EOF {
1453                         t.Errorf("body read = %d, %v; want 0, EOF", n, err)
1454                 }
1455                 if got, want := st.streamState(1), stateHalfClosedRemote; got != want {
1456                         t.Errorf("in handler, state is %v; want %v", got, want)
1457                 }
1458
1459                 <-leaveHandler
1460         })
1461         st.greet()
1462         if st.stream(1) != nil {
1463                 t.Fatal("stream 1 should be empty")
1464         }
1465         if got := st.streamState(1); got != stateIdle {
1466                 t.Fatalf("stream 1 should be idle; got %v", got)
1467         }
1468
1469         st.writeHeaders(HeadersFrameParam{
1470                 StreamID:      1,
1471                 BlockFragment: st.encodeHeader(":method", "POST"),
1472                 EndStream:     false, // keep it open
1473                 EndHeaders:    true,
1474         })
1475         <-inHandler
1476         <-writeData
1477         st.writeData(1, true, nil)
1478
1479         leaveHandler <- true
1480         hf := st.wantHeaders()
1481         if !hf.StreamEnded() {
1482                 t.Fatal("expected END_STREAM flag")
1483         }
1484
1485         if got, want := st.streamState(1), stateClosed; got != want {
1486                 t.Errorf("at end, state is %v; want %v", got, want)
1487         }
1488         if st.stream(1) != nil {
1489                 t.Fatal("at end, stream 1 should be gone")
1490         }
1491 }
1492
1493 // test HEADERS w/o EndHeaders + another HEADERS (should get rejected)
1494 func TestServer_Rejects_HeadersNoEnd_Then_Headers(t *testing.T) {
1495         testServerRejectsConn(t, func(st *serverTester) {
1496                 st.writeHeaders(HeadersFrameParam{
1497                         StreamID:      1,
1498                         BlockFragment: st.encodeHeader(),
1499                         EndStream:     true,
1500                         EndHeaders:    false,
1501                 })
1502                 st.writeHeaders(HeadersFrameParam{ // Not a continuation.
1503                         StreamID:      3, // different stream.
1504                         BlockFragment: st.encodeHeader(),
1505                         EndStream:     true,
1506                         EndHeaders:    true,
1507                 })
1508         })
1509 }
1510
1511 // test HEADERS w/o EndHeaders + PING (should get rejected)
1512 func TestServer_Rejects_HeadersNoEnd_Then_Ping(t *testing.T) {
1513         testServerRejectsConn(t, func(st *serverTester) {
1514                 st.writeHeaders(HeadersFrameParam{
1515                         StreamID:      1,
1516                         BlockFragment: st.encodeHeader(),
1517                         EndStream:     true,
1518                         EndHeaders:    false,
1519                 })
1520                 if err := st.fr.WritePing(false, [8]byte{}); err != nil {
1521                         t.Fatal(err)
1522                 }
1523         })
1524 }
1525
1526 // test HEADERS w/ EndHeaders + a continuation HEADERS (should get rejected)
1527 func TestServer_Rejects_HeadersEnd_Then_Continuation(t *testing.T) {
1528         testServerRejectsConn(t, func(st *serverTester) {
1529                 st.writeHeaders(HeadersFrameParam{
1530                         StreamID:      1,
1531                         BlockFragment: st.encodeHeader(),
1532                         EndStream:     true,
1533                         EndHeaders:    true,
1534                 })
1535                 st.wantHeaders()
1536                 if err := st.fr.WriteContinuation(1, true, encodeHeaderNoImplicit(t, "foo", "bar")); err != nil {
1537                         t.Fatal(err)
1538                 }
1539         })
1540 }
1541
1542 // test HEADERS w/o EndHeaders + a continuation HEADERS on wrong stream ID
1543 func TestServer_Rejects_HeadersNoEnd_Then_ContinuationWrongStream(t *testing.T) {
1544         testServerRejectsConn(t, func(st *serverTester) {
1545                 st.writeHeaders(HeadersFrameParam{
1546                         StreamID:      1,
1547                         BlockFragment: st.encodeHeader(),
1548                         EndStream:     true,
1549                         EndHeaders:    false,
1550                 })
1551                 if err := st.fr.WriteContinuation(3, true, encodeHeaderNoImplicit(t, "foo", "bar")); err != nil {
1552                         t.Fatal(err)
1553                 }
1554         })
1555 }
1556
1557 // No HEADERS on stream 0.
1558 func TestServer_Rejects_Headers0(t *testing.T) {
1559         testServerRejectsConn(t, func(st *serverTester) {
1560                 st.fr.AllowIllegalWrites = true
1561                 st.writeHeaders(HeadersFrameParam{
1562                         StreamID:      0,
1563                         BlockFragment: st.encodeHeader(),
1564                         EndStream:     true,
1565                         EndHeaders:    true,
1566                 })
1567         })
1568 }
1569
1570 // No CONTINUATION on stream 0.
1571 func TestServer_Rejects_Continuation0(t *testing.T) {
1572         testServerRejectsConn(t, func(st *serverTester) {
1573                 st.fr.AllowIllegalWrites = true
1574                 if err := st.fr.WriteContinuation(0, true, st.encodeHeader()); err != nil {
1575                         t.Fatal(err)
1576                 }
1577         })
1578 }
1579
1580 // No PRIORITY on stream 0.
1581 func TestServer_Rejects_Priority0(t *testing.T) {
1582         testServerRejectsConn(t, func(st *serverTester) {
1583                 st.fr.AllowIllegalWrites = true
1584                 st.writePriority(0, PriorityParam{StreamDep: 1})
1585         })
1586 }
1587
1588 // No HEADERS frame with a self-dependence.
1589 func TestServer_Rejects_HeadersSelfDependence(t *testing.T) {
1590         testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) {
1591                 st.fr.AllowIllegalWrites = true
1592                 st.writeHeaders(HeadersFrameParam{
1593                         StreamID:      1,
1594                         BlockFragment: st.encodeHeader(),
1595                         EndStream:     true,
1596                         EndHeaders:    true,
1597                         Priority:      PriorityParam{StreamDep: 1},
1598                 })
1599         })
1600 }
1601
1602 // No PRIORTY frame with a self-dependence.
1603 func TestServer_Rejects_PrioritySelfDependence(t *testing.T) {
1604         testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) {
1605                 st.fr.AllowIllegalWrites = true
1606                 st.writePriority(1, PriorityParam{StreamDep: 1})
1607         })
1608 }
1609
1610 func TestServer_Rejects_PushPromise(t *testing.T) {
1611         testServerRejectsConn(t, func(st *serverTester) {
1612                 pp := PushPromiseParam{
1613                         StreamID:  1,
1614                         PromiseID: 3,
1615                 }
1616                 if err := st.fr.WritePushPromise(pp); err != nil {
1617                         t.Fatal(err)
1618                 }
1619         })
1620 }
1621
1622 // testServerRejectsConn tests that the server hangs up with a GOAWAY
1623 // frame and a server close after the client does something
1624 // deserving a CONNECTION_ERROR.
1625 func testServerRejectsConn(t *testing.T, writeReq func(*serverTester)) {
1626         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {})
1627         st.addLogFilter("connection error: PROTOCOL_ERROR")
1628         defer st.Close()
1629         st.greet()
1630         writeReq(st)
1631
1632         st.wantGoAway()
1633         errc := make(chan error, 1)
1634         go func() {
1635                 fr, err := st.fr.ReadFrame()
1636                 if err == nil {
1637                         err = fmt.Errorf("got frame of type %T", fr)
1638                 }
1639                 errc <- err
1640         }()
1641         select {
1642         case err := <-errc:
1643                 if err != io.EOF {
1644                         t.Errorf("ReadFrame = %v; want io.EOF", err)
1645                 }
1646         case <-time.After(2 * time.Second):
1647                 t.Error("timeout waiting for disconnect")
1648         }
1649 }
1650
1651 // testServerRejectsStream tests that the server sends a RST_STREAM with the provided
1652 // error code after a client sends a bogus request.
1653 func testServerRejectsStream(t *testing.T, code ErrCode, writeReq func(*serverTester)) {
1654         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {})
1655         defer st.Close()
1656         st.greet()
1657         writeReq(st)
1658         st.wantRSTStream(1, code)
1659 }
1660
1661 // testServerRequest sets up an idle HTTP/2 connection and lets you
1662 // write a single request with writeReq, and then verify that the
1663 // *http.Request is built correctly in checkReq.
1664 func testServerRequest(t *testing.T, writeReq func(*serverTester), checkReq func(*http.Request)) {
1665         gotReq := make(chan bool, 1)
1666         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1667                 if r.Body == nil {
1668                         t.Fatal("nil Body")
1669                 }
1670                 checkReq(r)
1671                 gotReq <- true
1672         })
1673         defer st.Close()
1674
1675         st.greet()
1676         writeReq(st)
1677
1678         select {
1679         case <-gotReq:
1680         case <-time.After(2 * time.Second):
1681                 t.Error("timeout waiting for request")
1682         }
1683 }
1684
1685 func getSlash(st *serverTester) { st.bodylessReq1() }
1686
1687 func TestServer_Response_NoData(t *testing.T) {
1688         testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1689                 // Nothing.
1690                 return nil
1691         }, func(st *serverTester) {
1692                 getSlash(st)
1693                 hf := st.wantHeaders()
1694                 if !hf.StreamEnded() {
1695                         t.Fatal("want END_STREAM flag")
1696                 }
1697                 if !hf.HeadersEnded() {
1698                         t.Fatal("want END_HEADERS flag")
1699                 }
1700         })
1701 }
1702
1703 func TestServer_Response_NoData_Header_FooBar(t *testing.T) {
1704         testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1705                 w.Header().Set("Foo-Bar", "some-value")
1706                 return nil
1707         }, func(st *serverTester) {
1708                 getSlash(st)
1709                 hf := st.wantHeaders()
1710                 if !hf.StreamEnded() {
1711                         t.Fatal("want END_STREAM flag")
1712                 }
1713                 if !hf.HeadersEnded() {
1714                         t.Fatal("want END_HEADERS flag")
1715                 }
1716                 goth := st.decodeHeader(hf.HeaderBlockFragment())
1717                 wanth := [][2]string{
1718                         {":status", "200"},
1719                         {"foo-bar", "some-value"},
1720                         {"content-type", "text/plain; charset=utf-8"},
1721                         {"content-length", "0"},
1722                 }
1723                 if !reflect.DeepEqual(goth, wanth) {
1724                         t.Errorf("Got headers %v; want %v", goth, wanth)
1725                 }
1726         })
1727 }
1728
1729 func TestServer_Response_Data_Sniff_DoesntOverride(t *testing.T) {
1730         const msg = "<html>this is HTML."
1731         testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1732                 w.Header().Set("Content-Type", "foo/bar")
1733                 io.WriteString(w, msg)
1734                 return nil
1735         }, func(st *serverTester) {
1736                 getSlash(st)
1737                 hf := st.wantHeaders()
1738                 if hf.StreamEnded() {
1739                         t.Fatal("don't want END_STREAM, expecting data")
1740                 }
1741                 if !hf.HeadersEnded() {
1742                         t.Fatal("want END_HEADERS flag")
1743                 }
1744                 goth := st.decodeHeader(hf.HeaderBlockFragment())
1745                 wanth := [][2]string{
1746                         {":status", "200"},
1747                         {"content-type", "foo/bar"},
1748                         {"content-length", strconv.Itoa(len(msg))},
1749                 }
1750                 if !reflect.DeepEqual(goth, wanth) {
1751                         t.Errorf("Got headers %v; want %v", goth, wanth)
1752                 }
1753                 df := st.wantData()
1754                 if !df.StreamEnded() {
1755                         t.Error("expected DATA to have END_STREAM flag")
1756                 }
1757                 if got := string(df.Data()); got != msg {
1758                         t.Errorf("got DATA %q; want %q", got, msg)
1759                 }
1760         })
1761 }
1762
1763 func TestServer_Response_TransferEncoding_chunked(t *testing.T) {
1764         const msg = "hi"
1765         testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1766                 w.Header().Set("Transfer-Encoding", "chunked") // should be stripped
1767                 io.WriteString(w, msg)
1768                 return nil
1769         }, func(st *serverTester) {
1770                 getSlash(st)
1771                 hf := st.wantHeaders()
1772                 goth := st.decodeHeader(hf.HeaderBlockFragment())
1773                 wanth := [][2]string{
1774                         {":status", "200"},
1775                         {"content-type", "text/plain; charset=utf-8"},
1776                         {"content-length", strconv.Itoa(len(msg))},
1777                 }
1778                 if !reflect.DeepEqual(goth, wanth) {
1779                         t.Errorf("Got headers %v; want %v", goth, wanth)
1780                 }
1781         })
1782 }
1783
1784 // Header accessed only after the initial write.
1785 func TestServer_Response_Data_IgnoreHeaderAfterWrite_After(t *testing.T) {
1786         const msg = "<html>this is HTML."
1787         testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1788                 io.WriteString(w, msg)
1789                 w.Header().Set("foo", "should be ignored")
1790                 return nil
1791         }, func(st *serverTester) {
1792                 getSlash(st)
1793                 hf := st.wantHeaders()
1794                 if hf.StreamEnded() {
1795                         t.Fatal("unexpected END_STREAM")
1796                 }
1797                 if !hf.HeadersEnded() {
1798                         t.Fatal("want END_HEADERS flag")
1799                 }
1800                 goth := st.decodeHeader(hf.HeaderBlockFragment())
1801                 wanth := [][2]string{
1802                         {":status", "200"},
1803                         {"content-type", "text/html; charset=utf-8"},
1804                         {"content-length", strconv.Itoa(len(msg))},
1805                 }
1806                 if !reflect.DeepEqual(goth, wanth) {
1807                         t.Errorf("Got headers %v; want %v", goth, wanth)
1808                 }
1809         })
1810 }
1811
1812 // Header accessed before the initial write and later mutated.
1813 func TestServer_Response_Data_IgnoreHeaderAfterWrite_Overwrite(t *testing.T) {
1814         const msg = "<html>this is HTML."
1815         testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1816                 w.Header().Set("foo", "proper value")
1817                 io.WriteString(w, msg)
1818                 w.Header().Set("foo", "should be ignored")
1819                 return nil
1820         }, func(st *serverTester) {
1821                 getSlash(st)
1822                 hf := st.wantHeaders()
1823                 if hf.StreamEnded() {
1824                         t.Fatal("unexpected END_STREAM")
1825                 }
1826                 if !hf.HeadersEnded() {
1827                         t.Fatal("want END_HEADERS flag")
1828                 }
1829                 goth := st.decodeHeader(hf.HeaderBlockFragment())
1830                 wanth := [][2]string{
1831                         {":status", "200"},
1832                         {"foo", "proper value"},
1833                         {"content-type", "text/html; charset=utf-8"},
1834                         {"content-length", strconv.Itoa(len(msg))},
1835                 }
1836                 if !reflect.DeepEqual(goth, wanth) {
1837                         t.Errorf("Got headers %v; want %v", goth, wanth)
1838                 }
1839         })
1840 }
1841
1842 func TestServer_Response_Data_SniffLenType(t *testing.T) {
1843         const msg = "<html>this is HTML."
1844         testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1845                 io.WriteString(w, msg)
1846                 return nil
1847         }, func(st *serverTester) {
1848                 getSlash(st)
1849                 hf := st.wantHeaders()
1850                 if hf.StreamEnded() {
1851                         t.Fatal("don't want END_STREAM, expecting data")
1852                 }
1853                 if !hf.HeadersEnded() {
1854                         t.Fatal("want END_HEADERS flag")
1855                 }
1856                 goth := st.decodeHeader(hf.HeaderBlockFragment())
1857                 wanth := [][2]string{
1858                         {":status", "200"},
1859                         {"content-type", "text/html; charset=utf-8"},
1860                         {"content-length", strconv.Itoa(len(msg))},
1861                 }
1862                 if !reflect.DeepEqual(goth, wanth) {
1863                         t.Errorf("Got headers %v; want %v", goth, wanth)
1864                 }
1865                 df := st.wantData()
1866                 if !df.StreamEnded() {
1867                         t.Error("expected DATA to have END_STREAM flag")
1868                 }
1869                 if got := string(df.Data()); got != msg {
1870                         t.Errorf("got DATA %q; want %q", got, msg)
1871                 }
1872         })
1873 }
1874
1875 func TestServer_Response_Header_Flush_MidWrite(t *testing.T) {
1876         const msg = "<html>this is HTML"
1877         const msg2 = ", and this is the next chunk"
1878         testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1879                 io.WriteString(w, msg)
1880                 w.(http.Flusher).Flush()
1881                 io.WriteString(w, msg2)
1882                 return nil
1883         }, func(st *serverTester) {
1884                 getSlash(st)
1885                 hf := st.wantHeaders()
1886                 if hf.StreamEnded() {
1887                         t.Fatal("unexpected END_STREAM flag")
1888                 }
1889                 if !hf.HeadersEnded() {
1890                         t.Fatal("want END_HEADERS flag")
1891                 }
1892                 goth := st.decodeHeader(hf.HeaderBlockFragment())
1893                 wanth := [][2]string{
1894                         {":status", "200"},
1895                         {"content-type", "text/html; charset=utf-8"}, // sniffed
1896                         // and no content-length
1897                 }
1898                 if !reflect.DeepEqual(goth, wanth) {
1899                         t.Errorf("Got headers %v; want %v", goth, wanth)
1900                 }
1901                 {
1902                         df := st.wantData()
1903                         if df.StreamEnded() {
1904                                 t.Error("unexpected END_STREAM flag")
1905                         }
1906                         if got := string(df.Data()); got != msg {
1907                                 t.Errorf("got DATA %q; want %q", got, msg)
1908                         }
1909                 }
1910                 {
1911                         df := st.wantData()
1912                         if !df.StreamEnded() {
1913                                 t.Error("wanted END_STREAM flag on last data chunk")
1914                         }
1915                         if got := string(df.Data()); got != msg2 {
1916                                 t.Errorf("got DATA %q; want %q", got, msg2)
1917                         }
1918                 }
1919         })
1920 }
1921
1922 func TestServer_Response_LargeWrite(t *testing.T) {
1923         const size = 1 << 20
1924         const maxFrameSize = 16 << 10
1925         testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1926                 n, err := w.Write(bytes.Repeat([]byte("a"), size))
1927                 if err != nil {
1928                         return fmt.Errorf("Write error: %v", err)
1929                 }
1930                 if n != size {
1931                         return fmt.Errorf("wrong size %d from Write", n)
1932                 }
1933                 return nil
1934         }, func(st *serverTester) {
1935                 if err := st.fr.WriteSettings(
1936                         Setting{SettingInitialWindowSize, 0},
1937                         Setting{SettingMaxFrameSize, maxFrameSize},
1938                 ); err != nil {
1939                         t.Fatal(err)
1940                 }
1941                 st.wantSettingsAck()
1942
1943                 getSlash(st) // make the single request
1944
1945                 // Give the handler quota to write:
1946                 if err := st.fr.WriteWindowUpdate(1, size); err != nil {
1947                         t.Fatal(err)
1948                 }
1949                 // Give the handler quota to write to connection-level
1950                 // window as well
1951                 if err := st.fr.WriteWindowUpdate(0, size); err != nil {
1952                         t.Fatal(err)
1953                 }
1954                 hf := st.wantHeaders()
1955                 if hf.StreamEnded() {
1956                         t.Fatal("unexpected END_STREAM flag")
1957                 }
1958                 if !hf.HeadersEnded() {
1959                         t.Fatal("want END_HEADERS flag")
1960                 }
1961                 goth := st.decodeHeader(hf.HeaderBlockFragment())
1962                 wanth := [][2]string{
1963                         {":status", "200"},
1964                         {"content-type", "text/plain; charset=utf-8"}, // sniffed
1965                         // and no content-length
1966                 }
1967                 if !reflect.DeepEqual(goth, wanth) {
1968                         t.Errorf("Got headers %v; want %v", goth, wanth)
1969                 }
1970                 var bytes, frames int
1971                 for {
1972                         df := st.wantData()
1973                         bytes += len(df.Data())
1974                         frames++
1975                         for _, b := range df.Data() {
1976                                 if b != 'a' {
1977                                         t.Fatal("non-'a' byte seen in DATA")
1978                                 }
1979                         }
1980                         if df.StreamEnded() {
1981                                 break
1982                         }
1983                 }
1984                 if bytes != size {
1985                         t.Errorf("Got %d bytes; want %d", bytes, size)
1986                 }
1987                 if want := int(size / maxFrameSize); frames < want || frames > want*2 {
1988                         t.Errorf("Got %d frames; want %d", frames, size)
1989                 }
1990         })
1991 }
1992
1993 // Test that the handler can't write more than the client allows
1994 func TestServer_Response_LargeWrite_FlowControlled(t *testing.T) {
1995         // Make these reads. Before each read, the client adds exactly enough
1996         // flow-control to satisfy the read. Numbers chosen arbitrarily.
1997         reads := []int{123, 1, 13, 127}
1998         size := 0
1999         for _, n := range reads {
2000                 size += n
2001         }
2002
2003         testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2004                 w.(http.Flusher).Flush()
2005                 n, err := w.Write(bytes.Repeat([]byte("a"), size))
2006                 if err != nil {
2007                         return fmt.Errorf("Write error: %v", err)
2008                 }
2009                 if n != size {
2010                         return fmt.Errorf("wrong size %d from Write", n)
2011                 }
2012                 return nil
2013         }, func(st *serverTester) {
2014                 // Set the window size to something explicit for this test.
2015                 // It's also how much initial data we expect.
2016                 if err := st.fr.WriteSettings(Setting{SettingInitialWindowSize, uint32(reads[0])}); err != nil {
2017                         t.Fatal(err)
2018                 }
2019                 st.wantSettingsAck()
2020
2021                 getSlash(st) // make the single request
2022
2023                 hf := st.wantHeaders()
2024                 if hf.StreamEnded() {
2025                         t.Fatal("unexpected END_STREAM flag")
2026                 }
2027                 if !hf.HeadersEnded() {
2028                         t.Fatal("want END_HEADERS flag")
2029                 }
2030
2031                 df := st.wantData()
2032                 if got := len(df.Data()); got != reads[0] {
2033                         t.Fatalf("Initial window size = %d but got DATA with %d bytes", reads[0], got)
2034                 }
2035
2036                 for _, quota := range reads[1:] {
2037                         if err := st.fr.WriteWindowUpdate(1, uint32(quota)); err != nil {
2038                                 t.Fatal(err)
2039                         }
2040                         df := st.wantData()
2041                         if int(quota) != len(df.Data()) {
2042                                 t.Fatalf("read %d bytes after giving %d quota", len(df.Data()), quota)
2043                         }
2044                 }
2045         })
2046 }
2047
2048 // Test that the handler blocked in a Write is unblocked if the server sends a RST_STREAM.
2049 func TestServer_Response_RST_Unblocks_LargeWrite(t *testing.T) {
2050         const size = 1 << 20
2051         const maxFrameSize = 16 << 10
2052         testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2053                 w.(http.Flusher).Flush()
2054                 errc := make(chan error, 1)
2055                 go func() {
2056                         _, err := w.Write(bytes.Repeat([]byte("a"), size))
2057                         errc <- err
2058                 }()
2059                 select {
2060                 case err := <-errc:
2061                         if err == nil {
2062                                 return errors.New("unexpected nil error from Write in handler")
2063                         }
2064                         return nil
2065                 case <-time.After(2 * time.Second):
2066                         return errors.New("timeout waiting for Write in handler")
2067                 }
2068         }, func(st *serverTester) {
2069                 if err := st.fr.WriteSettings(
2070                         Setting{SettingInitialWindowSize, 0},
2071                         Setting{SettingMaxFrameSize, maxFrameSize},
2072                 ); err != nil {
2073                         t.Fatal(err)
2074                 }
2075                 st.wantSettingsAck()
2076
2077                 getSlash(st) // make the single request
2078
2079                 hf := st.wantHeaders()
2080                 if hf.StreamEnded() {
2081                         t.Fatal("unexpected END_STREAM flag")
2082                 }
2083                 if !hf.HeadersEnded() {
2084                         t.Fatal("want END_HEADERS flag")
2085                 }
2086
2087                 if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil {
2088                         t.Fatal(err)
2089                 }
2090         })
2091 }
2092
2093 func TestServer_Response_Empty_Data_Not_FlowControlled(t *testing.T) {
2094         testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2095                 w.(http.Flusher).Flush()
2096                 // Nothing; send empty DATA
2097                 return nil
2098         }, func(st *serverTester) {
2099                 // Handler gets no data quota:
2100                 if err := st.fr.WriteSettings(Setting{SettingInitialWindowSize, 0}); err != nil {
2101                         t.Fatal(err)
2102                 }
2103                 st.wantSettingsAck()
2104
2105                 getSlash(st) // make the single request
2106
2107                 hf := st.wantHeaders()
2108                 if hf.StreamEnded() {
2109                         t.Fatal("unexpected END_STREAM flag")
2110                 }
2111                 if !hf.HeadersEnded() {
2112                         t.Fatal("want END_HEADERS flag")
2113                 }
2114
2115                 df := st.wantData()
2116                 if got := len(df.Data()); got != 0 {
2117                         t.Fatalf("unexpected %d DATA bytes; want 0", got)
2118                 }
2119                 if !df.StreamEnded() {
2120                         t.Fatal("DATA didn't have END_STREAM")
2121                 }
2122         })
2123 }
2124
2125 func TestServer_Response_Automatic100Continue(t *testing.T) {
2126         const msg = "foo"
2127         const reply = "bar"
2128         testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2129                 if v := r.Header.Get("Expect"); v != "" {
2130                         t.Errorf("Expect header = %q; want empty", v)
2131                 }
2132                 buf := make([]byte, len(msg))
2133                 // This read should trigger the 100-continue being sent.
2134                 if n, err := io.ReadFull(r.Body, buf); err != nil || n != len(msg) || string(buf) != msg {
2135                         return fmt.Errorf("ReadFull = %q, %v; want %q, nil", buf[:n], err, msg)
2136                 }
2137                 _, err := io.WriteString(w, reply)
2138                 return err
2139         }, func(st *serverTester) {
2140                 st.writeHeaders(HeadersFrameParam{
2141                         StreamID:      1, // clients send odd numbers
2142                         BlockFragment: st.encodeHeader(":method", "POST", "expect", "100-continue"),
2143                         EndStream:     false,
2144                         EndHeaders:    true,
2145                 })
2146                 hf := st.wantHeaders()
2147                 if hf.StreamEnded() {
2148                         t.Fatal("unexpected END_STREAM flag")
2149                 }
2150                 if !hf.HeadersEnded() {
2151                         t.Fatal("want END_HEADERS flag")
2152                 }
2153                 goth := st.decodeHeader(hf.HeaderBlockFragment())
2154                 wanth := [][2]string{
2155                         {":status", "100"},
2156                 }
2157                 if !reflect.DeepEqual(goth, wanth) {
2158                         t.Fatalf("Got headers %v; want %v", goth, wanth)
2159                 }
2160
2161                 // Okay, they sent status 100, so we can send our
2162                 // gigantic and/or sensitive "foo" payload now.
2163                 st.writeData(1, true, []byte(msg))
2164
2165                 st.wantWindowUpdate(0, uint32(len(msg)))
2166
2167                 hf = st.wantHeaders()
2168                 if hf.StreamEnded() {
2169                         t.Fatal("expected data to follow")
2170                 }
2171                 if !hf.HeadersEnded() {
2172                         t.Fatal("want END_HEADERS flag")
2173                 }
2174                 goth = st.decodeHeader(hf.HeaderBlockFragment())
2175                 wanth = [][2]string{
2176                         {":status", "200"},
2177                         {"content-type", "text/plain; charset=utf-8"},
2178                         {"content-length", strconv.Itoa(len(reply))},
2179                 }
2180                 if !reflect.DeepEqual(goth, wanth) {
2181                         t.Errorf("Got headers %v; want %v", goth, wanth)
2182                 }
2183
2184                 df := st.wantData()
2185                 if string(df.Data()) != reply {
2186                         t.Errorf("Client read %q; want %q", df.Data(), reply)
2187                 }
2188                 if !df.StreamEnded() {
2189                         t.Errorf("expect data stream end")
2190                 }
2191         })
2192 }
2193
2194 func TestServer_HandlerWriteErrorOnDisconnect(t *testing.T) {
2195         errc := make(chan error, 1)
2196         testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2197                 p := []byte("some data.\n")
2198                 for {
2199                         _, err := w.Write(p)
2200                         if err != nil {
2201                                 errc <- err
2202                                 return nil
2203                         }
2204                 }
2205         }, func(st *serverTester) {
2206                 st.writeHeaders(HeadersFrameParam{
2207                         StreamID:      1,
2208                         BlockFragment: st.encodeHeader(),
2209                         EndStream:     false,
2210                         EndHeaders:    true,
2211                 })
2212                 hf := st.wantHeaders()
2213                 if hf.StreamEnded() {
2214                         t.Fatal("unexpected END_STREAM flag")
2215                 }
2216                 if !hf.HeadersEnded() {
2217                         t.Fatal("want END_HEADERS flag")
2218                 }
2219                 // Close the connection and wait for the handler to (hopefully) notice.
2220                 st.cc.Close()
2221                 select {
2222                 case <-errc:
2223                 case <-time.After(5 * time.Second):
2224                         t.Error("timeout")
2225                 }
2226         })
2227 }
2228
2229 func TestServer_Rejects_Too_Many_Streams(t *testing.T) {
2230         const testPath = "/some/path"
2231
2232         inHandler := make(chan uint32)
2233         leaveHandler := make(chan bool)
2234         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
2235                 id := w.(*responseWriter).rws.stream.id
2236                 inHandler <- id
2237                 if id == 1+(defaultMaxStreams+1)*2 && r.URL.Path != testPath {
2238                         t.Errorf("decoded final path as %q; want %q", r.URL.Path, testPath)
2239                 }
2240                 <-leaveHandler
2241         })
2242         defer st.Close()
2243         st.greet()
2244         nextStreamID := uint32(1)
2245         streamID := func() uint32 {
2246                 defer func() { nextStreamID += 2 }()
2247                 return nextStreamID
2248         }
2249         sendReq := func(id uint32, headers ...string) {
2250                 st.writeHeaders(HeadersFrameParam{
2251                         StreamID:      id,
2252                         BlockFragment: st.encodeHeader(headers...),
2253                         EndStream:     true,
2254                         EndHeaders:    true,
2255                 })
2256         }
2257         for i := 0; i < defaultMaxStreams; i++ {
2258                 sendReq(streamID())
2259                 <-inHandler
2260         }
2261         defer func() {
2262                 for i := 0; i < defaultMaxStreams; i++ {
2263                         leaveHandler <- true
2264                 }
2265         }()
2266
2267         // And this one should cross the limit:
2268         // (It's also sent as a CONTINUATION, to verify we still track the decoder context,
2269         // even if we're rejecting it)
2270         rejectID := streamID()
2271         headerBlock := st.encodeHeader(":path", testPath)
2272         frag1, frag2 := headerBlock[:3], headerBlock[3:]
2273         st.writeHeaders(HeadersFrameParam{
2274                 StreamID:      rejectID,
2275                 BlockFragment: frag1,
2276                 EndStream:     true,
2277                 EndHeaders:    false, // CONTINUATION coming
2278         })
2279         if err := st.fr.WriteContinuation(rejectID, true, frag2); err != nil {
2280                 t.Fatal(err)
2281         }
2282         st.wantRSTStream(rejectID, ErrCodeProtocol)
2283
2284         // But let a handler finish:
2285         leaveHandler <- true
2286         st.wantHeaders()
2287
2288         // And now another stream should be able to start:
2289         goodID := streamID()
2290         sendReq(goodID, ":path", testPath)
2291         select {
2292         case got := <-inHandler:
2293                 if got != goodID {
2294                         t.Errorf("Got stream %d; want %d", got, goodID)
2295                 }
2296         case <-time.After(3 * time.Second):
2297                 t.Error("timeout waiting for handler")
2298         }
2299 }
2300
2301 // So many response headers that the server needs to use CONTINUATION frames:
2302 func TestServer_Response_ManyHeaders_With_Continuation(t *testing.T) {
2303         testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2304                 h := w.Header()
2305                 for i := 0; i < 5000; i++ {
2306                         h.Set(fmt.Sprintf("x-header-%d", i), fmt.Sprintf("x-value-%d", i))
2307                 }
2308                 return nil
2309         }, func(st *serverTester) {
2310                 getSlash(st)
2311                 hf := st.wantHeaders()
2312                 if hf.HeadersEnded() {
2313                         t.Fatal("got unwanted END_HEADERS flag")
2314                 }
2315                 n := 0
2316                 for {
2317                         n++
2318                         cf := st.wantContinuation()
2319                         if cf.HeadersEnded() {
2320                                 break
2321                         }
2322                 }
2323                 if n < 5 {
2324                         t.Errorf("Only got %d CONTINUATION frames; expected 5+ (currently 6)", n)
2325                 }
2326         })
2327 }
2328
2329 // This previously crashed (reported by Mathieu Lonjaret as observed
2330 // while using Camlistore) because we got a DATA frame from the client
2331 // after the handler exited and our logic at the time was wrong,
2332 // keeping a stream in the map in stateClosed, which tickled an
2333 // invariant check later when we tried to remove that stream (via
2334 // defer sc.closeAllStreamsOnConnClose) when the serverConn serve loop
2335 // ended.
2336 func TestServer_NoCrash_HandlerClose_Then_ClientClose(t *testing.T) {
2337         testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2338                 // nothing
2339                 return nil
2340         }, func(st *serverTester) {
2341                 st.writeHeaders(HeadersFrameParam{
2342                         StreamID:      1,
2343                         BlockFragment: st.encodeHeader(),
2344                         EndStream:     false, // DATA is coming
2345                         EndHeaders:    true,
2346                 })
2347                 hf := st.wantHeaders()
2348                 if !hf.HeadersEnded() || !hf.StreamEnded() {
2349                         t.Fatalf("want END_HEADERS+END_STREAM, got %v", hf)
2350                 }
2351
2352                 // Sent when the a Handler closes while a client has
2353                 // indicated it's still sending DATA:
2354                 st.wantRSTStream(1, ErrCodeNo)
2355
2356                 // Now the handler has ended, so it's ended its
2357                 // stream, but the client hasn't closed its side
2358                 // (stateClosedLocal).  So send more data and verify
2359                 // it doesn't crash with an internal invariant panic, like
2360                 // it did before.
2361                 st.writeData(1, true, []byte("foo"))
2362
2363                 // Get our flow control bytes back, since the handler didn't get them.
2364                 st.wantWindowUpdate(0, uint32(len("foo")))
2365
2366                 // Sent after a peer sends data anyway (admittedly the
2367                 // previous RST_STREAM might've still been in-flight),
2368                 // but they'll get the more friendly 'cancel' code
2369                 // first.
2370                 st.wantRSTStream(1, ErrCodeStreamClosed)
2371
2372                 // Set up a bunch of machinery to record the panic we saw
2373                 // previously.
2374                 var (
2375                         panMu    sync.Mutex
2376                         panicVal interface{}
2377                 )
2378
2379                 testHookOnPanicMu.Lock()
2380                 testHookOnPanic = func(sc *serverConn, pv interface{}) bool {
2381                         panMu.Lock()
2382                         panicVal = pv
2383                         panMu.Unlock()
2384                         return true
2385                 }
2386                 testHookOnPanicMu.Unlock()
2387
2388                 // Now force the serve loop to end, via closing the connection.
2389                 st.cc.Close()
2390                 select {
2391                 case <-st.sc.doneServing:
2392                         // Loop has exited.
2393                         panMu.Lock()
2394                         got := panicVal
2395                         panMu.Unlock()
2396                         if got != nil {
2397                                 t.Errorf("Got panic: %v", got)
2398                         }
2399                 case <-time.After(5 * time.Second):
2400                         t.Error("timeout")
2401                 }
2402         })
2403 }
2404
2405 func TestServer_Rejects_TLS10(t *testing.T) { testRejectTLS(t, tls.VersionTLS10) }
2406 func TestServer_Rejects_TLS11(t *testing.T) { testRejectTLS(t, tls.VersionTLS11) }
2407
2408 func testRejectTLS(t *testing.T, max uint16) {
2409         st := newServerTester(t, nil, func(c *tls.Config) {
2410                 c.MaxVersion = max
2411         })
2412         defer st.Close()
2413         gf := st.wantGoAway()
2414         if got, want := gf.ErrCode, ErrCodeInadequateSecurity; got != want {
2415                 t.Errorf("Got error code %v; want %v", got, want)
2416         }
2417 }
2418
2419 func TestServer_Rejects_TLSBadCipher(t *testing.T) {
2420         st := newServerTester(t, nil, func(c *tls.Config) {
2421                 // Only list bad ones:
2422                 c.CipherSuites = []uint16{
2423                         tls.TLS_RSA_WITH_RC4_128_SHA,
2424                         tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA,
2425                         tls.TLS_RSA_WITH_AES_128_CBC_SHA,
2426                         tls.TLS_RSA_WITH_AES_256_CBC_SHA,
2427                         tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
2428                         tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
2429                         tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
2430                         tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA,
2431                         tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
2432                         tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
2433                         tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
2434                         cipher_TLS_RSA_WITH_AES_128_CBC_SHA256,
2435                 }
2436         })
2437         defer st.Close()
2438         gf := st.wantGoAway()
2439         if got, want := gf.ErrCode, ErrCodeInadequateSecurity; got != want {
2440                 t.Errorf("Got error code %v; want %v", got, want)
2441         }
2442 }
2443
2444 func TestServer_Advertises_Common_Cipher(t *testing.T) {
2445         const requiredSuite = tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256
2446         st := newServerTester(t, nil, func(c *tls.Config) {
2447                 // Have the client only support the one required by the spec.
2448                 c.CipherSuites = []uint16{requiredSuite}
2449         }, func(ts *httptest.Server) {
2450                 var srv *http.Server = ts.Config
2451                 // Have the server configured with no specific cipher suites.
2452                 // This tests that Go's defaults include the required one.
2453                 srv.TLSConfig = nil
2454         })
2455         defer st.Close()
2456         st.greet()
2457 }
2458
2459 func (st *serverTester) onHeaderField(f hpack.HeaderField) {
2460         if f.Name == "date" {
2461                 return
2462         }
2463         st.decodedHeaders = append(st.decodedHeaders, [2]string{f.Name, f.Value})
2464 }
2465
2466 func (st *serverTester) decodeHeader(headerBlock []byte) (pairs [][2]string) {
2467         st.decodedHeaders = nil
2468         if _, err := st.hpackDec.Write(headerBlock); err != nil {
2469                 st.t.Fatalf("hpack decoding error: %v", err)
2470         }
2471         if err := st.hpackDec.Close(); err != nil {
2472                 st.t.Fatalf("hpack decoding error: %v", err)
2473         }
2474         return st.decodedHeaders
2475 }
2476
2477 // testServerResponse sets up an idle HTTP/2 connection. The client function should
2478 // write a single request that must be handled by the handler. This waits up to 5s
2479 // for client to return, then up to an additional 2s for the handler to return.
2480 func testServerResponse(t testing.TB,
2481         handler func(http.ResponseWriter, *http.Request) error,
2482         client func(*serverTester),
2483 ) {
2484         errc := make(chan error, 1)
2485         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
2486                 if r.Body == nil {
2487                         t.Fatal("nil Body")
2488                 }
2489                 errc <- handler(w, r)
2490         })
2491         defer st.Close()
2492
2493         donec := make(chan bool)
2494         go func() {
2495                 defer close(donec)
2496                 st.greet()
2497                 client(st)
2498         }()
2499
2500         select {
2501         case <-donec:
2502         case <-time.After(5 * time.Second):
2503                 t.Fatal("timeout in client")
2504         }
2505
2506         select {
2507         case err := <-errc:
2508                 if err != nil {
2509                         t.Fatalf("Error in handler: %v", err)
2510                 }
2511         case <-time.After(2 * time.Second):
2512                 t.Fatal("timeout in handler")
2513         }
2514 }
2515
2516 // readBodyHandler returns an http Handler func that reads len(want)
2517 // bytes from r.Body and fails t if the contents read were not
2518 // the value of want.
2519 func readBodyHandler(t *testing.T, want string) func(w http.ResponseWriter, r *http.Request) {
2520         return func(w http.ResponseWriter, r *http.Request) {
2521                 buf := make([]byte, len(want))
2522                 _, err := io.ReadFull(r.Body, buf)
2523                 if err != nil {
2524                         t.Error(err)
2525                         return
2526                 }
2527                 if string(buf) != want {
2528                         t.Errorf("read %q; want %q", buf, want)
2529                 }
2530         }
2531 }
2532
2533 // TestServerWithCurl currently fails, hence the LenientCipherSuites test. See:
2534 //   https://github.com/tatsuhiro-t/nghttp2/issues/140 &
2535 //   http://sourceforge.net/p/curl/bugs/1472/
2536 func TestServerWithCurl(t *testing.T)                     { testServerWithCurl(t, false) }
2537 func TestServerWithCurl_LenientCipherSuites(t *testing.T) { testServerWithCurl(t, true) }
2538
2539 func testServerWithCurl(t *testing.T, permitProhibitedCipherSuites bool) {
2540         if runtime.GOOS != "linux" {
2541                 t.Skip("skipping Docker test when not on Linux; requires --net which won't work with boot2docker anyway")
2542         }
2543         if testing.Short() {
2544                 t.Skip("skipping curl test in short mode")
2545         }
2546         requireCurl(t)
2547         var gotConn int32
2548         testHookOnConn = func() { atomic.StoreInt32(&gotConn, 1) }
2549
2550         const msg = "Hello from curl!\n"
2551         ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2552                 w.Header().Set("Foo", "Bar")
2553                 w.Header().Set("Client-Proto", r.Proto)
2554                 io.WriteString(w, msg)
2555         }))
2556         ConfigureServer(ts.Config, &Server{
2557                 PermitProhibitedCipherSuites: permitProhibitedCipherSuites,
2558         })
2559         ts.TLS = ts.Config.TLSConfig // the httptest.Server has its own copy of this TLS config
2560         ts.StartTLS()
2561         defer ts.Close()
2562
2563         t.Logf("Running test server for curl to hit at: %s", ts.URL)
2564         container := curl(t, "--silent", "--http2", "--insecure", "-v", ts.URL)
2565         defer kill(container)
2566         resc := make(chan interface{}, 1)
2567         go func() {
2568                 res, err := dockerLogs(container)
2569                 if err != nil {
2570                         resc <- err
2571                 } else {
2572                         resc <- res
2573                 }
2574         }()
2575         select {
2576         case res := <-resc:
2577                 if err, ok := res.(error); ok {
2578                         t.Fatal(err)
2579                 }
2580                 body := string(res.([]byte))
2581                 // Search for both "key: value" and "key:value", since curl changed their format
2582                 // Our Dockerfile contains the latest version (no space), but just in case people
2583                 // didn't rebuild, check both.
2584                 if !strings.Contains(body, "foo: Bar") && !strings.Contains(body, "foo:Bar") {
2585                         t.Errorf("didn't see foo: Bar header")
2586                         t.Logf("Got: %s", body)
2587                 }
2588                 if !strings.Contains(body, "client-proto: HTTP/2") && !strings.Contains(body, "client-proto:HTTP/2") {
2589                         t.Errorf("didn't see client-proto: HTTP/2 header")
2590                         t.Logf("Got: %s", res)
2591                 }
2592                 if !strings.Contains(string(res.([]byte)), msg) {
2593                         t.Errorf("didn't see %q content", msg)
2594                         t.Logf("Got: %s", res)
2595                 }
2596         case <-time.After(3 * time.Second):
2597                 t.Errorf("timeout waiting for curl")
2598         }
2599
2600         if atomic.LoadInt32(&gotConn) == 0 {
2601                 t.Error("never saw an http2 connection")
2602         }
2603 }
2604
2605 var doh2load = flag.Bool("h2load", false, "Run h2load test")
2606
2607 func TestServerWithH2Load(t *testing.T) {
2608         if !*doh2load {
2609                 t.Skip("Skipping without --h2load flag.")
2610         }
2611         if runtime.GOOS != "linux" {
2612                 t.Skip("skipping Docker test when not on Linux; requires --net which won't work with boot2docker anyway")
2613         }
2614         requireH2load(t)
2615
2616         msg := strings.Repeat("Hello, h2load!\n", 5000)
2617         ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2618                 io.WriteString(w, msg)
2619                 w.(http.Flusher).Flush()
2620                 io.WriteString(w, msg)
2621         }))
2622         ts.StartTLS()
2623         defer ts.Close()
2624
2625         cmd := exec.Command("docker", "run", "--net=host", "--entrypoint=/usr/local/bin/h2load", "gohttp2/curl",
2626                 "-n100000", "-c100", "-m100", ts.URL)
2627         cmd.Stdout = os.Stdout
2628         cmd.Stderr = os.Stderr
2629         if err := cmd.Run(); err != nil {
2630                 t.Fatal(err)
2631         }
2632 }
2633
2634 // Issue 12843
2635 func TestServerDoS_MaxHeaderListSize(t *testing.T) {
2636         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {})
2637         defer st.Close()
2638
2639         // shake hands
2640         frameSize := defaultMaxReadFrameSize
2641         var advHeaderListSize *uint32
2642         st.greetAndCheckSettings(func(s Setting) error {
2643                 switch s.ID {
2644                 case SettingMaxFrameSize:
2645                         if s.Val < minMaxFrameSize {
2646                                 frameSize = minMaxFrameSize
2647                         } else if s.Val > maxFrameSize {
2648                                 frameSize = maxFrameSize
2649                         } else {
2650                                 frameSize = int(s.Val)
2651                         }
2652                 case SettingMaxHeaderListSize:
2653                         advHeaderListSize = &s.Val
2654                 }
2655                 return nil
2656         })
2657
2658         if advHeaderListSize == nil {
2659                 t.Errorf("server didn't advertise a max header list size")
2660         } else if *advHeaderListSize == 0 {
2661                 t.Errorf("server advertised a max header list size of 0")
2662         }
2663
2664         st.encodeHeaderField(":method", "GET")
2665         st.encodeHeaderField(":path", "/")
2666         st.encodeHeaderField(":scheme", "https")
2667         cookie := strings.Repeat("*", 4058)
2668         st.encodeHeaderField("cookie", cookie)
2669         st.writeHeaders(HeadersFrameParam{
2670                 StreamID:      1,
2671                 BlockFragment: st.headerBuf.Bytes(),
2672                 EndStream:     true,
2673                 EndHeaders:    false,
2674         })
2675
2676         // Capture the short encoding of a duplicate ~4K cookie, now
2677         // that we've already sent it once.
2678         st.headerBuf.Reset()
2679         st.encodeHeaderField("cookie", cookie)
2680
2681         // Now send 1MB of it.
2682         const size = 1 << 20
2683         b := bytes.Repeat(st.headerBuf.Bytes(), size/st.headerBuf.Len())
2684         for len(b) > 0 {
2685                 chunk := b
2686                 if len(chunk) > frameSize {
2687                         chunk = chunk[:frameSize]
2688                 }
2689                 b = b[len(chunk):]
2690                 st.fr.WriteContinuation(1, len(b) == 0, chunk)
2691         }
2692
2693         h := st.wantHeaders()
2694         if !h.HeadersEnded() {
2695                 t.Fatalf("Got HEADERS without END_HEADERS set: %v", h)
2696         }
2697         headers := st.decodeHeader(h.HeaderBlockFragment())
2698         want := [][2]string{
2699                 {":status", "431"},
2700                 {"content-type", "text/html; charset=utf-8"},
2701                 {"content-length", "63"},
2702         }
2703         if !reflect.DeepEqual(headers, want) {
2704                 t.Errorf("Headers mismatch.\n got: %q\nwant: %q\n", headers, want)
2705         }
2706 }
2707
2708 func TestCompressionErrorOnWrite(t *testing.T) {
2709         const maxStrLen = 8 << 10
2710         var serverConfig *http.Server
2711         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
2712                 // No response body.
2713         }, func(ts *httptest.Server) {
2714                 serverConfig = ts.Config
2715                 serverConfig.MaxHeaderBytes = maxStrLen
2716         })
2717         st.addLogFilter("connection error: COMPRESSION_ERROR")
2718         defer st.Close()
2719         st.greet()
2720
2721         maxAllowed := st.sc.framer.maxHeaderStringLen()
2722
2723         // Crank this up, now that we have a conn connected with the
2724         // hpack.Decoder's max string length set has been initialized
2725         // from the earlier low ~8K value. We want this higher so don't
2726         // hit the max header list size. We only want to test hitting
2727         // the max string size.
2728         serverConfig.MaxHeaderBytes = 1 << 20
2729
2730         // First a request with a header that's exactly the max allowed size
2731         // for the hpack compression. It's still too long for the header list
2732         // size, so we'll get the 431 error, but that keeps the compression
2733         // context still valid.
2734         hbf := st.encodeHeader("foo", strings.Repeat("a", maxAllowed))
2735
2736         st.writeHeaders(HeadersFrameParam{
2737                 StreamID:      1,
2738                 BlockFragment: hbf,
2739                 EndStream:     true,
2740                 EndHeaders:    true,
2741         })
2742         h := st.wantHeaders()
2743         if !h.HeadersEnded() {
2744                 t.Fatalf("Got HEADERS without END_HEADERS set: %v", h)
2745         }
2746         headers := st.decodeHeader(h.HeaderBlockFragment())
2747         want := [][2]string{
2748                 {":status", "431"},
2749                 {"content-type", "text/html; charset=utf-8"},
2750                 {"content-length", "63"},
2751         }
2752         if !reflect.DeepEqual(headers, want) {
2753                 t.Errorf("Headers mismatch.\n got: %q\nwant: %q\n", headers, want)
2754         }
2755         df := st.wantData()
2756         if !strings.Contains(string(df.Data()), "HTTP Error 431") {
2757                 t.Errorf("Unexpected data body: %q", df.Data())
2758         }
2759         if !df.StreamEnded() {
2760                 t.Fatalf("expect data stream end")
2761         }
2762
2763         // And now send one that's just one byte too big.
2764         hbf = st.encodeHeader("bar", strings.Repeat("b", maxAllowed+1))
2765         st.writeHeaders(HeadersFrameParam{
2766                 StreamID:      3,
2767                 BlockFragment: hbf,
2768                 EndStream:     true,
2769                 EndHeaders:    true,
2770         })
2771         ga := st.wantGoAway()
2772         if ga.ErrCode != ErrCodeCompression {
2773                 t.Errorf("GOAWAY err = %v; want ErrCodeCompression", ga.ErrCode)
2774         }
2775 }
2776
2777 func TestCompressionErrorOnClose(t *testing.T) {
2778         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
2779                 // No response body.
2780         })
2781         st.addLogFilter("connection error: COMPRESSION_ERROR")
2782         defer st.Close()
2783         st.greet()
2784
2785         hbf := st.encodeHeader("foo", "bar")
2786         hbf = hbf[:len(hbf)-1] // truncate one byte from the end, so hpack.Decoder.Close fails.
2787         st.writeHeaders(HeadersFrameParam{
2788                 StreamID:      1,
2789                 BlockFragment: hbf,
2790                 EndStream:     true,
2791                 EndHeaders:    true,
2792         })
2793         ga := st.wantGoAway()
2794         if ga.ErrCode != ErrCodeCompression {
2795                 t.Errorf("GOAWAY err = %v; want ErrCodeCompression", ga.ErrCode)
2796         }
2797 }
2798
2799 // test that a server handler can read trailers from a client
2800 func TestServerReadsTrailers(t *testing.T) {
2801         const testBody = "some test body"
2802         writeReq := func(st *serverTester) {
2803                 st.writeHeaders(HeadersFrameParam{
2804                         StreamID:      1, // clients send odd numbers
2805                         BlockFragment: st.encodeHeader("trailer", "Foo, Bar", "trailer", "Baz"),
2806                         EndStream:     false,
2807                         EndHeaders:    true,
2808                 })
2809                 st.writeData(1, false, []byte(testBody))
2810                 st.writeHeaders(HeadersFrameParam{
2811                         StreamID: 1, // clients send odd numbers
2812                         BlockFragment: st.encodeHeaderRaw(
2813                                 "foo", "foov",
2814                                 "bar", "barv",
2815                                 "baz", "bazv",
2816                                 "surprise", "wasn't declared; shouldn't show up",
2817                         ),
2818                         EndStream:  true,
2819                         EndHeaders: true,
2820                 })
2821         }
2822         checkReq := func(r *http.Request) {
2823                 wantTrailer := http.Header{
2824                         "Foo": nil,
2825                         "Bar": nil,
2826                         "Baz": nil,
2827                 }
2828                 if !reflect.DeepEqual(r.Trailer, wantTrailer) {
2829                         t.Errorf("initial Trailer = %v; want %v", r.Trailer, wantTrailer)
2830                 }
2831                 slurp, err := ioutil.ReadAll(r.Body)
2832                 if string(slurp) != testBody {
2833                         t.Errorf("read body %q; want %q", slurp, testBody)
2834                 }
2835                 if err != nil {
2836                         t.Fatalf("Body slurp: %v", err)
2837                 }
2838                 wantTrailerAfter := http.Header{
2839                         "Foo": {"foov"},
2840                         "Bar": {"barv"},
2841                         "Baz": {"bazv"},
2842                 }
2843                 if !reflect.DeepEqual(r.Trailer, wantTrailerAfter) {
2844                         t.Errorf("final Trailer = %v; want %v", r.Trailer, wantTrailerAfter)
2845                 }
2846         }
2847         testServerRequest(t, writeReq, checkReq)
2848 }
2849
2850 // test that a server handler can send trailers
2851 func TestServerWritesTrailers_WithFlush(t *testing.T)    { testServerWritesTrailers(t, true) }
2852 func TestServerWritesTrailers_WithoutFlush(t *testing.T) { testServerWritesTrailers(t, false) }
2853
2854 func testServerWritesTrailers(t *testing.T, withFlush bool) {
2855         // See https://httpwg.github.io/specs/rfc7540.html#rfc.section.8.1.3
2856         testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2857                 w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B")
2858                 w.Header().Add("Trailer", "Server-Trailer-C")
2859                 w.Header().Add("Trailer", "Transfer-Encoding, Content-Length, Trailer") // filtered
2860
2861                 // Regular headers:
2862                 w.Header().Set("Foo", "Bar")
2863                 w.Header().Set("Content-Length", "5") // len("Hello")
2864
2865                 io.WriteString(w, "Hello")
2866                 if withFlush {
2867                         w.(http.Flusher).Flush()
2868                 }
2869                 w.Header().Set("Server-Trailer-A", "valuea")
2870                 w.Header().Set("Server-Trailer-C", "valuec") // skipping B
2871                 // After a flush, random keys like Server-Surprise shouldn't show up:
2872                 w.Header().Set("Server-Surpise", "surprise! this isn't predeclared!")
2873                 // But we do permit promoting keys to trailers after a
2874                 // flush if they start with the magic
2875                 // otherwise-invalid "Trailer:" prefix:
2876                 w.Header().Set("Trailer:Post-Header-Trailer", "hi1")
2877                 w.Header().Set("Trailer:post-header-trailer2", "hi2")
2878                 w.Header().Set("Trailer:Range", "invalid")
2879                 w.Header().Set("Trailer:Foo\x01Bogus", "invalid")
2880                 w.Header().Set("Transfer-Encoding", "should not be included; Forbidden by RFC 2616 14.40")
2881                 w.Header().Set("Content-Length", "should not be included; Forbidden by RFC 2616 14.40")
2882                 w.Header().Set("Trailer", "should not be included; Forbidden by RFC 2616 14.40")
2883                 return nil
2884         }, func(st *serverTester) {
2885                 getSlash(st)
2886                 hf := st.wantHeaders()
2887                 if hf.StreamEnded() {
2888                         t.Fatal("response HEADERS had END_STREAM")
2889                 }
2890                 if !hf.HeadersEnded() {
2891                         t.Fatal("response HEADERS didn't have END_HEADERS")
2892                 }
2893                 goth := st.decodeHeader(hf.HeaderBlockFragment())
2894                 wanth := [][2]string{
2895                         {":status", "200"},
2896                         {"foo", "Bar"},
2897                         {"trailer", "Server-Trailer-A, Server-Trailer-B"},
2898                         {"trailer", "Server-Trailer-C"},
2899                         {"trailer", "Transfer-Encoding, Content-Length, Trailer"},
2900                         {"content-type", "text/plain; charset=utf-8"},
2901                         {"content-length", "5"},
2902                 }
2903                 if !reflect.DeepEqual(goth, wanth) {
2904                         t.Errorf("Header mismatch.\n got: %v\nwant: %v", goth, wanth)
2905                 }
2906                 df := st.wantData()
2907                 if string(df.Data()) != "Hello" {
2908                         t.Fatalf("Client read %q; want Hello", df.Data())
2909                 }
2910                 if df.StreamEnded() {
2911                         t.Fatalf("data frame had STREAM_ENDED")
2912                 }
2913                 tf := st.wantHeaders() // for the trailers
2914                 if !tf.StreamEnded() {
2915                         t.Fatalf("trailers HEADERS lacked END_STREAM")
2916                 }
2917                 if !tf.HeadersEnded() {
2918                         t.Fatalf("trailers HEADERS lacked END_HEADERS")
2919                 }
2920                 wanth = [][2]string{
2921                         {"post-header-trailer", "hi1"},
2922                         {"post-header-trailer2", "hi2"},
2923                         {"server-trailer-a", "valuea"},
2924                         {"server-trailer-c", "valuec"},
2925                 }
2926                 goth = st.decodeHeader(tf.HeaderBlockFragment())
2927                 if !reflect.DeepEqual(goth, wanth) {
2928                         t.Errorf("Header mismatch.\n got: %v\nwant: %v", goth, wanth)
2929                 }
2930         })
2931 }
2932
2933 // validate transmitted header field names & values
2934 // golang.org/issue/14048
2935 func TestServerDoesntWriteInvalidHeaders(t *testing.T) {
2936         testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2937                 w.Header().Add("OK1", "x")
2938                 w.Header().Add("Bad:Colon", "x") // colon (non-token byte) in key
2939                 w.Header().Add("Bad1\x00", "x")  // null in key
2940                 w.Header().Add("Bad2", "x\x00y") // null in value
2941                 return nil
2942         }, func(st *serverTester) {
2943                 getSlash(st)
2944                 hf := st.wantHeaders()
2945                 if !hf.StreamEnded() {
2946                         t.Error("response HEADERS lacked END_STREAM")
2947                 }
2948                 if !hf.HeadersEnded() {
2949                         t.Fatal("response HEADERS didn't have END_HEADERS")
2950                 }
2951                 goth := st.decodeHeader(hf.HeaderBlockFragment())
2952                 wanth := [][2]string{
2953                         {":status", "200"},
2954                         {"ok1", "x"},
2955                         {"content-type", "text/plain; charset=utf-8"},
2956                         {"content-length", "0"},
2957                 }
2958                 if !reflect.DeepEqual(goth, wanth) {
2959                         t.Errorf("Header mismatch.\n got: %v\nwant: %v", goth, wanth)
2960                 }
2961         })
2962 }
2963
2964 func BenchmarkServerGets(b *testing.B) {
2965         defer disableGoroutineTracking()()
2966         b.ReportAllocs()
2967
2968         const msg = "Hello, world"
2969         st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) {
2970                 io.WriteString(w, msg)
2971         })
2972         defer st.Close()
2973         st.greet()
2974
2975         // Give the server quota to reply. (plus it has the the 64KB)
2976         if err := st.fr.WriteWindowUpdate(0, uint32(b.N*len(msg))); err != nil {
2977                 b.Fatal(err)
2978         }
2979
2980         for i := 0; i < b.N; i++ {
2981                 id := 1 + uint32(i)*2
2982                 st.writeHeaders(HeadersFrameParam{
2983                         StreamID:      id,
2984                         BlockFragment: st.encodeHeader(),
2985                         EndStream:     true,
2986                         EndHeaders:    true,
2987                 })
2988                 st.wantHeaders()
2989                 df := st.wantData()
2990                 if !df.StreamEnded() {
2991                         b.Fatalf("DATA didn't have END_STREAM; got %v", df)
2992                 }
2993         }
2994 }
2995
2996 func BenchmarkServerPosts(b *testing.B) {
2997         defer disableGoroutineTracking()()
2998         b.ReportAllocs()
2999
3000         const msg = "Hello, world"
3001         st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) {
3002                 // Consume the (empty) body from th peer before replying, otherwise
3003                 // the server will sometimes (depending on scheduling) send the peer a
3004                 // a RST_STREAM with the CANCEL error code.
3005                 if n, err := io.Copy(ioutil.Discard, r.Body); n != 0 || err != nil {
3006                         b.Errorf("Copy error; got %v, %v; want 0, nil", n, err)
3007                 }
3008                 io.WriteString(w, msg)
3009         })
3010         defer st.Close()
3011         st.greet()
3012
3013         // Give the server quota to reply. (plus it has the the 64KB)
3014         if err := st.fr.WriteWindowUpdate(0, uint32(b.N*len(msg))); err != nil {
3015                 b.Fatal(err)
3016         }
3017
3018         for i := 0; i < b.N; i++ {
3019                 id := 1 + uint32(i)*2
3020                 st.writeHeaders(HeadersFrameParam{
3021                         StreamID:      id,
3022                         BlockFragment: st.encodeHeader(":method", "POST"),
3023                         EndStream:     false,
3024                         EndHeaders:    true,
3025                 })
3026                 st.writeData(id, true, nil)
3027                 st.wantHeaders()
3028                 df := st.wantData()
3029                 if !df.StreamEnded() {
3030                         b.Fatalf("DATA didn't have END_STREAM; got %v", df)
3031                 }
3032         }
3033 }
3034
3035 // Send a stream of messages from server to client in separate data frames.
3036 // Brings up performance issues seen in long streams.
3037 // Created to show problem in go issue #18502
3038 func BenchmarkServerToClientStreamDefaultOptions(b *testing.B) {
3039         benchmarkServerToClientStream(b)
3040 }
3041
3042 // Justification for Change-Id: Iad93420ef6c3918f54249d867098f1dadfa324d8
3043 // Expect to see memory/alloc reduction by opting in to Frame reuse with the Framer.
3044 func BenchmarkServerToClientStreamReuseFrames(b *testing.B) {
3045         benchmarkServerToClientStream(b, optFramerReuseFrames)
3046 }
3047
3048 func benchmarkServerToClientStream(b *testing.B, newServerOpts ...interface{}) {
3049         defer disableGoroutineTracking()()
3050         b.ReportAllocs()
3051         const msgLen = 1
3052         // default window size
3053         const windowSize = 1<<16 - 1
3054
3055         // next message to send from the server and for the client to expect
3056         nextMsg := func(i int) []byte {
3057                 msg := make([]byte, msgLen)
3058                 msg[0] = byte(i)
3059                 if len(msg) != msgLen {
3060                         panic("invalid test setup msg length")
3061                 }
3062                 return msg
3063         }
3064
3065         st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) {
3066                 // Consume the (empty) body from th peer before replying, otherwise
3067                 // the server will sometimes (depending on scheduling) send the peer a
3068                 // a RST_STREAM with the CANCEL error code.
3069                 if n, err := io.Copy(ioutil.Discard, r.Body); n != 0 || err != nil {
3070                         b.Errorf("Copy error; got %v, %v; want 0, nil", n, err)
3071                 }
3072                 for i := 0; i < b.N; i += 1 {
3073                         w.Write(nextMsg(i))
3074                         w.(http.Flusher).Flush()
3075                 }
3076         }, newServerOpts...)
3077         defer st.Close()
3078         st.greet()
3079
3080         const id = uint32(1)
3081
3082         st.writeHeaders(HeadersFrameParam{
3083                 StreamID:      id,
3084                 BlockFragment: st.encodeHeader(":method", "POST"),
3085                 EndStream:     false,
3086                 EndHeaders:    true,
3087         })
3088
3089         st.writeData(id, true, nil)
3090         st.wantHeaders()
3091
3092         var pendingWindowUpdate = uint32(0)
3093
3094         for i := 0; i < b.N; i += 1 {
3095                 expected := nextMsg(i)
3096                 df := st.wantData()
3097                 if bytes.Compare(expected, df.data) != 0 {
3098                         b.Fatalf("Bad message received; want %v; got %v", expected, df.data)
3099                 }
3100                 // try to send infrequent but large window updates so they don't overwhelm the test
3101                 pendingWindowUpdate += uint32(len(df.data))
3102                 if pendingWindowUpdate >= windowSize/2 {
3103                         if err := st.fr.WriteWindowUpdate(0, pendingWindowUpdate); err != nil {
3104                                 b.Fatal(err)
3105                         }
3106                         if err := st.fr.WriteWindowUpdate(id, pendingWindowUpdate); err != nil {
3107                                 b.Fatal(err)
3108                         }
3109                         pendingWindowUpdate = 0
3110                 }
3111         }
3112         df := st.wantData()
3113         if !df.StreamEnded() {
3114                 b.Fatalf("DATA didn't have END_STREAM; got %v", df)
3115         }
3116 }
3117
3118 // go-fuzz bug, originally reported at https://github.com/bradfitz/http2/issues/53
3119 // Verify we don't hang.
3120 func TestIssue53(t *testing.T) {
3121         const data = "PRI * HTTP/2.0\r\n\r\nSM" +
3122                 "\r\n\r\n\x00\x00\x00\x01\ainfinfin\ad"
3123         s := &http.Server{
3124                 ErrorLog: log.New(io.MultiWriter(stderrv(), twriter{t: t}), "", log.LstdFlags),
3125                 Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
3126                         w.Write([]byte("hello"))
3127                 }),
3128         }
3129         s2 := &Server{
3130                 MaxReadFrameSize:             1 << 16,
3131                 PermitProhibitedCipherSuites: true,
3132         }
3133         c := &issue53Conn{[]byte(data), false, false}
3134         s2.ServeConn(c, &ServeConnOpts{BaseConfig: s})
3135         if !c.closed {
3136                 t.Fatal("connection is not closed")
3137         }
3138 }
3139
3140 type issue53Conn struct {
3141         data    []byte
3142         closed  bool
3143         written bool
3144 }
3145
3146 func (c *issue53Conn) Read(b []byte) (n int, err error) {
3147         if len(c.data) == 0 {
3148                 return 0, io.EOF
3149         }
3150         n = copy(b, c.data)
3151         c.data = c.data[n:]
3152         return
3153 }
3154
3155 func (c *issue53Conn) Write(b []byte) (n int, err error) {
3156         c.written = true
3157         return len(b), nil
3158 }
3159
3160 func (c *issue53Conn) Close() error {
3161         c.closed = true
3162         return nil
3163 }
3164
3165 func (c *issue53Conn) LocalAddr() net.Addr {
3166         return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 49706}
3167 }
3168 func (c *issue53Conn) RemoteAddr() net.Addr {
3169         return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 49706}
3170 }
3171 func (c *issue53Conn) SetDeadline(t time.Time) error      { return nil }
3172 func (c *issue53Conn) SetReadDeadline(t time.Time) error  { return nil }
3173 func (c *issue53Conn) SetWriteDeadline(t time.Time) error { return nil }
3174
3175 // golang.org/issue/12895
3176 func TestConfigureServer(t *testing.T) {
3177         tests := []struct {
3178                 name      string
3179                 tlsConfig *tls.Config
3180                 wantErr   string
3181         }{
3182                 {
3183                         name: "empty server",
3184                 },
3185                 {
3186                         name: "just the required cipher suite",
3187                         tlsConfig: &tls.Config{
3188                                 CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
3189                         },
3190                 },
3191                 {
3192                         name: "missing required cipher suite",
3193                         tlsConfig: &tls.Config{
3194                                 CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384},
3195                         },
3196                         wantErr: "is missing HTTP/2-required TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
3197                 },
3198                 {
3199                         name: "required after bad",
3200                         tlsConfig: &tls.Config{
3201                                 CipherSuites: []uint16{tls.TLS_RSA_WITH_RC4_128_SHA, tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
3202                         },
3203                         wantErr: "contains an HTTP/2-approved cipher suite (0xc02f), but it comes after",
3204                 },
3205                 {
3206                         name: "bad after required",
3207                         tlsConfig: &tls.Config{
3208                                 CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_RSA_WITH_RC4_128_SHA},
3209                         },
3210                 },
3211         }
3212         for _, tt := range tests {
3213                 srv := &http.Server{TLSConfig: tt.tlsConfig}
3214                 err := ConfigureServer(srv, nil)
3215                 if (err != nil) != (tt.wantErr != "") {
3216                         if tt.wantErr != "" {
3217                                 t.Errorf("%s: success, but want error", tt.name)
3218                         } else {
3219                                 t.Errorf("%s: unexpected error: %v", tt.name, err)
3220                         }
3221                 }
3222                 if err != nil && tt.wantErr != "" && !strings.Contains(err.Error(), tt.wantErr) {
3223                         t.Errorf("%s: err = %v; want substring %q", tt.name, err, tt.wantErr)
3224                 }
3225                 if err == nil && !srv.TLSConfig.PreferServerCipherSuites {
3226                         t.Errorf("%s: PreferServerCipherSuite is false; want true", tt.name)
3227                 }
3228         }
3229 }
3230
3231 func TestServerRejectHeadWithBody(t *testing.T) {
3232         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3233                 // No response body.
3234         })
3235         defer st.Close()
3236         st.greet()
3237         st.writeHeaders(HeadersFrameParam{
3238                 StreamID:      1, // clients send odd numbers
3239                 BlockFragment: st.encodeHeader(":method", "HEAD"),
3240                 EndStream:     false, // what we're testing, a bogus HEAD request with body
3241                 EndHeaders:    true,
3242         })
3243         st.wantRSTStream(1, ErrCodeProtocol)
3244 }
3245
3246 func TestServerNoAutoContentLengthOnHead(t *testing.T) {
3247         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3248                 // No response body. (or smaller than one frame)
3249         })
3250         defer st.Close()
3251         st.greet()
3252         st.writeHeaders(HeadersFrameParam{
3253                 StreamID:      1, // clients send odd numbers
3254                 BlockFragment: st.encodeHeader(":method", "HEAD"),
3255                 EndStream:     true,
3256                 EndHeaders:    true,
3257         })
3258         h := st.wantHeaders()
3259         headers := st.decodeHeader(h.HeaderBlockFragment())
3260         want := [][2]string{
3261                 {":status", "200"},
3262                 {"content-type", "text/plain; charset=utf-8"},
3263         }
3264         if !reflect.DeepEqual(headers, want) {
3265                 t.Errorf("Headers mismatch.\n got: %q\nwant: %q\n", headers, want)
3266         }
3267 }
3268
3269 // golang.org/issue/13495
3270 func TestServerNoDuplicateContentType(t *testing.T) {
3271         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3272                 w.Header()["Content-Type"] = []string{""}
3273                 fmt.Fprintf(w, "<html><head></head><body>hi</body></html>")
3274         })
3275         defer st.Close()
3276         st.greet()
3277         st.writeHeaders(HeadersFrameParam{
3278                 StreamID:      1,
3279                 BlockFragment: st.encodeHeader(),
3280                 EndStream:     true,
3281                 EndHeaders:    true,
3282         })
3283         h := st.wantHeaders()
3284         headers := st.decodeHeader(h.HeaderBlockFragment())
3285         want := [][2]string{
3286                 {":status", "200"},
3287                 {"content-type", ""},
3288                 {"content-length", "41"},
3289         }
3290         if !reflect.DeepEqual(headers, want) {
3291                 t.Errorf("Headers mismatch.\n got: %q\nwant: %q\n", headers, want)
3292         }
3293 }
3294
3295 func disableGoroutineTracking() (restore func()) {
3296         old := DebugGoroutines
3297         DebugGoroutines = false
3298         return func() { DebugGoroutines = old }
3299 }
3300
3301 func BenchmarkServer_GetRequest(b *testing.B) {
3302         defer disableGoroutineTracking()()
3303         b.ReportAllocs()
3304         const msg = "Hello, world."
3305         st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) {
3306                 n, err := io.Copy(ioutil.Discard, r.Body)
3307                 if err != nil || n > 0 {
3308                         b.Errorf("Read %d bytes, error %v; want 0 bytes.", n, err)
3309                 }
3310                 io.WriteString(w, msg)
3311         })
3312         defer st.Close()
3313
3314         st.greet()
3315         // Give the server quota to reply. (plus it has the the 64KB)
3316         if err := st.fr.WriteWindowUpdate(0, uint32(b.N*len(msg))); err != nil {
3317                 b.Fatal(err)
3318         }
3319         hbf := st.encodeHeader(":method", "GET")
3320         for i := 0; i < b.N; i++ {
3321                 streamID := uint32(1 + 2*i)
3322                 st.writeHeaders(HeadersFrameParam{
3323                         StreamID:      streamID,
3324                         BlockFragment: hbf,
3325                         EndStream:     true,
3326                         EndHeaders:    true,
3327                 })
3328                 st.wantHeaders()
3329                 st.wantData()
3330         }
3331 }
3332
3333 func BenchmarkServer_PostRequest(b *testing.B) {
3334         defer disableGoroutineTracking()()
3335         b.ReportAllocs()
3336         const msg = "Hello, world."
3337         st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) {
3338                 n, err := io.Copy(ioutil.Discard, r.Body)
3339                 if err != nil || n > 0 {
3340                         b.Errorf("Read %d bytes, error %v; want 0 bytes.", n, err)
3341                 }
3342                 io.WriteString(w, msg)
3343         })
3344         defer st.Close()
3345         st.greet()
3346         // Give the server quota to reply. (plus it has the the 64KB)
3347         if err := st.fr.WriteWindowUpdate(0, uint32(b.N*len(msg))); err != nil {
3348                 b.Fatal(err)
3349         }
3350         hbf := st.encodeHeader(":method", "POST")
3351         for i := 0; i < b.N; i++ {
3352                 streamID := uint32(1 + 2*i)
3353                 st.writeHeaders(HeadersFrameParam{
3354                         StreamID:      streamID,
3355                         BlockFragment: hbf,
3356                         EndStream:     false,
3357                         EndHeaders:    true,
3358                 })
3359                 st.writeData(streamID, true, nil)
3360                 st.wantHeaders()
3361                 st.wantData()
3362         }
3363 }
3364
3365 type connStateConn struct {
3366         net.Conn
3367         cs tls.ConnectionState
3368 }
3369
3370 func (c connStateConn) ConnectionState() tls.ConnectionState { return c.cs }
3371
3372 // golang.org/issue/12737 -- handle any net.Conn, not just
3373 // *tls.Conn.
3374 func TestServerHandleCustomConn(t *testing.T) {
3375         var s Server
3376         c1, c2 := net.Pipe()
3377         clientDone := make(chan struct{})
3378         handlerDone := make(chan struct{})
3379         var req *http.Request
3380         go func() {
3381                 defer close(clientDone)
3382                 defer c2.Close()
3383                 fr := NewFramer(c2, c2)
3384                 io.WriteString(c2, ClientPreface)
3385                 fr.WriteSettings()
3386                 fr.WriteSettingsAck()
3387                 f, err := fr.ReadFrame()
3388                 if err != nil {
3389                         t.Error(err)
3390                         return
3391                 }
3392                 if sf, ok := f.(*SettingsFrame); !ok || sf.IsAck() {
3393                         t.Errorf("Got %v; want non-ACK SettingsFrame", summarizeFrame(f))
3394                         return
3395                 }
3396                 f, err = fr.ReadFrame()
3397                 if err != nil {
3398                         t.Error(err)
3399                         return
3400                 }
3401                 if sf, ok := f.(*SettingsFrame); !ok || !sf.IsAck() {
3402                         t.Errorf("Got %v; want ACK SettingsFrame", summarizeFrame(f))
3403                         return
3404                 }
3405                 var henc hpackEncoder
3406                 fr.WriteHeaders(HeadersFrameParam{
3407                         StreamID:      1,
3408                         BlockFragment: henc.encodeHeaderRaw(t, ":method", "GET", ":path", "/", ":scheme", "https", ":authority", "foo.com"),
3409                         EndStream:     true,
3410                         EndHeaders:    true,
3411                 })
3412                 go io.Copy(ioutil.Discard, c2)
3413                 <-handlerDone
3414         }()
3415         const testString = "my custom ConnectionState"
3416         fakeConnState := tls.ConnectionState{
3417                 ServerName:  testString,
3418                 Version:     tls.VersionTLS12,
3419                 CipherSuite: cipher_TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
3420         }
3421         go s.ServeConn(connStateConn{c1, fakeConnState}, &ServeConnOpts{
3422                 BaseConfig: &http.Server{
3423                         Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
3424                                 defer close(handlerDone)
3425                                 req = r
3426                         }),
3427                 }})
3428         select {
3429         case <-clientDone:
3430         case <-time.After(5 * time.Second):
3431                 t.Fatal("timeout waiting for handler")
3432         }
3433         if req.TLS == nil {
3434                 t.Fatalf("Request.TLS is nil. Got: %#v", req)
3435         }
3436         if req.TLS.ServerName != testString {
3437                 t.Fatalf("Request.TLS = %+v; want ServerName of %q", req.TLS, testString)
3438         }
3439 }
3440
3441 // golang.org/issue/14214
3442 func TestServer_Rejects_ConnHeaders(t *testing.T) {
3443         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3444                 t.Error("should not get to Handler")
3445         })
3446         defer st.Close()
3447         st.greet()
3448         st.bodylessReq1("connection", "foo")
3449         hf := st.wantHeaders()
3450         goth := st.decodeHeader(hf.HeaderBlockFragment())
3451         wanth := [][2]string{
3452                 {":status", "400"},
3453                 {"content-type", "text/plain; charset=utf-8"},
3454                 {"x-content-type-options", "nosniff"},
3455                 {"content-length", "51"},
3456         }
3457         if !reflect.DeepEqual(goth, wanth) {
3458                 t.Errorf("Got headers %v; want %v", goth, wanth)
3459         }
3460 }
3461
3462 type hpackEncoder struct {
3463         enc *hpack.Encoder
3464         buf bytes.Buffer
3465 }
3466
3467 func (he *hpackEncoder) encodeHeaderRaw(t *testing.T, headers ...string) []byte {
3468         if len(headers)%2 == 1 {
3469                 panic("odd number of kv args")
3470         }
3471         he.buf.Reset()
3472         if he.enc == nil {
3473                 he.enc = hpack.NewEncoder(&he.buf)
3474         }
3475         for len(headers) > 0 {
3476                 k, v := headers[0], headers[1]
3477                 err := he.enc.WriteField(hpack.HeaderField{Name: k, Value: v})
3478                 if err != nil {
3479                         t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err)
3480                 }
3481                 headers = headers[2:]
3482         }
3483         return he.buf.Bytes()
3484 }
3485
3486 func TestCheckValidHTTP2Request(t *testing.T) {
3487         tests := []struct {
3488                 h    http.Header
3489                 want error
3490         }{
3491                 {
3492                         h:    http.Header{"Te": {"trailers"}},
3493                         want: nil,
3494                 },
3495                 {
3496                         h:    http.Header{"Te": {"trailers", "bogus"}},
3497                         want: errors.New(`request header "TE" may only be "trailers" in HTTP/2`),
3498                 },
3499                 {
3500                         h:    http.Header{"Foo": {""}},
3501                         want: nil,
3502                 },
3503                 {
3504                         h:    http.Header{"Connection": {""}},
3505                         want: errors.New(`request header "Connection" is not valid in HTTP/2`),
3506                 },
3507                 {
3508                         h:    http.Header{"Proxy-Connection": {""}},
3509                         want: errors.New(`request header "Proxy-Connection" is not valid in HTTP/2`),
3510                 },
3511                 {
3512                         h:    http.Header{"Keep-Alive": {""}},
3513                         want: errors.New(`request header "Keep-Alive" is not valid in HTTP/2`),
3514                 },
3515                 {
3516                         h:    http.Header{"Upgrade": {""}},
3517                         want: errors.New(`request header "Upgrade" is not valid in HTTP/2`),
3518                 },
3519         }
3520         for i, tt := range tests {
3521                 got := checkValidHTTP2RequestHeaders(tt.h)
3522                 if !reflect.DeepEqual(got, tt.want) {
3523                         t.Errorf("%d. checkValidHTTP2Request = %v; want %v", i, got, tt.want)
3524                 }
3525         }
3526 }
3527
3528 // golang.org/issue/14030
3529 func TestExpect100ContinueAfterHandlerWrites(t *testing.T) {
3530         const msg = "Hello"
3531         const msg2 = "World"
3532
3533         doRead := make(chan bool, 1)
3534         defer close(doRead) // fallback cleanup
3535
3536         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3537                 io.WriteString(w, msg)
3538                 w.(http.Flusher).Flush()
3539
3540                 // Do a read, which might force a 100-continue status to be sent.
3541                 <-doRead
3542                 r.Body.Read(make([]byte, 10))
3543
3544                 io.WriteString(w, msg2)
3545
3546         }, optOnlyServer)
3547         defer st.Close()
3548
3549         tr := &Transport{TLSClientConfig: tlsConfigInsecure}
3550         defer tr.CloseIdleConnections()
3551
3552         req, _ := http.NewRequest("POST", st.ts.URL, io.LimitReader(neverEnding('A'), 2<<20))
3553         req.Header.Set("Expect", "100-continue")
3554
3555         res, err := tr.RoundTrip(req)
3556         if err != nil {
3557                 t.Fatal(err)
3558         }
3559         defer res.Body.Close()
3560
3561         buf := make([]byte, len(msg))
3562         if _, err := io.ReadFull(res.Body, buf); err != nil {
3563                 t.Fatal(err)
3564         }
3565         if string(buf) != msg {
3566                 t.Fatalf("msg = %q; want %q", buf, msg)
3567         }
3568
3569         doRead <- true
3570
3571         if _, err := io.ReadFull(res.Body, buf); err != nil {
3572                 t.Fatal(err)
3573         }
3574         if string(buf) != msg2 {
3575                 t.Fatalf("second msg = %q; want %q", buf, msg2)
3576         }
3577 }
3578
3579 type funcReader func([]byte) (n int, err error)
3580
3581 func (f funcReader) Read(p []byte) (n int, err error) { return f(p) }
3582
3583 // golang.org/issue/16481 -- return flow control when streams close with unread data.
3584 // (The Server version of the bug. See also TestUnreadFlowControlReturned_Transport)
3585 func TestUnreadFlowControlReturned_Server(t *testing.T) {
3586         unblock := make(chan bool, 1)
3587         defer close(unblock)
3588
3589         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3590                 // Don't read the 16KB request body. Wait until the client's
3591                 // done sending it and then return. This should cause the Server
3592                 // to then return those 16KB of flow control to the client.
3593                 <-unblock
3594         }, optOnlyServer)
3595         defer st.Close()
3596
3597         tr := &Transport{TLSClientConfig: tlsConfigInsecure}
3598         defer tr.CloseIdleConnections()
3599
3600         // This previously hung on the 4th iteration.
3601         for i := 0; i < 6; i++ {
3602                 body := io.MultiReader(
3603                         io.LimitReader(neverEnding('A'), 16<<10),
3604                         funcReader(func([]byte) (n int, err error) {
3605                                 unblock <- true
3606                                 return 0, io.EOF
3607                         }),
3608                 )
3609                 req, _ := http.NewRequest("POST", st.ts.URL, body)
3610                 res, err := tr.RoundTrip(req)
3611                 if err != nil {
3612                         t.Fatal(err)
3613                 }
3614                 res.Body.Close()
3615         }
3616
3617 }
3618
3619 func TestServerIdleTimeout(t *testing.T) {
3620         if testing.Short() {
3621                 t.Skip("skipping in short mode")
3622         }
3623
3624         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3625         }, func(h2s *Server) {
3626                 h2s.IdleTimeout = 500 * time.Millisecond
3627         })
3628         defer st.Close()
3629
3630         st.greet()
3631         ga := st.wantGoAway()
3632         if ga.ErrCode != ErrCodeNo {
3633                 t.Errorf("GOAWAY error = %v; want ErrCodeNo", ga.ErrCode)
3634         }
3635 }
3636
3637 func TestServerIdleTimeout_AfterRequest(t *testing.T) {
3638         if testing.Short() {
3639                 t.Skip("skipping in short mode")
3640         }
3641         const timeout = 250 * time.Millisecond
3642
3643         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3644                 time.Sleep(timeout * 2)
3645         }, func(h2s *Server) {
3646                 h2s.IdleTimeout = timeout
3647         })
3648         defer st.Close()
3649
3650         st.greet()
3651
3652         // Send a request which takes twice the timeout. Verifies the
3653         // idle timeout doesn't fire while we're in a request:
3654         st.bodylessReq1()
3655         st.wantHeaders()
3656
3657         // But the idle timeout should be rearmed after the request
3658         // is done:
3659         ga := st.wantGoAway()
3660         if ga.ErrCode != ErrCodeNo {
3661                 t.Errorf("GOAWAY error = %v; want ErrCodeNo", ga.ErrCode)
3662         }
3663 }
3664
3665 // grpc-go closes the Request.Body currently with a Read.
3666 // Verify that it doesn't race.
3667 // See https://github.com/grpc/grpc-go/pull/938
3668 func TestRequestBodyReadCloseRace(t *testing.T) {
3669         for i := 0; i < 100; i++ {
3670                 body := &requestBody{
3671                         pipe: &pipe{
3672                                 b: new(bytes.Buffer),
3673                         },
3674                 }
3675                 body.pipe.CloseWithError(io.EOF)
3676
3677                 done := make(chan bool, 1)
3678                 buf := make([]byte, 10)
3679                 go func() {
3680                         time.Sleep(1 * time.Millisecond)
3681                         body.Close()
3682                         done <- true
3683                 }()
3684                 body.Read(buf)
3685                 <-done
3686         }
3687 }
3688
3689 func TestIssue20704Race(t *testing.T) {
3690         if testing.Short() && os.Getenv("GO_BUILDER_NAME") == "" {
3691                 t.Skip("skipping in short mode")
3692         }
3693         const (
3694                 itemSize  = 1 << 10
3695                 itemCount = 100
3696         )
3697
3698         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3699                 for i := 0; i < itemCount; i++ {
3700                         _, err := w.Write(make([]byte, itemSize))
3701                         if err != nil {
3702                                 return
3703                         }
3704                 }
3705         }, optOnlyServer)
3706         defer st.Close()
3707
3708         tr := &Transport{TLSClientConfig: tlsConfigInsecure}
3709         defer tr.CloseIdleConnections()
3710         cl := &http.Client{Transport: tr}
3711
3712         for i := 0; i < 1000; i++ {
3713                 resp, err := cl.Get(st.ts.URL)
3714                 if err != nil {
3715                         t.Fatal(err)
3716                 }
3717                 // Force a RST stream to the server by closing without
3718                 // reading the body:
3719                 resp.Body.Close()
3720         }
3721 }