1 // Copyright 2009 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
28 func echoServer(ws *Conn) {
38 func countServer(ws *Conn) {
42 err := JSON.Receive(ws, &count)
47 count.S = strings.Repeat(count.S, count.N)
48 err = JSON.Send(ws, count)
55 type testCtrlAndDataHandler struct {
59 func (h *testCtrlAndDataHandler) WritePing(b []byte) (int, error) {
60 h.hybiFrameHandler.conn.wio.Lock()
61 defer h.hybiFrameHandler.conn.wio.Unlock()
62 w, err := h.hybiFrameHandler.conn.frameWriterFactory.NewFrameWriter(PingFrame)
71 func ctrlAndDataServer(ws *Conn) {
73 h := &testCtrlAndDataHandler{hybiFrameHandler: hybiFrameHandler{conn: ws}}
79 if i%2 != 0 { // with or without payload
80 b = []byte(fmt.Sprintf("#%d-CONTROL-FRAME-FROM-SERVER", i))
82 if _, err := h.WritePing(b); err != nil {
85 if _, err := h.WritePong(b); err != nil { // unsolicited pong
88 time.Sleep(10 * time.Millisecond)
92 b := make([]byte, 128)
98 if _, err := ws.Write(b[:n]); err != nil {
104 func subProtocolHandshake(config *Config, req *http.Request) error {
105 for _, proto := range config.Protocol {
107 config.Protocol = []string{proto}
111 return ErrBadWebSocketProtocol
114 func subProtoServer(ws *Conn) {
115 for _, proto := range ws.Config().Protocol {
116 io.WriteString(ws, proto)
121 http.Handle("/echo", Handler(echoServer))
122 http.Handle("/count", Handler(countServer))
123 http.Handle("/ctrldata", Handler(ctrlAndDataServer))
125 Handshake: subProtocolHandshake,
126 Handler: Handler(subProtoServer),
128 http.Handle("/subproto", subproto)
129 server := httptest.NewServer(nil)
130 serverAddr = server.Listener.Addr().String()
131 log.Print("Test WebSocket server listening on ", serverAddr)
134 func newConfig(t *testing.T, path string) *Config {
135 config, _ := NewConfig(fmt.Sprintf("ws://%s%s", serverAddr, path), "http://localhost")
139 func TestEcho(t *testing.T) {
143 client, err := net.Dial("tcp", serverAddr)
145 t.Fatal("dialing", err)
147 conn, err := NewClient(newConfig(t, "/echo"), client)
149 t.Errorf("WebSocket handshake error: %v", err)
153 msg := []byte("hello, world\n")
154 if _, err := conn.Write(msg); err != nil {
155 t.Errorf("Write: %v", err)
157 var actual_msg = make([]byte, 512)
158 n, err := conn.Read(actual_msg)
160 t.Errorf("Read: %v", err)
162 actual_msg = actual_msg[0:n]
163 if !bytes.Equal(msg, actual_msg) {
164 t.Errorf("Echo: expected %q got %q", msg, actual_msg)
169 func TestAddr(t *testing.T) {
173 client, err := net.Dial("tcp", serverAddr)
175 t.Fatal("dialing", err)
177 conn, err := NewClient(newConfig(t, "/echo"), client)
179 t.Errorf("WebSocket handshake error: %v", err)
183 ra := conn.RemoteAddr().String()
184 if !strings.HasPrefix(ra, "ws://") || !strings.HasSuffix(ra, "/echo") {
185 t.Errorf("Bad remote addr: %v", ra)
187 la := conn.LocalAddr().String()
188 if !strings.HasPrefix(la, "http://") {
189 t.Errorf("Bad local addr: %v", la)
194 func TestCount(t *testing.T) {
198 client, err := net.Dial("tcp", serverAddr)
200 t.Fatal("dialing", err)
202 conn, err := NewClient(newConfig(t, "/count"), client)
204 t.Errorf("WebSocket handshake error: %v", err)
210 if err := JSON.Send(conn, count); err != nil {
211 t.Errorf("Write: %v", err)
213 if err := JSON.Receive(conn, &count); err != nil {
214 t.Errorf("Read: %v", err)
217 t.Errorf("count: expected %d got %d", 1, count.N)
219 if count.S != "hello" {
220 t.Errorf("count: expected %q got %q", "hello", count.S)
222 if err := JSON.Send(conn, count); err != nil {
223 t.Errorf("Write: %v", err)
225 if err := JSON.Receive(conn, &count); err != nil {
226 t.Errorf("Read: %v", err)
229 t.Errorf("count: expected %d got %d", 2, count.N)
231 if count.S != "hellohello" {
232 t.Errorf("count: expected %q got %q", "hellohello", count.S)
237 func TestWithQuery(t *testing.T) {
240 client, err := net.Dial("tcp", serverAddr)
242 t.Fatal("dialing", err)
245 config := newConfig(t, "/echo")
246 config.Location, err = url.ParseRequestURI(fmt.Sprintf("ws://%s/echo?q=v", serverAddr))
248 t.Fatal("location url", err)
251 ws, err := NewClient(config, client)
253 t.Errorf("WebSocket handshake: %v", err)
259 func testWithProtocol(t *testing.T, subproto []string) (string, error) {
262 client, err := net.Dial("tcp", serverAddr)
264 t.Fatal("dialing", err)
267 config := newConfig(t, "/subproto")
268 config.Protocol = subproto
270 ws, err := NewClient(config, client)
274 msg := make([]byte, 16)
275 n, err := ws.Read(msg)
280 return string(msg[:n]), nil
283 func TestWithProtocol(t *testing.T) {
284 proto, err := testWithProtocol(t, []string{"chat"})
286 t.Errorf("SubProto: unexpected error: %v", err)
289 t.Errorf("SubProto: expected %q, got %q", "chat", proto)
293 func TestWithTwoProtocol(t *testing.T) {
294 proto, err := testWithProtocol(t, []string{"test", "chat"})
296 t.Errorf("SubProto: unexpected error: %v", err)
299 t.Errorf("SubProto: expected %q, got %q", "chat", proto)
303 func TestWithBadProtocol(t *testing.T) {
304 _, err := testWithProtocol(t, []string{"test"})
305 if err != ErrBadStatus {
306 t.Errorf("SubProto: expected %v, got %v", ErrBadStatus, err)
310 func TestHTTP(t *testing.T) {
313 // If the client did not send a handshake that matches the protocol
314 // specification, the server MUST return an HTTP response with an
315 // appropriate error code (such as 400 Bad Request)
316 resp, err := http.Get(fmt.Sprintf("http://%s/echo", serverAddr))
318 t.Errorf("Get: error %#v", err)
322 t.Error("Get: resp is null")
325 if resp.StatusCode != http.StatusBadRequest {
326 t.Errorf("Get: expected %q got %q", http.StatusBadRequest, resp.StatusCode)
330 func TestTrailingSpaces(t *testing.T) {
331 // http://code.google.com/p/go/issues/detail?id=955
332 // The last runs of this create keys with trailing spaces that should not be
333 // generated by the client.
335 config := newConfig(t, "/echo")
336 for i := 0; i < 30; i++ {
338 ws, err := DialConfig(config)
340 t.Errorf("Dial #%d failed: %v", i, err)
347 func TestDialConfigBadVersion(t *testing.T) {
349 config := newConfig(t, "/echo")
350 config.Version = 1234
352 _, err := DialConfig(config)
354 if dialerr, ok := err.(*DialError); ok {
355 if dialerr.Err != ErrBadProtocolVersion {
356 t.Errorf("dial expected err %q but got %q", ErrBadProtocolVersion, dialerr.Err)
361 func TestDialConfigWithDialer(t *testing.T) {
363 config := newConfig(t, "/echo")
364 config.Dialer = &net.Dialer{
365 Deadline: time.Now().Add(-time.Minute),
367 _, err := DialConfig(config)
368 dialerr, ok := err.(*DialError)
370 t.Fatalf("DialError expected, got %#v", err)
372 neterr, ok := dialerr.Err.(*net.OpError)
374 t.Fatalf("net.OpError error expected, got %#v", dialerr.Err)
376 if !neterr.Timeout() {
377 t.Fatalf("expected timeout error, got %#v", neterr)
381 func TestSmallBuffer(t *testing.T) {
382 // http://code.google.com/p/go/issues/detail?id=1145
383 // Read should be able to handle reading a fragment of a frame.
387 client, err := net.Dial("tcp", serverAddr)
389 t.Fatal("dialing", err)
391 conn, err := NewClient(newConfig(t, "/echo"), client)
393 t.Errorf("WebSocket handshake error: %v", err)
397 msg := []byte("hello, world\n")
398 if _, err := conn.Write(msg); err != nil {
399 t.Errorf("Write: %v", err)
401 var small_msg = make([]byte, 8)
402 n, err := conn.Read(small_msg)
404 t.Errorf("Read: %v", err)
406 if !bytes.Equal(msg[:len(small_msg)], small_msg) {
407 t.Errorf("Echo: expected %q got %q", msg[:len(small_msg)], small_msg)
409 var second_msg = make([]byte, len(msg))
410 n, err = conn.Read(second_msg)
412 t.Errorf("Read: %v", err)
414 second_msg = second_msg[0:n]
415 if !bytes.Equal(msg[len(small_msg):], second_msg) {
416 t.Errorf("Echo: expected %q got %q", msg[len(small_msg):], second_msg)
421 var parseAuthorityTests = []struct {
428 Host: "www.google.com",
435 Host: "www.google.com",
437 "www.google.com:443",
442 Host: "www.google.com:80",
449 Host: "www.google.com:443",
451 "www.google.com:443",
453 // some invalid ones for parseAuthority. parseAuthority doesn't
454 // concern itself with the scheme unless it actually knows about it
458 Host: "www.google.com",
465 Host: "www.google.com:80",
479 Host: "www.google.com",
485 func TestParseAuthority(t *testing.T) {
486 for _, tt := range parseAuthorityTests {
487 out := parseAuthority(tt.in)
489 t.Errorf("got %v; want %v", out, tt.out)
494 type closerConn struct {
496 closed int // count of the number of times Close was called
499 func (c *closerConn) Close() error {
501 return c.Conn.Close()
504 func TestClose(t *testing.T) {
505 if runtime.GOOS == "plan9" {
506 t.Skip("see golang.org/issue/11454")
511 conn, err := net.Dial("tcp", serverAddr)
513 t.Fatal("dialing", err)
516 cc := closerConn{Conn: conn}
518 client, err := NewClient(newConfig(t, "/echo"), &cc)
520 t.Fatalf("WebSocket handshake: %v", err)
523 // set the deadline to ten minutes ago, which will have expired by the time
524 // client.Close sends the close status frame.
525 conn.SetDeadline(time.Now().Add(-10 * time.Minute))
527 if err := client.Close(); err == nil {
528 t.Errorf("ws.Close(): expected error, got %v", err)
531 t.Fatalf("ws.Close(): expected underlying ws.rwc.Close to be called > 0 times, got: %v", cc.closed)
535 var originTests = []struct {
542 "Origin": []string{"http://www.example.com"},
547 Host: "www.example.com",
551 req: &http.Request{},
555 func TestOrigin(t *testing.T) {
556 conf := newConfig(t, "/echo")
557 conf.Version = ProtocolVersionHybi13
558 for i, tt := range originTests {
559 origin, err := Origin(conf, tt.req)
564 if !reflect.DeepEqual(origin, tt.origin) {
565 t.Errorf("#%d: got origin %v; want %v", i, origin, tt.origin)
571 func TestCtrlAndData(t *testing.T) {
574 c, err := net.Dial("tcp", serverAddr)
578 ws, err := NewClient(newConfig(t, "/ctrldata"), c)
584 h := &testCtrlAndDataHandler{hybiFrameHandler: hybiFrameHandler{conn: ws}}
587 b := make([]byte, 128)
588 for i := 0; i < 2; i++ {
589 data := []byte(fmt.Sprintf("#%d-DATA-FRAME-FROM-CLIENT", i))
590 if _, err := ws.Write(data); err != nil {
591 t.Fatalf("#%d: %v", i, err)
594 if i%2 != 0 { // with or without payload
595 ctrl = []byte(fmt.Sprintf("#%d-CONTROL-FRAME-FROM-CLIENT", i))
597 if _, err := h.WritePing(ctrl); err != nil {
598 t.Fatalf("#%d: %v", i, err)
602 t.Fatalf("#%d: %v", i, err)
604 if !bytes.Equal(b[:n], data) {
605 t.Fatalf("#%d: got %v; want %v", i, b[:n], data)
610 func TestCodec_ReceiveLimited(t *testing.T) {
612 var payloads [][]byte
613 for _, size := range []int{
616 4096, // receive of this message would be interrupted due to limit
617 2048, // this one is to make sure next receive recovers discarding leftovers
619 b := make([]byte, size)
621 payloads = append(payloads, b)
623 handlerDone := make(chan struct{})
624 limitedHandler := func(ws *Conn) {
625 defer close(handlerDone)
626 ws.MaxPayloadBytes = limit
628 for i, p := range payloads {
629 t.Logf("payload #%d (size %d, exceeds limit: %v)", i, len(p), len(p) > limit)
631 err := Message.Receive(ws, &recv)
634 case ErrFrameTooLarge:
636 t.Fatalf("unexpected frame size limit: expected %d bytes of payload having limit at %d", len(p), limit)
640 t.Fatalf("unexpected error: %v (want either nil or ErrFrameTooLarge)", err)
642 if len(recv) > limit {
643 t.Fatalf("received %d bytes of payload having limit at %d", len(recv), limit)
645 if !bytes.Equal(p, recv) {
646 t.Fatalf("received payload differs:\ngot:\t%v\nwant:\t%v", recv, p)
650 server := httptest.NewServer(Handler(limitedHandler))
651 defer server.CloseClientConnections()
653 addr := server.Listener.Addr().String()
654 ws, err := Dial("ws://"+addr+"/", "", "http://localhost/")
659 for i, p := range payloads {
660 if err := Message.Send(ws, p); err != nil {
661 t.Fatalf("payload #%d (size %d): %v", i, len(p), err)