OSDN Git Service

new repo
[bytom/vapor.git] / vendor / golang.org / x / net / websocket / websocket_test.go
1 // Copyright 2009 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package websocket
6
7 import (
8         "bytes"
9         "crypto/rand"
10         "fmt"
11         "io"
12         "log"
13         "net"
14         "net/http"
15         "net/http/httptest"
16         "net/url"
17         "reflect"
18         "runtime"
19         "strings"
20         "sync"
21         "testing"
22         "time"
23 )
24
25 var serverAddr string
26 var once sync.Once
27
28 func echoServer(ws *Conn) {
29         defer ws.Close()
30         io.Copy(ws, ws)
31 }
32
33 type Count struct {
34         S string
35         N int
36 }
37
38 func countServer(ws *Conn) {
39         defer ws.Close()
40         for {
41                 var count Count
42                 err := JSON.Receive(ws, &count)
43                 if err != nil {
44                         return
45                 }
46                 count.N++
47                 count.S = strings.Repeat(count.S, count.N)
48                 err = JSON.Send(ws, count)
49                 if err != nil {
50                         return
51                 }
52         }
53 }
54
55 type testCtrlAndDataHandler struct {
56         hybiFrameHandler
57 }
58
59 func (h *testCtrlAndDataHandler) WritePing(b []byte) (int, error) {
60         h.hybiFrameHandler.conn.wio.Lock()
61         defer h.hybiFrameHandler.conn.wio.Unlock()
62         w, err := h.hybiFrameHandler.conn.frameWriterFactory.NewFrameWriter(PingFrame)
63         if err != nil {
64                 return 0, err
65         }
66         n, err := w.Write(b)
67         w.Close()
68         return n, err
69 }
70
71 func ctrlAndDataServer(ws *Conn) {
72         defer ws.Close()
73         h := &testCtrlAndDataHandler{hybiFrameHandler: hybiFrameHandler{conn: ws}}
74         ws.frameHandler = h
75
76         go func() {
77                 for i := 0; ; i++ {
78                         var b []byte
79                         if i%2 != 0 { // with or without payload
80                                 b = []byte(fmt.Sprintf("#%d-CONTROL-FRAME-FROM-SERVER", i))
81                         }
82                         if _, err := h.WritePing(b); err != nil {
83                                 break
84                         }
85                         if _, err := h.WritePong(b); err != nil { // unsolicited pong
86                                 break
87                         }
88                         time.Sleep(10 * time.Millisecond)
89                 }
90         }()
91
92         b := make([]byte, 128)
93         for {
94                 n, err := ws.Read(b)
95                 if err != nil {
96                         break
97                 }
98                 if _, err := ws.Write(b[:n]); err != nil {
99                         break
100                 }
101         }
102 }
103
104 func subProtocolHandshake(config *Config, req *http.Request) error {
105         for _, proto := range config.Protocol {
106                 if proto == "chat" {
107                         config.Protocol = []string{proto}
108                         return nil
109                 }
110         }
111         return ErrBadWebSocketProtocol
112 }
113
114 func subProtoServer(ws *Conn) {
115         for _, proto := range ws.Config().Protocol {
116                 io.WriteString(ws, proto)
117         }
118 }
119
120 func startServer() {
121         http.Handle("/echo", Handler(echoServer))
122         http.Handle("/count", Handler(countServer))
123         http.Handle("/ctrldata", Handler(ctrlAndDataServer))
124         subproto := Server{
125                 Handshake: subProtocolHandshake,
126                 Handler:   Handler(subProtoServer),
127         }
128         http.Handle("/subproto", subproto)
129         server := httptest.NewServer(nil)
130         serverAddr = server.Listener.Addr().String()
131         log.Print("Test WebSocket server listening on ", serverAddr)
132 }
133
134 func newConfig(t *testing.T, path string) *Config {
135         config, _ := NewConfig(fmt.Sprintf("ws://%s%s", serverAddr, path), "http://localhost")
136         return config
137 }
138
139 func TestEcho(t *testing.T) {
140         once.Do(startServer)
141
142         // websocket.Dial()
143         client, err := net.Dial("tcp", serverAddr)
144         if err != nil {
145                 t.Fatal("dialing", err)
146         }
147         conn, err := NewClient(newConfig(t, "/echo"), client)
148         if err != nil {
149                 t.Errorf("WebSocket handshake error: %v", err)
150                 return
151         }
152
153         msg := []byte("hello, world\n")
154         if _, err := conn.Write(msg); err != nil {
155                 t.Errorf("Write: %v", err)
156         }
157         var actual_msg = make([]byte, 512)
158         n, err := conn.Read(actual_msg)
159         if err != nil {
160                 t.Errorf("Read: %v", err)
161         }
162         actual_msg = actual_msg[0:n]
163         if !bytes.Equal(msg, actual_msg) {
164                 t.Errorf("Echo: expected %q got %q", msg, actual_msg)
165         }
166         conn.Close()
167 }
168
169 func TestAddr(t *testing.T) {
170         once.Do(startServer)
171
172         // websocket.Dial()
173         client, err := net.Dial("tcp", serverAddr)
174         if err != nil {
175                 t.Fatal("dialing", err)
176         }
177         conn, err := NewClient(newConfig(t, "/echo"), client)
178         if err != nil {
179                 t.Errorf("WebSocket handshake error: %v", err)
180                 return
181         }
182
183         ra := conn.RemoteAddr().String()
184         if !strings.HasPrefix(ra, "ws://") || !strings.HasSuffix(ra, "/echo") {
185                 t.Errorf("Bad remote addr: %v", ra)
186         }
187         la := conn.LocalAddr().String()
188         if !strings.HasPrefix(la, "http://") {
189                 t.Errorf("Bad local addr: %v", la)
190         }
191         conn.Close()
192 }
193
194 func TestCount(t *testing.T) {
195         once.Do(startServer)
196
197         // websocket.Dial()
198         client, err := net.Dial("tcp", serverAddr)
199         if err != nil {
200                 t.Fatal("dialing", err)
201         }
202         conn, err := NewClient(newConfig(t, "/count"), client)
203         if err != nil {
204                 t.Errorf("WebSocket handshake error: %v", err)
205                 return
206         }
207
208         var count Count
209         count.S = "hello"
210         if err := JSON.Send(conn, count); err != nil {
211                 t.Errorf("Write: %v", err)
212         }
213         if err := JSON.Receive(conn, &count); err != nil {
214                 t.Errorf("Read: %v", err)
215         }
216         if count.N != 1 {
217                 t.Errorf("count: expected %d got %d", 1, count.N)
218         }
219         if count.S != "hello" {
220                 t.Errorf("count: expected %q got %q", "hello", count.S)
221         }
222         if err := JSON.Send(conn, count); err != nil {
223                 t.Errorf("Write: %v", err)
224         }
225         if err := JSON.Receive(conn, &count); err != nil {
226                 t.Errorf("Read: %v", err)
227         }
228         if count.N != 2 {
229                 t.Errorf("count: expected %d got %d", 2, count.N)
230         }
231         if count.S != "hellohello" {
232                 t.Errorf("count: expected %q got %q", "hellohello", count.S)
233         }
234         conn.Close()
235 }
236
237 func TestWithQuery(t *testing.T) {
238         once.Do(startServer)
239
240         client, err := net.Dial("tcp", serverAddr)
241         if err != nil {
242                 t.Fatal("dialing", err)
243         }
244
245         config := newConfig(t, "/echo")
246         config.Location, err = url.ParseRequestURI(fmt.Sprintf("ws://%s/echo?q=v", serverAddr))
247         if err != nil {
248                 t.Fatal("location url", err)
249         }
250
251         ws, err := NewClient(config, client)
252         if err != nil {
253                 t.Errorf("WebSocket handshake: %v", err)
254                 return
255         }
256         ws.Close()
257 }
258
259 func testWithProtocol(t *testing.T, subproto []string) (string, error) {
260         once.Do(startServer)
261
262         client, err := net.Dial("tcp", serverAddr)
263         if err != nil {
264                 t.Fatal("dialing", err)
265         }
266
267         config := newConfig(t, "/subproto")
268         config.Protocol = subproto
269
270         ws, err := NewClient(config, client)
271         if err != nil {
272                 return "", err
273         }
274         msg := make([]byte, 16)
275         n, err := ws.Read(msg)
276         if err != nil {
277                 return "", err
278         }
279         ws.Close()
280         return string(msg[:n]), nil
281 }
282
283 func TestWithProtocol(t *testing.T) {
284         proto, err := testWithProtocol(t, []string{"chat"})
285         if err != nil {
286                 t.Errorf("SubProto: unexpected error: %v", err)
287         }
288         if proto != "chat" {
289                 t.Errorf("SubProto: expected %q, got %q", "chat", proto)
290         }
291 }
292
293 func TestWithTwoProtocol(t *testing.T) {
294         proto, err := testWithProtocol(t, []string{"test", "chat"})
295         if err != nil {
296                 t.Errorf("SubProto: unexpected error: %v", err)
297         }
298         if proto != "chat" {
299                 t.Errorf("SubProto: expected %q, got %q", "chat", proto)
300         }
301 }
302
303 func TestWithBadProtocol(t *testing.T) {
304         _, err := testWithProtocol(t, []string{"test"})
305         if err != ErrBadStatus {
306                 t.Errorf("SubProto: expected %v, got %v", ErrBadStatus, err)
307         }
308 }
309
310 func TestHTTP(t *testing.T) {
311         once.Do(startServer)
312
313         // If the client did not send a handshake that matches the protocol
314         // specification, the server MUST return an HTTP response with an
315         // appropriate error code (such as 400 Bad Request)
316         resp, err := http.Get(fmt.Sprintf("http://%s/echo", serverAddr))
317         if err != nil {
318                 t.Errorf("Get: error %#v", err)
319                 return
320         }
321         if resp == nil {
322                 t.Error("Get: resp is null")
323                 return
324         }
325         if resp.StatusCode != http.StatusBadRequest {
326                 t.Errorf("Get: expected %q got %q", http.StatusBadRequest, resp.StatusCode)
327         }
328 }
329
330 func TestTrailingSpaces(t *testing.T) {
331         // http://code.google.com/p/go/issues/detail?id=955
332         // The last runs of this create keys with trailing spaces that should not be
333         // generated by the client.
334         once.Do(startServer)
335         config := newConfig(t, "/echo")
336         for i := 0; i < 30; i++ {
337                 // body
338                 ws, err := DialConfig(config)
339                 if err != nil {
340                         t.Errorf("Dial #%d failed: %v", i, err)
341                         break
342                 }
343                 ws.Close()
344         }
345 }
346
347 func TestDialConfigBadVersion(t *testing.T) {
348         once.Do(startServer)
349         config := newConfig(t, "/echo")
350         config.Version = 1234
351
352         _, err := DialConfig(config)
353
354         if dialerr, ok := err.(*DialError); ok {
355                 if dialerr.Err != ErrBadProtocolVersion {
356                         t.Errorf("dial expected err %q but got %q", ErrBadProtocolVersion, dialerr.Err)
357                 }
358         }
359 }
360
361 func TestDialConfigWithDialer(t *testing.T) {
362         once.Do(startServer)
363         config := newConfig(t, "/echo")
364         config.Dialer = &net.Dialer{
365                 Deadline: time.Now().Add(-time.Minute),
366         }
367         _, err := DialConfig(config)
368         dialerr, ok := err.(*DialError)
369         if !ok {
370                 t.Fatalf("DialError expected, got %#v", err)
371         }
372         neterr, ok := dialerr.Err.(*net.OpError)
373         if !ok {
374                 t.Fatalf("net.OpError error expected, got %#v", dialerr.Err)
375         }
376         if !neterr.Timeout() {
377                 t.Fatalf("expected timeout error, got %#v", neterr)
378         }
379 }
380
381 func TestSmallBuffer(t *testing.T) {
382         // http://code.google.com/p/go/issues/detail?id=1145
383         // Read should be able to handle reading a fragment of a frame.
384         once.Do(startServer)
385
386         // websocket.Dial()
387         client, err := net.Dial("tcp", serverAddr)
388         if err != nil {
389                 t.Fatal("dialing", err)
390         }
391         conn, err := NewClient(newConfig(t, "/echo"), client)
392         if err != nil {
393                 t.Errorf("WebSocket handshake error: %v", err)
394                 return
395         }
396
397         msg := []byte("hello, world\n")
398         if _, err := conn.Write(msg); err != nil {
399                 t.Errorf("Write: %v", err)
400         }
401         var small_msg = make([]byte, 8)
402         n, err := conn.Read(small_msg)
403         if err != nil {
404                 t.Errorf("Read: %v", err)
405         }
406         if !bytes.Equal(msg[:len(small_msg)], small_msg) {
407                 t.Errorf("Echo: expected %q got %q", msg[:len(small_msg)], small_msg)
408         }
409         var second_msg = make([]byte, len(msg))
410         n, err = conn.Read(second_msg)
411         if err != nil {
412                 t.Errorf("Read: %v", err)
413         }
414         second_msg = second_msg[0:n]
415         if !bytes.Equal(msg[len(small_msg):], second_msg) {
416                 t.Errorf("Echo: expected %q got %q", msg[len(small_msg):], second_msg)
417         }
418         conn.Close()
419 }
420
421 var parseAuthorityTests = []struct {
422         in  *url.URL
423         out string
424 }{
425         {
426                 &url.URL{
427                         Scheme: "ws",
428                         Host:   "www.google.com",
429                 },
430                 "www.google.com:80",
431         },
432         {
433                 &url.URL{
434                         Scheme: "wss",
435                         Host:   "www.google.com",
436                 },
437                 "www.google.com:443",
438         },
439         {
440                 &url.URL{
441                         Scheme: "ws",
442                         Host:   "www.google.com:80",
443                 },
444                 "www.google.com:80",
445         },
446         {
447                 &url.URL{
448                         Scheme: "wss",
449                         Host:   "www.google.com:443",
450                 },
451                 "www.google.com:443",
452         },
453         // some invalid ones for parseAuthority. parseAuthority doesn't
454         // concern itself with the scheme unless it actually knows about it
455         {
456                 &url.URL{
457                         Scheme: "http",
458                         Host:   "www.google.com",
459                 },
460                 "www.google.com",
461         },
462         {
463                 &url.URL{
464                         Scheme: "http",
465                         Host:   "www.google.com:80",
466                 },
467                 "www.google.com:80",
468         },
469         {
470                 &url.URL{
471                         Scheme: "asdf",
472                         Host:   "127.0.0.1",
473                 },
474                 "127.0.0.1",
475         },
476         {
477                 &url.URL{
478                         Scheme: "asdf",
479                         Host:   "www.google.com",
480                 },
481                 "www.google.com",
482         },
483 }
484
485 func TestParseAuthority(t *testing.T) {
486         for _, tt := range parseAuthorityTests {
487                 out := parseAuthority(tt.in)
488                 if out != tt.out {
489                         t.Errorf("got %v; want %v", out, tt.out)
490                 }
491         }
492 }
493
494 type closerConn struct {
495         net.Conn
496         closed int // count of the number of times Close was called
497 }
498
499 func (c *closerConn) Close() error {
500         c.closed++
501         return c.Conn.Close()
502 }
503
504 func TestClose(t *testing.T) {
505         if runtime.GOOS == "plan9" {
506                 t.Skip("see golang.org/issue/11454")
507         }
508
509         once.Do(startServer)
510
511         conn, err := net.Dial("tcp", serverAddr)
512         if err != nil {
513                 t.Fatal("dialing", err)
514         }
515
516         cc := closerConn{Conn: conn}
517
518         client, err := NewClient(newConfig(t, "/echo"), &cc)
519         if err != nil {
520                 t.Fatalf("WebSocket handshake: %v", err)
521         }
522
523         // set the deadline to ten minutes ago, which will have expired by the time
524         // client.Close sends the close status frame.
525         conn.SetDeadline(time.Now().Add(-10 * time.Minute))
526
527         if err := client.Close(); err == nil {
528                 t.Errorf("ws.Close(): expected error, got %v", err)
529         }
530         if cc.closed < 1 {
531                 t.Fatalf("ws.Close(): expected underlying ws.rwc.Close to be called > 0 times, got: %v", cc.closed)
532         }
533 }
534
535 var originTests = []struct {
536         req    *http.Request
537         origin *url.URL
538 }{
539         {
540                 req: &http.Request{
541                         Header: http.Header{
542                                 "Origin": []string{"http://www.example.com"},
543                         },
544                 },
545                 origin: &url.URL{
546                         Scheme: "http",
547                         Host:   "www.example.com",
548                 },
549         },
550         {
551                 req: &http.Request{},
552         },
553 }
554
555 func TestOrigin(t *testing.T) {
556         conf := newConfig(t, "/echo")
557         conf.Version = ProtocolVersionHybi13
558         for i, tt := range originTests {
559                 origin, err := Origin(conf, tt.req)
560                 if err != nil {
561                         t.Error(err)
562                         continue
563                 }
564                 if !reflect.DeepEqual(origin, tt.origin) {
565                         t.Errorf("#%d: got origin %v; want %v", i, origin, tt.origin)
566                         continue
567                 }
568         }
569 }
570
571 func TestCtrlAndData(t *testing.T) {
572         once.Do(startServer)
573
574         c, err := net.Dial("tcp", serverAddr)
575         if err != nil {
576                 t.Fatal(err)
577         }
578         ws, err := NewClient(newConfig(t, "/ctrldata"), c)
579         if err != nil {
580                 t.Fatal(err)
581         }
582         defer ws.Close()
583
584         h := &testCtrlAndDataHandler{hybiFrameHandler: hybiFrameHandler{conn: ws}}
585         ws.frameHandler = h
586
587         b := make([]byte, 128)
588         for i := 0; i < 2; i++ {
589                 data := []byte(fmt.Sprintf("#%d-DATA-FRAME-FROM-CLIENT", i))
590                 if _, err := ws.Write(data); err != nil {
591                         t.Fatalf("#%d: %v", i, err)
592                 }
593                 var ctrl []byte
594                 if i%2 != 0 { // with or without payload
595                         ctrl = []byte(fmt.Sprintf("#%d-CONTROL-FRAME-FROM-CLIENT", i))
596                 }
597                 if _, err := h.WritePing(ctrl); err != nil {
598                         t.Fatalf("#%d: %v", i, err)
599                 }
600                 n, err := ws.Read(b)
601                 if err != nil {
602                         t.Fatalf("#%d: %v", i, err)
603                 }
604                 if !bytes.Equal(b[:n], data) {
605                         t.Fatalf("#%d: got %v; want %v", i, b[:n], data)
606                 }
607         }
608 }
609
610 func TestCodec_ReceiveLimited(t *testing.T) {
611         const limit = 2048
612         var payloads [][]byte
613         for _, size := range []int{
614                 1024,
615                 2048,
616                 4096, // receive of this message would be interrupted due to limit
617                 2048, // this one is to make sure next receive recovers discarding leftovers
618         } {
619                 b := make([]byte, size)
620                 rand.Read(b)
621                 payloads = append(payloads, b)
622         }
623         handlerDone := make(chan struct{})
624         limitedHandler := func(ws *Conn) {
625                 defer close(handlerDone)
626                 ws.MaxPayloadBytes = limit
627                 defer ws.Close()
628                 for i, p := range payloads {
629                         t.Logf("payload #%d (size %d, exceeds limit: %v)", i, len(p), len(p) > limit)
630                         var recv []byte
631                         err := Message.Receive(ws, &recv)
632                         switch err {
633                         case nil:
634                         case ErrFrameTooLarge:
635                                 if len(p) <= limit {
636                                         t.Fatalf("unexpected frame size limit: expected %d bytes of payload having limit at %d", len(p), limit)
637                                 }
638                                 continue
639                         default:
640                                 t.Fatalf("unexpected error: %v (want either nil or ErrFrameTooLarge)", err)
641                         }
642                         if len(recv) > limit {
643                                 t.Fatalf("received %d bytes of payload having limit at %d", len(recv), limit)
644                         }
645                         if !bytes.Equal(p, recv) {
646                                 t.Fatalf("received payload differs:\ngot:\t%v\nwant:\t%v", recv, p)
647                         }
648                 }
649         }
650         server := httptest.NewServer(Handler(limitedHandler))
651         defer server.CloseClientConnections()
652         defer server.Close()
653         addr := server.Listener.Addr().String()
654         ws, err := Dial("ws://"+addr+"/", "", "http://localhost/")
655         if err != nil {
656                 t.Fatal(err)
657         }
658         defer ws.Close()
659         for i, p := range payloads {
660                 if err := Message.Send(ws, p); err != nil {
661                         t.Fatalf("payload #%d (size %d): %v", i, len(p), err)
662                 }
663         }
664         <-handlerDone
665 }