Skip to content

Commit

Permalink
Fail early with packed=True on MM datasets. (#2080)
Browse files Browse the repository at this point in the history
  • Loading branch information
SalmanMohammadi authored Nov 27, 2024
1 parent b5d2e63 commit d923234
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,12 @@ def test_get_item(self, load_image, load_dataset, tokenizer, test_image_pil):
assert Counter(input) == expected_count
assert labels.count(CROSS_ENTROPY_IGNORE_IDX) == 11
assert images == [test_image_pil]

def test_dataset_fails_with_packed(self, tokenizer):
with pytest.raises(
ValueError, match="Multimodal datasets don't support packing yet."
):
llava_instruct_dataset(
model_transform=tokenizer,
packed=True,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
from tests.test_utils import DummyTokenizer

from torchtune.datasets.multimodal import multimodal_chat_dataset


class TestMultimodalChatDataset:
@pytest.fixture
def tokenizer(self):
return DummyTokenizer()

def test_dataset_fails_with_packed(self, tokenizer):
with pytest.raises(
ValueError, match="Multimodal datasets don't support packing yet."
):
multimodal_chat_dataset(
model_transform=tokenizer, source="json", packed=True
)
10 changes: 10 additions & 0 deletions tests/torchtune/datasets/multimodal/test_the_cauldron_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,13 @@ def test_get_item(self, load_dataset, tokenizer, test_image_pil):
]
assert labels.count(CROSS_ENTROPY_IGNORE_IDX) == 24
assert images == [test_image_pil]

def test_dataset_fails_with_packed(self, tokenizer):
with pytest.raises(
ValueError, match="Multimodal datasets don't support packing yet."
):
the_cauldron_dataset(
model_transform=tokenizer,
subset="dummy",
packed=True,
)
10 changes: 10 additions & 0 deletions tests/torchtune/datasets/multimodal/test_vqa_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,13 @@ def test_get_item(self, tokenizer):
assert prompt == expected_tokens[i]
assert label == expected_labels[i]
assert isinstance(image[0], PngImageFile)

def test_dataset_fails_with_packed(self, tokenizer):
with pytest.raises(
ValueError, match="Multimodal datasets don't support packing yet."
):
vqa_dataset(
model_transform=tokenizer,
source="json",
packed=True,
)
5 changes: 3 additions & 2 deletions torchtune/datasets/multimodal/_llava_instruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
>>> print(f"Batch size: {len(batch)}")
>>> Batch size: 8
"""
if packed:
raise ValueError("Multimodal datasets don't support packing yet.")

message_transform = ShareGPTToMessages(
train_on_input=False,
Expand All @@ -136,6 +138,5 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
data_files=data_files,
**load_dataset_kwargs,
)
if packed:
raise ValueError("Multimodal datasets don't support packing yet.")

return ds
9 changes: 9 additions & 0 deletions torchtune/datasets/multimodal/_multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def multimodal_chat_dataset(
source: str,
column_map: Optional[Dict[str, str]] = None,
new_system_prompt: Optional[str] = None,
packed: bool = False,
image_tag: Optional[str] = None,
image_dir: Optional[str] = None,
filter_fn: Optional[Callable] = None,
Expand Down Expand Up @@ -79,6 +80,7 @@ def multimodal_chat_dataset(
new_system_prompt (Optional[str]): if specified, prepend a system message. This can
serve as instructions to guide the model response. Setting this will OVERRIDE any system
messages already present in the dataset. Default is None.
packed (bool): Whether or not to pack the dataset to ``max_seq_len`` prior to training. Default is False.
image_tag (Optional[str]): placeholder tags in the text content of each message to be replaced by dictionaries
indicating to the tokenizer where to place image tokens. If images are present and this is None,
then will prepend image tokens to the first user message in the sample by default. If text-only, leave
Expand Down Expand Up @@ -169,7 +171,14 @@ def multimodal_chat_dataset(
Returns:
SFTDataset: the configured :class:`~torchtune.datasets.SFTDataset`
Raises:
ValueError: If ``packed`` is True, they are not supported for multimodal datasets yet.
"""
if packed:
raise ValueError("Multimodal datasets don't support packing yet.")

message_transform = ShareGPTToMessages(
train_on_input=False,
column_map=column_map,
Expand Down
5 changes: 3 additions & 2 deletions torchtune/datasets/multimodal/_the_cauldron.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,8 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
>>> print(f"Batch size: {len(batch)}")
>>> Batch size: 8
"""
if packed:
raise ValueError("Multimodal datasets don't support packing yet.")

message_transform = TheCauldronToMessages(
column_map=column_map,
Expand All @@ -231,6 +233,5 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
split=split,
**load_dataset_kwargs,
)
if packed:
raise ValueError("Multimodal datasets don't support packing yet.")

return ds
9 changes: 9 additions & 0 deletions torchtune/datasets/multimodal/_vqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def vqa_dataset(
image_dir: str = None,
column_map: Optional[Dict[str, str]] = None,
new_system_prompt: Optional[str] = None,
packed: bool = False,
filter_fn: Optional[Callable] = None,
split: str = "train",
**load_dataset_kwargs: Dict[str, Any],
Expand Down Expand Up @@ -63,6 +64,7 @@ def vqa_dataset(
new_system_prompt (Optional[str]): if specified, prepend a system message. This can
serve as instructions to guide the model response. Setting this will OVERRIDE any system
messages already present in the dataset. Default is None.
packed (bool): Whether or not to pack the dataset to ``max_seq_len`` prior to training. Default is False.
filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See
the Hugging Face `docs <https://huggingface.co/docs/datasets/v2.20.0/process#select-and-filter>`_ for more
details.
Expand Down Expand Up @@ -122,7 +124,14 @@ def vqa_dataset(
Returns:
SFTDataset: the configured :class:`~torchtune.datasets.SFTDataset`
Raises:
ValueError: If ``packed`` is True, they are not supported for multimodal datasets yet.
"""
if packed:
raise ValueError("Multimodal datasets don't support packing yet.")

message_transform = InputOutputToMessages(
column_map=column_map, new_system_prompt=new_system_prompt, image_dir=image_dir
)
Expand Down

0 comments on commit d923234

Please sign in to comment.