1 // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
30 var cstUpgrader = Upgrader{
31 Subprotocols: []string{"p0", "p1"},
33 WriteBufferSize: 1024,
34 EnableCompression: true,
35 Error: func(w http.ResponseWriter, r *http.Request, status int, reason error) {
36 http.Error(w, reason.Error(), status)
40 var cstDialer = Dialer{
41 Subprotocols: []string{"p1", "p2"},
43 WriteBufferSize: 1024,
44 HandshakeTimeout: 30 * time.Second,
47 type cstHandler struct{ *testing.T }
49 type cstServer struct {
58 cstRequestURI = cstPath + "?" + cstRawQuery
61 func newServer(t *testing.T) *cstServer {
63 s.Server = httptest.NewServer(cstHandler{t})
64 s.Server.URL += cstRequestURI
65 s.URL = makeWsProto(s.Server.URL)
69 func newTLSServer(t *testing.T) *cstServer {
71 s.Server = httptest.NewTLSServer(cstHandler{t})
72 s.Server.URL += cstRequestURI
73 s.URL = makeWsProto(s.Server.URL)
77 func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
78 if r.URL.Path != cstPath {
79 t.Logf("path=%v, want %v", r.URL.Path, cstPath)
80 http.Error(w, "bad path", http.StatusBadRequest)
83 if r.URL.RawQuery != cstRawQuery {
84 t.Logf("query=%v, want %v", r.URL.RawQuery, cstRawQuery)
85 http.Error(w, "bad path", http.StatusBadRequest)
88 subprotos := Subprotocols(r)
89 if !reflect.DeepEqual(subprotos, cstDialer.Subprotocols) {
90 t.Logf("subprotols=%v, want %v", subprotos, cstDialer.Subprotocols)
91 http.Error(w, "bad protocol", http.StatusBadRequest)
94 ws, err := cstUpgrader.Upgrade(w, r, http.Header{"Set-Cookie": {"sessionID=1234"}})
96 t.Logf("Upgrade: %v", err)
101 if ws.Subprotocol() != "p1" {
102 t.Logf("Subprotocol() = %s, want p1", ws.Subprotocol())
106 op, rd, err := ws.NextReader()
108 t.Logf("NextReader: %v", err)
111 wr, err := ws.NextWriter(op)
113 t.Logf("NextWriter: %v", err)
116 if _, err = io.Copy(wr, rd); err != nil {
117 t.Logf("NextWriter: %v", err)
120 if err := wr.Close(); err != nil {
121 t.Logf("Close: %v", err)
126 func makeWsProto(s string) string {
127 return "ws" + strings.TrimPrefix(s, "http")
130 func sendRecv(t *testing.T, ws *Conn) {
131 const message = "Hello World!"
132 if err := ws.SetWriteDeadline(time.Now().Add(time.Second)); err != nil {
133 t.Fatalf("SetWriteDeadline: %v", err)
135 if err := ws.WriteMessage(TextMessage, []byte(message)); err != nil {
136 t.Fatalf("WriteMessage: %v", err)
138 if err := ws.SetReadDeadline(time.Now().Add(time.Second)); err != nil {
139 t.Fatalf("SetReadDeadline: %v", err)
141 _, p, err := ws.ReadMessage()
143 t.Fatalf("ReadMessage: %v", err)
145 if string(p) != message {
146 t.Fatalf("message=%s, want %s", p, message)
150 func TestProxyDial(t *testing.T) {
155 surl, _ := url.Parse(s.Server.URL)
157 cstDialer := cstDialer // make local copy for modification on next line.
158 cstDialer.Proxy = http.ProxyURL(surl)
161 origHandler := s.Server.Config.Handler
163 // Capture the request Host header.
164 s.Server.Config.Handler = http.HandlerFunc(
165 func(w http.ResponseWriter, r *http.Request) {
166 if r.Method == "CONNECT" {
168 w.WriteHeader(http.StatusOK)
173 t.Log("connect not received")
174 http.Error(w, "connect not received", http.StatusMethodNotAllowed)
177 origHandler.ServeHTTP(w, r)
180 ws, _, err := cstDialer.Dial(s.URL, nil)
182 t.Fatalf("Dial: %v", err)
188 func TestProxyAuthorizationDial(t *testing.T) {
192 surl, _ := url.Parse(s.Server.URL)
193 surl.User = url.UserPassword("username", "password")
195 cstDialer := cstDialer // make local copy for modification on next line.
196 cstDialer.Proxy = http.ProxyURL(surl)
199 origHandler := s.Server.Config.Handler
201 // Capture the request Host header.
202 s.Server.Config.Handler = http.HandlerFunc(
203 func(w http.ResponseWriter, r *http.Request) {
204 proxyAuth := r.Header.Get("Proxy-Authorization")
205 expectedProxyAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte("username:password"))
206 if r.Method == "CONNECT" && proxyAuth == expectedProxyAuth {
208 w.WriteHeader(http.StatusOK)
213 t.Log("connect with proxy authorization not received")
214 http.Error(w, "connect with proxy authorization not received", http.StatusMethodNotAllowed)
217 origHandler.ServeHTTP(w, r)
220 ws, _, err := cstDialer.Dial(s.URL, nil)
222 t.Fatalf("Dial: %v", err)
228 func TestDial(t *testing.T) {
232 ws, _, err := cstDialer.Dial(s.URL, nil)
234 t.Fatalf("Dial: %v", err)
240 func TestDialCookieJar(t *testing.T) {
244 jar, _ := cookiejar.New(nil)
248 u, _ := url.Parse(s.URL)
257 cookies := []*http.Cookie{{Name: "gorilla", Value: "ws", Path: "/"}}
258 d.Jar.SetCookies(u, cookies)
260 ws, _, err := d.Dial(s.URL, nil)
262 t.Fatalf("Dial: %v", err)
268 for _, c := range d.Jar.Cookies(u) {
269 if c.Name == "gorilla" {
273 if c.Name == "sessionID" {
278 t.Error("Cookie not present in jar.")
281 if sessionID != "1234" {
282 t.Error("Set-Cookie not received from the server.")
288 func rootCAs(t *testing.T, s *httptest.Server) *x509.CertPool {
289 certs := x509.NewCertPool()
290 for _, c := range s.TLS.Certificates {
291 roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
293 t.Fatalf("error parsing server's root cert: %v", err)
295 for _, root := range roots {
302 func TestDialTLS(t *testing.T) {
307 d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
308 ws, _, err := d.Dial(s.URL, nil)
310 t.Fatalf("Dial: %v", err)
316 func TestDialTimeout(t *testing.T) {
321 d.HandshakeTimeout = -1
322 ws, _, err := d.Dial(s.URL, nil)
325 t.Fatalf("Dial: nil")
329 // requireDeadlineNetConn fails the current test when Read or Write are called
331 type requireDeadlineNetConn struct {
334 readDeadlineIsSet bool
335 writeDeadlineIsSet bool
338 func (c *requireDeadlineNetConn) SetDeadline(t time.Time) error {
339 c.writeDeadlineIsSet = !t.Equal(time.Time{})
340 c.readDeadlineIsSet = c.writeDeadlineIsSet
341 return c.c.SetDeadline(t)
344 func (c *requireDeadlineNetConn) SetReadDeadline(t time.Time) error {
345 c.readDeadlineIsSet = !t.Equal(time.Time{})
346 return c.c.SetDeadline(t)
349 func (c *requireDeadlineNetConn) SetWriteDeadline(t time.Time) error {
350 c.writeDeadlineIsSet = !t.Equal(time.Time{})
351 return c.c.SetDeadline(t)
354 func (c *requireDeadlineNetConn) Write(p []byte) (int, error) {
355 if !c.writeDeadlineIsSet {
356 c.t.Fatalf("write with no deadline")
361 func (c *requireDeadlineNetConn) Read(p []byte) (int, error) {
362 if !c.readDeadlineIsSet {
363 c.t.Fatalf("read with no deadline")
368 func (c *requireDeadlineNetConn) Close() error { return c.c.Close() }
369 func (c *requireDeadlineNetConn) LocalAddr() net.Addr { return c.c.LocalAddr() }
370 func (c *requireDeadlineNetConn) RemoteAddr() net.Addr { return c.c.RemoteAddr() }
372 func TestHandshakeTimeout(t *testing.T) {
377 d.NetDial = func(n, a string) (net.Conn, error) {
378 c, err := net.Dial(n, a)
379 return &requireDeadlineNetConn{c: c, t: t}, err
381 ws, _, err := d.Dial(s.URL, nil)
383 t.Fatal("Dial:", err)
388 func TestHandshakeTimeoutInContext(t *testing.T) {
393 d.HandshakeTimeout = 0
394 d.NetDialContext = func(ctx context.Context, n, a string) (net.Conn, error) {
395 netDialer := &net.Dialer{}
396 c, err := netDialer.DialContext(ctx, n, a)
397 return &requireDeadlineNetConn{c: c, t: t}, err
400 ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(30*time.Second))
402 ws, _, err := d.DialContext(ctx, s.URL, nil)
404 t.Fatal("Dial:", err)
409 func TestDialBadScheme(t *testing.T) {
413 ws, _, err := cstDialer.Dial(s.Server.URL, nil)
416 t.Fatalf("Dial: nil")
420 func TestDialBadOrigin(t *testing.T) {
424 ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Origin": {"bad"}})
427 t.Fatalf("Dial: nil")
430 t.Fatalf("resp=nil, err=%v", err)
432 if resp.StatusCode != http.StatusForbidden {
433 t.Fatalf("status=%d, want %d", resp.StatusCode, http.StatusForbidden)
437 func TestDialBadHeader(t *testing.T) {
441 for _, k := range []string{"Upgrade",
444 "Sec-Websocket-Version",
445 "Sec-Websocket-Protocol"} {
448 ws, _, err := cstDialer.Dial(s.URL, http.Header{"Origin": {"bad"}})
451 t.Errorf("Dial with header %s returned nil", k)
456 func TestBadMethod(t *testing.T) {
457 s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
458 ws, err := cstUpgrader.Upgrade(w, r, nil)
460 t.Errorf("handshake succeeded, expect fail")
466 req, err := http.NewRequest("POST", s.URL, strings.NewReader(""))
468 t.Fatalf("NewRequest returned error %v", err)
470 req.Header.Set("Connection", "upgrade")
471 req.Header.Set("Upgrade", "websocket")
472 req.Header.Set("Sec-Websocket-Version", "13")
474 resp, err := http.DefaultClient.Do(req)
476 t.Fatalf("Do returned error %v", err)
479 if resp.StatusCode != http.StatusMethodNotAllowed {
480 t.Errorf("Status = %d, want %d", resp.StatusCode, http.StatusMethodNotAllowed)
484 func TestHandshake(t *testing.T) {
488 ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Origin": {s.URL}})
490 t.Fatalf("Dial: %v", err)
495 for _, c := range resp.Cookies() {
496 if c.Name == "sessionID" {
500 if sessionID != "1234" {
501 t.Error("Set-Cookie not received from the server.")
504 if ws.Subprotocol() != "p1" {
505 t.Errorf("ws.Subprotocol() = %s, want p1", ws.Subprotocol())
510 func TestRespOnBadHandshake(t *testing.T) {
511 const expectedStatus = http.StatusGone
512 const expectedBody = "This is the response body."
514 s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
515 w.WriteHeader(expectedStatus)
516 io.WriteString(w, expectedBody)
520 ws, resp, err := cstDialer.Dial(makeWsProto(s.URL), nil)
523 t.Fatalf("Dial: nil")
527 t.Fatalf("resp=nil, err=%v", err)
530 if resp.StatusCode != expectedStatus {
531 t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus)
534 p, err := ioutil.ReadAll(resp.Body)
536 t.Fatalf("ReadFull(resp.Body) returned error %v", err)
539 if string(p) != expectedBody {
540 t.Errorf("resp.Body=%s, want %s", p, expectedBody)
544 type testLogWriter struct {
548 func (w testLogWriter) Write(p []byte) (int, error) {
553 // TestHost tests handling of host names and confirms that it matches net/http.
554 func TestHost(t *testing.T) {
556 upgrader := Upgrader{}
557 handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
558 if IsWebSocketUpgrade(r) {
559 c, err := upgrader.Upgrade(w, r, http.Header{"X-Test-Host": {r.Host}})
565 w.Header().Set("X-Test-Host", r.Host)
569 server := httptest.NewServer(handler)
572 tlsServer := httptest.NewTLSServer(handler)
573 defer tlsServer.Close()
575 addrs := map[*httptest.Server]string{server: server.Listener.Addr().String(), tlsServer: tlsServer.Listener.Addr().String()}
576 wsProtos := map[*httptest.Server]string{server: "ws://", tlsServer: "wss://"}
577 httpProtos := map[*httptest.Server]string{server: "http://", tlsServer: "https://"}
579 // Avoid log noise from net/http server by logging to testing.T
580 server.Config.ErrorLog = log.New(testLogWriter{t}, "", 0)
581 tlsServer.Config.ErrorLog = server.Config.ErrorLog
583 cas := rootCAs(t, tlsServer)
586 fail bool // true if dial / get should fail
587 server *httptest.Server // server to use
588 url string // host for request URI
589 header string // optional request host header
590 tls string // optiona host for tls ServerName
591 wantAddr string // expected host for dial
592 wantHeader string // expected request header on server
593 insecureSkipVerify bool
598 wantAddr: addrs[server],
599 wantHeader: addrs[server],
603 url: addrs[tlsServer],
604 wantAddr: addrs[tlsServer],
605 wantHeader: addrs[tlsServer],
611 header: "badhost.com",
612 wantAddr: addrs[server],
613 wantHeader: "badhost.com",
617 url: addrs[tlsServer],
618 header: "badhost.com",
619 wantAddr: addrs[tlsServer],
620 wantHeader: "badhost.com",
626 header: "badhost.com",
627 wantAddr: "example.com:80",
628 wantHeader: "badhost.com",
633 header: "badhost.com",
634 wantAddr: "example.com:443",
635 wantHeader: "badhost.com",
641 header: "example.com",
642 wantAddr: "badhost.com:80",
643 wantHeader: "example.com",
649 header: "example.com",
650 wantAddr: "badhost.com:443",
655 insecureSkipVerify: true,
656 wantAddr: "badhost.com:443",
657 wantHeader: "badhost.com",
663 wantAddr: "badhost.com:443",
664 wantHeader: "badhost.com",
668 for i, tt := range tests {
673 InsecureSkipVerify: tt.insecureSkipVerify,
678 NetDial: func(network, addr string) (net.Conn, error) {
680 return net.Dial(network, addrs[tt.server])
682 TLSClientConfig: tls,
685 // Test websocket dial
689 h.Set("Host", tt.header)
691 c, resp, err := dialer.Dial(wsProtos[tt.server]+tt.url+"/", h)
696 check := func(protos map[*httptest.Server]string) {
697 name := fmt.Sprintf("%d: %s%s/ header[Host]=%q, tls.ServerName=%q", i+1, protos[tt.server], tt.url, tt.header, tt.tls)
698 if gotAddr != tt.wantAddr {
699 t.Errorf("%s: got addr %s, want %s", name, gotAddr, tt.wantAddr)
702 case tt.fail && err == nil:
703 t.Errorf("%s: unexpected success", name)
704 case !tt.fail && err != nil:
705 t.Errorf("%s: unexpected error %v", name, err)
706 case !tt.fail && err == nil:
707 if gotHost := resp.Header.Get("X-Test-Host"); gotHost != tt.wantHeader {
708 t.Errorf("%s: got host %s, want %s", name, gotHost, tt.wantHeader)
715 // Confirm that net/http has same result
717 transport := &http.Transport{
718 Dial: dialer.NetDial,
719 TLSClientConfig: dialer.TLSClientConfig,
721 req, _ := http.NewRequest("GET", httpProtos[tt.server]+tt.url+"/", nil)
725 client := &http.Client{Transport: transport}
726 resp, err = client.Do(req)
730 transport.CloseIdleConnections()
735 func TestDialCompression(t *testing.T) {
740 dialer.EnableCompression = true
741 ws, _, err := dialer.Dial(s.URL, nil)
743 t.Fatalf("Dial: %v", err)
749 func TestSocksProxyDial(t *testing.T) {
753 proxyListener, err := net.Listen("tcp", "127.0.0.1:0")
755 t.Fatalf("listen failed: %v", err)
757 defer proxyListener.Close()
759 c1, err := proxyListener.Accept()
761 t.Errorf("proxy accept failed: %v", err)
766 c1.SetDeadline(time.Now().Add(30 * time.Second))
768 buf := make([]byte, 32)
769 if _, err := io.ReadFull(c1, buf[:3]); err != nil {
770 t.Errorf("read failed: %v", err)
773 if want := []byte{5, 1, 0}; !bytes.Equal(want, buf[:len(want)]) {
774 t.Errorf("read %x, want %x", buf[:len(want)], want)
776 if _, err := c1.Write([]byte{5, 0}); err != nil {
777 t.Errorf("write failed: %v", err)
780 if _, err := io.ReadFull(c1, buf[:10]); err != nil {
781 t.Errorf("read failed: %v", err)
784 if want := []byte{5, 1, 0, 1}; !bytes.Equal(want, buf[:len(want)]) {
785 t.Errorf("read %x, want %x", buf[:len(want)], want)
789 if _, err := c1.Write(buf[:10]); err != nil {
790 t.Errorf("write failed: %v", err)
794 ip := net.IP(buf[4:8])
795 port := binary.BigEndian.Uint16(buf[8:10])
797 c2, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: ip, Port: int(port)})
799 t.Errorf("dial failed; %v", err)
803 done := make(chan struct{})
812 purl, err := url.Parse("socks5://" + proxyListener.Addr().String())
814 t.Fatalf("parse failed: %v", err)
817 cstDialer := cstDialer // make local copy for modification on next line.
818 cstDialer.Proxy = http.ProxyURL(purl)
820 ws, _, err := cstDialer.Dial(s.URL, nil)
822 t.Fatalf("Dial: %v", err)
828 func TestTracingDialWithContext(t *testing.T) {
830 var headersWrote, requestWrote, getConn, gotConn, connectDone, gotFirstResponseByte bool
831 trace := &httptrace.ClientTrace{
832 WroteHeaders: func() {
835 WroteRequest: func(httptrace.WroteRequestInfo) {
838 GetConn: func(hostPort string) {
841 GotConn: func(info httptrace.GotConnInfo) {
844 ConnectDone: func(network, addr string, err error) {
847 GotFirstResponseByte: func() {
848 gotFirstResponseByte = true
851 ctx := httptrace.WithClientTrace(context.Background(), trace)
857 d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
859 ws, _, err := d.DialContext(ctx, s.URL, nil)
861 t.Fatalf("Dial: %v", err)
865 t.Fatal("Headers was not written")
868 t.Fatal("Request was not written")
871 t.Fatal("getConn was not called")
874 t.Fatal("gotConn was not called")
877 t.Fatal("connectDone was not called")
879 if !gotFirstResponseByte {
880 t.Fatal("GotFirstResponseByte was not called")
887 func TestEmptyTracingDialWithContext(t *testing.T) {
889 trace := &httptrace.ClientTrace{}
890 ctx := httptrace.WithClientTrace(context.Background(), trace)
896 d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
898 ws, _, err := d.DialContext(ctx, s.URL, nil)
900 t.Fatalf("Dial: %v", err)