diff --git a/tests/torchtune/data/test_messages.py b/tests/torchtune/data/test_messages.py index adc64bfdec..a46cfd9349 100644 --- a/tests/torchtune/data/test_messages.py +++ b/tests/torchtune/data/test_messages.py @@ -86,6 +86,17 @@ def test_text_content(self, text_message, image_message): assert text_message.text_content == "hello world" assert image_message.text_content == "hello world" + def test_repr_text(self, text_message): + expected_repr = "Message(role='user', content=['hello world'])" + assert str(text_message) == expected_repr + assert repr(text_message) == expected_repr + + def test_repr_image(self, image_message, test_image): + img_repr = str(test_image) + expected_repr = f"Message(role='user', content=['hello', {img_repr}, ' world'])" + assert str(image_message) == expected_repr + assert repr(image_message) == expected_repr + class TestInputOutputToMessages: @pytest.fixture diff --git a/torchtune/data/_messages.py b/torchtune/data/_messages.py index 2d0eddce5d..fb24d9fbe7 100644 --- a/torchtune/data/_messages.py +++ b/torchtune/data/_messages.py @@ -136,6 +136,10 @@ def _validate_message(self) -> None: f"Only assistant messages can be tool calls. Found role {self.role} in message: {self.text_content}" ) + def __repr__(self) -> str: + content_only = [content["content"] for content in self.content] + return f"Message(role='{self.role}', content={content_only!r})" + class InputOutputToMessages(Transform): """