Skip to content

Commit

Permalink
Fix tests broken by moving Message to dataclass (#516)
Browse files Browse the repository at this point in the history
* Fix tests broken by moving Message to dataclass
* Undo breaking change to Message.asdict
  • Loading branch information
orsinium authored Dec 28, 2022
1 parent 425ba43 commit a41f3fc
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
9 changes: 8 additions & 1 deletion dramatiq/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,19 @@
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from __future__ import annotations

import time
from typing import TYPE_CHECKING, Iterable
from uuid import uuid4

from .broker import get_broker
from .rate_limits import Barrier
from .results import ResultMissing

if TYPE_CHECKING:
from .message import Message


class pipeline:
"""Chain actors together, passing the result of one actor to the
Expand All @@ -34,9 +39,11 @@ class pipeline:
broker(Broker): The broker to run the pipeline on. Defaults to
the current global broker.
"""
messages: list[Message]

def __init__(self, children, *, broker=None):
def __init__(self, children: Iterable[Message | pipeline], *, broker=None):
self.broker = broker or get_broker()
messages: list[Message]
self.messages = messages = []

for child in children:
Expand Down
14 changes: 12 additions & 2 deletions dramatiq/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,12 @@ def __or__(self, other) -> pipeline:
def asdict(self) -> Dict[str, Any]:
"""Convert this message to a dictionary.
"""
return dataclasses.asdict(self)
# For backward compatibility, we can't use `dataclasses.asdict`
# because it creates a copy of all values, including `options`.
result = {}
for field in dataclasses.fields(self):
result[field.name] = getattr(self, field.name)
return result

@classmethod
def decode(cls, data: bytes) -> "Message":
Expand All @@ -100,7 +105,9 @@ def decode(cls, data: bytes) -> "Message":
decoding `data`.
"""
try:
return cls(**global_encoder.decode(data))
fields = global_encoder.decode(data)
fields["args"] = tuple(fields["args"])
return cls(**fields)
except Exception as e:
raise DecodeError("Failed to decode message.", data, e) from e

Expand Down Expand Up @@ -165,3 +172,6 @@ def __str__(self) -> str:
params += ", ".join("%s=%r" % (name, value) for name, value in self.kwargs.items())

return "%s(%s)" % (self.actor_name, params)

def __lt__(self, other: "Message") -> bool:
return dataclasses.astuple(self) < dataclasses.astuple(other)

0 comments on commit a41f3fc

Please sign in to comment.