Skip to content

Commit

Permalink
fix(futures): fix incorrect context propgation with ThreadPoolExecuto…
Browse files Browse the repository at this point in the history
…r [backport 2.9] (#9603)

Backport 109ba08 from #9588 to 2.9.

There is a bug when scheduling work onto a `ThreadPoolExecutor` and not
waiting for the response (e.g. `pool.submit(work)`, and ignoring the
future) we not properly associate the spans created in the task with the
trace that was active when submitting the task.

The reason for this bug is because we propagate the currently active
span (parent) to the child task, however, if the parent span finishes
before the child task can create it's first span, we no longer consider
the parent span active/available to inherit from. This is because our
context management code does not work if passing spans between thread or
process boundaries.

The solution is to instead pass the active span's Context to the child
task. This is a similar process as passing context between two
services/processes via HTTP headers (for example). This change will
allow the child task's spans to be properly associated with the parent
span regardless of the execution order.


This issue can be highlighted by the following example:

```python

pool = ThreadPoolExecutor(max_workers=1)

def task():
    parent_span = tracer.current_span()
    assert parent_span is not None
    time.sleep(1)


with tracer.trace("parent"):
    for _ in range(10):
        pool.submit(task)
```

The first execution of `task` will (probably) succeed without any issues
because the parent span is likely still active at that time. However,
when each additional task executes the assertion will fail because the
parent span is no longer an active span so `tracer.current_span()` will
return `None`.

This example shows that only the first execution of `task` will be
properly associated with the parent span/trace, the other calls to
`task` will be disconnected traces.

This fix will resolve this inconsistent and unexpected behavior to
ensure that the spans created in `task` will always be properly
associated with the parent span/trace.

This change may impact people who were expecting to access the current
span in the child task, but before creating any spans in the child task
(the code sample above), as the span will no longer be available via
`tracer.current_span()`.


## Checklist

- [x] Change(s) are motivated and described in the PR description
- [x] Testing strategy is described if automated tests are not included
in the PR
- [x] Risks are described (performance impact, potential for breakage,
maintainability)
- [x] Change is maintainable (easy to change, telemetry, documentation)
- [x] [Library release note
guidelines](https://ddtrace.readthedocs.io/en/stable/releasenotes.html)
are followed or label `changelog/no-changelog` is set
- [x] Documentation is included (in-code, generated user docs, [public
corp docs](https://github.com/DataDog/documentation/))
- [x] Backport labels are set (if
[applicable](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting))
- [x] If this PR changes the public interface, I've notified
`@DataDog/apm-tees`.

## Reviewer Checklist

- [x] Title is accurate
- [x] All changes are related to the pull request's stated goal
- [x] Description motivates each change
- [x] Avoids breaking
[API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces)
changes
- [x] Testing strategy adequately addresses listed risks
- [x] Change is maintainable (easy to change, telemetry, documentation)
- [x] Release note makes sense to a user of the library
- [x] Author has acknowledged and discussed the performance implications
of this PR as reported in the benchmarks PR comment
- [x] Backport labels are set in a manner that is consistent with the
[release branch maintenance
policy](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)

Co-authored-by: Brett Langdon <brett.langdon@datadoghq.com>
  • Loading branch information
github-actions[bot] and brettlangdon authored Jun 21, 2024
1 parent 8c934bb commit 944c395
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 79 deletions.
20 changes: 6 additions & 14 deletions ddtrace/contrib/futures/threading.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Optional

import ddtrace
from ddtrace._trace.context import Context


def _wrap_submit(func, args, kwargs):
Expand All @@ -7,19 +10,8 @@ def _wrap_submit(func, args, kwargs):
thread. This wrapper ensures that a new `Context` is created and
properly propagated using an intermediate function.
"""
# If there isn't a currently active context, then do not create one
# DEV: Calling `.active()` when there isn't an active context will create a new context
# DEV: We need to do this in case they are either:
# - Starting nested futures
# - Starting futures from outside of an existing context
#
# In either of these cases we essentially will propagate the wrong context between futures
#
# The resolution is to not create/propagate a new context if one does not exist, but let the
# future's thread create the context instead.
current_ctx = None
if ddtrace.tracer.context_provider._has_active_context():
current_ctx = ddtrace.tracer.context_provider.active()
# DEV: Be sure to propagate a Context and not a Span since we are crossing thread boundaries
current_ctx: Optional[Context] = ddtrace.tracer.current_trace_context()

# The target function can be provided as a kwarg argument "fn" or the first positional argument
self = args[0]
Expand All @@ -31,7 +23,7 @@ def _wrap_submit(func, args, kwargs):
return func(self, _wrap_execution, current_ctx, fn, fn_args, kwargs)


def _wrap_execution(ctx, fn, args, kwargs):
def _wrap_execution(ctx: Optional[Context], fn, args, kwargs):
"""
Intermediate target function that is executed in a new thread;
it receives the original function with arguments and keyword
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
fixes:
- |
futures: Fixes inconsistent behavior with ``concurrent.futures.ThreadPoolExecutor`` context propagation by passing the current trace context instead of the currently active span to tasks. This prevents edge cases of disconnected spans when the task executes after the parent span has finished.
238 changes: 173 additions & 65 deletions tests/contrib/futures/test_propagation.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
import concurrent
import concurrent.futures
import time

import pytest

from ddtrace.contrib.futures import patch
from ddtrace.contrib.futures import unpatch
from tests.opentracer.utils import init_tracer
from tests.utils import DummyTracer
from tests.utils import TracerTestCase


@pytest.fixture(autouse=True)
def patch_futures():
patch()
try:
yield
finally:
unpatch()


class PropagationTestCase(TracerTestCase):
"""Ensures the Context Propagation works between threads
when the ``futures`` library is used, or when the
Expand Down Expand Up @@ -43,10 +53,15 @@ def fn():
self.assertEqual(result, 42)

# the trace must be completed
self.assert_structure(
dict(name="main.thread"),
(dict(name="executor.thread"),),
)
roots = self.get_root_spans()
assert len(roots) == 1
root = roots[0]
assert root.name == "main.thread"
spans = root.get_spans()
assert len(spans) == 1
assert spans[0].name == "executor.thread"
assert spans[0].trace_id == root.trace_id
assert spans[0].parent_id == root.span_id

def test_propagation_with_params(self):
# instrumentation must proxy arguments if available
Expand All @@ -65,10 +80,15 @@ def fn(value, key=None):
self.assertEqual(key, "CheeseShop")

# the trace must be completed
self.assert_structure(
dict(name="main.thread"),
(dict(name="executor.thread"),),
)
roots = self.get_root_spans()
assert len(roots) == 1
root = roots[0]
assert root.name == "main.thread"
spans = root.get_spans()
assert len(spans) == 1
assert spans[0].name == "executor.thread"
assert spans[0].trace_id == root.trace_id
assert spans[0].parent_id == root.span_id

def test_propagation_with_kwargs(self):
# instrumentation must work if only kwargs are provided
Expand All @@ -87,10 +107,15 @@ def fn(value, key=None):
self.assertEqual(key, "CheeseShop")

# the trace must be completed
self.assert_structure(
dict(name="main.thread"),
(dict(name="executor.thread"),),
)
roots = self.get_root_spans()
assert len(roots) == 1
root = roots[0]
assert root.name == "main.thread"
spans = root.get_spans()
assert len(spans) == 1
assert spans[0].name == "executor.thread"
assert spans[0].trace_id == root.trace_id
assert spans[0].parent_id == root.span_id

def test_disabled_instrumentation(self):
# it must not propagate if the module is disabled
Expand All @@ -116,8 +141,10 @@ def fn():
traces = self.get_root_spans()
self.assertEqual(len(traces), 2)

traces[0].assert_structure(dict(name="main.thread"))
traces[1].assert_structure(dict(name="executor.thread"))
assert traces[0].name == "main.thread"
assert traces[1].name == "executor.thread"
assert traces[1].trace_id != traces[0].trace_id
assert traces[1].parent_id is None

def test_double_instrumentation(self):
# double instrumentation must not happen
Expand All @@ -136,10 +163,15 @@ def fn():
self.assertEqual(result, 42)

# the trace must be completed
self.assert_structure(
dict(name="main.thread"),
(dict(name="executor.thread"),),
)
root_spans = self.get_root_spans()
self.assertEqual(len(root_spans), 1)
root = root_spans[0]
assert root.name == "main.thread"
spans = root.get_spans()
assert len(spans) == 1
assert spans[0].name == "executor.thread"
assert spans[0].trace_id == root.trace_id
assert spans[0].parent_id == root.span_id

def test_no_parent_span(self):
def fn():
Expand All @@ -154,7 +186,10 @@ def fn():
self.assertEqual(result, 42)

# the trace must be completed
self.assert_structure(dict(name="executor.thread"))
spans = self.get_spans()
assert len(spans) == 1
assert spans[0].name == "executor.thread"
assert spans[0].parent_id is None

def test_multiple_futures(self):
def fn():
Expand All @@ -171,15 +206,17 @@ def fn():
self.assertEqual(result, 42)

# the trace must be completed
self.assert_structure(
dict(name="main.thread"),
(
dict(name="executor.thread"),
dict(name="executor.thread"),
dict(name="executor.thread"),
dict(name="executor.thread"),
),
)
roots = self.get_root_spans()
assert len(roots) == 1
root = roots[0]
assert root.name == "main.thread"

spans = root.get_spans()
assert len(spans) == 4
for span in spans:
assert span.name == "executor.thread"
assert span.trace_id == root.trace_id
assert span.parent_id == root.span_id

def test_multiple_futures_no_parent(self):
def fn():
Expand All @@ -196,10 +233,11 @@ def fn():

# the trace must be completed
self.assert_span_count(4)
traces = self.get_root_spans()
self.assertEqual(len(traces), 4)
for trace in traces:
trace.assert_structure(dict(name="executor.thread"))
root_spans = self.get_root_spans()
self.assertEqual(len(root_spans), 4)
for root in root_spans:
assert root.name == "executor.thread"
assert root.parent_id is None

def test_nested_futures(self):
def fn2():
Expand All @@ -224,15 +262,14 @@ def fn():

# the trace must be completed
self.assert_span_count(3)
self.assert_structure(
dict(name="main.thread"),
(
(
dict(name="executor.thread"),
(dict(name="nested.thread"),),
),
),
)
spans = self.get_spans()
assert spans[0].name == "main.thread"
assert spans[1].name == "executor.thread"
assert spans[1].trace_id == spans[0].trace_id
assert spans[1].parent_id == spans[0].span_id
assert spans[2].name == "nested.thread"
assert spans[2].trace_id == spans[0].trace_id
assert spans[2].parent_id == spans[1].span_id

def test_multiple_nested_futures(self):
def fn2():
Expand All @@ -258,16 +295,25 @@ def fn():
self.assertEqual(result, 42)

# the trace must be completed
self.assert_structure(
dict(name="main.thread"),
(
(
dict(name="executor.thread"),
(dict(name="nested.thread"),) * 4,
),
)
* 4,
)
traces = self.get_root_spans()
self.assertEqual(len(traces), 1)

for root in traces:
assert root.name == "main.thread"

exec_spans = root.get_spans()
assert len(exec_spans) == 4
for exec_span in exec_spans:
assert exec_span.name == "executor.thread"
assert exec_span.trace_id == root.trace_id
assert exec_span.parent_id == root.span_id

spans = exec_span.get_spans()
assert len(spans) == 4
for i in range(4):
assert spans[i].name == "nested.thread"
assert spans[i].trace_id == exec_span.trace_id
assert spans[i].parent_id == exec_span.span_id

def test_multiple_nested_futures_no_parent(self):
def fn2():
Expand Down Expand Up @@ -295,11 +341,15 @@ def fn():
traces = self.get_root_spans()
self.assertEqual(len(traces), 4)

for trace in traces:
trace.assert_structure(
dict(name="executor.thread"),
(dict(name="nested.thread"),) * 4,
)
for root in traces:
assert root.name == "executor.thread"

spans = root.get_spans()
assert len(spans) == 4
for i in range(4):
assert spans[i].name == "nested.thread"
assert spans[i].trace_id == root.trace_id
assert spans[i].parent_id == root.span_id

def test_send_trace_when_finished(self):
# it must send the trace only when all threads are finished
Expand All @@ -322,10 +372,11 @@ def fn():
self.assertEqual(result, 42)

self.assert_span_count(2)
self.assert_structure(
dict(name="main.thread"),
(dict(name="executor.thread"),),
)
spans = self.get_spans()
assert spans[0].name == "main.thread"
assert spans[1].name == "executor.thread"
assert spans[1].trace_id == spans[0].trace_id
assert spans[1].parent_id == spans[0].span_id

def test_propagation_ot(self):
"""OpenTracing version of test_propagation."""
Expand All @@ -347,10 +398,12 @@ def fn():
self.assertEqual(result, 42)

# the trace must be completed
self.assert_structure(
dict(name="main.thread"),
(dict(name="executor.thread"),),
)
self.assert_span_count(2)
spans = self.get_spans()
assert spans[0].name == "main.thread"
assert spans[1].name == "executor.thread"
assert spans[1].trace_id == spans[0].trace_id
assert spans[1].parent_id == spans[0].span_id


@pytest.mark.subprocess(ddtrace_run=True, timeout=5)
Expand Down Expand Up @@ -384,3 +437,58 @@ def test_concurrent_futures_with_gevent():
assert result == 42
sys.exit(0)
os.waitpid(pid, 0)


def test_submit_no_wait(tracer: DummyTracer):
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)

futures = []

def work():
# This is our expected scenario
assert tracer.current_trace_context() is not None
assert tracer.current_span() is None

# DEV: This is the regression case that was raising
# tracer.current_span().set_tag("work", "done")

with tracer.trace("work"):
pass

def task():
with tracer.trace("task"):
for _ in range(4):
futures.append(executor.submit(work))

with tracer.trace("main"):
task()

# Make sure all background tasks are done
executor.shutdown(wait=True)

# Make sure no exceptions were raised in the tasks
for future in futures:
assert future.done()
assert future.exception() is None
assert future.result() is None

traces = tracer.pop_traces()
assert len(traces) == 4

assert len(traces[0]) == 3
root_span, task_span, work_span = traces[0]
assert root_span.name == "main"

assert task_span.name == "task"
assert task_span.parent_id == root_span.span_id
assert task_span.trace_id == root_span.trace_id

assert work_span.name == "work"
assert work_span.parent_id == task_span.span_id
assert work_span.trace_id == root_span.trace_id

for work_spans in traces[2:]:
(work_span,) = work_spans
assert work_span.name == "work"
assert work_span.parent_id == task_span.span_id
assert work_span.trace_id == root_span.trace_id

0 comments on commit 944c395

Please sign in to comment.