Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainer support for IterableDataset for evaluation and predict #11286

Merged
merged 7 commits into from
Apr 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/transformers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"cookiecutter": "cookiecutter==1.7.2",
"dataclasses": "dataclasses",
"datasets": "datasets",
"deepspeed": "deepspeed>0.3.13",
"deepspeed": "deepspeed>=0.3.14",
"docutils": "docutils==0.16.0",
"fairscale": "fairscale>0.3",
"faiss-cpu": "faiss-cpu",
Expand Down
412 changes: 336 additions & 76 deletions src/transformers/trainer.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions src/transformers/trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""
Callbacks to use with the Trainer class and customize the training loop.
"""

import collections
import dataclasses
import json
from dataclasses import dataclass
Expand Down Expand Up @@ -469,7 +469,7 @@ def on_step_end(self, args, state, control, **kwargs):
self.current_step = state.global_step

def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
if state.is_local_process_zero:
if state.is_local_process_zero and isinstance(eval_dataloader.dataset, collections.abc.Sized):
if self.prediction_bar is None:
self.prediction_bar = tqdm(total=len(eval_dataloader), leave=self.training_bar is None)
self.prediction_bar.update(1)
Expand Down
80 changes: 80 additions & 0 deletions src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,26 @@ def nested_concat(tensors, new_tensors, padding_index=-100):
raise TypeError(f"Unsupported type for concatenation: got {type(tensors)}")


def find_batch_size(tensors):
"""
Find the first dimension of a tensor in a nested list/tuple/dict of tensors.
"""
if isinstance(tensors, (list, tuple)):
for t in tensors:
result = find_batch_size(t)
if result is not None:
return result
elif isinstance(tensors, dict):
for key, value in tensors.items():
result = find_batch_size(value)
if result is not None:
return result
elif isinstance(tensors, torch.Tensor):
return tensors.shape[0] if len(tensors.shape) >= 1 else None
elif isinstance(tensors, np.ndarray):
return tensors.shape[0] if len(tensors.shape) >= 1 else None


def nested_numpify(tensors):
"Numpify `tensors` (even if it's a nested list/tuple of tensors)."
if isinstance(tensors, (list, tuple)):
Expand Down Expand Up @@ -222,6 +242,10 @@ class SequentialDistributedSampler(Sampler):
"""

def __init__(self, dataset, num_replicas=None, rank=None, batch_size=None):
warnings.warn(
"SequentialDistributedSampler is deprecated and will be removed in v5 of Tranformers.",
FutureWarning,
)
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
Expand Down Expand Up @@ -338,6 +362,10 @@ class DistributedTensorGatherer:
"""

def __init__(self, world_size, num_samples, make_multiple_of=None, padding_index=-100):
warnings.warn(
"DistributedTensorGatherer is deprecated and will be removed in v5 of Tranformers.",
FutureWarning,
)
self.world_size = world_size
self.num_samples = num_samples
total_size = world_size if make_multiple_of is None else world_size * make_multiple_of
Expand Down Expand Up @@ -576,6 +604,55 @@ def __iter__(self) -> Iterator:
return iter(indices)


class ShardSampler(Sampler):
"""
Sampler that shards batches between several processes. Dispatches indices batch by batch: on 2 processes with batch
size 4, the first two batches are :obj:`[0, 1, 2, 3, 4, 5, 6, 7]` and :obj:`[8, 9, 10, 11, 12, 13, 14, 15]`, which
shard into :obj:`[0, 1, 2, 3]` and :obj:`[8, 9, 10, 11]` for GPU-0 and :obj:`[4, 5, 6, 7]` and :obj:`[12, 13, 14,
15]` for GPU-1.

The sampler thus yields :obj:`[0, 1, 2, 3, 8, 9, 10, 11]` on GPU-0 and :obj:`[4, 5, 6, 7, 12, 13, 14, 15]` on
GPU-1.
"""

def __init__(
self,
dataset: Dataset,
batch_size: int = 1,
drop_last: bool = False,
num_processes: int = 1,
process_index: int = 0,
):
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
self.num_processes = num_processes
self.process_index = process_index

self.total_batch_size = total_batch_size = batch_size * num_processes

num_batches = len(dataset) // total_batch_size if drop_last else math.ceil(len(dataset) / total_batch_size)
self.total_num_samples = num_batches * total_batch_size

def __iter__(self):
indices = list(range(len(self.dataset)))

# Add extra samples to make it evenly divisible. While loop is there in the edge case we have a tiny dataset
# and it needs to be done several times.
while len(indices) < self.total_num_samples:
indices += indices[: (self.total_num_samples - len(indices))]

result = []
for batch_start in range(self.batch_size * self.process_index, self.total_num_samples, self.total_batch_size):
result += indices[batch_start : batch_start + self.batch_size]

return iter(result)

def __len__(self):
# Each shard only sees a fraction of total_num_samples.
return self.total_num_samples // self.num_processes


class IterableDatasetShard(IterableDataset):
"""
Wraps a PyTorch :obj:`IterableDataset` to generate samples for one of the processes only. Instances of this class
Expand Down Expand Up @@ -634,13 +711,15 @@ def __init__(
self.process_index = process_index
self.seed = seed
self.epoch = 0
self.num_examples = 0

def set_epoch(self, epoch):
self.epoch = epoch
if hasattr(self.dataset, "set_epoch"):
self.dataset.set_epoch(epoch)

def __iter__(self):
self.num_examples = 0
if (
not hasattr(self.dataset, "set_epoch")
and hasattr(self.dataset, "generator")
Expand All @@ -653,6 +732,7 @@ def __iter__(self):
first_batch = None
current_batch = []
for element in self.dataset:
self.num_examples += 1
current_batch.append(element)
# Wait to have a full batch before yielding elements.
if len(current_batch) == real_batch_size:
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ class EvalPrediction(NamedTuple):
label_ids: np.ndarray


class EvalLoopOutput(NamedTuple):
predictions: Union[np.ndarray, Tuple[np.ndarray]]
label_ids: Optional[np.ndarray]
metrics: Optional[Dict[str, float]]
num_samples: Optional[int]


class PredictionOutput(NamedTuple):
predictions: Union[np.ndarray, Tuple[np.ndarray]]
label_ids: Optional[np.ndarray]
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,9 @@ class TrainingArguments:
skip_memory_metrics: bool = field(
default=False, metadata={"help": "Whether or not to skip adding of memory profiler reports to metrics."}
)
use_legacy_prediction_loop: bool = field(
default=False, metadata={"help": "Whether or not to use the legacy prediction_loop in the Trainer."}
)
Comment on lines +527 to +529
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for adding this!

_n_gpu: int = field(init=False, repr=False, default=-1)
mp_parameters: str = field(
default="",
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/utils/notebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import time
from typing import Optional

Expand Down Expand Up @@ -286,6 +287,8 @@ def on_step_end(self, args, state, control, **kwargs):
self._force_next_update = False

def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
if not isinstance(eval_dataloader.dataset, collections.abc.Sized):
return
if self.prediction_bar is None:
if self.training_tracker is not None:
self.prediction_bar = self.training_tracker.add_child(len(eval_dataloader))
Expand Down
59 changes: 44 additions & 15 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,35 +819,64 @@ def test_trainer_eval_lm(self):
)
self.assertEqual(len(dataset), 31)

def test_trainer_iterable_dataset(self):
def test_training_iterable_dataset(self):
config = RegressionModelConfig()
model = RegressionPreTrainedModel(config)
train_dataset = SampleIterableDataset()

args = RegressionTrainingArguments(output_dir="./examples", max_steps=2)
args = RegressionTrainingArguments(output_dir="./examples", max_steps=4)
trainer = Trainer(model=model, args=args, train_dataset=train_dataset)
trainer.train()
self.assertEqual(trainer.state.global_step, 4)

loader = trainer.get_train_dataloader()
self.assertIsInstance(loader, torch.utils.data.DataLoader)
self.assertIsInstance(loader.sampler, torch.utils.data.dataloader._InfiniteConstantSampler)

# Exception if giving iterable dataset and no max_steps
with self.assertRaises(ValueError):
args1 = RegressionTrainingArguments(output_dir="./examples")
_ = Trainer(model=model, args=args1, train_dataset=train_dataset)
def test_evaluation_iterable_dataset(self):
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
eval_dataset = SampleIterableDataset()

args = RegressionTrainingArguments(output_dir="./examples")
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset, compute_metrics=AlmostAccuracy())
results = trainer.evaluate()

# Exception if eval_dataset is iterable in __init__
with self.assertRaises(ValueError):
_ = Trainer(model=model, args=args, train_dataset=train_dataset, eval_dataset=train_dataset)
x, y = trainer.eval_dataset.dataset.x, trainer.eval_dataset.dataset.ys[0]
pred = 1.5 * x + 2.5
expected_loss = ((pred - y) ** 2).mean()
self.assertAlmostEqual(results["eval_loss"], expected_loss)
expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)

# Exception if predicting with iterable dataset
with self.assertRaises(ValueError):
trainer.predict(train_dataset)
# With a number of elements not a round multiple of the batch size
eval_dataset = SampleIterableDataset(length=66)
results = trainer.evaluate(eval_dataset)

# Exception if evaluating with iterable dataset
with self.assertRaises(ValueError):
trainer.evaluate(train_dataset)
x, y = eval_dataset.dataset.x, eval_dataset.dataset.ys[0]
pred = 1.5 * x + 2.5
expected_loss = ((pred - y) ** 2).mean()
self.assertAlmostEqual(results["eval_loss"], expected_loss)
expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)

def test_predict_iterable_dataset(self):
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
eval_dataset = SampleIterableDataset()

args = RegressionTrainingArguments(output_dir="./examples")
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset, compute_metrics=AlmostAccuracy())

preds = trainer.predict(trainer.eval_dataset).predictions
x = eval_dataset.dataset.x
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))

# With a number of elements not a round multiple of the batch size
test_dataset = SampleIterableDataset(length=66)
preds = trainer.predict(test_dataset).predictions
x = test_dataset.dataset.x
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))

def test_num_train_epochs_in_training(self):
# len(train_dl) < gradient_accumulation_steps shouldn't give ``ZeroDivisionError`` when ``max_steps`` is given.
Expand Down
61 changes: 59 additions & 2 deletions tests/test_trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import unittest

import numpy as np
Expand All @@ -34,6 +35,7 @@
LabelSmoother,
LengthGroupedSampler,
SequentialDistributedSampler,
ShardSampler,
get_parameter_names,
)

Expand Down Expand Up @@ -283,6 +285,10 @@ def check_iterable_dataset_shard(self, dataset, batch_size, drop_last, num_proce
# All shards have the same number of samples
self.assertEqual(len(shard), len(shard_lists[0]))

for shard in shards:
# All shards know the total number of samples
self.assertEqual(shard.num_examples, len(reference))

observed = []
for idx in range(0, len(shard_lists[0]), batch_size):
for shard in shard_lists:
Expand All @@ -295,11 +301,62 @@ def check_iterable_dataset_shard(self, dataset, batch_size, drop_last, num_proce
reference += reference
self.assertListEqual(observed, reference[: len(observed)])

# Check equivalence between IterableDataset and ShardSampler
dataset.generator.manual_seed(epoch)
reference = list(dataset)

sampler_shards = [
ShardSampler(
reference, batch_size=batch_size, drop_last=drop_last, num_processes=num_processes, process_index=i
)
for i in range(num_processes)
]
for shard, sampler_shard in zip(shard_lists, sampler_shards):
self.assertListEqual(shard, list(sampler_shard))

def test_iterable_dataset_shard(self):
dataset = RandomIterableDataset()

self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=2, epoch=0)
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=2, epoch=0)
self.check_iterable_dataset_shard(dataset, 4, drop_last=False, num_processes=2, epoch=0)

self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=3, epoch=42)
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=3, epoch=42)
self.check_iterable_dataset_shard(dataset, 4, drop_last=False, num_processes=3, epoch=42)

def check_shard_sampler(self, dataset, batch_size, drop_last, num_processes=2):
shards = [
ShardSampler(
dataset, batch_size=batch_size, drop_last=drop_last, num_processes=num_processes, process_index=i
)
for i in range(num_processes)
]
shard_lists = [list(shard) for shard in shards]

for shard in shard_lists:
# All shards have a number of samples that is a round multiple of batch size
self.assertTrue(len(shard) % batch_size == 0)
# All shards have the same number of samples
self.assertEqual(len(shard), len(shard_lists[0]))

observed = []
for idx in range(0, len(shard_lists[0]), batch_size):
for shard in shard_lists:
observed += shard[idx : idx + batch_size]

# If drop_last is False we loop through samples at the beginning to have a size that is a round multiple of
# batch_size
reference = copy.copy(dataset)
if not drop_last:
while len(reference) < len(observed):
reference += reference
self.assertListEqual(observed, reference[: len(observed)])

def test_shard_sampler(self):
for n_elements in [64, 123]:
dataset = list(range(n_elements))

self.check_shard_sampler(dataset, 4, drop_last=True, num_processes=2)
self.check_shard_sampler(dataset, 4, drop_last=False, num_processes=2)

self.check_shard_sampler(dataset, 4, drop_last=True, num_processes=3)
self.check_shard_sampler(dataset, 4, drop_last=False, num_processes=3)