1 // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
22 var _ net.Error = errWriteTimeout
24 type fakeNetConn struct {
29 func (c fakeNetConn) Close() error { return nil }
30 func (c fakeNetConn) LocalAddr() net.Addr { return localAddr }
31 func (c fakeNetConn) RemoteAddr() net.Addr { return remoteAddr }
32 func (c fakeNetConn) SetDeadline(t time.Time) error { return nil }
33 func (c fakeNetConn) SetReadDeadline(t time.Time) error { return nil }
34 func (c fakeNetConn) SetWriteDeadline(t time.Time) error { return nil }
39 localAddr = fakeAddr(1)
40 remoteAddr = fakeAddr(2)
43 func (a fakeAddr) Network() string {
47 func (a fakeAddr) String() string {
51 // newTestConn creates a connnection backed by a fake network connection using
52 // default values for buffering.
53 func newTestConn(r io.Reader, w io.Writer, isServer bool) *Conn {
54 return newConn(fakeNetConn{Reader: r, Writer: w}, isServer, 1024, 1024, nil, nil, nil)
57 func TestFraming(t *testing.T) {
58 frameSizes := []int{0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, 65536, 65537}
59 var readChunkers = []struct {
61 f func(io.Reader) io.Reader
63 {"half", iotest.HalfReader},
64 {"one", iotest.OneByteReader},
65 {"asis", func(r io.Reader) io.Reader { return r }},
67 writeBuf := make([]byte, 65537)
68 for i := range writeBuf {
71 var writers = []struct {
73 f func(w io.Writer, n int) (int, error)
75 {"iocopy", func(w io.Writer, n int) (int, error) {
76 nn, err := io.Copy(w, bytes.NewReader(writeBuf[:n]))
79 {"write", func(w io.Writer, n int) (int, error) {
80 return w.Write(writeBuf[:n])
82 {"string", func(w io.Writer, n int) (int, error) {
83 return io.WriteString(w, string(writeBuf[:n]))
87 for _, compress := range []bool{false, true} {
88 for _, isServer := range []bool{true, false} {
89 for _, chunker := range readChunkers {
91 var connBuf bytes.Buffer
92 wc := newTestConn(nil, &connBuf, isServer)
93 rc := newTestConn(chunker.f(&connBuf), nil, !isServer)
95 wc.newCompressionWriter = compressNoContextTakeover
96 rc.newDecompressionReader = decompressNoContextTakeover
98 for _, n := range frameSizes {
99 for _, writer := range writers {
100 name := fmt.Sprintf("z:%v, s:%v, r:%s, n:%d w:%s", compress, isServer, chunker.name, n, writer.name)
102 w, err := wc.NextWriter(TextMessage)
104 t.Errorf("%s: wc.NextWriter() returned %v", name, err)
107 nn, err := writer.f(w, n)
108 if err != nil || nn != n {
109 t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err)
114 t.Errorf("%s: w.Close() returned %v", name, err)
118 opCode, r, err := rc.NextReader()
119 if err != nil || opCode != TextMessage {
120 t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err)
123 rbuf, err := ioutil.ReadAll(r)
125 t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
130 t.Errorf("%s: len(rbuf) is %d, want %d", name, len(rbuf), n)
134 for i, b := range rbuf {
136 t.Errorf("%s: bad byte at offset %d", name, i)
147 func TestControl(t *testing.T) {
148 const message = "this is a ping/pong messsage"
149 for _, isServer := range []bool{true, false} {
150 for _, isWriteControl := range []bool{true, false} {
151 name := fmt.Sprintf("s:%v, wc:%v", isServer, isWriteControl)
152 var connBuf bytes.Buffer
153 wc := newTestConn(nil, &connBuf, isServer)
154 rc := newTestConn(&connBuf, nil, !isServer)
156 wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second))
158 w, err := wc.NextWriter(PongMessage)
160 t.Errorf("%s: wc.NextWriter() returned %v", name, err)
163 if _, err := w.Write([]byte(message)); err != nil {
164 t.Errorf("%s: w.Write() returned %v", name, err)
167 if err := w.Close(); err != nil {
168 t.Errorf("%s: w.Close() returned %v", name, err)
171 var actualMessage string
172 rc.SetPongHandler(func(s string) error { actualMessage = s; return nil })
174 if actualMessage != message {
175 t.Errorf("%s: pong=%q, want %q", name, actualMessage, message)
183 // simpleBufferPool is an implementation of BufferPool for TestWriteBufferPool.
184 type simpleBufferPool struct {
188 func (p *simpleBufferPool) Get() interface{} {
194 func (p *simpleBufferPool) Put(v interface{}) {
198 func TestWriteBufferPool(t *testing.T) {
199 const message = "Now is the time for all good people to come to the aid of the party."
202 var pool simpleBufferPool
203 rc := newTestConn(&buf, nil, false)
205 // Specify writeBufferSize smaller than message size to ensure that pooling
206 // works with fragmented messages.
207 wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, len(message)-1, &pool, nil, nil)
209 if wc.writeBuf != nil {
210 t.Fatal("writeBuf not nil after create")
213 // Part 1: test NextWriter/Write/Close
215 w, err := wc.NextWriter(TextMessage)
217 t.Fatalf("wc.NextWriter() returned %v", err)
220 if wc.writeBuf == nil {
221 t.Fatal("writeBuf is nil after NextWriter")
224 writeBufAddr := &wc.writeBuf[0]
226 if _, err := io.WriteString(w, message); err != nil {
227 t.Fatalf("io.WriteString(w, message) returned %v", err)
230 if err := w.Close(); err != nil {
231 t.Fatalf("w.Close() returned %v", err)
234 if wc.writeBuf != nil {
235 t.Fatal("writeBuf not nil after w.Close()")
238 if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
239 t.Fatal("writeBuf not returned to pool")
242 opCode, p, err := rc.ReadMessage()
243 if opCode != TextMessage || err != nil {
244 t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err)
247 if s := string(p); s != message {
248 t.Fatalf("message is %s, want %s", s, message)
251 // Part 2: Test WriteMessage.
253 if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil {
254 t.Fatalf("wc.WriteMessage() returned %v", err)
257 if wc.writeBuf != nil {
258 t.Fatal("writeBuf not nil after wc.WriteMessage()")
261 if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
262 t.Fatal("writeBuf not returned to pool after WriteMessage")
265 opCode, p, err = rc.ReadMessage()
266 if opCode != TextMessage || err != nil {
267 t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err)
270 if s := string(p); s != message {
271 t.Fatalf("message is %s, want %s", s, message)
275 // TestWriteBufferPoolSync ensures that *sync.Pool works as a buffer pool.
276 func TestWriteBufferPoolSync(t *testing.T) {
279 wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil)
280 rc := newTestConn(&buf, nil, false)
282 const message = "Hello World!"
283 for i := 0; i < 3; i++ {
284 if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil {
285 t.Fatalf("wc.WriteMessage() returned %v", err)
287 opCode, p, err := rc.ReadMessage()
288 if opCode != TextMessage || err != nil {
289 t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err)
291 if s := string(p); s != message {
292 t.Fatalf("message is %s, want %s", s, message)
297 // errorWriter is an io.Writer than returns an error on all writes.
298 type errorWriter struct{}
300 func (ew errorWriter) Write(p []byte) (int, error) { return 0, errors.New("Error!") }
302 // TestWriteBufferPoolError ensures that buffer is returned to pool after error
304 func TestWriteBufferPoolError(t *testing.T) {
306 // Part 1: Test NextWriter/Write/Close
308 var pool simpleBufferPool
309 wc := newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil)
311 w, err := wc.NextWriter(TextMessage)
313 t.Fatalf("wc.NextWriter() returned %v", err)
316 if wc.writeBuf == nil {
317 t.Fatal("writeBuf is nil after NextWriter")
320 writeBufAddr := &wc.writeBuf[0]
322 if _, err := io.WriteString(w, "Hello"); err != nil {
323 t.Fatalf("io.WriteString(w, message) returned %v", err)
326 if err := w.Close(); err == nil {
327 t.Fatalf("w.Close() did not return error")
330 if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
331 t.Fatal("writeBuf not returned to pool")
334 // Part 2: Test WriteMessage
336 wc = newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil)
338 if err := wc.WriteMessage(TextMessage, []byte("Hello")); err == nil {
339 t.Fatalf("wc.WriteMessage did not return error")
342 if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
343 t.Fatal("writeBuf not returned to pool")
347 func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
350 expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"}
352 var b1, b2 bytes.Buffer
353 wc := newConn(&fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize, nil, nil, nil)
354 rc := newTestConn(&b1, &b2, true)
356 w, _ := wc.NextWriter(BinaryMessage)
357 w.Write(make([]byte, bufSize+bufSize/2))
358 wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second))
361 op, r, err := rc.NextReader()
362 if op != BinaryMessage || err != nil {
363 t.Fatalf("NextReader() returned %d, %v", op, err)
365 _, err = io.Copy(ioutil.Discard, r)
366 if !reflect.DeepEqual(err, expectedErr) {
367 t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr)
369 _, _, err = rc.NextReader()
370 if !reflect.DeepEqual(err, expectedErr) {
371 t.Fatalf("NextReader() returned %v, want %v", err, expectedErr)
375 func TestEOFWithinFrame(t *testing.T) {
380 wc := newTestConn(nil, &b, false)
381 rc := newTestConn(&b, nil, true)
383 w, _ := wc.NextWriter(BinaryMessage)
384 w.Write(make([]byte, bufSize))
392 op, r, err := rc.NextReader()
393 if err == errUnexpectedEOF {
396 if op != BinaryMessage || err != nil {
397 t.Fatalf("%d: NextReader() returned %d, %v", n, op, err)
399 _, err = io.Copy(ioutil.Discard, r)
400 if err != errUnexpectedEOF {
401 t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF)
403 _, _, err = rc.NextReader()
404 if err != errUnexpectedEOF {
405 t.Fatalf("%d: NextReader() returned %v, want %v", n, err, errUnexpectedEOF)
410 func TestEOFBeforeFinalFrame(t *testing.T) {
413 var b1, b2 bytes.Buffer
414 wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, bufSize, nil, nil, nil)
415 rc := newTestConn(&b1, &b2, true)
417 w, _ := wc.NextWriter(BinaryMessage)
418 w.Write(make([]byte, bufSize+bufSize/2))
420 op, r, err := rc.NextReader()
421 if op != BinaryMessage || err != nil {
422 t.Fatalf("NextReader() returned %d, %v", op, err)
424 _, err = io.Copy(ioutil.Discard, r)
425 if err != errUnexpectedEOF {
426 t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF)
428 _, _, err = rc.NextReader()
429 if err != errUnexpectedEOF {
430 t.Fatalf("NextReader() returned %v, want %v", err, errUnexpectedEOF)
434 func TestWriteAfterMessageWriterClose(t *testing.T) {
435 wc := newTestConn(nil, &bytes.Buffer{}, false)
436 w, _ := wc.NextWriter(BinaryMessage)
437 io.WriteString(w, "hello")
438 if err := w.Close(); err != nil {
439 t.Fatalf("unxpected error closing message writer, %v", err)
442 if _, err := io.WriteString(w, "world"); err == nil {
443 t.Fatalf("no error writing after close")
446 w, _ = wc.NextWriter(BinaryMessage)
447 io.WriteString(w, "hello")
449 // close w by getting next writer
450 _, err := wc.NextWriter(BinaryMessage)
452 t.Fatalf("unexpected error getting next writer, %v", err)
455 if _, err := io.WriteString(w, "world"); err == nil {
456 t.Fatalf("no error writing after close")
460 func TestReadLimit(t *testing.T) {
462 const readLimit = 512
463 message := make([]byte, readLimit+1)
465 var b1, b2 bytes.Buffer
466 wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, readLimit-2, nil, nil, nil)
467 rc := newTestConn(&b1, &b2, true)
468 rc.SetReadLimit(readLimit)
470 // Send message at the limit with interleaved pong.
471 w, _ := wc.NextWriter(BinaryMessage)
472 w.Write(message[:readLimit-1])
473 wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second))
477 // Send message larger than the limit.
478 wc.WriteMessage(BinaryMessage, message[:readLimit+1])
480 op, _, err := rc.NextReader()
481 if op != BinaryMessage || err != nil {
482 t.Fatalf("1: NextReader() returned %d, %v", op, err)
484 op, r, err := rc.NextReader()
485 if op != BinaryMessage || err != nil {
486 t.Fatalf("2: NextReader() returned %d, %v", op, err)
488 _, err = io.Copy(ioutil.Discard, r)
489 if err != ErrReadLimit {
490 t.Fatalf("io.Copy() returned %v", err)
494 func TestAddrs(t *testing.T) {
495 c := newTestConn(nil, nil, true)
496 if c.LocalAddr() != localAddr {
497 t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr)
499 if c.RemoteAddr() != remoteAddr {
500 t.Errorf("RemoteAddr = %v, want %v", c.RemoteAddr(), remoteAddr)
504 func TestUnderlyingConn(t *testing.T) {
505 var b1, b2 bytes.Buffer
506 fc := fakeNetConn{Reader: &b1, Writer: &b2}
507 c := newConn(fc, true, 1024, 1024, nil, nil, nil)
508 ul := c.UnderlyingConn()
510 t.Fatalf("Underlying conn is not what it should be.")
514 func TestBufioReadBytes(t *testing.T) {
515 // Test calling bufio.ReadBytes for value longer than read buffer size.
517 m := make([]byte, 512)
520 var b1, b2 bytes.Buffer
521 wc := newConn(fakeNetConn{Writer: &b1}, false, len(m)+64, len(m)+64, nil, nil, nil)
522 rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64, nil, nil, nil)
524 w, _ := wc.NextWriter(BinaryMessage)
528 op, r, err := rc.NextReader()
529 if op != BinaryMessage || err != nil {
530 t.Fatalf("NextReader() returned %d, %v", op, err)
533 br := bufio.NewReader(r)
534 p, err := br.ReadBytes('\n')
536 t.Fatalf("ReadBytes() returned %v", err)
538 if len(p) != len(m) {
539 t.Fatalf("read returned %d bytes, want %d bytes", len(p), len(m))
543 var closeErrorTests = []struct {
548 {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, true},
549 {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, false},
550 {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, true},
551 {errors.New("hello"), []int{CloseNormalClosure}, false},
554 func TestCloseError(t *testing.T) {
555 for _, tt := range closeErrorTests {
556 ok := IsCloseError(tt.err, tt.codes...)
558 t.Errorf("IsCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok)
563 var unexpectedCloseErrorTests = []struct {
568 {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, false},
569 {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, true},
570 {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, false},
571 {errors.New("hello"), []int{CloseNormalClosure}, false},
574 func TestUnexpectedCloseErrors(t *testing.T) {
575 for _, tt := range unexpectedCloseErrorTests {
576 ok := IsUnexpectedCloseError(tt.err, tt.codes...)
578 t.Errorf("IsUnexpectedCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok)
583 type blockingWriter struct {
587 func (w blockingWriter) Write(p []byte) (int, error) {
588 // Allow main to continue
590 // Wait for panic in main
595 func TestConcurrentWritePanic(t *testing.T) {
596 w := blockingWriter{make(chan struct{}), make(chan struct{})}
597 c := newTestConn(nil, w, false)
599 c.WriteMessage(TextMessage, []byte{})
602 // wait for goroutine to block in write.
607 if v := recover(); v != nil {
612 c.WriteMessage(TextMessage, []byte{})
613 t.Fatal("should not get here")
616 type failingReader struct{}
618 func (r failingReader) Read(p []byte) (int, error) {
622 func TestFailedConnectionReadPanic(t *testing.T) {
623 c := newTestConn(failingReader{}, nil, false)
626 if v := recover(); v != nil {
631 for i := 0; i < 20000; i++ {
634 t.Fatal("should not get here")