OSDN Git Service

new repo
[bytom/vapor.git] / vendor / google.golang.org / grpc / stats / stats_test.go
1 /*
2  *
3  * Copyright 2016 gRPC authors.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  */
18
19 package stats_test
20
21 import (
22         "fmt"
23         "io"
24         "net"
25         "reflect"
26         "sync"
27         "testing"
28         "time"
29
30         "github.com/golang/protobuf/proto"
31         "golang.org/x/net/context"
32         "google.golang.org/grpc"
33         "google.golang.org/grpc/metadata"
34         "google.golang.org/grpc/stats"
35         testpb "google.golang.org/grpc/stats/grpc_testing"
36 )
37
38 func init() {
39         grpc.EnableTracing = false
40 }
41
42 type connCtxKey struct{}
43 type rpcCtxKey struct{}
44
45 var (
46         // For headers:
47         testMetadata = metadata.MD{
48                 "key1": []string{"value1"},
49                 "key2": []string{"value2"},
50         }
51         // For trailers:
52         testTrailerMetadata = metadata.MD{
53                 "tkey1": []string{"trailerValue1"},
54                 "tkey2": []string{"trailerValue2"},
55         }
56         // The id for which the service handler should return error.
57         errorID int32 = 32202
58 )
59
60 type testServer struct{}
61
62 func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
63         md, ok := metadata.FromIncomingContext(ctx)
64         if ok {
65                 if err := grpc.SendHeader(ctx, md); err != nil {
66                         return nil, grpc.Errorf(grpc.Code(err), "grpc.SendHeader(_, %v) = %v, want <nil>", md, err)
67                 }
68                 if err := grpc.SetTrailer(ctx, testTrailerMetadata); err != nil {
69                         return nil, grpc.Errorf(grpc.Code(err), "grpc.SetTrailer(_, %v) = %v, want <nil>", testTrailerMetadata, err)
70                 }
71         }
72
73         if in.Id == errorID {
74                 return nil, fmt.Errorf("got error id: %v", in.Id)
75         }
76
77         return &testpb.SimpleResponse{Id: in.Id}, nil
78 }
79
80 func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error {
81         md, ok := metadata.FromIncomingContext(stream.Context())
82         if ok {
83                 if err := stream.SendHeader(md); err != nil {
84                         return grpc.Errorf(grpc.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, md, err, nil)
85                 }
86                 stream.SetTrailer(testTrailerMetadata)
87         }
88         for {
89                 in, err := stream.Recv()
90                 if err == io.EOF {
91                         // read done.
92                         return nil
93                 }
94                 if err != nil {
95                         return err
96                 }
97
98                 if in.Id == errorID {
99                         return fmt.Errorf("got error id: %v", in.Id)
100                 }
101
102                 if err := stream.Send(&testpb.SimpleResponse{Id: in.Id}); err != nil {
103                         return err
104                 }
105         }
106 }
107
108 func (s *testServer) ClientStreamCall(stream testpb.TestService_ClientStreamCallServer) error {
109         md, ok := metadata.FromIncomingContext(stream.Context())
110         if ok {
111                 if err := stream.SendHeader(md); err != nil {
112                         return grpc.Errorf(grpc.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, md, err, nil)
113                 }
114                 stream.SetTrailer(testTrailerMetadata)
115         }
116         for {
117                 in, err := stream.Recv()
118                 if err == io.EOF {
119                         // read done.
120                         return stream.SendAndClose(&testpb.SimpleResponse{Id: int32(0)})
121                 }
122                 if err != nil {
123                         return err
124                 }
125
126                 if in.Id == errorID {
127                         return fmt.Errorf("got error id: %v", in.Id)
128                 }
129         }
130 }
131
132 func (s *testServer) ServerStreamCall(in *testpb.SimpleRequest, stream testpb.TestService_ServerStreamCallServer) error {
133         md, ok := metadata.FromIncomingContext(stream.Context())
134         if ok {
135                 if err := stream.SendHeader(md); err != nil {
136                         return grpc.Errorf(grpc.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, md, err, nil)
137                 }
138                 stream.SetTrailer(testTrailerMetadata)
139         }
140
141         if in.Id == errorID {
142                 return fmt.Errorf("got error id: %v", in.Id)
143         }
144
145         for i := 0; i < 5; i++ {
146                 if err := stream.Send(&testpb.SimpleResponse{Id: in.Id}); err != nil {
147                         return err
148                 }
149         }
150         return nil
151 }
152
153 // test is an end-to-end test. It should be created with the newTest
154 // func, modified as needed, and then started with its startServer method.
155 // It should be cleaned up with the tearDown method.
156 type test struct {
157         t                  *testing.T
158         compress           string
159         clientStatsHandler stats.Handler
160         serverStatsHandler stats.Handler
161
162         testServer testpb.TestServiceServer // nil means none
163         // srv and srvAddr are set once startServer is called.
164         srv     *grpc.Server
165         srvAddr string
166
167         cc *grpc.ClientConn // nil until requested via clientConn
168 }
169
170 func (te *test) tearDown() {
171         if te.cc != nil {
172                 te.cc.Close()
173                 te.cc = nil
174         }
175         te.srv.Stop()
176 }
177
178 type testConfig struct {
179         compress string
180 }
181
182 // newTest returns a new test using the provided testing.T and
183 // environment.  It is returned with default values. Tests should
184 // modify it before calling its startServer and clientConn methods.
185 func newTest(t *testing.T, tc *testConfig, ch stats.Handler, sh stats.Handler) *test {
186         te := &test{
187                 t:                  t,
188                 compress:           tc.compress,
189                 clientStatsHandler: ch,
190                 serverStatsHandler: sh,
191         }
192         return te
193 }
194
195 // startServer starts a gRPC server listening. Callers should defer a
196 // call to te.tearDown to clean up.
197 func (te *test) startServer(ts testpb.TestServiceServer) {
198         te.testServer = ts
199         lis, err := net.Listen("tcp", "localhost:0")
200         if err != nil {
201                 te.t.Fatalf("Failed to listen: %v", err)
202         }
203         var opts []grpc.ServerOption
204         if te.compress == "gzip" {
205                 opts = append(opts,
206                         grpc.RPCCompressor(grpc.NewGZIPCompressor()),
207                         grpc.RPCDecompressor(grpc.NewGZIPDecompressor()),
208                 )
209         }
210         if te.serverStatsHandler != nil {
211                 opts = append(opts, grpc.StatsHandler(te.serverStatsHandler))
212         }
213         s := grpc.NewServer(opts...)
214         te.srv = s
215         if te.testServer != nil {
216                 testpb.RegisterTestServiceServer(s, te.testServer)
217         }
218
219         go s.Serve(lis)
220         te.srvAddr = lis.Addr().String()
221 }
222
223 func (te *test) clientConn() *grpc.ClientConn {
224         if te.cc != nil {
225                 return te.cc
226         }
227         opts := []grpc.DialOption{grpc.WithInsecure(), grpc.WithBlock()}
228         if te.compress == "gzip" {
229                 opts = append(opts,
230                         grpc.WithCompressor(grpc.NewGZIPCompressor()),
231                         grpc.WithDecompressor(grpc.NewGZIPDecompressor()),
232                 )
233         }
234         if te.clientStatsHandler != nil {
235                 opts = append(opts, grpc.WithStatsHandler(te.clientStatsHandler))
236         }
237
238         var err error
239         te.cc, err = grpc.Dial(te.srvAddr, opts...)
240         if err != nil {
241                 te.t.Fatalf("Dial(%q) = %v", te.srvAddr, err)
242         }
243         return te.cc
244 }
245
246 type rpcType int
247
248 const (
249         unaryRPC rpcType = iota
250         clientStreamRPC
251         serverStreamRPC
252         fullDuplexStreamRPC
253 )
254
255 type rpcConfig struct {
256         count      int  // Number of requests and responses for streaming RPCs.
257         success    bool // Whether the RPC should succeed or return error.
258         failfast   bool
259         callType   rpcType // Type of RPC.
260         noLastRecv bool    // Whether to call recv for io.EOF. When true, last recv won't be called. Only valid for streaming RPCs.
261 }
262
263 func (te *test) doUnaryCall(c *rpcConfig) (*testpb.SimpleRequest, *testpb.SimpleResponse, error) {
264         var (
265                 resp *testpb.SimpleResponse
266                 req  *testpb.SimpleRequest
267                 err  error
268         )
269         tc := testpb.NewTestServiceClient(te.clientConn())
270         if c.success {
271                 req = &testpb.SimpleRequest{Id: errorID + 1}
272         } else {
273                 req = &testpb.SimpleRequest{Id: errorID}
274         }
275         ctx := metadata.NewOutgoingContext(context.Background(), testMetadata)
276
277         resp, err = tc.UnaryCall(ctx, req, grpc.FailFast(c.failfast))
278         return req, resp, err
279 }
280
281 func (te *test) doFullDuplexCallRoundtrip(c *rpcConfig) ([]*testpb.SimpleRequest, []*testpb.SimpleResponse, error) {
282         var (
283                 reqs  []*testpb.SimpleRequest
284                 resps []*testpb.SimpleResponse
285                 err   error
286         )
287         tc := testpb.NewTestServiceClient(te.clientConn())
288         stream, err := tc.FullDuplexCall(metadata.NewOutgoingContext(context.Background(), testMetadata), grpc.FailFast(c.failfast))
289         if err != nil {
290                 return reqs, resps, err
291         }
292         var startID int32
293         if !c.success {
294                 startID = errorID
295         }
296         for i := 0; i < c.count; i++ {
297                 req := &testpb.SimpleRequest{
298                         Id: int32(i) + startID,
299                 }
300                 reqs = append(reqs, req)
301                 if err = stream.Send(req); err != nil {
302                         return reqs, resps, err
303                 }
304                 var resp *testpb.SimpleResponse
305                 if resp, err = stream.Recv(); err != nil {
306                         return reqs, resps, err
307                 }
308                 resps = append(resps, resp)
309         }
310         if err = stream.CloseSend(); err != nil && err != io.EOF {
311                 return reqs, resps, err
312         }
313         if !c.noLastRecv {
314                 if _, err = stream.Recv(); err != io.EOF {
315                         return reqs, resps, err
316                 }
317         } else {
318                 // In the case of not calling the last recv, sleep to avoid
319                 // returning too fast to miss the remaining stats (InTrailer and End).
320                 time.Sleep(time.Second)
321         }
322
323         return reqs, resps, nil
324 }
325
326 func (te *test) doClientStreamCall(c *rpcConfig) ([]*testpb.SimpleRequest, *testpb.SimpleResponse, error) {
327         var (
328                 reqs []*testpb.SimpleRequest
329                 resp *testpb.SimpleResponse
330                 err  error
331         )
332         tc := testpb.NewTestServiceClient(te.clientConn())
333         stream, err := tc.ClientStreamCall(metadata.NewOutgoingContext(context.Background(), testMetadata), grpc.FailFast(c.failfast))
334         if err != nil {
335                 return reqs, resp, err
336         }
337         var startID int32
338         if !c.success {
339                 startID = errorID
340         }
341         for i := 0; i < c.count; i++ {
342                 req := &testpb.SimpleRequest{
343                         Id: int32(i) + startID,
344                 }
345                 reqs = append(reqs, req)
346                 if err = stream.Send(req); err != nil {
347                         return reqs, resp, err
348                 }
349         }
350         resp, err = stream.CloseAndRecv()
351         return reqs, resp, err
352 }
353
354 func (te *test) doServerStreamCall(c *rpcConfig) (*testpb.SimpleRequest, []*testpb.SimpleResponse, error) {
355         var (
356                 req   *testpb.SimpleRequest
357                 resps []*testpb.SimpleResponse
358                 err   error
359         )
360
361         tc := testpb.NewTestServiceClient(te.clientConn())
362
363         var startID int32
364         if !c.success {
365                 startID = errorID
366         }
367         req = &testpb.SimpleRequest{Id: startID}
368         stream, err := tc.ServerStreamCall(metadata.NewOutgoingContext(context.Background(), testMetadata), req, grpc.FailFast(c.failfast))
369         if err != nil {
370                 return req, resps, err
371         }
372         for {
373                 var resp *testpb.SimpleResponse
374                 resp, err := stream.Recv()
375                 if err == io.EOF {
376                         return req, resps, nil
377                 } else if err != nil {
378                         return req, resps, err
379                 }
380                 resps = append(resps, resp)
381         }
382 }
383
384 type expectedData struct {
385         method      string
386         serverAddr  string
387         compression string
388         reqIdx      int
389         requests    []*testpb.SimpleRequest
390         respIdx     int
391         responses   []*testpb.SimpleResponse
392         err         error
393         failfast    bool
394 }
395
396 type gotData struct {
397         ctx    context.Context
398         client bool
399         s      interface{} // This could be RPCStats or ConnStats.
400 }
401
402 const (
403         begin int = iota
404         end
405         inPayload
406         inHeader
407         inTrailer
408         outPayload
409         outHeader
410         outTrailer
411         connbegin
412         connend
413 )
414
415 func checkBegin(t *testing.T, d *gotData, e *expectedData) {
416         var (
417                 ok bool
418                 st *stats.Begin
419         )
420         if st, ok = d.s.(*stats.Begin); !ok {
421                 t.Fatalf("got %T, want Begin", d.s)
422         }
423         if d.ctx == nil {
424                 t.Fatalf("d.ctx = nil, want <non-nil>")
425         }
426         if st.BeginTime.IsZero() {
427                 t.Fatalf("st.BeginTime = %v, want <non-zero>", st.BeginTime)
428         }
429         if d.client {
430                 if st.FailFast != e.failfast {
431                         t.Fatalf("st.FailFast = %v, want %v", st.FailFast, e.failfast)
432                 }
433         }
434 }
435
436 func checkInHeader(t *testing.T, d *gotData, e *expectedData) {
437         var (
438                 ok bool
439                 st *stats.InHeader
440         )
441         if st, ok = d.s.(*stats.InHeader); !ok {
442                 t.Fatalf("got %T, want InHeader", d.s)
443         }
444         if d.ctx == nil {
445                 t.Fatalf("d.ctx = nil, want <non-nil>")
446         }
447         if !d.client {
448                 if st.FullMethod != e.method {
449                         t.Fatalf("st.FullMethod = %s, want %v", st.FullMethod, e.method)
450                 }
451                 if st.LocalAddr.String() != e.serverAddr {
452                         t.Fatalf("st.LocalAddr = %v, want %v", st.LocalAddr, e.serverAddr)
453                 }
454                 if st.Compression != e.compression {
455                         t.Fatalf("st.Compression = %v, want %v", st.Compression, e.compression)
456                 }
457
458                 if connInfo, ok := d.ctx.Value(connCtxKey{}).(*stats.ConnTagInfo); ok {
459                         if connInfo.RemoteAddr != st.RemoteAddr {
460                                 t.Fatalf("connInfo.RemoteAddr = %v, want %v", connInfo.RemoteAddr, st.RemoteAddr)
461                         }
462                         if connInfo.LocalAddr != st.LocalAddr {
463                                 t.Fatalf("connInfo.LocalAddr = %v, want %v", connInfo.LocalAddr, st.LocalAddr)
464                         }
465                 } else {
466                         t.Fatalf("got context %v, want one with connCtxKey", d.ctx)
467                 }
468                 if rpcInfo, ok := d.ctx.Value(rpcCtxKey{}).(*stats.RPCTagInfo); ok {
469                         if rpcInfo.FullMethodName != st.FullMethod {
470                                 t.Fatalf("rpcInfo.FullMethod = %s, want %v", rpcInfo.FullMethodName, st.FullMethod)
471                         }
472                 } else {
473                         t.Fatalf("got context %v, want one with rpcCtxKey", d.ctx)
474                 }
475         }
476 }
477
478 func checkInPayload(t *testing.T, d *gotData, e *expectedData) {
479         var (
480                 ok bool
481                 st *stats.InPayload
482         )
483         if st, ok = d.s.(*stats.InPayload); !ok {
484                 t.Fatalf("got %T, want InPayload", d.s)
485         }
486         if d.ctx == nil {
487                 t.Fatalf("d.ctx = nil, want <non-nil>")
488         }
489         if d.client {
490                 b, err := proto.Marshal(e.responses[e.respIdx])
491                 if err != nil {
492                         t.Fatalf("failed to marshal message: %v", err)
493                 }
494                 if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.responses[e.respIdx]) {
495                         t.Fatalf("st.Payload = %T, want %T", st.Payload, e.responses[e.respIdx])
496                 }
497                 e.respIdx++
498                 if string(st.Data) != string(b) {
499                         t.Fatalf("st.Data = %v, want %v", st.Data, b)
500                 }
501                 if st.Length != len(b) {
502                         t.Fatalf("st.Lenght = %v, want %v", st.Length, len(b))
503                 }
504         } else {
505                 b, err := proto.Marshal(e.requests[e.reqIdx])
506                 if err != nil {
507                         t.Fatalf("failed to marshal message: %v", err)
508                 }
509                 if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.requests[e.reqIdx]) {
510                         t.Fatalf("st.Payload = %T, want %T", st.Payload, e.requests[e.reqIdx])
511                 }
512                 e.reqIdx++
513                 if string(st.Data) != string(b) {
514                         t.Fatalf("st.Data = %v, want %v", st.Data, b)
515                 }
516                 if st.Length != len(b) {
517                         t.Fatalf("st.Lenght = %v, want %v", st.Length, len(b))
518                 }
519         }
520         // TODO check WireLength and ReceivedTime.
521         if st.RecvTime.IsZero() {
522                 t.Fatalf("st.ReceivedTime = %v, want <non-zero>", st.RecvTime)
523         }
524 }
525
526 func checkInTrailer(t *testing.T, d *gotData, e *expectedData) {
527         var (
528                 ok bool
529         )
530         if _, ok = d.s.(*stats.InTrailer); !ok {
531                 t.Fatalf("got %T, want InTrailer", d.s)
532         }
533         if d.ctx == nil {
534                 t.Fatalf("d.ctx = nil, want <non-nil>")
535         }
536 }
537
538 func checkOutHeader(t *testing.T, d *gotData, e *expectedData) {
539         var (
540                 ok bool
541                 st *stats.OutHeader
542         )
543         if st, ok = d.s.(*stats.OutHeader); !ok {
544                 t.Fatalf("got %T, want OutHeader", d.s)
545         }
546         if d.ctx == nil {
547                 t.Fatalf("d.ctx = nil, want <non-nil>")
548         }
549         if d.client {
550                 if st.FullMethod != e.method {
551                         t.Fatalf("st.FullMethod = %s, want %v", st.FullMethod, e.method)
552                 }
553                 if st.RemoteAddr.String() != e.serverAddr {
554                         t.Fatalf("st.RemoteAddr = %v, want %v", st.RemoteAddr, e.serverAddr)
555                 }
556                 if st.Compression != e.compression {
557                         t.Fatalf("st.Compression = %v, want %v", st.Compression, e.compression)
558                 }
559
560                 if rpcInfo, ok := d.ctx.Value(rpcCtxKey{}).(*stats.RPCTagInfo); ok {
561                         if rpcInfo.FullMethodName != st.FullMethod {
562                                 t.Fatalf("rpcInfo.FullMethod = %s, want %v", rpcInfo.FullMethodName, st.FullMethod)
563                         }
564                 } else {
565                         t.Fatalf("got context %v, want one with rpcCtxKey", d.ctx)
566                 }
567         }
568 }
569
570 func checkOutPayload(t *testing.T, d *gotData, e *expectedData) {
571         var (
572                 ok bool
573                 st *stats.OutPayload
574         )
575         if st, ok = d.s.(*stats.OutPayload); !ok {
576                 t.Fatalf("got %T, want OutPayload", d.s)
577         }
578         if d.ctx == nil {
579                 t.Fatalf("d.ctx = nil, want <non-nil>")
580         }
581         if d.client {
582                 b, err := proto.Marshal(e.requests[e.reqIdx])
583                 if err != nil {
584                         t.Fatalf("failed to marshal message: %v", err)
585                 }
586                 if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.requests[e.reqIdx]) {
587                         t.Fatalf("st.Payload = %T, want %T", st.Payload, e.requests[e.reqIdx])
588                 }
589                 e.reqIdx++
590                 if string(st.Data) != string(b) {
591                         t.Fatalf("st.Data = %v, want %v", st.Data, b)
592                 }
593                 if st.Length != len(b) {
594                         t.Fatalf("st.Lenght = %v, want %v", st.Length, len(b))
595                 }
596         } else {
597                 b, err := proto.Marshal(e.responses[e.respIdx])
598                 if err != nil {
599                         t.Fatalf("failed to marshal message: %v", err)
600                 }
601                 if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.responses[e.respIdx]) {
602                         t.Fatalf("st.Payload = %T, want %T", st.Payload, e.responses[e.respIdx])
603                 }
604                 e.respIdx++
605                 if string(st.Data) != string(b) {
606                         t.Fatalf("st.Data = %v, want %v", st.Data, b)
607                 }
608                 if st.Length != len(b) {
609                         t.Fatalf("st.Lenght = %v, want %v", st.Length, len(b))
610                 }
611         }
612         // TODO check WireLength and ReceivedTime.
613         if st.SentTime.IsZero() {
614                 t.Fatalf("st.SentTime = %v, want <non-zero>", st.SentTime)
615         }
616 }
617
618 func checkOutTrailer(t *testing.T, d *gotData, e *expectedData) {
619         var (
620                 ok bool
621                 st *stats.OutTrailer
622         )
623         if st, ok = d.s.(*stats.OutTrailer); !ok {
624                 t.Fatalf("got %T, want OutTrailer", d.s)
625         }
626         if d.ctx == nil {
627                 t.Fatalf("d.ctx = nil, want <non-nil>")
628         }
629         if st.Client {
630                 t.Fatalf("st IsClient = true, want false")
631         }
632 }
633
634 func checkEnd(t *testing.T, d *gotData, e *expectedData) {
635         var (
636                 ok bool
637                 st *stats.End
638         )
639         if st, ok = d.s.(*stats.End); !ok {
640                 t.Fatalf("got %T, want End", d.s)
641         }
642         if d.ctx == nil {
643                 t.Fatalf("d.ctx = nil, want <non-nil>")
644         }
645         if st.EndTime.IsZero() {
646                 t.Fatalf("st.EndTime = %v, want <non-zero>", st.EndTime)
647         }
648         if grpc.Code(st.Error) != grpc.Code(e.err) || grpc.ErrorDesc(st.Error) != grpc.ErrorDesc(e.err) {
649                 t.Fatalf("st.Error = %v, want %v", st.Error, e.err)
650         }
651 }
652
653 func checkConnBegin(t *testing.T, d *gotData, e *expectedData) {
654         var (
655                 ok bool
656                 st *stats.ConnBegin
657         )
658         if st, ok = d.s.(*stats.ConnBegin); !ok {
659                 t.Fatalf("got %T, want ConnBegin", d.s)
660         }
661         if d.ctx == nil {
662                 t.Fatalf("d.ctx = nil, want <non-nil>")
663         }
664         st.IsClient() // TODO remove this.
665 }
666
667 func checkConnEnd(t *testing.T, d *gotData, e *expectedData) {
668         var (
669                 ok bool
670                 st *stats.ConnEnd
671         )
672         if st, ok = d.s.(*stats.ConnEnd); !ok {
673                 t.Fatalf("got %T, want ConnEnd", d.s)
674         }
675         if d.ctx == nil {
676                 t.Fatalf("d.ctx = nil, want <non-nil>")
677         }
678         st.IsClient() // TODO remove this.
679 }
680
681 type statshandler struct {
682         mu      sync.Mutex
683         gotRPC  []*gotData
684         gotConn []*gotData
685 }
686
687 func (h *statshandler) TagConn(ctx context.Context, info *stats.ConnTagInfo) context.Context {
688         return context.WithValue(ctx, connCtxKey{}, info)
689 }
690
691 func (h *statshandler) TagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context {
692         return context.WithValue(ctx, rpcCtxKey{}, info)
693 }
694
695 func (h *statshandler) HandleConn(ctx context.Context, s stats.ConnStats) {
696         h.mu.Lock()
697         defer h.mu.Unlock()
698         h.gotConn = append(h.gotConn, &gotData{ctx, s.IsClient(), s})
699 }
700
701 func (h *statshandler) HandleRPC(ctx context.Context, s stats.RPCStats) {
702         h.mu.Lock()
703         defer h.mu.Unlock()
704         h.gotRPC = append(h.gotRPC, &gotData{ctx, s.IsClient(), s})
705 }
706
707 func checkConnStats(t *testing.T, got []*gotData) {
708         if len(got) <= 0 || len(got)%2 != 0 {
709                 for i, g := range got {
710                         t.Errorf(" - %v, %T = %+v, ctx: %v", i, g.s, g.s, g.ctx)
711                 }
712                 t.Fatalf("got %v stats, want even positive number", len(got))
713         }
714         // The first conn stats must be a ConnBegin.
715         checkConnBegin(t, got[0], nil)
716         // The last conn stats must be a ConnEnd.
717         checkConnEnd(t, got[len(got)-1], nil)
718 }
719
720 func checkServerStats(t *testing.T, got []*gotData, expect *expectedData, checkFuncs []func(t *testing.T, d *gotData, e *expectedData)) {
721         if len(got) != len(checkFuncs) {
722                 for i, g := range got {
723                         t.Errorf(" - %v, %T", i, g.s)
724                 }
725                 t.Fatalf("got %v stats, want %v stats", len(got), len(checkFuncs))
726         }
727
728         var rpcctx context.Context
729         for i := 0; i < len(got); i++ {
730                 if _, ok := got[i].s.(stats.RPCStats); ok {
731                         if rpcctx != nil && got[i].ctx != rpcctx {
732                                 t.Fatalf("got different contexts with stats %T", got[i].s)
733                         }
734                         rpcctx = got[i].ctx
735                 }
736         }
737
738         for i, f := range checkFuncs {
739                 f(t, got[i], expect)
740         }
741 }
742
743 func testServerStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs []func(t *testing.T, d *gotData, e *expectedData)) {
744         h := &statshandler{}
745         te := newTest(t, tc, nil, h)
746         te.startServer(&testServer{})
747         defer te.tearDown()
748
749         var (
750                 reqs   []*testpb.SimpleRequest
751                 resps  []*testpb.SimpleResponse
752                 err    error
753                 method string
754
755                 req  *testpb.SimpleRequest
756                 resp *testpb.SimpleResponse
757                 e    error
758         )
759
760         switch cc.callType {
761         case unaryRPC:
762                 method = "/grpc.testing.TestService/UnaryCall"
763                 req, resp, e = te.doUnaryCall(cc)
764                 reqs = []*testpb.SimpleRequest{req}
765                 resps = []*testpb.SimpleResponse{resp}
766                 err = e
767         case clientStreamRPC:
768                 method = "/grpc.testing.TestService/ClientStreamCall"
769                 reqs, resp, e = te.doClientStreamCall(cc)
770                 resps = []*testpb.SimpleResponse{resp}
771                 err = e
772         case serverStreamRPC:
773                 method = "/grpc.testing.TestService/ServerStreamCall"
774                 req, resps, e = te.doServerStreamCall(cc)
775                 reqs = []*testpb.SimpleRequest{req}
776                 err = e
777         case fullDuplexStreamRPC:
778                 method = "/grpc.testing.TestService/FullDuplexCall"
779                 reqs, resps, err = te.doFullDuplexCallRoundtrip(cc)
780         }
781         if cc.success != (err == nil) {
782                 t.Fatalf("cc.success: %v, got error: %v", cc.success, err)
783         }
784         te.cc.Close()
785         te.srv.GracefulStop() // Wait for the server to stop.
786
787         for {
788                 h.mu.Lock()
789                 if len(h.gotRPC) >= len(checkFuncs) {
790                         h.mu.Unlock()
791                         break
792                 }
793                 h.mu.Unlock()
794                 time.Sleep(10 * time.Millisecond)
795         }
796
797         for {
798                 h.mu.Lock()
799                 if _, ok := h.gotConn[len(h.gotConn)-1].s.(*stats.ConnEnd); ok {
800                         h.mu.Unlock()
801                         break
802                 }
803                 h.mu.Unlock()
804                 time.Sleep(10 * time.Millisecond)
805         }
806
807         expect := &expectedData{
808                 serverAddr:  te.srvAddr,
809                 compression: tc.compress,
810                 method:      method,
811                 requests:    reqs,
812                 responses:   resps,
813                 err:         err,
814         }
815
816         h.mu.Lock()
817         checkConnStats(t, h.gotConn)
818         h.mu.Unlock()
819         checkServerStats(t, h.gotRPC, expect, checkFuncs)
820 }
821
822 func TestServerStatsUnaryRPC(t *testing.T) {
823         testServerStats(t, &testConfig{compress: ""}, &rpcConfig{success: true, callType: unaryRPC}, []func(t *testing.T, d *gotData, e *expectedData){
824                 checkInHeader,
825                 checkBegin,
826                 checkInPayload,
827                 checkOutHeader,
828                 checkOutPayload,
829                 checkOutTrailer,
830                 checkEnd,
831         })
832 }
833
834 func TestServerStatsUnaryRPCError(t *testing.T) {
835         testServerStats(t, &testConfig{compress: ""}, &rpcConfig{success: false, callType: unaryRPC}, []func(t *testing.T, d *gotData, e *expectedData){
836                 checkInHeader,
837                 checkBegin,
838                 checkInPayload,
839                 checkOutHeader,
840                 checkOutTrailer,
841                 checkEnd,
842         })
843 }
844
845 func TestServerStatsClientStreamRPC(t *testing.T) {
846         count := 5
847         checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){
848                 checkInHeader,
849                 checkBegin,
850                 checkOutHeader,
851         }
852         ioPayFuncs := []func(t *testing.T, d *gotData, e *expectedData){
853                 checkInPayload,
854         }
855         for i := 0; i < count; i++ {
856                 checkFuncs = append(checkFuncs, ioPayFuncs...)
857         }
858         checkFuncs = append(checkFuncs,
859                 checkOutPayload,
860                 checkOutTrailer,
861                 checkEnd,
862         )
863         testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, callType: clientStreamRPC}, checkFuncs)
864 }
865
866 func TestServerStatsClientStreamRPCError(t *testing.T) {
867         count := 1
868         testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, callType: clientStreamRPC}, []func(t *testing.T, d *gotData, e *expectedData){
869                 checkInHeader,
870                 checkBegin,
871                 checkOutHeader,
872                 checkInPayload,
873                 checkOutTrailer,
874                 checkEnd,
875         })
876 }
877
878 func TestServerStatsServerStreamRPC(t *testing.T) {
879         count := 5
880         checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){
881                 checkInHeader,
882                 checkBegin,
883                 checkInPayload,
884                 checkOutHeader,
885         }
886         ioPayFuncs := []func(t *testing.T, d *gotData, e *expectedData){
887                 checkOutPayload,
888         }
889         for i := 0; i < count; i++ {
890                 checkFuncs = append(checkFuncs, ioPayFuncs...)
891         }
892         checkFuncs = append(checkFuncs,
893                 checkOutTrailer,
894                 checkEnd,
895         )
896         testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, callType: serverStreamRPC}, checkFuncs)
897 }
898
899 func TestServerStatsServerStreamRPCError(t *testing.T) {
900         count := 5
901         testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, callType: serverStreamRPC}, []func(t *testing.T, d *gotData, e *expectedData){
902                 checkInHeader,
903                 checkBegin,
904                 checkInPayload,
905                 checkOutHeader,
906                 checkOutTrailer,
907                 checkEnd,
908         })
909 }
910
911 func TestServerStatsFullDuplexRPC(t *testing.T) {
912         count := 5
913         checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){
914                 checkInHeader,
915                 checkBegin,
916                 checkOutHeader,
917         }
918         ioPayFuncs := []func(t *testing.T, d *gotData, e *expectedData){
919                 checkInPayload,
920                 checkOutPayload,
921         }
922         for i := 0; i < count; i++ {
923                 checkFuncs = append(checkFuncs, ioPayFuncs...)
924         }
925         checkFuncs = append(checkFuncs,
926                 checkOutTrailer,
927                 checkEnd,
928         )
929         testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, callType: fullDuplexStreamRPC}, checkFuncs)
930 }
931
932 func TestServerStatsFullDuplexRPCError(t *testing.T) {
933         count := 5
934         testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, callType: fullDuplexStreamRPC}, []func(t *testing.T, d *gotData, e *expectedData){
935                 checkInHeader,
936                 checkBegin,
937                 checkOutHeader,
938                 checkInPayload,
939                 checkOutTrailer,
940                 checkEnd,
941         })
942 }
943
944 type checkFuncWithCount struct {
945         f func(t *testing.T, d *gotData, e *expectedData)
946         c int // expected count
947 }
948
949 func checkClientStats(t *testing.T, got []*gotData, expect *expectedData, checkFuncs map[int]*checkFuncWithCount) {
950         var expectLen int
951         for _, v := range checkFuncs {
952                 expectLen += v.c
953         }
954         if len(got) != expectLen {
955                 for i, g := range got {
956                         t.Errorf(" - %v, %T", i, g.s)
957                 }
958                 t.Fatalf("got %v stats, want %v stats", len(got), expectLen)
959         }
960
961         var tagInfoInCtx *stats.RPCTagInfo
962         for i := 0; i < len(got); i++ {
963                 if _, ok := got[i].s.(stats.RPCStats); ok {
964                         tagInfoInCtxNew, _ := got[i].ctx.Value(rpcCtxKey{}).(*stats.RPCTagInfo)
965                         if tagInfoInCtx != nil && tagInfoInCtx != tagInfoInCtxNew {
966                                 t.Fatalf("got context containing different tagInfo with stats %T", got[i].s)
967                         }
968                         tagInfoInCtx = tagInfoInCtxNew
969                 }
970         }
971
972         for _, s := range got {
973                 switch s.s.(type) {
974                 case *stats.Begin:
975                         if checkFuncs[begin].c <= 0 {
976                                 t.Fatalf("unexpected stats: %T", s.s)
977                         }
978                         checkFuncs[begin].f(t, s, expect)
979                         checkFuncs[begin].c--
980                 case *stats.OutHeader:
981                         if checkFuncs[outHeader].c <= 0 {
982                                 t.Fatalf("unexpected stats: %T", s.s)
983                         }
984                         checkFuncs[outHeader].f(t, s, expect)
985                         checkFuncs[outHeader].c--
986                 case *stats.OutPayload:
987                         if checkFuncs[outPayload].c <= 0 {
988                                 t.Fatalf("unexpected stats: %T", s.s)
989                         }
990                         checkFuncs[outPayload].f(t, s, expect)
991                         checkFuncs[outPayload].c--
992                 case *stats.InHeader:
993                         if checkFuncs[inHeader].c <= 0 {
994                                 t.Fatalf("unexpected stats: %T", s.s)
995                         }
996                         checkFuncs[inHeader].f(t, s, expect)
997                         checkFuncs[inHeader].c--
998                 case *stats.InPayload:
999                         if checkFuncs[inPayload].c <= 0 {
1000                                 t.Fatalf("unexpected stats: %T", s.s)
1001                         }
1002                         checkFuncs[inPayload].f(t, s, expect)
1003                         checkFuncs[inPayload].c--
1004                 case *stats.InTrailer:
1005                         if checkFuncs[inTrailer].c <= 0 {
1006                                 t.Fatalf("unexpected stats: %T", s.s)
1007                         }
1008                         checkFuncs[inTrailer].f(t, s, expect)
1009                         checkFuncs[inTrailer].c--
1010                 case *stats.End:
1011                         if checkFuncs[end].c <= 0 {
1012                                 t.Fatalf("unexpected stats: %T", s.s)
1013                         }
1014                         checkFuncs[end].f(t, s, expect)
1015                         checkFuncs[end].c--
1016                 case *stats.ConnBegin:
1017                         if checkFuncs[connbegin].c <= 0 {
1018                                 t.Fatalf("unexpected stats: %T", s.s)
1019                         }
1020                         checkFuncs[connbegin].f(t, s, expect)
1021                         checkFuncs[connbegin].c--
1022                 case *stats.ConnEnd:
1023                         if checkFuncs[connend].c <= 0 {
1024                                 t.Fatalf("unexpected stats: %T", s.s)
1025                         }
1026                         checkFuncs[connend].f(t, s, expect)
1027                         checkFuncs[connend].c--
1028                 default:
1029                         t.Fatalf("unexpected stats: %T", s.s)
1030                 }
1031         }
1032 }
1033
1034 func testClientStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs map[int]*checkFuncWithCount) {
1035         h := &statshandler{}
1036         te := newTest(t, tc, h, nil)
1037         te.startServer(&testServer{})
1038         defer te.tearDown()
1039
1040         var (
1041                 reqs   []*testpb.SimpleRequest
1042                 resps  []*testpb.SimpleResponse
1043                 method string
1044                 err    error
1045
1046                 req  *testpb.SimpleRequest
1047                 resp *testpb.SimpleResponse
1048                 e    error
1049         )
1050         switch cc.callType {
1051         case unaryRPC:
1052                 method = "/grpc.testing.TestService/UnaryCall"
1053                 req, resp, e = te.doUnaryCall(cc)
1054                 reqs = []*testpb.SimpleRequest{req}
1055                 resps = []*testpb.SimpleResponse{resp}
1056                 err = e
1057         case clientStreamRPC:
1058                 method = "/grpc.testing.TestService/ClientStreamCall"
1059                 reqs, resp, e = te.doClientStreamCall(cc)
1060                 resps = []*testpb.SimpleResponse{resp}
1061                 err = e
1062         case serverStreamRPC:
1063                 method = "/grpc.testing.TestService/ServerStreamCall"
1064                 req, resps, e = te.doServerStreamCall(cc)
1065                 reqs = []*testpb.SimpleRequest{req}
1066                 err = e
1067         case fullDuplexStreamRPC:
1068                 method = "/grpc.testing.TestService/FullDuplexCall"
1069                 reqs, resps, err = te.doFullDuplexCallRoundtrip(cc)
1070         }
1071         if cc.success != (err == nil) {
1072                 t.Fatalf("cc.success: %v, got error: %v", cc.success, err)
1073         }
1074         te.cc.Close()
1075         te.srv.GracefulStop() // Wait for the server to stop.
1076
1077         lenRPCStats := 0
1078         for _, v := range checkFuncs {
1079                 lenRPCStats += v.c
1080         }
1081         for {
1082                 h.mu.Lock()
1083                 if len(h.gotRPC) >= lenRPCStats {
1084                         h.mu.Unlock()
1085                         break
1086                 }
1087                 h.mu.Unlock()
1088                 time.Sleep(10 * time.Millisecond)
1089         }
1090
1091         for {
1092                 h.mu.Lock()
1093                 if _, ok := h.gotConn[len(h.gotConn)-1].s.(*stats.ConnEnd); ok {
1094                         h.mu.Unlock()
1095                         break
1096                 }
1097                 h.mu.Unlock()
1098                 time.Sleep(10 * time.Millisecond)
1099         }
1100
1101         expect := &expectedData{
1102                 serverAddr:  te.srvAddr,
1103                 compression: tc.compress,
1104                 method:      method,
1105                 requests:    reqs,
1106                 responses:   resps,
1107                 failfast:    cc.failfast,
1108                 err:         err,
1109         }
1110
1111         h.mu.Lock()
1112         checkConnStats(t, h.gotConn)
1113         h.mu.Unlock()
1114         checkClientStats(t, h.gotRPC, expect, checkFuncs)
1115 }
1116
1117 func TestClientStatsUnaryRPC(t *testing.T) {
1118         testClientStats(t, &testConfig{compress: ""}, &rpcConfig{success: true, failfast: false, callType: unaryRPC}, map[int]*checkFuncWithCount{
1119                 begin:      {checkBegin, 1},
1120                 outHeader:  {checkOutHeader, 1},
1121                 outPayload: {checkOutPayload, 1},
1122                 inHeader:   {checkInHeader, 1},
1123                 inPayload:  {checkInPayload, 1},
1124                 inTrailer:  {checkInTrailer, 1},
1125                 end:        {checkEnd, 1},
1126         })
1127 }
1128
1129 func TestClientStatsUnaryRPCError(t *testing.T) {
1130         testClientStats(t, &testConfig{compress: ""}, &rpcConfig{success: false, failfast: false, callType: unaryRPC}, map[int]*checkFuncWithCount{
1131                 begin:      {checkBegin, 1},
1132                 outHeader:  {checkOutHeader, 1},
1133                 outPayload: {checkOutPayload, 1},
1134                 inHeader:   {checkInHeader, 1},
1135                 inTrailer:  {checkInTrailer, 1},
1136                 end:        {checkEnd, 1},
1137         })
1138 }
1139
1140 func TestClientStatsClientStreamRPC(t *testing.T) {
1141         count := 5
1142         testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, callType: clientStreamRPC}, map[int]*checkFuncWithCount{
1143                 begin:      {checkBegin, 1},
1144                 outHeader:  {checkOutHeader, 1},
1145                 inHeader:   {checkInHeader, 1},
1146                 outPayload: {checkOutPayload, count},
1147                 inTrailer:  {checkInTrailer, 1},
1148                 inPayload:  {checkInPayload, 1},
1149                 end:        {checkEnd, 1},
1150         })
1151 }
1152
1153 func TestClientStatsClientStreamRPCError(t *testing.T) {
1154         count := 1
1155         testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, failfast: false, callType: clientStreamRPC}, map[int]*checkFuncWithCount{
1156                 begin:      {checkBegin, 1},
1157                 outHeader:  {checkOutHeader, 1},
1158                 inHeader:   {checkInHeader, 1},
1159                 outPayload: {checkOutPayload, 1},
1160                 inTrailer:  {checkInTrailer, 1},
1161                 end:        {checkEnd, 1},
1162         })
1163 }
1164
1165 func TestClientStatsServerStreamRPC(t *testing.T) {
1166         count := 5
1167         testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, callType: serverStreamRPC}, map[int]*checkFuncWithCount{
1168                 begin:      {checkBegin, 1},
1169                 outHeader:  {checkOutHeader, 1},
1170                 outPayload: {checkOutPayload, 1},
1171                 inHeader:   {checkInHeader, 1},
1172                 inPayload:  {checkInPayload, count},
1173                 inTrailer:  {checkInTrailer, 1},
1174                 end:        {checkEnd, 1},
1175         })
1176 }
1177
1178 func TestClientStatsServerStreamRPCError(t *testing.T) {
1179         count := 5
1180         testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, failfast: false, callType: serverStreamRPC}, map[int]*checkFuncWithCount{
1181                 begin:      {checkBegin, 1},
1182                 outHeader:  {checkOutHeader, 1},
1183                 outPayload: {checkOutPayload, 1},
1184                 inHeader:   {checkInHeader, 1},
1185                 inTrailer:  {checkInTrailer, 1},
1186                 end:        {checkEnd, 1},
1187         })
1188 }
1189
1190 func TestClientStatsFullDuplexRPC(t *testing.T) {
1191         count := 5
1192         testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, callType: fullDuplexStreamRPC}, map[int]*checkFuncWithCount{
1193                 begin:      {checkBegin, 1},
1194                 outHeader:  {checkOutHeader, 1},
1195                 outPayload: {checkOutPayload, count},
1196                 inHeader:   {checkInHeader, 1},
1197                 inPayload:  {checkInPayload, count},
1198                 inTrailer:  {checkInTrailer, 1},
1199                 end:        {checkEnd, 1},
1200         })
1201 }
1202
1203 func TestClientStatsFullDuplexRPCError(t *testing.T) {
1204         count := 5
1205         testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, failfast: false, callType: fullDuplexStreamRPC}, map[int]*checkFuncWithCount{
1206                 begin:      {checkBegin, 1},
1207                 outHeader:  {checkOutHeader, 1},
1208                 outPayload: {checkOutPayload, 1},
1209                 inHeader:   {checkInHeader, 1},
1210                 inTrailer:  {checkInTrailer, 1},
1211                 end:        {checkEnd, 1},
1212         })
1213 }
1214
1215 // If the user doesn't call the last recv() on clientStream.
1216 func TestClientStatsFullDuplexRPCNotCallingLastRecv(t *testing.T) {
1217         count := 1
1218         testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, callType: fullDuplexStreamRPC, noLastRecv: true}, map[int]*checkFuncWithCount{
1219                 begin:      {checkBegin, 1},
1220                 outHeader:  {checkOutHeader, 1},
1221                 outPayload: {checkOutPayload, count},
1222                 inHeader:   {checkInHeader, 1},
1223                 inPayload:  {checkInPayload, count},
1224                 inTrailer:  {checkInTrailer, 1},
1225                 end:        {checkEnd, 1},
1226         })
1227 }
1228
1229 func TestTags(t *testing.T) {
1230         b := []byte{5, 2, 4, 3, 1}
1231         ctx := stats.SetTags(context.Background(), b)
1232         if tg := stats.OutgoingTags(ctx); !reflect.DeepEqual(tg, b) {
1233                 t.Errorf("OutgoingTags(%v) = %v; want %v", ctx, tg, b)
1234         }
1235         if tg := stats.Tags(ctx); tg != nil {
1236                 t.Errorf("Tags(%v) = %v; want nil", ctx, tg)
1237         }
1238
1239         ctx = stats.SetIncomingTags(context.Background(), b)
1240         if tg := stats.Tags(ctx); !reflect.DeepEqual(tg, b) {
1241                 t.Errorf("Tags(%v) = %v; want %v", ctx, tg, b)
1242         }
1243         if tg := stats.OutgoingTags(ctx); tg != nil {
1244                 t.Errorf("OutgoingTags(%v) = %v; want nil", ctx, tg)
1245         }
1246 }
1247
1248 func TestTrace(t *testing.T) {
1249         b := []byte{5, 2, 4, 3, 1}
1250         ctx := stats.SetTrace(context.Background(), b)
1251         if tr := stats.OutgoingTrace(ctx); !reflect.DeepEqual(tr, b) {
1252                 t.Errorf("OutgoingTrace(%v) = %v; want %v", ctx, tr, b)
1253         }
1254         if tr := stats.Trace(ctx); tr != nil {
1255                 t.Errorf("Trace(%v) = %v; want nil", ctx, tr)
1256         }
1257
1258         ctx = stats.SetIncomingTrace(context.Background(), b)
1259         if tr := stats.Trace(ctx); !reflect.DeepEqual(tr, b) {
1260                 t.Errorf("Trace(%v) = %v; want %v", ctx, tr, b)
1261         }
1262         if tr := stats.OutgoingTrace(ctx); tr != nil {
1263                 t.Errorf("OutgoingTrace(%v) = %v; want nil", ctx, tr)
1264         }
1265 }