Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: refactor zrpc timeout #3671

Merged
merged 2 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions zrpc/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func SetClientSlowThreshold(threshold time.Duration) {
clientinterceptors.SetSlowThreshold(threshold)
}

// WithTimeoutCallOption return a call option with given timeout.
func WithTimeoutCallOption(timeout time.Duration) grpc.CallOption {
return clientinterceptors.WithTimeoutCallOption(timeout)
// WithCallTimeout return a call option with given timeout to make a method call.
func WithCallTimeout(timeout time.Duration) grpc.CallOption {
return clientinterceptors.WithCallTimeout(timeout)
}
28 changes: 14 additions & 14 deletions zrpc/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ func dialer() func(context.Context, string) (net.Conn, error) {

func TestDepositServer_Deposit(t *testing.T) {
tests := []struct {
name string
amount float32
timeoutCallOption time.Duration
res *mock.DepositResponse
errCode codes.Code
errMsg string
name string
amount float32
timeout time.Duration
res *mock.DepositResponse
errCode codes.Code
errMsg string
}{
{
name: "invalid request with negative amount",
Expand All @@ -66,12 +66,12 @@ func TestDepositServer_Deposit(t *testing.T) {
errMsg: "context deadline exceeded",
},
{
name: "valid request with timeout call option",
amount: 2000.00,
timeoutCallOption: time.Second * 3,
res: &mock.DepositResponse{Ok: true},
errCode: codes.OK,
errMsg: "",
name: "valid request with timeout call option",
amount: 2000.00,
timeout: time.Second * 3,
res: &mock.DepositResponse{Ok: true},
errCode: codes.OK,
errMsg: "",
},
}

Expand Down Expand Up @@ -171,8 +171,8 @@ func TestDepositServer_Deposit(t *testing.T) {
err error
)

if tt.timeoutCallOption > 0 {
response, err = cli.Deposit(ctx, request, WithTimeoutCallOption(tt.timeoutCallOption))
if tt.timeout > 0 {
response, err = cli.Deposit(ctx, request, WithCallTimeout(tt.timeout))
} else {
response, err = cli.Deposit(ctx, request)
}
Expand Down
6 changes: 3 additions & 3 deletions zrpc/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ type (
ServerMiddlewaresConf = internal.ServerMiddlewaresConf
// StatConf defines the stat config.
StatConf = internal.StatConf
// ServerSpecifiedTimeoutConf defines specified timeout for gRPC method.
ServerSpecifiedTimeoutConf = internal.ServerSpecifiedTimeoutConf
// MethodTimeoutConf defines specified timeout for gRPC method.
MethodTimeoutConf = internal.MethodTimeoutConf

// A RpcClientConf is a rpc client config.
RpcClientConf struct {
Expand Down Expand Up @@ -48,7 +48,7 @@ type (
Health bool `json:",default=true"`
Middlewares ServerMiddlewaresConf
// setting specified timeout for gRPC method
SpecifiedTimeouts []ServerSpecifiedTimeoutConf `json:",optional"`
MethodTimeouts []MethodTimeoutConf `json:",optional"`
}
)

Expand Down
33 changes: 17 additions & 16 deletions zrpc/internal/clientinterceptors/timeoutinterceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,17 @@ import (
"google.golang.org/grpc"
)

// TimeoutCallOption is a call option that controls timeout.
type TimeoutCallOption struct {
grpc.EmptyCallOption
timeout time.Duration
}

// TimeoutInterceptor is an interceptor that controls timeout.
func TimeoutInterceptor(timeout time.Duration) grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn,
invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
t := getTimeoutByCallOptions(opts, timeout)
t := getTimeoutFromCallOptions(opts, timeout)
if t <= 0 {
return invoker(ctx, method, req, reply, cc, opts...)
}
Expand All @@ -23,24 +29,19 @@ func TimeoutInterceptor(timeout time.Duration) grpc.UnaryClientInterceptor {
}
}

func getTimeoutByCallOptions(callOptions []grpc.CallOption, defaultTimeout time.Duration) time.Duration {
for _, callOption := range callOptions {
if o, ok := callOption.(TimeoutCallOption); ok {
// WithCallTimeout returns a call option that controls method call timeout.
func WithCallTimeout(timeout time.Duration) grpc.CallOption {
return TimeoutCallOption{
timeout: timeout,
}
}

func getTimeoutFromCallOptions(opts []grpc.CallOption, defaultTimeout time.Duration) time.Duration {
for _, opt := range opts {
if o, ok := opt.(TimeoutCallOption); ok {
return o.timeout
}
}

return defaultTimeout
}

type TimeoutCallOption struct {
grpc.EmptyCallOption

timeout time.Duration
}

func WithTimeoutCallOption(timeout time.Duration) grpc.CallOption {
return TimeoutCallOption{
timeout: timeout,
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func TestTimeoutInterceptor_TimeoutCallOption(t *testing.T) {
cc := new(grpc.ClientConn)
var co []grpc.CallOption
if tt.args.callOptionTimeout > 0 {
co = append(co, WithTimeoutCallOption(tt.args.callOptionTimeout))
co = append(co, WithCallTimeout(tt.args.callOptionTimeout))
}

err := interceptor(context.Background(), "/foo", nil, nil, cc,
Expand Down
3 changes: 2 additions & 1 deletion zrpc/internal/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,6 @@ type (
Breaker bool `json:",default=true"`
}

ServerSpecifiedTimeoutConf = serverinterceptors.ServerSpecifiedTimeoutConf
// MethodTimeoutConf defines specified timeout for gRPC methods.
MethodTimeoutConf = serverinterceptors.MethodTimeoutConf
)
34 changes: 15 additions & 19 deletions zrpc/internal/serverinterceptors/timeoutinterceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,22 @@ import (
)

type (
// ServerSpecifiedTimeoutConf defines specified timeout for gRPC method.
ServerSpecifiedTimeoutConf struct {
// MethodTimeoutConf defines specified timeout for gRPC method.
MethodTimeoutConf struct {
FullMethod string
Timeout time.Duration
}

specifiedTimeoutCache map[string]time.Duration
methodTimeouts map[string]time.Duration
)

// UnaryTimeoutInterceptor returns a func that sets timeout to incoming unary requests.
func UnaryTimeoutInterceptor(timeout time.Duration, specifiedTimeouts ...ServerSpecifiedTimeoutConf) grpc.UnaryServerInterceptor {
cache := cacheSpecifiedTimeout(specifiedTimeouts)
func UnaryTimeoutInterceptor(timeout time.Duration,
methodTimeouts ...MethodTimeoutConf) grpc.UnaryServerInterceptor {
timeouts := buildMethodTimeouts(methodTimeouts)
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler) (any, error) {
t := getTimeoutByUnaryServerInfo(info, timeout, cache)
t := getTimeoutByUnaryServerInfo(info.FullMethod, timeouts, timeout)
ctx, cancel := context.WithTimeout(ctx, t)
defer cancel()

Expand Down Expand Up @@ -72,27 +73,22 @@ func UnaryTimeoutInterceptor(timeout time.Duration, specifiedTimeouts ...ServerS
}
}

func cacheSpecifiedTimeout(specifiedTimeouts []ServerSpecifiedTimeoutConf) specifiedTimeoutCache {
cache := make(specifiedTimeoutCache, len(specifiedTimeouts))
for _, st := range specifiedTimeouts {
func buildMethodTimeouts(timeouts []MethodTimeoutConf) methodTimeouts {
mt := make(methodTimeouts, len(timeouts))
for _, st := range timeouts {
if st.FullMethod != "" {
cache[st.FullMethod] = st.Timeout
mt[st.FullMethod] = st.Timeout
}
}

return cache
return mt
}

func getTimeoutByUnaryServerInfo(info *grpc.UnaryServerInfo, defaultTimeout time.Duration, specifiedTimeout specifiedTimeoutCache) time.Duration {
if ts, ok := info.Server.(TimeoutStrategy); ok {
return ts.GetTimeoutByFullMethod(info.FullMethod, defaultTimeout)
} else if v, ok := specifiedTimeout[info.FullMethod]; ok {
func getTimeoutByUnaryServerInfo(method string, timeouts methodTimeouts,
defaultTimeout time.Duration) time.Duration {
if v, ok := timeouts[method]; ok {
return v
}

return defaultTimeout
}

type TimeoutStrategy interface {
GetTimeoutByFullMethod(fullMethod string, defaultTimeout time.Duration) time.Duration
}
22 changes: 2 additions & 20 deletions zrpc/internal/serverinterceptors/timeoutinterceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,6 @@ type tempServer struct {
func (s *tempServer) run(duration time.Duration) {
time.Sleep(duration)
}
func (s *tempServer) GetTimeoutByFullMethod(fullMethod string, defaultTimeout time.Duration) time.Duration {
if fullMethod == "/" {
return defaultTimeout
}

return s.timeout
}

func TestUnaryTimeoutInterceptor_TimeoutStrategy(t *testing.T) {
type args struct {
Expand All @@ -136,17 +129,6 @@ func TestUnaryTimeoutInterceptor_TimeoutStrategy(t *testing.T) {
},
wantErr: nil,
},
{
name: "do not timeout with timeout strategy",
args: args{
interceptorTimeout: time.Second,
contextTimeout: time.Second * 5,
serverTimeout: time.Second * 3,
runTime: time.Second * 2,
fullMethod: "/2s",
},
wantErr: nil,
},
{
name: "timeout with interceptor timeout",
args: args{
Expand Down Expand Up @@ -235,9 +217,9 @@ func TestUnaryTimeoutInterceptor_SpecifiedTimeout(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

var specifiedTimeouts []ServerSpecifiedTimeoutConf
var specifiedTimeouts []MethodTimeoutConf
if tt.args.methodTimeout > 0 {
specifiedTimeouts = []ServerSpecifiedTimeoutConf{
specifiedTimeouts = []MethodTimeoutConf{
{
FullMethod: tt.args.method,
Timeout: tt.args.methodTimeout,
Expand Down
8 changes: 2 additions & 6 deletions zrpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,8 @@ func setupInterceptors(svr internal.Server, c RpcServerConf, metrics *stat.Metri
}

if c.Timeout > 0 {
svr.AddUnaryInterceptors(
serverinterceptors.UnaryTimeoutInterceptor(
time.Duration(c.Timeout)*time.Millisecond,
c.SpecifiedTimeouts...,
),
)
svr.AddUnaryInterceptors(serverinterceptors.UnaryTimeoutInterceptor(
time.Duration(c.Timeout)*time.Millisecond, c.MethodTimeouts...))
}

if c.Auth {
Expand Down
8 changes: 4 additions & 4 deletions zrpc/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func TestServer_setupInterceptors(t *testing.T) {
Prometheus: true,
Breaker: true,
},
SpecifiedTimeouts: []ServerSpecifiedTimeoutConf{
MethodTimeouts: []MethodTimeoutConf{
{
FullMethod: "/foo",
Timeout: 5 * time.Second,
Expand Down Expand Up @@ -81,7 +81,7 @@ func TestServer(t *testing.T) {
Prometheus: true,
Breaker: true,
},
SpecifiedTimeouts: []ServerSpecifiedTimeoutConf{
MethodTimeouts: []MethodTimeoutConf{
{
FullMethod: "/foo",
Timeout: time.Second,
Expand Down Expand Up @@ -117,7 +117,7 @@ func TestServerError(t *testing.T) {
Prometheus: true,
Breaker: true,
},
SpecifiedTimeouts: []ServerSpecifiedTimeoutConf{},
MethodTimeouts: []MethodTimeoutConf{},
}, func(server *grpc.Server) {
})
assert.NotNil(t, err)
Expand All @@ -144,7 +144,7 @@ func TestServer_HasEtcd(t *testing.T) {
Prometheus: true,
Breaker: true,
},
SpecifiedTimeouts: []ServerSpecifiedTimeoutConf{},
MethodTimeouts: []MethodTimeoutConf{},
}, func(server *grpc.Server) {
})
svr.AddOptions(grpc.ConnectionTimeout(time.Hour))
Expand Down