Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Messages and message transforms docs #1574

Merged
merged 7 commits into from
Sep 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/basics/chat_datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ The primary entry point for fine-tuning with chat datasets in torchtune is the :
builder. This lets you specify a local or Hugging Face dataset that follows the chat data format
directly from the config and train your LLM on it.

.. _example_chat:

Example chat dataset
--------------------

Expand Down
2 changes: 2 additions & 0 deletions docs/source/basics/instruct_datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ The primary entry point for fine-tuning with instruct datasets in torchtune is t
builder. This lets you specify a local or Hugging Face dataset that follows the instruct data format
directly from the config and train your LLM on it.

.. _example_instruct:

Example instruct dataset
------------------------

Expand Down
103 changes: 103 additions & 0 deletions docs/source/basics/message_transforms.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
.. _message_transform_usage_label:

==================
Message Transforms
==================

Message transforms perform the conversion of raw sample dictionaries from your dataset into torchtune's
:class:`~torchtune.data.Message` structure. Once you data is represented as Messages, torchtune will handle
tokenization and preparing it for the model.

.. TODO (rafiayub): place an image here to depict overall pipeline


Configuring message transforms
------------------------------
Most of our built-in message transforms contain parameters for controlling input masking (``train_on_input``),
adding a system prompt (``new_system_prompt``), and changing the expected column names (``column_map``).
These are exposed in our dataset builders :func:`~torchtune.datasets.instruct_dataset` and :func:`~torchtune.datasets.chat_dataset`
so you don't have to worry about the message transform itself and can configure this directly from the config.
You can see :ref:`example_instruct` or :ref:`example_chat` for more details.


Custom message transforms
RdoubleA marked this conversation as resolved.
Show resolved Hide resolved
-------------------------
If our built-in message transforms do not configure for your particular dataset well,
you can create your own class with full flexibility. Simply inherit from the :class:`~torchtune.modules.transforms.Transform`
class and add your code in the ``__call__`` method.

A simple contrived example would be to take one column from the dataset as the user message and another
column as the model response. Indeed, this is quite similar to :class:`~torchtune.data.InputOutputToMessages`.

.. code-block:: python

from torchtune.modules.transforms import Transform
from torchtune.data import Message
from typing import Any, Mapping

class MessageTransform(Transform):
def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
return [
Message(
role="user",
content=sample["input"],
masked=True,
eot=True,
),
Message(
role="assistant",
content=sample["output"],
masked=False,
eot=True,
),
]

sample = {"input": "hello world", "output": "bye world"}
transform = MessageTransform()
messages = transform(sample)
print(messages)
# [<torchtune.data._messages.Message at 0x7fb0a10094e0>,
# <torchtune.data._messages.Message at 0x7fb0a100a290>]
for msg in messages:
print(msg.role, msg.text_content)
# user hello world
# assistant bye world

See :ref:`creating_messages` for more details on how to manipulate :class:`~torchtune.data.Message` objects.

To use this for your dataset, you must create a custom dataset builder that uses the underlying
dataset class, :class:`~torchtune.datasets.SFTDataset`.

.. code-block:: python

# In data/dataset.py
from torchtune.datasets import SFTDataset

def custom_dataset(tokenizer, **load_dataset_kwargs) -> SFTDataset:
message_transform = MyMessageTransform()
return SFTDataset(
source="json",
data_files="data/my_data.json",
split="train",
message_transform=message_transform,
model_transform=tokenizer,
**load_dataset_kwargs,
)

This can be used directly from the config.

.. code-block:: yaml

dataset:
_component_: data.dataset.custom_dataset


Example message transforms
--------------------------
- Instruct
- :class:`~torchtune.data.InputOutputToMessages`
- Chat
- :class:`~torchtune.data.ShareGPTToMessages`
- :class:`~torchtune.data.JSONToMessages`
- Preference
- :class:`~torchtune.data.ChosenRejectedToMessages`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you also add the stack exchange paired transform?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually I didn't expose that class since it's specific to stack exchange paired. I think in the future it might be nice to generalize that class

240 changes: 240 additions & 0 deletions docs/source/basics/messages.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
.. _messages_usage_label:

========
Messages
========

Messages are a core component in torchtune that govern how text and multimodal content is tokenized. It serves as the common interface
for all tokenizer and datasets APIs to operate on. Messages contain information about the text content, which role is sending the text
content, and other information relevant for special tokens in model tokenizers. For more information about the individual parameters
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comments as before. IMO this should be 1 or 2 lines + maybe image with pipeline + code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally think this is pretty concise as is. Will add a TODO to add an image. Planning to make an image for the whole data pipeline and I'll place parts of that where relevant in a follow up

for Messages, see the API ref for :class:`~torchtune.data.Message`.

.. _creating_messages:

Creating Messages
-----------------

Messages can be created via the standard class constructor or directly from a dictionary.

.. code-block:: python

from torchtune.data import Message
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great. Maybe giving a bit of context in the code, showing which class creates or receives a message, i think it would make it easier for the user to understand how things connect end2end.


msg = Message(
role="user",
content="Hello world!",
masked=True,
eot=True,
ipython=False,
)
# This is identical
msg = Message.from_dict(
{
"role": "user",
"content": "Hello world!",
"masked": True,
"eot": True,
"ipython": False,
},
)
print(msg.content)
# [{'type': 'text', 'content': 'Hello world!'}]

Content is formatted as a list of dictionaries. This is because Messages can also contain multimodal content, such as images.

Images in Messages
^^^^^^^^^^^^^^^^^^
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is great, but i think we need a bit of end2end, showing how to create the dataset and where messages are used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was actually thinking of a section "Tokenizing Messages" or similar to show how they are consumed

For multimodal datasets, you need to add the image as a :class:`~PIL.Image.Image` to the corresponding :class:`~torchtune.data.Message`.
To add it to the beginning of the message, simply prepend it to the content list.

.. code-block:: python

import PIL
from torchtune.data import Message

img_msg = Message(
role="user",
content=[
{
"type": "image",
# Place your image here
"content": PIL.Image.new(mode="RGB", size=(4, 4)),
},
{"type": "text", "content": "What's in this image?"},
],
)

This will indicate to the model tokenizers where to add the image special token and will be processed by the model transform
appropriately.

In many cases, you will have an image path instead of a raw :class:`~PIL.Image.Image`. You can use the :func:`~torchtune.data.load_image`
utility for both local paths and remote paths.

.. code-block:: python

import PIL
from torchtune.data import Message, load_image

image_path = "path/to/image.jpg"
img_msg = Message(
role="user",
content=[
{
"type": "image",
# Place your image here
"content": load_image(image_path),
},
{"type": "text", "content": "What's in this image?"},
],
)

If your dataset contain image tags, or placeholder text to indicate where in the text the image should be inserted,
you can use the :func:`~torchtune.data.format_content_with_images` to split the text into the correct content list
that you can pass into the content field of Message.

.. code-block:: python

import PIL
from torchtune.data import format_content_with_images

content = format_content_with_images(
"<|image|>hello <|image|>world",
image_tag="<|image|>",
images=[PIL.Image.new(mode="RGB", size=(4, 4)), PIL.Image.new(mode="RGB", size=(4, 4))]
)
print(content)
# [
# {"type": "image", "content": <PIL.Image.Image>},
# {"type": "text", "content": "hello "},
# {"type": "image", "content": <PIL.Image.Image>},
# {"type": "text", "content": "world"}
# ]

Message transforms
^^^^^^^^^^^^^^^^^^
Message transforms are convenient utilities to format raw data into a list of torchtune :class:`~torchtune.data.Message`
objects.

.. code-block:: python

from torchtune.data import InputOutputToMessages

sample = {
"input": "What is your name?",
"output": "I am an AI assistant, I don't have a name."
}
transform = InputOutputToMessages()
output = transform(sample)
for message in output["messages"]:
print(message.role, message.text_content)
# user What is your name?
# assistant I am an AI assistant, I don't have a name.

See :ref:`message_transform_usage_label` for more discussion.


Formatting messages with prompt templates
-----------------------------------------

Prompt templates provide a way to format messages into a structured text template. You can simply call any class that inherits
from :class:`~torchtune.data.PromptTemplateInterface` on a list of Messages and it will add the appropriate text to the content
list.

.. code-block:: python

from torchtune.models.mistral import MistralChatTemplate
from torchtune.data import Message

msg = Message(
role="user",
content="Hello world!",
masked=True,
eot=True,
ipython=False,
)
template = MistralChatTemplate()
templated_msg = template([msg])
print(templated_msg[0].content)
# [{'type': 'text', 'content': '[INST] '},
# {'type': 'text', 'content': 'Hello world!'},
# {'type': 'text', 'content': ' [/INST] '}]

Accessing text content in messages
----------------------------------
.. code-block:: python

from torchtune.models.mistral import MistralChatTemplate
from torchtune.data import Message

msg = Message(
role="user",
content="Hello world!",
masked=True,
eot=True,
ipython=False,
)
template = MistralChatTemplate()
templated_msg = template([msg])
print(templated_msg[0].text_content)
# [INST] Hello world! [/INST]

Accessing images in messages
----------------------------
.. code-block:: python

from torchtune.data import Message
import PIL

msg = Message(
role="user",
content=[
{
"type": "image",
# Place your image here
"content": PIL.Image.new(mode="RGB", size=(4, 4)),
},
{"type": "text", "content": "What's in this image?"},
],
)
if msg.contains_media:
print(msg.get_media())
# [<PIL.Image.Image image mode=RGB size=4x4 at 0x7F8D27E72740>]

Tokenizing messages
-------------------
All model tokenizers have a ``tokenize_messsages`` method that converts a list of
:class:`~torchtune.data.Message` objects into token IDs and a loss mask.

.. code-block:: python

from torchtune.models.mistral import mistral_tokenizer
from torchtune.data import Message

m_tokenizer = mistral_tokenizer(
path="/tmp/Mistral-7B-v0.1/tokenizer.model",
prompt_template="torchtune.models.mistral.MistralChatTemplate",
max_seq_len=8192,
)
msgs = [
Message(
role="user",
content="Hello world!",
masked=True,
eot=True,
ipython=False,
),
Message(
role="assistant",
content="Hi, I am an AI assistant.",
masked=False,
eot=True,
ipython=False,
)
]
tokens, mask = m_tokenizer.tokenize_messages(msgs)
print(tokens)
# [1, 733, 16289, 28793, 22557, 1526, 28808, 28705, 733, 28748, 16289, 28793, 15359, 28725, 315, 837, 396, 16107, 13892, 28723, 2]
print(mask) # User message is masked from the loss
# [True, True, True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False]
print(m_tokenizer.decode(tokens))
# [INST] Hello world! [/INST] Hi, I am an AI assistant.
2 changes: 2 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ torchtune tutorials.
:caption: Basics
:hidden:

basics/messages
basics/message_transforms
basics/instruct_datasets
basics/chat_datasets
basics/tokenizers
Expand Down
Loading