diff --git a/taskiq/kicker.py b/taskiq/kicker.py index bdc62dae..e6b93a83 100644 --- a/taskiq/kicker.py +++ b/taskiq/kicker.py @@ -69,7 +69,10 @@ def with_labels( self.labels.update(labels) return self - def with_task_id(self, task_id: str) -> "AsyncKicker[_FuncParams, _ReturnType]": + def with_task_id( + self, + task_id: Optional[str], + ) -> "AsyncKicker[_FuncParams, _ReturnType]": """ Set task_id for current execution. @@ -208,6 +211,7 @@ async def schedule_by_cron( labels=message.labels, args=message.args, kwargs=message.kwargs, + task_id=self.custom_task_id, cron=cron_str, cron_offset=cron_offset, ) @@ -239,6 +243,7 @@ async def schedule_by_time( labels=message.labels, args=message.args, kwargs=message.kwargs, + task_id=self.custom_task_id, time=time, ) await source.add_schedule(scheduled) diff --git a/taskiq/scheduler/scheduled_task/v1.py b/taskiq/scheduler/scheduled_task/v1.py index 5209f61e..96b5c1e1 100644 --- a/taskiq/scheduler/scheduled_task/v1.py +++ b/taskiq/scheduler/scheduled_task/v1.py @@ -12,6 +12,7 @@ class ScheduledTask(BaseModel): labels: Dict[str, Any] args: List[Any] kwargs: Dict[str, Any] + task_id: Optional[str] = None schedule_id: str = Field(default_factory=lambda: uuid.uuid4().hex) cron: Optional[str] = None cron_offset: Optional[Union[str, timedelta]] = None diff --git a/taskiq/scheduler/scheduled_task/v2.py b/taskiq/scheduler/scheduled_task/v2.py index 332dce5d..ce28c123 100644 --- a/taskiq/scheduler/scheduled_task/v2.py +++ b/taskiq/scheduler/scheduled_task/v2.py @@ -13,6 +13,7 @@ class ScheduledTask(BaseModel): labels: Dict[str, Any] args: List[Any] kwargs: Dict[str, Any] + task_id: Optional[str] = None schedule_id: str = Field(default_factory=lambda: uuid.uuid4().hex) cron: Optional[str] = None cron_offset: Optional[Union[str, timedelta]] = None diff --git a/taskiq/scheduler/scheduler.py b/taskiq/scheduler/scheduler.py index b2484243..7ad842cd 100644 --- a/taskiq/scheduler/scheduler.py +++ b/taskiq/scheduler/scheduler.py @@ -51,6 +51,7 @@ async def on_ready(self, source: "ScheduleSource", task: ScheduledTask) -> None: .with_labels( schedule_id=task.schedule_id, ) + .with_task_id(task_id=task.task_id) .kiq( *task.args, **task.kwargs, diff --git a/tests/test_retry_task.py b/tests/test_retry_task.py new file mode 100644 index 00000000..95eaf4dd --- /dev/null +++ b/tests/test_retry_task.py @@ -0,0 +1,43 @@ +import pytest + +from taskiq import ( + Context, + InMemoryBroker, + SmartRetryMiddleware, + TaskiqDepends, + TaskiqScheduler, +) +from taskiq.schedule_sources import LabelScheduleSource + + +@pytest.mark.parametrize( + "retry_count", + range(5), +) +@pytest.mark.anyio +async def test_save_task_id_for_retry(retry_count: int) -> None: + broker = InMemoryBroker().with_middlewares( + SmartRetryMiddleware( + default_retry_count=retry_count + 1, + default_delay=0.1, + ), + ) + scheduler = TaskiqScheduler(broker, [LabelScheduleSource(broker)]) + + check_interval = 0.5 + + @broker.task("exc_task", retry_on_error=True) + async def exc_task(count: int = 0, context: "Context" = TaskiqDepends()) -> int: + retry = int(context.message.labels.get("_retries", 0)) + if retry < count: + raise Exception("test") + return retry + + await broker.startup() + await scheduler.startup() + + task_with_retry = await exc_task.kiq(retry_count) + task_with_retry_result = await task_with_retry.wait_result( + check_interval=check_interval, + ) + assert task_with_retry_result.return_value == retry_count