OSDN Git Service

new repo
[bytom/vapor.git] / vendor / github.com / gorilla / websocket / conn_test.go
1 // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package websocket
6
7 import (
8         "bufio"
9         "bytes"
10         "errors"
11         "fmt"
12         "io"
13         "io/ioutil"
14         "net"
15         "reflect"
16         "sync"
17         "testing"
18         "testing/iotest"
19         "time"
20 )
21
22 var _ net.Error = errWriteTimeout
23
24 type fakeNetConn struct {
25         io.Reader
26         io.Writer
27 }
28
29 func (c fakeNetConn) Close() error                       { return nil }
30 func (c fakeNetConn) LocalAddr() net.Addr                { return localAddr }
31 func (c fakeNetConn) RemoteAddr() net.Addr               { return remoteAddr }
32 func (c fakeNetConn) SetDeadline(t time.Time) error      { return nil }
33 func (c fakeNetConn) SetReadDeadline(t time.Time) error  { return nil }
34 func (c fakeNetConn) SetWriteDeadline(t time.Time) error { return nil }
35
36 type fakeAddr int
37
38 var (
39         localAddr  = fakeAddr(1)
40         remoteAddr = fakeAddr(2)
41 )
42
43 func (a fakeAddr) Network() string {
44         return "net"
45 }
46
47 func (a fakeAddr) String() string {
48         return "str"
49 }
50
51 // newTestConn creates a connnection backed by a fake network connection using
52 // default values for buffering.
53 func newTestConn(r io.Reader, w io.Writer, isServer bool) *Conn {
54         return newConn(fakeNetConn{Reader: r, Writer: w}, isServer, 1024, 1024, nil, nil, nil)
55 }
56
57 func TestFraming(t *testing.T) {
58         frameSizes := []int{0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, 65536, 65537}
59         var readChunkers = []struct {
60                 name string
61                 f    func(io.Reader) io.Reader
62         }{
63                 {"half", iotest.HalfReader},
64                 {"one", iotest.OneByteReader},
65                 {"asis", func(r io.Reader) io.Reader { return r }},
66         }
67         writeBuf := make([]byte, 65537)
68         for i := range writeBuf {
69                 writeBuf[i] = byte(i)
70         }
71         var writers = []struct {
72                 name string
73                 f    func(w io.Writer, n int) (int, error)
74         }{
75                 {"iocopy", func(w io.Writer, n int) (int, error) {
76                         nn, err := io.Copy(w, bytes.NewReader(writeBuf[:n]))
77                         return int(nn), err
78                 }},
79                 {"write", func(w io.Writer, n int) (int, error) {
80                         return w.Write(writeBuf[:n])
81                 }},
82                 {"string", func(w io.Writer, n int) (int, error) {
83                         return io.WriteString(w, string(writeBuf[:n]))
84                 }},
85         }
86
87         for _, compress := range []bool{false, true} {
88                 for _, isServer := range []bool{true, false} {
89                         for _, chunker := range readChunkers {
90
91                                 var connBuf bytes.Buffer
92                                 wc := newTestConn(nil, &connBuf, isServer)
93                                 rc := newTestConn(chunker.f(&connBuf), nil, !isServer)
94                                 if compress {
95                                         wc.newCompressionWriter = compressNoContextTakeover
96                                         rc.newDecompressionReader = decompressNoContextTakeover
97                                 }
98                                 for _, n := range frameSizes {
99                                         for _, writer := range writers {
100                                                 name := fmt.Sprintf("z:%v, s:%v, r:%s, n:%d w:%s", compress, isServer, chunker.name, n, writer.name)
101
102                                                 w, err := wc.NextWriter(TextMessage)
103                                                 if err != nil {
104                                                         t.Errorf("%s: wc.NextWriter() returned %v", name, err)
105                                                         continue
106                                                 }
107                                                 nn, err := writer.f(w, n)
108                                                 if err != nil || nn != n {
109                                                         t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err)
110                                                         continue
111                                                 }
112                                                 err = w.Close()
113                                                 if err != nil {
114                                                         t.Errorf("%s: w.Close() returned %v", name, err)
115                                                         continue
116                                                 }
117
118                                                 opCode, r, err := rc.NextReader()
119                                                 if err != nil || opCode != TextMessage {
120                                                         t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err)
121                                                         continue
122                                                 }
123                                                 rbuf, err := ioutil.ReadAll(r)
124                                                 if err != nil {
125                                                         t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
126                                                         continue
127                                                 }
128
129                                                 if len(rbuf) != n {
130                                                         t.Errorf("%s: len(rbuf) is %d, want %d", name, len(rbuf), n)
131                                                         continue
132                                                 }
133
134                                                 for i, b := range rbuf {
135                                                         if byte(i) != b {
136                                                                 t.Errorf("%s: bad byte at offset %d", name, i)
137                                                                 break
138                                                         }
139                                                 }
140                                         }
141                                 }
142                         }
143                 }
144         }
145 }
146
147 func TestControl(t *testing.T) {
148         const message = "this is a ping/pong messsage"
149         for _, isServer := range []bool{true, false} {
150                 for _, isWriteControl := range []bool{true, false} {
151                         name := fmt.Sprintf("s:%v, wc:%v", isServer, isWriteControl)
152                         var connBuf bytes.Buffer
153                         wc := newTestConn(nil, &connBuf, isServer)
154                         rc := newTestConn(&connBuf, nil, !isServer)
155                         if isWriteControl {
156                                 wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second))
157                         } else {
158                                 w, err := wc.NextWriter(PongMessage)
159                                 if err != nil {
160                                         t.Errorf("%s: wc.NextWriter() returned %v", name, err)
161                                         continue
162                                 }
163                                 if _, err := w.Write([]byte(message)); err != nil {
164                                         t.Errorf("%s: w.Write() returned %v", name, err)
165                                         continue
166                                 }
167                                 if err := w.Close(); err != nil {
168                                         t.Errorf("%s: w.Close() returned %v", name, err)
169                                         continue
170                                 }
171                                 var actualMessage string
172                                 rc.SetPongHandler(func(s string) error { actualMessage = s; return nil })
173                                 rc.NextReader()
174                                 if actualMessage != message {
175                                         t.Errorf("%s: pong=%q, want %q", name, actualMessage, message)
176                                         continue
177                                 }
178                         }
179                 }
180         }
181 }
182
183 // simpleBufferPool is an implementation of BufferPool for TestWriteBufferPool.
184 type simpleBufferPool struct {
185         v interface{}
186 }
187
188 func (p *simpleBufferPool) Get() interface{} {
189         v := p.v
190         p.v = nil
191         return v
192 }
193
194 func (p *simpleBufferPool) Put(v interface{}) {
195         p.v = v
196 }
197
198 func TestWriteBufferPool(t *testing.T) {
199         const message = "Now is the time for all good people to come to the aid of the party."
200
201         var buf bytes.Buffer
202         var pool simpleBufferPool
203         rc := newTestConn(&buf, nil, false)
204
205         // Specify writeBufferSize smaller than message size to ensure that pooling
206         // works with fragmented messages.
207         wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, len(message)-1, &pool, nil, nil)
208
209         if wc.writeBuf != nil {
210                 t.Fatal("writeBuf not nil after create")
211         }
212
213         // Part 1: test NextWriter/Write/Close
214
215         w, err := wc.NextWriter(TextMessage)
216         if err != nil {
217                 t.Fatalf("wc.NextWriter() returned %v", err)
218         }
219
220         if wc.writeBuf == nil {
221                 t.Fatal("writeBuf is nil after NextWriter")
222         }
223
224         writeBufAddr := &wc.writeBuf[0]
225
226         if _, err := io.WriteString(w, message); err != nil {
227                 t.Fatalf("io.WriteString(w, message) returned %v", err)
228         }
229
230         if err := w.Close(); err != nil {
231                 t.Fatalf("w.Close() returned %v", err)
232         }
233
234         if wc.writeBuf != nil {
235                 t.Fatal("writeBuf not nil after w.Close()")
236         }
237
238         if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
239                 t.Fatal("writeBuf not returned to pool")
240         }
241
242         opCode, p, err := rc.ReadMessage()
243         if opCode != TextMessage || err != nil {
244                 t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err)
245         }
246
247         if s := string(p); s != message {
248                 t.Fatalf("message is %s, want %s", s, message)
249         }
250
251         // Part 2: Test WriteMessage.
252
253         if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil {
254                 t.Fatalf("wc.WriteMessage() returned %v", err)
255         }
256
257         if wc.writeBuf != nil {
258                 t.Fatal("writeBuf not nil after wc.WriteMessage()")
259         }
260
261         if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
262                 t.Fatal("writeBuf not returned to pool after WriteMessage")
263         }
264
265         opCode, p, err = rc.ReadMessage()
266         if opCode != TextMessage || err != nil {
267                 t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err)
268         }
269
270         if s := string(p); s != message {
271                 t.Fatalf("message is %s, want %s", s, message)
272         }
273 }
274
275 // TestWriteBufferPoolSync ensures that *sync.Pool works as a buffer pool.
276 func TestWriteBufferPoolSync(t *testing.T) {
277         var buf bytes.Buffer
278         var pool sync.Pool
279         wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil)
280         rc := newTestConn(&buf, nil, false)
281
282         const message = "Hello World!"
283         for i := 0; i < 3; i++ {
284                 if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil {
285                         t.Fatalf("wc.WriteMessage() returned %v", err)
286                 }
287                 opCode, p, err := rc.ReadMessage()
288                 if opCode != TextMessage || err != nil {
289                         t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err)
290                 }
291                 if s := string(p); s != message {
292                         t.Fatalf("message is %s, want %s", s, message)
293                 }
294         }
295 }
296
297 // errorWriter is an io.Writer than returns an error on all writes.
298 type errorWriter struct{}
299
300 func (ew errorWriter) Write(p []byte) (int, error) { return 0, errors.New("Error!") }
301
302 // TestWriteBufferPoolError ensures that buffer is returned to pool after error
303 // on write.
304 func TestWriteBufferPoolError(t *testing.T) {
305
306         // Part 1: Test NextWriter/Write/Close
307
308         var pool simpleBufferPool
309         wc := newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil)
310
311         w, err := wc.NextWriter(TextMessage)
312         if err != nil {
313                 t.Fatalf("wc.NextWriter() returned %v", err)
314         }
315
316         if wc.writeBuf == nil {
317                 t.Fatal("writeBuf is nil after NextWriter")
318         }
319
320         writeBufAddr := &wc.writeBuf[0]
321
322         if _, err := io.WriteString(w, "Hello"); err != nil {
323                 t.Fatalf("io.WriteString(w, message) returned %v", err)
324         }
325
326         if err := w.Close(); err == nil {
327                 t.Fatalf("w.Close() did not return error")
328         }
329
330         if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
331                 t.Fatal("writeBuf not returned to pool")
332         }
333
334         // Part 2: Test WriteMessage
335
336         wc = newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil)
337
338         if err := wc.WriteMessage(TextMessage, []byte("Hello")); err == nil {
339                 t.Fatalf("wc.WriteMessage did not return error")
340         }
341
342         if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
343                 t.Fatal("writeBuf not returned to pool")
344         }
345 }
346
347 func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
348         const bufSize = 512
349
350         expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"}
351
352         var b1, b2 bytes.Buffer
353         wc := newConn(&fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize, nil, nil, nil)
354         rc := newTestConn(&b1, &b2, true)
355
356         w, _ := wc.NextWriter(BinaryMessage)
357         w.Write(make([]byte, bufSize+bufSize/2))
358         wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second))
359         w.Close()
360
361         op, r, err := rc.NextReader()
362         if op != BinaryMessage || err != nil {
363                 t.Fatalf("NextReader() returned %d, %v", op, err)
364         }
365         _, err = io.Copy(ioutil.Discard, r)
366         if !reflect.DeepEqual(err, expectedErr) {
367                 t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr)
368         }
369         _, _, err = rc.NextReader()
370         if !reflect.DeepEqual(err, expectedErr) {
371                 t.Fatalf("NextReader() returned %v, want %v", err, expectedErr)
372         }
373 }
374
375 func TestEOFWithinFrame(t *testing.T) {
376         const bufSize = 64
377
378         for n := 0; ; n++ {
379                 var b bytes.Buffer
380                 wc := newTestConn(nil, &b, false)
381                 rc := newTestConn(&b, nil, true)
382
383                 w, _ := wc.NextWriter(BinaryMessage)
384                 w.Write(make([]byte, bufSize))
385                 w.Close()
386
387                 if n >= b.Len() {
388                         break
389                 }
390                 b.Truncate(n)
391
392                 op, r, err := rc.NextReader()
393                 if err == errUnexpectedEOF {
394                         continue
395                 }
396                 if op != BinaryMessage || err != nil {
397                         t.Fatalf("%d: NextReader() returned %d, %v", n, op, err)
398                 }
399                 _, err = io.Copy(ioutil.Discard, r)
400                 if err != errUnexpectedEOF {
401                         t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF)
402                 }
403                 _, _, err = rc.NextReader()
404                 if err != errUnexpectedEOF {
405                         t.Fatalf("%d: NextReader() returned %v, want %v", n, err, errUnexpectedEOF)
406                 }
407         }
408 }
409
410 func TestEOFBeforeFinalFrame(t *testing.T) {
411         const bufSize = 512
412
413         var b1, b2 bytes.Buffer
414         wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, bufSize, nil, nil, nil)
415         rc := newTestConn(&b1, &b2, true)
416
417         w, _ := wc.NextWriter(BinaryMessage)
418         w.Write(make([]byte, bufSize+bufSize/2))
419
420         op, r, err := rc.NextReader()
421         if op != BinaryMessage || err != nil {
422                 t.Fatalf("NextReader() returned %d, %v", op, err)
423         }
424         _, err = io.Copy(ioutil.Discard, r)
425         if err != errUnexpectedEOF {
426                 t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF)
427         }
428         _, _, err = rc.NextReader()
429         if err != errUnexpectedEOF {
430                 t.Fatalf("NextReader() returned %v, want %v", err, errUnexpectedEOF)
431         }
432 }
433
434 func TestWriteAfterMessageWriterClose(t *testing.T) {
435         wc := newTestConn(nil, &bytes.Buffer{}, false)
436         w, _ := wc.NextWriter(BinaryMessage)
437         io.WriteString(w, "hello")
438         if err := w.Close(); err != nil {
439                 t.Fatalf("unxpected error closing message writer, %v", err)
440         }
441
442         if _, err := io.WriteString(w, "world"); err == nil {
443                 t.Fatalf("no error writing after close")
444         }
445
446         w, _ = wc.NextWriter(BinaryMessage)
447         io.WriteString(w, "hello")
448
449         // close w by getting next writer
450         _, err := wc.NextWriter(BinaryMessage)
451         if err != nil {
452                 t.Fatalf("unexpected error getting next writer, %v", err)
453         }
454
455         if _, err := io.WriteString(w, "world"); err == nil {
456                 t.Fatalf("no error writing after close")
457         }
458 }
459
460 func TestReadLimit(t *testing.T) {
461
462         const readLimit = 512
463         message := make([]byte, readLimit+1)
464
465         var b1, b2 bytes.Buffer
466         wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, readLimit-2, nil, nil, nil)
467         rc := newTestConn(&b1, &b2, true)
468         rc.SetReadLimit(readLimit)
469
470         // Send message at the limit with interleaved pong.
471         w, _ := wc.NextWriter(BinaryMessage)
472         w.Write(message[:readLimit-1])
473         wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second))
474         w.Write(message[:1])
475         w.Close()
476
477         // Send message larger than the limit.
478         wc.WriteMessage(BinaryMessage, message[:readLimit+1])
479
480         op, _, err := rc.NextReader()
481         if op != BinaryMessage || err != nil {
482                 t.Fatalf("1: NextReader() returned %d, %v", op, err)
483         }
484         op, r, err := rc.NextReader()
485         if op != BinaryMessage || err != nil {
486                 t.Fatalf("2: NextReader() returned %d, %v", op, err)
487         }
488         _, err = io.Copy(ioutil.Discard, r)
489         if err != ErrReadLimit {
490                 t.Fatalf("io.Copy() returned %v", err)
491         }
492 }
493
494 func TestAddrs(t *testing.T) {
495         c := newTestConn(nil, nil, true)
496         if c.LocalAddr() != localAddr {
497                 t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr)
498         }
499         if c.RemoteAddr() != remoteAddr {
500                 t.Errorf("RemoteAddr = %v, want %v", c.RemoteAddr(), remoteAddr)
501         }
502 }
503
504 func TestUnderlyingConn(t *testing.T) {
505         var b1, b2 bytes.Buffer
506         fc := fakeNetConn{Reader: &b1, Writer: &b2}
507         c := newConn(fc, true, 1024, 1024, nil, nil, nil)
508         ul := c.UnderlyingConn()
509         if ul != fc {
510                 t.Fatalf("Underlying conn is not what it should be.")
511         }
512 }
513
514 func TestBufioReadBytes(t *testing.T) {
515         // Test calling bufio.ReadBytes for value longer than read buffer size.
516
517         m := make([]byte, 512)
518         m[len(m)-1] = '\n'
519
520         var b1, b2 bytes.Buffer
521         wc := newConn(fakeNetConn{Writer: &b1}, false, len(m)+64, len(m)+64, nil, nil, nil)
522         rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64, nil, nil, nil)
523
524         w, _ := wc.NextWriter(BinaryMessage)
525         w.Write(m)
526         w.Close()
527
528         op, r, err := rc.NextReader()
529         if op != BinaryMessage || err != nil {
530                 t.Fatalf("NextReader() returned %d, %v", op, err)
531         }
532
533         br := bufio.NewReader(r)
534         p, err := br.ReadBytes('\n')
535         if err != nil {
536                 t.Fatalf("ReadBytes() returned %v", err)
537         }
538         if len(p) != len(m) {
539                 t.Fatalf("read returned %d bytes, want %d bytes", len(p), len(m))
540         }
541 }
542
543 var closeErrorTests = []struct {
544         err   error
545         codes []int
546         ok    bool
547 }{
548         {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, true},
549         {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, false},
550         {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, true},
551         {errors.New("hello"), []int{CloseNormalClosure}, false},
552 }
553
554 func TestCloseError(t *testing.T) {
555         for _, tt := range closeErrorTests {
556                 ok := IsCloseError(tt.err, tt.codes...)
557                 if ok != tt.ok {
558                         t.Errorf("IsCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok)
559                 }
560         }
561 }
562
563 var unexpectedCloseErrorTests = []struct {
564         err   error
565         codes []int
566         ok    bool
567 }{
568         {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, false},
569         {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, true},
570         {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, false},
571         {errors.New("hello"), []int{CloseNormalClosure}, false},
572 }
573
574 func TestUnexpectedCloseErrors(t *testing.T) {
575         for _, tt := range unexpectedCloseErrorTests {
576                 ok := IsUnexpectedCloseError(tt.err, tt.codes...)
577                 if ok != tt.ok {
578                         t.Errorf("IsUnexpectedCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok)
579                 }
580         }
581 }
582
583 type blockingWriter struct {
584         c1, c2 chan struct{}
585 }
586
587 func (w blockingWriter) Write(p []byte) (int, error) {
588         // Allow main to continue
589         close(w.c1)
590         // Wait for panic in main
591         <-w.c2
592         return len(p), nil
593 }
594
595 func TestConcurrentWritePanic(t *testing.T) {
596         w := blockingWriter{make(chan struct{}), make(chan struct{})}
597         c := newTestConn(nil, w, false)
598         go func() {
599                 c.WriteMessage(TextMessage, []byte{})
600         }()
601
602         // wait for goroutine to block in write.
603         <-w.c1
604
605         defer func() {
606                 close(w.c2)
607                 if v := recover(); v != nil {
608                         return
609                 }
610         }()
611
612         c.WriteMessage(TextMessage, []byte{})
613         t.Fatal("should not get here")
614 }
615
616 type failingReader struct{}
617
618 func (r failingReader) Read(p []byte) (int, error) {
619         return 0, io.EOF
620 }
621
622 func TestFailedConnectionReadPanic(t *testing.T) {
623         c := newTestConn(failingReader{}, nil, false)
624
625         defer func() {
626                 if v := recover(); v != nil {
627                         return
628                 }
629         }()
630
631         for i := 0; i < 20000; i++ {
632                 c.ReadMessage()
633         }
634         t.Fatal("should not get here")
635 }