Skip to content

Commit

Permalink
use interceptor for circuit breaker
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan Leung <rleungx@gmail.com>
  • Loading branch information
rleungx committed Dec 18, 2024
1 parent ecb31de commit f1aeab1
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 62 deletions.
28 changes: 4 additions & 24 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ import (
"github.com/opentracing/opentracing-go"
"github.com/prometheus/client_golang/prometheus"
"go.uber.org/zap"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
Expand All @@ -42,7 +40,6 @@ import (
"github.com/tikv/pd/client/metrics"
"github.com/tikv/pd/client/opt"
"github.com/tikv/pd/client/pkg/caller"
cb "github.com/tikv/pd/client/pkg/circuitbreaker"
"github.com/tikv/pd/client/pkg/utils/tlsutil"
sd "github.com/tikv/pd/client/servicediscovery"
)
Expand Down Expand Up @@ -461,12 +458,6 @@ func (c *client) UpdateOption(option opt.DynamicOption, value any) error {
return errors.New("[pd] invalid value type for TSOClientRPCConcurrency option, it should be int")
}
c.inner.option.SetTSOClientRPCConcurrency(value)
case opt.RegionMetadataCircuitBreakerSettings:
applySettingsChange, ok := value.(func(config *cb.Settings))
if !ok {
return errors.New("[pd] invalid value type for RegionMetadataCircuitBreakerSettings option, it should be pd.Settings")
}
c.inner.regionMetaCircuitBreaker.ChangeSettings(applySettingsChange)
default:
return errors.New("[pd] unsupported client option")
}
Expand Down Expand Up @@ -661,13 +652,7 @@ func (c *client) GetRegion(ctx context.Context, key []byte, opts ...opt.GetRegio
if serviceClient == nil {
return nil, errs.ErrClientGetProtoClient
}
resp, err := c.inner.regionMetaCircuitBreaker.Execute(func() (*pdpb.GetRegionResponse, cb.Overloading, error) {
region, err := pdpb.NewPDClient(serviceClient.GetClientConn()).GetRegion(cctx, req)
failpoint.Inject("triggerCircuitBreaker", func() {
err = status.Error(codes.ResourceExhausted, "resource exhausted")
})
return region, isOverloaded(err), err
})
resp, err := pdpb.NewPDClient(serviceClient.GetClientConn()).GetRegion(cctx, req)
if serviceClient.NeedRetry(resp.GetHeader().GetError(), err) {
protoClient, cctx := c.getClientAndContext(ctx)
if protoClient == nil {
Expand Down Expand Up @@ -707,10 +692,7 @@ func (c *client) GetPrevRegion(ctx context.Context, key []byte, opts ...opt.GetR
if serviceClient == nil {
return nil, errs.ErrClientGetProtoClient
}
resp, err := c.inner.regionMetaCircuitBreaker.Execute(func() (*pdpb.GetRegionResponse, cb.Overloading, error) {
resp, err := pdpb.NewPDClient(serviceClient.GetClientConn()).GetPrevRegion(cctx, req)
return resp, isOverloaded(err), err
})
resp, err := pdpb.NewPDClient(serviceClient.GetClientConn()).GetPrevRegion(cctx, req)
if serviceClient.NeedRetry(resp.GetHeader().GetError(), err) {
protoClient, cctx := c.getClientAndContext(ctx)
if protoClient == nil {
Expand Down Expand Up @@ -750,10 +732,8 @@ func (c *client) GetRegionByID(ctx context.Context, regionID uint64, opts ...opt
if serviceClient == nil {
return nil, errs.ErrClientGetProtoClient
}
resp, err := c.inner.regionMetaCircuitBreaker.Execute(func() (*pdpb.GetRegionResponse, cb.Overloading, error) {
resp, err := pdpb.NewPDClient(serviceClient.GetClientConn()).GetRegionByID(cctx, req)
return resp, isOverloaded(err), err
})

resp, err := pdpb.NewPDClient(serviceClient.GetClientConn()).GetRegionByID(cctx, req)
if serviceClient.NeedRetry(resp.GetHeader().GetError(), err) {
protoClient, cctx := c.getClientAndContext(ctx)
if protoClient == nil {
Expand Down
7 changes: 0 additions & 7 deletions client/opt/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ const (
EnableFollowerHandle
// TSOClientRPCConcurrency controls the amount of ongoing TSO RPC requests at the same time in a single TSO client.
TSOClientRPCConcurrency
// RegionMetadataCircuitBreakerSettings controls settings for circuit breaker for region metadata requests.
RegionMetadataCircuitBreakerSettings

dynamicOptionCount
)
Expand Down Expand Up @@ -154,11 +152,6 @@ func (o *Option) GetTSOClientRPCConcurrency() int {
return o.dynamicOptions[TSOClientRPCConcurrency].Load().(int)
}

// GetRegionMetadataCircuitBreakerSettings gets circuit breaker settings for PD region metadata calls.
func (o *Option) GetRegionMetadataCircuitBreakerSettings() cb.Settings {
return o.dynamicOptions[RegionMetadataCircuitBreakerSettings].Load().(cb.Settings)
}

// ClientOption configures client.
type ClientOption func(*Option)

Expand Down
23 changes: 23 additions & 0 deletions client/pkg/circuitbreaker/circuit_breaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package circuitbreaker

import (
"context"
"fmt"
"strings"
"sync"
Expand Down Expand Up @@ -309,3 +310,25 @@ func (s *State[T]) onResult(overloaded Overloading) {
panic("unknown state")
}
}

// Define context key type
type cbCtxKey struct{}

// Key used to store circuit breaker
var CircuitBreakerKey = cbCtxKey{}

// FromContext retrieves the circuit breaker from the context
func FromContext[T any](ctx context.Context) *CircuitBreaker[T] {
if ctx == nil {
return nil
}
if cb, ok := ctx.Value(CircuitBreakerKey).(*CircuitBreaker[T]); ok {
return cb
}
return nil
}

// WithCircuitBreaker stores the circuit breaker into a new context
func WithCircuitBreaker[T any](ctx context.Context, cb *CircuitBreaker[T]) context.Context {
return context.WithValue(ctx, CircuitBreakerKey, cb)
}
35 changes: 33 additions & 2 deletions client/pkg/utils/grpcutil/grpcutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,19 @@ import (
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/backoff"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"

"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
"github.com/pingcap/kvproto/pkg/pdpb"
"github.com/pingcap/log"

"github.com/tikv/pd/client/errs"
"github.com/tikv/pd/client/pkg/circuitbreaker"
"github.com/tikv/pd/client/pkg/retry"
)

Expand Down Expand Up @@ -71,6 +75,30 @@ func UnaryBackofferInterceptor() grpc.UnaryClientInterceptor {
}
}

func UnaryCircuitBreakerInterceptor[T any]() grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
cb := circuitbreaker.FromContext[*pdpb.GetRegionResponse](ctx)
if cb == nil {
return invoker(ctx, method, req, reply, cc, opts...)
}
_, err := cb.Execute(func() (*pdpb.GetRegionResponse, circuitbreaker.Overloading, error) {
err := invoker(ctx, method, req, reply, cc, opts...)
failpoint.Inject("triggerCircuitBreaker", func() {
err = status.Error(codes.ResourceExhausted, "resource exhausted")
})
var zero *pdpb.GetRegionResponse
if err != nil {
return zero, circuitbreaker.Yes, err
}
return zero, circuitbreaker.No, nil
})
if err != nil {
return err
}
return nil
}
}

// GetClientConn returns a gRPC client connection.
// creates a client connection to the given target. By default, it's
// a non-blocking dial (the function won't wait for connections to be
Expand All @@ -96,7 +124,10 @@ func GetClientConn(ctx context.Context, addr string, tlsCfg *tls.Config, do ...g
}

// Add backoffer interceptor
retryOpt := grpc.WithUnaryInterceptor(UnaryBackofferInterceptor())
retryOpt := grpc.WithChainUnaryInterceptor(UnaryBackofferInterceptor())

// Add circuit breaker interceptor
cbOpt := grpc.WithChainUnaryInterceptor(UnaryCircuitBreakerInterceptor[any]())

// Add retry related connection parameters
backoffOpts := grpc.WithConnectParams(grpc.ConnectParams{
Expand All @@ -108,7 +139,7 @@ func GetClientConn(ctx context.Context, addr string, tlsCfg *tls.Config, do ...g
},
})

do = append(do, opt, retryOpt, backoffOpts)
do = append(do, opt, retryOpt, cbOpt, backoffOpts)
cc, err := grpc.DialContext(ctx, u.Host, do...)
if err != nil {
return nil, errs.ErrGRPCDial.Wrap(err).GenWithStackByCause()
Expand Down
63 changes: 34 additions & 29 deletions tests/integrations/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2070,36 +2070,38 @@ func TestCircuitBreaker(t *testing.T) {
}

endpoints := runServer(re, cluster)
cli := setupCli(ctx, re, endpoints, opt.WithRegionMetaCircuitBreaker(circuitBreakerSettings))
cli := setupCli(ctx, re, endpoints)
defer cli.Close()

circuitBreaker := cb.NewCircuitBreaker[*pdpb.GetRegionResponse]("region_meta", circuitBreakerSettings)
ctx = cb.WithCircuitBreaker(ctx, circuitBreaker)
for range 10 {
region, err := cli.GetRegion(context.TODO(), []byte("a"))
region, err := cli.GetRegion(ctx, []byte("a"))
re.NoError(err)
re.NotNil(region)
}

re.NoError(failpoint.Enable("github.com/tikv/pd/client/triggerCircuitBreaker", "return(true)"))
re.NoError(failpoint.Enable("github.com/tikv/pd/client/pkg/utils/grpcutil/triggerCircuitBreaker", "return(true)"))

for range 100 {
_, err := cli.GetRegion(context.TODO(), []byte("a"))
_, err := cli.GetRegion(ctx, []byte("a"))
re.Error(err)
}

_, err = cli.GetRegion(context.TODO(), []byte("a"))
_, err = cli.GetRegion(ctx, []byte("a"))
re.Error(err)
re.Contains(err.Error(), "circuit breaker is open")
re.NoError(failpoint.Disable("github.com/tikv/pd/client/triggerCircuitBreaker"))
re.NoError(failpoint.Disable("github.com/tikv/pd/client/pkg/utils/grpcutil/triggerCircuitBreaker"))

_, err = cli.GetRegion(context.TODO(), []byte("a"))
_, err = cli.GetRegion(ctx, []byte("a"))
re.Error(err)
re.Contains(err.Error(), "circuit breaker is open")

// wait cooldown
time.Sleep(time.Second)

for range 10 {
region, err := cli.GetRegion(context.TODO(), []byte("a"))
region, err := cli.GetRegion(ctx, []byte("a"))
re.NoError(err)
re.NotNil(region)
}
Expand All @@ -2123,34 +2125,35 @@ func TestCircuitBreakerOpenAndChangeSettings(t *testing.T) {
}

endpoints := runServer(re, cluster)
cli := setupCli(ctx, re, endpoints, opt.WithRegionMetaCircuitBreaker(circuitBreakerSettings))
cli := setupCli(ctx, re, endpoints)
defer cli.Close()

circuitBreaker := cb.NewCircuitBreaker[*pdpb.GetRegionResponse]("region_meta", circuitBreakerSettings)
ctx = cb.WithCircuitBreaker(ctx, circuitBreaker)
for range 10 {
region, err := cli.GetRegion(context.TODO(), []byte("a"))
region, err := cli.GetRegion(ctx, []byte("a"))
re.NoError(err)
re.NotNil(region)
}

re.NoError(failpoint.Enable("github.com/tikv/pd/client/triggerCircuitBreaker", "return(true)"))
re.NoError(failpoint.Enable("github.com/tikv/pd/client/pkg/utils/grpcutil/triggerCircuitBreaker", "return(true)"))

for range 100 {
_, err := cli.GetRegion(context.TODO(), []byte("a"))
_, err := cli.GetRegion(ctx, []byte("a"))
re.Error(err)
}

_, err = cli.GetRegion(context.TODO(), []byte("a"))
_, err = cli.GetRegion(ctx, []byte("a"))
re.Error(err)
re.Contains(err.Error(), "circuit breaker is open")

cli.UpdateOption(opt.RegionMetadataCircuitBreakerSettings, func(config *cb.Settings) {
circuitBreaker.ChangeSettings(func(config *cb.Settings) {
*config = cb.AlwaysClosedSettings
})

_, err = cli.GetRegion(context.TODO(), []byte("a"))
_, err = cli.GetRegion(ctx, []byte("a"))
re.Error(err)
re.Contains(err.Error(), "ResourceExhausted")
re.NoError(failpoint.Disable("github.com/tikv/pd/client/triggerCircuitBreaker"))
re.NoError(failpoint.Disable("github.com/tikv/pd/client/pkg/utils/grpcutil/triggerCircuitBreaker"))
}

func TestCircuitBreakerHalfOpenAndChangeSettings(t *testing.T) {
Expand All @@ -2171,33 +2174,36 @@ func TestCircuitBreakerHalfOpenAndChangeSettings(t *testing.T) {
}

endpoints := runServer(re, cluster)
cli := setupCli(ctx, re, endpoints, opt.WithRegionMetaCircuitBreaker(circuitBreakerSettings))

cli := setupCli(ctx, re, endpoints)
defer cli.Close()

circuitBreaker := cb.NewCircuitBreaker[*pdpb.GetRegionResponse]("region_meta", circuitBreakerSettings)
ctx = cb.WithCircuitBreaker(ctx, circuitBreaker)
for range 10 {
region, err := cli.GetRegion(context.TODO(), []byte("a"))
region, err := cli.GetRegion(ctx, []byte("a"))
re.NoError(err)
re.NotNil(region)
}

re.NoError(failpoint.Enable("github.com/tikv/pd/client/triggerCircuitBreaker", "return(true)"))
re.NoError(failpoint.Enable("github.com/tikv/pd/client/pkg/utils/grpcutil/triggerCircuitBreaker", "return(true)"))

for range 100 {
_, err := cli.GetRegion(context.TODO(), []byte("a"))
_, err := cli.GetRegion(ctx, []byte("a"))
re.Error(err)
}

_, err = cli.GetRegion(context.TODO(), []byte("a"))
_, err = cli.GetRegion(ctx, []byte("a"))
re.Error(err)
re.Contains(err.Error(), "circuit breaker is open")

fname := testutil.InitTempFileLogger("info")
defer os.RemoveAll(fname)
// wait for cooldown
time.Sleep(time.Second)
re.NoError(failpoint.Disable("github.com/tikv/pd/client/triggerCircuitBreaker"))
re.NoError(failpoint.Disable("github.com/tikv/pd/client/pkg/utils/grpcutil/triggerCircuitBreaker"))
// trigger circuit breaker state to be half open
_, err = cli.GetRegion(context.TODO(), []byte("a"))
_, err = cli.GetRegion(ctx, []byte("a"))
re.NoError(err)
testutil.Eventually(re, func() bool {
b, _ := os.ReadFile(fname)
Expand All @@ -2207,17 +2213,16 @@ func TestCircuitBreakerHalfOpenAndChangeSettings(t *testing.T) {
})

// The state is half open
re.NoError(failpoint.Enable("github.com/tikv/pd/client/triggerCircuitBreaker", "return(true)"))
re.NoError(failpoint.Enable("github.com/tikv/pd/client/pkg/utils/grpcutil/triggerCircuitBreaker", "return(true)"))
// change settings to always closed
cli.UpdateOption(opt.RegionMetadataCircuitBreakerSettings, func(config *cb.Settings) {
circuitBreaker.ChangeSettings(func(config *cb.Settings) {
*config = cb.AlwaysClosedSettings
})

// It won't be changed to open state.
for range 100 {
_, err := cli.GetRegion(context.TODO(), []byte("a"))
_, err := cli.GetRegion(ctx, []byte("a"))
re.Error(err)
re.NotContains(err.Error(), "circuit breaker is open")
}
re.NoError(failpoint.Disable("github.com/tikv/pd/client/triggerCircuitBreaker"))
re.NoError(failpoint.Disable("github.com/tikv/pd/client/pkg/utils/grpcutil/triggerCircuitBreaker"))
}

0 comments on commit f1aeab1

Please sign in to comment.