diff --git a/internal/extension/extension.go b/internal/extension/extension.go index 3f7179c..4670400 100644 --- a/internal/extension/extension.go +++ b/internal/extension/extension.go @@ -164,8 +164,17 @@ func (em *ExtensionManager) SendEndInvocationRequest(ctx context.Context, functi req.Header.Set(string(DdSamplingPriority), samplingPriority) } } else { - req.Header.Set(string(DdTraceId), fmt.Sprint(functionExecutionSpan.Context().TraceID())) - req.Header.Set(string(DdSpanId), fmt.Sprint(functionExecutionSpan.Context().SpanID())) + spanContext := functionExecutionSpan.Context() + req.Header.Set(string(DdTraceId), fmt.Sprint(spanContext.TraceID())) + req.Header.Set(string(DdSpanId), fmt.Sprint(spanContext.SpanID())) + + // Try to get sampling priority + // Check if the context implements SamplingPriority method + if pc, ok := spanContext.(interface{ SamplingPriority() (int, bool) }); ok && pc != nil { + if priority, ok := pc.SamplingPriority(); ok { + req.Header.Set(string(DdSamplingPriority), fmt.Sprint(priority)) + } + } } resp, err := em.httpClient.Do(req) diff --git a/internal/extension/extension_test.go b/internal/extension/extension_test.go index b4da119..0ad93fe 100644 --- a/internal/extension/extension_test.go +++ b/internal/extension/extension_test.go @@ -223,6 +223,34 @@ func TestExtensionEndInvocationError(t *testing.T) { assert.Contains(t, logOutput, "could not send end invocation payload to the extension") } +type mockSpanContext struct { + ddtrace.SpanContext +} + +func (m mockSpanContext) TraceID() uint64 { return 123 } +func (m mockSpanContext) SpanID() uint64 { return 456 } +func (m mockSpanContext) SamplingPriority() (int, bool) { return -1, true } + +type mockSpan struct{ ddtrace.Span } + +func (m mockSpan) Context() ddtrace.SpanContext { return mockSpanContext{} } + +func TestExtensionEndInvocationSamplingPriority(t *testing.T) { + headers := http.Header{} + em := &ExtensionManager{httpClient: capturingClient{hdr: headers}} + span := &mockSpan{} + + // When priority in context, use that value + ctx := context.WithValue(context.Background(), DdTraceId, "123") + ctx = context.WithValue(ctx, DdSamplingPriority, "2") + em.SendEndInvocationRequest(ctx, span, ddtrace.FinishConfig{}) + assert.Equal(t, "2", headers.Get("X-Datadog-Sampling-Priority")) + + // When no context, get priority from span + em.SendEndInvocationRequest(context.Background(), span, ddtrace.FinishConfig{}) + assert.Equal(t, "-1", headers.Get("X-Datadog-Sampling-Priority")) +} + type capturingClient struct { hdr http.Header }