diff --git a/pkg/mcs/metastorage/server/grpc_service.go b/pkg/mcs/metastorage/server/grpc_service.go index e9d35fbf14b..3da079e6109 100644 --- a/pkg/mcs/metastorage/server/grpc_service.go +++ b/pkg/mcs/metastorage/server/grpc_service.go @@ -86,7 +86,7 @@ func (s *Service) Watch(req *meta_storagepb.WatchRequest, server meta_storagepb. if err := s.checkServing(); err != nil { return err } - ctx, cancel := context.WithCancel(s.ctx) + ctx, cancel := context.WithCancel(server.Context()) defer cancel() options := []clientv3.OpOption{} key := string(req.GetKey()) @@ -106,6 +106,8 @@ func (s *Service) Watch(req *meta_storagepb.WatchRequest, server meta_storagepb. select { case <-ctx.Done(): return nil + case <-s.ctx.Done(): + return nil case res := <-watchChan: if res.Err() != nil { var resp meta_storagepb.WatchResponse diff --git a/server/grpc_service.go b/server/grpc_service.go index 0dbdcc8532f..38796fc7ff5 100644 --- a/server/grpc_service.go +++ b/server/grpc_service.go @@ -2810,7 +2810,7 @@ func (s *GrpcServer) WatchGlobalConfig(req *pdpb.WatchGlobalConfigRequest, serve return err } } - ctx, cancel := context.WithCancel(s.Context()) + ctx, cancel := context.WithCancel(server.Context()) defer cancel() configPath := req.GetConfigPath() if configPath == "" { @@ -2826,6 +2826,8 @@ func (s *GrpcServer) WatchGlobalConfig(req *pdpb.WatchGlobalConfigRequest, serve select { case <-ctx.Done(): return nil + case <-s.Context().Done(): + return nil case res := <-watchChan: if res.Err() != nil { var resp pdpb.WatchGlobalConfigResponse diff --git a/tests/integrations/client/global_config_test.go b/tests/integrations/client/global_config_test.go index aeb704c3305..c52a35159b0 100644 --- a/tests/integrations/client/global_config_test.go +++ b/tests/integrations/client/global_config_test.go @@ -15,6 +15,7 @@ package client_test import ( + "context" "path" "strconv" "testing" @@ -37,7 +38,8 @@ import ( const globalConfigPath = "/global/config/" type testReceiver struct { - re *require.Assertions + re *require.Assertions + ctx context.Context grpc.ServerStream } @@ -49,6 +51,10 @@ func (s testReceiver) Send(m *pdpb.WatchGlobalConfigResponse) error { return nil } +func (s testReceiver) Context() context.Context { + return s.ctx +} + type globalConfigTestSuite struct { suite.Suite server *server.GrpcServer @@ -199,7 +205,9 @@ func (suite *globalConfigTestSuite) TestWatch() { re.NoError(err) } }() - server := testReceiver{re: suite.Require()} + ctx, cancel := context.WithCancel(suite.server.Context()) + defer cancel() + server := testReceiver{re: suite.Require(), ctx: ctx} go suite.server.WatchGlobalConfig(&pdpb.WatchGlobalConfigRequest{ ConfigPath: globalConfigPath, Revision: 0,