OSDN Git Service

new repo
[bytom/vapor.git] / vendor / google.golang.org / grpc / server.go
1 /*
2  *
3  * Copyright 2014 gRPC authors.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  */
18
19 package grpc
20
21 import (
22         "bytes"
23         "errors"
24         "fmt"
25         "io"
26         "math"
27         "net"
28         "net/http"
29         "reflect"
30         "runtime"
31         "strings"
32         "sync"
33         "time"
34
35         "golang.org/x/net/context"
36         "golang.org/x/net/http2"
37         "golang.org/x/net/trace"
38         "google.golang.org/grpc/codes"
39         "google.golang.org/grpc/credentials"
40         "google.golang.org/grpc/grpclog"
41         "google.golang.org/grpc/internal"
42         "google.golang.org/grpc/keepalive"
43         "google.golang.org/grpc/metadata"
44         "google.golang.org/grpc/stats"
45         "google.golang.org/grpc/status"
46         "google.golang.org/grpc/tap"
47         "google.golang.org/grpc/transport"
48 )
49
50 const (
51         defaultServerMaxReceiveMessageSize = 1024 * 1024 * 4
52         defaultServerMaxSendMessageSize    = math.MaxInt32
53 )
54
55 type methodHandler func(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor UnaryServerInterceptor) (interface{}, error)
56
57 // MethodDesc represents an RPC service's method specification.
58 type MethodDesc struct {
59         MethodName string
60         Handler    methodHandler
61 }
62
63 // ServiceDesc represents an RPC service's specification.
64 type ServiceDesc struct {
65         ServiceName string
66         // The pointer to the service interface. Used to check whether the user
67         // provided implementation satisfies the interface requirements.
68         HandlerType interface{}
69         Methods     []MethodDesc
70         Streams     []StreamDesc
71         Metadata    interface{}
72 }
73
74 // service consists of the information of the server serving this service and
75 // the methods in this service.
76 type service struct {
77         server interface{} // the server for service methods
78         md     map[string]*MethodDesc
79         sd     map[string]*StreamDesc
80         mdata  interface{}
81 }
82
83 // Server is a gRPC server to serve RPC requests.
84 type Server struct {
85         opts options
86
87         mu     sync.Mutex // guards following
88         lis    map[net.Listener]bool
89         conns  map[io.Closer]bool
90         serve  bool
91         drain  bool
92         ctx    context.Context
93         cancel context.CancelFunc
94         // A CondVar to let GracefulStop() blocks until all the pending RPCs are finished
95         // and all the transport goes away.
96         cv     *sync.Cond
97         m      map[string]*service // service name -> service info
98         events trace.EventLog
99
100         quit     chan struct{}
101         done     chan struct{}
102         quitOnce sync.Once
103         doneOnce sync.Once
104 }
105
106 type options struct {
107         creds                 credentials.TransportCredentials
108         codec                 Codec
109         cp                    Compressor
110         dc                    Decompressor
111         unaryInt              UnaryServerInterceptor
112         streamInt             StreamServerInterceptor
113         inTapHandle           tap.ServerInHandle
114         statsHandler          stats.Handler
115         maxConcurrentStreams  uint32
116         maxReceiveMessageSize int
117         maxSendMessageSize    int
118         useHandlerImpl        bool // use http.Handler-based server
119         unknownStreamDesc     *StreamDesc
120         keepaliveParams       keepalive.ServerParameters
121         keepalivePolicy       keepalive.EnforcementPolicy
122         initialWindowSize     int32
123         initialConnWindowSize int32
124         writeBufferSize       int
125         readBufferSize        int
126 }
127
128 var defaultServerOptions = options{
129         maxReceiveMessageSize: defaultServerMaxReceiveMessageSize,
130         maxSendMessageSize:    defaultServerMaxSendMessageSize,
131 }
132
133 // A ServerOption sets options such as credentials, codec and keepalive parameters, etc.
134 type ServerOption func(*options)
135
136 // WriteBufferSize lets you set the size of write buffer, this determines how much data can be batched
137 // before doing a write on the wire.
138 func WriteBufferSize(s int) ServerOption {
139         return func(o *options) {
140                 o.writeBufferSize = s
141         }
142 }
143
144 // ReadBufferSize lets you set the size of read buffer, this determines how much data can be read at most
145 // for one read syscall.
146 func ReadBufferSize(s int) ServerOption {
147         return func(o *options) {
148                 o.readBufferSize = s
149         }
150 }
151
152 // InitialWindowSize returns a ServerOption that sets window size for stream.
153 // The lower bound for window size is 64K and any value smaller than that will be ignored.
154 func InitialWindowSize(s int32) ServerOption {
155         return func(o *options) {
156                 o.initialWindowSize = s
157         }
158 }
159
160 // InitialConnWindowSize returns a ServerOption that sets window size for a connection.
161 // The lower bound for window size is 64K and any value smaller than that will be ignored.
162 func InitialConnWindowSize(s int32) ServerOption {
163         return func(o *options) {
164                 o.initialConnWindowSize = s
165         }
166 }
167
168 // KeepaliveParams returns a ServerOption that sets keepalive and max-age parameters for the server.
169 func KeepaliveParams(kp keepalive.ServerParameters) ServerOption {
170         return func(o *options) {
171                 o.keepaliveParams = kp
172         }
173 }
174
175 // KeepaliveEnforcementPolicy returns a ServerOption that sets keepalive enforcement policy for the server.
176 func KeepaliveEnforcementPolicy(kep keepalive.EnforcementPolicy) ServerOption {
177         return func(o *options) {
178                 o.keepalivePolicy = kep
179         }
180 }
181
182 // CustomCodec returns a ServerOption that sets a codec for message marshaling and unmarshaling.
183 func CustomCodec(codec Codec) ServerOption {
184         return func(o *options) {
185                 o.codec = codec
186         }
187 }
188
189 // RPCCompressor returns a ServerOption that sets a compressor for outbound messages.
190 func RPCCompressor(cp Compressor) ServerOption {
191         return func(o *options) {
192                 o.cp = cp
193         }
194 }
195
196 // RPCDecompressor returns a ServerOption that sets a decompressor for inbound messages.
197 func RPCDecompressor(dc Decompressor) ServerOption {
198         return func(o *options) {
199                 o.dc = dc
200         }
201 }
202
203 // MaxMsgSize returns a ServerOption to set the max message size in bytes the server can receive.
204 // If this is not set, gRPC uses the default limit. Deprecated: use MaxRecvMsgSize instead.
205 func MaxMsgSize(m int) ServerOption {
206         return MaxRecvMsgSize(m)
207 }
208
209 // MaxRecvMsgSize returns a ServerOption to set the max message size in bytes the server can receive.
210 // If this is not set, gRPC uses the default 4MB.
211 func MaxRecvMsgSize(m int) ServerOption {
212         return func(o *options) {
213                 o.maxReceiveMessageSize = m
214         }
215 }
216
217 // MaxSendMsgSize returns a ServerOption to set the max message size in bytes the server can send.
218 // If this is not set, gRPC uses the default 4MB.
219 func MaxSendMsgSize(m int) ServerOption {
220         return func(o *options) {
221                 o.maxSendMessageSize = m
222         }
223 }
224
225 // MaxConcurrentStreams returns a ServerOption that will apply a limit on the number
226 // of concurrent streams to each ServerTransport.
227 func MaxConcurrentStreams(n uint32) ServerOption {
228         return func(o *options) {
229                 o.maxConcurrentStreams = n
230         }
231 }
232
233 // Creds returns a ServerOption that sets credentials for server connections.
234 func Creds(c credentials.TransportCredentials) ServerOption {
235         return func(o *options) {
236                 o.creds = c
237         }
238 }
239
240 // UnaryInterceptor returns a ServerOption that sets the UnaryServerInterceptor for the
241 // server. Only one unary interceptor can be installed. The construction of multiple
242 // interceptors (e.g., chaining) can be implemented at the caller.
243 func UnaryInterceptor(i UnaryServerInterceptor) ServerOption {
244         return func(o *options) {
245                 if o.unaryInt != nil {
246                         panic("The unary server interceptor was already set and may not be reset.")
247                 }
248                 o.unaryInt = i
249         }
250 }
251
252 // StreamInterceptor returns a ServerOption that sets the StreamServerInterceptor for the
253 // server. Only one stream interceptor can be installed.
254 func StreamInterceptor(i StreamServerInterceptor) ServerOption {
255         return func(o *options) {
256                 if o.streamInt != nil {
257                         panic("The stream server interceptor was already set and may not be reset.")
258                 }
259                 o.streamInt = i
260         }
261 }
262
263 // InTapHandle returns a ServerOption that sets the tap handle for all the server
264 // transport to be created. Only one can be installed.
265 func InTapHandle(h tap.ServerInHandle) ServerOption {
266         return func(o *options) {
267                 if o.inTapHandle != nil {
268                         panic("The tap handle was already set and may not be reset.")
269                 }
270                 o.inTapHandle = h
271         }
272 }
273
274 // StatsHandler returns a ServerOption that sets the stats handler for the server.
275 func StatsHandler(h stats.Handler) ServerOption {
276         return func(o *options) {
277                 o.statsHandler = h
278         }
279 }
280
281 // UnknownServiceHandler returns a ServerOption that allows for adding a custom
282 // unknown service handler. The provided method is a bidi-streaming RPC service
283 // handler that will be invoked instead of returning the "unimplemented" gRPC
284 // error whenever a request is received for an unregistered service or method.
285 // The handling function has full access to the Context of the request and the
286 // stream, and the invocation bypasses interceptors.
287 func UnknownServiceHandler(streamHandler StreamHandler) ServerOption {
288         return func(o *options) {
289                 o.unknownStreamDesc = &StreamDesc{
290                         StreamName: "unknown_service_handler",
291                         Handler:    streamHandler,
292                         // We need to assume that the users of the streamHandler will want to use both.
293                         ClientStreams: true,
294                         ServerStreams: true,
295                 }
296         }
297 }
298
299 // NewServer creates a gRPC server which has no service registered and has not
300 // started to accept requests yet.
301 func NewServer(opt ...ServerOption) *Server {
302         opts := defaultServerOptions
303         for _, o := range opt {
304                 o(&opts)
305         }
306         if opts.codec == nil {
307                 // Set the default codec.
308                 opts.codec = protoCodec{}
309         }
310         s := &Server{
311                 lis:   make(map[net.Listener]bool),
312                 opts:  opts,
313                 conns: make(map[io.Closer]bool),
314                 m:     make(map[string]*service),
315                 quit:  make(chan struct{}),
316                 done:  make(chan struct{}),
317         }
318         s.cv = sync.NewCond(&s.mu)
319         s.ctx, s.cancel = context.WithCancel(context.Background())
320         if EnableTracing {
321                 _, file, line, _ := runtime.Caller(1)
322                 s.events = trace.NewEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line))
323         }
324         return s
325 }
326
327 // printf records an event in s's event log, unless s has been stopped.
328 // REQUIRES s.mu is held.
329 func (s *Server) printf(format string, a ...interface{}) {
330         if s.events != nil {
331                 s.events.Printf(format, a...)
332         }
333 }
334
335 // errorf records an error in s's event log, unless s has been stopped.
336 // REQUIRES s.mu is held.
337 func (s *Server) errorf(format string, a ...interface{}) {
338         if s.events != nil {
339                 s.events.Errorf(format, a...)
340         }
341 }
342
343 // RegisterService registers a service and its implementation to the gRPC
344 // server. It is called from the IDL generated code. This must be called before
345 // invoking Serve.
346 func (s *Server) RegisterService(sd *ServiceDesc, ss interface{}) {
347         ht := reflect.TypeOf(sd.HandlerType).Elem()
348         st := reflect.TypeOf(ss)
349         if !st.Implements(ht) {
350                 grpclog.Fatalf("grpc: Server.RegisterService found the handler of type %v that does not satisfy %v", st, ht)
351         }
352         s.register(sd, ss)
353 }
354
355 func (s *Server) register(sd *ServiceDesc, ss interface{}) {
356         s.mu.Lock()
357         defer s.mu.Unlock()
358         s.printf("RegisterService(%q)", sd.ServiceName)
359         if s.serve {
360                 grpclog.Fatalf("grpc: Server.RegisterService after Server.Serve for %q", sd.ServiceName)
361         }
362         if _, ok := s.m[sd.ServiceName]; ok {
363                 grpclog.Fatalf("grpc: Server.RegisterService found duplicate service registration for %q", sd.ServiceName)
364         }
365         srv := &service{
366                 server: ss,
367                 md:     make(map[string]*MethodDesc),
368                 sd:     make(map[string]*StreamDesc),
369                 mdata:  sd.Metadata,
370         }
371         for i := range sd.Methods {
372                 d := &sd.Methods[i]
373                 srv.md[d.MethodName] = d
374         }
375         for i := range sd.Streams {
376                 d := &sd.Streams[i]
377                 srv.sd[d.StreamName] = d
378         }
379         s.m[sd.ServiceName] = srv
380 }
381
382 // MethodInfo contains the information of an RPC including its method name and type.
383 type MethodInfo struct {
384         // Name is the method name only, without the service name or package name.
385         Name string
386         // IsClientStream indicates whether the RPC is a client streaming RPC.
387         IsClientStream bool
388         // IsServerStream indicates whether the RPC is a server streaming RPC.
389         IsServerStream bool
390 }
391
392 // ServiceInfo contains unary RPC method info, streaming RPC method info and metadata for a service.
393 type ServiceInfo struct {
394         Methods []MethodInfo
395         // Metadata is the metadata specified in ServiceDesc when registering service.
396         Metadata interface{}
397 }
398
399 // GetServiceInfo returns a map from service names to ServiceInfo.
400 // Service names include the package names, in the form of <package>.<service>.
401 func (s *Server) GetServiceInfo() map[string]ServiceInfo {
402         ret := make(map[string]ServiceInfo)
403         for n, srv := range s.m {
404                 methods := make([]MethodInfo, 0, len(srv.md)+len(srv.sd))
405                 for m := range srv.md {
406                         methods = append(methods, MethodInfo{
407                                 Name:           m,
408                                 IsClientStream: false,
409                                 IsServerStream: false,
410                         })
411                 }
412                 for m, d := range srv.sd {
413                         methods = append(methods, MethodInfo{
414                                 Name:           m,
415                                 IsClientStream: d.ClientStreams,
416                                 IsServerStream: d.ServerStreams,
417                         })
418                 }
419
420                 ret[n] = ServiceInfo{
421                         Methods:  methods,
422                         Metadata: srv.mdata,
423                 }
424         }
425         return ret
426 }
427
428 // ErrServerStopped indicates that the operation is now illegal because of
429 // the server being stopped.
430 var ErrServerStopped = errors.New("grpc: the server has been stopped")
431
432 func (s *Server) useTransportAuthenticator(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
433         if s.opts.creds == nil {
434                 return rawConn, nil, nil
435         }
436         return s.opts.creds.ServerHandshake(rawConn)
437 }
438
439 // Serve accepts incoming connections on the listener lis, creating a new
440 // ServerTransport and service goroutine for each. The service goroutines
441 // read gRPC requests and then call the registered handlers to reply to them.
442 // Serve returns when lis.Accept fails with fatal errors.  lis will be closed when
443 // this method returns.
444 // Serve always returns non-nil error.
445 func (s *Server) Serve(lis net.Listener) error {
446         s.mu.Lock()
447         s.printf("serving")
448         s.serve = true
449         if s.lis == nil {
450                 s.mu.Unlock()
451                 lis.Close()
452                 return ErrServerStopped
453         }
454         s.lis[lis] = true
455         s.mu.Unlock()
456         defer func() {
457                 s.mu.Lock()
458                 if s.lis != nil && s.lis[lis] {
459                         lis.Close()
460                         delete(s.lis, lis)
461                 }
462                 s.mu.Unlock()
463         }()
464
465         var tempDelay time.Duration // how long to sleep on accept failure
466
467         for {
468                 rawConn, err := lis.Accept()
469                 if err != nil {
470                         if ne, ok := err.(interface {
471                                 Temporary() bool
472                         }); ok && ne.Temporary() {
473                                 if tempDelay == 0 {
474                                         tempDelay = 5 * time.Millisecond
475                                 } else {
476                                         tempDelay *= 2
477                                 }
478                                 if max := 1 * time.Second; tempDelay > max {
479                                         tempDelay = max
480                                 }
481                                 s.mu.Lock()
482                                 s.printf("Accept error: %v; retrying in %v", err, tempDelay)
483                                 s.mu.Unlock()
484                                 timer := time.NewTimer(tempDelay)
485                                 select {
486                                 case <-timer.C:
487                                 case <-s.ctx.Done():
488                                 }
489                                 timer.Stop()
490                                 continue
491                         }
492                         s.mu.Lock()
493                         s.printf("done serving; Accept = %v", err)
494                         s.mu.Unlock()
495
496                         // If Stop or GracefulStop is called, block until they are done and return nil
497                         select {
498                         case <-s.quit:
499                                 <-s.done
500                                 return nil
501                         default:
502                         }
503                         return err
504                 }
505                 tempDelay = 0
506                 // Start a new goroutine to deal with rawConn
507                 // so we don't stall this Accept loop goroutine.
508                 go s.handleRawConn(rawConn)
509         }
510 }
511
512 // handleRawConn is run in its own goroutine and handles a just-accepted
513 // connection that has not had any I/O performed on it yet.
514 func (s *Server) handleRawConn(rawConn net.Conn) {
515         conn, authInfo, err := s.useTransportAuthenticator(rawConn)
516         if err != nil {
517                 s.mu.Lock()
518                 s.errorf("ServerHandshake(%q) failed: %v", rawConn.RemoteAddr(), err)
519                 s.mu.Unlock()
520                 grpclog.Warningf("grpc: Server.Serve failed to complete security handshake from %q: %v", rawConn.RemoteAddr(), err)
521                 // If serverHandShake returns ErrConnDispatched, keep rawConn open.
522                 if err != credentials.ErrConnDispatched {
523                         rawConn.Close()
524                 }
525                 return
526         }
527
528         s.mu.Lock()
529         if s.conns == nil {
530                 s.mu.Unlock()
531                 conn.Close()
532                 return
533         }
534         s.mu.Unlock()
535
536         if s.opts.useHandlerImpl {
537                 s.serveUsingHandler(conn)
538         } else {
539                 s.serveHTTP2Transport(conn, authInfo)
540         }
541 }
542
543 // serveHTTP2Transport sets up a http/2 transport (using the
544 // gRPC http2 server transport in transport/http2_server.go) and
545 // serves streams on it.
546 // This is run in its own goroutine (it does network I/O in
547 // transport.NewServerTransport).
548 func (s *Server) serveHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) {
549         config := &transport.ServerConfig{
550                 MaxStreams:            s.opts.maxConcurrentStreams,
551                 AuthInfo:              authInfo,
552                 InTapHandle:           s.opts.inTapHandle,
553                 StatsHandler:          s.opts.statsHandler,
554                 KeepaliveParams:       s.opts.keepaliveParams,
555                 KeepalivePolicy:       s.opts.keepalivePolicy,
556                 InitialWindowSize:     s.opts.initialWindowSize,
557                 InitialConnWindowSize: s.opts.initialConnWindowSize,
558                 WriteBufferSize:       s.opts.writeBufferSize,
559                 ReadBufferSize:        s.opts.readBufferSize,
560         }
561         st, err := transport.NewServerTransport("http2", c, config)
562         if err != nil {
563                 s.mu.Lock()
564                 s.errorf("NewServerTransport(%q) failed: %v", c.RemoteAddr(), err)
565                 s.mu.Unlock()
566                 c.Close()
567                 grpclog.Warningln("grpc: Server.Serve failed to create ServerTransport: ", err)
568                 return
569         }
570         if !s.addConn(st) {
571                 st.Close()
572                 return
573         }
574         s.serveStreams(st)
575 }
576
577 func (s *Server) serveStreams(st transport.ServerTransport) {
578         defer s.removeConn(st)
579         defer st.Close()
580         var wg sync.WaitGroup
581         st.HandleStreams(func(stream *transport.Stream) {
582                 wg.Add(1)
583                 go func() {
584                         defer wg.Done()
585                         s.handleStream(st, stream, s.traceInfo(st, stream))
586                 }()
587         }, func(ctx context.Context, method string) context.Context {
588                 if !EnableTracing {
589                         return ctx
590                 }
591                 tr := trace.New("grpc.Recv."+methodFamily(method), method)
592                 return trace.NewContext(ctx, tr)
593         })
594         wg.Wait()
595 }
596
597 var _ http.Handler = (*Server)(nil)
598
599 // serveUsingHandler is called from handleRawConn when s is configured
600 // to handle requests via the http.Handler interface. It sets up a
601 // net/http.Server to handle the just-accepted conn. The http.Server
602 // is configured to route all incoming requests (all HTTP/2 streams)
603 // to ServeHTTP, which creates a new ServerTransport for each stream.
604 // serveUsingHandler blocks until conn closes.
605 //
606 // This codepath is only used when Server.TestingUseHandlerImpl has
607 // been configured. This lets the end2end tests exercise the ServeHTTP
608 // method as one of the environment types.
609 //
610 // conn is the *tls.Conn that's already been authenticated.
611 func (s *Server) serveUsingHandler(conn net.Conn) {
612         if !s.addConn(conn) {
613                 conn.Close()
614                 return
615         }
616         defer s.removeConn(conn)
617         h2s := &http2.Server{
618                 MaxConcurrentStreams: s.opts.maxConcurrentStreams,
619         }
620         h2s.ServeConn(conn, &http2.ServeConnOpts{
621                 Handler: s,
622         })
623 }
624
625 // ServeHTTP implements the Go standard library's http.Handler
626 // interface by responding to the gRPC request r, by looking up
627 // the requested gRPC method in the gRPC server s.
628 //
629 // The provided HTTP request must have arrived on an HTTP/2
630 // connection. When using the Go standard library's server,
631 // practically this means that the Request must also have arrived
632 // over TLS.
633 //
634 // To share one port (such as 443 for https) between gRPC and an
635 // existing http.Handler, use a root http.Handler such as:
636 //
637 //   if r.ProtoMajor == 2 && strings.HasPrefix(
638 //      r.Header.Get("Content-Type"), "application/grpc") {
639 //      grpcServer.ServeHTTP(w, r)
640 //   } else {
641 //      yourMux.ServeHTTP(w, r)
642 //   }
643 //
644 // Note that ServeHTTP uses Go's HTTP/2 server implementation which is totally
645 // separate from grpc-go's HTTP/2 server. Performance and features may vary
646 // between the two paths. ServeHTTP does not support some gRPC features
647 // available through grpc-go's HTTP/2 server, and it is currently EXPERIMENTAL
648 // and subject to change.
649 func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
650         st, err := transport.NewServerHandlerTransport(w, r)
651         if err != nil {
652                 http.Error(w, err.Error(), http.StatusInternalServerError)
653                 return
654         }
655         if !s.addConn(st) {
656                 st.Close()
657                 return
658         }
659         defer s.removeConn(st)
660         s.serveStreams(st)
661 }
662
663 // traceInfo returns a traceInfo and associates it with stream, if tracing is enabled.
664 // If tracing is not enabled, it returns nil.
665 func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Stream) (trInfo *traceInfo) {
666         tr, ok := trace.FromContext(stream.Context())
667         if !ok {
668                 return nil
669         }
670
671         trInfo = &traceInfo{
672                 tr: tr,
673         }
674         trInfo.firstLine.client = false
675         trInfo.firstLine.remoteAddr = st.RemoteAddr()
676
677         if dl, ok := stream.Context().Deadline(); ok {
678                 trInfo.firstLine.deadline = dl.Sub(time.Now())
679         }
680         return trInfo
681 }
682
683 func (s *Server) addConn(c io.Closer) bool {
684         s.mu.Lock()
685         defer s.mu.Unlock()
686         if s.conns == nil || s.drain {
687                 return false
688         }
689         s.conns[c] = true
690         return true
691 }
692
693 func (s *Server) removeConn(c io.Closer) {
694         s.mu.Lock()
695         defer s.mu.Unlock()
696         if s.conns != nil {
697                 delete(s.conns, c)
698                 s.cv.Broadcast()
699         }
700 }
701
702 func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, cp Compressor, opts *transport.Options) error {
703         var (
704                 cbuf       *bytes.Buffer
705                 outPayload *stats.OutPayload
706         )
707         if cp != nil {
708                 cbuf = new(bytes.Buffer)
709         }
710         if s.opts.statsHandler != nil {
711                 outPayload = &stats.OutPayload{}
712         }
713         hdr, data, err := encode(s.opts.codec, msg, cp, cbuf, outPayload)
714         if err != nil {
715                 grpclog.Errorln("grpc: server failed to encode response: ", err)
716                 return err
717         }
718         if len(data) > s.opts.maxSendMessageSize {
719                 return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(data), s.opts.maxSendMessageSize)
720         }
721         err = t.Write(stream, hdr, data, opts)
722         if err == nil && outPayload != nil {
723                 outPayload.SentTime = time.Now()
724                 s.opts.statsHandler.HandleRPC(stream.Context(), outPayload)
725         }
726         return err
727 }
728
729 func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, md *MethodDesc, trInfo *traceInfo) (err error) {
730         sh := s.opts.statsHandler
731         if sh != nil {
732                 begin := &stats.Begin{
733                         BeginTime: time.Now(),
734                 }
735                 sh.HandleRPC(stream.Context(), begin)
736                 defer func() {
737                         end := &stats.End{
738                                 EndTime: time.Now(),
739                         }
740                         if err != nil && err != io.EOF {
741                                 end.Error = toRPCErr(err)
742                         }
743                         sh.HandleRPC(stream.Context(), end)
744                 }()
745         }
746         if trInfo != nil {
747                 defer trInfo.tr.Finish()
748                 trInfo.firstLine.client = false
749                 trInfo.tr.LazyLog(&trInfo.firstLine, false)
750                 defer func() {
751                         if err != nil && err != io.EOF {
752                                 trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
753                                 trInfo.tr.SetError()
754                         }
755                 }()
756         }
757         if s.opts.cp != nil {
758                 // NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686.
759                 stream.SetSendCompress(s.opts.cp.Type())
760         }
761         p := &parser{r: stream}
762         pf, req, err := p.recvMsg(s.opts.maxReceiveMessageSize)
763         if err == io.EOF {
764                 // The entire stream is done (for unary RPC only).
765                 return err
766         }
767         if err == io.ErrUnexpectedEOF {
768                 err = Errorf(codes.Internal, io.ErrUnexpectedEOF.Error())
769         }
770         if err != nil {
771                 if st, ok := status.FromError(err); ok {
772                         if e := t.WriteStatus(stream, st); e != nil {
773                                 grpclog.Warningf("grpc: Server.processUnaryRPC failed to write status %v", e)
774                         }
775                 } else {
776                         switch st := err.(type) {
777                         case transport.ConnectionError:
778                                 // Nothing to do here.
779                         case transport.StreamError:
780                                 if e := t.WriteStatus(stream, status.New(st.Code, st.Desc)); e != nil {
781                                         grpclog.Warningf("grpc: Server.processUnaryRPC failed to write status %v", e)
782                                 }
783                         default:
784                                 panic(fmt.Sprintf("grpc: Unexpected error (%T) from recvMsg: %v", st, st))
785                         }
786                 }
787                 return err
788         }
789
790         if err := checkRecvPayload(pf, stream.RecvCompress(), s.opts.dc); err != nil {
791                 if st, ok := status.FromError(err); ok {
792                         if e := t.WriteStatus(stream, st); e != nil {
793                                 grpclog.Warningf("grpc: Server.processUnaryRPC failed to write status %v", e)
794                         }
795                         return err
796                 }
797                 if e := t.WriteStatus(stream, status.New(codes.Internal, err.Error())); e != nil {
798                         grpclog.Warningf("grpc: Server.processUnaryRPC failed to write status %v", e)
799                 }
800
801                 // TODO checkRecvPayload always return RPC error. Add a return here if necessary.
802         }
803         var inPayload *stats.InPayload
804         if sh != nil {
805                 inPayload = &stats.InPayload{
806                         RecvTime: time.Now(),
807                 }
808         }
809         df := func(v interface{}) error {
810                 if inPayload != nil {
811                         inPayload.WireLength = len(req)
812                 }
813                 if pf == compressionMade {
814                         var err error
815                         req, err = s.opts.dc.Do(bytes.NewReader(req))
816                         if err != nil {
817                                 return Errorf(codes.Internal, err.Error())
818                         }
819                 }
820                 if len(req) > s.opts.maxReceiveMessageSize {
821                         // TODO: Revisit the error code. Currently keep it consistent with
822                         // java implementation.
823                         return status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", len(req), s.opts.maxReceiveMessageSize)
824                 }
825                 if err := s.opts.codec.Unmarshal(req, v); err != nil {
826                         return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err)
827                 }
828                 if inPayload != nil {
829                         inPayload.Payload = v
830                         inPayload.Data = req
831                         inPayload.Length = len(req)
832                         sh.HandleRPC(stream.Context(), inPayload)
833                 }
834                 if trInfo != nil {
835                         trInfo.tr.LazyLog(&payload{sent: false, msg: v}, true)
836                 }
837                 return nil
838         }
839         reply, appErr := md.Handler(srv.server, stream.Context(), df, s.opts.unaryInt)
840         if appErr != nil {
841                 appStatus, ok := status.FromError(appErr)
842                 if !ok {
843                         // Convert appErr if it is not a grpc status error.
844                         appErr = status.Error(convertCode(appErr), appErr.Error())
845                         appStatus, _ = status.FromError(appErr)
846                 }
847                 if trInfo != nil {
848                         trInfo.tr.LazyLog(stringer(appStatus.Message()), true)
849                         trInfo.tr.SetError()
850                 }
851                 if e := t.WriteStatus(stream, appStatus); e != nil {
852                         grpclog.Warningf("grpc: Server.processUnaryRPC failed to write status: %v", e)
853                 }
854                 return appErr
855         }
856         if trInfo != nil {
857                 trInfo.tr.LazyLog(stringer("OK"), false)
858         }
859         opts := &transport.Options{
860                 Last:  true,
861                 Delay: false,
862         }
863         if err := s.sendResponse(t, stream, reply, s.opts.cp, opts); err != nil {
864                 if err == io.EOF {
865                         // The entire stream is done (for unary RPC only).
866                         return err
867                 }
868                 if s, ok := status.FromError(err); ok {
869                         if e := t.WriteStatus(stream, s); e != nil {
870                                 grpclog.Warningf("grpc: Server.processUnaryRPC failed to write status: %v", e)
871                         }
872                 } else {
873                         switch st := err.(type) {
874                         case transport.ConnectionError:
875                                 // Nothing to do here.
876                         case transport.StreamError:
877                                 if e := t.WriteStatus(stream, status.New(st.Code, st.Desc)); e != nil {
878                                         grpclog.Warningf("grpc: Server.processUnaryRPC failed to write status %v", e)
879                                 }
880                         default:
881                                 panic(fmt.Sprintf("grpc: Unexpected error (%T) from sendResponse: %v", st, st))
882                         }
883                 }
884                 return err
885         }
886         if trInfo != nil {
887                 trInfo.tr.LazyLog(&payload{sent: true, msg: reply}, true)
888         }
889         // TODO: Should we be logging if writing status failed here, like above?
890         // Should the logging be in WriteStatus?  Should we ignore the WriteStatus
891         // error or allow the stats handler to see it?
892         return t.WriteStatus(stream, status.New(codes.OK, ""))
893 }
894
895 func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, sd *StreamDesc, trInfo *traceInfo) (err error) {
896         sh := s.opts.statsHandler
897         if sh != nil {
898                 begin := &stats.Begin{
899                         BeginTime: time.Now(),
900                 }
901                 sh.HandleRPC(stream.Context(), begin)
902                 defer func() {
903                         end := &stats.End{
904                                 EndTime: time.Now(),
905                         }
906                         if err != nil && err != io.EOF {
907                                 end.Error = toRPCErr(err)
908                         }
909                         sh.HandleRPC(stream.Context(), end)
910                 }()
911         }
912         if s.opts.cp != nil {
913                 stream.SetSendCompress(s.opts.cp.Type())
914         }
915         ss := &serverStream{
916                 t:     t,
917                 s:     stream,
918                 p:     &parser{r: stream},
919                 codec: s.opts.codec,
920                 cp:    s.opts.cp,
921                 dc:    s.opts.dc,
922                 maxReceiveMessageSize: s.opts.maxReceiveMessageSize,
923                 maxSendMessageSize:    s.opts.maxSendMessageSize,
924                 trInfo:                trInfo,
925                 statsHandler:          sh,
926         }
927         if trInfo != nil {
928                 trInfo.tr.LazyLog(&trInfo.firstLine, false)
929                 defer func() {
930                         ss.mu.Lock()
931                         if err != nil && err != io.EOF {
932                                 ss.trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
933                                 ss.trInfo.tr.SetError()
934                         }
935                         ss.trInfo.tr.Finish()
936                         ss.trInfo.tr = nil
937                         ss.mu.Unlock()
938                 }()
939         }
940         var appErr error
941         var server interface{}
942         if srv != nil {
943                 server = srv.server
944         }
945         if s.opts.streamInt == nil {
946                 appErr = sd.Handler(server, ss)
947         } else {
948                 info := &StreamServerInfo{
949                         FullMethod:     stream.Method(),
950                         IsClientStream: sd.ClientStreams,
951                         IsServerStream: sd.ServerStreams,
952                 }
953                 appErr = s.opts.streamInt(server, ss, info, sd.Handler)
954         }
955         if appErr != nil {
956                 appStatus, ok := status.FromError(appErr)
957                 if !ok {
958                         switch err := appErr.(type) {
959                         case transport.StreamError:
960                                 appStatus = status.New(err.Code, err.Desc)
961                         default:
962                                 appStatus = status.New(convertCode(appErr), appErr.Error())
963                         }
964                         appErr = appStatus.Err()
965                 }
966                 if trInfo != nil {
967                         ss.mu.Lock()
968                         ss.trInfo.tr.LazyLog(stringer(appStatus.Message()), true)
969                         ss.trInfo.tr.SetError()
970                         ss.mu.Unlock()
971                 }
972                 t.WriteStatus(ss.s, appStatus)
973                 // TODO: Should we log an error from WriteStatus here and below?
974                 return appErr
975         }
976         if trInfo != nil {
977                 ss.mu.Lock()
978                 ss.trInfo.tr.LazyLog(stringer("OK"), false)
979                 ss.mu.Unlock()
980         }
981         return t.WriteStatus(ss.s, status.New(codes.OK, ""))
982
983 }
984
985 func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream, trInfo *traceInfo) {
986         sm := stream.Method()
987         if sm != "" && sm[0] == '/' {
988                 sm = sm[1:]
989         }
990         pos := strings.LastIndex(sm, "/")
991         if pos == -1 {
992                 if trInfo != nil {
993                         trInfo.tr.LazyLog(&fmtStringer{"Malformed method name %q", []interface{}{sm}}, true)
994                         trInfo.tr.SetError()
995                 }
996                 errDesc := fmt.Sprintf("malformed method name: %q", stream.Method())
997                 if err := t.WriteStatus(stream, status.New(codes.ResourceExhausted, errDesc)); err != nil {
998                         if trInfo != nil {
999                                 trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
1000                                 trInfo.tr.SetError()
1001                         }
1002                         grpclog.Warningf("grpc: Server.handleStream failed to write status: %v", err)
1003                 }
1004                 if trInfo != nil {
1005                         trInfo.tr.Finish()
1006                 }
1007                 return
1008         }
1009         service := sm[:pos]
1010         method := sm[pos+1:]
1011         srv, ok := s.m[service]
1012         if !ok {
1013                 if unknownDesc := s.opts.unknownStreamDesc; unknownDesc != nil {
1014                         s.processStreamingRPC(t, stream, nil, unknownDesc, trInfo)
1015                         return
1016                 }
1017                 if trInfo != nil {
1018                         trInfo.tr.LazyLog(&fmtStringer{"Unknown service %v", []interface{}{service}}, true)
1019                         trInfo.tr.SetError()
1020                 }
1021                 errDesc := fmt.Sprintf("unknown service %v", service)
1022                 if err := t.WriteStatus(stream, status.New(codes.Unimplemented, errDesc)); err != nil {
1023                         if trInfo != nil {
1024                                 trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
1025                                 trInfo.tr.SetError()
1026                         }
1027                         grpclog.Warningf("grpc: Server.handleStream failed to write status: %v", err)
1028                 }
1029                 if trInfo != nil {
1030                         trInfo.tr.Finish()
1031                 }
1032                 return
1033         }
1034         // Unary RPC or Streaming RPC?
1035         if md, ok := srv.md[method]; ok {
1036                 s.processUnaryRPC(t, stream, srv, md, trInfo)
1037                 return
1038         }
1039         if sd, ok := srv.sd[method]; ok {
1040                 s.processStreamingRPC(t, stream, srv, sd, trInfo)
1041                 return
1042         }
1043         if trInfo != nil {
1044                 trInfo.tr.LazyLog(&fmtStringer{"Unknown method %v", []interface{}{method}}, true)
1045                 trInfo.tr.SetError()
1046         }
1047         if unknownDesc := s.opts.unknownStreamDesc; unknownDesc != nil {
1048                 s.processStreamingRPC(t, stream, nil, unknownDesc, trInfo)
1049                 return
1050         }
1051         errDesc := fmt.Sprintf("unknown method %v", method)
1052         if err := t.WriteStatus(stream, status.New(codes.Unimplemented, errDesc)); err != nil {
1053                 if trInfo != nil {
1054                         trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
1055                         trInfo.tr.SetError()
1056                 }
1057                 grpclog.Warningf("grpc: Server.handleStream failed to write status: %v", err)
1058         }
1059         if trInfo != nil {
1060                 trInfo.tr.Finish()
1061         }
1062 }
1063
1064 // Stop stops the gRPC server. It immediately closes all open
1065 // connections and listeners.
1066 // It cancels all active RPCs on the server side and the corresponding
1067 // pending RPCs on the client side will get notified by connection
1068 // errors.
1069 func (s *Server) Stop() {
1070         s.quitOnce.Do(func() {
1071                 close(s.quit)
1072         })
1073
1074         defer func() {
1075                 s.doneOnce.Do(func() {
1076                         close(s.done)
1077                 })
1078         }()
1079
1080         s.mu.Lock()
1081         listeners := s.lis
1082         s.lis = nil
1083         st := s.conns
1084         s.conns = nil
1085         // interrupt GracefulStop if Stop and GracefulStop are called concurrently.
1086         s.cv.Broadcast()
1087         s.mu.Unlock()
1088
1089         for lis := range listeners {
1090                 lis.Close()
1091         }
1092         for c := range st {
1093                 c.Close()
1094         }
1095
1096         s.mu.Lock()
1097         s.cancel()
1098         if s.events != nil {
1099                 s.events.Finish()
1100                 s.events = nil
1101         }
1102         s.mu.Unlock()
1103 }
1104
1105 // GracefulStop stops the gRPC server gracefully. It stops the server from
1106 // accepting new connections and RPCs and blocks until all the pending RPCs are
1107 // finished.
1108 func (s *Server) GracefulStop() {
1109         s.quitOnce.Do(func() {
1110                 close(s.quit)
1111         })
1112
1113         defer func() {
1114                 s.doneOnce.Do(func() {
1115                         close(s.done)
1116                 })
1117         }()
1118
1119         s.mu.Lock()
1120         defer s.mu.Unlock()
1121         if s.conns == nil {
1122                 return
1123         }
1124         for lis := range s.lis {
1125                 lis.Close()
1126         }
1127         s.lis = nil
1128         s.cancel()
1129         if !s.drain {
1130                 for c := range s.conns {
1131                         c.(transport.ServerTransport).Drain()
1132                 }
1133                 s.drain = true
1134         }
1135         for len(s.conns) != 0 {
1136                 s.cv.Wait()
1137         }
1138         s.conns = nil
1139         if s.events != nil {
1140                 s.events.Finish()
1141                 s.events = nil
1142         }
1143 }
1144
1145 func init() {
1146         internal.TestingCloseConns = func(arg interface{}) {
1147                 arg.(*Server).testingCloseConns()
1148         }
1149         internal.TestingUseHandlerImpl = func(arg interface{}) {
1150                 arg.(*Server).opts.useHandlerImpl = true
1151         }
1152 }
1153
1154 // testingCloseConns closes all existing transports but keeps s.lis
1155 // accepting new connections.
1156 func (s *Server) testingCloseConns() {
1157         s.mu.Lock()
1158         for c := range s.conns {
1159                 c.Close()
1160                 delete(s.conns, c)
1161         }
1162         s.mu.Unlock()
1163 }
1164
1165 // SetHeader sets the header metadata.
1166 // When called multiple times, all the provided metadata will be merged.
1167 // All the metadata will be sent out when one of the following happens:
1168 //  - grpc.SendHeader() is called;
1169 //  - The first response is sent out;
1170 //  - An RPC status is sent out (error or success).
1171 func SetHeader(ctx context.Context, md metadata.MD) error {
1172         if md.Len() == 0 {
1173                 return nil
1174         }
1175         stream, ok := transport.StreamFromContext(ctx)
1176         if !ok {
1177                 return Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
1178         }
1179         return stream.SetHeader(md)
1180 }
1181
1182 // SendHeader sends header metadata. It may be called at most once.
1183 // The provided md and headers set by SetHeader() will be sent.
1184 func SendHeader(ctx context.Context, md metadata.MD) error {
1185         stream, ok := transport.StreamFromContext(ctx)
1186         if !ok {
1187                 return Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
1188         }
1189         t := stream.ServerTransport()
1190         if t == nil {
1191                 grpclog.Fatalf("grpc: SendHeader: %v has no ServerTransport to send header metadata.", stream)
1192         }
1193         if err := t.WriteHeader(stream, md); err != nil {
1194                 return toRPCErr(err)
1195         }
1196         return nil
1197 }
1198
1199 // SetTrailer sets the trailer metadata that will be sent when an RPC returns.
1200 // When called more than once, all the provided metadata will be merged.
1201 func SetTrailer(ctx context.Context, md metadata.MD) error {
1202         if md.Len() == 0 {
1203                 return nil
1204         }
1205         stream, ok := transport.StreamFromContext(ctx)
1206         if !ok {
1207                 return Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
1208         }
1209         return stream.SetTrailer(md)
1210 }