diff --git a/CHANGELOG.md b/CHANGELOG.md index 4921f8a632..db85c114f8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ([#685](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/685)) - Add metric instrumentation for tornado ([#1252](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1252)) +- `opentelemetry-instrumentation-aws-lambda` Add option to disable aws context propagation + ([#1466](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1466)) ### Added diff --git a/instrumentation/opentelemetry-instrumentation-aws-lambda/src/opentelemetry/instrumentation/aws_lambda/__init__.py b/instrumentation/opentelemetry-instrumentation-aws-lambda/src/opentelemetry/instrumentation/aws_lambda/__init__.py index 115709bc83..11769c729d 100644 --- a/instrumentation/opentelemetry-instrumentation-aws-lambda/src/opentelemetry/instrumentation/aws_lambda/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-aws-lambda/src/opentelemetry/instrumentation/aws_lambda/__init__.py @@ -134,7 +134,9 @@ def _default_event_context_extractor(lambda_event: Any) -> Context: def _determine_parent_context( - lambda_event: Any, event_context_extractor: Callable[[Any], Context] + lambda_event: Any, + event_context_extractor: Callable[[Any], Context], + disable_aws_context_propagation: bool = False, ) -> Context: """Determine the parent context for the current Lambda invocation. @@ -144,17 +146,25 @@ def _determine_parent_context( Args: lambda_event: user-defined, so it could be anything, but this method counts it being a map with a 'headers' key + event_context_extractor: a method which takes the Lambda + Event as input and extracts an OTel Context from it. By default, + the context is extracted from the HTTP headers of an API Gateway + request. + disable_aws_context_propagation: By default, this instrumentation + will try to read the context from the `_X_AMZN_TRACE_ID` environment + variable set by Lambda, set this to `True` to disable this behavior. Returns: A Context with configuration found in the carrier. """ parent_context = None - xray_env_var = os.environ.get(_X_AMZN_TRACE_ID) + if not disable_aws_context_propagation: + xray_env_var = os.environ.get(_X_AMZN_TRACE_ID) - if xray_env_var: - parent_context = AwsXRayPropagator().extract( - {TRACE_HEADER_KEY: xray_env_var} - ) + if xray_env_var: + parent_context = AwsXRayPropagator().extract( + {TRACE_HEADER_KEY: xray_env_var} + ) if ( parent_context @@ -258,6 +268,7 @@ def _instrument( flush_timeout, event_context_extractor: Callable[[Any], Context], tracer_provider: TracerProvider = None, + disable_aws_context_propagation: bool = False, ): def _instrumented_lambda_handler_call( call_wrapped, instance, args, kwargs @@ -269,7 +280,9 @@ def _instrumented_lambda_handler_call( lambda_event = args[0] parent_context = _determine_parent_context( - lambda_event, event_context_extractor + lambda_event, + event_context_extractor, + disable_aws_context_propagation, ) span_kind = None @@ -368,6 +381,9 @@ def _instrument(self, **kwargs): Event as input and extracts an OTel Context from it. By default, the context is extracted from the HTTP headers of an API Gateway request. + ``disable_aws_context_propagation``: By default, this instrumentation + will try to read the context from the `_X_AMZN_TRACE_ID` environment + variable set by Lambda, set this to `True` to disable this behavior. """ lambda_handler = os.environ.get(ORIG_HANDLER, os.environ.get(_HANDLER)) # pylint: disable=attribute-defined-outside-init @@ -377,11 +393,12 @@ def _instrument(self, **kwargs): ) = lambda_handler.rsplit(".", 1) flush_timeout_env = os.environ.get( - OTEL_INSTRUMENTATION_AWS_LAMBDA_FLUSH_TIMEOUT, "" + OTEL_INSTRUMENTATION_AWS_LAMBDA_FLUSH_TIMEOUT, None ) flush_timeout = 30000 try: - flush_timeout = int(flush_timeout_env) + if flush_timeout_env is not None: + flush_timeout = int(flush_timeout_env) except ValueError: logger.warning( "Could not convert OTEL_INSTRUMENTATION_AWS_LAMBDA_FLUSH_TIMEOUT value %s to int", @@ -396,6 +413,9 @@ def _instrument(self, **kwargs): "event_context_extractor", _default_event_context_extractor ), tracer_provider=kwargs.get("tracer_provider"), + disable_aws_context_propagation=kwargs.get( + "disable_aws_context_propagation", False + ), ) def _uninstrument(self, **kwargs): diff --git a/instrumentation/opentelemetry-instrumentation-aws-lambda/tests/test_aws_lambda_instrumentation_manual.py b/instrumentation/opentelemetry-instrumentation-aws-lambda/tests/test_aws_lambda_instrumentation_manual.py index 496829fe4e..2936f04718 100644 --- a/instrumentation/opentelemetry-instrumentation-aws-lambda/tests/test_aws_lambda_instrumentation_manual.py +++ b/instrumentation/opentelemetry-instrumentation-aws-lambda/tests/test_aws_lambda_instrumentation_manual.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from dataclasses import dataclass from importlib import import_module +from typing import Any, Callable, Dict from unittest import mock from mocks.api_gateway_http_api_event import ( @@ -155,103 +157,129 @@ def test_active_tracing(self): test_env_patch.stop() def test_parent_context_from_lambda_event(self): - test_env_patch = mock.patch.dict( - "os.environ", - { - **os.environ, - # NOT Active Tracing - _X_AMZN_TRACE_ID: MOCK_XRAY_TRACE_CONTEXT_NOT_SAMPLED, - # NOT using the X-Ray Propagator - OTEL_PROPAGATORS: "tracecontext", - }, - ) - test_env_patch.start() - - AwsLambdaInstrumentor().instrument() - - mock_execute_lambda( - { - "headers": { - TraceContextTextMapPropagator._TRACEPARENT_HEADER_NAME: MOCK_W3C_TRACE_CONTEXT_SAMPLED, - TraceContextTextMapPropagator._TRACESTATE_HEADER_NAME: f"{MOCK_W3C_TRACE_STATE_KEY}={MOCK_W3C_TRACE_STATE_VALUE},foo=1,bar=2", - } - } - ) - - spans = self.memory_exporter.get_finished_spans() + @dataclass + class TestCase: + name: str + custom_extractor: Callable[[Any], None] + context: Dict + expected_traceid: int + expected_parentid: int + xray_traceid: str + expected_state_value: str = None + expected_trace_state_len: int = 0 + disable_aws_context_propagation: bool = False - assert spans - - self.assertEqual(len(spans), 1) - span = spans[0] - self.assertEqual(span.get_span_context().trace_id, MOCK_W3C_TRACE_ID) - - parent_context = span.parent - self.assertEqual( - parent_context.trace_id, span.get_span_context().trace_id - ) - self.assertEqual(parent_context.span_id, MOCK_W3C_PARENT_SPAN_ID) - self.assertEqual(len(parent_context.trace_state), 3) - self.assertEqual( - parent_context.trace_state.get(MOCK_W3C_TRACE_STATE_KEY), - MOCK_W3C_TRACE_STATE_VALUE, - ) - self.assertTrue(parent_context.is_remote) - - test_env_patch.stop() - - def test_using_custom_extractor(self): def custom_event_context_extractor(lambda_event): return get_global_textmap().extract(lambda_event["foo"]["headers"]) - test_env_patch = mock.patch.dict( - "os.environ", - { - **os.environ, - # NOT Active Tracing - _X_AMZN_TRACE_ID: MOCK_XRAY_TRACE_CONTEXT_NOT_SAMPLED, - # NOT using the X-Ray Propagator - OTEL_PROPAGATORS: "tracecontext", - }, - ) - test_env_patch.start() - - AwsLambdaInstrumentor().instrument( - event_context_extractor=custom_event_context_extractor, - ) - - mock_execute_lambda( - { - "foo": { + tests = [ + TestCase( + name="no_custom_extractor", + custom_extractor=None, + context={ "headers": { TraceContextTextMapPropagator._TRACEPARENT_HEADER_NAME: MOCK_W3C_TRACE_CONTEXT_SAMPLED, TraceContextTextMapPropagator._TRACESTATE_HEADER_NAME: f"{MOCK_W3C_TRACE_STATE_KEY}={MOCK_W3C_TRACE_STATE_VALUE},foo=1,bar=2", } - } - } - ) - - spans = self.memory_exporter.get_finished_spans() - - assert spans - - self.assertEqual(len(spans), 1) - span = spans[0] - self.assertEqual(span.get_span_context().trace_id, MOCK_W3C_TRACE_ID) - - parent_context = span.parent - self.assertEqual( - parent_context.trace_id, span.get_span_context().trace_id - ) - self.assertEqual(parent_context.span_id, MOCK_W3C_PARENT_SPAN_ID) - self.assertEqual(len(parent_context.trace_state), 3) - self.assertEqual( - parent_context.trace_state.get(MOCK_W3C_TRACE_STATE_KEY), - MOCK_W3C_TRACE_STATE_VALUE, - ) - self.assertTrue(parent_context.is_remote) - - test_env_patch.stop() + }, + expected_traceid=MOCK_W3C_TRACE_ID, + expected_parentid=MOCK_W3C_PARENT_SPAN_ID, + expected_trace_state_len=3, + expected_state_value=MOCK_W3C_TRACE_STATE_VALUE, + xray_traceid=MOCK_XRAY_TRACE_CONTEXT_NOT_SAMPLED, + ), + TestCase( + name="custom_extractor_not_sampled_xray", + custom_extractor=custom_event_context_extractor, + context={ + "foo": { + "headers": { + TraceContextTextMapPropagator._TRACEPARENT_HEADER_NAME: MOCK_W3C_TRACE_CONTEXT_SAMPLED, + TraceContextTextMapPropagator._TRACESTATE_HEADER_NAME: f"{MOCK_W3C_TRACE_STATE_KEY}={MOCK_W3C_TRACE_STATE_VALUE},foo=1,bar=2", + } + } + }, + expected_traceid=MOCK_W3C_TRACE_ID, + expected_parentid=MOCK_W3C_PARENT_SPAN_ID, + expected_trace_state_len=3, + expected_state_value=MOCK_W3C_TRACE_STATE_VALUE, + xray_traceid=MOCK_XRAY_TRACE_CONTEXT_NOT_SAMPLED, + ), + TestCase( + name="custom_extractor_sampled_xray", + custom_extractor=custom_event_context_extractor, + context={ + "foo": { + "headers": { + TraceContextTextMapPropagator._TRACEPARENT_HEADER_NAME: MOCK_W3C_TRACE_CONTEXT_SAMPLED, + TraceContextTextMapPropagator._TRACESTATE_HEADER_NAME: f"{MOCK_W3C_TRACE_STATE_KEY}={MOCK_W3C_TRACE_STATE_VALUE},foo=1,bar=2", + } + } + }, + expected_traceid=MOCK_XRAY_TRACE_ID, + expected_parentid=MOCK_XRAY_PARENT_SPAN_ID, + xray_traceid=MOCK_XRAY_TRACE_CONTEXT_SAMPLED, + ), + TestCase( + name="custom_extractor_sampled_xray_disable_aws_propagation", + custom_extractor=custom_event_context_extractor, + context={ + "foo": { + "headers": { + TraceContextTextMapPropagator._TRACEPARENT_HEADER_NAME: MOCK_W3C_TRACE_CONTEXT_SAMPLED, + TraceContextTextMapPropagator._TRACESTATE_HEADER_NAME: f"{MOCK_W3C_TRACE_STATE_KEY}={MOCK_W3C_TRACE_STATE_VALUE},foo=1,bar=2", + } + } + }, + disable_aws_context_propagation=True, + expected_traceid=MOCK_W3C_TRACE_ID, + expected_parentid=MOCK_W3C_PARENT_SPAN_ID, + expected_trace_state_len=3, + expected_state_value=MOCK_W3C_TRACE_STATE_VALUE, + xray_traceid=MOCK_XRAY_TRACE_CONTEXT_SAMPLED, + ), + ] + for test in tests: + test_env_patch = mock.patch.dict( + "os.environ", + { + **os.environ, + # NOT Active Tracing + _X_AMZN_TRACE_ID: test.xray_traceid, + # NOT using the X-Ray Propagator + OTEL_PROPAGATORS: "tracecontext", + }, + ) + test_env_patch.start() + AwsLambdaInstrumentor().instrument( + event_context_extractor=test.custom_extractor, + disable_aws_context_propagation=test.disable_aws_context_propagation, + ) + mock_execute_lambda(test.context) + spans = self.memory_exporter.get_finished_spans() + assert spans + self.assertEqual(len(spans), 1) + span = spans[0] + self.assertEqual( + span.get_span_context().trace_id, test.expected_traceid + ) + + parent_context = span.parent + self.assertEqual( + parent_context.trace_id, span.get_span_context().trace_id + ) + self.assertEqual(parent_context.span_id, test.expected_parentid) + self.assertEqual( + len(parent_context.trace_state), test.expected_trace_state_len + ) + self.assertEqual( + parent_context.trace_state.get(MOCK_W3C_TRACE_STATE_KEY), + test.expected_state_value, + ) + self.assertTrue(parent_context.is_remote) + self.memory_exporter.clear() + AwsLambdaInstrumentor().uninstrument() + test_env_patch.stop() def test_lambda_no_error_with_invalid_flush_timeout(self):