diff --git a/pyproject.toml b/pyproject.toml index 48e33ff7..15b2b406 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -159,6 +159,7 @@ ignore = [ "ANN401", # typing.Any are disallowed in `**kwargs "PLR0913", # Too many arguments for function call "D106", # Missing docstring in public nested class + "D205", # 1 blank line required between summary line and description ] exclude = [".venv/"] mccabe = { max-complexity = 10 } diff --git a/taskiq/__init__.py b/taskiq/__init__.py index 42779b34..280257ee 100644 --- a/taskiq/__init__.py +++ b/taskiq/__init__.py @@ -1,4 +1,5 @@ """Distributed task manager.""" + from importlib.metadata import version from taskiq_dependencies import Depends as TaskiqDepends @@ -8,7 +9,7 @@ from taskiq.abc.middleware import TaskiqMiddleware from taskiq.abc.result_backend import AsyncResultBackend from taskiq.abc.schedule_source import ScheduleSource -from taskiq.acks import AckableMessage +from taskiq.acks import AckableMessage, AckableMessageWithDeliveryCount from taskiq.brokers.inmemory_broker import InMemoryBroker from taskiq.brokers.shared_broker import async_shared_broker from taskiq.brokers.zmq_broker import ZeroMQBroker @@ -24,7 +25,7 @@ TaskiqResultTimeoutError, ) from taskiq.funcs import gather -from taskiq.message import BrokerMessage, TaskiqMessage +from taskiq.message import BrokerMessage, DeliveryCountMessage, TaskiqMessage from taskiq.middlewares.prometheus_middleware import PrometheusMiddleware from taskiq.middlewares.retry_middleware import SimpleRetryMiddleware from taskiq.result import TaskiqResult @@ -53,6 +54,8 @@ "NoResultError", "SendTaskError", "AckableMessage", + "DeliveryCountMessage", + "AckableMessageWithDeliveryCount", "InMemoryBroker", "ScheduleSource", "TaskiqScheduler", diff --git a/taskiq/abc/__init__.py b/taskiq/abc/__init__.py index c15fe965..76b5f07b 100644 --- a/taskiq/abc/__init__.py +++ b/taskiq/abc/__init__.py @@ -1,4 +1,5 @@ """Abstract classes for taskiq.""" + from taskiq.abc.broker import AsyncBroker from taskiq.abc.result_backend import AsyncResultBackend diff --git a/taskiq/abc/broker.py b/taskiq/abc/broker.py index ed05ee46..fd682694 100644 --- a/taskiq/abc/broker.py +++ b/taskiq/abc/broker.py @@ -77,6 +77,7 @@ def __init__( self, result_backend: "Optional[AsyncResultBackend[_T]]" = None, task_id_generator: Optional[Callable[[], str]] = None, + max_attempts_at_message: Optional[int] = None, ) -> None: if result_backend is None: result_backend = DummyResultBackend() @@ -113,6 +114,7 @@ def __init__( self.state = TaskiqState() self.custom_dependency_context: Dict[Any, Any] = {} self.dependency_overrides: Dict[Any, Any] = {} + self.max_attempts_at_message = max_attempts_at_message # True only if broker runs in worker process. self.is_worker_process: bool = False # True only if broker runs in scheduler process. diff --git a/taskiq/acks.py b/taskiq/acks.py index c3b3fe77..53058939 100644 --- a/taskiq/acks.py +++ b/taskiq/acks.py @@ -1,7 +1,7 @@ import enum from typing import Awaitable, Callable, Union -from pydantic import BaseModel +from taskiq.message import DeliveryCountMessage, WrappedMessage @enum.unique @@ -20,7 +20,7 @@ class AcknowledgeType(str, enum.Enum): WHEN_SAVED = "when_saved" -class AckableMessage(BaseModel): +class AckableMessage(WrappedMessage): """ Message that can be acknowledged. @@ -33,5 +33,8 @@ class AckableMessage(BaseModel): as a whole. """ - data: bytes ack: Callable[[], Union[None, Awaitable[None]]] + + +class AckableMessageWithDeliveryCount(AckableMessage, DeliveryCountMessage): + """Message that can be acknowledged and has a delivery count.""" diff --git a/taskiq/cli/worker/run.py b/taskiq/cli/worker/run.py index 727a02a6..89de14d5 100644 --- a/taskiq/cli/worker/run.py +++ b/taskiq/cli/worker/run.py @@ -143,6 +143,7 @@ def interrupt_handler(signum: int, _frame: Any) -> None: ack_type=args.ack_type, max_tasks_to_execute=args.max_tasks_per_child, wait_tasks_timeout=args.wait_tasks_timeout, + max_attempts_at_message=broker.max_attempts_at_message, **receiver_kwargs, # type: ignore ) loop.run_until_complete(receiver.listen()) diff --git a/taskiq/message.py b/taskiq/message.py index 675f7cf3..ad055c68 100644 --- a/taskiq/message.py +++ b/taskiq/message.py @@ -42,3 +42,15 @@ class BrokerMessage(BaseModel): task_name: str message: bytes labels: Dict[str, Any] + + +class WrappedMessage(BaseModel): + """Abstraction for an incoming message in a wrapper.""" + + data: bytes + + +class DeliveryCountMessage(WrappedMessage): + """Message with a present delivery count.""" + + delivery_count: int | None = None diff --git a/taskiq/receiver/receiver.py b/taskiq/receiver/receiver.py index 7d5a4035..9d0a10b8 100644 --- a/taskiq/receiver/receiver.py +++ b/taskiq/receiver/receiver.py @@ -13,7 +13,7 @@ from taskiq.acks import AcknowledgeType from taskiq.context import Context from taskiq.exceptions import NoResultError -from taskiq.message import TaskiqMessage +from taskiq.message import DeliveryCountMessage, TaskiqMessage, WrappedMessage from taskiq.receiver.params_parser import parse_params from taskiq.result import TaskiqResult from taskiq.state import TaskiqState @@ -58,6 +58,7 @@ def __init__( on_exit: Optional[Callable[["Receiver"], None]] = None, max_tasks_to_execute: Optional[int] = None, wait_tasks_timeout: Optional[float] = None, + max_attempts_at_message: Optional[int] = None, ) -> None: self.broker = broker self.executor = executor @@ -72,6 +73,7 @@ def __init__( self.known_tasks: Set[str] = set() self.max_tasks_to_execute = max_tasks_to_execute self.wait_tasks_timeout = wait_tasks_timeout + self.max_attempts_at_message = max_attempts_at_message for task in self.broker.get_all_tasks().values(): self._prepare_task(task.task_name, task.original_func) self.sem: "Optional[asyncio.Semaphore]" = None @@ -86,7 +88,7 @@ def __init__( async def callback( # noqa: C901, PLR0912 self, - message: Union[bytes, AckableMessage], + message: Union[bytes, WrappedMessage], raise_err: bool = False, ) -> None: """ @@ -101,7 +103,31 @@ async def callback( # noqa: C901, PLR0912 :param raise_err: raise an error if cannot save result in result_backend. """ - message_data = message.data if isinstance(message, AckableMessage) else message + message_data = message.data if isinstance(message, WrappedMessage) else message + + delivery_count = ( + message.delivery_count + if isinstance(message, DeliveryCountMessage) + else None + ) + if ( + delivery_count + and self.max_attempts_at_message + and delivery_count >= self.max_attempts_at_message + ): + logger.error( + "Permitted number of attempts at processing message %s " + "has been exhausted after %s attempts.", + message_data, + self.max_attempts_at_message, + ) + if isinstance( + message, + AckableMessage, + ): + await maybe_awaitable(message.ack()) + return + try: taskiq_msg = self.broker.formatter.loads(message=message_data) taskiq_msg.parse_labels() diff --git a/taskiq/schedule_sources/__init__.py b/taskiq/schedule_sources/__init__.py index 1ad5a6fd..7e2dfb16 100644 --- a/taskiq/schedule_sources/__init__.py +++ b/taskiq/schedule_sources/__init__.py @@ -1,4 +1,5 @@ """Package for schedule sources.""" + from taskiq.schedule_sources.label_based import LabelScheduleSource __all__ = [ diff --git a/taskiq/scheduler/created_schedule.py b/taskiq/scheduler/created_schedule.py index 8e870834..fd37e4b0 100644 --- a/taskiq/scheduler/created_schedule.py +++ b/taskiq/scheduler/created_schedule.py @@ -32,7 +32,9 @@ async def kiq( ... @overload - async def kiq(self: "CreatedSchedule[_ReturnType]") -> AsyncTaskiqTask[_ReturnType]: + async def kiq( + self: "CreatedSchedule[_ReturnType]", + ) -> AsyncTaskiqTask[_ReturnType]: ... async def kiq(self) -> Any: diff --git a/taskiq/serializers/__init__.py b/taskiq/serializers/__init__.py index 26cd430a..ba4ccd88 100644 --- a/taskiq/serializers/__init__.py +++ b/taskiq/serializers/__init__.py @@ -1,4 +1,5 @@ """Taskiq serializers.""" + from .cbor_serializer import CBORSerializer from .json_serializer import JSONSerializer from .msgpack_serializer import MSGPackSerializer diff --git a/tests/receiver/test_receiver.py b/tests/receiver/test_receiver.py index 6b79e325..cc4862f9 100644 --- a/tests/receiver/test_receiver.py +++ b/tests/receiver/test_receiver.py @@ -9,9 +9,10 @@ from taskiq.abc.broker import AckableMessage, AsyncBroker from taskiq.abc.middleware import TaskiqMiddleware +from taskiq.acks import AckableMessageWithDeliveryCount from taskiq.brokers.inmemory_broker import InMemoryBroker from taskiq.exceptions import NoResultError, TaskiqResultTimeoutError -from taskiq.message import TaskiqMessage +from taskiq.message import DeliveryCountMessage, TaskiqMessage from taskiq.receiver import Receiver from taskiq.result import TaskiqResult from tests.utils import AsyncQueueBroker @@ -359,6 +360,127 @@ async def test_callback_unknown_task() -> None: await receiver.callback(broker_message.message) +@pytest.mark.anyio +@pytest.mark.parametrize("delivery_count", [2, None]) +async def test_callback_max_attempts_at_message_not_exceeded( + delivery_count: Optional[int], +) -> None: + """ + Test that callback function calls the task if `max_attempts_at_message` + is not exceeded. + """ + broker = InMemoryBroker() + called_times = 0 + + @broker.task + async def my_task() -> int: + nonlocal called_times + called_times += 1 + return 1 + + receiver = get_receiver(broker) + receiver.max_attempts_at_message = 3 + + broker_message = broker.formatter.dumps( + TaskiqMessage( + task_id="task_id", + task_name=my_task.task_name, + labels={}, + args=[], + kwargs={}, + ), + ) + + await receiver.callback( + DeliveryCountMessage( + data=broker_message.message, + delivery_count=delivery_count, + ), + ) + assert called_times == 1 + + +@pytest.mark.anyio +async def test_callback_max_attempts_at_message_exceeded() -> None: + """ + Test that callback function does not call the task if `max_attempts_at_message` + is exceeded. + """ + broker = InMemoryBroker() + called_times = 0 + + @broker.task + async def my_task() -> int: + nonlocal called_times + called_times += 1 + return 1 + + receiver = get_receiver(broker) + receiver.max_attempts_at_message = 3 + + broker_message = broker.formatter.dumps( + TaskiqMessage( + task_id="task_id", + task_name=my_task.task_name, + labels={}, + args=[], + kwargs={}, + ), + ) + + await receiver.callback( + DeliveryCountMessage( + data=broker_message.message, + delivery_count=3, + ), + ) + assert called_times == 0 + + +@pytest.mark.anyio +async def test_callback_max_attempts_at_message_exceeded_ackable() -> None: + """ + Test that callback function does not call the task if `max_attempts_at_message` + is exceeded and acks the message. + """ + broker = InMemoryBroker() + called_times = 0 + acked = False + + @broker.task + async def my_task() -> int: + nonlocal called_times + called_times += 1 + return 1 + + async def ack_callback() -> None: + nonlocal acked + acked = True + + receiver = get_receiver(broker) + receiver.max_attempts_at_message = 3 + + broker_message = broker.formatter.dumps( + TaskiqMessage( + task_id="task_id", + task_name=my_task.task_name, + labels={}, + args=[], + kwargs={}, + ), + ) + + await receiver.callback( + AckableMessageWithDeliveryCount( + data=broker_message.message, + delivery_count=3, + ack=ack_callback, + ), + ) + assert called_times == 0 + assert acked + + @pytest.mark.anyio async def test_custom_ctx() -> None: """Tests that run_task can run sync tasks."""