diff --git a/ddtrace/contrib/futures/threading.py b/ddtrace/contrib/futures/threading.py index ab67e215555..deea68e2c17 100644 --- a/ddtrace/contrib/futures/threading.py +++ b/ddtrace/contrib/futures/threading.py @@ -1,4 +1,7 @@ +from typing import Optional + import ddtrace +from ddtrace._trace.context import Context def _wrap_submit(func, args, kwargs): @@ -7,19 +10,8 @@ def _wrap_submit(func, args, kwargs): thread. This wrapper ensures that a new `Context` is created and properly propagated using an intermediate function. """ - # If there isn't a currently active context, then do not create one - # DEV: Calling `.active()` when there isn't an active context will create a new context - # DEV: We need to do this in case they are either: - # - Starting nested futures - # - Starting futures from outside of an existing context - # - # In either of these cases we essentially will propagate the wrong context between futures - # - # The resolution is to not create/propagate a new context if one does not exist, but let the - # future's thread create the context instead. - current_ctx = None - if ddtrace.tracer.context_provider._has_active_context(): - current_ctx = ddtrace.tracer.context_provider.active() + # DEV: Be sure to propagate a Context and not a Span since we are crossing thread boundaries + current_ctx: Optional[Context] = ddtrace.tracer.current_trace_context() # The target function can be provided as a kwarg argument "fn" or the first positional argument self = args[0] @@ -31,7 +23,7 @@ def _wrap_submit(func, args, kwargs): return func(self, _wrap_execution, current_ctx, fn, fn_args, kwargs) -def _wrap_execution(ctx, fn, args, kwargs): +def _wrap_execution(ctx: Optional[Context], fn, args, kwargs): """ Intermediate target function that is executed in a new thread; it receives the original function with arguments and keyword diff --git a/releasenotes/notes/fix-futures-propagation-f5b33579e0fdafc3.yaml b/releasenotes/notes/fix-futures-propagation-f5b33579e0fdafc3.yaml new file mode 100644 index 00000000000..b53ffd752de --- /dev/null +++ b/releasenotes/notes/fix-futures-propagation-f5b33579e0fdafc3.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + futures: Fixes inconsistent behavior with ``concurrent.futures.ThreadPoolExecutor`` context propagation by passing the current trace context instead of the currently active span to tasks. This prevents edge cases of disconnected spans when the task executes after the parent span has finished. diff --git a/tests/contrib/futures/test_propagation.py b/tests/contrib/futures/test_propagation.py index fbbebcd95fe..d4d5beb8946 100644 --- a/tests/contrib/futures/test_propagation.py +++ b/tests/contrib/futures/test_propagation.py @@ -1,4 +1,4 @@ -import concurrent +import concurrent.futures import time import pytest @@ -6,9 +6,19 @@ from ddtrace.contrib.futures import patch from ddtrace.contrib.futures import unpatch from tests.opentracer.utils import init_tracer +from tests.utils import DummyTracer from tests.utils import TracerTestCase +@pytest.fixture(autouse=True) +def patch_futures(): + patch() + try: + yield + finally: + unpatch() + + class PropagationTestCase(TracerTestCase): """Ensures the Context Propagation works between threads when the ``futures`` library is used, or when the @@ -43,10 +53,15 @@ def fn(): self.assertEqual(result, 42) # the trace must be completed - self.assert_structure( - dict(name="main.thread"), - (dict(name="executor.thread"),), - ) + roots = self.get_root_spans() + assert len(roots) == 1 + root = roots[0] + assert root.name == "main.thread" + spans = root.get_spans() + assert len(spans) == 1 + assert spans[0].name == "executor.thread" + assert spans[0].trace_id == root.trace_id + assert spans[0].parent_id == root.span_id def test_propagation_with_params(self): # instrumentation must proxy arguments if available @@ -65,10 +80,15 @@ def fn(value, key=None): self.assertEqual(key, "CheeseShop") # the trace must be completed - self.assert_structure( - dict(name="main.thread"), - (dict(name="executor.thread"),), - ) + roots = self.get_root_spans() + assert len(roots) == 1 + root = roots[0] + assert root.name == "main.thread" + spans = root.get_spans() + assert len(spans) == 1 + assert spans[0].name == "executor.thread" + assert spans[0].trace_id == root.trace_id + assert spans[0].parent_id == root.span_id def test_propagation_with_kwargs(self): # instrumentation must work if only kwargs are provided @@ -87,10 +107,15 @@ def fn(value, key=None): self.assertEqual(key, "CheeseShop") # the trace must be completed - self.assert_structure( - dict(name="main.thread"), - (dict(name="executor.thread"),), - ) + roots = self.get_root_spans() + assert len(roots) == 1 + root = roots[0] + assert root.name == "main.thread" + spans = root.get_spans() + assert len(spans) == 1 + assert spans[0].name == "executor.thread" + assert spans[0].trace_id == root.trace_id + assert spans[0].parent_id == root.span_id def test_disabled_instrumentation(self): # it must not propagate if the module is disabled @@ -116,8 +141,10 @@ def fn(): traces = self.get_root_spans() self.assertEqual(len(traces), 2) - traces[0].assert_structure(dict(name="main.thread")) - traces[1].assert_structure(dict(name="executor.thread")) + assert traces[0].name == "main.thread" + assert traces[1].name == "executor.thread" + assert traces[1].trace_id != traces[0].trace_id + assert traces[1].parent_id is None def test_double_instrumentation(self): # double instrumentation must not happen @@ -136,10 +163,15 @@ def fn(): self.assertEqual(result, 42) # the trace must be completed - self.assert_structure( - dict(name="main.thread"), - (dict(name="executor.thread"),), - ) + root_spans = self.get_root_spans() + self.assertEqual(len(root_spans), 1) + root = root_spans[0] + assert root.name == "main.thread" + spans = root.get_spans() + assert len(spans) == 1 + assert spans[0].name == "executor.thread" + assert spans[0].trace_id == root.trace_id + assert spans[0].parent_id == root.span_id def test_no_parent_span(self): def fn(): @@ -154,7 +186,10 @@ def fn(): self.assertEqual(result, 42) # the trace must be completed - self.assert_structure(dict(name="executor.thread")) + spans = self.get_spans() + assert len(spans) == 1 + assert spans[0].name == "executor.thread" + assert spans[0].parent_id is None def test_multiple_futures(self): def fn(): @@ -171,15 +206,17 @@ def fn(): self.assertEqual(result, 42) # the trace must be completed - self.assert_structure( - dict(name="main.thread"), - ( - dict(name="executor.thread"), - dict(name="executor.thread"), - dict(name="executor.thread"), - dict(name="executor.thread"), - ), - ) + roots = self.get_root_spans() + assert len(roots) == 1 + root = roots[0] + assert root.name == "main.thread" + + spans = root.get_spans() + assert len(spans) == 4 + for span in spans: + assert span.name == "executor.thread" + assert span.trace_id == root.trace_id + assert span.parent_id == root.span_id def test_multiple_futures_no_parent(self): def fn(): @@ -196,10 +233,11 @@ def fn(): # the trace must be completed self.assert_span_count(4) - traces = self.get_root_spans() - self.assertEqual(len(traces), 4) - for trace in traces: - trace.assert_structure(dict(name="executor.thread")) + root_spans = self.get_root_spans() + self.assertEqual(len(root_spans), 4) + for root in root_spans: + assert root.name == "executor.thread" + assert root.parent_id is None def test_nested_futures(self): def fn2(): @@ -224,15 +262,14 @@ def fn(): # the trace must be completed self.assert_span_count(3) - self.assert_structure( - dict(name="main.thread"), - ( - ( - dict(name="executor.thread"), - (dict(name="nested.thread"),), - ), - ), - ) + spans = self.get_spans() + assert spans[0].name == "main.thread" + assert spans[1].name == "executor.thread" + assert spans[1].trace_id == spans[0].trace_id + assert spans[1].parent_id == spans[0].span_id + assert spans[2].name == "nested.thread" + assert spans[2].trace_id == spans[0].trace_id + assert spans[2].parent_id == spans[1].span_id def test_multiple_nested_futures(self): def fn2(): @@ -258,16 +295,25 @@ def fn(): self.assertEqual(result, 42) # the trace must be completed - self.assert_structure( - dict(name="main.thread"), - ( - ( - dict(name="executor.thread"), - (dict(name="nested.thread"),) * 4, - ), - ) - * 4, - ) + traces = self.get_root_spans() + self.assertEqual(len(traces), 1) + + for root in traces: + assert root.name == "main.thread" + + exec_spans = root.get_spans() + assert len(exec_spans) == 4 + for exec_span in exec_spans: + assert exec_span.name == "executor.thread" + assert exec_span.trace_id == root.trace_id + assert exec_span.parent_id == root.span_id + + spans = exec_span.get_spans() + assert len(spans) == 4 + for i in range(4): + assert spans[i].name == "nested.thread" + assert spans[i].trace_id == exec_span.trace_id + assert spans[i].parent_id == exec_span.span_id def test_multiple_nested_futures_no_parent(self): def fn2(): @@ -295,11 +341,15 @@ def fn(): traces = self.get_root_spans() self.assertEqual(len(traces), 4) - for trace in traces: - trace.assert_structure( - dict(name="executor.thread"), - (dict(name="nested.thread"),) * 4, - ) + for root in traces: + assert root.name == "executor.thread" + + spans = root.get_spans() + assert len(spans) == 4 + for i in range(4): + assert spans[i].name == "nested.thread" + assert spans[i].trace_id == root.trace_id + assert spans[i].parent_id == root.span_id def test_send_trace_when_finished(self): # it must send the trace only when all threads are finished @@ -322,10 +372,11 @@ def fn(): self.assertEqual(result, 42) self.assert_span_count(2) - self.assert_structure( - dict(name="main.thread"), - (dict(name="executor.thread"),), - ) + spans = self.get_spans() + assert spans[0].name == "main.thread" + assert spans[1].name == "executor.thread" + assert spans[1].trace_id == spans[0].trace_id + assert spans[1].parent_id == spans[0].span_id def test_propagation_ot(self): """OpenTracing version of test_propagation.""" @@ -347,10 +398,12 @@ def fn(): self.assertEqual(result, 42) # the trace must be completed - self.assert_structure( - dict(name="main.thread"), - (dict(name="executor.thread"),), - ) + self.assert_span_count(2) + spans = self.get_spans() + assert spans[0].name == "main.thread" + assert spans[1].name == "executor.thread" + assert spans[1].trace_id == spans[0].trace_id + assert spans[1].parent_id == spans[0].span_id @pytest.mark.subprocess(ddtrace_run=True, timeout=5) @@ -384,3 +437,58 @@ def test_concurrent_futures_with_gevent(): assert result == 42 sys.exit(0) os.waitpid(pid, 0) + + +def test_submit_no_wait(tracer: DummyTracer): + executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + + futures = [] + + def work(): + # This is our expected scenario + assert tracer.current_trace_context() is not None + assert tracer.current_span() is None + + # DEV: This is the regression case that was raising + # tracer.current_span().set_tag("work", "done") + + with tracer.trace("work"): + pass + + def task(): + with tracer.trace("task"): + for _ in range(4): + futures.append(executor.submit(work)) + + with tracer.trace("main"): + task() + + # Make sure all background tasks are done + executor.shutdown(wait=True) + + # Make sure no exceptions were raised in the tasks + for future in futures: + assert future.done() + assert future.exception() is None + assert future.result() is None + + traces = tracer.pop_traces() + assert len(traces) == 4 + + assert len(traces[0]) == 3 + root_span, task_span, work_span = traces[0] + assert root_span.name == "main" + + assert task_span.name == "task" + assert task_span.parent_id == root_span.span_id + assert task_span.trace_id == root_span.trace_id + + assert work_span.name == "work" + assert work_span.parent_id == task_span.span_id + assert work_span.trace_id == root_span.trace_id + + for work_spans in traces[2:]: + (work_span,) = work_spans + assert work_span.name == "work" + assert work_span.parent_id == task_span.span_id + assert work_span.trace_id == root_span.trace_id