Skip to content

Commit

Permalink
Add streams support (#946)
Browse files Browse the repository at this point in the history
* add convert

* fix

* fix convert

* add jsonl

* revert setup

* test precommit

* pre-commit

* test pre-commit

* v0

* review comments

* temporarily trigger test

* test

* add convert

* fix

* v0

* fix

* fix MDS write

* streams support

* fake commit

* fix setup

* format

* add back arxiv

* trigger test

* review comments

* temporarily trigger test

* test

* add convert

* fix

* fix

* fix MDS write

* format

* trigger test

* fix

* format

* resolve conflicts

* add back jsonl

* fix yaml

* comments

* format

* comments

* comments

* add unit test

* comments

* comments

* merge

* format

* typo

* Update llmfoundry/data/finetuning/dataloader.py

Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>

---------

Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>
  • Loading branch information
bigning and dakinggg authored Feb 9, 2024
1 parent 2f64a14 commit aa0ea6e
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 41 deletions.
42 changes: 36 additions & 6 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
SUPPORTED_EXTENSIONS,
dataset_constructor)
from llmfoundry.data.packing import BinPackCollator, auto_packing_ratio
from llmfoundry.data.text_data import get_tokens_per_batch_func
from llmfoundry.data.text_data import build_streams, get_tokens_per_batch_func

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -128,11 +128,14 @@ def build_finetuning_dataloader(cfg: DictConfig,

dataset = None # for pyright
sampler = None
if cfg.dataset.get('remote') is not None:
if cfg.dataset.get('remote') is not None or cfg.dataset.get(
'streams') is not None:
# Build streaming dataloader
streams = build_streams(cfg.dataset)
dataset = dataset_constructor.build_from_streaming(
tokenizer=tokenizer,
local=cfg.dataset.local,
streams=streams,
local=cfg.dataset.get('local', None),
remote=cfg.dataset.get('remote', None),
split=cfg.dataset.get('split', None),
download_retry=cfg.dataset.get('download_retry', 2),
Expand Down Expand Up @@ -279,11 +282,38 @@ def _validate_config(dataset_cfg: DictConfig) -> None:
'Using a streaming dataset requires setting both `remote` and `local`, ' +\
'but dataset.local is None.'
)
elif dataset_cfg.get('streams') is not None:
# Using the streaming dataset codepath
illegal_keys = ['hf_name', 'hf_kwargs', 'preprocessing_fn', 'safe_load']
discovered_illegal_keys = []
for key in illegal_keys:
if dataset_cfg.get(key) is not None:
discovered_illegal_keys.append('`' + key + '`')
if discovered_illegal_keys:
raise ValueError(
'The dataset config sets a value for `streams` as well as the ' +\
f'following keys: {", ".join(discovered_illegal_keys)}.\n' +\
'Those keys are used when building from a HuggingFace dataset, but ' +\
'setting `streams` instructs the dataset to build from a streaming dataset.'
)
illegal_keys = ['remote', 'local']
discovered_illegal_keys = []
for key in illegal_keys:
if dataset_cfg.get(key) is not None:
discovered_illegal_keys.append('`' + key + '`')
if discovered_illegal_keys:
raise ValueError(
'The dataset config sets a value for `streams` as well as the ' +\
f'following keys: {", ".join(discovered_illegal_keys)}.\n' +\
'Please either use single stream (set remote/local only) ' +\
'or put remote/local under streams'
)

else:
raise ValueError(
'In the dataset config, you must set either `hf_name` to use a ' +\
'HuggingFace dataset or set `remote` to use a streaming ' +\
'dataset, but both were None.'
'In the dataset config, you must set `hf_name` to use a HuggingFace ' +\
'dataset, or set `remote` to use a streaming dataset, or set ' +\
'`streams` to use multiple streaming datasets, but all were None.'
)
if dataset_cfg.get('max_seq_len') is None:
raise ValueError(
Expand Down
36 changes: 25 additions & 11 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
import warnings
from functools import partial
from pathlib import Path
from typing import (Any, Callable, Dict, List, Literal, Optional, Tuple, Union,
cast)
from typing import (Any, Callable, Dict, List, Literal, Optional, Sequence,
Tuple, Union, cast)

import datasets as hf_datasets
import huggingface_hub as hf_hub
import numpy as np
from composer.utils import dist
from streaming import StreamingDataset
from streaming import Stream, StreamingDataset
from transformers import PreTrainedTokenizerBase

from llmfoundry.utils.logging_utils import SpecificWarningFilter
Expand Down Expand Up @@ -257,12 +257,25 @@ def is_valid_ift_example(pad_token_id: int, max_seq_len: int,
non_padding_response)


def _stream_remote_local_validate(remote: Optional[str], local: Optional[str],
split: Optional[str]):
if remote is None or (local == remote):
if local is not None and os.path.isdir(local):
contents = set(os.listdir(local))
if split is not None and split not in contents:
raise ValueError(
f'local directory {local} does not contain split {split}')


class StreamingFinetuningDataset(StreamingDataset):
"""Finetuning dataset with flexible tokenization using StreamingDataset.
Args:
tokenizer (Tokenizer): The name of the HuggingFace tokenizer to use to
tokenize samples.
streams (Sequence[Stream], optional): One or more Streams to stream/cache samples from,
which may be upsampled or downsampled. StreamingDataset uses either ``streams`` or
``remote``/``local``. Defaults to ``None``.
local (str): Local dataset directory where shards are cached by split.
remote (str, optional): Remote path or directory to download the dataset from. If ``None``,
its data must exist locally. StreamingDataset uses either ``streams`` or
Expand Down Expand Up @@ -313,7 +326,8 @@ class StreamingFinetuningDataset(StreamingDataset):

def __init__(self,
tokenizer: PreTrainedTokenizerBase,
local: str,
streams: Optional[Sequence[Stream]] = None,
local: Optional[str] = None,
remote: Optional[str] = None,
split: Optional[str] = None,
download_retry: int = 2,
Expand Down Expand Up @@ -341,15 +355,15 @@ def __init__(self,
f'StreamingFinetuningDataset() got an unexpected keyword argument: {kwargs}'
)

if remote is None or (local == remote):
if os.path.isdir(local):
contents = set(os.listdir(local))
if split not in contents:
raise ValueError(
f'local directory {local} does not contain split {split}'
)
if streams is None:
_stream_remote_local_validate(remote, local, split)
else:
for stream in streams:
_stream_remote_local_validate(stream.remote, stream.local,
split)

super().__init__(
streams=streams,
local=local,
remote=remote,
split=split,
Expand Down
23 changes: 14 additions & 9 deletions llmfoundry/data/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,19 @@ def get_sequence_id_from_batch(
return torch.cat([left_zeros, cumulative_sep[:, :-1]], dim=1)


def build_streams(dataset_cfg: DictConfig):
streams_dict = dataset_cfg.pop('streams', None)
# build streams
streams = None
if streams_dict is not None:
streams = []
for _, stream in streams_dict.items():
# stream is the streams kwargs
# fwd all kwargs with **stream allows streaming to check args
streams.append(Stream(**stream))
return streams


def build_text_dataloader(
cfg: DictConfig,
tokenizer: PreTrainedTokenizerBase,
Expand All @@ -240,19 +253,11 @@ def build_text_dataloader(
assert cfg.name == 'text', f'Tried to build text dataloader with cfg.name={cfg.name}'

# get kwargs
streams_dict = cfg.dataset.pop('streams', None)
mlm_probability = cfg.dataset.pop('mlm_probability', None)
eos_token_id = cfg.dataset.pop('eos_token_id', None)
bos_token_id = cfg.dataset.pop('bos_token_id', None)

# build streams
streams = None
if streams_dict is not None:
streams = []
for _, stream in streams_dict.items():
# stream is the streams kwargs
# fwd all kwargs with **stream allows streaming to check args
streams.append(Stream(**stream))
streams = build_streams(cfg.dataset)

# build dataset potentially with streams
dataset = StreamingTextDataset(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@ train_loader:
name: finetuning
dataset:
############
remote: ${data_remote}
local: ${data_local}
split: train
streams:
my_data:
remote: ${data_remote}
local: ${data_local}
split: train
############
shuffle: true
max_seq_len: ${max_seq_len}
Expand Down
35 changes: 23 additions & 12 deletions tests/data/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,31 +548,38 @@ def test_finetuning_dataloader_custom_split_remote(split: str):


@pytest.mark.parametrize('pretokenize', [True, False])
@pytest.mark.parametrize('use_multiple_streams', [True, False])
@pytest.mark.parametrize('use_bytes', [True, False])
def test_finetuning_dataloader_streaming(pretokenize: bool, use_bytes: bool,
def test_finetuning_dataloader_streaming(pretokenize: bool,
use_multiple_streams: bool,
use_bytes: bool,
tmp_path: pathlib.Path):
max_seq_len = 2048

remote_path = os.path.join(tmp_path, 'remote')
local_path = os.path.join(tmp_path, 'local')

tokenizer = build_tokenizer(
tokenizer_name='gpt2',
tokenizer_kwargs={'model_max_length': max_seq_len},
)

build_mock_ft_streaming_dataset(remote_path,
'train',
pretokenize,
use_bytes=use_bytes,
tokenizer=tokenizer)
streams_config = {'streams': {}}
num_streams = 2
for i in range(num_streams):
remote_path = os.path.join(tmp_path, f'remote_{i}')
local_path = os.path.join(tmp_path, f'local_{i}')
build_mock_ft_streaming_dataset(remote_path,
'train',
pretokenize,
use_bytes=use_bytes,
tokenizer=tokenizer)
streams_config['streams'][f'stream_{i}'] = {
'remote': remote_path,
'local': local_path,
'split': 'train'
}

cfg = {
'name': 'finetuning',
'dataset': {
'remote': remote_path,
'local': local_path,
'split': 'train',
'max_seq_len': 2048,
'decoder_only_format': True,
'allow_pad_trimming': False,
Expand All @@ -586,6 +593,10 @@ def test_finetuning_dataloader_streaming(pretokenize: bool, use_bytes: bool,
'persistent_workers': False,
'timeout': 0
}
if use_multiple_streams:
cfg['dataset'].update(streams_config)
else:
cfg['dataset'].update(streams_config['streams']['stream_0'])

cfg = om.create(cfg)

Expand Down

0 comments on commit aa0ea6e

Please sign in to comment.