3 * Copyright 2016 gRPC authors.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
19 //go:generate protoc -I ../grpc_testing --go_out=plugins=grpc:../grpc_testing ../grpc_testing/metrics.proto
21 // client starts an interop client to do stress test and a metrics server to report qps.
34 "golang.org/x/net/context"
35 "google.golang.org/grpc"
36 "google.golang.org/grpc/codes"
37 "google.golang.org/grpc/credentials"
38 "google.golang.org/grpc/grpclog"
39 "google.golang.org/grpc/interop"
40 testpb "google.golang.org/grpc/interop/grpc_testing"
41 metricspb "google.golang.org/grpc/stress/grpc_testing"
42 "google.golang.org/grpc/testdata"
46 serverAddresses = flag.String("server_addresses", "localhost:8080", "a list of server addresses")
47 testCases = flag.String("test_cases", "", "a list of test cases along with the relative weights")
48 testDurationSecs = flag.Int("test_duration_secs", -1, "test duration in seconds")
49 numChannelsPerServer = flag.Int("num_channels_per_server", 1, "Number of channels (i.e connections) to each server")
50 numStubsPerChannel = flag.Int("num_stubs_per_channel", 1, "Number of client stubs per each connection to server")
51 metricsPort = flag.Int("metrics_port", 8081, "The port at which the stress client exposes QPS metrics")
52 useTLS = flag.Bool("use_tls", false, "Connection uses TLS if true, else plain TCP")
53 testCA = flag.Bool("use_test_ca", false, "Whether to replace platform root CAs with test CA as the CA root")
54 tlsServerName = flag.String("server_host_override", "foo.test.google.fr", "The server name use to verify the hostname returned by TLS handshake if it is not empty. Otherwise, --server_host is used.")
55 caFile = flag.String("ca_file", "", "The file containning the CA root cert file")
58 // testCaseWithWeight contains the test case type and its weight.
59 type testCaseWithWeight struct {
64 // parseTestCases converts test case string to a list of struct testCaseWithWeight.
65 func parseTestCases(testCaseString string) []testCaseWithWeight {
66 testCaseStrings := strings.Split(testCaseString, ",")
67 testCases := make([]testCaseWithWeight, len(testCaseStrings))
68 for i, str := range testCaseStrings {
69 testCase := strings.Split(str, ":")
70 if len(testCase) != 2 {
71 panic(fmt.Sprintf("invalid test case with weight: %s", str))
73 // Check if test case is supported.
82 "timeout_on_sleeping_server",
84 "cancel_after_first_response",
85 "status_code_and_message",
88 panic(fmt.Sprintf("unknown test type: %s", testCase[0]))
90 testCases[i].name = testCase[0]
91 w, err := strconv.Atoi(testCase[1])
93 panic(fmt.Sprintf("%v", err))
95 testCases[i].weight = w
100 // weightedRandomTestSelector defines a weighted random selector for test case types.
101 type weightedRandomTestSelector struct {
102 tests []testCaseWithWeight
106 // newWeightedRandomTestSelector constructs a weightedRandomTestSelector with the given list of testCaseWithWeight.
107 func newWeightedRandomTestSelector(tests []testCaseWithWeight) *weightedRandomTestSelector {
109 for _, t := range tests {
110 totalWeight += t.weight
112 rand.Seed(time.Now().UnixNano())
113 return &weightedRandomTestSelector{tests, totalWeight}
116 func (selector weightedRandomTestSelector) getNextTest() string {
117 random := rand.Intn(selector.totalWeight)
119 for _, test := range selector.tests {
120 weightSofar += test.weight
121 if random < weightSofar {
125 panic("no test case selected by weightedRandomTestSelector")
128 // gauge stores the qps of one interop client (one stub).
134 func (g *gauge) set(v int64) {
136 defer g.mutex.Unlock()
140 func (g *gauge) get() int64 {
142 defer g.mutex.RUnlock()
146 // server implements metrics server functions.
149 // gauges is a map from /stress_test/server_<n>/channel_<n>/stub_<n>/qps to its qps gauge.
150 gauges map[string]*gauge
153 // newMetricsServer returns a new metrics server.
154 func newMetricsServer() *server {
155 return &server{gauges: make(map[string]*gauge)}
158 // GetAllGauges returns all gauges.
159 func (s *server) GetAllGauges(in *metricspb.EmptyMessage, stream metricspb.MetricsService_GetAllGaugesServer) error {
161 defer s.mutex.RUnlock()
163 for name, gauge := range s.gauges {
164 if err := stream.Send(&metricspb.GaugeResponse{Name: name, Value: &metricspb.GaugeResponse_LongValue{LongValue: gauge.get()}}); err != nil {
171 // GetGauge returns the gauge for the given name.
172 func (s *server) GetGauge(ctx context.Context, in *metricspb.GaugeRequest) (*metricspb.GaugeResponse, error) {
174 defer s.mutex.RUnlock()
176 if g, ok := s.gauges[in.Name]; ok {
177 return &metricspb.GaugeResponse{Name: in.Name, Value: &metricspb.GaugeResponse_LongValue{LongValue: g.get()}}, nil
179 return nil, grpc.Errorf(codes.InvalidArgument, "gauge with name %s not found", in.Name)
182 // createGauge creates a gauge using the given name in metrics server.
183 func (s *server) createGauge(name string) *gauge {
185 defer s.mutex.Unlock()
187 if _, ok := s.gauges[name]; ok {
188 // gauge already exists.
189 panic(fmt.Sprintf("gauge %s already exists", name))
196 func startServer(server *server, port int) {
197 lis, err := net.Listen("tcp", ":"+strconv.Itoa(port))
199 grpclog.Fatalf("failed to listen: %v", err)
202 s := grpc.NewServer()
203 metricspb.RegisterMetricsServiceServer(s, server)
208 // performRPCs uses weightedRandomTestSelector to select test case and runs the tests.
209 func performRPCs(gauge *gauge, conn *grpc.ClientConn, selector *weightedRandomTestSelector, stop <-chan bool) {
210 client := testpb.NewTestServiceClient(conn)
212 startTime := time.Now()
214 test := selector.getNextTest()
217 interop.DoEmptyUnaryCall(client, grpc.FailFast(false))
219 interop.DoLargeUnaryCall(client, grpc.FailFast(false))
220 case "client_streaming":
221 interop.DoClientStreaming(client, grpc.FailFast(false))
222 case "server_streaming":
223 interop.DoServerStreaming(client, grpc.FailFast(false))
225 interop.DoPingPong(client, grpc.FailFast(false))
227 interop.DoEmptyStream(client, grpc.FailFast(false))
228 case "timeout_on_sleeping_server":
229 interop.DoTimeoutOnSleepingServer(client, grpc.FailFast(false))
230 case "cancel_after_begin":
231 interop.DoCancelAfterBegin(client, grpc.FailFast(false))
232 case "cancel_after_first_response":
233 interop.DoCancelAfterFirstResponse(client, grpc.FailFast(false))
234 case "status_code_and_message":
235 interop.DoStatusCodeAndMessage(client, grpc.FailFast(false))
236 case "custom_metadata":
237 interop.DoCustomMetadata(client, grpc.FailFast(false))
240 gauge.set(int64(float64(numCalls) / time.Since(startTime).Seconds()))
250 func logParameterInfo(addresses []string, tests []testCaseWithWeight) {
251 grpclog.Printf("server_addresses: %s", *serverAddresses)
252 grpclog.Printf("test_cases: %s", *testCases)
253 grpclog.Printf("test_duration_secs: %d", *testDurationSecs)
254 grpclog.Printf("num_channels_per_server: %d", *numChannelsPerServer)
255 grpclog.Printf("num_stubs_per_channel: %d", *numStubsPerChannel)
256 grpclog.Printf("metrics_port: %d", *metricsPort)
257 grpclog.Printf("use_tls: %t", *useTLS)
258 grpclog.Printf("use_test_ca: %t", *testCA)
259 grpclog.Printf("server_host_override: %s", *tlsServerName)
261 grpclog.Println("addresses:")
262 for i, addr := range addresses {
263 grpclog.Printf("%d. %s\n", i+1, addr)
265 grpclog.Println("tests:")
266 for i, test := range tests {
267 grpclog.Printf("%d. %v\n", i+1, test)
271 func newConn(address string, useTLS, testCA bool, tlsServerName string) (*grpc.ClientConn, error) {
272 var opts []grpc.DialOption
275 if tlsServerName != "" {
278 var creds credentials.TransportCredentials
282 *caFile = testdata.Path("ca.pem")
284 creds, err = credentials.NewClientTLSFromFile(*caFile, sn)
286 grpclog.Fatalf("Failed to create TLS credentials %v", err)
289 creds = credentials.NewClientTLSFromCert(nil, sn)
291 opts = append(opts, grpc.WithTransportCredentials(creds))
293 opts = append(opts, grpc.WithInsecure())
295 return grpc.Dial(address, opts...)
300 addresses := strings.Split(*serverAddresses, ",")
301 tests := parseTestCases(*testCases)
302 logParameterInfo(addresses, tests)
303 testSelector := newWeightedRandomTestSelector(tests)
304 metricsServer := newMetricsServer()
306 var wg sync.WaitGroup
307 wg.Add(len(addresses) * *numChannelsPerServer * *numStubsPerChannel)
308 stop := make(chan bool)
310 for serverIndex, address := range addresses {
311 for connIndex := 0; connIndex < *numChannelsPerServer; connIndex++ {
312 conn, err := newConn(address, *useTLS, *testCA, *tlsServerName)
314 grpclog.Fatalf("Fail to dial: %v", err)
317 for clientIndex := 0; clientIndex < *numStubsPerChannel; clientIndex++ {
318 name := fmt.Sprintf("/stress_test/server_%d/channel_%d/stub_%d/qps", serverIndex+1, connIndex+1, clientIndex+1)
321 g := metricsServer.createGauge(name)
322 performRPCs(g, conn, testSelector, stop)
328 go startServer(metricsServer, *metricsPort)
329 if *testDurationSecs > 0 {
330 time.Sleep(time.Duration(*testDurationSecs) * time.Second)
334 grpclog.Printf(" ===== ALL DONE ===== ")