From 942ff8fae246ebd114480e74529c2cd75c40b9c5 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Wed, 14 Apr 2021 17:02:26 -0400 Subject: [PATCH] Trainer iterable dataset (#11254) * IterableDatasetShard * Test and integration in Trainer * Update src/transformers/trainer_pt_utils.py Co-authored-by: Lysandre Debut * Style Co-authored-by: Lysandre Debut --- src/transformers/trainer.py | 25 +++++++- src/transformers/trainer_pt_utils.py | 92 +++++++++++++++++++++++++++- tests/test_trainer.py | 48 ++++----------- tests/test_trainer_utils.py | 60 ++++++++++++++++++ 4 files changed, 185 insertions(+), 40 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 41800b7fd3a32c..02f6a29dc57446 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -81,6 +81,7 @@ DistributedLengthGroupedSampler, DistributedSamplerWithLoop, DistributedTensorGatherer, + IterableDatasetShard, LabelSmoother, LengthGroupedSampler, SequentialDistributedSampler, @@ -493,9 +494,7 @@ def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optio dataset.set_format(type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"]) def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]: - if isinstance(self.train_dataset, torch.utils.data.IterableDataset) or not isinstance( - self.train_dataset, collections.abc.Sized - ): + if not isinstance(self.train_dataset, collections.abc.Sized): return None # Build the sampler. @@ -553,6 +552,26 @@ def get_train_dataloader(self) -> DataLoader: """ if self.train_dataset is None: raise ValueError("Trainer: training requires a train_dataset.") + + if isinstance(self.train_dataset, torch.utils.data.dataset.IterableDataset): + if self.args.world_size > 1: + train_dataset = IterableDatasetShard( + self.train_dataset, + batch_size=self.args.train_batch_size, + drop_last=self.args.dataloader_drop_last, + num_processes=self.args.world_size, + process_index=self.args.process_index, + ) + else: + train_dataset = self.train_dataset + return DataLoader( + train_dataset, + batch_size=self.args.train_batch_size, + collate_fn=self.data_collator, + num_workers=self.args.dataloader_num_workers, + pin_memory=self.args.dataloader_pin_memory, + ) + train_sampler = self._get_train_sampler() return DataLoader( diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index ebcb7d05572322..e048cd8d94162e 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -28,7 +28,7 @@ import numpy as np import torch from packaging import version -from torch.utils.data.dataset import Dataset +from torch.utils.data.dataset import Dataset, IterableDataset from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import RandomSampler, Sampler @@ -576,6 +576,96 @@ def __iter__(self) -> Iterator: return iter(indices) +class IterableDatasetShard(IterableDataset): + """ + Wraps a PyTorch :obj:`IterableDataset` to generate samples for one of the processes only. Instances of this class + will always yield a number of samples that is a round multiple of the actual batch size (which is :obj:`batch_size + x num_processes`). Depending on the value of the :obj:`drop_last` attribute, it will either stop the iteration at + the first batch that would be too small or loop with indices from the beginning. + + On two processes with an iterable dataset yielding of :obj:`[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]` with a batch + size of 2: + + - the shard on process 0 will yield :obj:`[0, 1, 4, 5, 8, 9]` so will see batches :obj:`[0, 1]`, :obj:`[4, 5]`, + :obj:`[8, 9]` + - the shard on process 1 will yield :obj:`[2, 3, 6, 7, 10, 11]` so will see batches :obj:`[2, 3]`, :obj:`[6, 7]`, + :obj:`[10, 11]` + + .. warning: + + If your IterableDataset implements some randomization that needs to be applied the same way on all processes + (for instance, a shuffling), you should use a :obj:`torch.Generator` in a :obj:`generator` attribute of the + :obj:`dataset` to generate your random numbers and call the + :meth:`~transformers.trainer_pt_utils.IterableDatasetShard.set_epoch` method of this object. It will set the + seed of this :obj:`generator` to :obj:`seed + epoch` on all processes before starting the iteration. + Alternatively, you can also subclass this class and override the :meth:`__iter__` method with your custom + logic. + + Args: + dataset (:obj:`torch.utils.data.dataset.IterableDataset`): + The batch sampler to split in several shards. + batch_size (:obj:`int`, `optional`, defaults to 1): + The size of the batches per shard. + drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to drop the last incomplete batch or complete the last batches by using the samples from the + beginning. + num_processes (:obj:`int`, `optional`, defaults to 1): + The number of processes running concurrently. + process_index (:obj:`int`, `optional`, defaults to 0): + The index of the current process. + seed (:obj:`int`, `optional`, defaults to 0): + A random seed that will be used for the random number generation in + :meth:`~transformers.trainer_pt_utils.IterableDatasetShard.set_epoch`. + """ + + def __init__( + self, + dataset: IterableDataset, + batch_size: int = 1, + drop_last: bool = False, + num_processes: int = 1, + process_index: int = 0, + seed: int = 0, + ): + self.dataset = dataset + self.batch_size = batch_size + self.drop_last = drop_last + self.num_processes = num_processes + self.process_index = process_index + self.seed = seed + self.epoch = 0 + + def set_epoch(self, epoch): + self.epoch = epoch + + def __iter__(self): + if hasattr(self.dataset, "generator") and isinstance(self.dataset.generator, torch.Generator): + self.dataset.generator.manual_seed(self.seed + self.epoch) + real_batch_size = self.batch_size * self.num_processes + process_slice = range(self.process_index * self.batch_size, (self.process_index + 1) * self.batch_size) + + first_batch = None + current_batch = [] + for element in self.dataset: + current_batch.append(element) + # Wait to have a full batch before yielding elements. + if len(current_batch) == real_batch_size: + for i in process_slice: + yield current_batch[i] + if first_batch is None: + first_batch = current_batch.copy() + current_batch = [] + + # Finished if drop_last is True, otherwise complete the last batch with elements from the beginning. + if not self.drop_last and len(current_batch) > 0: + if first_batch is None: + first_batch = current_batch.copy() + while len(current_batch) < real_batch_size: + current_batch += first_batch + for i in process_slice: + yield current_batch[i] + + # In order to keep `trainer.py` compact and easy to understand, place any secondary PT Trainer # helper methods here diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 914e6f5bf2503b..53f5f0b1ca0c69 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -44,9 +44,7 @@ from torch.utils.data import IterableDataset from transformers import ( - AutoModelForMaskedLM, AutoModelForSequenceClassification, - DataCollatorForLanguageModeling, EarlyStoppingCallback, GlueDataset, GlueDataTrainingArguments, @@ -54,7 +52,6 @@ GPT2LMHeadModel, LineByLineTextDataset, PreTrainedModel, - TextDataset, Trainer, TrainerState, ) @@ -138,16 +135,12 @@ def __init__(self, a=0, b=0, double_output=False, **kwargs): if is_torch_available(): class SampleIterableDataset(IterableDataset): - """ - Criteria is not whether it is IterableDataset or not, criteria is whether __len__ is implemented - """ - - def __init__(self, file_path, tokenizer): - self.ds = TextDataset(file_path=file_path, tokenizer=tokenizer, block_size=64) + def __init__(self, a=2, b=3, length=64, seed=42, label_names=None): + self.dataset = RegressionDataset(a=a, b=b, length=length, seed=seed, label_names=label_names) def __iter__(self): - for i in range(len(self.ds)): - yield self.ds[i] + for i in range(len(self.dataset)): + yield self.dataset[i] class RegressionModel(torch.nn.Module): def __init__(self, a=0, b=0, double_output=False): @@ -827,18 +820,12 @@ def test_trainer_eval_lm(self): self.assertEqual(len(dataset), 31) def test_trainer_iterable_dataset(self): - # Simulate Language Modeling with an IterableDataset, with no __len__ method - # Pick-up a tiny model, so it works on CPU - # See Issue #5990: https://github.com/huggingface/transformers/issues/5990 - MODEL_ID = "sshleifer/tiny-distilbert-base-cased" - model = AutoModelForMaskedLM.from_pretrained(MODEL_ID) - tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) - train_dataset = SampleIterableDataset(file_path=PATH_SAMPLE_TEXT, tokenizer=tokenizer) - training_args = TrainingArguments(output_dir="./examples", no_cuda=True, max_steps=2) - data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15) + config = RegressionModelConfig() + model = RegressionPreTrainedModel(config) + train_dataset = SampleIterableDataset() - training_args = TrainingArguments(output_dir="./examples", no_cuda=True, max_steps=2) - trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, data_collator=data_collator) + args = RegressionTrainingArguments(output_dir="./examples", max_steps=2) + trainer = Trainer(model=model, args=args, train_dataset=train_dataset) trainer.train() loader = trainer.get_train_dataloader() @@ -847,30 +834,19 @@ def test_trainer_iterable_dataset(self): # Exception if giving iterable dataset and no max_steps with self.assertRaises(ValueError): - training_args = TrainingArguments(output_dir="./examples", no_cuda=True) - _ = Trainer(model=model, args=training_args, train_dataset=train_dataset, data_collator=data_collator) + args1 = RegressionTrainingArguments(output_dir="./examples") + _ = Trainer(model=model, args=args1, train_dataset=train_dataset) # Exception if eval_dataset is iterable in __init__ with self.assertRaises(ValueError): - training_args = TrainingArguments(output_dir="./examples", no_cuda=True, max_steps=2) - _ = Trainer( - model=model, - args=training_args, - train_dataset=train_dataset, - eval_dataset=train_dataset, - data_collator=data_collator, - ) + _ = Trainer(model=model, args=args, train_dataset=train_dataset, eval_dataset=train_dataset) # Exception if predicting with iterable dataset with self.assertRaises(ValueError): - training_args = TrainingArguments(output_dir="./examples", no_cuda=True) - trainer = Trainer(model=model, args=training_args, data_collator=data_collator) trainer.predict(train_dataset) # Exception if evaluating with iterable dataset with self.assertRaises(ValueError): - training_args = TrainingArguments(output_dir="./examples", no_cuda=True) - trainer = Trainer(model=model, args=training_args, data_collator=data_collator) trainer.evaluate(train_dataset) def test_num_train_epochs_in_training(self): diff --git a/tests/test_trainer_utils.py b/tests/test_trainer_utils.py index be1037ffc651a7..8657a9e640966c 100644 --- a/tests/test_trainer_utils.py +++ b/tests/test_trainer_utils.py @@ -23,12 +23,14 @@ if is_torch_available(): import torch + from torch.utils.data import IterableDataset from transformers.modeling_outputs import SequenceClassifierOutput from transformers.trainer_pt_utils import ( DistributedLengthGroupedSampler, DistributedSamplerWithLoop, DistributedTensorGatherer, + IterableDatasetShard, LabelSmoother, LengthGroupedSampler, SequentialDistributedSampler, @@ -49,6 +51,22 @@ def forward(self, x): h = torch.nn.functional.relu(self.linear2(x)) return self.ln2(x + h + self.bias) + class RandomIterableDataset(IterableDataset): + # For testing, an iterable dataset of random length + def __init__(self, p_stop=0.01, max_length=1000): + self.p_stop = p_stop + self.max_length = max_length + self.generator = torch.Generator() + + def __iter__(self): + count = 0 + stop = False + while not stop and count < self.max_length: + yield count + count += 1 + number = torch.rand(1, generator=self.generator).item() + stop = number < self.p_stop + @require_torch class TrainerUtilsTest(unittest.TestCase): @@ -243,3 +261,45 @@ def test_sequential_distributed_sampler(self): self.assertListEqual(total[:length], dataset) self.assertListEqual(total[length:], dataset[: (len(total) - length)]) + + def check_iterable_dataset_shard(self, dataset, batch_size, drop_last, num_processes=2, epoch=0): + # Set the seed for the base dataset to get the proper reference. + dataset.generator.manual_seed(epoch) + reference = list(dataset) + + shards = [ + IterableDatasetShard( + dataset, batch_size=batch_size, drop_last=drop_last, num_processes=num_processes, process_index=i + ) + for i in range(num_processes) + ] + for shard in shards: + shard.set_epoch(epoch) + shard_lists = [list(shard) for shard in shards] + + for shard in shard_lists: + # All shards have a number of samples that is a round multiple of batch size + self.assertTrue(len(shard) % batch_size == 0) + # All shards have the same number of samples + self.assertEqual(len(shard), len(shard_lists[0])) + + observed = [] + for idx in range(0, len(shard_lists[0]), batch_size): + for shard in shard_lists: + observed += shard[idx : idx + batch_size] + + # If drop_last is False we loop through samples at the beginning to have a size that is a round multiple of + # batch_size + if not drop_last: + while len(reference) < len(observed): + reference += reference + self.assertListEqual(observed, reference[: len(observed)]) + + def test_iterable_dataset_shard(self): + dataset = RandomIterableDataset() + + self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=2, epoch=0) + self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=2, epoch=0) + + self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=3, epoch=42) + self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=3, epoch=42)