Skip to content

Commit

Permalink
mcs: fix panic of forwarding request (#7220)
Browse files Browse the repository at this point in the history
ref #5839

Signed-off-by: Ryan Leung <rleungx@gmail.com>

Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com>
  • Loading branch information
rleungx and ti-chi-bot[bot] authored Oct 19, 2023
1 parent 43ff408 commit 445319f
Showing 1 changed file with 34 additions and 15 deletions.
49 changes: 34 additions & 15 deletions server/grpc_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ var (
ErrNotStarted = status.Errorf(codes.Unavailable, "server not started")
ErrSendHeartbeatTimeout = status.Errorf(codes.DeadlineExceeded, "send heartbeat timeout")
ErrNotFoundTSOAddr = status.Errorf(codes.NotFound, "not found tso address")
ErrNotFoundSchedulingAddr = status.Errorf(codes.NotFound, "not found scheduling address")
ErrForwardTSOTimeout = status.Errorf(codes.DeadlineExceeded, "forward tso request timeout")
ErrMaxCountTSOProxyRoutinesExceeded = status.Errorf(codes.ResourceExhausted, "max count of concurrent tso proxy routines exceeded")
ErrTSOProxyRecvFromClientTimeout = status.Errorf(codes.DeadlineExceeded, "tso proxy timeout when receiving from client; stream closed by server")
Expand Down Expand Up @@ -1002,16 +1003,16 @@ func (s *GrpcServer) StoreHeartbeat(ctx context.Context, request *pdpb.StoreHear
s.handleDamagedStore(request.GetStats())
storeHeartbeatHandleDuration.WithLabelValues(storeAddress, storeLabel).Observe(time.Since(start).Seconds())
if s.IsAPIServiceMode() {
s.updateSchedulingClient(ctx)
if s.schedulingClient.Load() != nil {
forwardCli, _ := s.updateSchedulingClient(ctx)
if forwardCli != nil {
req := &schedulingpb.StoreHeartbeatRequest{
Header: &schedulingpb.RequestHeader{
ClusterId: request.GetHeader().GetClusterId(),
SenderId: request.GetHeader().GetSenderId(),
},
Stats: request.GetStats(),
}
if _, err := s.schedulingClient.Load().(*schedulingClient).getClient().StoreHeartbeat(ctx, req); err != nil {
if _, err := forwardCli.StoreHeartbeat(ctx, req); err != nil {
// reset to let it be updated in the next request
s.schedulingClient.Store(&schedulingClient{})
}
Expand All @@ -1030,19 +1031,22 @@ func (s *GrpcServer) StoreHeartbeat(ctx context.Context, request *pdpb.StoreHear
return resp, nil
}

func (s *GrpcServer) updateSchedulingClient(ctx context.Context) {
func (s *GrpcServer) updateSchedulingClient(ctx context.Context) (schedulingpb.SchedulingClient, error) {
forwardedHost, _ := s.GetServicePrimaryAddr(ctx, utils.SchedulingServiceName)
pre := s.schedulingClient.Load()
if forwardedHost != "" && ((pre == nil) || (pre != nil && forwardedHost != pre.(*schedulingClient).getPrimaryAddr())) {
client, err := s.getDelegateClient(ctx, forwardedHost)
if err != nil {
log.Error("get delegate client failed", zap.Error(err))
}
s.schedulingClient.Store(&schedulingClient{
forwardCli := &schedulingClient{
client: schedulingpb.NewSchedulingClient(client),
lastPrimary: forwardedHost,
})
}
s.schedulingClient.Store(forwardCli)
return forwardCli.getClient(), nil
}
return nil, ErrNotFoundSchedulingAddr
}

// bucketHeartbeatServer wraps PD_ReportBucketsServer to ensure when any error
Expand Down Expand Up @@ -1791,8 +1795,13 @@ func (s *GrpcServer) PutClusterConfig(ctx context.Context, request *pdpb.PutClus
// ScatterRegion implements gRPC PDServer.
func (s *GrpcServer) ScatterRegion(ctx context.Context, request *pdpb.ScatterRegionRequest) (*pdpb.ScatterRegionResponse, error) {
if s.IsAPIServiceMode() {
s.updateSchedulingClient(ctx)
if s.schedulingClient.Load() != nil {
forwardCli, err := s.updateSchedulingClient(ctx)
if err != nil {
return &pdpb.ScatterRegionResponse{
Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()),
}, nil
}
if forwardCli != nil {
regionsID := request.GetRegionsId()
if len(regionsID) == 0 {
return &pdpb.ScatterRegionResponse{
Expand All @@ -1809,7 +1818,7 @@ func (s *GrpcServer) ScatterRegion(ctx context.Context, request *pdpb.ScatterReg
RetryLimit: request.GetRetryLimit(),
SkipStoreLimit: request.GetSkipStoreLimit(),
}
resp, err := s.schedulingClient.Load().(*schedulingClient).getClient().ScatterRegions(ctx, req)
resp, err := forwardCli.ScatterRegions(ctx, req)
if err != nil {
// reset to let it be updated in the next request
s.schedulingClient.Store(&schedulingClient{})
Expand Down Expand Up @@ -2010,16 +2019,21 @@ func (s *GrpcServer) UpdateServiceGCSafePoint(ctx context.Context, request *pdpb
// GetOperator gets information about the operator belonging to the specify region.
func (s *GrpcServer) GetOperator(ctx context.Context, request *pdpb.GetOperatorRequest) (*pdpb.GetOperatorResponse, error) {
if s.IsAPIServiceMode() {
s.updateSchedulingClient(ctx)
if s.schedulingClient.Load() != nil {
forwardCli, err := s.updateSchedulingClient(ctx)
if err != nil {
return &pdpb.GetOperatorResponse{
Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()),
}, nil
}
if forwardCli != nil {
req := &schedulingpb.GetOperatorRequest{
Header: &schedulingpb.RequestHeader{
ClusterId: request.GetHeader().GetClusterId(),
SenderId: request.GetHeader().GetSenderId(),
},
RegionId: request.GetRegionId(),
}
resp, err := s.schedulingClient.Load().(*schedulingClient).getClient().GetOperator(ctx, req)
resp, err := forwardCli.GetOperator(ctx, req)
if err != nil {
// reset to let it be updated in the next request
s.schedulingClient.Store(&schedulingClient{})
Expand Down Expand Up @@ -2268,8 +2282,13 @@ func (s *GrpcServer) SyncMaxTS(_ context.Context, request *pdpb.SyncMaxTSRequest
// SplitRegions split regions by the given split keys
func (s *GrpcServer) SplitRegions(ctx context.Context, request *pdpb.SplitRegionsRequest) (*pdpb.SplitRegionsResponse, error) {
if s.IsAPIServiceMode() {
s.updateSchedulingClient(ctx)
if s.schedulingClient.Load() != nil {
forwardCli, err := s.updateSchedulingClient(ctx)
if err != nil {
return &pdpb.SplitRegionsResponse{
Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()),
}, nil
}
if forwardCli != nil {
req := &schedulingpb.SplitRegionsRequest{
Header: &schedulingpb.RequestHeader{
ClusterId: request.GetHeader().GetClusterId(),
Expand All @@ -2278,7 +2297,7 @@ func (s *GrpcServer) SplitRegions(ctx context.Context, request *pdpb.SplitRegion
SplitKeys: request.GetSplitKeys(),
RetryLimit: request.GetRetryLimit(),
}
resp, err := s.schedulingClient.Load().(*schedulingClient).getClient().SplitRegions(ctx, req)
resp, err := forwardCli.SplitRegions(ctx, req)
if err != nil {
// reset to let it be updated in the next request
s.schedulingClient.Store(&schedulingClient{})
Expand Down

0 comments on commit 445319f

Please sign in to comment.