diff --git a/internal/internal.go b/internal/internal.go index 2699223a27f1..3e1e2ec34617 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -244,6 +244,8 @@ var ( // When set, the function will be called before the stream enters // the blocking state. NewStreamWaitingForResolver = func() {} + + ActiveStreamTracker = func(created, deleted int) {} ) // HealthChecker defines the signature of the client-side LB channel health diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index 9f725e15a812..be1e56463b56 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -615,6 +615,7 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade } t.activeStreams[streamID] = s + internal.ActiveStreamTracker(1, 0) if len(t.activeStreams) == 1 { t.idle = time.Time{} } @@ -1310,6 +1311,7 @@ func (t *http2Server) deleteStream(s *ServerStream, eosReceived bool) { if len(t.activeStreams) == 0 { t.idle = time.Now() } + internal.ActiveStreamTracker(0, 1) } t.mu.Unlock() diff --git a/test/end2end_test.go b/test/end2end_test.go index b2f503990bc1..ee1866316e41 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -3946,6 +3946,77 @@ func (s) TestServerStreaming_ServerRecvZeroRequests(t *testing.T) { } } +func (s) TestServerStreaming_ClientTimeoutWithoutContextCancellation(t *testing.T) { + activeStreams := atomic.Int64{} + + internal.ActiveStreamTracker = func(created, deleted int) { + activeStreams.Add(int64(created)) + activeStreams.Add(-int64(deleted)) + } + + ss := &stubserver.StubServer{ + StreamingOutputCallF: func(req *testpb.StreamingOutputCallRequest, stream testpb.TestService_StreamingOutputCallServer) error { + // lets keep this busy until error + for { + if err := stream.Send(&testpb.StreamingOutputCallResponse{}); err != nil { + return err + } + } + }, + } + + if err := ss.Start([]grpc.ServerOption{grpc.MaxConcurrentStreams(10)}); err != nil { + t.Fatalf("Starting stubServer: %v", err) + } + defer ss.Stop() + + const numStreams = 100 + + wg := sync.WaitGroup{} + wg.Add(numStreams) + for j := 0; j < numStreams; j++ { + go func() { + defer wg.Done() + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + stream, err := ss.Client.StreamingOutputCall(ctx, &testpb.StreamingOutputCallRequest{}) + if err != nil { + return + } + defer stream.CloseSend() + + //let keep receiving the streams until timeout + for { + _, _ = stream.Recv() + select { + case <-ctx.Done(): + return + default: + time.Sleep(time.Second) + } + } + }() + } + wg.Wait() + + const ( + sleepEachLoop = 100 * time.Millisecond + loopCount = int(5 * time.Second / sleepEachLoop) + ) + + for i := 0; i < loopCount; i++ { + time.Sleep(sleepEachLoop) + if activeStreams.Load() == 0 { + break + } + } + + if activeStreams.Load() != 0 { + t.Fatalf("leak streams: %d", activeStreams.Load()) + } +} + // Tests the behavior of client for server-side streaming RPC when client sends zero request messages. func (s) TestServerStreaming_ClientSendsZeroRequests(t *testing.T) { t.Skip("blocked on i/7286")