Skip to content

Commit

Permalink
Add max_attempts_at_message
Browse files Browse the repository at this point in the history
  • Loading branch information
Arseniy-Popov committed Dec 29, 2024
1 parent 49c0408 commit 1a432e5
Show file tree
Hide file tree
Showing 12 changed files with 185 additions and 10 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ ignore = [
"ANN401", # typing.Any are disallowed in `**kwargs
"PLR0913", # Too many arguments for function call
"D106", # Missing docstring in public nested class
"D205", # 1 blank line required between summary line and description
]
exclude = [".venv/"]
mccabe = { max-complexity = 10 }
Expand Down
7 changes: 5 additions & 2 deletions taskiq/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Distributed task manager."""

from importlib.metadata import version

from taskiq_dependencies import Depends as TaskiqDepends
Expand All @@ -8,7 +9,7 @@
from taskiq.abc.middleware import TaskiqMiddleware
from taskiq.abc.result_backend import AsyncResultBackend
from taskiq.abc.schedule_source import ScheduleSource
from taskiq.acks import AckableMessage
from taskiq.acks import AckableMessage, AckableMessageWithDeliveryCount
from taskiq.brokers.inmemory_broker import InMemoryBroker
from taskiq.brokers.shared_broker import async_shared_broker
from taskiq.brokers.zmq_broker import ZeroMQBroker
Expand All @@ -24,7 +25,7 @@
TaskiqResultTimeoutError,
)
from taskiq.funcs import gather
from taskiq.message import BrokerMessage, TaskiqMessage
from taskiq.message import BrokerMessage, DeliveryCountMessage, TaskiqMessage
from taskiq.middlewares.prometheus_middleware import PrometheusMiddleware
from taskiq.middlewares.retry_middleware import SimpleRetryMiddleware
from taskiq.result import TaskiqResult
Expand Down Expand Up @@ -53,6 +54,8 @@
"NoResultError",
"SendTaskError",
"AckableMessage",
"DeliveryCountMessage",
"AckableMessageWithDeliveryCount",
"InMemoryBroker",
"ScheduleSource",
"TaskiqScheduler",
Expand Down
1 change: 1 addition & 0 deletions taskiq/abc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Abstract classes for taskiq."""

from taskiq.abc.broker import AsyncBroker
from taskiq.abc.result_backend import AsyncResultBackend

Expand Down
2 changes: 2 additions & 0 deletions taskiq/abc/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(
self,
result_backend: "Optional[AsyncResultBackend[_T]]" = None,
task_id_generator: Optional[Callable[[], str]] = None,
max_attempts_at_message: Optional[int] = None,
) -> None:
if result_backend is None:
result_backend = DummyResultBackend()
Expand Down Expand Up @@ -113,6 +114,7 @@ def __init__(
self.state = TaskiqState()
self.custom_dependency_context: Dict[Any, Any] = {}
self.dependency_overrides: Dict[Any, Any] = {}
self.max_attempts_at_message = max_attempts_at_message
# True only if broker runs in worker process.
self.is_worker_process: bool = False
# True only if broker runs in scheduler process.
Expand Down
9 changes: 6 additions & 3 deletions taskiq/acks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import enum
from typing import Awaitable, Callable, Union

from pydantic import BaseModel
from taskiq.message import DeliveryCountMessage, WrappedMessage


@enum.unique
Expand All @@ -20,7 +20,7 @@ class AcknowledgeType(str, enum.Enum):
WHEN_SAVED = "when_saved"


class AckableMessage(BaseModel):
class AckableMessage(WrappedMessage):
"""
Message that can be acknowledged.
Expand All @@ -33,5 +33,8 @@ class AckableMessage(BaseModel):
as a whole.
"""

data: bytes
ack: Callable[[], Union[None, Awaitable[None]]]


class AckableMessageWithDeliveryCount(AckableMessage, DeliveryCountMessage):
"""Message that can be acknowledged and has a delivery count."""
1 change: 1 addition & 0 deletions taskiq/cli/worker/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def interrupt_handler(signum: int, _frame: Any) -> None:
ack_type=args.ack_type,
max_tasks_to_execute=args.max_tasks_per_child,
wait_tasks_timeout=args.wait_tasks_timeout,
max_attempts_at_message=broker.max_attempts_at_message,
**receiver_kwargs, # type: ignore
)
loop.run_until_complete(receiver.listen())
Expand Down
12 changes: 12 additions & 0 deletions taskiq/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,15 @@ class BrokerMessage(BaseModel):
task_name: str
message: bytes
labels: Dict[str, Any]


class WrappedMessage(BaseModel):
"""Abstraction for an incoming message in a wrapper."""

data: bytes


class DeliveryCountMessage(WrappedMessage):
"""Message with a present delivery count."""

delivery_count: int | None = None
32 changes: 29 additions & 3 deletions taskiq/receiver/receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from taskiq.acks import AcknowledgeType
from taskiq.context import Context
from taskiq.exceptions import NoResultError
from taskiq.message import TaskiqMessage
from taskiq.message import DeliveryCountMessage, TaskiqMessage, WrappedMessage
from taskiq.receiver.params_parser import parse_params
from taskiq.result import TaskiqResult
from taskiq.state import TaskiqState
Expand Down Expand Up @@ -58,6 +58,7 @@ def __init__(
on_exit: Optional[Callable[["Receiver"], None]] = None,
max_tasks_to_execute: Optional[int] = None,
wait_tasks_timeout: Optional[float] = None,
max_attempts_at_message: Optional[int] = None,
) -> None:
self.broker = broker
self.executor = executor
Expand All @@ -72,6 +73,7 @@ def __init__(
self.known_tasks: Set[str] = set()
self.max_tasks_to_execute = max_tasks_to_execute
self.wait_tasks_timeout = wait_tasks_timeout
self.max_attempts_at_message = max_attempts_at_message
for task in self.broker.get_all_tasks().values():
self._prepare_task(task.task_name, task.original_func)
self.sem: "Optional[asyncio.Semaphore]" = None
Expand All @@ -86,7 +88,7 @@ def __init__(

async def callback( # noqa: C901, PLR0912
self,
message: Union[bytes, AckableMessage],
message: Union[bytes, WrappedMessage],
raise_err: bool = False,
) -> None:
"""
Expand All @@ -101,7 +103,31 @@ async def callback( # noqa: C901, PLR0912
:param raise_err: raise an error if cannot save result in
result_backend.
"""
message_data = message.data if isinstance(message, AckableMessage) else message
message_data = message.data if isinstance(message, WrappedMessage) else message

delivery_count = (
message.delivery_count
if isinstance(message, DeliveryCountMessage)
else None
)
if (
delivery_count
and self.max_attempts_at_message
and delivery_count >= self.max_attempts_at_message
):
logger.error(
"Permitted number of attempts at processing message %s "
"has been exhausted after %s attempts.",
message_data,
self.max_attempts_at_message,
)
if isinstance(
message,
AckableMessage,
):
await maybe_awaitable(message.ack())
return

try:
taskiq_msg = self.broker.formatter.loads(message=message_data)
taskiq_msg.parse_labels()
Expand Down
1 change: 1 addition & 0 deletions taskiq/schedule_sources/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Package for schedule sources."""

from taskiq.schedule_sources.label_based import LabelScheduleSource

__all__ = [
Expand Down
4 changes: 3 additions & 1 deletion taskiq/scheduler/created_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ async def kiq(
...

@overload
async def kiq(self: "CreatedSchedule[_ReturnType]") -> AsyncTaskiqTask[_ReturnType]:
async def kiq(
self: "CreatedSchedule[_ReturnType]",
) -> AsyncTaskiqTask[_ReturnType]:
...

async def kiq(self) -> Any:
Expand Down
1 change: 1 addition & 0 deletions taskiq/serializers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Taskiq serializers."""

from .cbor_serializer import CBORSerializer
from .json_serializer import JSONSerializer
from .msgpack_serializer import MSGPackSerializer
Expand Down
124 changes: 123 additions & 1 deletion tests/receiver/test_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@

from taskiq.abc.broker import AckableMessage, AsyncBroker
from taskiq.abc.middleware import TaskiqMiddleware
from taskiq.acks import AckableMessageWithDeliveryCount
from taskiq.brokers.inmemory_broker import InMemoryBroker
from taskiq.exceptions import NoResultError, TaskiqResultTimeoutError
from taskiq.message import TaskiqMessage
from taskiq.message import DeliveryCountMessage, TaskiqMessage
from taskiq.receiver import Receiver
from taskiq.result import TaskiqResult
from tests.utils import AsyncQueueBroker
Expand Down Expand Up @@ -359,6 +360,127 @@ async def test_callback_unknown_task() -> None:
await receiver.callback(broker_message.message)


@pytest.mark.anyio
@pytest.mark.parametrize("delivery_count", [2, None])
async def test_callback_max_attempts_at_message_not_exceeded(
delivery_count: Optional[int],
) -> None:
"""
Test that callback function calls the task if `max_attempts_at_message`
is not exceeded.
"""
broker = InMemoryBroker()
called_times = 0

@broker.task
async def my_task() -> int:
nonlocal called_times
called_times += 1
return 1

receiver = get_receiver(broker)
receiver.max_attempts_at_message = 3

broker_message = broker.formatter.dumps(
TaskiqMessage(
task_id="task_id",
task_name=my_task.task_name,
labels={},
args=[],
kwargs={},
),
)

await receiver.callback(
DeliveryCountMessage(
data=broker_message.message,
delivery_count=delivery_count,
),
)
assert called_times == 1


@pytest.mark.anyio
async def test_callback_max_attempts_at_message_exceeded() -> None:
"""
Test that callback function does not call the task if `max_attempts_at_message`
is exceeded.
"""
broker = InMemoryBroker()
called_times = 0

@broker.task
async def my_task() -> int:
nonlocal called_times
called_times += 1
return 1

receiver = get_receiver(broker)
receiver.max_attempts_at_message = 3

broker_message = broker.formatter.dumps(
TaskiqMessage(
task_id="task_id",
task_name=my_task.task_name,
labels={},
args=[],
kwargs={},
),
)

await receiver.callback(
DeliveryCountMessage(
data=broker_message.message,
delivery_count=3,
),
)
assert called_times == 0


@pytest.mark.anyio
async def test_callback_max_attempts_at_message_exceeded_ackable() -> None:
"""
Test that callback function does not call the task if `max_attempts_at_message`
is exceeded and acks the message.
"""
broker = InMemoryBroker()
called_times = 0
acked = False

@broker.task
async def my_task() -> int:
nonlocal called_times
called_times += 1
return 1

async def ack_callback() -> None:
nonlocal acked
acked = True

receiver = get_receiver(broker)
receiver.max_attempts_at_message = 3

broker_message = broker.formatter.dumps(
TaskiqMessage(
task_id="task_id",
task_name=my_task.task_name,
labels={},
args=[],
kwargs={},
),
)

await receiver.callback(
AckableMessageWithDeliveryCount(
data=broker_message.message,
delivery_count=3,
ack=ack_callback,
),
)
assert called_times == 0
assert acked


@pytest.mark.anyio
async def test_custom_ctx() -> None:
"""Tests that run_task can run sync tasks."""
Expand Down

0 comments on commit 1a432e5

Please sign in to comment.