Skip to content

Commit

Permalink
Use correct signatures for Celery Task Hooks (#791)
Browse files Browse the repository at this point in the history
  • Loading branch information
Swatinem authored Oct 21, 2024
1 parent ba012ed commit e7432b4
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 105 deletions.
21 changes: 6 additions & 15 deletions helpers/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,39 +68,30 @@ def __init__(
sentry_sdk.set_tag("owner_id", owner_id)
sentry_sdk.set_tag("repo_id", repo_id)
sentry_sdk.set_tag("commit_sha", commit_sha)
transaction = sentry_sdk.get_current_scope().transaction
if transaction is not None:
transaction.set_tag("owner_id", owner_id)
transaction.set_tag("repo_id", repo_id)
transaction.set_tag("commit_sha", commit_sha)

def populate(self):
if self.populated:
return

repo = None
commit = None
dbsession = get_db_session()

if self.repo_id:
if not self.owner_id:
repo = (
dbsession.query(Repository)
self.owner_id = (
dbsession.query(Repository.ownerid)
.filter(Repository.repoid == self.repo_id)
.first()
.first()[0]
)
self.owner_id = repo.ownerid

if self.commit_sha and not self.commit_id:
commit = (
dbsession.query(Commit)
self.commit_id = (
dbsession.query(Commit.id_)
.filter(
Commit.repoid == self.repo_id,
Commit.commitid == self.commit_sha,
)
.first()
.first()[0]
)
self.commit_id = commit.id_

self.populated = True

Expand Down
84 changes: 32 additions & 52 deletions tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
log_set_task_id,
log_set_task_name,
)
from helpers.metrics import metrics
from helpers.telemetry import MetricContext, TimeseriesTimer
from helpers.timeseries import timeseries_enabled

Expand All @@ -51,9 +50,7 @@ def on_timeout(self, soft: bool, timeout: int):
res = super().on_timeout(soft, timeout)
if not soft:
REQUEST_HARD_TIMEOUT_COUNTER.labels(task=self.name).inc()
metrics.incr(f"{self.metrics_prefix}.hardtimeout")
REQUEST_TIMEOUT_COUNTER.labels(task=self.name).inc()
metrics.incr(f"{self.metrics_prefix}.timeout")
return res


Expand Down Expand Up @@ -245,20 +242,13 @@ def _emit_queue_metrics(self):
enqueued_time = datetime.fromisoformat(created_timestamp)
now = datetime.now()
delta = now - enqueued_time
metrics.timing(f"{self.metrics_prefix}.time_in_queue", delta)

queue_name = self.request.get("delivery_info", {}).get("routing_key", None)
time_in_queue_timer = TASK_TIME_IN_QUEUE.labels(
task=self.name, queue=queue_name
) # TODO is None a valid label value
time_in_queue_timer.observe(delta.total_seconds())

if queue_name:
metrics.timing(f"worker.queues.{queue_name}.time_in_queue", delta)
metrics.timing(
f"{self.metrics_prefix}.{queue_name}.time_in_queue", delta
)

def run(self, *args, **kwargs):
task = get_current_task()

Expand All @@ -279,39 +269,32 @@ def run(self, *args, **kwargs):
owner_id=kwargs.get("ownerid"),
)

with TimeseriesTimer(
metric_context, f"{self.metrics_prefix}.full_runtime", sync=True
):
with self.task_full_runtime.time(): # Timer isn't tested
with metrics.timer(f"{self.metrics_prefix}.full"):
db_session = get_db_session()
try:
with TimeseriesTimer(
metric_context,
f"{self.metrics_prefix}.core_runtime",
sync=True,
):
with self.task_core_runtime.time(): # Timer isn't tested
with metrics.timer(f"{self.metrics_prefix}.run"):
return self.run_impl(db_session, *args, **kwargs)
except (DataError, IntegrityError):
log.exception(
"Errors related to the constraints of database happened",
extra=dict(task_args=args, task_kwargs=kwargs),
)
db_session.rollback()
self._rollback_django()
self.retry()
except SQLAlchemyError as ex:
self._analyse_error(ex, args, kwargs)
db_session.rollback()
self._rollback_django()
self.retry()
finally:
log_set_task_name(None)
log_set_task_id(None)
self.wrap_up_dbsession(db_session)
self._commit_django()
with self.task_full_runtime.time(): # Timer isn't tested
db_session = get_db_session()
try:
with TimeseriesTimer(
metric_context, f"{self.metrics_prefix}.core_runtime", sync=True
):
with self.task_core_runtime.time(): # Timer isn't tested
return self.run_impl(db_session, *args, **kwargs)
except (DataError, IntegrityError):
log.exception(
"Errors related to the constraints of database happened",
extra=dict(task_args=args, task_kwargs=kwargs),
)
db_session.rollback()
self._rollback_django()
self.retry()
except SQLAlchemyError as ex:
self._analyse_error(ex, args, kwargs)
db_session.rollback()
self._rollback_django()
self.retry()
finally:
log_set_task_name(None)
log_set_task_id(None)
self.wrap_up_dbsession(db_session)
self._commit_django()

def wrap_up_dbsession(self, db_session):
"""
Expand Down Expand Up @@ -352,10 +335,9 @@ def wrap_up_dbsession(self, db_session):
)
get_db_session.remove()

def on_retry(self, *args, **kwargs):
res = super().on_retry(*args, **kwargs)
def on_retry(self, exc, task_id, args, kwargs, einfo):
res = super().on_retry(exc, task_id, args, kwargs, einfo)
self.task_retry_counter.inc()
metrics.incr(f"{self.metrics_prefix}.retries")
metric_context = MetricContext(
commit_sha=kwargs.get("commitid"),
repo_id=kwargs.get("repoid"),
Expand All @@ -364,10 +346,9 @@ def on_retry(self, *args, **kwargs):
metric_context.log_simple_metric(f"{self.metrics_prefix}.retry", 1.0)
return res

def on_success(self, *args, **kwargs):
res = super().on_success(*args, **kwargs)
def on_success(self, retval, task_id, args, kwargs):
res = super().on_success(retval, task_id, args, kwargs)
self.task_success_counter.inc()
metrics.incr(f"{self.metrics_prefix}.successes")
metric_context = MetricContext(
commit_sha=kwargs.get("commitid"),
repo_id=kwargs.get("repoid"),
Expand All @@ -376,13 +357,12 @@ def on_success(self, *args, **kwargs):
metric_context.log_simple_metric(f"{self.metrics_prefix}.success", 1.0)
return res

def on_failure(self, *args, **kwargs):
def on_failure(self, exc, task_id, args, kwargs, einfo):
"""
Includes SoftTimeoutLimitException, for example
"""
res = super().on_failure(*args, **kwargs)
res = super().on_failure(exc, task_id, args, kwargs, einfo)
self.task_failure_counter.inc()
metrics.incr(f"{self.metrics_prefix}.failures")
metric_context = MetricContext(
commit_sha=kwargs.get("commitid"),
repo_id=kwargs.get("repoid"),
Expand Down
44 changes: 6 additions & 38 deletions tasks/tests/unit/test_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from datetime import datetime, timedelta
from datetime import datetime
from pathlib import Path
from unittest.mock import patch

Expand Down Expand Up @@ -104,7 +104,6 @@ def test_hard_time_limit_task_from_default_app(self, mocker):
@patch("helpers.telemetry.MetricContext.log_simple_metric")
def test_sample_run(self, mock_simple_metric, mocker, dbsession):
mocked_get_db_session = mocker.patch("tasks.base.get_db_session")
mocked_metrics = mocker.patch("tasks.base.metrics")
mock_task_request = mocker.patch("tasks.base.BaseCodecovTask.request")
fake_request_values = dict(
created_timestamp="2023-06-13 10:00:00.000000",
Expand All @@ -117,23 +116,6 @@ def test_sample_run(self, mock_simple_metric, mocker, dbsession):
task_instance = SampleTask()
result = task_instance.run()
assert result == {"unusual": "return", "value": ["There"]}
assert mocked_metrics.timing.call_count == 3
mocked_metrics.timing.assert_has_calls(
[
call(
"worker.task.test.SampleTask.time_in_queue",
timedelta(seconds=61, microseconds=123),
),
call(
"worker.queues.my-queue.time_in_queue",
timedelta(seconds=61, microseconds=123),
),
call(
"worker.task.test.SampleTask.my-queue.time_in_queue",
timedelta(seconds=61, microseconds=123),
),
]
)
assert (
REGISTRY.get_sample_value(
"worker_tasks_timers_time_in_queue_seconds_sum",
Expand All @@ -142,10 +124,7 @@ def test_sample_run(self, mock_simple_metric, mocker, dbsession):
== 61.000123
)
mock_simple_metric.assert_has_calls(
[
call("worker.task.test.SampleTask.core_runtime", ANY),
call("worker.task.test.SampleTask.full_runtime", ANY),
]
[call("worker.task.test.SampleTask.core_runtime", ANY)]
)

@patch("tasks.base.BaseCodecovTask._emit_queue_metrics")
Expand Down Expand Up @@ -329,12 +308,11 @@ def test_run_sqlalchemy_error_rollback(self, mocker, dbsession, celery_app):

@pytest.mark.django_db(databases={"default", "timeseries"})
class TestBaseCodecovTaskHooks(object):
def test_sample_task_success(self, celery_app, mocker):
def test_sample_task_success(self, celery_app):
class SampleTask(BaseCodecovTask, name="test.SampleTask"):
def run_impl(self, dbsession):
return {"unusual": "return", "value": ["There"]}

mock_metrics = mocker.patch("tasks.base.metrics.incr")
DTask = celery_app.register_task(SampleTask())
task = celery_app.tasks[DTask.name]

Expand All @@ -354,16 +332,14 @@ def run_impl(self, dbsession):

res = k.get()
assert res == {"unusual": "return", "value": ["There"]}
mock_metrics.assert_called_with("worker.task.test.SampleTask.successes")
assert prom_run_counter_after - prom_run_counter_before == 1
assert prom_success_counter_after - prom_success_counter_before == 1

def test_sample_task_failure(self, celery_app, mocker):
def test_sample_task_failure(self, celery_app):
class FailureSampleTask(BaseCodecovTask, name="test.FailureSampleTask"):
def run_impl(self, *args, **kwargs):
raise Exception("Whhhhyyyyyyy")

mock_metrics = mocker.patch("tasks.base.metrics.incr")
DTask = celery_app.register_task(FailureSampleTask())
task = celery_app.tasks[DTask.name]
with pytest.raises(Exception) as exc:
Expand All @@ -383,24 +359,21 @@ def run_impl(self, *args, **kwargs):
assert prom_run_counter_after - prom_run_counter_before == 1
assert prom_failure_counter_after - prom_failure_counter_before == 1
assert exc.value.args == ("Whhhhyyyyyyy",)
mock_metrics.assert_called_with("worker.task.test.FailureSampleTask.failures")

def test_sample_task_retry(self, celery_app, mocker):
def test_sample_task_retry(self):
# Unfortunately we cant really call the task with apply().get()
# Something happens inside celery as of version 4.3 that makes them
# not call on_Retry at all.
# best we can do is to call on_retry ourselves and ensure this makes the
# metric be called
mock_metrics = mocker.patch("tasks.base.metrics.incr")
task = RetrySampleTask()
prom_retry_counter_before = REGISTRY.get_sample_value(
"worker_task_counts_retries_total", labels={"task": task.name}
)
task.on_retry("exc", "task_id", "args", "kwargs", "einfo")
task.on_retry("exc", "task_id", ("args",), {"kwargs": "foo"}, "einfo")
prom_retry_counter_after = REGISTRY.get_sample_value(
"worker_task_counts_retries_total", labels={"task": task.name}
)
mock_metrics.assert_called_with("worker.task.test.RetrySampleTask.retries")
assert prom_retry_counter_after - prom_retry_counter_before == 1


Expand Down Expand Up @@ -435,7 +408,6 @@ def test_sample_task_timeout(self, celery_app, mocker):
class SampleTask(BaseCodecovTask, name="test.SampleTask"):
pass

mock_metrics = mocker.patch("tasks.base.metrics.incr")
DTask = celery_app.register_task(SampleTask())
request = self.xRequest(mocker, DTask.name, celery_app)
prom_timeout_counter_before = (
Expand All @@ -448,14 +420,12 @@ class SampleTask(BaseCodecovTask, name="test.SampleTask"):
prom_timeout_counter_after = REGISTRY.get_sample_value(
"worker_task_counts_timeouts_total", labels={"task": DTask.name}
)
mock_metrics.assert_called_with("worker.task.test.SampleTask.timeout")
assert prom_timeout_counter_after - prom_timeout_counter_before == 1

def test_sample_task_hard_timeout(self, celery_app, mocker):
class SampleTask(BaseCodecovTask, name="test.SampleTask"):
pass

mock_metrics = mocker.patch("tasks.base.metrics.incr")
DTask = celery_app.register_task(SampleTask())
request = self.xRequest(mocker, DTask.name, celery_app)
prom_timeout_counter_before = (
Expand All @@ -477,8 +447,6 @@ class SampleTask(BaseCodecovTask, name="test.SampleTask"):
prom_hard_timeout_counter_after = REGISTRY.get_sample_value(
"worker_task_counts_hard_timeouts_total", labels={"task": DTask.name}
)
mock_metrics.assert_any_call("worker.task.test.SampleTask.hardtimeout")
mock_metrics.assert_any_call("worker.task.test.SampleTask.timeout")
assert prom_timeout_counter_after - prom_timeout_counter_before == 1
assert prom_hard_timeout_counter_after - prom_hard_timeout_counter_before == 1

Expand Down

0 comments on commit e7432b4

Please sign in to comment.