diff --git a/opentelemetry-sdk/src/opentelemetry/sdk/trace/propagation/b3_format.py b/opentelemetry-sdk/src/opentelemetry/sdk/trace/propagation/b3_format.py index c1f71d4d87e..901a5772f83 100644 --- a/opentelemetry-sdk/src/opentelemetry/sdk/trace/propagation/b3_format.py +++ b/opentelemetry-sdk/src/opentelemetry/sdk/trace/propagation/b3_format.py @@ -13,9 +13,11 @@ # limitations under the License. import typing +from re import compile as re_compile import opentelemetry.trace as trace from opentelemetry.context import Context +from opentelemetry.sdk.trace import generate_span_id, generate_trace_id from opentelemetry.trace.propagation.httptextformat import ( Getter, HTTPTextFormat, @@ -37,6 +39,8 @@ class B3Format(HTTPTextFormat): SAMPLED_KEY = "x-b3-sampled" FLAGS_KEY = "x-b3-flags" _SAMPLE_PROPAGATE_VALUES = set(["1", "True", "true", "d"]) + _trace_id_regex = re_compile(r"[\da-fA-F]{16}|[\da-fA-F]{32}") + _span_id_regex = re_compile(r"[\da-fA-F]{16}") def extract( self, @@ -95,6 +99,18 @@ def extract( or flags ) + if ( + self._trace_id_regex.fullmatch(trace_id) is None + or self._span_id_regex.fullmatch(span_id) is None + ): + trace_id = generate_trace_id() + span_id = generate_span_id() + sampled = "0" + + else: + trace_id = int(trace_id, 16) + span_id = int(span_id, 16) + options = 0 # The b3 spec provides no defined behavior for both sample and # flag values set. Since the setting of at least one implies @@ -102,12 +118,13 @@ def extract( # header is set to allow. if sampled in self._SAMPLE_PROPAGATE_VALUES or flags == "1": options |= trace.TraceFlags.SAMPLED + return trace.set_span_in_context( trace.DefaultSpan( trace.SpanContext( # trace an span ids are encoded in hex, so must be converted - trace_id=int(trace_id, 16), - span_id=int(span_id, 16), + trace_id=trace_id, + span_id=span_id, is_remote=True, trace_flags=trace.TraceFlags(options), trace_state=trace.TraceState(), diff --git a/opentelemetry-sdk/tests/trace/propagation/test_b3_format.py b/opentelemetry-sdk/tests/trace/propagation/test_b3_format.py index a5bd1baaa48..bc508f3fd91 100644 --- a/opentelemetry-sdk/tests/trace/propagation/test_b3_format.py +++ b/opentelemetry-sdk/tests/trace/propagation/test_b3_format.py @@ -13,6 +13,7 @@ # limitations under the License. import unittest +from unittest.mock import patch import opentelemetry.sdk.trace as trace import opentelemetry.sdk.trace.propagation.b3_format as b3_format @@ -245,6 +246,50 @@ def test_missing_trace_id(self): span_context = trace_api.get_current_span(ctx).get_context() self.assertEqual(span_context.trace_id, trace_api.INVALID_TRACE_ID) + @patch("opentelemetry.sdk.trace.propagation.b3_format.generate_trace_id") + @patch("opentelemetry.sdk.trace.propagation.b3_format.generate_span_id") + def test_invalid_trace_id( + self, mock_generate_span_id, mock_generate_trace_id + ): + """If a trace id is invalid, generate a trace id.""" + + mock_generate_trace_id.configure_mock(return_value=1) + mock_generate_span_id.configure_mock(return_value=2) + + carrier = { + FORMAT.TRACE_ID_KEY: "abc123", + FORMAT.SPAN_ID_KEY: self.serialized_span_id, + FORMAT.FLAGS_KEY: "1", + } + + ctx = FORMAT.extract(get_as_list, carrier) + span_context = trace_api.get_current_span(ctx).get_context() + + self.assertEqual(span_context.trace_id, 1) + self.assertEqual(span_context.span_id, 2) + + @patch("opentelemetry.sdk.trace.propagation.b3_format.generate_trace_id") + @patch("opentelemetry.sdk.trace.propagation.b3_format.generate_span_id") + def test_invalid_span_id( + self, mock_generate_span_id, mock_generate_trace_id + ): + """If a span id is invalid, generate a trace id.""" + + mock_generate_trace_id.configure_mock(return_value=1) + mock_generate_span_id.configure_mock(return_value=2) + + carrier = { + FORMAT.TRACE_ID_KEY: self.serialized_trace_id, + FORMAT.SPAN_ID_KEY: "abc123", + FORMAT.FLAGS_KEY: "1", + } + + ctx = FORMAT.extract(get_as_list, carrier) + span_context = trace_api.get_current_span(ctx).get_context() + + self.assertEqual(span_context.trace_id, 1) + self.assertEqual(span_context.span_id, 2) + def test_missing_span_id(self): """If a trace id is missing, populate an invalid trace id.""" carrier = {