Skip to content

Commit

Permalink
fix(opentelemetry-instrumentation-celery): attach incoming context on… (
Browse files Browse the repository at this point in the history
open-telemetry#2385)

* fix(opentelemetry-instrumentation-celery): attach incoming context on _trace_prerun

* docs(CHANGELOG): add entry for fix open-telemetry#2385

* fix(opentelemetry-instrumentation-celery): detach context after task is run

* test(opentelemetry-instrumentation-celery): add context utils tests

* fix(opentelemetry-instrumentation-celery): remove duplicated signal registration

* refactor(opentelemetry-instrumentation-celery): fix lint issues

* refactor(opentelemetry-instrumentation-celery): fix types and tests for python 3.8

* refactor(opentelemetry-instrumentation-celery): fix lint issues

* refactor(opentelemetry-instrumentation-celery): fix lint issues

* fix(opentelemetry-instrumentation-celery): attach context only if it is not None

* refactor(opentelemetry-instrumentation-celery): fix lint issues
  • Loading branch information
malcolmrebughini authored and xrmx committed Jan 24, 2025
1 parent 898b630 commit f478267
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 49 deletions.
8 changes: 4 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
([#2756](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2756))
- `opentelemetry-instrumentation-aws-lambda` Fixing w3c baggage support
([#2589](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2589))
- `opentelemetry-instrumentation-celery` propagates baggage
([#2385](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2385))

## Version 1.26.0/0.47b0 (2024-07-23)

Expand Down Expand Up @@ -119,10 +121,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
([#2610](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2610))
- `opentelemetry-instrumentation-asgi` Bugfix: Middleware did not set status code attribute on duration metrics for non-recording spans.
([#2627](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2627))
<<<<<<< HEAD
- `opentelemetry-instrumentation-mysql` Add support for `mysql-connector-python` v9 ([#2751](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2751))
=======
>>>>>>> 5a623233 (Changelog update)
- `opentelemetry-instrumentation-mysql` Add support for `mysql-connector-python` v9
([#2751](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2751))

## Version 1.25.0/0.46b0 (2024-05-31)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def add(x, y):
from billiard.einfo import ExceptionInfo
from celery import signals # pylint: disable=no-name-in-module

from opentelemetry import context as context_api
from opentelemetry import trace
from opentelemetry.instrumentation.celery import utils
from opentelemetry.instrumentation.celery.package import _instruments
Expand Down Expand Up @@ -169,6 +170,7 @@ def _trace_prerun(self, *args, **kwargs):
self.update_task_duration_time(task_id)
request = task.request
tracectx = extract(request, getter=celery_getter) or None
token = context_api.attach(tracectx) if tracectx is not None else None

logger.debug("prerun signal start task_id=%s", task_id)

Expand All @@ -179,7 +181,7 @@ def _trace_prerun(self, *args, **kwargs):

activation = trace.use_span(span, end_on_exit=True)
activation.__enter__() # pylint: disable=E1101
utils.attach_span(task, task_id, (span, activation))
utils.attach_context(task, task_id, span, activation, token)

def _trace_postrun(self, *args, **kwargs):
task = utils.retrieve_task(kwargs)
Expand All @@ -191,11 +193,14 @@ def _trace_postrun(self, *args, **kwargs):
logger.debug("postrun signal task_id=%s", task_id)

# retrieve and finish the Span
span, activation = utils.retrieve_span(task, task_id)
if span is None:
ctx = utils.retrieve_context(task, task_id)

if ctx is None:
logger.warning("no existing span found for task_id=%s", task_id)
return

span, activation, token = ctx

# request context tags
if span.is_recording():
span.set_attribute(_TASK_TAG_KEY, _TASK_RUN)
Expand All @@ -204,10 +209,11 @@ def _trace_postrun(self, *args, **kwargs):
span.set_attribute(_TASK_NAME_KEY, task.name)

activation.__exit__(None, None, None)
utils.detach_span(task, task_id)
utils.detach_context(task, task_id)
self.update_task_duration_time(task_id)
labels = {"task": task.name, "worker": task.request.hostname}
self._record_histograms(task_id, labels)
context_api.detach(token)

def _trace_before_publish(self, *args, **kwargs):
task = utils.retrieve_task_from_sender(kwargs)
Expand Down Expand Up @@ -238,7 +244,9 @@ def _trace_before_publish(self, *args, **kwargs):
activation = trace.use_span(span, end_on_exit=True)
activation.__enter__() # pylint: disable=E1101

utils.attach_span(task, task_id, (span, activation), is_publish=True)
utils.attach_context(
task, task_id, span, activation, None, is_publish=True
)

headers = kwargs.get("headers")
if headers:
Expand All @@ -253,13 +261,16 @@ def _trace_after_publish(*args, **kwargs):
return

# retrieve and finish the Span
_, activation = utils.retrieve_span(task, task_id, is_publish=True)
if activation is None:
ctx = utils.retrieve_context(task, task_id, is_publish=True)

if ctx is None:
logger.warning("no existing span found for task_id=%s", task_id)
return

_, activation, _ = ctx

activation.__exit__(None, None, None) # pylint: disable=E1101
utils.detach_span(task, task_id, is_publish=True)
utils.detach_context(task, task_id, is_publish=True)

@staticmethod
def _trace_failure(*args, **kwargs):
Expand All @@ -269,9 +280,14 @@ def _trace_failure(*args, **kwargs):
if task is None or task_id is None:
return

# retrieve and pass exception info to activation
span, _ = utils.retrieve_span(task, task_id)
if span is None or not span.is_recording():
ctx = utils.retrieve_context(task, task_id)

if ctx is None:
return

span, _, _ = ctx

if not span.is_recording():
return

status_kwargs = {"status_code": StatusCode.ERROR}
Expand Down Expand Up @@ -311,8 +327,14 @@ def _trace_retry(*args, **kwargs):
if task is None or task_id is None or reason is None:
return

span, _ = utils.retrieve_span(task, task_id)
if span is None or not span.is_recording():
ctx = utils.retrieve_context(task, task_id)

if ctx is None:
return

span, _, _ = ctx

if not span.is_recording():
return

# Add retry reason metadata to span
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@
# limitations under the License.

import logging
from typing import ContextManager, Optional, Tuple

from celery import registry # pylint: disable=no-name-in-module
from celery.app.task import Task

from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace import Span

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -81,10 +84,12 @@ def set_attributes_from_context(span, context):
elif key == "delivery_info":
# Get also destination from this
routing_key = value.get("routing_key")

if routing_key is not None:
span.set_attribute(
SpanAttributes.MESSAGING_DESTINATION, routing_key
)

value = str(value)

elif key == "id":
Expand Down Expand Up @@ -114,11 +119,18 @@ def set_attributes_from_context(span, context):
span.set_attribute(attribute_name, value)


def attach_span(task, task_id, span, is_publish=False):
"""Helper to propagate a `Span` for the given `Task` instance. This
function uses a `dict` that stores the Span using the
`(task_id, is_publish)` as a key. This is useful when information must be
propagated from one Celery signal to another.
def attach_context(
task: Optional[Task],
task_id: str,
span: Span,
activation: ContextManager[Span],
token: Optional[object],
is_publish: bool = False,
) -> None:
"""Helper to propagate a `Span`, `ContextManager` and context token
for the given `Task` instance. This function uses a `dict` that stores
the Span using the `(task_id, is_publish)` as a key. This is useful
when information must be propagated from one Celery signal to another.
We use (task_id, is_publish) for the key to ensure that publishing a
task from within another task does not cause any conflicts.
Expand All @@ -134,36 +146,41 @@ def attach_span(task, task_id, span, is_publish=False):
"""
if task is None:
return
span_dict = getattr(task, CTX_KEY, None)
if span_dict is None:
span_dict = {}
setattr(task, CTX_KEY, span_dict)

span_dict[(task_id, is_publish)] = span
ctx_dict = getattr(task, CTX_KEY, None)

if ctx_dict is None:
ctx_dict = {}
setattr(task, CTX_KEY, ctx_dict)

ctx_dict[(task_id, is_publish)] = (span, activation, token)


def detach_span(task, task_id, is_publish=False):
"""Helper to remove a `Span` in a Celery task when it's propagated.
This function handles tasks where the `Span` is not attached.
def detach_context(task, task_id, is_publish=False) -> None:
"""Helper to remove `Span`, `ContextManager` and context token in a
Celery task when it's propagated.
This function handles tasks where no values are attached to the `Task`.
"""
span_dict = getattr(task, CTX_KEY, None)
if span_dict is None:
return

# See note in `attach_span` for key info
span_dict.pop((task_id, is_publish), (None, None))
# See note in `attach_context` for key info
span_dict.pop((task_id, is_publish), None)


def retrieve_span(task, task_id, is_publish=False):
"""Helper to retrieve an active `Span` stored in a `Task`
instance
def retrieve_context(
task, task_id, is_publish=False
) -> Optional[Tuple[Span, ContextManager[Span], Optional[object]]]:
"""Helper to retrieve an active `Span`, `ContextManager` and context token
stored in a `Task` instance
"""
span_dict = getattr(task, CTX_KEY, None)
if span_dict is None:
return (None, None)
return None

# See note in `attach_span` for key info
return span_dict.get((task_id, is_publish), (None, None))
# See note in `attach_context` for key info
return span_dict.get((task_id, is_publish), None)


def retrieve_task(kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from celery import Celery

from opentelemetry import baggage


class Config:
result_backend = "rpc"
Expand All @@ -36,3 +38,8 @@ def task_add(num_a, num_b):
@app.task
def task_raises():
raise CustomError("The task failed!")


@app.task
def task_returns_baggage():
return dict(baggage.get_all())
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
import threading
import time

from opentelemetry import baggage, context
from opentelemetry.instrumentation.celery import CeleryInstrumentor
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.test.test_base import TestBase
from opentelemetry.trace import SpanKind, StatusCode

from .celery_test_tasks import app, task_add, task_raises
from .celery_test_tasks import app, task_add, task_raises, task_returns_baggage


class TestCeleryInstrumentation(TestBase):
Expand Down Expand Up @@ -168,6 +169,22 @@ def test_uninstrument(self):
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 0)

def test_baggage(self):
CeleryInstrumentor().instrument()

ctx = baggage.set_baggage("key", "value")
context.attach(ctx)

task = task_returns_baggage.delay()

timeout = time.time() + 60 * 1 # 1 minutes from now
while not task.ready():
if time.time() > timeout:
break
time.sleep(0.05)

self.assertEqual(task.result, {"key": "value"})


class TestCelerySignatureTask(TestBase):
def setUp(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,10 @@ def fn_task():
# propagate and retrieve a Span
task_id = "7c6731af-9533-40c3-83a9-25b58f0d837f"
span = trace._Span("name", mock.Mock(spec=trace_api.SpanContext))
utils.attach_span(fn_task, task_id, span)
span_after = utils.retrieve_span(fn_task, task_id)
utils.attach_context(fn_task, task_id, span, mock.Mock(), "")
ctx = utils.retrieve_context(fn_task, task_id)
self.assertIsNotNone(ctx)
span_after, _, _ = ctx
self.assertIs(span, span_after)

def test_span_delete(self):
Expand All @@ -180,17 +182,19 @@ def fn_task():
# propagate a Span
task_id = "7c6731af-9533-40c3-83a9-25b58f0d837f"
span = trace._Span("name", mock.Mock(spec=trace_api.SpanContext))
utils.attach_span(fn_task, task_id, span)
utils.attach_context(fn_task, task_id, span, mock.Mock(), "")
# delete the Span
utils.detach_span(fn_task, task_id)
self.assertEqual(utils.retrieve_span(fn_task, task_id), (None, None))
utils.detach_context(fn_task, task_id)
self.assertEqual(utils.retrieve_context(fn_task, task_id), None)

def test_optional_task_span_attach(self):
task_id = "7c6731af-9533-40c3-83a9-25b58f0d837f"
span = trace._Span("name", mock.Mock(spec=trace_api.SpanContext))

# assert this is is a no-aop
self.assertIsNone(utils.attach_span(None, task_id, span))
self.assertIsNone(
utils.attach_context(None, task_id, span, mock.Mock(), "")
)

def test_span_delete_empty(self):
# ensure detach_span doesn't raise an exception if span is not present
Expand All @@ -201,10 +205,8 @@ def fn_task():
# delete the Span
task_id = "7c6731af-9533-40c3-83a9-25b58f0d837f"
try:
utils.detach_span(fn_task, task_id)
self.assertEqual(
utils.retrieve_span(fn_task, task_id), (None, None)
)
utils.detach_context(fn_task, task_id)
self.assertEqual(utils.retrieve_context(fn_task, task_id), None)
except Exception as ex: # pylint: disable=broad-except
self.fail(f"Exception was raised: {ex}")

Expand Down

0 comments on commit f478267

Please sign in to comment.