Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

opentelemetry-instrumentation-aws-lambda: Adding option to disable context propagation #1466

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
([#685](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/685))
- Add metric instrumentation for tornado
([#1252](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1252))
- `opentelemetry-instrumentation-aws-lambda` Add option to disable aws context propagation
([#1466](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1466))

### Added

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,9 @@ def _default_event_context_extractor(lambda_event: Any) -> Context:


def _determine_parent_context(
lambda_event: Any, event_context_extractor: Callable[[Any], Context]
lambda_event: Any,
event_context_extractor: Callable[[Any], Context],
disable_aws_context_propagation: bool = False,
) -> Context:
"""Determine the parent context for the current Lambda invocation.
Expand All @@ -144,17 +146,25 @@ def _determine_parent_context(
Args:
lambda_event: user-defined, so it could be anything, but this
method counts it being a map with a 'headers' key
event_context_extractor: a method which takes the Lambda
Event as input and extracts an OTel Context from it. By default,
the context is extracted from the HTTP headers of an API Gateway
request.
disable_aws_context_propagation: By default, this instrumentation
will try to read the context from the `_X_AMZN_TRACE_ID` environment
variable set by Lambda, set this to `True` to disable this behavior.
Returns:
A Context with configuration found in the carrier.
"""
parent_context = None

xray_env_var = os.environ.get(_X_AMZN_TRACE_ID)
if not disable_aws_context_propagation:
xray_env_var = os.environ.get(_X_AMZN_TRACE_ID)

if xray_env_var:
parent_context = AwsXRayPropagator().extract(
{TRACE_HEADER_KEY: xray_env_var}
)
if xray_env_var:
parent_context = AwsXRayPropagator().extract(
{TRACE_HEADER_KEY: xray_env_var}
)

if (
parent_context
Expand Down Expand Up @@ -258,6 +268,7 @@ def _instrument(
flush_timeout,
event_context_extractor: Callable[[Any], Context],
tracer_provider: TracerProvider = None,
disable_aws_context_propagation: bool = False,
):
def _instrumented_lambda_handler_call(
call_wrapped, instance, args, kwargs
Expand All @@ -269,7 +280,9 @@ def _instrumented_lambda_handler_call(
lambda_event = args[0]

parent_context = _determine_parent_context(
lambda_event, event_context_extractor
lambda_event,
event_context_extractor,
disable_aws_context_propagation,
)

span_kind = None
Expand Down Expand Up @@ -368,6 +381,9 @@ def _instrument(self, **kwargs):
Event as input and extracts an OTel Context from it. By default,
the context is extracted from the HTTP headers of an API Gateway
request.
``disable_aws_context_propagation``: By default, this instrumentation
will try to read the context from the `_X_AMZN_TRACE_ID` environment
variable set by Lambda, set this to `True` to disable this behavior.
"""
lambda_handler = os.environ.get(ORIG_HANDLER, os.environ.get(_HANDLER))
# pylint: disable=attribute-defined-outside-init
Expand All @@ -377,11 +393,12 @@ def _instrument(self, **kwargs):
) = lambda_handler.rsplit(".", 1)

flush_timeout_env = os.environ.get(
OTEL_INSTRUMENTATION_AWS_LAMBDA_FLUSH_TIMEOUT, ""
OTEL_INSTRUMENTATION_AWS_LAMBDA_FLUSH_TIMEOUT, None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't .get without a default None already anyways?

i.e. shouldn't this just be:

os.environ.get(
            OTEL_INSTRUMENTATION_AWS_LAMBDA_FLUSH_TIMEOUT
)

)
flush_timeout = 30000
try:
flush_timeout = int(flush_timeout_env)
if flush_timeout_env is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not if flush_timeout_env? Maybe you want to guard against the "" environment variable. But perhaps users would prefer a warning log if they set this variable (which should be an int) to an empty string?

flush_timeout = int(flush_timeout_env)
except ValueError:
logger.warning(
"Could not convert OTEL_INSTRUMENTATION_AWS_LAMBDA_FLUSH_TIMEOUT value %s to int",
Expand All @@ -396,6 +413,9 @@ def _instrument(self, **kwargs):
"event_context_extractor", _default_event_context_extractor
),
tracer_provider=kwargs.get("tracer_provider"),
disable_aws_context_propagation=kwargs.get(
"disable_aws_context_propagation", False
),
)

def _uninstrument(self, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from dataclasses import dataclass
from importlib import import_module
from typing import Any, Callable, Dict
from unittest import mock

from mocks.api_gateway_http_api_event import (
Expand Down Expand Up @@ -155,103 +157,129 @@ def test_active_tracing(self):
test_env_patch.stop()

def test_parent_context_from_lambda_event(self):
test_env_patch = mock.patch.dict(
"os.environ",
{
**os.environ,
# NOT Active Tracing
_X_AMZN_TRACE_ID: MOCK_XRAY_TRACE_CONTEXT_NOT_SAMPLED,
# NOT using the X-Ray Propagator
OTEL_PROPAGATORS: "tracecontext",
},
)
test_env_patch.start()

AwsLambdaInstrumentor().instrument()

mock_execute_lambda(
{
"headers": {
TraceContextTextMapPropagator._TRACEPARENT_HEADER_NAME: MOCK_W3C_TRACE_CONTEXT_SAMPLED,
TraceContextTextMapPropagator._TRACESTATE_HEADER_NAME: f"{MOCK_W3C_TRACE_STATE_KEY}={MOCK_W3C_TRACE_STATE_VALUE},foo=1,bar=2",
}
}
)

spans = self.memory_exporter.get_finished_spans()
@dataclass
class TestCase:
name: str
custom_extractor: Callable[[Any], None]
context: Dict
expected_traceid: int
expected_parentid: int
xray_traceid: str
expected_state_value: str = None
expected_trace_state_len: int = 0
disable_aws_context_propagation: bool = False

assert spans

self.assertEqual(len(spans), 1)
span = spans[0]
self.assertEqual(span.get_span_context().trace_id, MOCK_W3C_TRACE_ID)

parent_context = span.parent
self.assertEqual(
parent_context.trace_id, span.get_span_context().trace_id
)
self.assertEqual(parent_context.span_id, MOCK_W3C_PARENT_SPAN_ID)
self.assertEqual(len(parent_context.trace_state), 3)
self.assertEqual(
parent_context.trace_state.get(MOCK_W3C_TRACE_STATE_KEY),
MOCK_W3C_TRACE_STATE_VALUE,
)
self.assertTrue(parent_context.is_remote)

test_env_patch.stop()

def test_using_custom_extractor(self):
def custom_event_context_extractor(lambda_event):
return get_global_textmap().extract(lambda_event["foo"]["headers"])

test_env_patch = mock.patch.dict(
"os.environ",
{
**os.environ,
# NOT Active Tracing
_X_AMZN_TRACE_ID: MOCK_XRAY_TRACE_CONTEXT_NOT_SAMPLED,
# NOT using the X-Ray Propagator
OTEL_PROPAGATORS: "tracecontext",
},
)
test_env_patch.start()

AwsLambdaInstrumentor().instrument(
event_context_extractor=custom_event_context_extractor,
)

mock_execute_lambda(
{
"foo": {
tests = [
TestCase(
name="no_custom_extractor",
custom_extractor=None,
context={
"headers": {
TraceContextTextMapPropagator._TRACEPARENT_HEADER_NAME: MOCK_W3C_TRACE_CONTEXT_SAMPLED,
TraceContextTextMapPropagator._TRACESTATE_HEADER_NAME: f"{MOCK_W3C_TRACE_STATE_KEY}={MOCK_W3C_TRACE_STATE_VALUE},foo=1,bar=2",
}
}
}
)

spans = self.memory_exporter.get_finished_spans()

assert spans

self.assertEqual(len(spans), 1)
span = spans[0]
self.assertEqual(span.get_span_context().trace_id, MOCK_W3C_TRACE_ID)

parent_context = span.parent
self.assertEqual(
parent_context.trace_id, span.get_span_context().trace_id
)
self.assertEqual(parent_context.span_id, MOCK_W3C_PARENT_SPAN_ID)
self.assertEqual(len(parent_context.trace_state), 3)
self.assertEqual(
parent_context.trace_state.get(MOCK_W3C_TRACE_STATE_KEY),
MOCK_W3C_TRACE_STATE_VALUE,
)
self.assertTrue(parent_context.is_remote)

test_env_patch.stop()
},
expected_traceid=MOCK_W3C_TRACE_ID,
expected_parentid=MOCK_W3C_PARENT_SPAN_ID,
expected_trace_state_len=3,
expected_state_value=MOCK_W3C_TRACE_STATE_VALUE,
xray_traceid=MOCK_XRAY_TRACE_CONTEXT_NOT_SAMPLED,
),
TestCase(
name="custom_extractor_not_sampled_xray",
custom_extractor=custom_event_context_extractor,
context={
"foo": {
"headers": {
TraceContextTextMapPropagator._TRACEPARENT_HEADER_NAME: MOCK_W3C_TRACE_CONTEXT_SAMPLED,
TraceContextTextMapPropagator._TRACESTATE_HEADER_NAME: f"{MOCK_W3C_TRACE_STATE_KEY}={MOCK_W3C_TRACE_STATE_VALUE},foo=1,bar=2",
}
}
},
expected_traceid=MOCK_W3C_TRACE_ID,
expected_parentid=MOCK_W3C_PARENT_SPAN_ID,
expected_trace_state_len=3,
expected_state_value=MOCK_W3C_TRACE_STATE_VALUE,
xray_traceid=MOCK_XRAY_TRACE_CONTEXT_NOT_SAMPLED,
),
TestCase(
name="custom_extractor_sampled_xray",
custom_extractor=custom_event_context_extractor,
context={
"foo": {
"headers": {
TraceContextTextMapPropagator._TRACEPARENT_HEADER_NAME: MOCK_W3C_TRACE_CONTEXT_SAMPLED,
TraceContextTextMapPropagator._TRACESTATE_HEADER_NAME: f"{MOCK_W3C_TRACE_STATE_KEY}={MOCK_W3C_TRACE_STATE_VALUE},foo=1,bar=2",
}
}
},
expected_traceid=MOCK_XRAY_TRACE_ID,
expected_parentid=MOCK_XRAY_PARENT_SPAN_ID,
xray_traceid=MOCK_XRAY_TRACE_CONTEXT_SAMPLED,
),
TestCase(
name="custom_extractor_sampled_xray_disable_aws_propagation",
custom_extractor=custom_event_context_extractor,
context={
"foo": {
"headers": {
TraceContextTextMapPropagator._TRACEPARENT_HEADER_NAME: MOCK_W3C_TRACE_CONTEXT_SAMPLED,
TraceContextTextMapPropagator._TRACESTATE_HEADER_NAME: f"{MOCK_W3C_TRACE_STATE_KEY}={MOCK_W3C_TRACE_STATE_VALUE},foo=1,bar=2",
}
}
},
disable_aws_context_propagation=True,
expected_traceid=MOCK_W3C_TRACE_ID,
expected_parentid=MOCK_W3C_PARENT_SPAN_ID,
expected_trace_state_len=3,
expected_state_value=MOCK_W3C_TRACE_STATE_VALUE,
xray_traceid=MOCK_XRAY_TRACE_CONTEXT_SAMPLED,
),
]
for test in tests:
test_env_patch = mock.patch.dict(
"os.environ",
{
**os.environ,
# NOT Active Tracing
_X_AMZN_TRACE_ID: test.xray_traceid,
# NOT using the X-Ray Propagator
OTEL_PROPAGATORS: "tracecontext",
},
)
test_env_patch.start()
AwsLambdaInstrumentor().instrument(
event_context_extractor=test.custom_extractor,
disable_aws_context_propagation=test.disable_aws_context_propagation,
)
mock_execute_lambda(test.context)
spans = self.memory_exporter.get_finished_spans()
assert spans
self.assertEqual(len(spans), 1)
span = spans[0]
self.assertEqual(
span.get_span_context().trace_id, test.expected_traceid
)

parent_context = span.parent
self.assertEqual(
parent_context.trace_id, span.get_span_context().trace_id
)
self.assertEqual(parent_context.span_id, test.expected_parentid)
self.assertEqual(
len(parent_context.trace_state), test.expected_trace_state_len
)
self.assertEqual(
parent_context.trace_state.get(MOCK_W3C_TRACE_STATE_KEY),
test.expected_state_value,
)
self.assertTrue(parent_context.is_remote)
self.memory_exporter.clear()
AwsLambdaInstrumentor().uninstrument()
test_env_patch.stop()

def test_lambda_no_error_with_invalid_flush_timeout(self):

Expand Down