Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add finetuning streaming dataset conversion #933

Merged
merged 16 commits into from
Feb 6, 2024
Merged
24 changes: 14 additions & 10 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
import logging
import os
import warnings
from functools import partial
from pathlib import Path
from typing import (Any, Callable, Dict, List, Literal, Optional, Tuple, Union,
cast)
Expand Down Expand Up @@ -228,6 +229,17 @@ def _tokenize_formatted_example(
raise ValueError(f'Unknown conversation type {example_format=}')


def _filter_long_or_empty_examples(pad_token_id: int, max_seq_len: int,
bigning marked this conversation as resolved.
Show resolved Hide resolved
example: Dict) -> bool:
less_than_max_seq_len = len(example['input_ids']) < max_seq_len
non_empty_input = len(example['input_ids']) > 0
non_empty_labels = len(example['labels']) > 0
non_padding_response = any(
token_id != pad_token_id for token_id in example['labels'])
return (less_than_max_seq_len and non_empty_input and non_empty_labels and
non_padding_response)


class StreamingFinetuningDataset(StreamingDataset):
"""Finetuning dataset with flexible tokenization using StreamingDataset.

Expand Down Expand Up @@ -567,17 +579,9 @@ def dataset_mapper(example: Dict):

pad_token_id = tokenizer.pad_token_id

def filter_long_or_empty_examples(example: Dict) -> bool:
less_than_max_seq_len = len(example['input_ids']) < max_seq_len
non_empty_input = len(example['input_ids']) > 0
non_empty_labels = len(example['labels']) > 0
non_padding_response = any(
token_id != pad_token_id for token_id in example['labels'])
return (less_than_max_seq_len and non_empty_input and
non_empty_labels and non_padding_response)

filtered_dataset = tokenized_dataset.filter(
filter_long_or_empty_examples,
partial(_filter_long_or_empty_examples, pad_token_id,
max_seq_len),
num_proc=num_cpus_to_use,
desc='Filtering out long prompts',
)
Expand Down
70 changes: 60 additions & 10 deletions scripts/data_prep/convert_finetuning_dataset.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import json
import os
import platform
import warnings
from argparse import ArgumentParser, Namespace
from typing import Dict, Iterable, List, Optional, Union

import datasets as hf_datasets
import numpy as np
import psutil
from streaming import MDSWriter
from torch.utils.data import DataLoader, IterableDataset
from tqdm import tqdm

from llmfoundry.data.finetuning.tasks import dataset_constructor
from llmfoundry.data.finetuning.tasks import (_filter_long_or_empty_examples,
_tokenize_formatted_example,
bigning marked this conversation as resolved.
Show resolved Hide resolved
dataset_constructor)
from llmfoundry.utils.builders import build_tokenizer


def parse_args() -> Namespace:
Expand All @@ -23,7 +29,7 @@ def parse_args() -> Namespace:
type=str,
required=True,
help=
'Name/path of the dataset (e.g., first argument to `datasets.load_dataset`)'
'Name of the dataset (e.g., first argument to `datasets.load_dataset`, for jsonl data format, it is `json`)'
)
parser.add_argument('--data_subset',
type=str,
Expand All @@ -38,6 +44,10 @@ def parse_args() -> Namespace:
default=None,
help='Name or import path of function used to preprocess (reformat) the dataset. ' +\
'See README for additional details.')
parser.add_argument('--data_files',
bigning marked this conversation as resolved.
Show resolved Hide resolved
nargs='+',
default=[],
help='Data file for each split')
parser.add_argument(
'--skip-preprocessing',
action='store_true',
Expand All @@ -63,6 +73,9 @@ def parse_args() -> Namespace:
default=None,
help='(Optional) name of compression algorithm to use.')
parser.add_argument('--num_workers', type=int, required=False, default=None)
parser.add_argument('--tokenizer', type=str, required=False, default=None)
parser.add_argument('--tokenizer_kwargs', type=str, required=False)
parser.add_argument('--max_seq_len', type=int, default=2048)

parsed = parser.parse_args()

Expand All @@ -73,6 +86,11 @@ def parse_args() -> Namespace:
f'--out_root={parsed.out_root} contains {os.listdir(parsed.out_root)} which cannot overlap with the requested splits {parsed.splits}.'
)

if parsed.tokenizer_kwargs is not None:
parsed.tokenizer_kwargs = json.loads(parsed.tokenizer_kwargs)
else:
parsed.tokenizer_kwargs = {}

return parsed


Expand Down Expand Up @@ -170,12 +188,22 @@ def main(args: Namespace) -> None:
'include the "--skip-preprocessing" flag to avoid this error.'
)

columns = ['prompt', 'response']
tokenizer = None
args.tokenizer_kwargs.update({'model_max_length': args.max_seq_len})
bigning marked this conversation as resolved.
Show resolved Hide resolved
if args.tokenizer:
tokenizer = build_tokenizer(args.tokenizer, args.tokenizer_kwargs)
columns = {'input_ids': 'bytes', 'labels': 'bytes'}
else:
columns = {'prompt': 'str', 'response': 'str'}

for split_name in args.splits:
for i, split_name in enumerate(args.splits):
data_file = None
if len(args.data_files) > 0:
bigning marked this conversation as resolved.
Show resolved Hide resolved
data_file = args.data_files[i]
dataset = hf_datasets.load_dataset(path=args.dataset,
name=args.data_subset,
split=split_name,
data_files=data_file,
streaming=True)
loader = build_dataloader(dataset=dataset,
batch_size=512,
Expand All @@ -190,12 +218,14 @@ def main(args: Namespace) -> None:
keep_local = True
else:
keep_local = False
with MDSWriter(columns={key: 'str' for key in columns},
with MDSWriter(columns=columns,
out=out,
compression=args.compression,
keep_local=keep_local) as out:
examples_removed = 0
for sample in tqdm(samples, desc=split_name):
formatted_sample = preprocessing_fn(sample)

if ('prompt'
not in formatted_sample) or ('response'
not in formatted_sample):
Expand All @@ -204,11 +234,31 @@ def main(args: Namespace) -> None:
'"prompt" and "response" are required keys but at least one was missing ' +\
f'from {formatted_sample=}.'
)
encoded_sample = {
key: formatted_sample[key].encode('utf-8')
for key in columns
}
out.write(encoded_sample)
if tokenizer is not None:
sample = _tokenize_formatted_example(sample,
tokenizer=tokenizer)
if not _filter_long_or_empty_examples(
tokenizer.pad_token_id, args.max_seq_len, sample):
examples_removed += 1
continue

sample_to_write = {}
# convert to bytes
for key in columns.keys():
sample_to_write[key] = np.asarray(sample[key]).tobytes()
out.write(sample_to_write)
else:
encoded_sample = {
key: formatted_sample[key].encode('utf-8')
for key in columns.keys()
}
out.write(encoded_sample)
if tokenizer is not None and examples_removed > 0:
warnings.warn(
f'Dropped {examples_removed} examples where the prompt was longer than {args.max_seq_len}, '
+
'the prompt or response was empty, or the response was all padding tokens.'
)


if __name__ == '__main__':
Expand Down
Loading
Loading