Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add finetuning streaming dataset conversion #933

Merged
merged 16 commits into from
Feb 6, 2024
Merged
45 changes: 32 additions & 13 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
import logging
import os
import warnings
from functools import partial
from pathlib import Path
from typing import (Any, Callable, Dict, List, Literal, Optional, Tuple, Union,
cast)
Expand Down Expand Up @@ -199,7 +200,7 @@ def _tokenize_prompt_response_formatted_example(
return tokenizer(text=prompt, text_target=response)


def _tokenize_formatted_example(
def tokenize_formatted_example(
example: Example,
tokenizer: PreTrainedTokenizerBase) -> TokenizedExample:
"""Tokenizes a formatted example using the provided tokenizer.
Expand Down Expand Up @@ -228,6 +229,33 @@ def _tokenize_formatted_example(
raise ValueError(f'Unknown conversation type {example_format=}')


def is_valid_ift_example(pad_token_id: int, max_seq_len: int,
example: Dict) -> bool:
"""Check if it's an valid ift example.
bigning marked this conversation as resolved.
Show resolved Hide resolved

This functions does the following check:
a. Length of input_ids should less than max_seq_len
bigning marked this conversation as resolved.
Show resolved Hide resolved
b. Both input_ids and labels should not be empty
c. Labels should has at least 1 non-padding token.
bigning marked this conversation as resolved.
Show resolved Hide resolved

Args:
pad_token_id (int): The id of the padding token.
max_seq_len (int): Maximum sequence length.
example (Dict): The input example after tokenization, which has
``input_ids`` and ``labels`` fields.

Returns:
bool: Indicator of whether the input example is valid
"""
less_than_max_seq_len = len(example['input_ids']) < max_seq_len
non_empty_input = len(example['input_ids']) > 0
non_empty_labels = len(example['labels']) > 0
non_padding_response = any(
token_id != pad_token_id for token_id in example['labels'])
return (less_than_max_seq_len and non_empty_input and non_empty_labels and
non_padding_response)


class StreamingFinetuningDataset(StreamingDataset):
"""Finetuning dataset with flexible tokenization using StreamingDataset.

Expand Down Expand Up @@ -347,7 +375,7 @@ def __init__(self,
# How to process a sample
def __getitem__(self, idx: int) -> Dict[str, Any]:
sample = super().__getitem__(idx)
return _tokenize_formatted_example(sample, tokenizer=self.tokenizer)
return tokenize_formatted_example(sample, tokenizer=self.tokenizer)


class DatasetConstructor:
Expand Down Expand Up @@ -550,7 +578,7 @@ def build_from_hf(
def dataset_mapper(example: Dict):
if preprocessing_fn is not None:
example = preprocessing_fn(example)
return _tokenize_formatted_example(example, tokenizer)
return tokenize_formatted_example(example, tokenizer)

detected_cpu_count = os.cpu_count() or 1
detected_cpus_with_margin = detected_cpu_count - 8
Expand All @@ -567,17 +595,8 @@ def dataset_mapper(example: Dict):

pad_token_id = tokenizer.pad_token_id

def filter_long_or_empty_examples(example: Dict) -> bool:
less_than_max_seq_len = len(example['input_ids']) < max_seq_len
non_empty_input = len(example['input_ids']) > 0
non_empty_labels = len(example['labels']) > 0
non_padding_response = any(
token_id != pad_token_id for token_id in example['labels'])
return (less_than_max_seq_len and non_empty_input and
non_empty_labels and non_padding_response)

filtered_dataset = tokenized_dataset.filter(
filter_long_or_empty_examples,
partial(is_valid_ift_example, pad_token_id, max_seq_len),
num_proc=num_cpus_to_use,
desc='Filtering out long prompts',
)
Expand Down
80 changes: 70 additions & 10 deletions scripts/data_prep/convert_finetuning_dataset.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import json
import os
import platform
import warnings
from argparse import ArgumentParser, Namespace
from typing import Dict, Iterable, List, Optional, Union

import datasets as hf_datasets
import numpy as np
import psutil
from streaming import MDSWriter
from torch.utils.data import DataLoader, IterableDataset
from tqdm import tqdm

from llmfoundry.data.finetuning.tasks import dataset_constructor
from llmfoundry.data.finetuning.tasks import (dataset_constructor,
is_valid_ift_example,
tokenize_formatted_example)
from llmfoundry.utils.builders import build_tokenizer


def parse_args() -> Namespace:
Expand All @@ -23,7 +29,7 @@ def parse_args() -> Namespace:
type=str,
required=True,
help=
'Name/path of the dataset (e.g., first argument to `datasets.load_dataset`)'
'Name of the dataset (e.g., first argument to `datasets.load_dataset`, for jsonl data format, it is `json`)'
)
parser.add_argument('--data_subset',
type=str,
Expand All @@ -38,6 +44,13 @@ def parse_args() -> Namespace:
default=None,
help='Name or import path of function used to preprocess (reformat) the dataset. ' +\
'See README for additional details.')
parser.add_argument(
'--data_files',
nargs='+',
default=[],
help=
'Data file for each split. If set, its length should be exact same as len(splits)'
)
parser.add_argument(
'--skip-preprocessing',
action='store_true',
Expand All @@ -63,6 +76,9 @@ def parse_args() -> Namespace:
default=None,
help='(Optional) name of compression algorithm to use.')
parser.add_argument('--num_workers', type=int, required=False, default=None)
parser.add_argument('--tokenizer', type=str, required=False, default=None)
parser.add_argument('--tokenizer_kwargs', type=str, required=False)
parser.add_argument('--max_seq_len', type=int, default=2048)

parsed = parser.parse_args()

Expand All @@ -73,6 +89,17 @@ def parse_args() -> Namespace:
f'--out_root={parsed.out_root} contains {os.listdir(parsed.out_root)} which cannot overlap with the requested splits {parsed.splits}.'
)

if parsed.tokenizer_kwargs is not None:
parsed.tokenizer_kwargs = json.loads(parsed.tokenizer_kwargs)
else:
parsed.tokenizer_kwargs = {}

if len(parsed.data_files) > 0 and len(parsed.data_files) != len(
parsed.splits):
raise ValueError(
f'If set data_files, data_files and splits must be 1:1 mapping. Got {len(parsed.data_files)=} while {len(parsed.splits)=}'
bigning marked this conversation as resolved.
Show resolved Hide resolved
)

return parsed


Expand Down Expand Up @@ -170,12 +197,23 @@ def main(args: Namespace) -> None:
'include the "--skip-preprocessing" flag to avoid this error.'
)

columns = ['prompt', 'response']
tokenizer = None
tokenizer_kwargs = args.tokenizer_kwargs
tokenizer_kwargs.update({'model_max_length': args.max_seq_len})
if args.tokenizer:
tokenizer = build_tokenizer(args.tokenizer, tokenizer_kwargs)
columns = {'input_ids': 'bytes', 'labels': 'bytes'}
else:
columns = {'prompt': 'str', 'response': 'str'}

for split_name in args.splits:
for i, split_name in enumerate(args.splits):
data_file = None
if len(args.data_files) > 0:
bigning marked this conversation as resolved.
Show resolved Hide resolved
data_file = args.data_files[i]
dataset = hf_datasets.load_dataset(path=args.dataset,
name=args.data_subset,
split=split_name,
data_files=data_file,
streaming=True)
loader = build_dataloader(dataset=dataset,
batch_size=512,
Expand All @@ -190,12 +228,14 @@ def main(args: Namespace) -> None:
keep_local = True
else:
keep_local = False
with MDSWriter(columns={key: 'str' for key in columns},
with MDSWriter(columns=columns,
out=out,
compression=args.compression,
keep_local=keep_local) as out:
examples_removed = 0
for sample in tqdm(samples, desc=split_name):
formatted_sample = preprocessing_fn(sample)

if ('prompt'
not in formatted_sample) or ('response'
not in formatted_sample):
Expand All @@ -204,11 +244,31 @@ def main(args: Namespace) -> None:
'"prompt" and "response" are required keys but at least one was missing ' +\
f'from {formatted_sample=}.'
)
encoded_sample = {
key: formatted_sample[key].encode('utf-8')
for key in columns
}
out.write(encoded_sample)
if tokenizer is not None:
sample = tokenize_formatted_example(sample,
tokenizer=tokenizer)
if not is_valid_ift_example(tokenizer.pad_token_id,
args.max_seq_len, sample):
examples_removed += 1
continue

sample_to_write = {}
# convert to bytes
for key in columns.keys():
sample_to_write[key] = np.asarray(sample[key]).tobytes()
out.write(sample_to_write)
else:
encoded_sample = {
key: formatted_sample[key].encode('utf-8')
for key in columns.keys()
}
out.write(encoded_sample)
if tokenizer is not None and examples_removed > 0:
warnings.warn(
f'Dropped {examples_removed} examples where the prompt was longer than {args.max_seq_len}, '
+
'the prompt or response was empty, or the response was all padding tokens.'
)


if __name__ == '__main__':
Expand Down
Loading
Loading