OSDN Git Service

Hulk did something
[bytom/vapor.git] / vendor / github.com / gorilla / websocket / client_server_test.go
1 // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package websocket
6
7 import (
8         "bytes"
9         "context"
10         "crypto/tls"
11         "crypto/x509"
12         "encoding/base64"
13         "encoding/binary"
14         "fmt"
15         "io"
16         "io/ioutil"
17         "log"
18         "net"
19         "net/http"
20         "net/http/cookiejar"
21         "net/http/httptest"
22         "net/http/httptrace"
23         "net/url"
24         "reflect"
25         "strings"
26         "testing"
27         "time"
28 )
29
30 var cstUpgrader = Upgrader{
31         Subprotocols:      []string{"p0", "p1"},
32         ReadBufferSize:    1024,
33         WriteBufferSize:   1024,
34         EnableCompression: true,
35         Error: func(w http.ResponseWriter, r *http.Request, status int, reason error) {
36                 http.Error(w, reason.Error(), status)
37         },
38 }
39
40 var cstDialer = Dialer{
41         Subprotocols:     []string{"p1", "p2"},
42         ReadBufferSize:   1024,
43         WriteBufferSize:  1024,
44         HandshakeTimeout: 30 * time.Second,
45 }
46
47 type cstHandler struct{ *testing.T }
48
49 type cstServer struct {
50         *httptest.Server
51         URL string
52         t   *testing.T
53 }
54
55 const (
56         cstPath       = "/a/b"
57         cstRawQuery   = "x=y"
58         cstRequestURI = cstPath + "?" + cstRawQuery
59 )
60
61 func newServer(t *testing.T) *cstServer {
62         var s cstServer
63         s.Server = httptest.NewServer(cstHandler{t})
64         s.Server.URL += cstRequestURI
65         s.URL = makeWsProto(s.Server.URL)
66         return &s
67 }
68
69 func newTLSServer(t *testing.T) *cstServer {
70         var s cstServer
71         s.Server = httptest.NewTLSServer(cstHandler{t})
72         s.Server.URL += cstRequestURI
73         s.URL = makeWsProto(s.Server.URL)
74         return &s
75 }
76
77 func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
78         if r.URL.Path != cstPath {
79                 t.Logf("path=%v, want %v", r.URL.Path, cstPath)
80                 http.Error(w, "bad path", http.StatusBadRequest)
81                 return
82         }
83         if r.URL.RawQuery != cstRawQuery {
84                 t.Logf("query=%v, want %v", r.URL.RawQuery, cstRawQuery)
85                 http.Error(w, "bad path", http.StatusBadRequest)
86                 return
87         }
88         subprotos := Subprotocols(r)
89         if !reflect.DeepEqual(subprotos, cstDialer.Subprotocols) {
90                 t.Logf("subprotols=%v, want %v", subprotos, cstDialer.Subprotocols)
91                 http.Error(w, "bad protocol", http.StatusBadRequest)
92                 return
93         }
94         ws, err := cstUpgrader.Upgrade(w, r, http.Header{"Set-Cookie": {"sessionID=1234"}})
95         if err != nil {
96                 t.Logf("Upgrade: %v", err)
97                 return
98         }
99         defer ws.Close()
100
101         if ws.Subprotocol() != "p1" {
102                 t.Logf("Subprotocol() = %s, want p1", ws.Subprotocol())
103                 ws.Close()
104                 return
105         }
106         op, rd, err := ws.NextReader()
107         if err != nil {
108                 t.Logf("NextReader: %v", err)
109                 return
110         }
111         wr, err := ws.NextWriter(op)
112         if err != nil {
113                 t.Logf("NextWriter: %v", err)
114                 return
115         }
116         if _, err = io.Copy(wr, rd); err != nil {
117                 t.Logf("NextWriter: %v", err)
118                 return
119         }
120         if err := wr.Close(); err != nil {
121                 t.Logf("Close: %v", err)
122                 return
123         }
124 }
125
126 func makeWsProto(s string) string {
127         return "ws" + strings.TrimPrefix(s, "http")
128 }
129
130 func sendRecv(t *testing.T, ws *Conn) {
131         const message = "Hello World!"
132         if err := ws.SetWriteDeadline(time.Now().Add(time.Second)); err != nil {
133                 t.Fatalf("SetWriteDeadline: %v", err)
134         }
135         if err := ws.WriteMessage(TextMessage, []byte(message)); err != nil {
136                 t.Fatalf("WriteMessage: %v", err)
137         }
138         if err := ws.SetReadDeadline(time.Now().Add(time.Second)); err != nil {
139                 t.Fatalf("SetReadDeadline: %v", err)
140         }
141         _, p, err := ws.ReadMessage()
142         if err != nil {
143                 t.Fatalf("ReadMessage: %v", err)
144         }
145         if string(p) != message {
146                 t.Fatalf("message=%s, want %s", p, message)
147         }
148 }
149
150 func TestProxyDial(t *testing.T) {
151
152         s := newServer(t)
153         defer s.Close()
154
155         surl, _ := url.Parse(s.Server.URL)
156
157         cstDialer := cstDialer // make local copy for modification on next line.
158         cstDialer.Proxy = http.ProxyURL(surl)
159
160         connect := false
161         origHandler := s.Server.Config.Handler
162
163         // Capture the request Host header.
164         s.Server.Config.Handler = http.HandlerFunc(
165                 func(w http.ResponseWriter, r *http.Request) {
166                         if r.Method == "CONNECT" {
167                                 connect = true
168                                 w.WriteHeader(http.StatusOK)
169                                 return
170                         }
171
172                         if !connect {
173                                 t.Log("connect not received")
174                                 http.Error(w, "connect not received", http.StatusMethodNotAllowed)
175                                 return
176                         }
177                         origHandler.ServeHTTP(w, r)
178                 })
179
180         ws, _, err := cstDialer.Dial(s.URL, nil)
181         if err != nil {
182                 t.Fatalf("Dial: %v", err)
183         }
184         defer ws.Close()
185         sendRecv(t, ws)
186 }
187
188 func TestProxyAuthorizationDial(t *testing.T) {
189         s := newServer(t)
190         defer s.Close()
191
192         surl, _ := url.Parse(s.Server.URL)
193         surl.User = url.UserPassword("username", "password")
194
195         cstDialer := cstDialer // make local copy for modification on next line.
196         cstDialer.Proxy = http.ProxyURL(surl)
197
198         connect := false
199         origHandler := s.Server.Config.Handler
200
201         // Capture the request Host header.
202         s.Server.Config.Handler = http.HandlerFunc(
203                 func(w http.ResponseWriter, r *http.Request) {
204                         proxyAuth := r.Header.Get("Proxy-Authorization")
205                         expectedProxyAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte("username:password"))
206                         if r.Method == "CONNECT" && proxyAuth == expectedProxyAuth {
207                                 connect = true
208                                 w.WriteHeader(http.StatusOK)
209                                 return
210                         }
211
212                         if !connect {
213                                 t.Log("connect with proxy authorization not received")
214                                 http.Error(w, "connect with proxy authorization not received", http.StatusMethodNotAllowed)
215                                 return
216                         }
217                         origHandler.ServeHTTP(w, r)
218                 })
219
220         ws, _, err := cstDialer.Dial(s.URL, nil)
221         if err != nil {
222                 t.Fatalf("Dial: %v", err)
223         }
224         defer ws.Close()
225         sendRecv(t, ws)
226 }
227
228 func TestDial(t *testing.T) {
229         s := newServer(t)
230         defer s.Close()
231
232         ws, _, err := cstDialer.Dial(s.URL, nil)
233         if err != nil {
234                 t.Fatalf("Dial: %v", err)
235         }
236         defer ws.Close()
237         sendRecv(t, ws)
238 }
239
240 func TestDialCookieJar(t *testing.T) {
241         s := newServer(t)
242         defer s.Close()
243
244         jar, _ := cookiejar.New(nil)
245         d := cstDialer
246         d.Jar = jar
247
248         u, _ := url.Parse(s.URL)
249
250         switch u.Scheme {
251         case "ws":
252                 u.Scheme = "http"
253         case "wss":
254                 u.Scheme = "https"
255         }
256
257         cookies := []*http.Cookie{{Name: "gorilla", Value: "ws", Path: "/"}}
258         d.Jar.SetCookies(u, cookies)
259
260         ws, _, err := d.Dial(s.URL, nil)
261         if err != nil {
262                 t.Fatalf("Dial: %v", err)
263         }
264         defer ws.Close()
265
266         var gorilla string
267         var sessionID string
268         for _, c := range d.Jar.Cookies(u) {
269                 if c.Name == "gorilla" {
270                         gorilla = c.Value
271                 }
272
273                 if c.Name == "sessionID" {
274                         sessionID = c.Value
275                 }
276         }
277         if gorilla != "ws" {
278                 t.Error("Cookie not present in jar.")
279         }
280
281         if sessionID != "1234" {
282                 t.Error("Set-Cookie not received from the server.")
283         }
284
285         sendRecv(t, ws)
286 }
287
288 func rootCAs(t *testing.T, s *httptest.Server) *x509.CertPool {
289         certs := x509.NewCertPool()
290         for _, c := range s.TLS.Certificates {
291                 roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
292                 if err != nil {
293                         t.Fatalf("error parsing server's root cert: %v", err)
294                 }
295                 for _, root := range roots {
296                         certs.AddCert(root)
297                 }
298         }
299         return certs
300 }
301
302 func TestDialTLS(t *testing.T) {
303         s := newTLSServer(t)
304         defer s.Close()
305
306         d := cstDialer
307         d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
308         ws, _, err := d.Dial(s.URL, nil)
309         if err != nil {
310                 t.Fatalf("Dial: %v", err)
311         }
312         defer ws.Close()
313         sendRecv(t, ws)
314 }
315
316 func TestDialTimeout(t *testing.T) {
317         s := newServer(t)
318         defer s.Close()
319
320         d := cstDialer
321         d.HandshakeTimeout = -1
322         ws, _, err := d.Dial(s.URL, nil)
323         if err == nil {
324                 ws.Close()
325                 t.Fatalf("Dial: nil")
326         }
327 }
328
329 // requireDeadlineNetConn fails the current test when Read or Write are called
330 // with no deadline.
331 type requireDeadlineNetConn struct {
332         t                  *testing.T
333         c                  net.Conn
334         readDeadlineIsSet  bool
335         writeDeadlineIsSet bool
336 }
337
338 func (c *requireDeadlineNetConn) SetDeadline(t time.Time) error {
339         c.writeDeadlineIsSet = !t.Equal(time.Time{})
340         c.readDeadlineIsSet = c.writeDeadlineIsSet
341         return c.c.SetDeadline(t)
342 }
343
344 func (c *requireDeadlineNetConn) SetReadDeadline(t time.Time) error {
345         c.readDeadlineIsSet = !t.Equal(time.Time{})
346         return c.c.SetDeadline(t)
347 }
348
349 func (c *requireDeadlineNetConn) SetWriteDeadline(t time.Time) error {
350         c.writeDeadlineIsSet = !t.Equal(time.Time{})
351         return c.c.SetDeadline(t)
352 }
353
354 func (c *requireDeadlineNetConn) Write(p []byte) (int, error) {
355         if !c.writeDeadlineIsSet {
356                 c.t.Fatalf("write with no deadline")
357         }
358         return c.c.Write(p)
359 }
360
361 func (c *requireDeadlineNetConn) Read(p []byte) (int, error) {
362         if !c.readDeadlineIsSet {
363                 c.t.Fatalf("read with no deadline")
364         }
365         return c.c.Read(p)
366 }
367
368 func (c *requireDeadlineNetConn) Close() error         { return c.c.Close() }
369 func (c *requireDeadlineNetConn) LocalAddr() net.Addr  { return c.c.LocalAddr() }
370 func (c *requireDeadlineNetConn) RemoteAddr() net.Addr { return c.c.RemoteAddr() }
371
372 func TestHandshakeTimeout(t *testing.T) {
373         s := newServer(t)
374         defer s.Close()
375
376         d := cstDialer
377         d.NetDial = func(n, a string) (net.Conn, error) {
378                 c, err := net.Dial(n, a)
379                 return &requireDeadlineNetConn{c: c, t: t}, err
380         }
381         ws, _, err := d.Dial(s.URL, nil)
382         if err != nil {
383                 t.Fatal("Dial:", err)
384         }
385         ws.Close()
386 }
387
388 func TestHandshakeTimeoutInContext(t *testing.T) {
389         s := newServer(t)
390         defer s.Close()
391
392         d := cstDialer
393         d.HandshakeTimeout = 0
394         d.NetDialContext = func(ctx context.Context, n, a string) (net.Conn, error) {
395                 netDialer := &net.Dialer{}
396                 c, err := netDialer.DialContext(ctx, n, a)
397                 return &requireDeadlineNetConn{c: c, t: t}, err
398         }
399
400         ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(30*time.Second))
401         defer cancel()
402         ws, _, err := d.DialContext(ctx, s.URL, nil)
403         if err != nil {
404                 t.Fatal("Dial:", err)
405         }
406         ws.Close()
407 }
408
409 func TestDialBadScheme(t *testing.T) {
410         s := newServer(t)
411         defer s.Close()
412
413         ws, _, err := cstDialer.Dial(s.Server.URL, nil)
414         if err == nil {
415                 ws.Close()
416                 t.Fatalf("Dial: nil")
417         }
418 }
419
420 func TestDialBadOrigin(t *testing.T) {
421         s := newServer(t)
422         defer s.Close()
423
424         ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Origin": {"bad"}})
425         if err == nil {
426                 ws.Close()
427                 t.Fatalf("Dial: nil")
428         }
429         if resp == nil {
430                 t.Fatalf("resp=nil, err=%v", err)
431         }
432         if resp.StatusCode != http.StatusForbidden {
433                 t.Fatalf("status=%d, want %d", resp.StatusCode, http.StatusForbidden)
434         }
435 }
436
437 func TestDialBadHeader(t *testing.T) {
438         s := newServer(t)
439         defer s.Close()
440
441         for _, k := range []string{"Upgrade",
442                 "Connection",
443                 "Sec-Websocket-Key",
444                 "Sec-Websocket-Version",
445                 "Sec-Websocket-Protocol"} {
446                 h := http.Header{}
447                 h.Set(k, "bad")
448                 ws, _, err := cstDialer.Dial(s.URL, http.Header{"Origin": {"bad"}})
449                 if err == nil {
450                         ws.Close()
451                         t.Errorf("Dial with header %s returned nil", k)
452                 }
453         }
454 }
455
456 func TestBadMethod(t *testing.T) {
457         s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
458                 ws, err := cstUpgrader.Upgrade(w, r, nil)
459                 if err == nil {
460                         t.Errorf("handshake succeeded, expect fail")
461                         ws.Close()
462                 }
463         }))
464         defer s.Close()
465
466         req, err := http.NewRequest("POST", s.URL, strings.NewReader(""))
467         if err != nil {
468                 t.Fatalf("NewRequest returned error %v", err)
469         }
470         req.Header.Set("Connection", "upgrade")
471         req.Header.Set("Upgrade", "websocket")
472         req.Header.Set("Sec-Websocket-Version", "13")
473
474         resp, err := http.DefaultClient.Do(req)
475         if err != nil {
476                 t.Fatalf("Do returned error %v", err)
477         }
478         resp.Body.Close()
479         if resp.StatusCode != http.StatusMethodNotAllowed {
480                 t.Errorf("Status = %d, want %d", resp.StatusCode, http.StatusMethodNotAllowed)
481         }
482 }
483
484 func TestHandshake(t *testing.T) {
485         s := newServer(t)
486         defer s.Close()
487
488         ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Origin": {s.URL}})
489         if err != nil {
490                 t.Fatalf("Dial: %v", err)
491         }
492         defer ws.Close()
493
494         var sessionID string
495         for _, c := range resp.Cookies() {
496                 if c.Name == "sessionID" {
497                         sessionID = c.Value
498                 }
499         }
500         if sessionID != "1234" {
501                 t.Error("Set-Cookie not received from the server.")
502         }
503
504         if ws.Subprotocol() != "p1" {
505                 t.Errorf("ws.Subprotocol() = %s, want p1", ws.Subprotocol())
506         }
507         sendRecv(t, ws)
508 }
509
510 func TestRespOnBadHandshake(t *testing.T) {
511         const expectedStatus = http.StatusGone
512         const expectedBody = "This is the response body."
513
514         s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
515                 w.WriteHeader(expectedStatus)
516                 io.WriteString(w, expectedBody)
517         }))
518         defer s.Close()
519
520         ws, resp, err := cstDialer.Dial(makeWsProto(s.URL), nil)
521         if err == nil {
522                 ws.Close()
523                 t.Fatalf("Dial: nil")
524         }
525
526         if resp == nil {
527                 t.Fatalf("resp=nil, err=%v", err)
528         }
529
530         if resp.StatusCode != expectedStatus {
531                 t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus)
532         }
533
534         p, err := ioutil.ReadAll(resp.Body)
535         if err != nil {
536                 t.Fatalf("ReadFull(resp.Body) returned error %v", err)
537         }
538
539         if string(p) != expectedBody {
540                 t.Errorf("resp.Body=%s, want %s", p, expectedBody)
541         }
542 }
543
544 type testLogWriter struct {
545         t *testing.T
546 }
547
548 func (w testLogWriter) Write(p []byte) (int, error) {
549         w.t.Logf("%s", p)
550         return len(p), nil
551 }
552
553 // TestHost tests handling of host names and confirms that it matches net/http.
554 func TestHost(t *testing.T) {
555
556         upgrader := Upgrader{}
557         handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
558                 if IsWebSocketUpgrade(r) {
559                         c, err := upgrader.Upgrade(w, r, http.Header{"X-Test-Host": {r.Host}})
560                         if err != nil {
561                                 t.Fatal(err)
562                         }
563                         c.Close()
564                 } else {
565                         w.Header().Set("X-Test-Host", r.Host)
566                 }
567         })
568
569         server := httptest.NewServer(handler)
570         defer server.Close()
571
572         tlsServer := httptest.NewTLSServer(handler)
573         defer tlsServer.Close()
574
575         addrs := map[*httptest.Server]string{server: server.Listener.Addr().String(), tlsServer: tlsServer.Listener.Addr().String()}
576         wsProtos := map[*httptest.Server]string{server: "ws://", tlsServer: "wss://"}
577         httpProtos := map[*httptest.Server]string{server: "http://", tlsServer: "https://"}
578
579         // Avoid log noise from net/http server by logging to testing.T
580         server.Config.ErrorLog = log.New(testLogWriter{t}, "", 0)
581         tlsServer.Config.ErrorLog = server.Config.ErrorLog
582
583         cas := rootCAs(t, tlsServer)
584
585         tests := []struct {
586                 fail               bool             // true if dial / get should fail
587                 server             *httptest.Server // server to use
588                 url                string           // host for request URI
589                 header             string           // optional request host header
590                 tls                string           // optiona host for tls ServerName
591                 wantAddr           string           // expected host for dial
592                 wantHeader         string           // expected request header on server
593                 insecureSkipVerify bool
594         }{
595                 {
596                         server:     server,
597                         url:        addrs[server],
598                         wantAddr:   addrs[server],
599                         wantHeader: addrs[server],
600                 },
601                 {
602                         server:     tlsServer,
603                         url:        addrs[tlsServer],
604                         wantAddr:   addrs[tlsServer],
605                         wantHeader: addrs[tlsServer],
606                 },
607
608                 {
609                         server:     server,
610                         url:        addrs[server],
611                         header:     "badhost.com",
612                         wantAddr:   addrs[server],
613                         wantHeader: "badhost.com",
614                 },
615                 {
616                         server:     tlsServer,
617                         url:        addrs[tlsServer],
618                         header:     "badhost.com",
619                         wantAddr:   addrs[tlsServer],
620                         wantHeader: "badhost.com",
621                 },
622
623                 {
624                         server:     server,
625                         url:        "example.com",
626                         header:     "badhost.com",
627                         wantAddr:   "example.com:80",
628                         wantHeader: "badhost.com",
629                 },
630                 {
631                         server:     tlsServer,
632                         url:        "example.com",
633                         header:     "badhost.com",
634                         wantAddr:   "example.com:443",
635                         wantHeader: "badhost.com",
636                 },
637
638                 {
639                         server:     server,
640                         url:        "badhost.com",
641                         header:     "example.com",
642                         wantAddr:   "badhost.com:80",
643                         wantHeader: "example.com",
644                 },
645                 {
646                         fail:     true,
647                         server:   tlsServer,
648                         url:      "badhost.com",
649                         header:   "example.com",
650                         wantAddr: "badhost.com:443",
651                 },
652                 {
653                         server:             tlsServer,
654                         url:                "badhost.com",
655                         insecureSkipVerify: true,
656                         wantAddr:           "badhost.com:443",
657                         wantHeader:         "badhost.com",
658                 },
659                 {
660                         server:     tlsServer,
661                         url:        "badhost.com",
662                         tls:        "example.com",
663                         wantAddr:   "badhost.com:443",
664                         wantHeader: "badhost.com",
665                 },
666         }
667
668         for i, tt := range tests {
669
670                 tls := &tls.Config{
671                         RootCAs:            cas,
672                         ServerName:         tt.tls,
673                         InsecureSkipVerify: tt.insecureSkipVerify,
674                 }
675
676                 var gotAddr string
677                 dialer := Dialer{
678                         NetDial: func(network, addr string) (net.Conn, error) {
679                                 gotAddr = addr
680                                 return net.Dial(network, addrs[tt.server])
681                         },
682                         TLSClientConfig: tls,
683                 }
684
685                 // Test websocket dial
686
687                 h := http.Header{}
688                 if tt.header != "" {
689                         h.Set("Host", tt.header)
690                 }
691                 c, resp, err := dialer.Dial(wsProtos[tt.server]+tt.url+"/", h)
692                 if err == nil {
693                         c.Close()
694                 }
695
696                 check := func(protos map[*httptest.Server]string) {
697                         name := fmt.Sprintf("%d: %s%s/ header[Host]=%q, tls.ServerName=%q", i+1, protos[tt.server], tt.url, tt.header, tt.tls)
698                         if gotAddr != tt.wantAddr {
699                                 t.Errorf("%s: got addr %s, want %s", name, gotAddr, tt.wantAddr)
700                         }
701                         switch {
702                         case tt.fail && err == nil:
703                                 t.Errorf("%s: unexpected success", name)
704                         case !tt.fail && err != nil:
705                                 t.Errorf("%s: unexpected error %v", name, err)
706                         case !tt.fail && err == nil:
707                                 if gotHost := resp.Header.Get("X-Test-Host"); gotHost != tt.wantHeader {
708                                         t.Errorf("%s: got host %s, want %s", name, gotHost, tt.wantHeader)
709                                 }
710                         }
711                 }
712
713                 check(wsProtos)
714
715                 // Confirm that net/http has same result
716
717                 transport := &http.Transport{
718                         Dial:            dialer.NetDial,
719                         TLSClientConfig: dialer.TLSClientConfig,
720                 }
721                 req, _ := http.NewRequest("GET", httpProtos[tt.server]+tt.url+"/", nil)
722                 if tt.header != "" {
723                         req.Host = tt.header
724                 }
725                 client := &http.Client{Transport: transport}
726                 resp, err = client.Do(req)
727                 if err == nil {
728                         resp.Body.Close()
729                 }
730                 transport.CloseIdleConnections()
731                 check(httpProtos)
732         }
733 }
734
735 func TestDialCompression(t *testing.T) {
736         s := newServer(t)
737         defer s.Close()
738
739         dialer := cstDialer
740         dialer.EnableCompression = true
741         ws, _, err := dialer.Dial(s.URL, nil)
742         if err != nil {
743                 t.Fatalf("Dial: %v", err)
744         }
745         defer ws.Close()
746         sendRecv(t, ws)
747 }
748
749 func TestSocksProxyDial(t *testing.T) {
750         s := newServer(t)
751         defer s.Close()
752
753         proxyListener, err := net.Listen("tcp", "127.0.0.1:0")
754         if err != nil {
755                 t.Fatalf("listen failed: %v", err)
756         }
757         defer proxyListener.Close()
758         go func() {
759                 c1, err := proxyListener.Accept()
760                 if err != nil {
761                         t.Errorf("proxy accept failed: %v", err)
762                         return
763                 }
764                 defer c1.Close()
765
766                 c1.SetDeadline(time.Now().Add(30 * time.Second))
767
768                 buf := make([]byte, 32)
769                 if _, err := io.ReadFull(c1, buf[:3]); err != nil {
770                         t.Errorf("read failed: %v", err)
771                         return
772                 }
773                 if want := []byte{5, 1, 0}; !bytes.Equal(want, buf[:len(want)]) {
774                         t.Errorf("read %x, want %x", buf[:len(want)], want)
775                 }
776                 if _, err := c1.Write([]byte{5, 0}); err != nil {
777                         t.Errorf("write failed: %v", err)
778                         return
779                 }
780                 if _, err := io.ReadFull(c1, buf[:10]); err != nil {
781                         t.Errorf("read failed: %v", err)
782                         return
783                 }
784                 if want := []byte{5, 1, 0, 1}; !bytes.Equal(want, buf[:len(want)]) {
785                         t.Errorf("read %x, want %x", buf[:len(want)], want)
786                         return
787                 }
788                 buf[1] = 0
789                 if _, err := c1.Write(buf[:10]); err != nil {
790                         t.Errorf("write failed: %v", err)
791                         return
792                 }
793
794                 ip := net.IP(buf[4:8])
795                 port := binary.BigEndian.Uint16(buf[8:10])
796
797                 c2, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: ip, Port: int(port)})
798                 if err != nil {
799                         t.Errorf("dial failed; %v", err)
800                         return
801                 }
802                 defer c2.Close()
803                 done := make(chan struct{})
804                 go func() {
805                         io.Copy(c1, c2)
806                         close(done)
807                 }()
808                 io.Copy(c2, c1)
809                 <-done
810         }()
811
812         purl, err := url.Parse("socks5://" + proxyListener.Addr().String())
813         if err != nil {
814                 t.Fatalf("parse failed: %v", err)
815         }
816
817         cstDialer := cstDialer // make local copy for modification on next line.
818         cstDialer.Proxy = http.ProxyURL(purl)
819
820         ws, _, err := cstDialer.Dial(s.URL, nil)
821         if err != nil {
822                 t.Fatalf("Dial: %v", err)
823         }
824         defer ws.Close()
825         sendRecv(t, ws)
826 }
827
828 func TestTracingDialWithContext(t *testing.T) {
829
830         var headersWrote, requestWrote, getConn, gotConn, connectDone, gotFirstResponseByte bool
831         trace := &httptrace.ClientTrace{
832                 WroteHeaders: func() {
833                         headersWrote = true
834                 },
835                 WroteRequest: func(httptrace.WroteRequestInfo) {
836                         requestWrote = true
837                 },
838                 GetConn: func(hostPort string) {
839                         getConn = true
840                 },
841                 GotConn: func(info httptrace.GotConnInfo) {
842                         gotConn = true
843                 },
844                 ConnectDone: func(network, addr string, err error) {
845                         connectDone = true
846                 },
847                 GotFirstResponseByte: func() {
848                         gotFirstResponseByte = true
849                 },
850         }
851         ctx := httptrace.WithClientTrace(context.Background(), trace)
852
853         s := newTLSServer(t)
854         defer s.Close()
855
856         d := cstDialer
857         d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
858
859         ws, _, err := d.DialContext(ctx, s.URL, nil)
860         if err != nil {
861                 t.Fatalf("Dial: %v", err)
862         }
863
864         if !headersWrote {
865                 t.Fatal("Headers was not written")
866         }
867         if !requestWrote {
868                 t.Fatal("Request was not written")
869         }
870         if !getConn {
871                 t.Fatal("getConn was not called")
872         }
873         if !gotConn {
874                 t.Fatal("gotConn was not called")
875         }
876         if !connectDone {
877                 t.Fatal("connectDone was not called")
878         }
879         if !gotFirstResponseByte {
880                 t.Fatal("GotFirstResponseByte was not called")
881         }
882
883         defer ws.Close()
884         sendRecv(t, ws)
885 }
886
887 func TestEmptyTracingDialWithContext(t *testing.T) {
888
889         trace := &httptrace.ClientTrace{}
890         ctx := httptrace.WithClientTrace(context.Background(), trace)
891
892         s := newTLSServer(t)
893         defer s.Close()
894
895         d := cstDialer
896         d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
897
898         ws, _, err := d.DialContext(ctx, s.URL, nil)
899         if err != nil {
900                 t.Fatalf("Dial: %v", err)
901         }
902
903         defer ws.Close()
904         sendRecv(t, ws)
905 }