From 555a952e6e80f9dc8ae4df612523e88c9f6dd185 Mon Sep 17 00:00:00 2001 From: Quinn Klassen Date: Sat, 13 Jan 2024 09:04:05 -0800 Subject: [PATCH] Make dataConverterWithoutDeadlock context aware --- internal/workflow_deadlock.go | 4 +-- internal/workflow_deadlock_test.go | 44 ++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/internal/workflow_deadlock.go b/internal/workflow_deadlock.go index 9a517739b..6f3a16d85 100644 --- a/internal/workflow_deadlock.go +++ b/internal/workflow_deadlock.go @@ -211,9 +211,9 @@ func (d *dataConverterWithoutDeadlock) ToStrings(input *commonpb.Payloads) []str } func (d *dataConverterWithoutDeadlock) WithWorkflowContext(ctx Context) converter.DataConverter { - return &dataConverterWithoutDeadlock{context: ctx, underlying: d.underlying} + return &dataConverterWithoutDeadlock{context: ctx, underlying: WithWorkflowContext(ctx, d.underlying)} } func (d *dataConverterWithoutDeadlock) WithContext(ctx context.Context) converter.DataConverter { - return d + return &dataConverterWithoutDeadlock{context: d.context, underlying: WithContext(ctx, d.underlying)} } diff --git a/internal/workflow_deadlock_test.go b/internal/workflow_deadlock_test.go index 1cf1f5663..24fa8d0e3 100644 --- a/internal/workflow_deadlock_test.go +++ b/internal/workflow_deadlock_test.go @@ -96,3 +96,47 @@ func (s *slowToPayloadsConverter) ToPayloads(value ...interface{}) (*commonpb.Pa time.Sleep(600 * time.Millisecond) return s.DataConverter.ToPayloads(value...) } + +func TestDataConverterWithoutDeadlockDetectionContext(t *testing.T) { + contextAwareDataConverter := NewContextAwareDataConverter(converter.GetDefaultDataConverter()) + conv := DataConverterWithoutDeadlockDetection(contextAwareDataConverter) + + t.Parallel() + t.Run("default", func(t *testing.T) { + t.Parallel() + payload, _ := conv.ToPayload("test") + result := conv.ToString(payload) + + require.Equal(t, `"test"`, result) + }) + t.Run("implements ContextAware", func(t *testing.T) { + t.Parallel() + _, ok := conv.(ContextAware) + require.True(t, ok) + }) + t.Run("with activity context", func(t *testing.T) { + t.Parallel() + ctx := context.Background() + ctx = context.WithValue(ctx, ContextAwareDataConverterContextKey, "e") + + dc := WithContext(ctx, conv) + + payload, _ := dc.ToPayload("test") + result := dc.ToString(payload) + + require.Equal(t, `"t?st"`, result) + }) + t.Run("with workflow context", func(t *testing.T) { + t.Parallel() + ctx := Background() + ctx = WithValue(ctx, ContextAwareDataConverterContextKey, "e") + + dc := WithWorkflowContext(ctx, conv) + + payload, _ := dc.ToPayload("test") + result := dc.ToString(payload) + + require.Equal(t, `"t?st"`, result) + }) + +}