OSDN Git Service

new repo
[bytom/vapor.git] / vendor / golang.org / x / net / websocket / hybi.go
1 // Copyright 2011 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package websocket
6
7 // This file implements a protocol of hybi draft.
8 // http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17
9
10 import (
11         "bufio"
12         "bytes"
13         "crypto/rand"
14         "crypto/sha1"
15         "encoding/base64"
16         "encoding/binary"
17         "fmt"
18         "io"
19         "io/ioutil"
20         "net/http"
21         "net/url"
22         "strings"
23 )
24
25 const (
26         websocketGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
27
28         closeStatusNormal            = 1000
29         closeStatusGoingAway         = 1001
30         closeStatusProtocolError     = 1002
31         closeStatusUnsupportedData   = 1003
32         closeStatusFrameTooLarge     = 1004
33         closeStatusNoStatusRcvd      = 1005
34         closeStatusAbnormalClosure   = 1006
35         closeStatusBadMessageData    = 1007
36         closeStatusPolicyViolation   = 1008
37         closeStatusTooBigData        = 1009
38         closeStatusExtensionMismatch = 1010
39
40         maxControlFramePayloadLength = 125
41 )
42
43 var (
44         ErrBadMaskingKey         = &ProtocolError{"bad masking key"}
45         ErrBadPongMessage        = &ProtocolError{"bad pong message"}
46         ErrBadClosingStatus      = &ProtocolError{"bad closing status"}
47         ErrUnsupportedExtensions = &ProtocolError{"unsupported extensions"}
48         ErrNotImplemented        = &ProtocolError{"not implemented"}
49
50         handshakeHeader = map[string]bool{
51                 "Host":                   true,
52                 "Upgrade":                true,
53                 "Connection":             true,
54                 "Sec-Websocket-Key":      true,
55                 "Sec-Websocket-Origin":   true,
56                 "Sec-Websocket-Version":  true,
57                 "Sec-Websocket-Protocol": true,
58                 "Sec-Websocket-Accept":   true,
59         }
60 )
61
62 // A hybiFrameHeader is a frame header as defined in hybi draft.
63 type hybiFrameHeader struct {
64         Fin        bool
65         Rsv        [3]bool
66         OpCode     byte
67         Length     int64
68         MaskingKey []byte
69
70         data *bytes.Buffer
71 }
72
73 // A hybiFrameReader is a reader for hybi frame.
74 type hybiFrameReader struct {
75         reader io.Reader
76
77         header hybiFrameHeader
78         pos    int64
79         length int
80 }
81
82 func (frame *hybiFrameReader) Read(msg []byte) (n int, err error) {
83         n, err = frame.reader.Read(msg)
84         if frame.header.MaskingKey != nil {
85                 for i := 0; i < n; i++ {
86                         msg[i] = msg[i] ^ frame.header.MaskingKey[frame.pos%4]
87                         frame.pos++
88                 }
89         }
90         return n, err
91 }
92
93 func (frame *hybiFrameReader) PayloadType() byte { return frame.header.OpCode }
94
95 func (frame *hybiFrameReader) HeaderReader() io.Reader {
96         if frame.header.data == nil {
97                 return nil
98         }
99         if frame.header.data.Len() == 0 {
100                 return nil
101         }
102         return frame.header.data
103 }
104
105 func (frame *hybiFrameReader) TrailerReader() io.Reader { return nil }
106
107 func (frame *hybiFrameReader) Len() (n int) { return frame.length }
108
109 // A hybiFrameReaderFactory creates new frame reader based on its frame type.
110 type hybiFrameReaderFactory struct {
111         *bufio.Reader
112 }
113
114 // NewFrameReader reads a frame header from the connection, and creates new reader for the frame.
115 // See Section 5.2 Base Framing protocol for detail.
116 // http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17#section-5.2
117 func (buf hybiFrameReaderFactory) NewFrameReader() (frame frameReader, err error) {
118         hybiFrame := new(hybiFrameReader)
119         frame = hybiFrame
120         var header []byte
121         var b byte
122         // First byte. FIN/RSV1/RSV2/RSV3/OpCode(4bits)
123         b, err = buf.ReadByte()
124         if err != nil {
125                 return
126         }
127         header = append(header, b)
128         hybiFrame.header.Fin = ((header[0] >> 7) & 1) != 0
129         for i := 0; i < 3; i++ {
130                 j := uint(6 - i)
131                 hybiFrame.header.Rsv[i] = ((header[0] >> j) & 1) != 0
132         }
133         hybiFrame.header.OpCode = header[0] & 0x0f
134
135         // Second byte. Mask/Payload len(7bits)
136         b, err = buf.ReadByte()
137         if err != nil {
138                 return
139         }
140         header = append(header, b)
141         mask := (b & 0x80) != 0
142         b &= 0x7f
143         lengthFields := 0
144         switch {
145         case b <= 125: // Payload length 7bits.
146                 hybiFrame.header.Length = int64(b)
147         case b == 126: // Payload length 7+16bits
148                 lengthFields = 2
149         case b == 127: // Payload length 7+64bits
150                 lengthFields = 8
151         }
152         for i := 0; i < lengthFields; i++ {
153                 b, err = buf.ReadByte()
154                 if err != nil {
155                         return
156                 }
157                 if lengthFields == 8 && i == 0 { // MSB must be zero when 7+64 bits
158                         b &= 0x7f
159                 }
160                 header = append(header, b)
161                 hybiFrame.header.Length = hybiFrame.header.Length*256 + int64(b)
162         }
163         if mask {
164                 // Masking key. 4 bytes.
165                 for i := 0; i < 4; i++ {
166                         b, err = buf.ReadByte()
167                         if err != nil {
168                                 return
169                         }
170                         header = append(header, b)
171                         hybiFrame.header.MaskingKey = append(hybiFrame.header.MaskingKey, b)
172                 }
173         }
174         hybiFrame.reader = io.LimitReader(buf.Reader, hybiFrame.header.Length)
175         hybiFrame.header.data = bytes.NewBuffer(header)
176         hybiFrame.length = len(header) + int(hybiFrame.header.Length)
177         return
178 }
179
180 // A HybiFrameWriter is a writer for hybi frame.
181 type hybiFrameWriter struct {
182         writer *bufio.Writer
183
184         header *hybiFrameHeader
185 }
186
187 func (frame *hybiFrameWriter) Write(msg []byte) (n int, err error) {
188         var header []byte
189         var b byte
190         if frame.header.Fin {
191                 b |= 0x80
192         }
193         for i := 0; i < 3; i++ {
194                 if frame.header.Rsv[i] {
195                         j := uint(6 - i)
196                         b |= 1 << j
197                 }
198         }
199         b |= frame.header.OpCode
200         header = append(header, b)
201         if frame.header.MaskingKey != nil {
202                 b = 0x80
203         } else {
204                 b = 0
205         }
206         lengthFields := 0
207         length := len(msg)
208         switch {
209         case length <= 125:
210                 b |= byte(length)
211         case length < 65536:
212                 b |= 126
213                 lengthFields = 2
214         default:
215                 b |= 127
216                 lengthFields = 8
217         }
218         header = append(header, b)
219         for i := 0; i < lengthFields; i++ {
220                 j := uint((lengthFields - i - 1) * 8)
221                 b = byte((length >> j) & 0xff)
222                 header = append(header, b)
223         }
224         if frame.header.MaskingKey != nil {
225                 if len(frame.header.MaskingKey) != 4 {
226                         return 0, ErrBadMaskingKey
227                 }
228                 header = append(header, frame.header.MaskingKey...)
229                 frame.writer.Write(header)
230                 data := make([]byte, length)
231                 for i := range data {
232                         data[i] = msg[i] ^ frame.header.MaskingKey[i%4]
233                 }
234                 frame.writer.Write(data)
235                 err = frame.writer.Flush()
236                 return length, err
237         }
238         frame.writer.Write(header)
239         frame.writer.Write(msg)
240         err = frame.writer.Flush()
241         return length, err
242 }
243
244 func (frame *hybiFrameWriter) Close() error { return nil }
245
246 type hybiFrameWriterFactory struct {
247         *bufio.Writer
248         needMaskingKey bool
249 }
250
251 func (buf hybiFrameWriterFactory) NewFrameWriter(payloadType byte) (frame frameWriter, err error) {
252         frameHeader := &hybiFrameHeader{Fin: true, OpCode: payloadType}
253         if buf.needMaskingKey {
254                 frameHeader.MaskingKey, err = generateMaskingKey()
255                 if err != nil {
256                         return nil, err
257                 }
258         }
259         return &hybiFrameWriter{writer: buf.Writer, header: frameHeader}, nil
260 }
261
262 type hybiFrameHandler struct {
263         conn        *Conn
264         payloadType byte
265 }
266
267 func (handler *hybiFrameHandler) HandleFrame(frame frameReader) (frameReader, error) {
268         if handler.conn.IsServerConn() {
269                 // The client MUST mask all frames sent to the server.
270                 if frame.(*hybiFrameReader).header.MaskingKey == nil {
271                         handler.WriteClose(closeStatusProtocolError)
272                         return nil, io.EOF
273                 }
274         } else {
275                 // The server MUST NOT mask all frames.
276                 if frame.(*hybiFrameReader).header.MaskingKey != nil {
277                         handler.WriteClose(closeStatusProtocolError)
278                         return nil, io.EOF
279                 }
280         }
281         if header := frame.HeaderReader(); header != nil {
282                 io.Copy(ioutil.Discard, header)
283         }
284         switch frame.PayloadType() {
285         case ContinuationFrame:
286                 frame.(*hybiFrameReader).header.OpCode = handler.payloadType
287         case TextFrame, BinaryFrame:
288                 handler.payloadType = frame.PayloadType()
289         case CloseFrame:
290                 return nil, io.EOF
291         case PingFrame, PongFrame:
292                 b := make([]byte, maxControlFramePayloadLength)
293                 n, err := io.ReadFull(frame, b)
294                 if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
295                         return nil, err
296                 }
297                 io.Copy(ioutil.Discard, frame)
298                 if frame.PayloadType() == PingFrame {
299                         if _, err := handler.WritePong(b[:n]); err != nil {
300                                 return nil, err
301                         }
302                 }
303                 return nil, nil
304         }
305         return frame, nil
306 }
307
308 func (handler *hybiFrameHandler) WriteClose(status int) (err error) {
309         handler.conn.wio.Lock()
310         defer handler.conn.wio.Unlock()
311         w, err := handler.conn.frameWriterFactory.NewFrameWriter(CloseFrame)
312         if err != nil {
313                 return err
314         }
315         msg := make([]byte, 2)
316         binary.BigEndian.PutUint16(msg, uint16(status))
317         _, err = w.Write(msg)
318         w.Close()
319         return err
320 }
321
322 func (handler *hybiFrameHandler) WritePong(msg []byte) (n int, err error) {
323         handler.conn.wio.Lock()
324         defer handler.conn.wio.Unlock()
325         w, err := handler.conn.frameWriterFactory.NewFrameWriter(PongFrame)
326         if err != nil {
327                 return 0, err
328         }
329         n, err = w.Write(msg)
330         w.Close()
331         return n, err
332 }
333
334 // newHybiConn creates a new WebSocket connection speaking hybi draft protocol.
335 func newHybiConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
336         if buf == nil {
337                 br := bufio.NewReader(rwc)
338                 bw := bufio.NewWriter(rwc)
339                 buf = bufio.NewReadWriter(br, bw)
340         }
341         ws := &Conn{config: config, request: request, buf: buf, rwc: rwc,
342                 frameReaderFactory: hybiFrameReaderFactory{buf.Reader},
343                 frameWriterFactory: hybiFrameWriterFactory{
344                         buf.Writer, request == nil},
345                 PayloadType:        TextFrame,
346                 defaultCloseStatus: closeStatusNormal}
347         ws.frameHandler = &hybiFrameHandler{conn: ws}
348         return ws
349 }
350
351 // generateMaskingKey generates a masking key for a frame.
352 func generateMaskingKey() (maskingKey []byte, err error) {
353         maskingKey = make([]byte, 4)
354         if _, err = io.ReadFull(rand.Reader, maskingKey); err != nil {
355                 return
356         }
357         return
358 }
359
360 // generateNonce generates a nonce consisting of a randomly selected 16-byte
361 // value that has been base64-encoded.
362 func generateNonce() (nonce []byte) {
363         key := make([]byte, 16)
364         if _, err := io.ReadFull(rand.Reader, key); err != nil {
365                 panic(err)
366         }
367         nonce = make([]byte, 24)
368         base64.StdEncoding.Encode(nonce, key)
369         return
370 }
371
372 // removeZone removes IPv6 zone identifer from host.
373 // E.g., "[fe80::1%en0]:8080" to "[fe80::1]:8080"
374 func removeZone(host string) string {
375         if !strings.HasPrefix(host, "[") {
376                 return host
377         }
378         i := strings.LastIndex(host, "]")
379         if i < 0 {
380                 return host
381         }
382         j := strings.LastIndex(host[:i], "%")
383         if j < 0 {
384                 return host
385         }
386         return host[:j] + host[i:]
387 }
388
389 // getNonceAccept computes the base64-encoded SHA-1 of the concatenation of
390 // the nonce ("Sec-WebSocket-Key" value) with the websocket GUID string.
391 func getNonceAccept(nonce []byte) (expected []byte, err error) {
392         h := sha1.New()
393         if _, err = h.Write(nonce); err != nil {
394                 return
395         }
396         if _, err = h.Write([]byte(websocketGUID)); err != nil {
397                 return
398         }
399         expected = make([]byte, 28)
400         base64.StdEncoding.Encode(expected, h.Sum(nil))
401         return
402 }
403
404 // Client handshake described in draft-ietf-hybi-thewebsocket-protocol-17
405 func hybiClientHandshake(config *Config, br *bufio.Reader, bw *bufio.Writer) (err error) {
406         bw.WriteString("GET " + config.Location.RequestURI() + " HTTP/1.1\r\n")
407
408         // According to RFC 6874, an HTTP client, proxy, or other
409         // intermediary must remove any IPv6 zone identifier attached
410         // to an outgoing URI.
411         bw.WriteString("Host: " + removeZone(config.Location.Host) + "\r\n")
412         bw.WriteString("Upgrade: websocket\r\n")
413         bw.WriteString("Connection: Upgrade\r\n")
414         nonce := generateNonce()
415         if config.handshakeData != nil {
416                 nonce = []byte(config.handshakeData["key"])
417         }
418         bw.WriteString("Sec-WebSocket-Key: " + string(nonce) + "\r\n")
419         bw.WriteString("Origin: " + strings.ToLower(config.Origin.String()) + "\r\n")
420
421         if config.Version != ProtocolVersionHybi13 {
422                 return ErrBadProtocolVersion
423         }
424
425         bw.WriteString("Sec-WebSocket-Version: " + fmt.Sprintf("%d", config.Version) + "\r\n")
426         if len(config.Protocol) > 0 {
427                 bw.WriteString("Sec-WebSocket-Protocol: " + strings.Join(config.Protocol, ", ") + "\r\n")
428         }
429         // TODO(ukai): send Sec-WebSocket-Extensions.
430         err = config.Header.WriteSubset(bw, handshakeHeader)
431         if err != nil {
432                 return err
433         }
434
435         bw.WriteString("\r\n")
436         if err = bw.Flush(); err != nil {
437                 return err
438         }
439
440         resp, err := http.ReadResponse(br, &http.Request{Method: "GET"})
441         if err != nil {
442                 return err
443         }
444         if resp.StatusCode != 101 {
445                 return ErrBadStatus
446         }
447         if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" ||
448                 strings.ToLower(resp.Header.Get("Connection")) != "upgrade" {
449                 return ErrBadUpgrade
450         }
451         expectedAccept, err := getNonceAccept(nonce)
452         if err != nil {
453                 return err
454         }
455         if resp.Header.Get("Sec-WebSocket-Accept") != string(expectedAccept) {
456                 return ErrChallengeResponse
457         }
458         if resp.Header.Get("Sec-WebSocket-Extensions") != "" {
459                 return ErrUnsupportedExtensions
460         }
461         offeredProtocol := resp.Header.Get("Sec-WebSocket-Protocol")
462         if offeredProtocol != "" {
463                 protocolMatched := false
464                 for i := 0; i < len(config.Protocol); i++ {
465                         if config.Protocol[i] == offeredProtocol {
466                                 protocolMatched = true
467                                 break
468                         }
469                 }
470                 if !protocolMatched {
471                         return ErrBadWebSocketProtocol
472                 }
473                 config.Protocol = []string{offeredProtocol}
474         }
475
476         return nil
477 }
478
479 // newHybiClientConn creates a client WebSocket connection after handshake.
480 func newHybiClientConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser) *Conn {
481         return newHybiConn(config, buf, rwc, nil)
482 }
483
484 // A HybiServerHandshaker performs a server handshake using hybi draft protocol.
485 type hybiServerHandshaker struct {
486         *Config
487         accept []byte
488 }
489
490 func (c *hybiServerHandshaker) ReadHandshake(buf *bufio.Reader, req *http.Request) (code int, err error) {
491         c.Version = ProtocolVersionHybi13
492         if req.Method != "GET" {
493                 return http.StatusMethodNotAllowed, ErrBadRequestMethod
494         }
495         // HTTP version can be safely ignored.
496
497         if strings.ToLower(req.Header.Get("Upgrade")) != "websocket" ||
498                 !strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") {
499                 return http.StatusBadRequest, ErrNotWebSocket
500         }
501
502         key := req.Header.Get("Sec-Websocket-Key")
503         if key == "" {
504                 return http.StatusBadRequest, ErrChallengeResponse
505         }
506         version := req.Header.Get("Sec-Websocket-Version")
507         switch version {
508         case "13":
509                 c.Version = ProtocolVersionHybi13
510         default:
511                 return http.StatusBadRequest, ErrBadWebSocketVersion
512         }
513         var scheme string
514         if req.TLS != nil {
515                 scheme = "wss"
516         } else {
517                 scheme = "ws"
518         }
519         c.Location, err = url.ParseRequestURI(scheme + "://" + req.Host + req.URL.RequestURI())
520         if err != nil {
521                 return http.StatusBadRequest, err
522         }
523         protocol := strings.TrimSpace(req.Header.Get("Sec-Websocket-Protocol"))
524         if protocol != "" {
525                 protocols := strings.Split(protocol, ",")
526                 for i := 0; i < len(protocols); i++ {
527                         c.Protocol = append(c.Protocol, strings.TrimSpace(protocols[i]))
528                 }
529         }
530         c.accept, err = getNonceAccept([]byte(key))
531         if err != nil {
532                 return http.StatusInternalServerError, err
533         }
534         return http.StatusSwitchingProtocols, nil
535 }
536
537 // Origin parses the Origin header in req.
538 // If the Origin header is not set, it returns nil and nil.
539 func Origin(config *Config, req *http.Request) (*url.URL, error) {
540         var origin string
541         switch config.Version {
542         case ProtocolVersionHybi13:
543                 origin = req.Header.Get("Origin")
544         }
545         if origin == "" {
546                 return nil, nil
547         }
548         return url.ParseRequestURI(origin)
549 }
550
551 func (c *hybiServerHandshaker) AcceptHandshake(buf *bufio.Writer) (err error) {
552         if len(c.Protocol) > 0 {
553                 if len(c.Protocol) != 1 {
554                         // You need choose a Protocol in Handshake func in Server.
555                         return ErrBadWebSocketProtocol
556                 }
557         }
558         buf.WriteString("HTTP/1.1 101 Switching Protocols\r\n")
559         buf.WriteString("Upgrade: websocket\r\n")
560         buf.WriteString("Connection: Upgrade\r\n")
561         buf.WriteString("Sec-WebSocket-Accept: " + string(c.accept) + "\r\n")
562         if len(c.Protocol) > 0 {
563                 buf.WriteString("Sec-WebSocket-Protocol: " + c.Protocol[0] + "\r\n")
564         }
565         // TODO(ukai): send Sec-WebSocket-Extensions.
566         if c.Header != nil {
567                 err := c.Header.WriteSubset(buf, handshakeHeader)
568                 if err != nil {
569                         return err
570                 }
571         }
572         buf.WriteString("\r\n")
573         return buf.Flush()
574 }
575
576 func (c *hybiServerHandshaker) NewServerConn(buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
577         return newHybiServerConn(c.Config, buf, rwc, request)
578 }
579
580 // newHybiServerConn returns a new WebSocket connection speaking hybi draft protocol.
581 func newHybiServerConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
582         return newHybiConn(config, buf, rwc, request)
583 }