OSDN Git Service

new repo
[bytom/vapor.git] / vendor / golang.org / x / net / http2 / frame_test.go
1 // Copyright 2014 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package http2
6
7 import (
8         "bytes"
9         "fmt"
10         "io"
11         "reflect"
12         "strings"
13         "testing"
14         "unsafe"
15
16         "golang.org/x/net/http2/hpack"
17 )
18
19 func testFramer() (*Framer, *bytes.Buffer) {
20         buf := new(bytes.Buffer)
21         return NewFramer(buf, buf), buf
22 }
23
24 func TestFrameSizes(t *testing.T) {
25         // Catch people rearranging the FrameHeader fields.
26         if got, want := int(unsafe.Sizeof(FrameHeader{})), 12; got != want {
27                 t.Errorf("FrameHeader size = %d; want %d", got, want)
28         }
29 }
30
31 func TestFrameTypeString(t *testing.T) {
32         tests := []struct {
33                 ft   FrameType
34                 want string
35         }{
36                 {FrameData, "DATA"},
37                 {FramePing, "PING"},
38                 {FrameGoAway, "GOAWAY"},
39                 {0xf, "UNKNOWN_FRAME_TYPE_15"},
40         }
41
42         for i, tt := range tests {
43                 got := tt.ft.String()
44                 if got != tt.want {
45                         t.Errorf("%d. String(FrameType %d) = %q; want %q", i, int(tt.ft), got, tt.want)
46                 }
47         }
48 }
49
50 func TestWriteRST(t *testing.T) {
51         fr, buf := testFramer()
52         var streamID uint32 = 1<<24 + 2<<16 + 3<<8 + 4
53         var errCode uint32 = 7<<24 + 6<<16 + 5<<8 + 4
54         fr.WriteRSTStream(streamID, ErrCode(errCode))
55         const wantEnc = "\x00\x00\x04\x03\x00\x01\x02\x03\x04\x07\x06\x05\x04"
56         if buf.String() != wantEnc {
57                 t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)
58         }
59         f, err := fr.ReadFrame()
60         if err != nil {
61                 t.Fatal(err)
62         }
63         want := &RSTStreamFrame{
64                 FrameHeader: FrameHeader{
65                         valid:    true,
66                         Type:     0x3,
67                         Flags:    0x0,
68                         Length:   0x4,
69                         StreamID: 0x1020304,
70                 },
71                 ErrCode: 0x7060504,
72         }
73         if !reflect.DeepEqual(f, want) {
74                 t.Errorf("parsed back %#v; want %#v", f, want)
75         }
76 }
77
78 func TestWriteData(t *testing.T) {
79         fr, buf := testFramer()
80         var streamID uint32 = 1<<24 + 2<<16 + 3<<8 + 4
81         data := []byte("ABC")
82         fr.WriteData(streamID, true, data)
83         const wantEnc = "\x00\x00\x03\x00\x01\x01\x02\x03\x04ABC"
84         if buf.String() != wantEnc {
85                 t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)
86         }
87         f, err := fr.ReadFrame()
88         if err != nil {
89                 t.Fatal(err)
90         }
91         df, ok := f.(*DataFrame)
92         if !ok {
93                 t.Fatalf("got %T; want *DataFrame", f)
94         }
95         if !bytes.Equal(df.Data(), data) {
96                 t.Errorf("got %q; want %q", df.Data(), data)
97         }
98         if f.Header().Flags&1 == 0 {
99                 t.Errorf("didn't see END_STREAM flag")
100         }
101 }
102
103 func TestWriteDataPadded(t *testing.T) {
104         tests := [...]struct {
105                 streamID   uint32
106                 endStream  bool
107                 data       []byte
108                 pad        []byte
109                 wantHeader FrameHeader
110         }{
111                 // Unpadded:
112                 0: {
113                         streamID:  1,
114                         endStream: true,
115                         data:      []byte("foo"),
116                         pad:       nil,
117                         wantHeader: FrameHeader{
118                                 Type:     FrameData,
119                                 Flags:    FlagDataEndStream,
120                                 Length:   3,
121                                 StreamID: 1,
122                         },
123                 },
124
125                 // Padded bit set, but no padding:
126                 1: {
127                         streamID:  1,
128                         endStream: true,
129                         data:      []byte("foo"),
130                         pad:       []byte{},
131                         wantHeader: FrameHeader{
132                                 Type:     FrameData,
133                                 Flags:    FlagDataEndStream | FlagDataPadded,
134                                 Length:   4,
135                                 StreamID: 1,
136                         },
137                 },
138
139                 // Padded bit set, with padding:
140                 2: {
141                         streamID:  1,
142                         endStream: false,
143                         data:      []byte("foo"),
144                         pad:       []byte{0, 0, 0},
145                         wantHeader: FrameHeader{
146                                 Type:     FrameData,
147                                 Flags:    FlagDataPadded,
148                                 Length:   7,
149                                 StreamID: 1,
150                         },
151                 },
152         }
153         for i, tt := range tests {
154                 fr, _ := testFramer()
155                 fr.WriteDataPadded(tt.streamID, tt.endStream, tt.data, tt.pad)
156                 f, err := fr.ReadFrame()
157                 if err != nil {
158                         t.Errorf("%d. ReadFrame: %v", i, err)
159                         continue
160                 }
161                 got := f.Header()
162                 tt.wantHeader.valid = true
163                 if got != tt.wantHeader {
164                         t.Errorf("%d. read %+v; want %+v", i, got, tt.wantHeader)
165                         continue
166                 }
167                 df := f.(*DataFrame)
168                 if !bytes.Equal(df.Data(), tt.data) {
169                         t.Errorf("%d. got %q; want %q", i, df.Data(), tt.data)
170                 }
171         }
172 }
173
174 func TestWriteHeaders(t *testing.T) {
175         tests := []struct {
176                 name      string
177                 p         HeadersFrameParam
178                 wantEnc   string
179                 wantFrame *HeadersFrame
180         }{
181                 {
182                         "basic",
183                         HeadersFrameParam{
184                                 StreamID:      42,
185                                 BlockFragment: []byte("abc"),
186                                 Priority:      PriorityParam{},
187                         },
188                         "\x00\x00\x03\x01\x00\x00\x00\x00*abc",
189                         &HeadersFrame{
190                                 FrameHeader: FrameHeader{
191                                         valid:    true,
192                                         StreamID: 42,
193                                         Type:     FrameHeaders,
194                                         Length:   uint32(len("abc")),
195                                 },
196                                 Priority:      PriorityParam{},
197                                 headerFragBuf: []byte("abc"),
198                         },
199                 },
200                 {
201                         "basic + end flags",
202                         HeadersFrameParam{
203                                 StreamID:      42,
204                                 BlockFragment: []byte("abc"),
205                                 EndStream:     true,
206                                 EndHeaders:    true,
207                                 Priority:      PriorityParam{},
208                         },
209                         "\x00\x00\x03\x01\x05\x00\x00\x00*abc",
210                         &HeadersFrame{
211                                 FrameHeader: FrameHeader{
212                                         valid:    true,
213                                         StreamID: 42,
214                                         Type:     FrameHeaders,
215                                         Flags:    FlagHeadersEndStream | FlagHeadersEndHeaders,
216                                         Length:   uint32(len("abc")),
217                                 },
218                                 Priority:      PriorityParam{},
219                                 headerFragBuf: []byte("abc"),
220                         },
221                 },
222                 {
223                         "with padding",
224                         HeadersFrameParam{
225                                 StreamID:      42,
226                                 BlockFragment: []byte("abc"),
227                                 EndStream:     true,
228                                 EndHeaders:    true,
229                                 PadLength:     5,
230                                 Priority:      PriorityParam{},
231                         },
232                         "\x00\x00\t\x01\r\x00\x00\x00*\x05abc\x00\x00\x00\x00\x00",
233                         &HeadersFrame{
234                                 FrameHeader: FrameHeader{
235                                         valid:    true,
236                                         StreamID: 42,
237                                         Type:     FrameHeaders,
238                                         Flags:    FlagHeadersEndStream | FlagHeadersEndHeaders | FlagHeadersPadded,
239                                         Length:   uint32(1 + len("abc") + 5), // pad length + contents + padding
240                                 },
241                                 Priority:      PriorityParam{},
242                                 headerFragBuf: []byte("abc"),
243                         },
244                 },
245                 {
246                         "with priority",
247                         HeadersFrameParam{
248                                 StreamID:      42,
249                                 BlockFragment: []byte("abc"),
250                                 EndStream:     true,
251                                 EndHeaders:    true,
252                                 PadLength:     2,
253                                 Priority: PriorityParam{
254                                         StreamDep: 15,
255                                         Exclusive: true,
256                                         Weight:    127,
257                                 },
258                         },
259                         "\x00\x00\v\x01-\x00\x00\x00*\x02\x80\x00\x00\x0f\u007fabc\x00\x00",
260                         &HeadersFrame{
261                                 FrameHeader: FrameHeader{
262                                         valid:    true,
263                                         StreamID: 42,
264                                         Type:     FrameHeaders,
265                                         Flags:    FlagHeadersEndStream | FlagHeadersEndHeaders | FlagHeadersPadded | FlagHeadersPriority,
266                                         Length:   uint32(1 + 5 + len("abc") + 2), // pad length + priority + contents + padding
267                                 },
268                                 Priority: PriorityParam{
269                                         StreamDep: 15,
270                                         Exclusive: true,
271                                         Weight:    127,
272                                 },
273                                 headerFragBuf: []byte("abc"),
274                         },
275                 },
276                 {
277                         "with priority stream dep zero", // golang.org/issue/15444
278                         HeadersFrameParam{
279                                 StreamID:      42,
280                                 BlockFragment: []byte("abc"),
281                                 EndStream:     true,
282                                 EndHeaders:    true,
283                                 PadLength:     2,
284                                 Priority: PriorityParam{
285                                         StreamDep: 0,
286                                         Exclusive: true,
287                                         Weight:    127,
288                                 },
289                         },
290                         "\x00\x00\v\x01-\x00\x00\x00*\x02\x80\x00\x00\x00\u007fabc\x00\x00",
291                         &HeadersFrame{
292                                 FrameHeader: FrameHeader{
293                                         valid:    true,
294                                         StreamID: 42,
295                                         Type:     FrameHeaders,
296                                         Flags:    FlagHeadersEndStream | FlagHeadersEndHeaders | FlagHeadersPadded | FlagHeadersPriority,
297                                         Length:   uint32(1 + 5 + len("abc") + 2), // pad length + priority + contents + padding
298                                 },
299                                 Priority: PriorityParam{
300                                         StreamDep: 0,
301                                         Exclusive: true,
302                                         Weight:    127,
303                                 },
304                                 headerFragBuf: []byte("abc"),
305                         },
306                 },
307         }
308         for _, tt := range tests {
309                 fr, buf := testFramer()
310                 if err := fr.WriteHeaders(tt.p); err != nil {
311                         t.Errorf("test %q: %v", tt.name, err)
312                         continue
313                 }
314                 if buf.String() != tt.wantEnc {
315                         t.Errorf("test %q: encoded %q; want %q", tt.name, buf.Bytes(), tt.wantEnc)
316                 }
317                 f, err := fr.ReadFrame()
318                 if err != nil {
319                         t.Errorf("test %q: failed to read the frame back: %v", tt.name, err)
320                         continue
321                 }
322                 if !reflect.DeepEqual(f, tt.wantFrame) {
323                         t.Errorf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tt.name, f, tt.wantFrame)
324                 }
325         }
326 }
327
328 func TestWriteInvalidStreamDep(t *testing.T) {
329         fr, _ := testFramer()
330         err := fr.WriteHeaders(HeadersFrameParam{
331                 StreamID: 42,
332                 Priority: PriorityParam{
333                         StreamDep: 1 << 31,
334                 },
335         })
336         if err != errDepStreamID {
337                 t.Errorf("header error = %v; want %q", err, errDepStreamID)
338         }
339
340         err = fr.WritePriority(2, PriorityParam{StreamDep: 1 << 31})
341         if err != errDepStreamID {
342                 t.Errorf("priority error = %v; want %q", err, errDepStreamID)
343         }
344 }
345
346 func TestWriteContinuation(t *testing.T) {
347         const streamID = 42
348         tests := []struct {
349                 name string
350                 end  bool
351                 frag []byte
352
353                 wantFrame *ContinuationFrame
354         }{
355                 {
356                         "not end",
357                         false,
358                         []byte("abc"),
359                         &ContinuationFrame{
360                                 FrameHeader: FrameHeader{
361                                         valid:    true,
362                                         StreamID: streamID,
363                                         Type:     FrameContinuation,
364                                         Length:   uint32(len("abc")),
365                                 },
366                                 headerFragBuf: []byte("abc"),
367                         },
368                 },
369                 {
370                         "end",
371                         true,
372                         []byte("def"),
373                         &ContinuationFrame{
374                                 FrameHeader: FrameHeader{
375                                         valid:    true,
376                                         StreamID: streamID,
377                                         Type:     FrameContinuation,
378                                         Flags:    FlagContinuationEndHeaders,
379                                         Length:   uint32(len("def")),
380                                 },
381                                 headerFragBuf: []byte("def"),
382                         },
383                 },
384         }
385         for _, tt := range tests {
386                 fr, _ := testFramer()
387                 if err := fr.WriteContinuation(streamID, tt.end, tt.frag); err != nil {
388                         t.Errorf("test %q: %v", tt.name, err)
389                         continue
390                 }
391                 fr.AllowIllegalReads = true
392                 f, err := fr.ReadFrame()
393                 if err != nil {
394                         t.Errorf("test %q: failed to read the frame back: %v", tt.name, err)
395                         continue
396                 }
397                 if !reflect.DeepEqual(f, tt.wantFrame) {
398                         t.Errorf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tt.name, f, tt.wantFrame)
399                 }
400         }
401 }
402
403 func TestWritePriority(t *testing.T) {
404         const streamID = 42
405         tests := []struct {
406                 name      string
407                 priority  PriorityParam
408                 wantFrame *PriorityFrame
409         }{
410                 {
411                         "not exclusive",
412                         PriorityParam{
413                                 StreamDep: 2,
414                                 Exclusive: false,
415                                 Weight:    127,
416                         },
417                         &PriorityFrame{
418                                 FrameHeader{
419                                         valid:    true,
420                                         StreamID: streamID,
421                                         Type:     FramePriority,
422                                         Length:   5,
423                                 },
424                                 PriorityParam{
425                                         StreamDep: 2,
426                                         Exclusive: false,
427                                         Weight:    127,
428                                 },
429                         },
430                 },
431
432                 {
433                         "exclusive",
434                         PriorityParam{
435                                 StreamDep: 3,
436                                 Exclusive: true,
437                                 Weight:    77,
438                         },
439                         &PriorityFrame{
440                                 FrameHeader{
441                                         valid:    true,
442                                         StreamID: streamID,
443                                         Type:     FramePriority,
444                                         Length:   5,
445                                 },
446                                 PriorityParam{
447                                         StreamDep: 3,
448                                         Exclusive: true,
449                                         Weight:    77,
450                                 },
451                         },
452                 },
453         }
454         for _, tt := range tests {
455                 fr, _ := testFramer()
456                 if err := fr.WritePriority(streamID, tt.priority); err != nil {
457                         t.Errorf("test %q: %v", tt.name, err)
458                         continue
459                 }
460                 f, err := fr.ReadFrame()
461                 if err != nil {
462                         t.Errorf("test %q: failed to read the frame back: %v", tt.name, err)
463                         continue
464                 }
465                 if !reflect.DeepEqual(f, tt.wantFrame) {
466                         t.Errorf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tt.name, f, tt.wantFrame)
467                 }
468         }
469 }
470
471 func TestWriteSettings(t *testing.T) {
472         fr, buf := testFramer()
473         settings := []Setting{{1, 2}, {3, 4}}
474         fr.WriteSettings(settings...)
475         const wantEnc = "\x00\x00\f\x04\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02\x00\x03\x00\x00\x00\x04"
476         if buf.String() != wantEnc {
477                 t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)
478         }
479         f, err := fr.ReadFrame()
480         if err != nil {
481                 t.Fatal(err)
482         }
483         sf, ok := f.(*SettingsFrame)
484         if !ok {
485                 t.Fatalf("Got a %T; want a SettingsFrame", f)
486         }
487         var got []Setting
488         sf.ForeachSetting(func(s Setting) error {
489                 got = append(got, s)
490                 valBack, ok := sf.Value(s.ID)
491                 if !ok || valBack != s.Val {
492                         t.Errorf("Value(%d) = %v, %v; want %v, true", s.ID, valBack, ok, s.Val)
493                 }
494                 return nil
495         })
496         if !reflect.DeepEqual(settings, got) {
497                 t.Errorf("Read settings %+v != written settings %+v", got, settings)
498         }
499 }
500
501 func TestWriteSettingsAck(t *testing.T) {
502         fr, buf := testFramer()
503         fr.WriteSettingsAck()
504         const wantEnc = "\x00\x00\x00\x04\x01\x00\x00\x00\x00"
505         if buf.String() != wantEnc {
506                 t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)
507         }
508 }
509
510 func TestWriteWindowUpdate(t *testing.T) {
511         fr, buf := testFramer()
512         const streamID = 1<<24 + 2<<16 + 3<<8 + 4
513         const incr = 7<<24 + 6<<16 + 5<<8 + 4
514         if err := fr.WriteWindowUpdate(streamID, incr); err != nil {
515                 t.Fatal(err)
516         }
517         const wantEnc = "\x00\x00\x04\x08\x00\x01\x02\x03\x04\x07\x06\x05\x04"
518         if buf.String() != wantEnc {
519                 t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)
520         }
521         f, err := fr.ReadFrame()
522         if err != nil {
523                 t.Fatal(err)
524         }
525         want := &WindowUpdateFrame{
526                 FrameHeader: FrameHeader{
527                         valid:    true,
528                         Type:     0x8,
529                         Flags:    0x0,
530                         Length:   0x4,
531                         StreamID: 0x1020304,
532                 },
533                 Increment: 0x7060504,
534         }
535         if !reflect.DeepEqual(f, want) {
536                 t.Errorf("parsed back %#v; want %#v", f, want)
537         }
538 }
539
540 func TestWritePing(t *testing.T)    { testWritePing(t, false) }
541 func TestWritePingAck(t *testing.T) { testWritePing(t, true) }
542
543 func testWritePing(t *testing.T, ack bool) {
544         fr, buf := testFramer()
545         if err := fr.WritePing(ack, [8]byte{1, 2, 3, 4, 5, 6, 7, 8}); err != nil {
546                 t.Fatal(err)
547         }
548         var wantFlags Flags
549         if ack {
550                 wantFlags = FlagPingAck
551         }
552         var wantEnc = "\x00\x00\x08\x06" + string(wantFlags) + "\x00\x00\x00\x00" + "\x01\x02\x03\x04\x05\x06\x07\x08"
553         if buf.String() != wantEnc {
554                 t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)
555         }
556
557         f, err := fr.ReadFrame()
558         if err != nil {
559                 t.Fatal(err)
560         }
561         want := &PingFrame{
562                 FrameHeader: FrameHeader{
563                         valid:    true,
564                         Type:     0x6,
565                         Flags:    wantFlags,
566                         Length:   0x8,
567                         StreamID: 0,
568                 },
569                 Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8},
570         }
571         if !reflect.DeepEqual(f, want) {
572                 t.Errorf("parsed back %#v; want %#v", f, want)
573         }
574 }
575
576 func TestReadFrameHeader(t *testing.T) {
577         tests := []struct {
578                 in   string
579                 want FrameHeader
580         }{
581                 {in: "\x00\x00\x00" + "\x00" + "\x00" + "\x00\x00\x00\x00", want: FrameHeader{}},
582                 {in: "\x01\x02\x03" + "\x04" + "\x05" + "\x06\x07\x08\x09", want: FrameHeader{
583                         Length: 66051, Type: 4, Flags: 5, StreamID: 101124105,
584                 }},
585                 // Ignore high bit:
586                 {in: "\xff\xff\xff" + "\xff" + "\xff" + "\xff\xff\xff\xff", want: FrameHeader{
587                         Length: 16777215, Type: 255, Flags: 255, StreamID: 2147483647}},
588                 {in: "\xff\xff\xff" + "\xff" + "\xff" + "\x7f\xff\xff\xff", want: FrameHeader{
589                         Length: 16777215, Type: 255, Flags: 255, StreamID: 2147483647}},
590         }
591         for i, tt := range tests {
592                 got, err := readFrameHeader(make([]byte, 9), strings.NewReader(tt.in))
593                 if err != nil {
594                         t.Errorf("%d. readFrameHeader(%q) = %v", i, tt.in, err)
595                         continue
596                 }
597                 tt.want.valid = true
598                 if got != tt.want {
599                         t.Errorf("%d. readFrameHeader(%q) = %+v; want %+v", i, tt.in, got, tt.want)
600                 }
601         }
602 }
603
604 func TestReadWriteFrameHeader(t *testing.T) {
605         tests := []struct {
606                 len      uint32
607                 typ      FrameType
608                 flags    Flags
609                 streamID uint32
610         }{
611                 {len: 0, typ: 255, flags: 1, streamID: 0},
612                 {len: 0, typ: 255, flags: 1, streamID: 1},
613                 {len: 0, typ: 255, flags: 1, streamID: 255},
614                 {len: 0, typ: 255, flags: 1, streamID: 256},
615                 {len: 0, typ: 255, flags: 1, streamID: 65535},
616                 {len: 0, typ: 255, flags: 1, streamID: 65536},
617
618                 {len: 0, typ: 1, flags: 255, streamID: 1},
619                 {len: 255, typ: 1, flags: 255, streamID: 1},
620                 {len: 256, typ: 1, flags: 255, streamID: 1},
621                 {len: 65535, typ: 1, flags: 255, streamID: 1},
622                 {len: 65536, typ: 1, flags: 255, streamID: 1},
623                 {len: 16777215, typ: 1, flags: 255, streamID: 1},
624         }
625         for _, tt := range tests {
626                 fr, buf := testFramer()
627                 fr.startWrite(tt.typ, tt.flags, tt.streamID)
628                 fr.writeBytes(make([]byte, tt.len))
629                 fr.endWrite()
630                 fh, err := ReadFrameHeader(buf)
631                 if err != nil {
632                         t.Errorf("ReadFrameHeader(%+v) = %v", tt, err)
633                         continue
634                 }
635                 if fh.Type != tt.typ || fh.Flags != tt.flags || fh.Length != tt.len || fh.StreamID != tt.streamID {
636                         t.Errorf("ReadFrameHeader(%+v) = %+v; mismatch", tt, fh)
637                 }
638         }
639
640 }
641
642 func TestWriteTooLargeFrame(t *testing.T) {
643         fr, _ := testFramer()
644         fr.startWrite(0, 1, 1)
645         fr.writeBytes(make([]byte, 1<<24))
646         err := fr.endWrite()
647         if err != ErrFrameTooLarge {
648                 t.Errorf("endWrite = %v; want errFrameTooLarge", err)
649         }
650 }
651
652 func TestWriteGoAway(t *testing.T) {
653         const debug = "foo"
654         fr, buf := testFramer()
655         if err := fr.WriteGoAway(0x01020304, 0x05060708, []byte(debug)); err != nil {
656                 t.Fatal(err)
657         }
658         const wantEnc = "\x00\x00\v\a\x00\x00\x00\x00\x00\x01\x02\x03\x04\x05\x06\x07\x08" + debug
659         if buf.String() != wantEnc {
660                 t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)
661         }
662         f, err := fr.ReadFrame()
663         if err != nil {
664                 t.Fatal(err)
665         }
666         want := &GoAwayFrame{
667                 FrameHeader: FrameHeader{
668                         valid:    true,
669                         Type:     0x7,
670                         Flags:    0,
671                         Length:   uint32(4 + 4 + len(debug)),
672                         StreamID: 0,
673                 },
674                 LastStreamID: 0x01020304,
675                 ErrCode:      0x05060708,
676                 debugData:    []byte(debug),
677         }
678         if !reflect.DeepEqual(f, want) {
679                 t.Fatalf("parsed back:\n%#v\nwant:\n%#v", f, want)
680         }
681         if got := string(f.(*GoAwayFrame).DebugData()); got != debug {
682                 t.Errorf("debug data = %q; want %q", got, debug)
683         }
684 }
685
686 func TestWritePushPromise(t *testing.T) {
687         pp := PushPromiseParam{
688                 StreamID:      42,
689                 PromiseID:     42,
690                 BlockFragment: []byte("abc"),
691         }
692         fr, buf := testFramer()
693         if err := fr.WritePushPromise(pp); err != nil {
694                 t.Fatal(err)
695         }
696         const wantEnc = "\x00\x00\x07\x05\x00\x00\x00\x00*\x00\x00\x00*abc"
697         if buf.String() != wantEnc {
698                 t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)
699         }
700         f, err := fr.ReadFrame()
701         if err != nil {
702                 t.Fatal(err)
703         }
704         _, ok := f.(*PushPromiseFrame)
705         if !ok {
706                 t.Fatalf("got %T; want *PushPromiseFrame", f)
707         }
708         want := &PushPromiseFrame{
709                 FrameHeader: FrameHeader{
710                         valid:    true,
711                         Type:     0x5,
712                         Flags:    0x0,
713                         Length:   0x7,
714                         StreamID: 42,
715                 },
716                 PromiseID:     42,
717                 headerFragBuf: []byte("abc"),
718         }
719         if !reflect.DeepEqual(f, want) {
720                 t.Fatalf("parsed back:\n%#v\nwant:\n%#v", f, want)
721         }
722 }
723
724 // test checkFrameOrder and that HEADERS and CONTINUATION frames can't be intermingled.
725 func TestReadFrameOrder(t *testing.T) {
726         head := func(f *Framer, id uint32, end bool) {
727                 f.WriteHeaders(HeadersFrameParam{
728                         StreamID:      id,
729                         BlockFragment: []byte("foo"), // unused, but non-empty
730                         EndHeaders:    end,
731                 })
732         }
733         cont := func(f *Framer, id uint32, end bool) {
734                 f.WriteContinuation(id, end, []byte("foo"))
735         }
736
737         tests := [...]struct {
738                 name    string
739                 w       func(*Framer)
740                 atLeast int
741                 wantErr string
742         }{
743                 0: {
744                         w: func(f *Framer) {
745                                 head(f, 1, true)
746                         },
747                 },
748                 1: {
749                         w: func(f *Framer) {
750                                 head(f, 1, true)
751                                 head(f, 2, true)
752                         },
753                 },
754                 2: {
755                         wantErr: "got HEADERS for stream 2; expected CONTINUATION following HEADERS for stream 1",
756                         w: func(f *Framer) {
757                                 head(f, 1, false)
758                                 head(f, 2, true)
759                         },
760                 },
761                 3: {
762                         wantErr: "got DATA for stream 1; expected CONTINUATION following HEADERS for stream 1",
763                         w: func(f *Framer) {
764                                 head(f, 1, false)
765                         },
766                 },
767                 4: {
768                         w: func(f *Framer) {
769                                 head(f, 1, false)
770                                 cont(f, 1, true)
771                                 head(f, 2, true)
772                         },
773                 },
774                 5: {
775                         wantErr: "got CONTINUATION for stream 2; expected stream 1",
776                         w: func(f *Framer) {
777                                 head(f, 1, false)
778                                 cont(f, 2, true)
779                                 head(f, 2, true)
780                         },
781                 },
782                 6: {
783                         wantErr: "unexpected CONTINUATION for stream 1",
784                         w: func(f *Framer) {
785                                 cont(f, 1, true)
786                         },
787                 },
788                 7: {
789                         wantErr: "unexpected CONTINUATION for stream 1",
790                         w: func(f *Framer) {
791                                 cont(f, 1, false)
792                         },
793                 },
794                 8: {
795                         wantErr: "HEADERS frame with stream ID 0",
796                         w: func(f *Framer) {
797                                 head(f, 0, true)
798                         },
799                 },
800                 9: {
801                         wantErr: "CONTINUATION frame with stream ID 0",
802                         w: func(f *Framer) {
803                                 cont(f, 0, true)
804                         },
805                 },
806                 10: {
807                         wantErr: "unexpected CONTINUATION for stream 1",
808                         atLeast: 5,
809                         w: func(f *Framer) {
810                                 head(f, 1, false)
811                                 cont(f, 1, false)
812                                 cont(f, 1, false)
813                                 cont(f, 1, false)
814                                 cont(f, 1, true)
815                                 cont(f, 1, false)
816                         },
817                 },
818         }
819         for i, tt := range tests {
820                 buf := new(bytes.Buffer)
821                 f := NewFramer(buf, buf)
822                 f.AllowIllegalWrites = true
823                 tt.w(f)
824                 f.WriteData(1, true, nil) // to test transition away from last step
825
826                 var err error
827                 n := 0
828                 var log bytes.Buffer
829                 for {
830                         var got Frame
831                         got, err = f.ReadFrame()
832                         fmt.Fprintf(&log, "  read %v, %v\n", got, err)
833                         if err != nil {
834                                 break
835                         }
836                         n++
837                 }
838                 if err == io.EOF {
839                         err = nil
840                 }
841                 ok := tt.wantErr == ""
842                 if ok && err != nil {
843                         t.Errorf("%d. after %d good frames, ReadFrame = %v; want success\n%s", i, n, err, log.Bytes())
844                         continue
845                 }
846                 if !ok && err != ConnectionError(ErrCodeProtocol) {
847                         t.Errorf("%d. after %d good frames, ReadFrame = %v; want ConnectionError(ErrCodeProtocol)\n%s", i, n, err, log.Bytes())
848                         continue
849                 }
850                 if !((f.errDetail == nil && tt.wantErr == "") || (fmt.Sprint(f.errDetail) == tt.wantErr)) {
851                         t.Errorf("%d. framer eror = %q; want %q\n%s", i, f.errDetail, tt.wantErr, log.Bytes())
852                 }
853                 if n < tt.atLeast {
854                         t.Errorf("%d. framer only read %d frames; want at least %d\n%s", i, n, tt.atLeast, log.Bytes())
855                 }
856         }
857 }
858
859 func TestMetaFrameHeader(t *testing.T) {
860         write := func(f *Framer, frags ...[]byte) {
861                 for i, frag := range frags {
862                         end := (i == len(frags)-1)
863                         if i == 0 {
864                                 f.WriteHeaders(HeadersFrameParam{
865                                         StreamID:      1,
866                                         BlockFragment: frag,
867                                         EndHeaders:    end,
868                                 })
869                         } else {
870                                 f.WriteContinuation(1, end, frag)
871                         }
872                 }
873         }
874
875         want := func(flags Flags, length uint32, pairs ...string) *MetaHeadersFrame {
876                 mh := &MetaHeadersFrame{
877                         HeadersFrame: &HeadersFrame{
878                                 FrameHeader: FrameHeader{
879                                         Type:     FrameHeaders,
880                                         Flags:    flags,
881                                         Length:   length,
882                                         StreamID: 1,
883                                 },
884                         },
885                         Fields: []hpack.HeaderField(nil),
886                 }
887                 for len(pairs) > 0 {
888                         mh.Fields = append(mh.Fields, hpack.HeaderField{
889                                 Name:  pairs[0],
890                                 Value: pairs[1],
891                         })
892                         pairs = pairs[2:]
893                 }
894                 return mh
895         }
896         truncated := func(mh *MetaHeadersFrame) *MetaHeadersFrame {
897                 mh.Truncated = true
898                 return mh
899         }
900
901         const noFlags Flags = 0
902
903         oneKBString := strings.Repeat("a", 1<<10)
904
905         tests := [...]struct {
906                 name              string
907                 w                 func(*Framer)
908                 want              interface{} // *MetaHeaderFrame or error
909                 wantErrReason     string
910                 maxHeaderListSize uint32
911         }{
912                 0: {
913                         name: "single_headers",
914                         w: func(f *Framer) {
915                                 var he hpackEncoder
916                                 all := he.encodeHeaderRaw(t, ":method", "GET", ":path", "/")
917                                 write(f, all)
918                         },
919                         want: want(FlagHeadersEndHeaders, 2, ":method", "GET", ":path", "/"),
920                 },
921                 1: {
922                         name: "with_continuation",
923                         w: func(f *Framer) {
924                                 var he hpackEncoder
925                                 all := he.encodeHeaderRaw(t, ":method", "GET", ":path", "/", "foo", "bar")
926                                 write(f, all[:1], all[1:])
927                         },
928                         want: want(noFlags, 1, ":method", "GET", ":path", "/", "foo", "bar"),
929                 },
930                 2: {
931                         name: "with_two_continuation",
932                         w: func(f *Framer) {
933                                 var he hpackEncoder
934                                 all := he.encodeHeaderRaw(t, ":method", "GET", ":path", "/", "foo", "bar")
935                                 write(f, all[:2], all[2:4], all[4:])
936                         },
937                         want: want(noFlags, 2, ":method", "GET", ":path", "/", "foo", "bar"),
938                 },
939                 3: {
940                         name: "big_string_okay",
941                         w: func(f *Framer) {
942                                 var he hpackEncoder
943                                 all := he.encodeHeaderRaw(t, ":method", "GET", ":path", "/", "foo", oneKBString)
944                                 write(f, all[:2], all[2:])
945                         },
946                         want: want(noFlags, 2, ":method", "GET", ":path", "/", "foo", oneKBString),
947                 },
948                 4: {
949                         name: "big_string_error",
950                         w: func(f *Framer) {
951                                 var he hpackEncoder
952                                 all := he.encodeHeaderRaw(t, ":method", "GET", ":path", "/", "foo", oneKBString)
953                                 write(f, all[:2], all[2:])
954                         },
955                         maxHeaderListSize: (1 << 10) / 2,
956                         want:              ConnectionError(ErrCodeCompression),
957                 },
958                 5: {
959                         name: "max_header_list_truncated",
960                         w: func(f *Framer) {
961                                 var he hpackEncoder
962                                 var pairs = []string{":method", "GET", ":path", "/"}
963                                 for i := 0; i < 100; i++ {
964                                         pairs = append(pairs, "foo", "bar")
965                                 }
966                                 all := he.encodeHeaderRaw(t, pairs...)
967                                 write(f, all[:2], all[2:])
968                         },
969                         maxHeaderListSize: (1 << 10) / 2,
970                         want: truncated(want(noFlags, 2,
971                                 ":method", "GET",
972                                 ":path", "/",
973                                 "foo", "bar",
974                                 "foo", "bar",
975                                 "foo", "bar",
976                                 "foo", "bar",
977                                 "foo", "bar",
978                                 "foo", "bar",
979                                 "foo", "bar",
980                                 "foo", "bar",
981                                 "foo", "bar",
982                                 "foo", "bar",
983                                 "foo", "bar", // 11
984                         )),
985                 },
986                 6: {
987                         name: "pseudo_order",
988                         w: func(f *Framer) {
989                                 write(f, encodeHeaderRaw(t,
990                                         ":method", "GET",
991                                         "foo", "bar",
992                                         ":path", "/", // bogus
993                                 ))
994                         },
995                         want:          streamError(1, ErrCodeProtocol),
996                         wantErrReason: "pseudo header field after regular",
997                 },
998                 7: {
999                         name: "pseudo_unknown",
1000                         w: func(f *Framer) {
1001                                 write(f, encodeHeaderRaw(t,
1002                                         ":unknown", "foo", // bogus
1003                                         "foo", "bar",
1004                                 ))
1005                         },
1006                         want:          streamError(1, ErrCodeProtocol),
1007                         wantErrReason: "invalid pseudo-header \":unknown\"",
1008                 },
1009                 8: {
1010                         name: "pseudo_mix_request_response",
1011                         w: func(f *Framer) {
1012                                 write(f, encodeHeaderRaw(t,
1013                                         ":method", "GET",
1014                                         ":status", "100",
1015                                 ))
1016                         },
1017                         want:          streamError(1, ErrCodeProtocol),
1018                         wantErrReason: "mix of request and response pseudo headers",
1019                 },
1020                 9: {
1021                         name: "pseudo_dup",
1022                         w: func(f *Framer) {
1023                                 write(f, encodeHeaderRaw(t,
1024                                         ":method", "GET",
1025                                         ":method", "POST",
1026                                 ))
1027                         },
1028                         want:          streamError(1, ErrCodeProtocol),
1029                         wantErrReason: "duplicate pseudo-header \":method\"",
1030                 },
1031                 10: {
1032                         name: "trailer_okay_no_pseudo",
1033                         w:    func(f *Framer) { write(f, encodeHeaderRaw(t, "foo", "bar")) },
1034                         want: want(FlagHeadersEndHeaders, 8, "foo", "bar"),
1035                 },
1036                 11: {
1037                         name:          "invalid_field_name",
1038                         w:             func(f *Framer) { write(f, encodeHeaderRaw(t, "CapitalBad", "x")) },
1039                         want:          streamError(1, ErrCodeProtocol),
1040                         wantErrReason: "invalid header field name \"CapitalBad\"",
1041                 },
1042                 12: {
1043                         name:          "invalid_field_value",
1044                         w:             func(f *Framer) { write(f, encodeHeaderRaw(t, "key", "bad_null\x00")) },
1045                         want:          streamError(1, ErrCodeProtocol),
1046                         wantErrReason: "invalid header field value \"bad_null\\x00\"",
1047                 },
1048         }
1049         for i, tt := range tests {
1050                 buf := new(bytes.Buffer)
1051                 f := NewFramer(buf, buf)
1052                 f.ReadMetaHeaders = hpack.NewDecoder(initialHeaderTableSize, nil)
1053                 f.MaxHeaderListSize = tt.maxHeaderListSize
1054                 tt.w(f)
1055
1056                 name := tt.name
1057                 if name == "" {
1058                         name = fmt.Sprintf("test index %d", i)
1059                 }
1060
1061                 var got interface{}
1062                 var err error
1063                 got, err = f.ReadFrame()
1064                 if err != nil {
1065                         got = err
1066
1067                         // Ignore the StreamError.Cause field, if it matches the wantErrReason.
1068                         // The test table above predates the Cause field.
1069                         if se, ok := err.(StreamError); ok && se.Cause != nil && se.Cause.Error() == tt.wantErrReason {
1070                                 se.Cause = nil
1071                                 got = se
1072                         }
1073                 }
1074                 if !reflect.DeepEqual(got, tt.want) {
1075                         if mhg, ok := got.(*MetaHeadersFrame); ok {
1076                                 if mhw, ok := tt.want.(*MetaHeadersFrame); ok {
1077                                         hg := mhg.HeadersFrame
1078                                         hw := mhw.HeadersFrame
1079                                         if hg != nil && hw != nil && !reflect.DeepEqual(*hg, *hw) {
1080                                                 t.Errorf("%s: headers differ:\n got: %+v\nwant: %+v\n", name, *hg, *hw)
1081                                         }
1082                                 }
1083                         }
1084                         str := func(v interface{}) string {
1085                                 if _, ok := v.(error); ok {
1086                                         return fmt.Sprintf("error %v", v)
1087                                 } else {
1088                                         return fmt.Sprintf("value %#v", v)
1089                                 }
1090                         }
1091                         t.Errorf("%s:\n got: %v\nwant: %s", name, str(got), str(tt.want))
1092                 }
1093                 if tt.wantErrReason != "" && tt.wantErrReason != fmt.Sprint(f.errDetail) {
1094                         t.Errorf("%s: got error reason %q; want %q", name, f.errDetail, tt.wantErrReason)
1095                 }
1096         }
1097 }
1098
1099 func TestSetReuseFrames(t *testing.T) {
1100         fr, buf := testFramer()
1101         fr.SetReuseFrames()
1102
1103         // Check that DataFrames are reused. Note that
1104         // SetReuseFrames only currently implements reuse of DataFrames.
1105         firstDf := readAndVerifyDataFrame("ABC", 3, fr, buf, t)
1106
1107         for i := 0; i < 10; i++ {
1108                 df := readAndVerifyDataFrame("XYZ", 3, fr, buf, t)
1109                 if df != firstDf {
1110                         t.Errorf("Expected Framer to return references to the same DataFrame. Have %v and %v", &df, &firstDf)
1111                 }
1112         }
1113
1114         for i := 0; i < 10; i++ {
1115                 df := readAndVerifyDataFrame("", 0, fr, buf, t)
1116                 if df != firstDf {
1117                         t.Errorf("Expected Framer to return references to the same DataFrame. Have %v and %v", &df, &firstDf)
1118                 }
1119         }
1120
1121         for i := 0; i < 10; i++ {
1122                 df := readAndVerifyDataFrame("HHH", 3, fr, buf, t)
1123                 if df != firstDf {
1124                         t.Errorf("Expected Framer to return references to the same DataFrame. Have %v and %v", &df, &firstDf)
1125                 }
1126         }
1127 }
1128
1129 func TestSetReuseFramesMoreThanOnce(t *testing.T) {
1130         fr, buf := testFramer()
1131         fr.SetReuseFrames()
1132
1133         firstDf := readAndVerifyDataFrame("ABC", 3, fr, buf, t)
1134         fr.SetReuseFrames()
1135
1136         for i := 0; i < 10; i++ {
1137                 df := readAndVerifyDataFrame("XYZ", 3, fr, buf, t)
1138                 // SetReuseFrames should be idempotent
1139                 fr.SetReuseFrames()
1140                 if df != firstDf {
1141                         t.Errorf("Expected Framer to return references to the same DataFrame. Have %v and %v", &df, &firstDf)
1142                 }
1143         }
1144 }
1145
1146 func TestNoSetReuseFrames(t *testing.T) {
1147         fr, buf := testFramer()
1148         const numNewDataFrames = 10
1149         dfSoFar := make([]interface{}, numNewDataFrames)
1150
1151         // Check that DataFrames are not reused if SetReuseFrames wasn't called.
1152         // SetReuseFrames only currently implements reuse of DataFrames.
1153         for i := 0; i < numNewDataFrames; i++ {
1154                 df := readAndVerifyDataFrame("XYZ", 3, fr, buf, t)
1155                 for _, item := range dfSoFar {
1156                         if df == item {
1157                                 t.Errorf("Expected Framer to return new DataFrames since SetNoReuseFrames not set.")
1158                         }
1159                 }
1160                 dfSoFar[i] = df
1161         }
1162 }
1163
1164 func readAndVerifyDataFrame(data string, length byte, fr *Framer, buf *bytes.Buffer, t *testing.T) *DataFrame {
1165         var streamID uint32 = 1<<24 + 2<<16 + 3<<8 + 4
1166         fr.WriteData(streamID, true, []byte(data))
1167         wantEnc := "\x00\x00" + string(length) + "\x00\x01\x01\x02\x03\x04" + data
1168         if buf.String() != wantEnc {
1169                 t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)
1170         }
1171         f, err := fr.ReadFrame()
1172         if err != nil {
1173                 t.Fatal(err)
1174         }
1175         df, ok := f.(*DataFrame)
1176         if !ok {
1177                 t.Fatalf("got %T; want *DataFrame", f)
1178         }
1179         if !bytes.Equal(df.Data(), []byte(data)) {
1180                 t.Errorf("got %q; want %q", df.Data(), []byte(data))
1181         }
1182         if f.Header().Flags&1 == 0 {
1183                 t.Errorf("didn't see END_STREAM flag")
1184         }
1185         return df
1186 }
1187
1188 func encodeHeaderRaw(t *testing.T, pairs ...string) []byte {
1189         var he hpackEncoder
1190         return he.encodeHeaderRaw(t, pairs...)
1191 }