diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 92bbac561d..ae0da6d09b 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -198,7 +198,8 @@ def build_finetuning_dataloader( allowed_dataset_config_keys = set( dataset_constructor_keys, ).union(_ALLOWED_DATASET_KEYS) - _validate_config( + + extraneous_keys = _validate_config( **dataset_cfg, allowed_dataset_keys=allowed_dataset_config_keys, ) @@ -253,13 +254,13 @@ def build_finetuning_dataloader( streams_cfg, ) if streams_cfg is not None else None - # Take the constructor args from above, minus args that have been created separately dataset_constructor_args = { k: v for k, v in dataset_cfg.items() - if k in dataset_constructor_keys and + if k in set(dataset_constructor_keys).union(extraneous_keys) and k not in {'streams', 'packing_ratio'} } + streaming_dataset = dataset_constructor.build_from_streaming( tokenizer=tokenizer, streams=streams, @@ -378,7 +379,7 @@ def _validate_config( target_responses: Optional[str] = None, allowed_dataset_keys: set[str] = _ALLOWED_DATASET_KEYS, **kwargs: dict[str, Any], -) -> None: +) -> set[str]: """Validates the dataset configuration. Makes sure that the dataset is properly configured for either @@ -434,11 +435,16 @@ def _validate_config( Raises: ValueError: If the dataset configuration does not meet the requirements. + + Returns: + set[str]: Return the extraneous keys. """ + extraneous_keys = set() if not set(kwargs.keys()).issubset(allowed_dataset_keys): - raise ValueError( + extraneous_keys = set(kwargs.keys()) - allowed_dataset_keys + log.warning( 'The dataset config contains the following extraneous keys: ' +\ - ', '.join(set(kwargs.keys()) - allowed_dataset_keys), + ', '.join(extraneous_keys), ) if hf_name is not None: @@ -533,6 +539,8 @@ def _validate_config( decoder_only_format, ) + return extraneous_keys + def _download_remote_hf_dataset(remote_path: str, split: str) -> str: """Downloads a dataset from a remote object store. diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 179f017fd9..915267786f 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -613,11 +613,6 @@ def __init__( **kwargs: Any, ): - if len(kwargs) > 0: - raise ValueError( - f'StreamingFinetuningDataset() got an unexpected keyword argument: {kwargs}', - ) - if token_encoding_type not in SUPPORTED_MDS_ENCODING_TYPES: raise ValueError( f'The token_encoding_type must be one of {SUPPORTED_MDS_ENCODING_TYPES}, but got {token_encoding_type}', @@ -658,6 +653,7 @@ def __init__( batching_method=batching_method, allow_unsafe_types=allow_unsafe_types, replication=replication, + **kwargs, ) self.tokenizer = tokenizer diff --git a/llmfoundry/data/text_data.py b/llmfoundry/data/text_data.py index 3ce248e69f..37d4c32b23 100644 --- a/llmfoundry/data/text_data.py +++ b/llmfoundry/data/text_data.py @@ -138,11 +138,6 @@ def __init__( **kwargs: Any, ): - if len(kwargs) > 0: - raise ValueError( - f'StreamingTextDataset() got an unexpected keyword argument: {kwargs}', - ) - if token_encoding_type not in SUPPORTED_MDS_ENCODING_TYPES: raise ValueError( f'The token_encoding_type must be one of {SUPPORTED_MDS_ENCODING_TYPES}, but got {token_encoding_type}', @@ -188,6 +183,7 @@ def __init__( batching_method=batching_method, allow_unsafe_types=allow_unsafe_types, replication=replication, + **kwargs, ) self.tokenizer = tokenizer self.max_seq_len = max_seq_len @@ -332,10 +328,13 @@ def build_text_dataloader( StreamingTextDataset, ).parameters + valid_base_dataset_params = inspect.signature(StreamingDataset,).parameters + dataset_config_subset_for_streaming_text_dataset = { k: v for k, v in dataset_cfg.items() - if k in valid_streaming_text_dataset_parameters + if k in valid_streaming_text_dataset_parameters or + k in valid_base_dataset_params } # build dataset potentially with streams diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py index d7f979713a..5f16c86eb9 100644 --- a/tests/data/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -8,7 +8,7 @@ from contextlib import nullcontext as does_not_raise from pathlib import Path from typing import Any, Callable, ContextManager, Literal, Optional, Union -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, mock_open, patch import catalogue import numpy as np @@ -1423,3 +1423,160 @@ def test_sharegpt_format( device_batch_size=device_batch_size, **cfg, ).dataloader + +def test_ft_dataloader_with_extra_keys(): + max_seq_len = 2 + cfg = { + 'dataset': { + 'remote': '/remote', + 'local': '/local', + 'split': 'train', + 'max_seq_len': 2048, + 'decoder_only_format': True, + 'shuffle': True, + 'num_canonical_nodes': 472, + 'target_responses': 'last', + 'target_prompts': 'none', + 'extra_key_1': 'extra_key_1', + 'extra_key_2': 'extra_key_2', + 'extra_key_3': 'extra_key_3', + }, + 'drop_last': False, + 'num_workers': 0, + 'pin_memory': False, + 'prefetch_factor': None, + 'persistent_workers': False, + 'timeout': 0, + } + + cfg = om.create(cfg) + + tokenizer = build_tokenizer( + tokenizer_name='gpt2', + tokenizer_kwargs={'model_max_length': max_seq_len}, + ) + + device_batch_size = 2 + + mock_stat = MagicMock() + mock_stat.st_size = 1024 # Mock st_size with a desired value + mock_stat.st_mode = 33188 # Regular file mode for Unix-based systems + + #with patch('streaming.base.stream.get_shards', return_value=None): + with patch('os.makedirs'), \ + patch('builtins.open', new_callable=mock_open, read_data='{"version": 2, "shards": []}'), \ + patch('json.load') as mock_json_load, \ + patch('os.stat', return_value=mock_stat), \ + patch('torch.distributed.is_available', return_value=True), \ + patch('torch.distributed.is_initialized', return_value=True), \ + patch('torch.distributed.broadcast_object_list'), \ + patch('torch.distributed.init_process_group'), \ + patch('torch.distributed.destroy_process_group'), \ + patch('torch.distributed.barrier'), \ + patch('streaming.base.dataset.StreamingDataset.get_item'): + + mock_json_load.return_value = { + 'version': + 2, + 'shards': [{ + 'column_names': ['column1', 'column2'], + 'column_encodings': ['int', 'float'], + 'column_sizes': [4, 8], + 'compression': None, + 'format': 'mds', + 'hashes': [], + 'raw_data': { + 'basename': 'shard.00000.mds', + 'bytes': 1024, + 'hashes': {}, + }, + 'samples': 1000, + 'size_limit': 67108864, + 'version': 2, + 'zip_data': None, + }], + } + + with pytest.raises(TypeError, match=f'.*got an unexpected keyword argument.*'): + _ = build_finetuning_dataloader( + **cfg, + tokenizer=tokenizer, + device_batch_size=device_batch_size, + ).dataloader + +@pytest.mark.xfail +def test_text_dataloader_with_extra_keys(): + max_seq_len = 1024 + cfg = { + 'dataset': { + 'remote': '/remote', + 'local': '/local', + 'split': 'train', + 'max_seq_len': max_seq_len, + 'shuffle': True, + 'num_canonical_nodes': 472, + 'extra_key_1': 'extra_key_1', + 'extra_key_2': 'extra_key_2', + 'extra_key_3': 'extra_key_3', + }, + 'drop_last': False, + 'num_workers': 0, + 'pin_memory': False, + 'prefetch_factor': None, + 'persistent_workers': False, + 'timeout': 0, + } + + cfg = om.create(cfg) + + tokenizer = build_tokenizer( + tokenizer_name='gpt2', + tokenizer_kwargs={'model_max_length': max_seq_len}, + ) + + device_batch_size = 2 + + mock_stat = MagicMock() + mock_stat.st_size = 1024 # Mock st_size with a desired value + mock_stat.st_mode = 33188 # Regular file mode for Unix-based systems + + #with patch('streaming.base.stream.get_shards', return_value=None): + with patch('os.makedirs'), \ + patch('builtins.open', new_callable=mock_open, read_data='{"version": 2, "shards": []}'), \ + patch('json.load') as mock_json_load, \ + patch('os.stat', return_value=mock_stat), \ + patch('torch.distributed.is_available', return_value=True), \ + patch('torch.distributed.is_initialized', return_value=True), \ + patch('torch.distributed.broadcast_object_list'), \ + patch('torch.distributed.init_process_group'), \ + patch('torch.distributed.destroy_process_group'), \ + patch('torch.distributed.barrier'), \ + patch('streaming.base.dataset.StreamingDataset.get_item'): + + mock_json_load.return_value = { + 'version': + 2, + 'shards': [{ + 'column_names': ['column1', 'column2'], + 'column_encodings': ['int', 'float'], + 'column_sizes': [4, 8], + 'compression': None, + 'format': 'mds', + 'hashes': [], + 'raw_data': { + 'basename': 'shard.00000.mds', + 'bytes': 1024, + 'hashes': {}, + }, + 'samples': 1000, + 'size_limit': 67108864, + 'version': 2, + 'zip_data': None, + }], + } + with pytest.raises(TypeError, match=f'.*got an unexpected keyword argument.*'): + _ = build_text_dataloader( + **cfg, + tokenizer=tokenizer, + device_batch_size=device_batch_size, + ).dataloader