Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make FinetuningStreamingDataset parameters more flexible #1580

Merged
merged 38 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
ba90b6d
update
XiaohanZhangCMU Sep 14, 2024
9bc080e
update
XiaohanZhangCMU Sep 14, 2024
4795051
update
XiaohanZhangCMU Sep 14, 2024
c97fe6e
update
XiaohanZhangCMU Sep 14, 2024
a8c58a6
update
XiaohanZhangCMU Sep 14, 2024
8da5211
update
XiaohanZhangCMU Sep 14, 2024
56f47ab
update
XiaohanZhangCMU Sep 14, 2024
4c78a64
update
XiaohanZhangCMU Sep 18, 2024
46105f0
update
XiaohanZhangCMU Sep 14, 2024
9a75ccd
update
XiaohanZhangCMU Sep 14, 2024
9ee6159
update
XiaohanZhangCMU Sep 14, 2024
6a3be7b
update
XiaohanZhangCMU Sep 14, 2024
e08082d
update
XiaohanZhangCMU Sep 14, 2024
01d9ddd
update
XiaohanZhangCMU Sep 14, 2024
2c1f9d8
update
XiaohanZhangCMU Sep 14, 2024
a042327
update
XiaohanZhangCMU Sep 18, 2024
5ac6557
update
XiaohanZhangCMU Oct 4, 2024
7c7f736
merge
XiaohanZhangCMU Oct 4, 2024
d031ad1
update
XiaohanZhangCMU Oct 4, 2024
f31cfec
update
XiaohanZhangCMU Oct 4, 2024
5d08a20
update
XiaohanZhangCMU Oct 9, 2024
a1264e4
update
XiaohanZhangCMU Oct 10, 2024
55dcd67
update
XiaohanZhangCMU Oct 10, 2024
fc1743d
update
XiaohanZhangCMU Oct 11, 2024
f034a86
update
XiaohanZhangCMU Oct 11, 2024
9a2e7b4
update
XiaohanZhangCMU Oct 11, 2024
1a4f8ff
update
XiaohanZhangCMU Oct 11, 2024
9e830b4
update
XiaohanZhangCMU Oct 11, 2024
a574b04
update
XiaohanZhangCMU Oct 11, 2024
58b6a29
Merge branch 'main' into xiaohan/delta-streaming-test
XiaohanZhangCMU Oct 11, 2024
0478d49
update
XiaohanZhangCMU Oct 11, 2024
3e50fbf
update
XiaohanZhangCMU Oct 11, 2024
adf6d95
update
XiaohanZhangCMU Oct 11, 2024
6bbdee7
update
XiaohanZhangCMU Oct 11, 2024
ce12512
Merge branch 'main' into xiaohan/delta-streaming-test
XiaohanZhangCMU Oct 11, 2024
966153e
xfail
XiaohanZhangCMU Oct 12, 2024
62463cb
update
XiaohanZhangCMU Oct 12, 2024
abce3b0
lint
XiaohanZhangCMU Oct 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ def build_finetuning_dataloader(
allowed_dataset_config_keys = set(
dataset_constructor_keys,
).union(_ALLOWED_DATASET_KEYS)
_validate_config(

extraneous_keys = _validate_config(
**dataset_cfg,
allowed_dataset_keys=allowed_dataset_config_keys,
)
Expand Down Expand Up @@ -253,13 +254,13 @@ def build_finetuning_dataloader(
streams_cfg,
) if streams_cfg is not None else None

# Take the constructor args from above, minus args that have been created separately
dataset_constructor_args = {
k: v
for k, v in dataset_cfg.items()
if k in dataset_constructor_keys and
if k in set(dataset_constructor_keys).union(extraneous_keys) and
k not in {'streams', 'packing_ratio'}
}

streaming_dataset = dataset_constructor.build_from_streaming(
tokenizer=tokenizer,
streams=streams,
Expand Down Expand Up @@ -378,7 +379,7 @@ def _validate_config(
target_responses: Optional[str] = None,
allowed_dataset_keys: set[str] = _ALLOWED_DATASET_KEYS,
**kwargs: dict[str, Any],
) -> None:
) -> set[str]:
"""Validates the dataset configuration.

Makes sure that the dataset is properly configured for either
Expand Down Expand Up @@ -434,11 +435,16 @@ def _validate_config(

Raises:
ValueError: If the dataset configuration does not meet the requirements.

Returns:
set[str]: Return the extraneous keys.
"""
extraneous_keys = set()
if not set(kwargs.keys()).issubset(allowed_dataset_keys):
raise ValueError(
extraneous_keys = set(kwargs.keys()) - allowed_dataset_keys
log.warning(
'The dataset config contains the following extraneous keys: ' +\
', '.join(set(kwargs.keys()) - allowed_dataset_keys),
', '.join(extraneous_keys),
)

if hf_name is not None:
Expand Down Expand Up @@ -533,6 +539,8 @@ def _validate_config(
decoder_only_format,
)

return extraneous_keys


def _download_remote_hf_dataset(remote_path: str, split: str) -> str:
"""Downloads a dataset from a remote object store.
Expand Down
6 changes: 1 addition & 5 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,11 +613,6 @@ def __init__(
**kwargs: Any,
):

if len(kwargs) > 0:
raise ValueError(
f'StreamingFinetuningDataset() got an unexpected keyword argument: {kwargs}',
)

if token_encoding_type not in SUPPORTED_MDS_ENCODING_TYPES:
raise ValueError(
f'The token_encoding_type must be one of {SUPPORTED_MDS_ENCODING_TYPES}, but got {token_encoding_type}',
Expand Down Expand Up @@ -658,6 +653,7 @@ def __init__(
batching_method=batching_method,
allow_unsafe_types=allow_unsafe_types,
replication=replication,
**kwargs,
)

self.tokenizer = tokenizer
Expand Down
11 changes: 5 additions & 6 deletions llmfoundry/data/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,6 @@ def __init__(
**kwargs: Any,
):

if len(kwargs) > 0:
raise ValueError(
f'StreamingTextDataset() got an unexpected keyword argument: {kwargs}',
)

if token_encoding_type not in SUPPORTED_MDS_ENCODING_TYPES:
raise ValueError(
f'The token_encoding_type must be one of {SUPPORTED_MDS_ENCODING_TYPES}, but got {token_encoding_type}',
Expand Down Expand Up @@ -188,6 +183,7 @@ def __init__(
batching_method=batching_method,
allow_unsafe_types=allow_unsafe_types,
replication=replication,
**kwargs,
)
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
Expand Down Expand Up @@ -332,10 +328,13 @@ def build_text_dataloader(
StreamingTextDataset,
).parameters

valid_base_dataset_params = inspect.signature(StreamingDataset,).parameters

dataset_config_subset_for_streaming_text_dataset = {
k: v
for k, v in dataset_cfg.items()
if k in valid_streaming_text_dataset_parameters
if k in valid_streaming_text_dataset_parameters or
k in valid_base_dataset_params
}

# build dataset potentially with streams
Expand Down
161 changes: 159 additions & 2 deletions tests/data/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from contextlib import nullcontext as does_not_raise
from pathlib import Path
from typing import Any, Callable, ContextManager, Literal, Optional, Union
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, mock_open, patch

import catalogue
import numpy as np
Expand Down Expand Up @@ -686,7 +686,6 @@ def test_finetuning_dataloader_streaming(
'dataset': {
'max_seq_len': 2048,
'decoder_only_format': True,
'allow_pad_trimming': False,
XiaohanZhangCMU marked this conversation as resolved.
Show resolved Hide resolved
'packing_ratio': None,
'shuffle': True,
},
Expand Down Expand Up @@ -1423,3 +1422,161 @@ def test_sharegpt_format(
device_batch_size=device_batch_size,
**cfg,
).dataloader

def test_ft_dataloader_with_extra_keys():
max_seq_len = 2
cfg = {
'dataset': {
'remote': '/remote',
'local': '/local',
'split': 'train',
'max_seq_len': 2048,
'decoder_only_format': True,
'shuffle': True,
'num_canonical_nodes': 472,
'target_responses': 'last',
'target_prompts': 'none',
'extra_key_1': 'extra_key_1',
'extra_key_2': 'extra_key_2',
'extra_key_3': 'extra_key_3',
},
'drop_last': False,
'num_workers': 0,
'pin_memory': False,
'prefetch_factor': None,
'persistent_workers': False,
'timeout': 0,
}

cfg = om.create(cfg)

tokenizer = build_tokenizer(
tokenizer_name='gpt2',
tokenizer_kwargs={'model_max_length': max_seq_len},
)

device_batch_size = 2

mock_stat = MagicMock()
mock_stat.st_size = 1024 # Mock st_size with a desired value
mock_stat.st_mode = 33188 # Regular file mode for Unix-based systems

#with patch('streaming.base.stream.get_shards', return_value=None):
with patch('os.makedirs'), \
patch('builtins.open', new_callable=mock_open, read_data='{"version": 2, "shards": []}'), \
patch('json.load') as mock_json_load, \
patch('os.stat', return_value=mock_stat), \
patch('torch.distributed.is_available', return_value=True), \
patch('torch.distributed.is_initialized', return_value=True), \
patch('torch.distributed.broadcast_object_list'), \
patch('torch.distributed.init_process_group'), \
patch('torch.distributed.destroy_process_group'), \
patch('torch.distributed.barrier'), \
patch('streaming.base.dataset.StreamingDataset.get_item'):

mock_json_load.return_value = {
'version':
2,
'shards': [{
'column_names': ['column1', 'column2'],
'column_encodings': ['int', 'float'],
'column_sizes': [4, 8],
'compression': None,
'format': 'mds',
'hashes': [],
'raw_data': {
'basename': 'shard.00000.mds',
'bytes': 1024,
'hashes': {},
},
'samples': 1000,
'size_limit': 67108864,
'version': 2,
'zip_data': None,
}],
}

with pytest.raises(TypeError, match=f'.*got an unexpected keyword argument.*'):
_ = build_finetuning_dataloader(
**cfg,
tokenizer=tokenizer,
device_batch_size=device_batch_size,
).dataloader

@pytest.mark.xfail
def test_text_dataloader_with_extra_keys():
max_seq_len = 1024
cfg = {
'dataset': {
'remote': '/remote',
'local': '/local',
'split': 'train',
'max_seq_len': max_seq_len,
'shuffle': True,
'num_canonical_nodes': 472,
'extra_key_1': 'extra_key_1',
'extra_key_2': 'extra_key_2',
'extra_key_3': 'extra_key_3',
},
'drop_last': False,
'num_workers': 0,
'pin_memory': False,
'prefetch_factor': None,
'persistent_workers': False,
'timeout': 0,
}

cfg = om.create(cfg)

tokenizer = build_tokenizer(
tokenizer_name='gpt2',
tokenizer_kwargs={'model_max_length': max_seq_len},
)

device_batch_size = 2

mock_stat = MagicMock()
mock_stat.st_size = 1024 # Mock st_size with a desired value
mock_stat.st_mode = 33188 # Regular file mode for Unix-based systems

#with patch('streaming.base.stream.get_shards', return_value=None):
with patch('os.makedirs'), \
patch('builtins.open', new_callable=mock_open, read_data='{"version": 2, "shards": []}'), \
patch('json.load') as mock_json_load, \
patch('os.stat', return_value=mock_stat), \
patch('torch.distributed.is_available', return_value=True), \
patch('torch.distributed.is_initialized', return_value=True), \
patch('torch.distributed.broadcast_object_list'), \
patch('torch.distributed.init_process_group'), \
patch('torch.distributed.destroy_process_group'), \
patch('torch.distributed.barrier'), \
patch('streaming.base.dataset.StreamingDataset.get_item'):

mock_json_load.return_value = {
'version':
2,
'shards': [{
'column_names': ['column1', 'column2'],
'column_encodings': ['int', 'float'],
'column_sizes': [4, 8],
'compression': None,
'format': 'mds',
'hashes': [],
'raw_data': {
'basename': 'shard.00000.mds',
'bytes': 1024,
'hashes': {},
},
'samples': 1000,
'size_limit': 67108864,
'version': 2,
'zip_data': None,
}],
}
with pytest.raises(TypeError, match=f'.*got an unexpected keyword argument.*'):
_ = build_text_dataloader(
**cfg,
tokenizer=tokenizer,
device_batch_size=device_batch_size,
).dataloader

Loading