diff --git a/benchmark/benchmain/main.go b/benchmark/benchmain/main.go index 443799f57871..2077149e5660 100644 --- a/benchmark/benchmain/main.go +++ b/benchmark/benchmain/main.go @@ -537,7 +537,7 @@ func prepareMessages(streams [][]testgrpc.BenchmarkService_StreamingCallClient, // Makes a UnaryCall gRPC request using the given BenchmarkServiceClient and // request and response sizes. func unaryCaller(client testgrpc.BenchmarkServiceClient, reqSize, respSize int) { - if err := benchmark.DoUnaryCall(client, reqSize, respSize); err != nil { + if err := benchmark.DoUnaryCall(context.Background(), client, reqSize, respSize); err != nil { logger.Fatalf("DoUnaryCall failed: %v", err) } } diff --git a/benchmark/benchmark.go b/benchmark/benchmark.go index 5c3dcd51db9d..1238320d49f3 100644 --- a/benchmark/benchmark.go +++ b/benchmark/benchmark.go @@ -276,15 +276,15 @@ func StartServer(info ServerInfo, opts ...grpc.ServerOption) func() { } } -// DoUnaryCall performs a unary RPC with given stub and request and response sizes. -func DoUnaryCall(tc testgrpc.BenchmarkServiceClient, reqSize, respSize int) error { +// DoUnaryCall performs a unary RPC with given context, stub and request and response sizes. +func DoUnaryCall(ctx context.Context, tc testgrpc.BenchmarkServiceClient, reqSize, respSize int) error { pl := NewPayload(testpb.PayloadType_COMPRESSABLE, reqSize) req := &testpb.SimpleRequest{ ResponseType: pl.Type, ResponseSize: int32(respSize), Payload: pl, } - if _, err := tc.UnaryCall(context.Background(), req); err != nil { + if _, err := tc.UnaryCall(ctx, req); err != nil { return fmt.Errorf("/BenchmarkService/UnaryCall(_, _) = _, %v, want _, ", err) } return nil diff --git a/benchmark/worker/benchmark_client.go b/benchmark/worker/benchmark_client.go index c28312dd6aab..2566e31f1346 100644 --- a/benchmark/worker/benchmark_client.go +++ b/benchmark/worker/benchmark_client.go @@ -73,7 +73,6 @@ func (h *lockingHistogram) mergeInto(merged *stats.Histogram) { type benchmarkClient struct { closeConns func() - stop chan bool lastResetTime time.Time histogramOptions stats.HistogramOptions lockingHistograms []lockingHistogram @@ -168,7 +167,7 @@ func createConns(config *testpb.ClientConfig) ([]*grpc.ClientConn, func(), error }, nil } -func performRPCs(config *testpb.ClientConfig, conns []*grpc.ClientConn, bc *benchmarkClient) error { +func performRPCs(ctx context.Context, config *testpb.ClientConfig, conns []*grpc.ClientConn, bc *benchmarkClient) error { // Read payload size and type from config. var ( payloadReqSize, payloadRespSize int @@ -212,9 +211,9 @@ func performRPCs(config *testpb.ClientConfig, conns []*grpc.ClientConn, bc *benc switch config.RpcType { case testpb.RpcType_UNARY: - bc.unaryLoop(conns, rpcCountPerConn, payloadReqSize, payloadRespSize, poissonLambda) + bc.unaryLoop(ctx, conns, rpcCountPerConn, payloadReqSize, payloadRespSize, poissonLambda) case testpb.RpcType_STREAMING: - bc.streamingLoop(conns, rpcCountPerConn, payloadReqSize, payloadRespSize, payloadType, poissonLambda) + bc.streamingLoop(ctx, conns, rpcCountPerConn, payloadReqSize, payloadRespSize, payloadType, poissonLambda) default: return status.Errorf(codes.InvalidArgument, "unknown rpc type: %v", config.RpcType) } @@ -222,7 +221,7 @@ func performRPCs(config *testpb.ClientConfig, conns []*grpc.ClientConn, bc *benc return nil } -func startBenchmarkClient(config *testpb.ClientConfig) (*benchmarkClient, error) { +func startBenchmarkClient(ctx context.Context, config *testpb.ClientConfig) (*benchmarkClient, error) { printClientConfig(config) // Set running environment like how many cores to use. @@ -243,13 +242,12 @@ func startBenchmarkClient(config *testpb.ClientConfig) (*benchmarkClient, error) }, lockingHistograms: make([]lockingHistogram, rpcCountPerConn*len(conns)), - stop: make(chan bool), lastResetTime: time.Now(), closeConns: closeConns, rusageLastReset: syscall.GetRusage(), } - if err = performRPCs(config, conns, bc); err != nil { + if err = performRPCs(ctx, config, conns, bc); err != nil { // Close all connections if performRPCs failed. closeConns() return nil, err @@ -258,7 +256,7 @@ func startBenchmarkClient(config *testpb.ClientConfig) (*benchmarkClient, error) return bc, nil } -func (bc *benchmarkClient) unaryLoop(conns []*grpc.ClientConn, rpcCountPerConn int, reqSize int, respSize int, poissonLambda *float64) { +func (bc *benchmarkClient) unaryLoop(ctx context.Context, conns []*grpc.ClientConn, rpcCountPerConn int, reqSize int, respSize int, poissonLambda *float64) { for ic, conn := range conns { client := testgrpc.NewBenchmarkServiceClient(conn) // For each connection, create rpcCountPerConn goroutines to do rpc. @@ -266,6 +264,9 @@ func (bc *benchmarkClient) unaryLoop(conns []*grpc.ClientConn, rpcCountPerConn i // Create histogram for each goroutine. idx := ic*rpcCountPerConn + j bc.lockingHistograms[idx].histogram = stats.NewHistogram(bc.histogramOptions) + if ctx.Err() != nil { + return + } // Start goroutine on the created mutex and histogram. go func(idx int) { // TODO: do warm up if necessary. @@ -274,14 +275,9 @@ func (bc *benchmarkClient) unaryLoop(conns []*grpc.ClientConn, rpcCountPerConn i // before starting benchmark. if poissonLambda == nil { // Closed loop. for { - select { - case <-bc.stop: - return - default: - } start := time.Now() - if err := benchmark.DoUnaryCall(client, reqSize, respSize); err != nil { - continue + if err := benchmark.DoUnaryCall(ctx, client, reqSize, respSize); err != nil { + return } elapse := time.Since(start) bc.lockingHistograms[idx].add(int64(elapse)) @@ -289,16 +285,15 @@ func (bc *benchmarkClient) unaryLoop(conns []*grpc.ClientConn, rpcCountPerConn i } else { // Open loop. timeBetweenRPCs := time.Duration((rand.ExpFloat64() / *poissonLambda) * float64(time.Second)) time.AfterFunc(timeBetweenRPCs, func() { - bc.poissonUnary(client, idx, reqSize, respSize, *poissonLambda) + bc.poissonUnary(ctx, client, idx, reqSize, respSize, *poissonLambda) }) } - }(idx) } } } -func (bc *benchmarkClient) streamingLoop(conns []*grpc.ClientConn, rpcCountPerConn int, reqSize int, respSize int, payloadType string, poissonLambda *float64) { +func (bc *benchmarkClient) streamingLoop(ctx context.Context, conns []*grpc.ClientConn, rpcCountPerConn int, reqSize int, respSize int, payloadType string, poissonLambda *float64) { var doRPC func(testgrpc.BenchmarkService_StreamingCallClient, int, int) error if payloadType == "bytebuf" { doRPC = benchmark.DoByteBufStreamingRoundTrip @@ -309,13 +304,16 @@ func (bc *benchmarkClient) streamingLoop(conns []*grpc.ClientConn, rpcCountPerCo // For each connection, create rpcCountPerConn goroutines to do rpc. for j := 0; j < rpcCountPerConn; j++ { c := testgrpc.NewBenchmarkServiceClient(conn) - stream, err := c.StreamingCall(context.Background()) + stream, err := c.StreamingCall(ctx) if err != nil { logger.Fatalf("%v.StreamingCall(_) = _, %v", c, err) } idx := ic*rpcCountPerConn + j bc.lockingHistograms[idx].histogram = stats.NewHistogram(bc.histogramOptions) if poissonLambda == nil { // Closed loop. + if stream.Context().Err() != nil { + return + } // Start goroutine on the created mutex and histogram. go func(idx int) { // TODO: do warm up if necessary. @@ -329,11 +327,6 @@ func (bc *benchmarkClient) streamingLoop(conns []*grpc.ClientConn, rpcCountPerCo } elapse := time.Since(start) bc.lockingHistograms[idx].add(int64(elapse)) - select { - case <-bc.stop: - return - default: - } } }(idx) } else { // Open loop. @@ -346,10 +339,11 @@ func (bc *benchmarkClient) streamingLoop(conns []*grpc.ClientConn, rpcCountPerCo } } -func (bc *benchmarkClient) poissonUnary(client testgrpc.BenchmarkServiceClient, idx int, reqSize int, respSize int, lambda float64) { +func (bc *benchmarkClient) poissonUnary(ctx context.Context, client testgrpc.BenchmarkServiceClient, idx int, reqSize int, respSize int, lambda float64) { go func() { start := time.Now() - if err := benchmark.DoUnaryCall(client, reqSize, respSize); err != nil { + + if err := benchmark.DoUnaryCall(ctx, client, reqSize, respSize); err != nil { return } elapse := time.Since(start) @@ -357,13 +351,14 @@ func (bc *benchmarkClient) poissonUnary(client testgrpc.BenchmarkServiceClient, }() timeBetweenRPCs := time.Duration((rand.ExpFloat64() / lambda) * float64(time.Second)) time.AfterFunc(timeBetweenRPCs, func() { - bc.poissonUnary(client, idx, reqSize, respSize, lambda) + bc.poissonUnary(ctx, client, idx, reqSize, respSize, lambda) }) } func (bc *benchmarkClient) poissonStreaming(stream testgrpc.BenchmarkService_StreamingCallClient, idx int, reqSize int, respSize int, lambda float64, doRPC func(testgrpc.BenchmarkService_StreamingCallClient, int, int) error) { go func() { start := time.Now() + if err := doRPC(stream, reqSize, respSize); err != nil { return } @@ -430,6 +425,5 @@ func (bc *benchmarkClient) getStats(reset bool) *testpb.ClientStats { } func (bc *benchmarkClient) shutdown() { - close(bc.stop) bc.closeConns() } diff --git a/benchmark/worker/main.go b/benchmark/worker/main.go index 45893d7b15a2..785e8504e04d 100644 --- a/benchmark/worker/main.go +++ b/benchmark/worker/main.go @@ -163,7 +163,10 @@ func (s *workerServer) RunClient(stream testgrpc.WorkerService_RunClientServer) logger.Infof("client setup received when client already exists, shutting down the existing client") bc.shutdown() } - bc, err = startBenchmarkClient(t.Setup) + + ctx, cancel := context.WithCancel(stream.Context()) + defer cancel() + bc, err = startBenchmarkClient(ctx, t.Setup) if err != nil { return err }