From a41f3fcff50c4f34e9e72b29da2abff6dd063443 Mon Sep 17 00:00:00 2001 From: Gram Date: Wed, 28 Dec 2022 09:36:41 +0100 Subject: [PATCH] Fix tests broken by moving Message to dataclass (#516) * Fix tests broken by moving Message to dataclass * Undo breaking change to Message.asdict --- dramatiq/composition.py | 9 ++++++++- dramatiq/message.py | 14 ++++++++++++-- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/dramatiq/composition.py b/dramatiq/composition.py index 531aa184..c5c3be03 100644 --- a/dramatiq/composition.py +++ b/dramatiq/composition.py @@ -14,14 +14,19 @@ # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . +from __future__ import annotations import time +from typing import TYPE_CHECKING, Iterable from uuid import uuid4 from .broker import get_broker from .rate_limits import Barrier from .results import ResultMissing +if TYPE_CHECKING: + from .message import Message + class pipeline: """Chain actors together, passing the result of one actor to the @@ -34,9 +39,11 @@ class pipeline: broker(Broker): The broker to run the pipeline on. Defaults to the current global broker. """ + messages: list[Message] - def __init__(self, children, *, broker=None): + def __init__(self, children: Iterable[Message | pipeline], *, broker=None): self.broker = broker or get_broker() + messages: list[Message] self.messages = messages = [] for child in children: diff --git a/dramatiq/message.py b/dramatiq/message.py index a68401ba..1647d183 100644 --- a/dramatiq/message.py +++ b/dramatiq/message.py @@ -89,7 +89,12 @@ def __or__(self, other) -> pipeline: def asdict(self) -> Dict[str, Any]: """Convert this message to a dictionary. """ - return dataclasses.asdict(self) + # For backward compatibility, we can't use `dataclasses.asdict` + # because it creates a copy of all values, including `options`. + result = {} + for field in dataclasses.fields(self): + result[field.name] = getattr(self, field.name) + return result @classmethod def decode(cls, data: bytes) -> "Message": @@ -100,7 +105,9 @@ def decode(cls, data: bytes) -> "Message": decoding `data`. """ try: - return cls(**global_encoder.decode(data)) + fields = global_encoder.decode(data) + fields["args"] = tuple(fields["args"]) + return cls(**fields) except Exception as e: raise DecodeError("Failed to decode message.", data, e) from e @@ -165,3 +172,6 @@ def __str__(self) -> str: params += ", ".join("%s=%r" % (name, value) for name, value in self.kwargs.items()) return "%s(%s)" % (self.actor_name, params) + + def __lt__(self, other: "Message") -> bool: + return dataclasses.astuple(self) < dataclasses.astuple(other)