From cf892a834d2f5ea440d91159dd13124f6acc8a37 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Sun, 6 Oct 2024 16:09:56 +0000 Subject: [PATCH 1/2] Adds repr to Message --- tests/torchtune/data/test_messages.py | 11 +++++++++++ torchtune/data/_messages.py | 4 ++++ 2 files changed, 15 insertions(+) diff --git a/tests/torchtune/data/test_messages.py b/tests/torchtune/data/test_messages.py index adc64bfdec..a46cfd9349 100644 --- a/tests/torchtune/data/test_messages.py +++ b/tests/torchtune/data/test_messages.py @@ -86,6 +86,17 @@ def test_text_content(self, text_message, image_message): assert text_message.text_content == "hello world" assert image_message.text_content == "hello world" + def test_repr_text(self, text_message): + expected_repr = "Message(role='user', content=['hello world'])" + assert str(text_message) == expected_repr + assert repr(text_message) == expected_repr + + def test_repr_image(self, image_message, test_image): + img_repr = str(test_image) + expected_repr = f"Message(role='user', content=['hello', {img_repr}, ' world'])" + assert str(image_message) == expected_repr + assert repr(image_message) == expected_repr + class TestInputOutputToMessages: @pytest.fixture diff --git a/torchtune/data/_messages.py b/torchtune/data/_messages.py index 2d0eddce5d..6655d170d5 100644 --- a/torchtune/data/_messages.py +++ b/torchtune/data/_messages.py @@ -136,6 +136,10 @@ def _validate_message(self) -> None: f"Only assistant messages can be tool calls. Found role {self.role} in message: {self.text_content}" ) + def __repr__(self) -> str: + content_only = [content["content"] for content in self.content] + return f"Message(role={self.role!r}, content={content_only!r})" + class InputOutputToMessages(Transform): """ From c4616d6f4dad14e5c45d3244d8b5c527e71d81a8 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Tue, 8 Oct 2024 01:24:29 +0000 Subject: [PATCH 2/2] Remove repr from role --- torchtune/data/_messages.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtune/data/_messages.py b/torchtune/data/_messages.py index 6655d170d5..fb24d9fbe7 100644 --- a/torchtune/data/_messages.py +++ b/torchtune/data/_messages.py @@ -138,7 +138,7 @@ def _validate_message(self) -> None: def __repr__(self) -> str: content_only = [content["content"] for content in self.content] - return f"Message(role={self.role!r}, content={content_only!r})" + return f"Message(role='{self.role}', content={content_only!r})" class InputOutputToMessages(Transform):