OSDN Git Service

new repo
[bytom/vapor.git] / vendor / google.golang.org / grpc / stream.go
1 /*
2  *
3  * Copyright 2014 gRPC authors.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  */
18
19 package grpc
20
21 import (
22         "bytes"
23         "errors"
24         "io"
25         "sync"
26         "time"
27
28         "golang.org/x/net/context"
29         "golang.org/x/net/trace"
30         "google.golang.org/grpc/balancer"
31         "google.golang.org/grpc/codes"
32         "google.golang.org/grpc/metadata"
33         "google.golang.org/grpc/peer"
34         "google.golang.org/grpc/stats"
35         "google.golang.org/grpc/status"
36         "google.golang.org/grpc/transport"
37 )
38
39 // StreamHandler defines the handler called by gRPC server to complete the
40 // execution of a streaming RPC.
41 type StreamHandler func(srv interface{}, stream ServerStream) error
42
43 // StreamDesc represents a streaming RPC service's method specification.
44 type StreamDesc struct {
45         StreamName string
46         Handler    StreamHandler
47
48         // At least one of these is true.
49         ServerStreams bool
50         ClientStreams bool
51 }
52
53 // Stream defines the common interface a client or server stream has to satisfy.
54 type Stream interface {
55         // Context returns the context for this stream.
56         Context() context.Context
57         // SendMsg blocks until it sends m, the stream is done or the stream
58         // breaks.
59         // On error, it aborts the stream and returns an RPC status on client
60         // side. On server side, it simply returns the error to the caller.
61         // SendMsg is called by generated code. Also Users can call SendMsg
62         // directly when it is really needed in their use cases.
63         // It's safe to have a goroutine calling SendMsg and another goroutine calling
64         // recvMsg on the same stream at the same time.
65         // But it is not safe to call SendMsg on the same stream in different goroutines.
66         SendMsg(m interface{}) error
67         // RecvMsg blocks until it receives a message or the stream is
68         // done. On client side, it returns io.EOF when the stream is done. On
69         // any other error, it aborts the stream and returns an RPC status. On
70         // server side, it simply returns the error to the caller.
71         // It's safe to have a goroutine calling SendMsg and another goroutine calling
72         // recvMsg on the same stream at the same time.
73         // But it is not safe to call RecvMsg on the same stream in different goroutines.
74         RecvMsg(m interface{}) error
75 }
76
77 // ClientStream defines the interface a client stream has to satisfy.
78 type ClientStream interface {
79         // Header returns the header metadata received from the server if there
80         // is any. It blocks if the metadata is not ready to read.
81         Header() (metadata.MD, error)
82         // Trailer returns the trailer metadata from the server, if there is any.
83         // It must only be called after stream.CloseAndRecv has returned, or
84         // stream.Recv has returned a non-nil error (including io.EOF).
85         Trailer() metadata.MD
86         // CloseSend closes the send direction of the stream. It closes the stream
87         // when non-nil error is met.
88         CloseSend() error
89         // Stream.SendMsg() may return a non-nil error when something wrong happens sending
90         // the request. The returned error indicates the status of this sending, not the final
91         // status of the RPC.
92         // Always call Stream.RecvMsg() to get the final status if you care about the status of
93         // the RPC.
94         Stream
95 }
96
97 // NewStream creates a new Stream for the client side. This is typically
98 // called by generated code.
99 func (cc *ClientConn) NewStream(ctx context.Context, desc *StreamDesc, method string, opts ...CallOption) (ClientStream, error) {
100         if cc.dopts.streamInt != nil {
101                 return cc.dopts.streamInt(ctx, desc, cc, method, newClientStream, opts...)
102         }
103         return newClientStream(ctx, desc, cc, method, opts...)
104 }
105
106 // NewClientStream creates a new Stream for the client side. This is typically
107 // called by generated code.
108 //
109 // DEPRECATED: Use ClientConn.NewStream instead.
110 func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) {
111         return cc.NewStream(ctx, desc, method, opts...)
112 }
113
114 func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) {
115         var (
116                 t      transport.ClientTransport
117                 s      *transport.Stream
118                 done   func(balancer.DoneInfo)
119                 cancel context.CancelFunc
120         )
121         c := defaultCallInfo()
122         mc := cc.GetMethodConfig(method)
123         if mc.WaitForReady != nil {
124                 c.failFast = !*mc.WaitForReady
125         }
126
127         if mc.Timeout != nil {
128                 ctx, cancel = context.WithTimeout(ctx, *mc.Timeout)
129                 defer func() {
130                         if err != nil {
131                                 cancel()
132                         }
133                 }()
134         }
135
136         opts = append(cc.dopts.callOptions, opts...)
137         for _, o := range opts {
138                 if err := o.before(c); err != nil {
139                         return nil, toRPCErr(err)
140                 }
141         }
142         c.maxSendMessageSize = getMaxSize(mc.MaxReqSize, c.maxSendMessageSize, defaultClientMaxSendMessageSize)
143         c.maxReceiveMessageSize = getMaxSize(mc.MaxRespSize, c.maxReceiveMessageSize, defaultClientMaxReceiveMessageSize)
144
145         callHdr := &transport.CallHdr{
146                 Host:   cc.authority,
147                 Method: method,
148                 // If it's not client streaming, we should already have the request to be sent,
149                 // so we don't flush the header.
150                 // If it's client streaming, the user may never send a request or send it any
151                 // time soon, so we ask the transport to flush the header.
152                 Flush: desc.ClientStreams,
153         }
154         if cc.dopts.cp != nil {
155                 callHdr.SendCompress = cc.dopts.cp.Type()
156         }
157         if c.creds != nil {
158                 callHdr.Creds = c.creds
159         }
160         var trInfo traceInfo
161         if EnableTracing {
162                 trInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method)
163                 trInfo.firstLine.client = true
164                 if deadline, ok := ctx.Deadline(); ok {
165                         trInfo.firstLine.deadline = deadline.Sub(time.Now())
166                 }
167                 trInfo.tr.LazyLog(&trInfo.firstLine, false)
168                 ctx = trace.NewContext(ctx, trInfo.tr)
169                 defer func() {
170                         if err != nil {
171                                 // Need to call tr.finish() if error is returned.
172                                 // Because tr will not be returned to caller.
173                                 trInfo.tr.LazyPrintf("RPC: [%v]", err)
174                                 trInfo.tr.SetError()
175                                 trInfo.tr.Finish()
176                         }
177                 }()
178         }
179         ctx = newContextWithRPCInfo(ctx, c.failFast)
180         sh := cc.dopts.copts.StatsHandler
181         if sh != nil {
182                 ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method, FailFast: c.failFast})
183                 begin := &stats.Begin{
184                         Client:    true,
185                         BeginTime: time.Now(),
186                         FailFast:  c.failFast,
187                 }
188                 sh.HandleRPC(ctx, begin)
189                 defer func() {
190                         if err != nil {
191                                 // Only handle end stats if err != nil.
192                                 end := &stats.End{
193                                         Client: true,
194                                         Error:  err,
195                                 }
196                                 sh.HandleRPC(ctx, end)
197                         }
198                 }()
199         }
200         for {
201                 t, done, err = cc.getTransport(ctx, c.failFast)
202                 if err != nil {
203                         // TODO(zhaoq): Probably revisit the error handling.
204                         if _, ok := status.FromError(err); ok {
205                                 return nil, err
206                         }
207                         if err == errConnClosing || err == errConnUnavailable {
208                                 if c.failFast {
209                                         return nil, Errorf(codes.Unavailable, "%v", err)
210                                 }
211                                 continue
212                         }
213                         // All the other errors are treated as Internal errors.
214                         return nil, Errorf(codes.Internal, "%v", err)
215                 }
216
217                 s, err = t.NewStream(ctx, callHdr)
218                 if err != nil {
219                         if _, ok := err.(transport.ConnectionError); ok && done != nil {
220                                 // If error is connection error, transport was sending data on wire,
221                                 // and we are not sure if anything has been sent on wire.
222                                 // If error is not connection error, we are sure nothing has been sent.
223                                 updateRPCInfoInContext(ctx, rpcInfo{bytesSent: true, bytesReceived: false})
224                         }
225                         if done != nil {
226                                 done(balancer.DoneInfo{Err: err})
227                                 done = nil
228                         }
229                         if _, ok := err.(transport.ConnectionError); (ok || err == transport.ErrStreamDrain) && !c.failFast {
230                                 continue
231                         }
232                         return nil, toRPCErr(err)
233                 }
234                 break
235         }
236         // Set callInfo.peer object from stream's context.
237         if peer, ok := peer.FromContext(s.Context()); ok {
238                 c.peer = peer
239         }
240         cs := &clientStream{
241                 opts:   opts,
242                 c:      c,
243                 desc:   desc,
244                 codec:  cc.dopts.codec,
245                 cp:     cc.dopts.cp,
246                 dc:     cc.dopts.dc,
247                 cancel: cancel,
248
249                 done: done,
250                 t:    t,
251                 s:    s,
252                 p:    &parser{r: s},
253
254                 tracing: EnableTracing,
255                 trInfo:  trInfo,
256
257                 statsCtx:     ctx,
258                 statsHandler: cc.dopts.copts.StatsHandler,
259         }
260         // Listen on ctx.Done() to detect cancellation and s.Done() to detect normal termination
261         // when there is no pending I/O operations on this stream.
262         go func() {
263                 select {
264                 case <-t.Error():
265                         // Incur transport error, simply exit.
266                 case <-cc.ctx.Done():
267                         cs.finish(ErrClientConnClosing)
268                         cs.closeTransportStream(ErrClientConnClosing)
269                 case <-s.Done():
270                         // TODO: The trace of the RPC is terminated here when there is no pending
271                         // I/O, which is probably not the optimal solution.
272                         cs.finish(s.Status().Err())
273                         cs.closeTransportStream(nil)
274                 case <-s.GoAway():
275                         cs.finish(errConnDrain)
276                         cs.closeTransportStream(errConnDrain)
277                 case <-s.Context().Done():
278                         err := s.Context().Err()
279                         cs.finish(err)
280                         cs.closeTransportStream(transport.ContextErr(err))
281                 }
282         }()
283         return cs, nil
284 }
285
286 // clientStream implements a client side Stream.
287 type clientStream struct {
288         opts   []CallOption
289         c      *callInfo
290         t      transport.ClientTransport
291         s      *transport.Stream
292         p      *parser
293         desc   *StreamDesc
294         codec  Codec
295         cp     Compressor
296         dc     Decompressor
297         cancel context.CancelFunc
298
299         tracing bool // set to EnableTracing when the clientStream is created.
300
301         mu       sync.Mutex
302         done     func(balancer.DoneInfo)
303         closed   bool
304         finished bool
305         // trInfo.tr is set when the clientStream is created (if EnableTracing is true),
306         // and is set to nil when the clientStream's finish method is called.
307         trInfo traceInfo
308
309         // statsCtx keeps the user context for stats handling.
310         // All stats collection should use the statsCtx (instead of the stream context)
311         // so that all the generated stats for a particular RPC can be associated in the processing phase.
312         statsCtx     context.Context
313         statsHandler stats.Handler
314 }
315
316 func (cs *clientStream) Context() context.Context {
317         return cs.s.Context()
318 }
319
320 func (cs *clientStream) Header() (metadata.MD, error) {
321         m, err := cs.s.Header()
322         if err != nil {
323                 if _, ok := err.(transport.ConnectionError); !ok {
324                         cs.closeTransportStream(err)
325                 }
326         }
327         return m, err
328 }
329
330 func (cs *clientStream) Trailer() metadata.MD {
331         return cs.s.Trailer()
332 }
333
334 func (cs *clientStream) SendMsg(m interface{}) (err error) {
335         if cs.tracing {
336                 cs.mu.Lock()
337                 if cs.trInfo.tr != nil {
338                         cs.trInfo.tr.LazyLog(&payload{sent: true, msg: m}, true)
339                 }
340                 cs.mu.Unlock()
341         }
342         // TODO Investigate how to signal the stats handling party.
343         // generate error stats if err != nil && err != io.EOF?
344         defer func() {
345                 if err != nil {
346                         cs.finish(err)
347                 }
348                 if err == nil {
349                         return
350                 }
351                 if err == io.EOF {
352                         // Specialize the process for server streaming. SendMsg is only called
353                         // once when creating the stream object. io.EOF needs to be skipped when
354                         // the rpc is early finished (before the stream object is created.).
355                         // TODO: It is probably better to move this into the generated code.
356                         if !cs.desc.ClientStreams && cs.desc.ServerStreams {
357                                 err = nil
358                         }
359                         return
360                 }
361                 if _, ok := err.(transport.ConnectionError); !ok {
362                         cs.closeTransportStream(err)
363                 }
364                 err = toRPCErr(err)
365         }()
366         var outPayload *stats.OutPayload
367         if cs.statsHandler != nil {
368                 outPayload = &stats.OutPayload{
369                         Client: true,
370                 }
371         }
372         hdr, data, err := encode(cs.codec, m, cs.cp, bytes.NewBuffer([]byte{}), outPayload)
373         if err != nil {
374                 return err
375         }
376         if cs.c.maxSendMessageSize == nil {
377                 return Errorf(codes.Internal, "callInfo maxSendMessageSize field uninitialized(nil)")
378         }
379         if len(data) > *cs.c.maxSendMessageSize {
380                 return Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), *cs.c.maxSendMessageSize)
381         }
382         err = cs.t.Write(cs.s, hdr, data, &transport.Options{Last: false})
383         if err == nil && outPayload != nil {
384                 outPayload.SentTime = time.Now()
385                 cs.statsHandler.HandleRPC(cs.statsCtx, outPayload)
386         }
387         return err
388 }
389
390 func (cs *clientStream) RecvMsg(m interface{}) (err error) {
391         var inPayload *stats.InPayload
392         if cs.statsHandler != nil {
393                 inPayload = &stats.InPayload{
394                         Client: true,
395                 }
396         }
397         if cs.c.maxReceiveMessageSize == nil {
398                 return Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)")
399         }
400         err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, inPayload)
401         defer func() {
402                 // err != nil indicates the termination of the stream.
403                 if err != nil {
404                         cs.finish(err)
405                 }
406         }()
407         if err == nil {
408                 if cs.tracing {
409                         cs.mu.Lock()
410                         if cs.trInfo.tr != nil {
411                                 cs.trInfo.tr.LazyLog(&payload{sent: false, msg: m}, true)
412                         }
413                         cs.mu.Unlock()
414                 }
415                 if inPayload != nil {
416                         cs.statsHandler.HandleRPC(cs.statsCtx, inPayload)
417                 }
418                 if !cs.desc.ClientStreams || cs.desc.ServerStreams {
419                         return
420                 }
421                 // Special handling for client streaming rpc.
422                 // This recv expects EOF or errors, so we don't collect inPayload.
423                 if cs.c.maxReceiveMessageSize == nil {
424                         return Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)")
425                 }
426                 err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, nil)
427                 cs.closeTransportStream(err)
428                 if err == nil {
429                         return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
430                 }
431                 if err == io.EOF {
432                         if se := cs.s.Status().Err(); se != nil {
433                                 return se
434                         }
435                         cs.finish(err)
436                         return nil
437                 }
438                 return toRPCErr(err)
439         }
440         if _, ok := err.(transport.ConnectionError); !ok {
441                 cs.closeTransportStream(err)
442         }
443         if err == io.EOF {
444                 if statusErr := cs.s.Status().Err(); statusErr != nil {
445                         return statusErr
446                 }
447                 // Returns io.EOF to indicate the end of the stream.
448                 return
449         }
450         return toRPCErr(err)
451 }
452
453 func (cs *clientStream) CloseSend() (err error) {
454         err = cs.t.Write(cs.s, nil, nil, &transport.Options{Last: true})
455         defer func() {
456                 if err != nil {
457                         cs.finish(err)
458                 }
459         }()
460         if err == nil || err == io.EOF {
461                 return nil
462         }
463         if _, ok := err.(transport.ConnectionError); !ok {
464                 cs.closeTransportStream(err)
465         }
466         err = toRPCErr(err)
467         return
468 }
469
470 func (cs *clientStream) closeTransportStream(err error) {
471         cs.mu.Lock()
472         if cs.closed {
473                 cs.mu.Unlock()
474                 return
475         }
476         cs.closed = true
477         cs.mu.Unlock()
478         cs.t.CloseStream(cs.s, err)
479 }
480
481 func (cs *clientStream) finish(err error) {
482         cs.mu.Lock()
483         defer cs.mu.Unlock()
484         if cs.finished {
485                 return
486         }
487         cs.finished = true
488         defer func() {
489                 if cs.cancel != nil {
490                         cs.cancel()
491                 }
492         }()
493         for _, o := range cs.opts {
494                 o.after(cs.c)
495         }
496         if cs.done != nil {
497                 updateRPCInfoInContext(cs.s.Context(), rpcInfo{
498                         bytesSent:     cs.s.BytesSent(),
499                         bytesReceived: cs.s.BytesReceived(),
500                 })
501                 cs.done(balancer.DoneInfo{Err: err})
502                 cs.done = nil
503         }
504         if cs.statsHandler != nil {
505                 end := &stats.End{
506                         Client:  true,
507                         EndTime: time.Now(),
508                 }
509                 if err != io.EOF {
510                         // end.Error is nil if the RPC finished successfully.
511                         end.Error = toRPCErr(err)
512                 }
513                 cs.statsHandler.HandleRPC(cs.statsCtx, end)
514         }
515         if !cs.tracing {
516                 return
517         }
518         if cs.trInfo.tr != nil {
519                 if err == nil || err == io.EOF {
520                         cs.trInfo.tr.LazyPrintf("RPC: [OK]")
521                 } else {
522                         cs.trInfo.tr.LazyPrintf("RPC: [%v]", err)
523                         cs.trInfo.tr.SetError()
524                 }
525                 cs.trInfo.tr.Finish()
526                 cs.trInfo.tr = nil
527         }
528 }
529
530 // ServerStream defines the interface a server stream has to satisfy.
531 type ServerStream interface {
532         // SetHeader sets the header metadata. It may be called multiple times.
533         // When call multiple times, all the provided metadata will be merged.
534         // All the metadata will be sent out when one of the following happens:
535         //  - ServerStream.SendHeader() is called;
536         //  - The first response is sent out;
537         //  - An RPC status is sent out (error or success).
538         SetHeader(metadata.MD) error
539         // SendHeader sends the header metadata.
540         // The provided md and headers set by SetHeader() will be sent.
541         // It fails if called multiple times.
542         SendHeader(metadata.MD) error
543         // SetTrailer sets the trailer metadata which will be sent with the RPC status.
544         // When called more than once, all the provided metadata will be merged.
545         SetTrailer(metadata.MD)
546         Stream
547 }
548
549 // serverStream implements a server side Stream.
550 type serverStream struct {
551         t                     transport.ServerTransport
552         s                     *transport.Stream
553         p                     *parser
554         codec                 Codec
555         cp                    Compressor
556         dc                    Decompressor
557         maxReceiveMessageSize int
558         maxSendMessageSize    int
559         trInfo                *traceInfo
560
561         statsHandler stats.Handler
562
563         mu sync.Mutex // protects trInfo.tr after the service handler runs.
564 }
565
566 func (ss *serverStream) Context() context.Context {
567         return ss.s.Context()
568 }
569
570 func (ss *serverStream) SetHeader(md metadata.MD) error {
571         if md.Len() == 0 {
572                 return nil
573         }
574         return ss.s.SetHeader(md)
575 }
576
577 func (ss *serverStream) SendHeader(md metadata.MD) error {
578         return ss.t.WriteHeader(ss.s, md)
579 }
580
581 func (ss *serverStream) SetTrailer(md metadata.MD) {
582         if md.Len() == 0 {
583                 return
584         }
585         ss.s.SetTrailer(md)
586         return
587 }
588
589 func (ss *serverStream) SendMsg(m interface{}) (err error) {
590         defer func() {
591                 if ss.trInfo != nil {
592                         ss.mu.Lock()
593                         if ss.trInfo.tr != nil {
594                                 if err == nil {
595                                         ss.trInfo.tr.LazyLog(&payload{sent: true, msg: m}, true)
596                                 } else {
597                                         ss.trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
598                                         ss.trInfo.tr.SetError()
599                                 }
600                         }
601                         ss.mu.Unlock()
602                 }
603                 if err != nil && err != io.EOF {
604                         st, _ := status.FromError(toRPCErr(err))
605                         ss.t.WriteStatus(ss.s, st)
606                 }
607         }()
608         var outPayload *stats.OutPayload
609         if ss.statsHandler != nil {
610                 outPayload = &stats.OutPayload{}
611         }
612         hdr, data, err := encode(ss.codec, m, ss.cp, bytes.NewBuffer([]byte{}), outPayload)
613         if err != nil {
614                 return err
615         }
616         if len(data) > ss.maxSendMessageSize {
617                 return Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), ss.maxSendMessageSize)
618         }
619         if err := ss.t.Write(ss.s, hdr, data, &transport.Options{Last: false}); err != nil {
620                 return toRPCErr(err)
621         }
622         if outPayload != nil {
623                 outPayload.SentTime = time.Now()
624                 ss.statsHandler.HandleRPC(ss.s.Context(), outPayload)
625         }
626         return nil
627 }
628
629 func (ss *serverStream) RecvMsg(m interface{}) (err error) {
630         defer func() {
631                 if ss.trInfo != nil {
632                         ss.mu.Lock()
633                         if ss.trInfo.tr != nil {
634                                 if err == nil {
635                                         ss.trInfo.tr.LazyLog(&payload{sent: false, msg: m}, true)
636                                 } else if err != io.EOF {
637                                         ss.trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
638                                         ss.trInfo.tr.SetError()
639                                 }
640                         }
641                         ss.mu.Unlock()
642                 }
643                 if err != nil && err != io.EOF {
644                         st, _ := status.FromError(toRPCErr(err))
645                         ss.t.WriteStatus(ss.s, st)
646                 }
647         }()
648         var inPayload *stats.InPayload
649         if ss.statsHandler != nil {
650                 inPayload = &stats.InPayload{}
651         }
652         if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, inPayload); err != nil {
653                 if err == io.EOF {
654                         return err
655                 }
656                 if err == io.ErrUnexpectedEOF {
657                         err = Errorf(codes.Internal, io.ErrUnexpectedEOF.Error())
658                 }
659                 return toRPCErr(err)
660         }
661         if inPayload != nil {
662                 ss.statsHandler.HandleRPC(ss.s.Context(), inPayload)
663         }
664         return nil
665 }
666
667 // MethodFromServerStream returns the method string for the input stream.
668 // The returned string is in the format of "/service/method".
669 func MethodFromServerStream(stream ServerStream) (string, bool) {
670         s, ok := transport.StreamFromContext(stream.Context())
671         if !ok {
672                 return "", ok
673         }
674         return s.Method(), ok
675 }