Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise DatasetTooSmall exception if canonical nodes is less than num samples #1518

Merged
merged 9 commits into from
Sep 12, 2024
4 changes: 3 additions & 1 deletion llmfoundry/command_utils/data_prep/convert_text_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,9 @@ def convert_text_to_mds(
index_path = os.path.join(local_output_folder, 'index.json')
with open(index_path, 'r') as index_file:
if not json.load(index_file)['shards']:
raise DatasetTooSmallError()
raise DatasetTooSmallError(
irenedea marked this conversation as resolved.
Show resolved Hide resolved
reason='No shards were created when converting text to MDS.',
)

# Write a done file with the args and object names
write_done_file(local_output_folder, args_str, object_names)
Expand Down
20 changes: 19 additions & 1 deletion llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
ALLOWED_RESPONSE_KEYS,
ChatTemplateError,
ConsecutiveRepeatedChatRolesError,
DatasetTooSmallError,
IncorrectMessageKeyQuantityError,
InvalidContentTypeError,
InvalidExampleTypeError,
Expand Down Expand Up @@ -1033,7 +1034,24 @@ def build_from_streaming(
*args: Any,
**kwargs: Any,
) -> StreamingFinetuningDataset:
return self.streaming_dataset_class(*args, **kwargs)
dataset = self.streaming_dataset_class(*args, **kwargs)
irenedea marked this conversation as resolved.
Show resolved Hide resolved
num_canonical_nodes = dataset.num_canonical_nodes
num_samples = dataset.num_samples
if num_canonical_nodes is None:
num_physical_nodes = dist.get_world_size(
) // dist.get_local_world_size()
if num_samples < num_physical_nodes:
raise DatasetTooSmallError(
f'{num_samples=} is less than {dist.get_world_size() // dist.get_local_world_size()}, the number of physical nodes. ',
)

if num_canonical_nodes is not None and num_samples < num_canonical_nodes:
raise DatasetTooSmallError(
irenedea marked this conversation as resolved.
Show resolved Hide resolved
f'{num_samples=} is less than {num_canonical_nodes=}. ' +
'Please check your index.json file and ensure that your dataset has been written out correctly.'
+ 'If this was intended, reduce num_canonical_nodes.',
)
return dataset


dataset_constructor = DatasetConstructor()
Expand Down
6 changes: 3 additions & 3 deletions llmfoundry/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,9 +376,9 @@ def __init__(self, dataset_name: str, split: str) -> None:
class DatasetTooSmallError(UserError):
"""Error thrown when the dataset is too small to be processed."""

def __init__(self) -> None:
message = f'Your dataset is too small and produced no complete samples during preprocessing. Please provide more data.'
super().__init__(message)
def __init__(self, reason: str) -> None:
message = f'Your dataset is too small and produced no complete samples or too few samples. Please provide more data. {reason}'
super().__init__(message, reason=reason)


class RunTimeoutError(InternalError):
Expand Down
44 changes: 44 additions & 0 deletions tests/data/test_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
from contextlib import nullcontext
from typing import Optional
from unittest import mock

import pytest

from llmfoundry.data.finetuning.tasks import dataset_constructor
from llmfoundry.utils.exceptions import DatasetTooSmallError


@pytest.mark.parametrize('num_canonical_nodes', [None, 8, 2])
def test_finetuning_streaming_dataset_too_small(
num_canonical_nodes: Optional[int],
):
num_samples = 2

class MockDataset:

def __init__(self):
self.num_canonical_nodes = num_canonical_nodes
self.num_samples = num_samples

class MockDist:

def get_world_size(self):
return 32

def get_local_world_size(self):
return 8

result_context = nullcontext(
) if num_canonical_nodes == 2 else pytest.raises(DatasetTooSmallError)
with result_context:
with mock.patch(
'llmfoundry.data.finetuning.tasks.dist',
new=MockDist(),
):
with mock.patch(
'llmfoundry.data.finetuning.tasks.DatasetConstructor.streaming_dataset_class',
new=MockDataset,
):
dataset_constructor.build_from_streaming()
Loading