Skip to content

Commit

Permalink
*: follower support to handle GetRegion and other region api (#7432)
Browse files Browse the repository at this point in the history
ref #7431

Signed-off-by: Cabinfever_B <cabinfeveroier@gmail.com>

Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com>
  • Loading branch information
CabinfeverB and ti-chi-bot[bot] authored Dec 29, 2023
1 parent 7af8b81 commit a67ccbb
Show file tree
Hide file tree
Showing 6 changed files with 295 additions and 54 deletions.
87 changes: 66 additions & 21 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ func (k *serviceModeKeeper) close() {
type client struct {
keyspaceID uint32
svrUrls []string
pdSvcDiscovery ServiceDiscovery
pdSvcDiscovery *pdServiceDiscovery
tokenDispatcher *tokenDispatcher

// For service mode switching.
Expand Down Expand Up @@ -503,7 +503,7 @@ func newClientWithKeyspaceName(
return err
}
// c.keyspaceID is the source of truth for keyspace id.
c.pdSvcDiscovery.(*pdServiceDiscovery).SetKeyspaceID(c.keyspaceID)
c.pdSvcDiscovery.SetKeyspaceID(c.keyspaceID)
return nil
}

Expand Down Expand Up @@ -733,6 +733,23 @@ func (c *client) getClientAndContext(ctx context.Context) (pdpb.PDClient, contex
return pdpb.NewPDClient(serviceClient.GetClientConn()), serviceClient.BuildGRPCTargetContext(ctx, true)
}

// getClientAndContext returns the leader pd client and the original context. If leader is unhealthy, it returns
// follower pd client and the context which holds forward information.
func (c *client) getRegionAPIClientAndContext(ctx context.Context, allowFollower bool) (ServiceClient, context.Context) {
var serviceClient ServiceClient
if allowFollower {
serviceClient = c.pdSvcDiscovery.getServiceClientByKind(regionAPIKind)
if serviceClient != nil {
return serviceClient, serviceClient.BuildGRPCTargetContext(ctx, !allowFollower)
}
}
serviceClient = c.pdSvcDiscovery.GetServiceClient()
if serviceClient == nil {
return nil, ctx
}
return serviceClient, serviceClient.BuildGRPCTargetContext(ctx, !allowFollower)
}

func (c *client) GetTSAsync(ctx context.Context) TSFuture {
return c.GetLocalTSAsync(ctx, globalDCLocation)
}
Expand Down Expand Up @@ -885,6 +902,7 @@ func (c *client) GetRegion(ctx context.Context, key []byte, opts ...GetRegionOpt
start := time.Now()
defer func() { cmdDurationGetRegion.Observe(time.Since(start).Seconds()) }()
ctx, cancel := context.WithTimeout(ctx, c.option.timeout)
defer cancel()

options := &GetRegionOp{}
for _, opt := range opts {
Expand All @@ -895,13 +913,18 @@ func (c *client) GetRegion(ctx context.Context, key []byte, opts ...GetRegionOpt
RegionKey: key,
NeedBuckets: options.needBuckets,
}
protoClient, ctx := c.getClientAndContext(ctx)
if protoClient == nil {
cancel()
serviceClient, cctx := c.getRegionAPIClientAndContext(ctx, options.allowFollowerHandle && c.option.getEnableFollowerHandle())
if serviceClient == nil {
return nil, errs.ErrClientGetProtoClient
}
resp, err := protoClient.GetRegion(ctx, req)
cancel()
resp, err := pdpb.NewPDClient(serviceClient.GetClientConn()).GetRegion(cctx, req)
if serviceClient.NeedRetry(resp.GetHeader().GetError(), err) {
protoClient, cctx := c.getClientAndContext(ctx)
if protoClient == nil {
return nil, errs.ErrClientGetProtoClient
}
resp, err = protoClient.GetRegion(cctx, req)
}

if err = c.respForErr(cmdFailDurationGetRegion, start, err, resp.GetHeader()); err != nil {
return nil, err
Expand All @@ -917,6 +940,7 @@ func (c *client) GetPrevRegion(ctx context.Context, key []byte, opts ...GetRegio
start := time.Now()
defer func() { cmdDurationGetPrevRegion.Observe(time.Since(start).Seconds()) }()
ctx, cancel := context.WithTimeout(ctx, c.option.timeout)
defer cancel()

options := &GetRegionOp{}
for _, opt := range opts {
Expand All @@ -927,13 +951,18 @@ func (c *client) GetPrevRegion(ctx context.Context, key []byte, opts ...GetRegio
RegionKey: key,
NeedBuckets: options.needBuckets,
}
protoClient, ctx := c.getClientAndContext(ctx)
if protoClient == nil {
cancel()
serviceClient, cctx := c.getRegionAPIClientAndContext(ctx, options.allowFollowerHandle && c.option.getEnableFollowerHandle())
if serviceClient == nil {
return nil, errs.ErrClientGetProtoClient
}
resp, err := protoClient.GetPrevRegion(ctx, req)
cancel()
resp, err := pdpb.NewPDClient(serviceClient.GetClientConn()).GetPrevRegion(cctx, req)
if serviceClient.NeedRetry(resp.GetHeader().GetError(), err) {
protoClient, cctx := c.getClientAndContext(ctx)
if protoClient == nil {
return nil, errs.ErrClientGetProtoClient
}
resp, err = protoClient.GetPrevRegion(cctx, req)
}

if err = c.respForErr(cmdFailDurationGetPrevRegion, start, err, resp.GetHeader()); err != nil {
return nil, err
Expand All @@ -949,6 +978,7 @@ func (c *client) GetRegionByID(ctx context.Context, regionID uint64, opts ...Get
start := time.Now()
defer func() { cmdDurationGetRegionByID.Observe(time.Since(start).Seconds()) }()
ctx, cancel := context.WithTimeout(ctx, c.option.timeout)
defer cancel()

options := &GetRegionOp{}
for _, opt := range opts {
Expand All @@ -959,13 +989,18 @@ func (c *client) GetRegionByID(ctx context.Context, regionID uint64, opts ...Get
RegionId: regionID,
NeedBuckets: options.needBuckets,
}
protoClient, ctx := c.getClientAndContext(ctx)
if protoClient == nil {
cancel()
serviceClient, cctx := c.getRegionAPIClientAndContext(ctx, options.allowFollowerHandle && c.option.getEnableFollowerHandle())
if serviceClient == nil {
return nil, errs.ErrClientGetProtoClient
}
resp, err := protoClient.GetRegionByID(ctx, req)
cancel()
resp, err := pdpb.NewPDClient(serviceClient.GetClientConn()).GetRegionByID(cctx, req)
if serviceClient.NeedRetry(resp.GetHeader().GetError(), err) {
protoClient, cctx := c.getClientAndContext(ctx)
if protoClient == nil {
return nil, errs.ErrClientGetProtoClient
}
resp, err = protoClient.GetRegionByID(cctx, req)
}

if err = c.respForErr(cmdFailedDurationGetRegionByID, start, err, resp.GetHeader()); err != nil {
return nil, err
Expand All @@ -987,18 +1022,28 @@ func (c *client) ScanRegions(ctx context.Context, key, endKey []byte, limit int,
scanCtx, cancel = context.WithTimeout(ctx, c.option.timeout)
defer cancel()
}
options := &GetRegionOp{}
for _, opt := range opts {
opt(options)
}
req := &pdpb.ScanRegionsRequest{
Header: c.requestHeader(),
StartKey: key,
EndKey: endKey,
Limit: int32(limit),
}
protoClient, scanCtx := c.getClientAndContext(scanCtx)
if protoClient == nil {
cancel()
serviceClient, cctx := c.getRegionAPIClientAndContext(scanCtx, options.allowFollowerHandle && c.option.getEnableFollowerHandle())
if serviceClient == nil {
return nil, errs.ErrClientGetProtoClient
}
resp, err := protoClient.ScanRegions(scanCtx, req)
resp, err := pdpb.NewPDClient(serviceClient.GetClientConn()).ScanRegions(cctx, req)
if !serviceClient.IsConnectedToLeader() && err != nil || resp.Header.GetError() != nil {
protoClient, cctx := c.getClientAndContext(scanCtx)
if protoClient == nil {
return nil, errs.ErrClientGetProtoClient
}
resp, err = protoClient.ScanRegions(cctx, req)
}

if err = c.respForErr(cmdFailedDurationScanRegions, start, err, resp.GetHeader()); err != nil {
return nil, err
Expand Down
4 changes: 3 additions & 1 deletion client/pd_service_discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ type apiKind int

const (
forwardAPIKind apiKind = iota
regionAPIKind
apiKindCount
)

Expand Down Expand Up @@ -445,7 +446,7 @@ func newPDServiceDiscovery(
ctx: ctx,
cancel: cancel,
wg: wg,
apiCandidateNodes: [apiKindCount]*pdServiceBalancer{newPDServiceBalancer(emptyErrorFn)},
apiCandidateNodes: [apiKindCount]*pdServiceBalancer{newPDServiceBalancer(emptyErrorFn), newPDServiceBalancer(regionAPIErrorFn)},
serviceModeUpdateCb: serviceModeUpdateCb,
updateKeyspaceIDCb: updateKeyspaceIDCb,
keyspaceID: keyspaceID,
Expand Down Expand Up @@ -563,6 +564,7 @@ func (c *pdServiceDiscovery) updateServiceModeLoop() {
}
}
}

func (c *pdServiceDiscovery) memberHealthCheckLoop() {
defer c.wg.Done()

Expand Down
15 changes: 14 additions & 1 deletion pkg/utils/grpcutil/grpcutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ import (
const (
// ForwardMetadataKey is used to record the forwarded host of PD.
ForwardMetadataKey = "pd-forwarded-host"
// FollowerHandleMetadataKey is used to mark the permit of follower handle.
FollowerHandleMetadataKey = "pd-allow-follower-handle"
)

// TLSConfig is the configuration for supporting tls.
Expand Down Expand Up @@ -173,7 +175,7 @@ func ResetForwardContext(ctx context.Context) context.Context {
func GetForwardedHost(ctx context.Context) string {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
log.Debug("failed to get forwarding metadata")
log.Debug("failed to get gRPC incoming metadata when getting forwarded host")
return ""
}
if t, ok := md[ForwardMetadataKey]; ok {
Expand All @@ -182,6 +184,17 @@ func GetForwardedHost(ctx context.Context) string {
return ""
}

// IsFollowerHandleEnabled returns the follower host in metadata.
func IsFollowerHandleEnabled(ctx context.Context) bool {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
log.Debug("failed to get gRPC incoming metadata when checking follower handle is enabled")
return false
}
_, ok = md[FollowerHandleMetadataKey]
return ok
}

func establish(ctx context.Context, addr string, tlsConfig *TLSConfig, do ...grpc.DialOption) (*grpc.ClientConn, error) {
tlsCfg, err := tlsConfig.ToTLSConfig()
if err != nil {
Expand Down
6 changes: 3 additions & 3 deletions server/forward.go
Original file line number Diff line number Diff line change
Expand Up @@ -384,16 +384,16 @@ func (s *GrpcServer) getForwardedHost(ctx, streamCtx context.Context, serviceNam
return forwardedHost, nil
}

func (s *GrpcServer) isLocalRequest(forwardedHost string) bool {
func (s *GrpcServer) isLocalRequest(host string) bool {
failpoint.Inject("useForwardRequest", func() {
failpoint.Return(false)
})
if forwardedHost == "" {
if host == "" {
return true
}
memberAddrs := s.GetMember().Member().GetClientUrls()
for _, addr := range memberAddrs {
if addr == forwardedHost {
if addr == host {
return true
}
}
Expand Down
Loading

0 comments on commit a67ccbb

Please sign in to comment.