From e198ca3d2490334bad21c24b96bee331f2b10a84 Mon Sep 17 00:00:00 2001 From: Shijie Sheng Date: Thu, 7 Mar 2024 14:07:56 -0800 Subject: [PATCH] [history] fix generated timeout wrapper (#5737) What changed? #5728 code generated history client should exclude some endpoints. Why? Some replication endpoints don't have timeout before How did you test it? unit test --- client/history/interface.go | 2 +- client/templates/timeout.tmpl | 11 ++-- client/wrappers/timeout/history_generated.go | 64 ------------------- .../timeout/history_generated_test.go | 25 +++++++- 4 files changed, 31 insertions(+), 71 deletions(-) diff --git a/client/history/interface.go b/client/history/interface.go index f1329059222..7e5a6a13345 100644 --- a/client/history/interface.go +++ b/client/history/interface.go @@ -34,7 +34,7 @@ import ( //go:generate gowrap gen -g -p . -i Client -t ../templates/errorinjectors.tmpl -o ../wrappers/errorinjectors/history_generated.go -v client=History //go:generate gowrap gen -g -p . -i Client -t ../templates/grpc.tmpl -o ../wrappers/grpc/history_generated.go -v client=History -v package=historyv1 -v path=github.com/uber/cadence/.gen/proto/history/v1 -v prefix=History //go:generate gowrap gen -g -p . -i Client -t ../templates/thrift.tmpl -o ../wrappers/thrift/history_generated.go -v client=History -v prefix=History -//go:generate gowrap gen -g -p . -i Client -t ../templates/timeout.tmpl -o ../wrappers/timeout/history_generated.go -v client=History +//go:generate gowrap gen -g -p . -i Client -t ../templates/timeout.tmpl -o ../wrappers/timeout/history_generated.go -v client=History -v exclude=GetReplicationMessages|GetDLQReplicationMessages|CountDLQMessages|ReadDLQMessages|PurgeDLQMessages|MergeDLQMessages|GetCrossClusterTasks|GetFailoverInfo // Client is the interface exposed by history service client type Client interface { diff --git a/client/templates/timeout.tmpl b/client/templates/timeout.tmpl index a9ba1384294..95e58a7b17b 100644 --- a/client/templates/timeout.tmpl +++ b/client/templates/timeout.tmpl @@ -3,6 +3,7 @@ import ( "time" ) +{{$exclude := splitList "|" (index .Vars "exclude")}} {{$clientName := (index .Vars "client")}} {{ $decorator := (printf "%s%s" (down $clientName) .Interface.Name) }} {{ $Decorator := (printf "%s%s" $clientName .Interface.Name) }} @@ -22,8 +23,8 @@ func New{{$Decorator}}(client {{.Interface.Type}}, timeout time.Duration) {{.Int } {{range $method := .Interface.Methods}} - {{if $method.AcceptsContext }} - func (c *{{$decorator}}) {{$method.Declaration}} { + func (c *{{$decorator}}) {{$method.Declaration}} { + {{- if and $method.AcceptsContext (not (has $method.Name $exclude)) }} if ctx == nil { ctx = context.Background() } @@ -32,7 +33,7 @@ func New{{$Decorator}}(client {{.Interface.Type}}, timeout time.Duration) {{.Int ctx, cancelFunc = context.WithTimeout(ctx, c.timeout) defer cancelFunc() } - {{$method.Pass ("c.client.") }} - } - {{end}} + {{- end}} + {{$method.Pass ("c.client.") }} + } {{end}} diff --git a/client/wrappers/timeout/history_generated.go b/client/wrappers/timeout/history_generated.go index 96d85264b8f..ecbe048ca89 100644 --- a/client/wrappers/timeout/history_generated.go +++ b/client/wrappers/timeout/history_generated.go @@ -63,14 +63,6 @@ func (c *historyClient) CloseShard(ctx context.Context, cp1 *types.CloseShardReq } func (c *historyClient) CountDLQMessages(ctx context.Context, cp1 *types.CountDLQMessagesRequest, p1 ...yarpc.CallOption) (hp1 *types.HistoryCountDLQMessagesResponse, err error) { - if ctx == nil { - ctx = context.Background() - } - var cancelFunc func() - if c.timeout > 0 { - ctx, cancelFunc = context.WithTimeout(ctx, c.timeout) - defer cancelFunc() - } return c.client.CountDLQMessages(ctx, cp1, p1...) } @@ -123,38 +115,14 @@ func (c *historyClient) DescribeWorkflowExecution(ctx context.Context, hp1 *type } func (c *historyClient) GetCrossClusterTasks(ctx context.Context, gp1 *types.GetCrossClusterTasksRequest, p1 ...yarpc.CallOption) (gp2 *types.GetCrossClusterTasksResponse, err error) { - if ctx == nil { - ctx = context.Background() - } - var cancelFunc func() - if c.timeout > 0 { - ctx, cancelFunc = context.WithTimeout(ctx, c.timeout) - defer cancelFunc() - } return c.client.GetCrossClusterTasks(ctx, gp1, p1...) } func (c *historyClient) GetDLQReplicationMessages(ctx context.Context, gp1 *types.GetDLQReplicationMessagesRequest, p1 ...yarpc.CallOption) (gp2 *types.GetDLQReplicationMessagesResponse, err error) { - if ctx == nil { - ctx = context.Background() - } - var cancelFunc func() - if c.timeout > 0 { - ctx, cancelFunc = context.WithTimeout(ctx, c.timeout) - defer cancelFunc() - } return c.client.GetDLQReplicationMessages(ctx, gp1, p1...) } func (c *historyClient) GetFailoverInfo(ctx context.Context, gp1 *types.GetFailoverInfoRequest, p1 ...yarpc.CallOption) (gp2 *types.GetFailoverInfoResponse, err error) { - if ctx == nil { - ctx = context.Background() - } - var cancelFunc func() - if c.timeout > 0 { - ctx, cancelFunc = context.WithTimeout(ctx, c.timeout) - defer cancelFunc() - } return c.client.GetFailoverInfo(ctx, gp1, p1...) } @@ -171,26 +139,10 @@ func (c *historyClient) GetMutableState(ctx context.Context, gp1 *types.GetMutab } func (c *historyClient) GetReplicationMessages(ctx context.Context, gp1 *types.GetReplicationMessagesRequest, p1 ...yarpc.CallOption) (gp2 *types.GetReplicationMessagesResponse, err error) { - if ctx == nil { - ctx = context.Background() - } - var cancelFunc func() - if c.timeout > 0 { - ctx, cancelFunc = context.WithTimeout(ctx, c.timeout) - defer cancelFunc() - } return c.client.GetReplicationMessages(ctx, gp1, p1...) } func (c *historyClient) MergeDLQMessages(ctx context.Context, mp1 *types.MergeDLQMessagesRequest, p1 ...yarpc.CallOption) (mp2 *types.MergeDLQMessagesResponse, err error) { - if ctx == nil { - ctx = context.Background() - } - var cancelFunc func() - if c.timeout > 0 { - ctx, cancelFunc = context.WithTimeout(ctx, c.timeout) - defer cancelFunc() - } return c.client.MergeDLQMessages(ctx, mp1, p1...) } @@ -219,14 +171,6 @@ func (c *historyClient) PollMutableState(ctx context.Context, pp1 *types.PollMut } func (c *historyClient) PurgeDLQMessages(ctx context.Context, pp1 *types.PurgeDLQMessagesRequest, p1 ...yarpc.CallOption) (err error) { - if ctx == nil { - ctx = context.Background() - } - var cancelFunc func() - if c.timeout > 0 { - ctx, cancelFunc = context.WithTimeout(ctx, c.timeout) - defer cancelFunc() - } return c.client.PurgeDLQMessages(ctx, pp1, p1...) } @@ -243,14 +187,6 @@ func (c *historyClient) QueryWorkflow(ctx context.Context, hp1 *types.HistoryQue } func (c *historyClient) ReadDLQMessages(ctx context.Context, rp1 *types.ReadDLQMessagesRequest, p1 ...yarpc.CallOption) (rp2 *types.ReadDLQMessagesResponse, err error) { - if ctx == nil { - ctx = context.Background() - } - var cancelFunc func() - if c.timeout > 0 { - ctx, cancelFunc = context.WithTimeout(ctx, c.timeout) - defer cancelFunc() - } return c.client.ReadDLQMessages(ctx, rp1, p1...) } diff --git a/client/wrappers/timeout/history_generated_test.go b/client/wrappers/timeout/history_generated_test.go index 6a589130c82..90554f00e8b 100644 --- a/client/wrappers/timeout/history_generated_test.go +++ b/client/wrappers/timeout/history_generated_test.go @@ -28,6 +28,7 @@ import ( "time" "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" "go.uber.org/yarpc" "github.com/uber/cadence/client/history" @@ -53,7 +54,7 @@ func Test_historyClient_CloseShard(t *testing.T) { { name: "nil context success", fields: fields{ - timeout: time.Millisecond * 100, + timeout: time.Millisecond * 150, }, args: args{ ctx: nil, @@ -107,3 +108,25 @@ func Test_historyClient_CloseShard(t *testing.T) { }) } } + +func Test_historyClient_GetReplicationMessages(t *testing.T) { + t.Run("no timeout", func(t *testing.T) { + m := history.NewMockClient(gomock.NewController(t)) + m.EXPECT().GetReplicationMessages(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, r *types.GetReplicationMessagesRequest, opt ...yarpc.CallOption) (*types.GetReplicationMessagesResponse, error) { + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(time.Millisecond * 100): + return &types.GetReplicationMessagesResponse{}, nil + } + } + }) + c := historyClient{ + client: m, + timeout: time.Millisecond * 10, + } + _, err := c.GetReplicationMessages(context.Background(), &types.GetReplicationMessagesRequest{}) + assert.NoError(t, err) + }) +}