Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ public object OnDelegateBegin<TArg1>(object sender, ref TArg1 arg)
LambdaCommon.Log("DelegateWrapper Running OnDelegateBegin");

Scope scope;
object requestid = null;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[naming conventions]
We should use requestId instead of requestid (this is the nit-pickiest comment in this PR, I promise)

[type safety]
requestid can be string since ILambdaContext.AwsRequestId is declared as string. We want to avoid using object as much as possible.

var proxyInstance = arg.DuckCast<IInvocationRequest>();
if (proxyInstance == null)
{
Expand All @@ -80,11 +81,12 @@ public object OnDelegateBegin<TArg1>(object sender, ref TArg1 arg)
else
{
var jsonString = ConvertPayloadStream(proxyInstance.InputStream);
scope = LambdaCommon.SendStartInvocation(new LambdaRequestBuilder(), jsonString, proxyInstance.LambdaContext?.ClientContext?.Custom);
scope = LambdaCommon.SendStartInvocation(new LambdaRequestBuilder(), jsonString, proxyInstance.LambdaContext);
requestid = proxyInstance.LambdaContext?.AwsRequestId;
}

LambdaCommon.Log("DelegateWrapper FINISHED Running OnDelegateBegin");
return new CallTargetState(scope);
return new CallTargetState(scope, requestid);
}

public void OnException(object sender, Exception ex)
Expand All @@ -108,18 +110,18 @@ public async Task<TInnerReturn> OnDelegateEndAsync<TInnerReturn>(object sender,
if (proxyInstance == null)
{
LambdaCommon.Log("DuckCast.IInvocationResponse got null proxyInstance", debug: false);
await LambdaCommon.EndInvocationAsync(string.Empty, exception, ((CallTargetState)state!).Scope, RequestBuilder).ConfigureAwait(false);
await LambdaCommon.EndInvocationAsync(string.Empty, exception, state, RequestBuilder).ConfigureAwait(false);
}
else
{
var jsonString = ConvertPayloadStream(proxyInstance.OutputStream);
await LambdaCommon.EndInvocationAsync(jsonString, exception, ((CallTargetState)state!).Scope, RequestBuilder).ConfigureAwait(false);
await LambdaCommon.EndInvocationAsync(jsonString, exception, state, RequestBuilder).ConfigureAwait(false);
}
}
catch (Exception ex)
{
LambdaCommon.Log("OnDelegateEndAsync could not send payload to the extension", ex, false);
await LambdaCommon.EndInvocationAsync(string.Empty, ex, ((CallTargetState)state!).Scope, RequestBuilder).ConfigureAwait(false);
await LambdaCommon.EndInvocationAsync(string.Empty, ex, state, RequestBuilder).ConfigureAwait(false);
}

LambdaCommon.Log("DelegateWrapper FINISHED Running OnDelegateEndAsync");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ internal interface ILambdaExtensionRequest
/// Get the end invocation request
/// </summary>
/// <returns>The end invocation request</returns>
WebRequest GetEndInvocationRequest(Scope scope, bool isError);
WebRequest GetEndInvocationRequest(Scope scope, object state, bool isError);
}

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
using System.Net;
using System.Text;
using System.Threading.Tasks;
using Datadog.Trace.ClrProfiler.AutoInstrumentation.AWS.SDK;
using Datadog.Trace.ClrProfiler.CallTarget;
using Datadog.Trace.Headers;
using Datadog.Trace.Propagators;
using Datadog.Trace.Telemetry;
Expand All @@ -22,6 +24,7 @@ internal abstract class LambdaCommon
private const string PlaceholderOperationName = "placeholder-operation";
private const double ServerlessMaxWaitingFlushTime = 3;
private const string LogLevelEnvName = "DD_LOG_LEVEL";
private const string LambdaRuntimeAwsRequestIdHeader = "lambda-runtime-aws-request-id";

internal static Scope CreatePlaceholderScope(Tracer tracer, NameValueHeadersCollection headers)
{
Expand All @@ -38,11 +41,16 @@ internal static Scope CreatePlaceholderScope(Tracer tracer, NameValueHeadersColl
return tracer.TracerManager.ScopeManager.Activate(span, false);
}

internal static Scope SendStartInvocation(ILambdaExtensionRequest requestBuilder, string data, IDictionary<string, string> context)
internal static Scope SendStartInvocation(ILambdaExtensionRequest requestBuilder, string data, ILambdaContext context)
{
var request = requestBuilder.GetStartInvocationRequest();
WriteRequestPayload(request, data);
WriteRequestHeaders(request, context);
WriteRequestHeaders(request, context?.ClientContext?.Custom);
Copy link
Member

@lucaspimentel lucaspimentel Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If either context or context.ClientContext or context.ClientContext.Custom are null, we will effectively call WriteRequestHeaders(request, null).

Does WriteRequestHeaders() handle null gracefully? Or should we check for null before calling?

if (context?.ClientContext?.Custom is { } c)
{
    WriteRequestHeaders(request, c);
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WriteRequestHeaders starts by checking if context is null (if it is null it does nothing). This was how it was structured prior to my changes, so I am choosing to leave it that way.

if (context?.AwsRequestId != null)
{
request.Headers.Add(LambdaRuntimeAwsRequestIdHeader, context.AwsRequestId);
}

using var response = (HttpWebResponse)request.GetResponse();

var headers = response.Headers.Wrap();
Expand All @@ -55,9 +63,9 @@ internal static Scope SendStartInvocation(ILambdaExtensionRequest requestBuilder
return CreatePlaceholderScope(tracer, headers);
}

internal static void SendEndInvocation(ILambdaExtensionRequest requestBuilder, Scope scope, bool isError, string data)
internal static void SendEndInvocation(ILambdaExtensionRequest requestBuilder, Scope scope, object state, bool isError, string data)
Copy link
Member

@lucaspimentel lucaspimentel Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general we should try to avoid using the object type when passing arguments around. C# developers like the strong type safety that the language provides and treating anything as object disable that language feature.

To avoid that, we should cast the state to CallTargetState as early as possible and pass it around as a CallTargetState, not as object.

{
var request = requestBuilder.GetEndInvocationRequest(scope, isError);
var request = requestBuilder.GetEndInvocationRequest(scope, state, isError);
WriteRequestPayload(request, data);
using var response = (HttpWebResponse)request.GetResponse();

Expand All @@ -67,8 +75,10 @@ internal static void SendEndInvocation(ILambdaExtensionRequest requestBuilder, S
}
}

internal static async Task EndInvocationAsync(string returnValue, Exception exception, Scope scope, ILambdaExtensionRequest requestBuilder)
internal static async Task EndInvocationAsync(string returnValue, Exception exception, object stateObject, ILambdaExtensionRequest requestBuilder)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here about the using object as parameter type.

{
var state = (CallTargetState)stateObject!;
Copy link
Member

@lucaspimentel lucaspimentel Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a dangerous cast that will throw an exception if stateObject is null or otherwise not a CallTargetState. Since stateObject is declared as object, it could be anything, so we should always be defensive and check first:

if (state is CallTargetState)
{
    ...

Even better would be to use CallTargetState as the parameter type instead of object. Then the type is guaranteed and we won't need to either check or cast here.

var scope = state.Scope;
try
{
await Task.WhenAll(
Expand All @@ -90,7 +100,7 @@ await Task.WhenAll(
span.SetException(exception);
}

SendEndInvocation(requestBuilder, scope, exception != null, returnValue);
SendEndInvocation(requestBuilder, scope, state.State, exception != null, returnValue);
}
catch (Exception ex)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,17 @@ WebRequest ILambdaExtensionRequest.GetStartInvocationRequest()
return request;
}

WebRequest ILambdaExtensionRequest.GetEndInvocationRequest(Scope scope, bool isError)
WebRequest ILambdaExtensionRequest.GetEndInvocationRequest(Scope scope, object state, bool isError)
Copy link
Member

@lucaspimentel lucaspimentel Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's a good example of why not to not use object in method signature if we can avoid it: in some methods state is a CallTargetState, but here it's a string. It's easy to get them confused and try to cast to the wrong type, which will throw an exception at run time.

Or if someone mistakenly checks if (state is CallTargetState) before casting, it will always be false (because this one is a string) and we'll never hit the code path below that adds the request header.

{
var request = WebRequest.Create(Uri + EndInvocationPath);
request.Method = "POST";
request.Headers.Set(HttpHeaderNames.TracingEnabled, "false");

if (state != null)
{
request.Headers.Set("lambda-runtime-aws-request-id", (string)state);
}
Comment on lines +44 to +47
Copy link
Member

@lucaspimentel lucaspimentel Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although it should never happen, if state is ever a different type from string, this casting operation will throw an exception. It would be safer to check that state is a string, not simply not null

The more modern way of doing this in C# is by combining the type check with a string variable declaration and removing the (string) casting operator:

if (state is string requestId)
{
    request.Headers.Set("lambda-runtime-aws-request-id", requestId);
}


if (scope is { Span: var span })
{
// TODO: add support for 128-bit trace ids in serverless
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ namespace Datadog.Trace.ClrProfiler.AutoInstrumentation.AWS.SDK;
/// </summary>
internal interface ILambdaContext
{
// /// <summary>
// /// Gets the AWS request ID associated with the request.
// /// This is the same ID returned to the client that called invoke().
// /// This ID is reused for retries on the same request.
// /// </summary>
// string AwsRequestId { get; }
/// <summary>
/// Gets the AWS request ID associated with the request.
/// This is the same ID returned to the client that called invoke().
/// This ID is reused for retries on the same request.
/// </summary>
string AwsRequestId { get; }

/// <summary>
/// Gets information about the client application and device when invoked
Expand Down
21 changes: 12 additions & 9 deletions tracer/test/Datadog.Trace.Tests/LambdaCommonTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ public void TestSendStartInvocationThrow()

_lambdaRequestMock.Setup(lr => lr.GetStartInvocationRequest()).Returns(httpRequest.Object);

Assert.Throws<WebException>(() => LambdaCommon.SendStartInvocation(_lambdaRequestMock.Object, "{}", new Dictionary<string, string>()));
Assert.Throws<WebException>(() => LambdaCommon.SendStartInvocation(_lambdaRequestMock.Object, "{}", null));
}

[Fact]
Expand All @@ -110,7 +110,7 @@ public void TestSendStartInvocationNull()

_lambdaRequestMock.Setup(lr => lr.GetStartInvocationRequest()).Returns(httpRequest.Object);

LambdaCommon.SendStartInvocation(_lambdaRequestMock.Object, "{}", new Dictionary<string, string>()).Should().BeNull();
LambdaCommon.SendStartInvocation(_lambdaRequestMock.Object, "{}", null).Should().BeNull();
}

[Fact]
Expand All @@ -129,7 +129,7 @@ public void TestSendStartInvocationSuccess()

_lambdaRequestMock.Setup(lr => lr.GetStartInvocationRequest()).Returns(httpRequest.Object);

LambdaCommon.SendStartInvocation(_lambdaRequestMock.Object, "{}", new Dictionary<string, string>()).Should().NotBeNull();
LambdaCommon.SendStartInvocation(_lambdaRequestMock.Object, "{}", null).Should().NotBeNull();
}

[Fact]
Expand All @@ -139,6 +139,7 @@ public async Task TestSendEndInvocationFailure()
await using var tracer = TracerHelper.CreateWithFakeAgent();
var headers = new WebHeaderCollection { { HttpHeaderNames.TraceId, "1234" }, { HttpHeaderNames.SamplingPriority, "-1" } }.Wrap();
var scope = LambdaCommon.CreatePlaceholderScope(tracer, headers);
var state = "example-aws-request-id";

var response = new Mock<HttpWebResponse>(MockBehavior.Loose);
var responseStream = new Mock<Stream>(MockBehavior.Loose);
Expand All @@ -148,9 +149,9 @@ public async Task TestSendEndInvocationFailure()
httpRequest.Setup(h => h.GetResponse()).Throws(new WebException());
httpRequest.Setup(h => h.GetRequestStream()).Returns(responseStream.Object);

_lambdaRequestMock.Setup(lr => lr.GetEndInvocationRequest(scope, true)).Returns(httpRequest.Object);
_lambdaRequestMock.Setup(lr => lr.GetEndInvocationRequest(scope, state, true)).Returns(httpRequest.Object);

Assert.Throws<WebException>(() => LambdaCommon.SendEndInvocation(_lambdaRequestMock.Object, scope, true, "{}"));
Assert.Throws<WebException>(() => LambdaCommon.SendEndInvocation(_lambdaRequestMock.Object, scope, state, true, "{}"));
}

[Fact]
Expand All @@ -160,6 +161,7 @@ public async Task TestSendEndInvocationSuccess()
await using var tracer = TracerHelper.CreateWithFakeAgent();
var headers = new WebHeaderCollection { { HttpHeaderNames.TraceId, "1234" }, { HttpHeaderNames.SamplingPriority, "-1" } }.Wrap();
var scope = LambdaCommon.CreatePlaceholderScope(tracer, headers);
var state = "example-aws-request-id";

var response = new Mock<HttpWebResponse>(MockBehavior.Loose);
var responseStream = new Mock<Stream>(MockBehavior.Loose);
Expand All @@ -169,10 +171,10 @@ public async Task TestSendEndInvocationSuccess()
httpRequest.Setup(h => h.GetResponse()).Returns(response.Object);
httpRequest.Setup(h => h.GetRequestStream()).Returns(responseStream.Object);

_lambdaRequestMock.Setup(lr => lr.GetEndInvocationRequest(scope, true)).Returns(httpRequest.Object);
_lambdaRequestMock.Setup(lr => lr.GetEndInvocationRequest(scope, state, true)).Returns(httpRequest.Object);
var output = new StringWriter();
Console.SetOut(output);
LambdaCommon.SendEndInvocation(_lambdaRequestMock.Object, scope, true, "{}");
LambdaCommon.SendEndInvocation(_lambdaRequestMock.Object, scope, state, true, "{}");
httpRequest.Verify(r => r.GetResponse(), Times.Once);
Assert.Empty(output.ToString());
}
Expand All @@ -184,6 +186,7 @@ public async Task TestSendEndInvocationFalse()
await using var tracer = TracerHelper.CreateWithFakeAgent();
var headers = new WebHeaderCollection { { HttpHeaderNames.TraceId, "1234" }, { HttpHeaderNames.SamplingPriority, "-1" } }.Wrap();
var scope = LambdaCommon.CreatePlaceholderScope(tracer, headers);
var state = "example-aws-request-id";

var response = new Mock<HttpWebResponse>(MockBehavior.Loose);
var responseStream = new Mock<Stream>(MockBehavior.Loose);
Expand All @@ -193,10 +196,10 @@ public async Task TestSendEndInvocationFalse()
httpRequest.Setup(h => h.GetResponse()).Returns(response.Object);
httpRequest.Setup(h => h.GetRequestStream()).Returns(responseStream.Object);

_lambdaRequestMock.Setup(lr => lr.GetEndInvocationRequest(scope, true)).Returns(httpRequest.Object);
_lambdaRequestMock.Setup(lr => lr.GetEndInvocationRequest(scope, state, true)).Returns(httpRequest.Object);
var output = new StringWriter();
Console.SetOut(output);
LambdaCommon.SendEndInvocation(_lambdaRequestMock.Object, scope, true, "{}");
LambdaCommon.SendEndInvocation(_lambdaRequestMock.Object, scope, state, true, "{}");
httpRequest.Verify(r => r.GetResponse(), Times.Once);
Assert.Contains("Extension does not send a status 200 OK", output.ToString());
}
Expand Down
Loading
Loading