Skip to content

Commit

Permalink
revert data split fallbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
kylesayrs committed Dec 19, 2024
1 parent 4932ec5 commit 6b7c11f
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 106 deletions.
90 changes: 22 additions & 68 deletions src/llmcompressor/transformers/finetune/data/data_helpers.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import logging
import os
import warnings
from typing import Any, Callable, Dict, List, Optional

import torch
from datasets import Dataset, load_dataset
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers.data import default_data_collator

from llmcompressor.typing import DatasetType

LOGGER = logging.getLogger(__name__)
LABELS_MASK_VALUE = -100

Expand All @@ -18,7 +15,6 @@
"get_raw_dataset",
"make_dataset_splits",
"get_custom_datasets_from_path",
"LABELS_MASK_VALUE",
]


Expand All @@ -28,7 +24,7 @@ def format_calibration_data(
do_shuffle: bool = True,
collate_fn: Callable = default_data_collator,
accelerator: Optional[Any] = None,
) -> torch.utils.data.DataLoader:
) -> List[torch.Tensor]:
"""
Creates a dataloader out of the calibration dataset split, trimming it to
the desired number of calibration samples
Expand Down Expand Up @@ -96,17 +92,17 @@ def get_raw_dataset(


def make_dataset_splits(
datasets: Dict[str, DatasetType],
tokenized_datasets: Dict[str, Any],
do_train: bool = False,
do_eval: bool = False,
do_predict: bool = False,
do_oneshot: bool = False,
) -> Dict[str, DatasetType]:
) -> Dict[str, Dataset]:
"""
Restructures the datasets dictionary based on what tasks will be run
(train, eval, predict)
:param datasets: dictionary of processed datasets
:param tokenized_datasets: dictionary of processed datasets
:param do_train: Whether to store the train dataset
:param do_eval: Whether to store the validation dataset
:param do_predict: Whether to store the test dataset
Expand All @@ -115,40 +111,31 @@ def make_dataset_splits(
"""

# handles case where all splits are contained in a single dataset
if "all" in datasets and len(datasets) == 1:
datasets = datasets.get("all")
if isinstance(datasets, Dataset):
datasets = {"train": datasets, "calibration": datasets} # shallow copy
if "all" in tokenized_datasets and len(tokenized_datasets) == 1:
tokenized_datasets = tokenized_datasets.get("all")
if isinstance(tokenized_datasets, Dataset):
tokenized_datasets = {"train": tokenized_datasets}

train_split = eval_split = predict_split = calib_split = None

if do_train:
train_split = _get_split_with_fallbacks(
datasets, "train", ["train"], strict=True
)
if "train" not in tokenized_datasets:
raise ValueError("--do_train requires a train dataset")
train_split = tokenized_datasets["train"]
if do_eval:
eval_split = _get_split_with_fallbacks(
datasets, "evaluation", ["validation"], ["test"], strict=True
)
if "validation" not in tokenized_datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_split = tokenized_datasets["validation"]
if do_predict:
predict_split = _get_split_with_fallbacks(
datasets, "prediction", ["test"], ["validation"], strict=True
)
if "test" not in tokenized_datasets:
raise ValueError("--do_predict requires a test dataset")
predict_split = tokenized_datasets["test"]
if do_oneshot:
calib_split = _get_split_with_fallbacks(
datasets,
"oneshot",
["calibration", "train"],
["test", "validation"],
strict=False,
)

# remove labels from calibration dataset
column_names = calib_split.column_names
if isinstance(column_names, dict):
column_names = sum(column_names.values(), [])
if "labels" in column_names:
calib_split = calib_split.remove_columns("labels")
calib_split = tokenized_datasets.get("calibration")
if calib_split is None:
if "train" not in tokenized_datasets:
raise ValueError("--do_oneshot requires a calibration dataset")
calib_split = tokenized_datasets["train"]

split_datasets = {
"train": train_split,
Expand Down Expand Up @@ -256,36 +243,3 @@ def do_transform(candidate: str) -> bool:
transform_dataset_key(dataset_key)

return data_files


def _get_split_with_fallbacks(
datasets: Dict[str, DatasetType],
task: str,
preferred: List[str],
fallbacks: List[str] = [],
strict: bool = True,
) -> DatasetType:
assert len(preferred) > 0
if len(datasets) <= 0:
raise ValueError("Cannot get retrieve data from dataset with no splits")

# check preferred names (without warning)
for pref in preferred:
if pref in datasets:
return datasets[pref]

# fallback to the first available dataset if all else fails
if not strict:
fallbacks.append(next(iter(datasets.keys())))

# check fallbacks (with warning)
for fallback in fallbacks:
if fallback in datasets:
warnings.warn(
f"{task} expects one of {preferred} dataset split, "
f"falling back to {fallback}. Use "
f'`splits={{"{preferred[0]}": "{fallback}"}}` to silence this warning'
)
return datasets[fallback]

raise ValueError(f"{task} expects at least one of {fallbacks} dataset splits")
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from unittest.mock import Mock

import pytest

from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments
Expand Down Expand Up @@ -50,41 +48,8 @@ def test_separate_datasets():
assert split_datasets.get("validation") is not None
assert split_datasets.get("test") is None


@pytest.mark.unit
def test_datasets_fallbacks():
# strict splits
mock_datasets = {"calibration": Mock(ds_name="calibration_ds", column_names=[])}
with pytest.raises(ValueError):
_ = make_dataset_splits(mock_datasets, do_train=True)
with pytest.raises(ValueError):
_ = make_dataset_splits(mock_datasets, do_eval=True)
with pytest.raises(ValueError):
_ = make_dataset_splits(mock_datasets, do_predict=True)

# validation, predict, and oneshot fallbacks
mock_datasets = {"test": Mock(ds_name="test_ds", column_names=[])}
with pytest.warns(UserWarning):
split_ds = make_dataset_splits(
mock_datasets, do_eval=True, do_predict=True, do_oneshot=True
# fails due to no test split specified
split_datasets = make_dataset_splits(
datasets, do_train=True, do_eval=True, do_predict=True
)
assert split_ds.get("validation").ds_name == "test_ds"
assert split_ds.get("test").ds_name == "test_ds"
assert split_ds.get("calibration").ds_name == "test_ds"

# oneshot takes train without warning
mock_datasets = {"train": Mock(ds_name="train_ds", column_names=[])}
split_ds = make_dataset_splits(mock_datasets, do_oneshot=True)
assert split_ds.get("calibration").ds_name == "train_ds"

# oneshot takes test with warning
mock_datasets = {"test": Mock(ds_name="test_ds", column_names=[])}
with pytest.warns(UserWarning):
split_ds = make_dataset_splits(mock_datasets, do_oneshot=True)
assert split_ds.get("calibration").ds_name == "test_ds"

# oneshot takes custom splits with warning
mock_datasets = {"custom_split": Mock(ds_name="custom_ds", column_names=[])}
with pytest.warns(UserWarning):
split_ds = make_dataset_splits(mock_datasets, do_oneshot=True)
assert split_ds.get("calibration").ds_name == "custom_ds"

0 comments on commit 6b7c11f

Please sign in to comment.