OSDN Git Service

new repo
[bytom/vapor.git] / vendor / google.golang.org / grpc / stress / client / main.go
1 /*
2  *
3  * Copyright 2016 gRPC authors.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  */
18
19 //go:generate protoc -I ../grpc_testing --go_out=plugins=grpc:../grpc_testing ../grpc_testing/metrics.proto
20
21 // client starts an interop client to do stress test and a metrics server to report qps.
22 package main
23
24 import (
25         "flag"
26         "fmt"
27         "math/rand"
28         "net"
29         "strconv"
30         "strings"
31         "sync"
32         "time"
33
34         "golang.org/x/net/context"
35         "google.golang.org/grpc"
36         "google.golang.org/grpc/codes"
37         "google.golang.org/grpc/credentials"
38         "google.golang.org/grpc/grpclog"
39         "google.golang.org/grpc/interop"
40         testpb "google.golang.org/grpc/interop/grpc_testing"
41         metricspb "google.golang.org/grpc/stress/grpc_testing"
42         "google.golang.org/grpc/testdata"
43 )
44
45 var (
46         serverAddresses      = flag.String("server_addresses", "localhost:8080", "a list of server addresses")
47         testCases            = flag.String("test_cases", "", "a list of test cases along with the relative weights")
48         testDurationSecs     = flag.Int("test_duration_secs", -1, "test duration in seconds")
49         numChannelsPerServer = flag.Int("num_channels_per_server", 1, "Number of channels (i.e connections) to each server")
50         numStubsPerChannel   = flag.Int("num_stubs_per_channel", 1, "Number of client stubs per each connection to server")
51         metricsPort          = flag.Int("metrics_port", 8081, "The port at which the stress client exposes QPS metrics")
52         useTLS               = flag.Bool("use_tls", false, "Connection uses TLS if true, else plain TCP")
53         testCA               = flag.Bool("use_test_ca", false, "Whether to replace platform root CAs with test CA as the CA root")
54         tlsServerName        = flag.String("server_host_override", "foo.test.google.fr", "The server name use to verify the hostname returned by TLS handshake if it is not empty. Otherwise, --server_host is used.")
55         caFile               = flag.String("ca_file", "", "The file containning the CA root cert file")
56 )
57
58 // testCaseWithWeight contains the test case type and its weight.
59 type testCaseWithWeight struct {
60         name   string
61         weight int
62 }
63
64 // parseTestCases converts test case string to a list of struct testCaseWithWeight.
65 func parseTestCases(testCaseString string) []testCaseWithWeight {
66         testCaseStrings := strings.Split(testCaseString, ",")
67         testCases := make([]testCaseWithWeight, len(testCaseStrings))
68         for i, str := range testCaseStrings {
69                 testCase := strings.Split(str, ":")
70                 if len(testCase) != 2 {
71                         panic(fmt.Sprintf("invalid test case with weight: %s", str))
72                 }
73                 // Check if test case is supported.
74                 switch testCase[0] {
75                 case
76                         "empty_unary",
77                         "large_unary",
78                         "client_streaming",
79                         "server_streaming",
80                         "ping_pong",
81                         "empty_stream",
82                         "timeout_on_sleeping_server",
83                         "cancel_after_begin",
84                         "cancel_after_first_response",
85                         "status_code_and_message",
86                         "custom_metadata":
87                 default:
88                         panic(fmt.Sprintf("unknown test type: %s", testCase[0]))
89                 }
90                 testCases[i].name = testCase[0]
91                 w, err := strconv.Atoi(testCase[1])
92                 if err != nil {
93                         panic(fmt.Sprintf("%v", err))
94                 }
95                 testCases[i].weight = w
96         }
97         return testCases
98 }
99
100 // weightedRandomTestSelector defines a weighted random selector for test case types.
101 type weightedRandomTestSelector struct {
102         tests       []testCaseWithWeight
103         totalWeight int
104 }
105
106 // newWeightedRandomTestSelector constructs a weightedRandomTestSelector with the given list of testCaseWithWeight.
107 func newWeightedRandomTestSelector(tests []testCaseWithWeight) *weightedRandomTestSelector {
108         var totalWeight int
109         for _, t := range tests {
110                 totalWeight += t.weight
111         }
112         rand.Seed(time.Now().UnixNano())
113         return &weightedRandomTestSelector{tests, totalWeight}
114 }
115
116 func (selector weightedRandomTestSelector) getNextTest() string {
117         random := rand.Intn(selector.totalWeight)
118         var weightSofar int
119         for _, test := range selector.tests {
120                 weightSofar += test.weight
121                 if random < weightSofar {
122                         return test.name
123                 }
124         }
125         panic("no test case selected by weightedRandomTestSelector")
126 }
127
128 // gauge stores the qps of one interop client (one stub).
129 type gauge struct {
130         mutex sync.RWMutex
131         val   int64
132 }
133
134 func (g *gauge) set(v int64) {
135         g.mutex.Lock()
136         defer g.mutex.Unlock()
137         g.val = v
138 }
139
140 func (g *gauge) get() int64 {
141         g.mutex.RLock()
142         defer g.mutex.RUnlock()
143         return g.val
144 }
145
146 // server implements metrics server functions.
147 type server struct {
148         mutex sync.RWMutex
149         // gauges is a map from /stress_test/server_<n>/channel_<n>/stub_<n>/qps to its qps gauge.
150         gauges map[string]*gauge
151 }
152
153 // newMetricsServer returns a new metrics server.
154 func newMetricsServer() *server {
155         return &server{gauges: make(map[string]*gauge)}
156 }
157
158 // GetAllGauges returns all gauges.
159 func (s *server) GetAllGauges(in *metricspb.EmptyMessage, stream metricspb.MetricsService_GetAllGaugesServer) error {
160         s.mutex.RLock()
161         defer s.mutex.RUnlock()
162
163         for name, gauge := range s.gauges {
164                 if err := stream.Send(&metricspb.GaugeResponse{Name: name, Value: &metricspb.GaugeResponse_LongValue{LongValue: gauge.get()}}); err != nil {
165                         return err
166                 }
167         }
168         return nil
169 }
170
171 // GetGauge returns the gauge for the given name.
172 func (s *server) GetGauge(ctx context.Context, in *metricspb.GaugeRequest) (*metricspb.GaugeResponse, error) {
173         s.mutex.RLock()
174         defer s.mutex.RUnlock()
175
176         if g, ok := s.gauges[in.Name]; ok {
177                 return &metricspb.GaugeResponse{Name: in.Name, Value: &metricspb.GaugeResponse_LongValue{LongValue: g.get()}}, nil
178         }
179         return nil, grpc.Errorf(codes.InvalidArgument, "gauge with name %s not found", in.Name)
180 }
181
182 // createGauge creates a gauge using the given name in metrics server.
183 func (s *server) createGauge(name string) *gauge {
184         s.mutex.Lock()
185         defer s.mutex.Unlock()
186
187         if _, ok := s.gauges[name]; ok {
188                 // gauge already exists.
189                 panic(fmt.Sprintf("gauge %s already exists", name))
190         }
191         var g gauge
192         s.gauges[name] = &g
193         return &g
194 }
195
196 func startServer(server *server, port int) {
197         lis, err := net.Listen("tcp", ":"+strconv.Itoa(port))
198         if err != nil {
199                 grpclog.Fatalf("failed to listen: %v", err)
200         }
201
202         s := grpc.NewServer()
203         metricspb.RegisterMetricsServiceServer(s, server)
204         s.Serve(lis)
205
206 }
207
208 // performRPCs uses weightedRandomTestSelector to select test case and runs the tests.
209 func performRPCs(gauge *gauge, conn *grpc.ClientConn, selector *weightedRandomTestSelector, stop <-chan bool) {
210         client := testpb.NewTestServiceClient(conn)
211         var numCalls int64
212         startTime := time.Now()
213         for {
214                 test := selector.getNextTest()
215                 switch test {
216                 case "empty_unary":
217                         interop.DoEmptyUnaryCall(client, grpc.FailFast(false))
218                 case "large_unary":
219                         interop.DoLargeUnaryCall(client, grpc.FailFast(false))
220                 case "client_streaming":
221                         interop.DoClientStreaming(client, grpc.FailFast(false))
222                 case "server_streaming":
223                         interop.DoServerStreaming(client, grpc.FailFast(false))
224                 case "ping_pong":
225                         interop.DoPingPong(client, grpc.FailFast(false))
226                 case "empty_stream":
227                         interop.DoEmptyStream(client, grpc.FailFast(false))
228                 case "timeout_on_sleeping_server":
229                         interop.DoTimeoutOnSleepingServer(client, grpc.FailFast(false))
230                 case "cancel_after_begin":
231                         interop.DoCancelAfterBegin(client, grpc.FailFast(false))
232                 case "cancel_after_first_response":
233                         interop.DoCancelAfterFirstResponse(client, grpc.FailFast(false))
234                 case "status_code_and_message":
235                         interop.DoStatusCodeAndMessage(client, grpc.FailFast(false))
236                 case "custom_metadata":
237                         interop.DoCustomMetadata(client, grpc.FailFast(false))
238                 }
239                 numCalls++
240                 gauge.set(int64(float64(numCalls) / time.Since(startTime).Seconds()))
241
242                 select {
243                 case <-stop:
244                         return
245                 default:
246                 }
247         }
248 }
249
250 func logParameterInfo(addresses []string, tests []testCaseWithWeight) {
251         grpclog.Printf("server_addresses: %s", *serverAddresses)
252         grpclog.Printf("test_cases: %s", *testCases)
253         grpclog.Printf("test_duration_secs: %d", *testDurationSecs)
254         grpclog.Printf("num_channels_per_server: %d", *numChannelsPerServer)
255         grpclog.Printf("num_stubs_per_channel: %d", *numStubsPerChannel)
256         grpclog.Printf("metrics_port: %d", *metricsPort)
257         grpclog.Printf("use_tls: %t", *useTLS)
258         grpclog.Printf("use_test_ca: %t", *testCA)
259         grpclog.Printf("server_host_override: %s", *tlsServerName)
260
261         grpclog.Println("addresses:")
262         for i, addr := range addresses {
263                 grpclog.Printf("%d. %s\n", i+1, addr)
264         }
265         grpclog.Println("tests:")
266         for i, test := range tests {
267                 grpclog.Printf("%d. %v\n", i+1, test)
268         }
269 }
270
271 func newConn(address string, useTLS, testCA bool, tlsServerName string) (*grpc.ClientConn, error) {
272         var opts []grpc.DialOption
273         if useTLS {
274                 var sn string
275                 if tlsServerName != "" {
276                         sn = tlsServerName
277                 }
278                 var creds credentials.TransportCredentials
279                 if testCA {
280                         var err error
281                         if *caFile == "" {
282                                 *caFile = testdata.Path("ca.pem")
283                         }
284                         creds, err = credentials.NewClientTLSFromFile(*caFile, sn)
285                         if err != nil {
286                                 grpclog.Fatalf("Failed to create TLS credentials %v", err)
287                         }
288                 } else {
289                         creds = credentials.NewClientTLSFromCert(nil, sn)
290                 }
291                 opts = append(opts, grpc.WithTransportCredentials(creds))
292         } else {
293                 opts = append(opts, grpc.WithInsecure())
294         }
295         return grpc.Dial(address, opts...)
296 }
297
298 func main() {
299         flag.Parse()
300         addresses := strings.Split(*serverAddresses, ",")
301         tests := parseTestCases(*testCases)
302         logParameterInfo(addresses, tests)
303         testSelector := newWeightedRandomTestSelector(tests)
304         metricsServer := newMetricsServer()
305
306         var wg sync.WaitGroup
307         wg.Add(len(addresses) * *numChannelsPerServer * *numStubsPerChannel)
308         stop := make(chan bool)
309
310         for serverIndex, address := range addresses {
311                 for connIndex := 0; connIndex < *numChannelsPerServer; connIndex++ {
312                         conn, err := newConn(address, *useTLS, *testCA, *tlsServerName)
313                         if err != nil {
314                                 grpclog.Fatalf("Fail to dial: %v", err)
315                         }
316                         defer conn.Close()
317                         for clientIndex := 0; clientIndex < *numStubsPerChannel; clientIndex++ {
318                                 name := fmt.Sprintf("/stress_test/server_%d/channel_%d/stub_%d/qps", serverIndex+1, connIndex+1, clientIndex+1)
319                                 go func() {
320                                         defer wg.Done()
321                                         g := metricsServer.createGauge(name)
322                                         performRPCs(g, conn, testSelector, stop)
323                                 }()
324                         }
325
326                 }
327         }
328         go startServer(metricsServer, *metricsPort)
329         if *testDurationSecs > 0 {
330                 time.Sleep(time.Duration(*testDurationSecs) * time.Second)
331                 close(stop)
332         }
333         wg.Wait()
334         grpclog.Printf(" ===== ALL DONE ===== ")
335
336 }