diff --git a/client/client.go b/client/client.go index 272d6c597b5..684b2933a5f 100644 --- a/client/client.go +++ b/client/client.go @@ -25,8 +25,6 @@ import ( "github.com/opentracing/opentracing-go" "github.com/prometheus/client_golang/prometheus" "go.uber.org/zap" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" "github.com/pingcap/errors" "github.com/pingcap/failpoint" @@ -42,7 +40,6 @@ import ( "github.com/tikv/pd/client/metrics" "github.com/tikv/pd/client/opt" "github.com/tikv/pd/client/pkg/caller" - cb "github.com/tikv/pd/client/pkg/circuitbreaker" "github.com/tikv/pd/client/pkg/utils/tlsutil" sd "github.com/tikv/pd/client/servicediscovery" ) @@ -461,12 +458,6 @@ func (c *client) UpdateOption(option opt.DynamicOption, value any) error { return errors.New("[pd] invalid value type for TSOClientRPCConcurrency option, it should be int") } c.inner.option.SetTSOClientRPCConcurrency(value) - case opt.RegionMetadataCircuitBreakerSettings: - applySettingsChange, ok := value.(func(config *cb.Settings)) - if !ok { - return errors.New("[pd] invalid value type for RegionMetadataCircuitBreakerSettings option, it should be pd.Settings") - } - c.inner.regionMetaCircuitBreaker.ChangeSettings(applySettingsChange) default: return errors.New("[pd] unsupported client option") } @@ -661,13 +652,7 @@ func (c *client) GetRegion(ctx context.Context, key []byte, opts ...opt.GetRegio if serviceClient == nil { return nil, errs.ErrClientGetProtoClient } - resp, err := c.inner.regionMetaCircuitBreaker.Execute(func() (*pdpb.GetRegionResponse, cb.Overloading, error) { - region, err := pdpb.NewPDClient(serviceClient.GetClientConn()).GetRegion(cctx, req) - failpoint.Inject("triggerCircuitBreaker", func() { - err = status.Error(codes.ResourceExhausted, "resource exhausted") - }) - return region, isOverloaded(err), err - }) + resp, err := pdpb.NewPDClient(serviceClient.GetClientConn()).GetRegion(cctx, req) if serviceClient.NeedRetry(resp.GetHeader().GetError(), err) { protoClient, cctx := c.getClientAndContext(ctx) if protoClient == nil { @@ -707,10 +692,7 @@ func (c *client) GetPrevRegion(ctx context.Context, key []byte, opts ...opt.GetR if serviceClient == nil { return nil, errs.ErrClientGetProtoClient } - resp, err := c.inner.regionMetaCircuitBreaker.Execute(func() (*pdpb.GetRegionResponse, cb.Overloading, error) { - resp, err := pdpb.NewPDClient(serviceClient.GetClientConn()).GetPrevRegion(cctx, req) - return resp, isOverloaded(err), err - }) + resp, err := pdpb.NewPDClient(serviceClient.GetClientConn()).GetPrevRegion(cctx, req) if serviceClient.NeedRetry(resp.GetHeader().GetError(), err) { protoClient, cctx := c.getClientAndContext(ctx) if protoClient == nil { @@ -750,10 +732,8 @@ func (c *client) GetRegionByID(ctx context.Context, regionID uint64, opts ...opt if serviceClient == nil { return nil, errs.ErrClientGetProtoClient } - resp, err := c.inner.regionMetaCircuitBreaker.Execute(func() (*pdpb.GetRegionResponse, cb.Overloading, error) { - resp, err := pdpb.NewPDClient(serviceClient.GetClientConn()).GetRegionByID(cctx, req) - return resp, isOverloaded(err), err - }) + + resp, err := pdpb.NewPDClient(serviceClient.GetClientConn()).GetRegionByID(cctx, req) if serviceClient.NeedRetry(resp.GetHeader().GetError(), err) { protoClient, cctx := c.getClientAndContext(ctx) if protoClient == nil { diff --git a/client/opt/option.go b/client/opt/option.go index af95a225fab..8d4f5955384 100644 --- a/client/opt/option.go +++ b/client/opt/option.go @@ -50,8 +50,6 @@ const ( EnableFollowerHandle // TSOClientRPCConcurrency controls the amount of ongoing TSO RPC requests at the same time in a single TSO client. TSOClientRPCConcurrency - // RegionMetadataCircuitBreakerSettings controls settings for circuit breaker for region metadata requests. - RegionMetadataCircuitBreakerSettings dynamicOptionCount ) @@ -154,11 +152,6 @@ func (o *Option) GetTSOClientRPCConcurrency() int { return o.dynamicOptions[TSOClientRPCConcurrency].Load().(int) } -// GetRegionMetadataCircuitBreakerSettings gets circuit breaker settings for PD region metadata calls. -func (o *Option) GetRegionMetadataCircuitBreakerSettings() cb.Settings { - return o.dynamicOptions[RegionMetadataCircuitBreakerSettings].Load().(cb.Settings) -} - // ClientOption configures client. type ClientOption func(*Option) diff --git a/client/pkg/circuitbreaker/circuit_breaker.go b/client/pkg/circuitbreaker/circuit_breaker.go index 2c65f4f1965..5229da197ec 100644 --- a/client/pkg/circuitbreaker/circuit_breaker.go +++ b/client/pkg/circuitbreaker/circuit_breaker.go @@ -14,6 +14,7 @@ package circuitbreaker import ( + "context" "fmt" "strings" "sync" @@ -309,3 +310,25 @@ func (s *State[T]) onResult(overloaded Overloading) { panic("unknown state") } } + +// Define context key type +type cbCtxKey struct{} + +// Key used to store circuit breaker +var CircuitBreakerKey = cbCtxKey{} + +// FromContext retrieves the circuit breaker from the context +func FromContext[T any](ctx context.Context) *CircuitBreaker[T] { + if ctx == nil { + return nil + } + if cb, ok := ctx.Value(CircuitBreakerKey).(*CircuitBreaker[T]); ok { + return cb + } + return nil +} + +// WithCircuitBreaker stores the circuit breaker into a new context +func WithCircuitBreaker[T any](ctx context.Context, cb *CircuitBreaker[T]) context.Context { + return context.WithValue(ctx, CircuitBreakerKey, cb) +} diff --git a/client/pkg/utils/grpcutil/grpcutil.go b/client/pkg/utils/grpcutil/grpcutil.go index b73d117fe84..637d24b4eeb 100644 --- a/client/pkg/utils/grpcutil/grpcutil.go +++ b/client/pkg/utils/grpcutil/grpcutil.go @@ -24,15 +24,19 @@ import ( "go.uber.org/zap" "google.golang.org/grpc" "google.golang.org/grpc/backoff" + "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" "github.com/pingcap/errors" "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/pdpb" "github.com/pingcap/log" "github.com/tikv/pd/client/errs" + "github.com/tikv/pd/client/pkg/circuitbreaker" "github.com/tikv/pd/client/pkg/retry" ) @@ -71,6 +75,30 @@ func UnaryBackofferInterceptor() grpc.UnaryClientInterceptor { } } +func UnaryCircuitBreakerInterceptor[T any]() grpc.UnaryClientInterceptor { + return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + cb := circuitbreaker.FromContext[*pdpb.GetRegionResponse](ctx) + if cb == nil { + return invoker(ctx, method, req, reply, cc, opts...) + } + _, err := cb.Execute(func() (*pdpb.GetRegionResponse, circuitbreaker.Overloading, error) { + err := invoker(ctx, method, req, reply, cc, opts...) + failpoint.Inject("triggerCircuitBreaker", func() { + err = status.Error(codes.ResourceExhausted, "resource exhausted") + }) + var zero *pdpb.GetRegionResponse + if err != nil { + return zero, circuitbreaker.Yes, err + } + return zero, circuitbreaker.No, nil + }) + if err != nil { + return err + } + return nil + } +} + // GetClientConn returns a gRPC client connection. // creates a client connection to the given target. By default, it's // a non-blocking dial (the function won't wait for connections to be @@ -96,7 +124,10 @@ func GetClientConn(ctx context.Context, addr string, tlsCfg *tls.Config, do ...g } // Add backoffer interceptor - retryOpt := grpc.WithUnaryInterceptor(UnaryBackofferInterceptor()) + retryOpt := grpc.WithChainUnaryInterceptor(UnaryBackofferInterceptor()) + + // Add circuit breaker interceptor + cbOpt := grpc.WithChainUnaryInterceptor(UnaryCircuitBreakerInterceptor[any]()) // Add retry related connection parameters backoffOpts := grpc.WithConnectParams(grpc.ConnectParams{ @@ -108,7 +139,7 @@ func GetClientConn(ctx context.Context, addr string, tlsCfg *tls.Config, do ...g }, }) - do = append(do, opt, retryOpt, backoffOpts) + do = append(do, opt, retryOpt, cbOpt, backoffOpts) cc, err := grpc.DialContext(ctx, u.Host, do...) if err != nil { return nil, errs.ErrGRPCDial.Wrap(err).GenWithStackByCause() diff --git a/tests/integrations/client/client_test.go b/tests/integrations/client/client_test.go index fadfb952e4c..cdab9e81991 100644 --- a/tests/integrations/client/client_test.go +++ b/tests/integrations/client/client_test.go @@ -2070,28 +2070,30 @@ func TestCircuitBreaker(t *testing.T) { } endpoints := runServer(re, cluster) - cli := setupCli(ctx, re, endpoints, opt.WithRegionMetaCircuitBreaker(circuitBreakerSettings)) + cli := setupCli(ctx, re, endpoints) defer cli.Close() + circuitBreaker := cb.NewCircuitBreaker[*pdpb.GetRegionResponse]("region_meta", circuitBreakerSettings) + ctx = cb.WithCircuitBreaker(ctx, circuitBreaker) for range 10 { - region, err := cli.GetRegion(context.TODO(), []byte("a")) + region, err := cli.GetRegion(ctx, []byte("a")) re.NoError(err) re.NotNil(region) } - re.NoError(failpoint.Enable("github.com/tikv/pd/client/triggerCircuitBreaker", "return(true)")) + re.NoError(failpoint.Enable("github.com/tikv/pd/client/pkg/utils/grpcutil/triggerCircuitBreaker", "return(true)")) for range 100 { - _, err := cli.GetRegion(context.TODO(), []byte("a")) + _, err := cli.GetRegion(ctx, []byte("a")) re.Error(err) } - _, err = cli.GetRegion(context.TODO(), []byte("a")) + _, err = cli.GetRegion(ctx, []byte("a")) re.Error(err) re.Contains(err.Error(), "circuit breaker is open") - re.NoError(failpoint.Disable("github.com/tikv/pd/client/triggerCircuitBreaker")) + re.NoError(failpoint.Disable("github.com/tikv/pd/client/pkg/utils/grpcutil/triggerCircuitBreaker")) - _, err = cli.GetRegion(context.TODO(), []byte("a")) + _, err = cli.GetRegion(ctx, []byte("a")) re.Error(err) re.Contains(err.Error(), "circuit breaker is open") @@ -2099,7 +2101,7 @@ func TestCircuitBreaker(t *testing.T) { time.Sleep(time.Second) for range 10 { - region, err := cli.GetRegion(context.TODO(), []byte("a")) + region, err := cli.GetRegion(ctx, []byte("a")) re.NoError(err) re.NotNil(region) } @@ -2123,34 +2125,35 @@ func TestCircuitBreakerOpenAndChangeSettings(t *testing.T) { } endpoints := runServer(re, cluster) - cli := setupCli(ctx, re, endpoints, opt.WithRegionMetaCircuitBreaker(circuitBreakerSettings)) + cli := setupCli(ctx, re, endpoints) defer cli.Close() + circuitBreaker := cb.NewCircuitBreaker[*pdpb.GetRegionResponse]("region_meta", circuitBreakerSettings) + ctx = cb.WithCircuitBreaker(ctx, circuitBreaker) for range 10 { - region, err := cli.GetRegion(context.TODO(), []byte("a")) + region, err := cli.GetRegion(ctx, []byte("a")) re.NoError(err) re.NotNil(region) } - re.NoError(failpoint.Enable("github.com/tikv/pd/client/triggerCircuitBreaker", "return(true)")) + re.NoError(failpoint.Enable("github.com/tikv/pd/client/pkg/utils/grpcutil/triggerCircuitBreaker", "return(true)")) for range 100 { - _, err := cli.GetRegion(context.TODO(), []byte("a")) + _, err := cli.GetRegion(ctx, []byte("a")) re.Error(err) } - _, err = cli.GetRegion(context.TODO(), []byte("a")) + _, err = cli.GetRegion(ctx, []byte("a")) re.Error(err) re.Contains(err.Error(), "circuit breaker is open") - cli.UpdateOption(opt.RegionMetadataCircuitBreakerSettings, func(config *cb.Settings) { + circuitBreaker.ChangeSettings(func(config *cb.Settings) { *config = cb.AlwaysClosedSettings }) - - _, err = cli.GetRegion(context.TODO(), []byte("a")) + _, err = cli.GetRegion(ctx, []byte("a")) re.Error(err) re.Contains(err.Error(), "ResourceExhausted") - re.NoError(failpoint.Disable("github.com/tikv/pd/client/triggerCircuitBreaker")) + re.NoError(failpoint.Disable("github.com/tikv/pd/client/pkg/utils/grpcutil/triggerCircuitBreaker")) } func TestCircuitBreakerHalfOpenAndChangeSettings(t *testing.T) { @@ -2171,23 +2174,26 @@ func TestCircuitBreakerHalfOpenAndChangeSettings(t *testing.T) { } endpoints := runServer(re, cluster) - cli := setupCli(ctx, re, endpoints, opt.WithRegionMetaCircuitBreaker(circuitBreakerSettings)) + + cli := setupCli(ctx, re, endpoints) defer cli.Close() + circuitBreaker := cb.NewCircuitBreaker[*pdpb.GetRegionResponse]("region_meta", circuitBreakerSettings) + ctx = cb.WithCircuitBreaker(ctx, circuitBreaker) for range 10 { - region, err := cli.GetRegion(context.TODO(), []byte("a")) + region, err := cli.GetRegion(ctx, []byte("a")) re.NoError(err) re.NotNil(region) } - re.NoError(failpoint.Enable("github.com/tikv/pd/client/triggerCircuitBreaker", "return(true)")) + re.NoError(failpoint.Enable("github.com/tikv/pd/client/pkg/utils/grpcutil/triggerCircuitBreaker", "return(true)")) for range 100 { - _, err := cli.GetRegion(context.TODO(), []byte("a")) + _, err := cli.GetRegion(ctx, []byte("a")) re.Error(err) } - _, err = cli.GetRegion(context.TODO(), []byte("a")) + _, err = cli.GetRegion(ctx, []byte("a")) re.Error(err) re.Contains(err.Error(), "circuit breaker is open") @@ -2195,9 +2201,9 @@ func TestCircuitBreakerHalfOpenAndChangeSettings(t *testing.T) { defer os.RemoveAll(fname) // wait for cooldown time.Sleep(time.Second) - re.NoError(failpoint.Disable("github.com/tikv/pd/client/triggerCircuitBreaker")) + re.NoError(failpoint.Disable("github.com/tikv/pd/client/pkg/utils/grpcutil/triggerCircuitBreaker")) // trigger circuit breaker state to be half open - _, err = cli.GetRegion(context.TODO(), []byte("a")) + _, err = cli.GetRegion(ctx, []byte("a")) re.NoError(err) testutil.Eventually(re, func() bool { b, _ := os.ReadFile(fname) @@ -2207,17 +2213,16 @@ func TestCircuitBreakerHalfOpenAndChangeSettings(t *testing.T) { }) // The state is half open - re.NoError(failpoint.Enable("github.com/tikv/pd/client/triggerCircuitBreaker", "return(true)")) + re.NoError(failpoint.Enable("github.com/tikv/pd/client/pkg/utils/grpcutil/triggerCircuitBreaker", "return(true)")) // change settings to always closed - cli.UpdateOption(opt.RegionMetadataCircuitBreakerSettings, func(config *cb.Settings) { + circuitBreaker.ChangeSettings(func(config *cb.Settings) { *config = cb.AlwaysClosedSettings }) - // It won't be changed to open state. for range 100 { - _, err := cli.GetRegion(context.TODO(), []byte("a")) + _, err := cli.GetRegion(ctx, []byte("a")) re.Error(err) re.NotContains(err.Error(), "circuit breaker is open") } - re.NoError(failpoint.Disable("github.com/tikv/pd/client/triggerCircuitBreaker")) + re.NoError(failpoint.Disable("github.com/tikv/pd/client/pkg/utils/grpcutil/triggerCircuitBreaker")) }