OSDN Git Service

add package
[bytom/vapor.git] / vendor / github.com / hashicorp / yamux / stream.go
1 package yamux
2
3 import (
4         "bytes"
5         "io"
6         "sync"
7         "sync/atomic"
8         "time"
9 )
10
11 type streamState int
12
13 const (
14         streamInit streamState = iota
15         streamSYNSent
16         streamSYNReceived
17         streamEstablished
18         streamLocalClose
19         streamRemoteClose
20         streamClosed
21         streamReset
22 )
23
24 // Stream is used to represent a logical stream
25 // within a session.
26 type Stream struct {
27         recvWindow uint32
28         sendWindow uint32
29
30         id      uint32
31         session *Session
32
33         state     streamState
34         stateLock sync.Mutex
35
36         recvBuf  *bytes.Buffer
37         recvLock sync.Mutex
38
39         controlHdr     header
40         controlErr     chan error
41         controlHdrLock sync.Mutex
42
43         sendHdr  header
44         sendErr  chan error
45         sendLock sync.Mutex
46
47         recvNotifyCh chan struct{}
48         sendNotifyCh chan struct{}
49
50         readDeadline  atomic.Value // time.Time
51         writeDeadline atomic.Value // time.Time
52 }
53
54 // newStream is used to construct a new stream within
55 // a given session for an ID
56 func newStream(session *Session, id uint32, state streamState) *Stream {
57         s := &Stream{
58                 id:           id,
59                 session:      session,
60                 state:        state,
61                 controlHdr:   header(make([]byte, headerSize)),
62                 controlErr:   make(chan error, 1),
63                 sendHdr:      header(make([]byte, headerSize)),
64                 sendErr:      make(chan error, 1),
65                 recvWindow:   initialStreamWindow,
66                 sendWindow:   initialStreamWindow,
67                 recvNotifyCh: make(chan struct{}, 1),
68                 sendNotifyCh: make(chan struct{}, 1),
69         }
70         s.readDeadline.Store(time.Time{})
71         s.writeDeadline.Store(time.Time{})
72         return s
73 }
74
75 // Session returns the associated stream session
76 func (s *Stream) Session() *Session {
77         return s.session
78 }
79
80 // StreamID returns the ID of this stream
81 func (s *Stream) StreamID() uint32 {
82         return s.id
83 }
84
85 // Read is used to read from the stream
86 func (s *Stream) Read(b []byte) (n int, err error) {
87         defer asyncNotify(s.recvNotifyCh)
88 START:
89         s.stateLock.Lock()
90         switch s.state {
91         case streamLocalClose:
92                 fallthrough
93         case streamRemoteClose:
94                 fallthrough
95         case streamClosed:
96                 s.recvLock.Lock()
97                 if s.recvBuf == nil || s.recvBuf.Len() == 0 {
98                         s.recvLock.Unlock()
99                         s.stateLock.Unlock()
100                         return 0, io.EOF
101                 }
102                 s.recvLock.Unlock()
103         case streamReset:
104                 s.stateLock.Unlock()
105                 return 0, ErrConnectionReset
106         }
107         s.stateLock.Unlock()
108
109         // If there is no data available, block
110         s.recvLock.Lock()
111         if s.recvBuf == nil || s.recvBuf.Len() == 0 {
112                 s.recvLock.Unlock()
113                 goto WAIT
114         }
115
116         // Read any bytes
117         n, _ = s.recvBuf.Read(b)
118         s.recvLock.Unlock()
119
120         // Send a window update potentially
121         err = s.sendWindowUpdate()
122         return n, err
123
124 WAIT:
125         var timeout <-chan time.Time
126         var timer *time.Timer
127         readDeadline := s.readDeadline.Load().(time.Time)
128         if !readDeadline.IsZero() {
129                 delay := readDeadline.Sub(time.Now())
130                 timer = time.NewTimer(delay)
131                 timeout = timer.C
132         }
133         select {
134         case <-s.recvNotifyCh:
135                 if timer != nil {
136                         timer.Stop()
137                 }
138                 goto START
139         case <-timeout:
140                 return 0, ErrTimeout
141         }
142 }
143
144 // Write is used to write to the stream
145 func (s *Stream) Write(b []byte) (n int, err error) {
146         s.sendLock.Lock()
147         defer s.sendLock.Unlock()
148         total := 0
149         for total < len(b) {
150                 n, err := s.write(b[total:])
151                 total += n
152                 if err != nil {
153                         return total, err
154                 }
155         }
156         return total, nil
157 }
158
159 // write is used to write to the stream, may return on
160 // a short write.
161 func (s *Stream) write(b []byte) (n int, err error) {
162         var flags uint16
163         var max uint32
164         var body io.Reader
165 START:
166         s.stateLock.Lock()
167         switch s.state {
168         case streamLocalClose:
169                 fallthrough
170         case streamClosed:
171                 s.stateLock.Unlock()
172                 return 0, ErrStreamClosed
173         case streamReset:
174                 s.stateLock.Unlock()
175                 return 0, ErrConnectionReset
176         }
177         s.stateLock.Unlock()
178
179         // If there is no data available, block
180         window := atomic.LoadUint32(&s.sendWindow)
181         if window == 0 {
182                 goto WAIT
183         }
184
185         // Determine the flags if any
186         flags = s.sendFlags()
187
188         // Send up to our send window
189         max = min(window, uint32(len(b)))
190         body = bytes.NewReader(b[:max])
191
192         // Send the header
193         s.sendHdr.encode(typeData, flags, s.id, max)
194         if err = s.session.waitForSendErr(s.sendHdr, body, s.sendErr); err != nil {
195                 return 0, err
196         }
197
198         // Reduce our send window
199         atomic.AddUint32(&s.sendWindow, ^uint32(max-1))
200
201         // Unlock
202         return int(max), err
203
204 WAIT:
205         var timeout <-chan time.Time
206         writeDeadline := s.writeDeadline.Load().(time.Time)
207         if !writeDeadline.IsZero() {
208                 delay := writeDeadline.Sub(time.Now())
209                 timeout = time.After(delay)
210         }
211         select {
212         case <-s.sendNotifyCh:
213                 goto START
214         case <-timeout:
215                 return 0, ErrTimeout
216         }
217         return 0, nil
218 }
219
220 // sendFlags determines any flags that are appropriate
221 // based on the current stream state
222 func (s *Stream) sendFlags() uint16 {
223         s.stateLock.Lock()
224         defer s.stateLock.Unlock()
225         var flags uint16
226         switch s.state {
227         case streamInit:
228                 flags |= flagSYN
229                 s.state = streamSYNSent
230         case streamSYNReceived:
231                 flags |= flagACK
232                 s.state = streamEstablished
233         }
234         return flags
235 }
236
237 // sendWindowUpdate potentially sends a window update enabling
238 // further writes to take place. Must be invoked with the lock.
239 func (s *Stream) sendWindowUpdate() error {
240         s.controlHdrLock.Lock()
241         defer s.controlHdrLock.Unlock()
242
243         // Determine the delta update
244         max := s.session.config.MaxStreamWindowSize
245         var bufLen uint32
246         s.recvLock.Lock()
247         if s.recvBuf != nil {
248                 bufLen = uint32(s.recvBuf.Len())
249         }
250         delta := (max - bufLen) - s.recvWindow
251
252         // Determine the flags if any
253         flags := s.sendFlags()
254
255         // Check if we can omit the update
256         if delta < (max/2) && flags == 0 {
257                 s.recvLock.Unlock()
258                 return nil
259         }
260
261         // Update our window
262         s.recvWindow += delta
263         s.recvLock.Unlock()
264
265         // Send the header
266         s.controlHdr.encode(typeWindowUpdate, flags, s.id, delta)
267         if err := s.session.waitForSendErr(s.controlHdr, nil, s.controlErr); err != nil {
268                 return err
269         }
270         return nil
271 }
272
273 // sendClose is used to send a FIN
274 func (s *Stream) sendClose() error {
275         s.controlHdrLock.Lock()
276         defer s.controlHdrLock.Unlock()
277
278         flags := s.sendFlags()
279         flags |= flagFIN
280         s.controlHdr.encode(typeWindowUpdate, flags, s.id, 0)
281         if err := s.session.waitForSendErr(s.controlHdr, nil, s.controlErr); err != nil {
282                 return err
283         }
284         return nil
285 }
286
287 // Close is used to close the stream
288 func (s *Stream) Close() error {
289         closeStream := false
290         s.stateLock.Lock()
291         switch s.state {
292         // Opened means we need to signal a close
293         case streamSYNSent:
294                 fallthrough
295         case streamSYNReceived:
296                 fallthrough
297         case streamEstablished:
298                 s.state = streamLocalClose
299                 goto SEND_CLOSE
300
301         case streamLocalClose:
302         case streamRemoteClose:
303                 s.state = streamClosed
304                 closeStream = true
305                 goto SEND_CLOSE
306
307         case streamClosed:
308         case streamReset:
309         default:
310                 panic("unhandled state")
311         }
312         s.stateLock.Unlock()
313         return nil
314 SEND_CLOSE:
315         s.stateLock.Unlock()
316         s.sendClose()
317         s.notifyWaiting()
318         if closeStream {
319                 s.session.closeStream(s.id)
320         }
321         return nil
322 }
323
324 // forceClose is used for when the session is exiting
325 func (s *Stream) forceClose() {
326         s.stateLock.Lock()
327         s.state = streamClosed
328         s.stateLock.Unlock()
329         s.notifyWaiting()
330 }
331
332 // processFlags is used to update the state of the stream
333 // based on set flags, if any. Lock must be held
334 func (s *Stream) processFlags(flags uint16) error {
335         // Close the stream without holding the state lock
336         closeStream := false
337         defer func() {
338                 if closeStream {
339                         s.session.closeStream(s.id)
340                 }
341         }()
342
343         s.stateLock.Lock()
344         defer s.stateLock.Unlock()
345         if flags&flagACK == flagACK {
346                 if s.state == streamSYNSent {
347                         s.state = streamEstablished
348                 }
349                 s.session.establishStream(s.id)
350         }
351         if flags&flagFIN == flagFIN {
352                 switch s.state {
353                 case streamSYNSent:
354                         fallthrough
355                 case streamSYNReceived:
356                         fallthrough
357                 case streamEstablished:
358                         s.state = streamRemoteClose
359                         s.notifyWaiting()
360                 case streamLocalClose:
361                         s.state = streamClosed
362                         closeStream = true
363                         s.notifyWaiting()
364                 default:
365                         s.session.logger.Printf("[ERR] yamux: unexpected FIN flag in state %d", s.state)
366                         return ErrUnexpectedFlag
367                 }
368         }
369         if flags&flagRST == flagRST {
370                 s.state = streamReset
371                 closeStream = true
372                 s.notifyWaiting()
373         }
374         return nil
375 }
376
377 // notifyWaiting notifies all the waiting channels
378 func (s *Stream) notifyWaiting() {
379         asyncNotify(s.recvNotifyCh)
380         asyncNotify(s.sendNotifyCh)
381 }
382
383 // incrSendWindow updates the size of our send window
384 func (s *Stream) incrSendWindow(hdr header, flags uint16) error {
385         if err := s.processFlags(flags); err != nil {
386                 return err
387         }
388
389         // Increase window, unblock a sender
390         atomic.AddUint32(&s.sendWindow, hdr.Length())
391         asyncNotify(s.sendNotifyCh)
392         return nil
393 }
394
395 // readData is used to handle a data frame
396 func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
397         if err := s.processFlags(flags); err != nil {
398                 return err
399         }
400
401         // Check that our recv window is not exceeded
402         length := hdr.Length()
403         if length == 0 {
404                 return nil
405         }
406
407         // Wrap in a limited reader
408         conn = &io.LimitedReader{R: conn, N: int64(length)}
409
410         // Copy into buffer
411         s.recvLock.Lock()
412
413         if length > s.recvWindow {
414                 s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, s.recvWindow, length)
415                 return ErrRecvWindowExceeded
416         }
417
418         if s.recvBuf == nil {
419                 // Allocate the receive buffer just-in-time to fit the full data frame.
420                 // This way we can read in the whole packet without further allocations.
421                 s.recvBuf = bytes.NewBuffer(make([]byte, 0, length))
422         }
423         if _, err := io.Copy(s.recvBuf, conn); err != nil {
424                 s.session.logger.Printf("[ERR] yamux: Failed to read stream data: %v", err)
425                 s.recvLock.Unlock()
426                 return err
427         }
428
429         // Decrement the receive window
430         s.recvWindow -= length
431         s.recvLock.Unlock()
432
433         // Unblock any readers
434         asyncNotify(s.recvNotifyCh)
435         return nil
436 }
437
438 // SetDeadline sets the read and write deadlines
439 func (s *Stream) SetDeadline(t time.Time) error {
440         if err := s.SetReadDeadline(t); err != nil {
441                 return err
442         }
443         if err := s.SetWriteDeadline(t); err != nil {
444                 return err
445         }
446         return nil
447 }
448
449 // SetReadDeadline sets the deadline for future Read calls.
450 func (s *Stream) SetReadDeadline(t time.Time) error {
451         s.readDeadline.Store(t)
452         return nil
453 }
454
455 // SetWriteDeadline sets the deadline for future Write calls
456 func (s *Stream) SetWriteDeadline(t time.Time) error {
457         s.writeDeadline.Store(t)
458         return nil
459 }
460
461 // Shrink is used to compact the amount of buffers utilized
462 // This is useful when using Yamux in a connection pool to reduce
463 // the idle memory utilization.
464 func (s *Stream) Shrink() {
465         s.recvLock.Lock()
466         if s.recvBuf != nil && s.recvBuf.Len() == 0 {
467                 s.recvBuf = nil
468         }
469         s.recvLock.Unlock()
470 }