diff --git a/CHANGELOG.md b/CHANGELOG.md index f359c6b633f..9b01102db18 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ([#3524](https://github.com/open-telemetry/opentelemetry-python/pull/3524)) - Handle `taskName` `logrecord` attribute ([#3557](https://github.com/open-telemetry/opentelemetry-python/pull/3557)) +- Add `span_id` to `Sampler.should_sample` + ([#3574](https://github.com/open-telemetry/opentelemetry-python/pull/3574)) ## Version 1.21.0/0.42b0 (2023-11-01) diff --git a/opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py b/opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py index 6dae70b2f6b..3758a6cc1fd 100644 --- a/opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py +++ b/opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py @@ -1074,6 +1074,8 @@ def start_span( # pylint: disable=too-many-locals trace_id = self.id_generator.generate_trace_id() else: trace_id = parent_span_context.trace_id + + span_id = self.id_generator.generate_span_id() # The sampler decides whether to create a real or no-op span at the # time of span creation. No-op spans do not record events, and are not @@ -1082,7 +1084,7 @@ def start_span( # pylint: disable=too-many-locals # to include information about the sampling result. # The sampler may also modify the parent span context's tracestate sampling_result = self.sampler.should_sample( - context, trace_id, name, kind, attributes, links + context, trace_id, name, kind, attributes, links, span_id=span_id, ) trace_flags = ( @@ -1092,7 +1094,7 @@ def start_span( # pylint: disable=too-many-locals ) span_context = trace_api.SpanContext( trace_id, - self.id_generator.generate_span_id(), + span_id, is_remote=False, trace_flags=trace_flags, trace_state=sampling_result.trace_state, diff --git a/opentelemetry-sdk/src/opentelemetry/sdk/trace/sampling.py b/opentelemetry-sdk/src/opentelemetry/sdk/trace/sampling.py index 0236fac6b62..d7dd1b3cd58 100644 --- a/opentelemetry-sdk/src/opentelemetry/sdk/trace/sampling.py +++ b/opentelemetry-sdk/src/opentelemetry/sdk/trace/sampling.py @@ -205,6 +205,7 @@ def should_sample( attributes: Attributes = None, links: Sequence["Link"] = None, trace_state: "TraceState" = None, + span_id: Optional[int] = None, ) -> "SamplingResult": pass @@ -228,6 +229,7 @@ def should_sample( attributes: Attributes = None, links: Sequence["Link"] = None, trace_state: "TraceState" = None, + span_id: Optional[int] = None, ) -> "SamplingResult": if self._decision is Decision.DROP: attributes = None @@ -289,6 +291,7 @@ def should_sample( attributes: Attributes = None, links: Sequence["Link"] = None, trace_state: "TraceState" = None, + span_id: Optional[int] = None, ) -> "SamplingResult": decision = Decision.DROP if trace_id & self.TRACE_ID_LIMIT < self.bound: @@ -344,6 +347,7 @@ def should_sample( attributes: Attributes = None, links: Sequence["Link"] = None, trace_state: "TraceState" = None, + span_id: Optional[int] = None, ) -> "SamplingResult": parent_span_context = get_current_span( parent_context diff --git a/opentelemetry-sdk/tests/trace/test_sampling.py b/opentelemetry-sdk/tests/trace/test_sampling.py index e976b0f551e..e5e43c71270 100644 --- a/opentelemetry-sdk/tests/trace/test_sampling.py +++ b/opentelemetry-sdk/tests/trace/test_sampling.py @@ -20,6 +20,7 @@ from opentelemetry import context as context_api from opentelemetry import trace from opentelemetry.sdk.trace import sampling +from opentelemetry.util.types import Attributes TO_DEFAULT = trace.TraceFlags(trace.TraceFlags.DEFAULT) TO_SAMPLED = trace.TraceFlags(trace.TraceFlags.SAMPLED) @@ -537,3 +538,98 @@ def implicit_parent_context(span: trace.Span): context_api.detach(token) self.exec_parent_based(implicit_parent_context) + + def test_sample_using_span_id(self): + + max_span_id = (2 ** 32) - 1 # span ids are 32 bit integers + + class AttributeBasedSampler(sampling.Sampler): + def __init__(self) -> None: + super().__init__() + + def should_sample( + self, + parent_context: typing.Optional[context_api.Context], + trace_id: int, + name: str, + kind: trace.SpanKind = None, + attributes: Attributes = None, + links: typing.Sequence[trace.Link] = None, + trace_state: trace.TraceState = None, + span_id: typing.Optional[int] = None, + ) -> sampling.SamplingResult: + sample_rate = typing.cast( + typing.Optional[float], (attributes or {}).get("sample_rate") + ) + if sample_rate is not None: + assert span_id is not None # or maybe always sample? + bound = round(sample_rate * (max_span_id + 1)) + if span_id & max_span_id < bound: + decision = sampling.Decision.RECORD_AND_SAMPLE + else: + decision = sampling.Decision.DROP + attributes = None + return sampling.SamplingResult( + decision, + attributes, + ) + return sampling.SamplingResult( + sampling.Decision.RECORD_AND_SAMPLE, + attributes, + ) + + def get_description(self) -> str: + return "AttributeBasedSampler" + + # sample rate of 0 is never sampled + trace_state = trace.TraceState([]) + context = self._create_parent(TO_SAMPLED, False, trace_state) + sample_result = AttributeBasedSampler().should_sample( + context, + 1, + "child", + trace.SpanKind.INTERNAL, + attributes={"sample_rate": 0, "foo": "bar"}, + span_id=1, + ) + self.assertFalse(sample_result.decision.is_sampled()) + self.assertEqual(sample_result.attributes, {}) + + # sample rate of 1 is always sampled + trace_state = trace.TraceState([]) + context = self._create_parent(TO_SAMPLED, False, trace_state) + sample_result = AttributeBasedSampler().should_sample( + context, + 1, + "child", + trace.SpanKind.INTERNAL, + attributes={"sample_rate": 1, "foo": "bar"}, + span_id=max_span_id - 1, + ) + self.assertTrue(sample_result.decision.is_sampled()) + self.assertEqual(sample_result.attributes, {"sample_rate": 1, "foo": "bar"}) + + # sample rate of 0.5 is sampled only if the span id is less than 2^64 * 0.5 + trace_state = trace.TraceState([]) + context = self._create_parent(TO_SAMPLED, False, trace_state) + sample_result = AttributeBasedSampler().should_sample( + context, + 1, + "child", + trace.SpanKind.INTERNAL, + attributes={"sample_rate": 0.5, "foo": "bar"}, + span_id=int(max_span_id * 0.4), + ) + self.assertTrue(sample_result.decision.is_sampled()) + self.assertEqual(sample_result.attributes, {"sample_rate": 0.5, "foo": "bar"}) + + sample_result = AttributeBasedSampler().should_sample( + context, + 1, + "child", + trace.SpanKind.INTERNAL, + attributes={"sample_rate": 0.5, "foo": "bar"}, + span_id=int(max_span_id * 0.6), + ) + self.assertFalse(sample_result.decision.is_sampled()) + self.assertEqual(sample_result.attributes, {})