OSDN Git Service

new repo
[bytom/vapor.git] / vendor / google.golang.org / grpc / grpclb / grpclb_test.go
1 /*
2  *
3  * Copyright 2016 gRPC authors.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  */
18
19 //go:generate protoc --go_out=plugins=:$GOPATH grpc_lb_v1/messages/messages.proto
20 //go:generate protoc --go_out=plugins=grpc:$GOPATH grpc_lb_v1/service/service.proto
21
22 // Package grpclb_test is currently used only for grpclb testing.
23 package grpclb_test
24
25 import (
26         "errors"
27         "fmt"
28         "io"
29         "net"
30         "strings"
31         "sync"
32         "testing"
33         "time"
34
35         "github.com/golang/protobuf/proto"
36         "golang.org/x/net/context"
37         "google.golang.org/grpc"
38         "google.golang.org/grpc/codes"
39         "google.golang.org/grpc/credentials"
40         lbmpb "google.golang.org/grpc/grpclb/grpc_lb_v1/messages"
41         lbspb "google.golang.org/grpc/grpclb/grpc_lb_v1/service"
42         _ "google.golang.org/grpc/grpclog/glogger"
43         "google.golang.org/grpc/metadata"
44         "google.golang.org/grpc/naming"
45         testpb "google.golang.org/grpc/test/grpc_testing"
46         "google.golang.org/grpc/test/leakcheck"
47 )
48
49 var (
50         lbsn    = "bar.com"
51         besn    = "foo.com"
52         lbToken = "iamatoken"
53
54         // Resolver replaces localhost with fakeName in Next().
55         // Dialer replaces fakeName with localhost when dialing.
56         // This will test that custom dialer is passed from Dial to grpclb.
57         fakeName = "fake.Name"
58 )
59
60 type testWatcher struct {
61         // the channel to receives name resolution updates
62         update chan *naming.Update
63         // the side channel to get to know how many updates in a batch
64         side chan int
65         // the channel to notifiy update injector that the update reading is done
66         readDone chan int
67 }
68
69 func (w *testWatcher) Next() (updates []*naming.Update, err error) {
70         n, ok := <-w.side
71         if !ok {
72                 return nil, fmt.Errorf("w.side is closed")
73         }
74         for i := 0; i < n; i++ {
75                 u, ok := <-w.update
76                 if !ok {
77                         break
78                 }
79                 if u != nil {
80                         // Resolver replaces localhost with fakeName in Next().
81                         // Custom dialer will replace fakeName with localhost when dialing.
82                         u.Addr = strings.Replace(u.Addr, "localhost", fakeName, 1)
83                         updates = append(updates, u)
84                 }
85         }
86         w.readDone <- 0
87         return
88 }
89
90 func (w *testWatcher) Close() {
91         close(w.side)
92 }
93
94 // Inject naming resolution updates to the testWatcher.
95 func (w *testWatcher) inject(updates []*naming.Update) {
96         w.side <- len(updates)
97         for _, u := range updates {
98                 w.update <- u
99         }
100         <-w.readDone
101 }
102
103 type testNameResolver struct {
104         w     *testWatcher
105         addrs []string
106 }
107
108 func (r *testNameResolver) Resolve(target string) (naming.Watcher, error) {
109         r.w = &testWatcher{
110                 update:   make(chan *naming.Update, len(r.addrs)),
111                 side:     make(chan int, 1),
112                 readDone: make(chan int),
113         }
114         r.w.side <- len(r.addrs)
115         for _, addr := range r.addrs {
116                 r.w.update <- &naming.Update{
117                         Op:   naming.Add,
118                         Addr: addr,
119                         Metadata: &naming.AddrMetadataGRPCLB{
120                                 AddrType:   naming.GRPCLB,
121                                 ServerName: lbsn,
122                         },
123                 }
124         }
125         go func() {
126                 <-r.w.readDone
127         }()
128         return r.w, nil
129 }
130
131 func (r *testNameResolver) inject(updates []*naming.Update) {
132         if r.w != nil {
133                 r.w.inject(updates)
134         }
135 }
136
137 type serverNameCheckCreds struct {
138         mu       sync.Mutex
139         sn       string
140         expected string
141 }
142
143 func (c *serverNameCheckCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
144         if _, err := io.WriteString(rawConn, c.sn); err != nil {
145                 fmt.Printf("Failed to write the server name %s to the client %v", c.sn, err)
146                 return nil, nil, err
147         }
148         return rawConn, nil, nil
149 }
150 func (c *serverNameCheckCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
151         c.mu.Lock()
152         defer c.mu.Unlock()
153         b := make([]byte, len(c.expected))
154         errCh := make(chan error, 1)
155         go func() {
156                 _, err := rawConn.Read(b)
157                 errCh <- err
158         }()
159         select {
160         case err := <-errCh:
161                 if err != nil {
162                         fmt.Printf("Failed to read the server name from the server %v", err)
163                         return nil, nil, err
164                 }
165         case <-ctx.Done():
166                 return nil, nil, ctx.Err()
167         }
168         if c.expected != string(b) {
169                 fmt.Printf("Read the server name %s want %s", string(b), c.expected)
170                 return nil, nil, errors.New("received unexpected server name")
171         }
172         return rawConn, nil, nil
173 }
174 func (c *serverNameCheckCreds) Info() credentials.ProtocolInfo {
175         c.mu.Lock()
176         defer c.mu.Unlock()
177         return credentials.ProtocolInfo{}
178 }
179 func (c *serverNameCheckCreds) Clone() credentials.TransportCredentials {
180         c.mu.Lock()
181         defer c.mu.Unlock()
182         return &serverNameCheckCreds{
183                 expected: c.expected,
184         }
185 }
186 func (c *serverNameCheckCreds) OverrideServerName(s string) error {
187         c.mu.Lock()
188         defer c.mu.Unlock()
189         c.expected = s
190         return nil
191 }
192
193 // fakeNameDialer replaces fakeName with localhost when dialing.
194 // This will test that custom dialer is passed from Dial to grpclb.
195 func fakeNameDialer(addr string, timeout time.Duration) (net.Conn, error) {
196         addr = strings.Replace(addr, fakeName, "localhost", 1)
197         return net.DialTimeout("tcp", addr, timeout)
198 }
199
200 type remoteBalancer struct {
201         sls       []*lbmpb.ServerList
202         intervals []time.Duration
203         statsDura time.Duration
204         done      chan struct{}
205         mu        sync.Mutex
206         stats     lbmpb.ClientStats
207 }
208
209 func newRemoteBalancer(sls []*lbmpb.ServerList, intervals []time.Duration) *remoteBalancer {
210         return &remoteBalancer{
211                 sls:       sls,
212                 intervals: intervals,
213                 done:      make(chan struct{}),
214         }
215 }
216
217 func (b *remoteBalancer) stop() {
218         close(b.done)
219 }
220
221 func (b *remoteBalancer) BalanceLoad(stream lbspb.LoadBalancer_BalanceLoadServer) error {
222         req, err := stream.Recv()
223         if err != nil {
224                 return err
225         }
226         initReq := req.GetInitialRequest()
227         if initReq.Name != besn {
228                 return grpc.Errorf(codes.InvalidArgument, "invalid service name: %v", initReq.Name)
229         }
230         resp := &lbmpb.LoadBalanceResponse{
231                 LoadBalanceResponseType: &lbmpb.LoadBalanceResponse_InitialResponse{
232                         InitialResponse: &lbmpb.InitialLoadBalanceResponse{
233                                 ClientStatsReportInterval: &lbmpb.Duration{
234                                         Seconds: int64(b.statsDura.Seconds()),
235                                         Nanos:   int32(b.statsDura.Nanoseconds() - int64(b.statsDura.Seconds())*1e9),
236                                 },
237                         },
238                 },
239         }
240         if err := stream.Send(resp); err != nil {
241                 return err
242         }
243         go func() {
244                 for {
245                         var (
246                                 req *lbmpb.LoadBalanceRequest
247                                 err error
248                         )
249                         if req, err = stream.Recv(); err != nil {
250                                 return
251                         }
252                         b.mu.Lock()
253                         b.stats.NumCallsStarted += req.GetClientStats().NumCallsStarted
254                         b.stats.NumCallsFinished += req.GetClientStats().NumCallsFinished
255                         b.stats.NumCallsFinishedWithDropForRateLimiting += req.GetClientStats().NumCallsFinishedWithDropForRateLimiting
256                         b.stats.NumCallsFinishedWithDropForLoadBalancing += req.GetClientStats().NumCallsFinishedWithDropForLoadBalancing
257                         b.stats.NumCallsFinishedWithClientFailedToSend += req.GetClientStats().NumCallsFinishedWithClientFailedToSend
258                         b.stats.NumCallsFinishedKnownReceived += req.GetClientStats().NumCallsFinishedKnownReceived
259                         b.mu.Unlock()
260                 }
261         }()
262         for k, v := range b.sls {
263                 time.Sleep(b.intervals[k])
264                 resp = &lbmpb.LoadBalanceResponse{
265                         LoadBalanceResponseType: &lbmpb.LoadBalanceResponse_ServerList{
266                                 ServerList: v,
267                         },
268                 }
269                 if err := stream.Send(resp); err != nil {
270                         return err
271                 }
272         }
273         <-b.done
274         return nil
275 }
276
277 type testServer struct {
278         testpb.TestServiceServer
279
280         addr string
281 }
282
283 const testmdkey = "testmd"
284
285 func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
286         md, ok := metadata.FromIncomingContext(ctx)
287         if !ok {
288                 return nil, grpc.Errorf(codes.Internal, "failed to receive metadata")
289         }
290         if md == nil || md["lb-token"][0] != lbToken {
291                 return nil, grpc.Errorf(codes.Internal, "received unexpected metadata: %v", md)
292         }
293         grpc.SetTrailer(ctx, metadata.Pairs(testmdkey, s.addr))
294         return &testpb.Empty{}, nil
295 }
296
297 func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error {
298         return nil
299 }
300
301 func startBackends(sn string, lis ...net.Listener) (servers []*grpc.Server) {
302         for _, l := range lis {
303                 creds := &serverNameCheckCreds{
304                         sn: sn,
305                 }
306                 s := grpc.NewServer(grpc.Creds(creds))
307                 testpb.RegisterTestServiceServer(s, &testServer{addr: l.Addr().String()})
308                 servers = append(servers, s)
309                 go func(s *grpc.Server, l net.Listener) {
310                         s.Serve(l)
311                 }(s, l)
312         }
313         return
314 }
315
316 func stopBackends(servers []*grpc.Server) {
317         for _, s := range servers {
318                 s.Stop()
319         }
320 }
321
322 type testServers struct {
323         lbAddr  string
324         ls      *remoteBalancer
325         lb      *grpc.Server
326         beIPs   []net.IP
327         bePorts []int
328 }
329
330 func newLoadBalancer(numberOfBackends int) (tss *testServers, cleanup func(), err error) {
331         var (
332                 beListeners []net.Listener
333                 ls          *remoteBalancer
334                 lb          *grpc.Server
335                 beIPs       []net.IP
336                 bePorts     []int
337         )
338         for i := 0; i < numberOfBackends; i++ {
339                 // Start a backend.
340                 beLis, e := net.Listen("tcp", "localhost:0")
341                 if e != nil {
342                         err = fmt.Errorf("Failed to listen %v", err)
343                         return
344                 }
345                 beIPs = append(beIPs, beLis.Addr().(*net.TCPAddr).IP)
346                 bePorts = append(bePorts, beLis.Addr().(*net.TCPAddr).Port)
347
348                 beListeners = append(beListeners, beLis)
349         }
350         backends := startBackends(besn, beListeners...)
351
352         // Start a load balancer.
353         lbLis, err := net.Listen("tcp", "localhost:0")
354         if err != nil {
355                 err = fmt.Errorf("Failed to create the listener for the load balancer %v", err)
356                 return
357         }
358         lbCreds := &serverNameCheckCreds{
359                 sn: lbsn,
360         }
361         lb = grpc.NewServer(grpc.Creds(lbCreds))
362         if err != nil {
363                 err = fmt.Errorf("Failed to generate the port number %v", err)
364                 return
365         }
366         ls = newRemoteBalancer(nil, nil)
367         lbspb.RegisterLoadBalancerServer(lb, ls)
368         go func() {
369                 lb.Serve(lbLis)
370         }()
371
372         tss = &testServers{
373                 lbAddr:  lbLis.Addr().String(),
374                 ls:      ls,
375                 lb:      lb,
376                 beIPs:   beIPs,
377                 bePorts: bePorts,
378         }
379         cleanup = func() {
380                 defer stopBackends(backends)
381                 defer func() {
382                         ls.stop()
383                         lb.Stop()
384                 }()
385         }
386         return
387 }
388
389 func TestGRPCLB(t *testing.T) {
390         defer leakcheck.Check(t)
391         tss, cleanup, err := newLoadBalancer(1)
392         if err != nil {
393                 t.Fatalf("failed to create new load balancer: %v", err)
394         }
395         defer cleanup()
396
397         be := &lbmpb.Server{
398                 IpAddress:        tss.beIPs[0],
399                 Port:             int32(tss.bePorts[0]),
400                 LoadBalanceToken: lbToken,
401         }
402         var bes []*lbmpb.Server
403         bes = append(bes, be)
404         sl := &lbmpb.ServerList{
405                 Servers: bes,
406         }
407         tss.ls.sls = []*lbmpb.ServerList{sl}
408         tss.ls.intervals = []time.Duration{0}
409         creds := serverNameCheckCreds{
410                 expected: besn,
411         }
412         ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
413         defer cancel()
414         cc, err := grpc.DialContext(ctx, besn,
415                 grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{addrs: []string{tss.lbAddr}})),
416                 grpc.WithBlock(), grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer))
417         if err != nil {
418                 t.Fatalf("Failed to dial to the backend %v", err)
419         }
420         defer cc.Close()
421         testC := testpb.NewTestServiceClient(cc)
422         if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
423                 t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
424         }
425 }
426
427 func TestDropRequest(t *testing.T) {
428         defer leakcheck.Check(t)
429         tss, cleanup, err := newLoadBalancer(2)
430         if err != nil {
431                 t.Fatalf("failed to create new load balancer: %v", err)
432         }
433         defer cleanup()
434         tss.ls.sls = []*lbmpb.ServerList{{
435                 Servers: []*lbmpb.Server{{
436                         IpAddress:            tss.beIPs[0],
437                         Port:                 int32(tss.bePorts[0]),
438                         LoadBalanceToken:     lbToken,
439                         DropForLoadBalancing: true,
440                 }, {
441                         IpAddress:            tss.beIPs[1],
442                         Port:                 int32(tss.bePorts[1]),
443                         LoadBalanceToken:     lbToken,
444                         DropForLoadBalancing: false,
445                 }},
446         }}
447         tss.ls.intervals = []time.Duration{0}
448         creds := serverNameCheckCreds{
449                 expected: besn,
450         }
451         ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
452         defer cancel()
453         cc, err := grpc.DialContext(ctx, besn,
454                 grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{addrs: []string{tss.lbAddr}})),
455                 grpc.WithBlock(), grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer))
456         if err != nil {
457                 t.Fatalf("Failed to dial to the backend %v", err)
458         }
459         defer cc.Close()
460         testC := testpb.NewTestServiceClient(cc)
461         // Wait until the first connection is up.
462         // The first one has Drop set to true, error should contain "drop requests".
463         for {
464                 if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
465                         if strings.Contains(err.Error(), "drops requests") {
466                                 break
467                         }
468                 }
469         }
470         // The 1st, non-fail-fast RPC should succeed.  This ensures both server
471         // connections are made, because the first one has DropForLoadBalancing set to true.
472         if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
473                 t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", testC, err)
474         }
475         for i := 0; i < 3; i++ {
476                 // Odd fail-fast RPCs should fail, because the 1st backend has DropForLoadBalancing
477                 // set to true.
478                 if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.Unavailable {
479                         t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, %s", testC, err, codes.Unavailable)
480                 }
481                 // Even fail-fast RPCs should succeed since they choose the
482                 // non-drop-request backend according to the round robin policy.
483                 if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
484                         t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
485                 }
486         }
487 }
488
489 func TestDropRequestFailedNonFailFast(t *testing.T) {
490         defer leakcheck.Check(t)
491         tss, cleanup, err := newLoadBalancer(1)
492         if err != nil {
493                 t.Fatalf("failed to create new load balancer: %v", err)
494         }
495         defer cleanup()
496         be := &lbmpb.Server{
497                 IpAddress:            tss.beIPs[0],
498                 Port:                 int32(tss.bePorts[0]),
499                 LoadBalanceToken:     lbToken,
500                 DropForLoadBalancing: true,
501         }
502         var bes []*lbmpb.Server
503         bes = append(bes, be)
504         sl := &lbmpb.ServerList{
505                 Servers: bes,
506         }
507         tss.ls.sls = []*lbmpb.ServerList{sl}
508         tss.ls.intervals = []time.Duration{0}
509         creds := serverNameCheckCreds{
510                 expected: besn,
511         }
512         ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
513         defer cancel()
514         cc, err := grpc.DialContext(ctx, besn,
515                 grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{addrs: []string{tss.lbAddr}})),
516                 grpc.WithBlock(), grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer))
517         if err != nil {
518                 t.Fatalf("Failed to dial to the backend %v", err)
519         }
520         defer cc.Close()
521         testC := testpb.NewTestServiceClient(cc)
522         ctx, cancel = context.WithTimeout(context.Background(), 10*time.Millisecond)
523         defer cancel()
524         if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); grpc.Code(err) != codes.DeadlineExceeded {
525                 t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, %s", testC, err, codes.DeadlineExceeded)
526         }
527 }
528
529 // When the balancer in use disconnects, grpclb should connect to the next address from resolved balancer address list.
530 func TestBalancerDisconnects(t *testing.T) {
531         defer leakcheck.Check(t)
532         var (
533                 lbAddrs []string
534                 lbs     []*grpc.Server
535         )
536         for i := 0; i < 3; i++ {
537                 tss, cleanup, err := newLoadBalancer(1)
538                 if err != nil {
539                         t.Fatalf("failed to create new load balancer: %v", err)
540                 }
541                 defer cleanup()
542
543                 be := &lbmpb.Server{
544                         IpAddress:        tss.beIPs[0],
545                         Port:             int32(tss.bePorts[0]),
546                         LoadBalanceToken: lbToken,
547                 }
548                 var bes []*lbmpb.Server
549                 bes = append(bes, be)
550                 sl := &lbmpb.ServerList{
551                         Servers: bes,
552                 }
553                 tss.ls.sls = []*lbmpb.ServerList{sl}
554                 tss.ls.intervals = []time.Duration{0}
555
556                 lbAddrs = append(lbAddrs, tss.lbAddr)
557                 lbs = append(lbs, tss.lb)
558         }
559
560         creds := serverNameCheckCreds{
561                 expected: besn,
562         }
563         ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
564         defer cancel()
565         resolver := &testNameResolver{
566                 addrs: lbAddrs[:2],
567         }
568         cc, err := grpc.DialContext(ctx, besn,
569                 grpc.WithBalancer(grpc.NewGRPCLBBalancer(resolver)),
570                 grpc.WithBlock(), grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer))
571         if err != nil {
572                 t.Fatalf("Failed to dial to the backend %v", err)
573         }
574         defer cc.Close()
575         testC := testpb.NewTestServiceClient(cc)
576         var previousTrailer string
577         trailer := metadata.MD{}
578         if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Trailer(&trailer), grpc.FailFast(false)); err != nil {
579                 t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
580         } else {
581                 previousTrailer = trailer[testmdkey][0]
582         }
583         // The initial resolver update contains lbs[0] and lbs[1].
584         // When lbs[0] is stopped, lbs[1] should be used.
585         lbs[0].Stop()
586         for {
587                 if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Trailer(&trailer), grpc.FailFast(false)); err != nil {
588                         t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
589                 } else if trailer[testmdkey][0] != previousTrailer {
590                         // A new backend server should receive the request.
591                         // The trailer contains the backend address, so the trailer should be different from the previous one.
592                         previousTrailer = trailer[testmdkey][0]
593                         break
594                 }
595                 time.Sleep(100 * time.Millisecond)
596         }
597         // Inject a update to add lbs[2] to resolved addresses.
598         resolver.inject([]*naming.Update{
599                 {Op: naming.Add,
600                         Addr: lbAddrs[2],
601                         Metadata: &naming.AddrMetadataGRPCLB{
602                                 AddrType:   naming.GRPCLB,
603                                 ServerName: lbsn,
604                         },
605                 },
606         })
607         // Stop lbs[1]. Now lbs[0] and lbs[1] are all stopped. lbs[2] should be used.
608         lbs[1].Stop()
609         for {
610                 if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Trailer(&trailer), grpc.FailFast(false)); err != nil {
611                         t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
612                 } else if trailer[testmdkey][0] != previousTrailer {
613                         // A new backend server should receive the request.
614                         // The trailer contains the backend address, so the trailer should be different from the previous one.
615                         break
616                 }
617                 time.Sleep(100 * time.Millisecond)
618         }
619 }
620
621 type failPreRPCCred struct{}
622
623 func (failPreRPCCred) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
624         if strings.Contains(uri[0], "failtosend") {
625                 return nil, fmt.Errorf("rpc should fail to send")
626         }
627         return nil, nil
628 }
629
630 func (failPreRPCCred) RequireTransportSecurity() bool {
631         return false
632 }
633
634 func checkStats(stats *lbmpb.ClientStats, expected *lbmpb.ClientStats) error {
635         if !proto.Equal(stats, expected) {
636                 return fmt.Errorf("stats not equal: got %+v, want %+v", stats, expected)
637         }
638         return nil
639 }
640
641 func runAndGetStats(t *testing.T, dropForLoadBalancing, dropForRateLimiting bool, runRPCs func(*grpc.ClientConn)) lbmpb.ClientStats {
642         tss, cleanup, err := newLoadBalancer(3)
643         if err != nil {
644                 t.Fatalf("failed to create new load balancer: %v", err)
645         }
646         defer cleanup()
647         tss.ls.sls = []*lbmpb.ServerList{{
648                 Servers: []*lbmpb.Server{{
649                         IpAddress:            tss.beIPs[2],
650                         Port:                 int32(tss.bePorts[2]),
651                         LoadBalanceToken:     lbToken,
652                         DropForLoadBalancing: dropForLoadBalancing,
653                         DropForRateLimiting:  dropForRateLimiting,
654                 }},
655         }}
656         tss.ls.intervals = []time.Duration{0}
657         tss.ls.statsDura = 100 * time.Millisecond
658         creds := serverNameCheckCreds{expected: besn}
659
660         ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
661         defer cancel()
662         cc, err := grpc.DialContext(ctx, besn,
663                 grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{addrs: []string{tss.lbAddr}})),
664                 grpc.WithTransportCredentials(&creds), grpc.WithPerRPCCredentials(failPreRPCCred{}),
665                 grpc.WithBlock(), grpc.WithDialer(fakeNameDialer))
666         if err != nil {
667                 t.Fatalf("Failed to dial to the backend %v", err)
668         }
669         defer cc.Close()
670
671         runRPCs(cc)
672         time.Sleep(1 * time.Second)
673         tss.ls.mu.Lock()
674         stats := tss.ls.stats
675         tss.ls.mu.Unlock()
676         return stats
677 }
678
679 const countRPC = 40
680
681 func TestGRPCLBStatsUnarySuccess(t *testing.T) {
682         defer leakcheck.Check(t)
683         stats := runAndGetStats(t, false, false, func(cc *grpc.ClientConn) {
684                 testC := testpb.NewTestServiceClient(cc)
685                 // The first non-failfast RPC succeeds, all connections are up.
686                 if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
687                         t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
688                 }
689                 for i := 0; i < countRPC-1; i++ {
690                         testC.EmptyCall(context.Background(), &testpb.Empty{})
691                 }
692         })
693
694         if err := checkStats(&stats, &lbmpb.ClientStats{
695                 NumCallsStarted:               int64(countRPC),
696                 NumCallsFinished:              int64(countRPC),
697                 NumCallsFinishedKnownReceived: int64(countRPC),
698         }); err != nil {
699                 t.Fatal(err)
700         }
701 }
702
703 func TestGRPCLBStatsUnaryDropLoadBalancing(t *testing.T) {
704         defer leakcheck.Check(t)
705         c := 0
706         stats := runAndGetStats(t, true, false, func(cc *grpc.ClientConn) {
707                 testC := testpb.NewTestServiceClient(cc)
708                 for {
709                         c++
710                         if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
711                                 if strings.Contains(err.Error(), "drops requests") {
712                                         break
713                                 }
714                         }
715                 }
716                 for i := 0; i < countRPC; i++ {
717                         testC.EmptyCall(context.Background(), &testpb.Empty{})
718                 }
719         })
720
721         if err := checkStats(&stats, &lbmpb.ClientStats{
722                 NumCallsStarted:                          int64(countRPC + c),
723                 NumCallsFinished:                         int64(countRPC + c),
724                 NumCallsFinishedWithDropForLoadBalancing: int64(countRPC + 1),
725                 NumCallsFinishedWithClientFailedToSend:   int64(c - 1),
726         }); err != nil {
727                 t.Fatal(err)
728         }
729 }
730
731 func TestGRPCLBStatsUnaryDropRateLimiting(t *testing.T) {
732         defer leakcheck.Check(t)
733         c := 0
734         stats := runAndGetStats(t, false, true, func(cc *grpc.ClientConn) {
735                 testC := testpb.NewTestServiceClient(cc)
736                 for {
737                         c++
738                         if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
739                                 if strings.Contains(err.Error(), "drops requests") {
740                                         break
741                                 }
742                         }
743                 }
744                 for i := 0; i < countRPC; i++ {
745                         testC.EmptyCall(context.Background(), &testpb.Empty{})
746                 }
747         })
748
749         if err := checkStats(&stats, &lbmpb.ClientStats{
750                 NumCallsStarted:                         int64(countRPC + c),
751                 NumCallsFinished:                        int64(countRPC + c),
752                 NumCallsFinishedWithDropForRateLimiting: int64(countRPC + 1),
753                 NumCallsFinishedWithClientFailedToSend:  int64(c - 1),
754         }); err != nil {
755                 t.Fatal(err)
756         }
757 }
758
759 func TestGRPCLBStatsUnaryFailedToSend(t *testing.T) {
760         defer leakcheck.Check(t)
761         stats := runAndGetStats(t, false, false, func(cc *grpc.ClientConn) {
762                 testC := testpb.NewTestServiceClient(cc)
763                 // The first non-failfast RPC succeeds, all connections are up.
764                 if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
765                         t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
766                 }
767                 for i := 0; i < countRPC-1; i++ {
768                         grpc.Invoke(context.Background(), "failtosend", &testpb.Empty{}, nil, cc)
769                 }
770         })
771
772         if err := checkStats(&stats, &lbmpb.ClientStats{
773                 NumCallsStarted:                        int64(countRPC),
774                 NumCallsFinished:                       int64(countRPC),
775                 NumCallsFinishedWithClientFailedToSend: int64(countRPC - 1),
776                 NumCallsFinishedKnownReceived:          1,
777         }); err != nil {
778                 t.Fatal(err)
779         }
780 }
781
782 func TestGRPCLBStatsStreamingSuccess(t *testing.T) {
783         defer leakcheck.Check(t)
784         stats := runAndGetStats(t, false, false, func(cc *grpc.ClientConn) {
785                 testC := testpb.NewTestServiceClient(cc)
786                 // The first non-failfast RPC succeeds, all connections are up.
787                 stream, err := testC.FullDuplexCall(context.Background(), grpc.FailFast(false))
788                 if err != nil {
789                         t.Fatalf("%v.FullDuplexCall(_, _) = _, %v, want _, <nil>", testC, err)
790                 }
791                 for {
792                         if _, err = stream.Recv(); err == io.EOF {
793                                 break
794                         }
795                 }
796                 for i := 0; i < countRPC-1; i++ {
797                         stream, err = testC.FullDuplexCall(context.Background())
798                         if err == nil {
799                                 // Wait for stream to end if err is nil.
800                                 for {
801                                         if _, err = stream.Recv(); err == io.EOF {
802                                                 break
803                                         }
804                                 }
805                         }
806                 }
807         })
808
809         if err := checkStats(&stats, &lbmpb.ClientStats{
810                 NumCallsStarted:               int64(countRPC),
811                 NumCallsFinished:              int64(countRPC),
812                 NumCallsFinishedKnownReceived: int64(countRPC),
813         }); err != nil {
814                 t.Fatal(err)
815         }
816 }
817
818 func TestGRPCLBStatsStreamingDropLoadBalancing(t *testing.T) {
819         defer leakcheck.Check(t)
820         c := 0
821         stats := runAndGetStats(t, true, false, func(cc *grpc.ClientConn) {
822                 testC := testpb.NewTestServiceClient(cc)
823                 for {
824                         c++
825                         if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
826                                 if strings.Contains(err.Error(), "drops requests") {
827                                         break
828                                 }
829                         }
830                 }
831                 for i := 0; i < countRPC; i++ {
832                         testC.FullDuplexCall(context.Background())
833                 }
834         })
835
836         if err := checkStats(&stats, &lbmpb.ClientStats{
837                 NumCallsStarted:                          int64(countRPC + c),
838                 NumCallsFinished:                         int64(countRPC + c),
839                 NumCallsFinishedWithDropForLoadBalancing: int64(countRPC + 1),
840                 NumCallsFinishedWithClientFailedToSend:   int64(c - 1),
841         }); err != nil {
842                 t.Fatal(err)
843         }
844 }
845
846 func TestGRPCLBStatsStreamingDropRateLimiting(t *testing.T) {
847         defer leakcheck.Check(t)
848         c := 0
849         stats := runAndGetStats(t, false, true, func(cc *grpc.ClientConn) {
850                 testC := testpb.NewTestServiceClient(cc)
851                 for {
852                         c++
853                         if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
854                                 if strings.Contains(err.Error(), "drops requests") {
855                                         break
856                                 }
857                         }
858                 }
859                 for i := 0; i < countRPC; i++ {
860                         testC.FullDuplexCall(context.Background())
861                 }
862         })
863
864         if err := checkStats(&stats, &lbmpb.ClientStats{
865                 NumCallsStarted:                         int64(countRPC + c),
866                 NumCallsFinished:                        int64(countRPC + c),
867                 NumCallsFinishedWithDropForRateLimiting: int64(countRPC + 1),
868                 NumCallsFinishedWithClientFailedToSend:  int64(c - 1),
869         }); err != nil {
870                 t.Fatal(err)
871         }
872 }
873
874 func TestGRPCLBStatsStreamingFailedToSend(t *testing.T) {
875         defer leakcheck.Check(t)
876         stats := runAndGetStats(t, false, false, func(cc *grpc.ClientConn) {
877                 testC := testpb.NewTestServiceClient(cc)
878                 // The first non-failfast RPC succeeds, all connections are up.
879                 stream, err := testC.FullDuplexCall(context.Background(), grpc.FailFast(false))
880                 if err != nil {
881                         t.Fatalf("%v.FullDuplexCall(_, _) = _, %v, want _, <nil>", testC, err)
882                 }
883                 for {
884                         if _, err = stream.Recv(); err == io.EOF {
885                                 break
886                         }
887                 }
888                 for i := 0; i < countRPC-1; i++ {
889                         grpc.NewClientStream(context.Background(), &grpc.StreamDesc{}, cc, "failtosend")
890                 }
891         })
892
893         if err := checkStats(&stats, &lbmpb.ClientStats{
894                 NumCallsStarted:                        int64(countRPC),
895                 NumCallsFinished:                       int64(countRPC),
896                 NumCallsFinishedWithClientFailedToSend: int64(countRPC - 1),
897                 NumCallsFinishedKnownReceived:          1,
898         }); err != nil {
899                 t.Fatal(err)
900         }
901 }