OSDN Git Service

new repo
[bytom/vapor.git] / vendor / golang.org / x / net / http2 / server_push_test.go
1 // Copyright 2016 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 // +build go1.8
6
7 package http2
8
9 import (
10         "errors"
11         "fmt"
12         "io"
13         "io/ioutil"
14         "net/http"
15         "reflect"
16         "strconv"
17         "sync"
18         "testing"
19         "time"
20 )
21
22 func TestServer_Push_Success(t *testing.T) {
23         const (
24                 mainBody   = "<html>index page</html>"
25                 pushedBody = "<html>pushed page</html>"
26                 userAgent  = "testagent"
27                 cookie     = "testcookie"
28         )
29
30         var stURL string
31         checkPromisedReq := func(r *http.Request, wantMethod string, wantH http.Header) error {
32                 if got, want := r.Method, wantMethod; got != want {
33                         return fmt.Errorf("promised Req.Method=%q, want %q", got, want)
34                 }
35                 if got, want := r.Header, wantH; !reflect.DeepEqual(got, want) {
36                         return fmt.Errorf("promised Req.Header=%q, want %q", got, want)
37                 }
38                 if got, want := "https://"+r.Host, stURL; got != want {
39                         return fmt.Errorf("promised Req.Host=%q, want %q", got, want)
40                 }
41                 if r.Body == nil {
42                         return fmt.Errorf("nil Body")
43                 }
44                 if buf, err := ioutil.ReadAll(r.Body); err != nil || len(buf) != 0 {
45                         return fmt.Errorf("ReadAll(Body)=%q,%v, want '',nil", buf, err)
46                 }
47                 return nil
48         }
49
50         errc := make(chan error, 3)
51         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
52                 switch r.URL.RequestURI() {
53                 case "/":
54                         // Push "/pushed?get" as a GET request, using an absolute URL.
55                         opt := &http.PushOptions{
56                                 Header: http.Header{
57                                         "User-Agent": {userAgent},
58                                 },
59                         }
60                         if err := w.(http.Pusher).Push(stURL+"/pushed?get", opt); err != nil {
61                                 errc <- fmt.Errorf("error pushing /pushed?get: %v", err)
62                                 return
63                         }
64                         // Push "/pushed?head" as a HEAD request, using a path.
65                         opt = &http.PushOptions{
66                                 Method: "HEAD",
67                                 Header: http.Header{
68                                         "User-Agent": {userAgent},
69                                         "Cookie":     {cookie},
70                                 },
71                         }
72                         if err := w.(http.Pusher).Push("/pushed?head", opt); err != nil {
73                                 errc <- fmt.Errorf("error pushing /pushed?head: %v", err)
74                                 return
75                         }
76                         w.Header().Set("Content-Type", "text/html")
77                         w.Header().Set("Content-Length", strconv.Itoa(len(mainBody)))
78                         w.WriteHeader(200)
79                         io.WriteString(w, mainBody)
80                         errc <- nil
81
82                 case "/pushed?get":
83                         wantH := http.Header{}
84                         wantH.Set("User-Agent", userAgent)
85                         if err := checkPromisedReq(r, "GET", wantH); err != nil {
86                                 errc <- fmt.Errorf("/pushed?get: %v", err)
87                                 return
88                         }
89                         w.Header().Set("Content-Type", "text/html")
90                         w.Header().Set("Content-Length", strconv.Itoa(len(pushedBody)))
91                         w.WriteHeader(200)
92                         io.WriteString(w, pushedBody)
93                         errc <- nil
94
95                 case "/pushed?head":
96                         wantH := http.Header{}
97                         wantH.Set("User-Agent", userAgent)
98                         wantH.Set("Cookie", cookie)
99                         if err := checkPromisedReq(r, "HEAD", wantH); err != nil {
100                                 errc <- fmt.Errorf("/pushed?head: %v", err)
101                                 return
102                         }
103                         w.WriteHeader(204)
104                         errc <- nil
105
106                 default:
107                         errc <- fmt.Errorf("unknown RequestURL %q", r.URL.RequestURI())
108                 }
109         })
110         stURL = st.ts.URL
111
112         // Send one request, which should push two responses.
113         st.greet()
114         getSlash(st)
115         for k := 0; k < 3; k++ {
116                 select {
117                 case <-time.After(2 * time.Second):
118                         t.Errorf("timeout waiting for handler %d to finish", k)
119                 case err := <-errc:
120                         if err != nil {
121                                 t.Fatal(err)
122                         }
123                 }
124         }
125
126         checkPushPromise := func(f Frame, promiseID uint32, wantH [][2]string) error {
127                 pp, ok := f.(*PushPromiseFrame)
128                 if !ok {
129                         return fmt.Errorf("got a %T; want *PushPromiseFrame", f)
130                 }
131                 if !pp.HeadersEnded() {
132                         return fmt.Errorf("want END_HEADERS flag in PushPromiseFrame")
133                 }
134                 if got, want := pp.PromiseID, promiseID; got != want {
135                         return fmt.Errorf("got PromiseID %v; want %v", got, want)
136                 }
137                 gotH := st.decodeHeader(pp.HeaderBlockFragment())
138                 if !reflect.DeepEqual(gotH, wantH) {
139                         return fmt.Errorf("got promised headers %v; want %v", gotH, wantH)
140                 }
141                 return nil
142         }
143         checkHeaders := func(f Frame, wantH [][2]string) error {
144                 hf, ok := f.(*HeadersFrame)
145                 if !ok {
146                         return fmt.Errorf("got a %T; want *HeadersFrame", f)
147                 }
148                 gotH := st.decodeHeader(hf.HeaderBlockFragment())
149                 if !reflect.DeepEqual(gotH, wantH) {
150                         return fmt.Errorf("got response headers %v; want %v", gotH, wantH)
151                 }
152                 return nil
153         }
154         checkData := func(f Frame, wantData string) error {
155                 df, ok := f.(*DataFrame)
156                 if !ok {
157                         return fmt.Errorf("got a %T; want *DataFrame", f)
158                 }
159                 if gotData := string(df.Data()); gotData != wantData {
160                         return fmt.Errorf("got response data %q; want %q", gotData, wantData)
161                 }
162                 return nil
163         }
164
165         // Stream 1 has 2 PUSH_PROMISE + HEADERS + DATA
166         // Stream 2 has HEADERS + DATA
167         // Stream 4 has HEADERS
168         expected := map[uint32][]func(Frame) error{
169                 1: {
170                         func(f Frame) error {
171                                 return checkPushPromise(f, 2, [][2]string{
172                                         {":method", "GET"},
173                                         {":scheme", "https"},
174                                         {":authority", st.ts.Listener.Addr().String()},
175                                         {":path", "/pushed?get"},
176                                         {"user-agent", userAgent},
177                                 })
178                         },
179                         func(f Frame) error {
180                                 return checkPushPromise(f, 4, [][2]string{
181                                         {":method", "HEAD"},
182                                         {":scheme", "https"},
183                                         {":authority", st.ts.Listener.Addr().String()},
184                                         {":path", "/pushed?head"},
185                                         {"cookie", cookie},
186                                         {"user-agent", userAgent},
187                                 })
188                         },
189                         func(f Frame) error {
190                                 return checkHeaders(f, [][2]string{
191                                         {":status", "200"},
192                                         {"content-type", "text/html"},
193                                         {"content-length", strconv.Itoa(len(mainBody))},
194                                 })
195                         },
196                         func(f Frame) error {
197                                 return checkData(f, mainBody)
198                         },
199                 },
200                 2: {
201                         func(f Frame) error {
202                                 return checkHeaders(f, [][2]string{
203                                         {":status", "200"},
204                                         {"content-type", "text/html"},
205                                         {"content-length", strconv.Itoa(len(pushedBody))},
206                                 })
207                         },
208                         func(f Frame) error {
209                                 return checkData(f, pushedBody)
210                         },
211                 },
212                 4: {
213                         func(f Frame) error {
214                                 return checkHeaders(f, [][2]string{
215                                         {":status", "204"},
216                                 })
217                         },
218                 },
219         }
220
221         consumed := map[uint32]int{}
222         for k := 0; len(expected) > 0; k++ {
223                 f, err := st.readFrame()
224                 if err != nil {
225                         for id, left := range expected {
226                                 t.Errorf("stream %d: missing %d frames", id, len(left))
227                         }
228                         t.Fatalf("readFrame %d: %v", k, err)
229                 }
230                 id := f.Header().StreamID
231                 label := fmt.Sprintf("stream %d, frame %d", id, consumed[id])
232                 if len(expected[id]) == 0 {
233                         t.Fatalf("%s: unexpected frame %#+v", label, f)
234                 }
235                 check := expected[id][0]
236                 expected[id] = expected[id][1:]
237                 if len(expected[id]) == 0 {
238                         delete(expected, id)
239                 }
240                 if err := check(f); err != nil {
241                         t.Fatalf("%s: %v", label, err)
242                 }
243                 consumed[id]++
244         }
245 }
246
247 func TestServer_Push_SuccessNoRace(t *testing.T) {
248         // Regression test for issue #18326. Ensure the request handler can mutate
249         // pushed request headers without racing with the PUSH_PROMISE write.
250         errc := make(chan error, 2)
251         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
252                 switch r.URL.RequestURI() {
253                 case "/":
254                         opt := &http.PushOptions{
255                                 Header: http.Header{"User-Agent": {"testagent"}},
256                         }
257                         if err := w.(http.Pusher).Push("/pushed", opt); err != nil {
258                                 errc <- fmt.Errorf("error pushing: %v", err)
259                                 return
260                         }
261                         w.WriteHeader(200)
262                         errc <- nil
263
264                 case "/pushed":
265                         // Update request header, ensure there is no race.
266                         r.Header.Set("User-Agent", "newagent")
267                         r.Header.Set("Cookie", "cookie")
268                         w.WriteHeader(200)
269                         errc <- nil
270
271                 default:
272                         errc <- fmt.Errorf("unknown RequestURL %q", r.URL.RequestURI())
273                 }
274         })
275
276         // Send one request, which should push one response.
277         st.greet()
278         getSlash(st)
279         for k := 0; k < 2; k++ {
280                 select {
281                 case <-time.After(2 * time.Second):
282                         t.Errorf("timeout waiting for handler %d to finish", k)
283                 case err := <-errc:
284                         if err != nil {
285                                 t.Fatal(err)
286                         }
287                 }
288         }
289 }
290
291 func TestServer_Push_RejectRecursivePush(t *testing.T) {
292         // Expect two requests, but might get three if there's a bug and the second push succeeds.
293         errc := make(chan error, 3)
294         handler := func(w http.ResponseWriter, r *http.Request) error {
295                 baseURL := "https://" + r.Host
296                 switch r.URL.Path {
297                 case "/":
298                         if err := w.(http.Pusher).Push(baseURL+"/push1", nil); err != nil {
299                                 return fmt.Errorf("first Push()=%v, want nil", err)
300                         }
301                         return nil
302
303                 case "/push1":
304                         if got, want := w.(http.Pusher).Push(baseURL+"/push2", nil), ErrRecursivePush; got != want {
305                                 return fmt.Errorf("Push()=%v, want %v", got, want)
306                         }
307                         return nil
308
309                 default:
310                         return fmt.Errorf("unexpected path: %q", r.URL.Path)
311                 }
312         }
313         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
314                 errc <- handler(w, r)
315         })
316         defer st.Close()
317         st.greet()
318         getSlash(st)
319         if err := <-errc; err != nil {
320                 t.Errorf("First request failed: %v", err)
321         }
322         if err := <-errc; err != nil {
323                 t.Errorf("Second request failed: %v", err)
324         }
325 }
326
327 func testServer_Push_RejectSingleRequest(t *testing.T, doPush func(http.Pusher, *http.Request) error, settings ...Setting) {
328         // Expect one request, but might get two if there's a bug and the push succeeds.
329         errc := make(chan error, 2)
330         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
331                 errc <- doPush(w.(http.Pusher), r)
332         })
333         defer st.Close()
334         st.greet()
335         if err := st.fr.WriteSettings(settings...); err != nil {
336                 st.t.Fatalf("WriteSettings: %v", err)
337         }
338         st.wantSettingsAck()
339         getSlash(st)
340         if err := <-errc; err != nil {
341                 t.Error(err)
342         }
343         // Should not get a PUSH_PROMISE frame.
344         hf := st.wantHeaders()
345         if !hf.StreamEnded() {
346                 t.Error("stream should end after headers")
347         }
348 }
349
350 func TestServer_Push_RejectIfDisabled(t *testing.T) {
351         testServer_Push_RejectSingleRequest(t,
352                 func(p http.Pusher, r *http.Request) error {
353                         if got, want := p.Push("https://"+r.Host+"/pushed", nil), http.ErrNotSupported; got != want {
354                                 return fmt.Errorf("Push()=%v, want %v", got, want)
355                         }
356                         return nil
357                 },
358                 Setting{SettingEnablePush, 0})
359 }
360
361 func TestServer_Push_RejectWhenNoConcurrentStreams(t *testing.T) {
362         testServer_Push_RejectSingleRequest(t,
363                 func(p http.Pusher, r *http.Request) error {
364                         if got, want := p.Push("https://"+r.Host+"/pushed", nil), ErrPushLimitReached; got != want {
365                                 return fmt.Errorf("Push()=%v, want %v", got, want)
366                         }
367                         return nil
368                 },
369                 Setting{SettingMaxConcurrentStreams, 0})
370 }
371
372 func TestServer_Push_RejectWrongScheme(t *testing.T) {
373         testServer_Push_RejectSingleRequest(t,
374                 func(p http.Pusher, r *http.Request) error {
375                         if err := p.Push("http://"+r.Host+"/pushed", nil); err == nil {
376                                 return errors.New("Push() should have failed (push target URL is http)")
377                         }
378                         return nil
379                 })
380 }
381
382 func TestServer_Push_RejectMissingHost(t *testing.T) {
383         testServer_Push_RejectSingleRequest(t,
384                 func(p http.Pusher, r *http.Request) error {
385                         if err := p.Push("https:pushed", nil); err == nil {
386                                 return errors.New("Push() should have failed (push target URL missing host)")
387                         }
388                         return nil
389                 })
390 }
391
392 func TestServer_Push_RejectRelativePath(t *testing.T) {
393         testServer_Push_RejectSingleRequest(t,
394                 func(p http.Pusher, r *http.Request) error {
395                         if err := p.Push("../test", nil); err == nil {
396                                 return errors.New("Push() should have failed (push target is a relative path)")
397                         }
398                         return nil
399                 })
400 }
401
402 func TestServer_Push_RejectForbiddenMethod(t *testing.T) {
403         testServer_Push_RejectSingleRequest(t,
404                 func(p http.Pusher, r *http.Request) error {
405                         if err := p.Push("https://"+r.Host+"/pushed", &http.PushOptions{Method: "POST"}); err == nil {
406                                 return errors.New("Push() should have failed (cannot promise a POST)")
407                         }
408                         return nil
409                 })
410 }
411
412 func TestServer_Push_RejectForbiddenHeader(t *testing.T) {
413         testServer_Push_RejectSingleRequest(t,
414                 func(p http.Pusher, r *http.Request) error {
415                         header := http.Header{
416                                 "Content-Length":   {"10"},
417                                 "Content-Encoding": {"gzip"},
418                                 "Trailer":          {"Foo"},
419                                 "Te":               {"trailers"},
420                                 "Host":             {"test.com"},
421                                 ":authority":       {"test.com"},
422                         }
423                         if err := p.Push("https://"+r.Host+"/pushed", &http.PushOptions{Header: header}); err == nil {
424                                 return errors.New("Push() should have failed (forbidden headers)")
425                         }
426                         return nil
427                 })
428 }
429
430 func TestServer_Push_StateTransitions(t *testing.T) {
431         const body = "foo"
432
433         gotPromise := make(chan bool)
434         finishedPush := make(chan bool)
435
436         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
437                 switch r.URL.RequestURI() {
438                 case "/":
439                         if err := w.(http.Pusher).Push("/pushed", nil); err != nil {
440                                 t.Errorf("Push error: %v", err)
441                         }
442                         // Don't finish this request until the push finishes so we don't
443                         // nondeterministically interleave output frames with the push.
444                         <-finishedPush
445                 case "/pushed":
446                         <-gotPromise
447                 }
448                 w.Header().Set("Content-Type", "text/html")
449                 w.Header().Set("Content-Length", strconv.Itoa(len(body)))
450                 w.WriteHeader(200)
451                 io.WriteString(w, body)
452         })
453         defer st.Close()
454
455         st.greet()
456         if st.stream(2) != nil {
457                 t.Fatal("stream 2 should be empty")
458         }
459         if got, want := st.streamState(2), stateIdle; got != want {
460                 t.Fatalf("streamState(2)=%v, want %v", got, want)
461         }
462         getSlash(st)
463         // After the PUSH_PROMISE is sent, the stream should be stateHalfClosedRemote.
464         st.wantPushPromise()
465         if got, want := st.streamState(2), stateHalfClosedRemote; got != want {
466                 t.Fatalf("streamState(2)=%v, want %v", got, want)
467         }
468         // We stall the HTTP handler for "/pushed" until the above check. If we don't
469         // stall the handler, then the handler might write HEADERS and DATA and finish
470         // the stream before we check st.streamState(2) -- should that happen, we'll
471         // see stateClosed and fail the above check.
472         close(gotPromise)
473         st.wantHeaders()
474         if df := st.wantData(); !df.StreamEnded() {
475                 t.Fatal("expected END_STREAM flag on DATA")
476         }
477         if got, want := st.streamState(2), stateClosed; got != want {
478                 t.Fatalf("streamState(2)=%v, want %v", got, want)
479         }
480         close(finishedPush)
481 }
482
483 func TestServer_Push_RejectAfterGoAway(t *testing.T) {
484         var readyOnce sync.Once
485         ready := make(chan struct{})
486         errc := make(chan error, 2)
487         st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
488                 select {
489                 case <-ready:
490                 case <-time.After(5 * time.Second):
491                         errc <- fmt.Errorf("timeout waiting for GOAWAY to be processed")
492                 }
493                 if got, want := w.(http.Pusher).Push("https://"+r.Host+"/pushed", nil), http.ErrNotSupported; got != want {
494                         errc <- fmt.Errorf("Push()=%v, want %v", got, want)
495                 }
496                 errc <- nil
497         })
498         defer st.Close()
499         st.greet()
500         getSlash(st)
501
502         // Send GOAWAY and wait for it to be processed.
503         st.fr.WriteGoAway(1, ErrCodeNo, nil)
504         go func() {
505                 for {
506                         select {
507                         case <-ready:
508                                 return
509                         default:
510                         }
511                         st.sc.serveMsgCh <- func(loopNum int) {
512                                 if !st.sc.pushEnabled {
513                                         readyOnce.Do(func() { close(ready) })
514                                 }
515                         }
516                 }
517         }()
518         if err := <-errc; err != nil {
519                 t.Error(err)
520         }
521 }