diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 02f6a29dc57446..cab3bbb246146f 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -191,9 +191,15 @@ class Trainer: The function to use to form a batch from a list of elements of :obj:`train_dataset` or :obj:`eval_dataset`. Will default to :func:`~transformers.default_data_collator` if no ``tokenizer`` is provided, an instance of :func:`~transformers.DataCollatorWithPadding` otherwise. - train_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`): + train_dataset (:obj:`torch.utils.data.dataset.Dataset` or :obj:`torch.utils.data.dataset.IterableDataset`, `optional`): The dataset to use for training. If it is an :obj:`datasets.Dataset`, columns not accepted by the ``model.forward()`` method are automatically removed. + + Note that if it's a :obj:`torch.utils.data.dataset.IterableDataset` with some randomization and you are + training in a distributed fashion, your iterable dataset should either use a internal attribute + :obj:`generator` that is a :obj:`torch.Generator` for the randomization that must be identic on all + processes (and the Trainer will manually set the seed of this :obj:`generator` at each epoch) or have a + :obj:`set_epoch()` method that internally sets the seed of the RNGs used. eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`): The dataset to use for evaluation. If it is an :obj:`datasets.Dataset`, columns not accepted by the ``model.forward()`` method are automatically removed. @@ -1095,6 +1101,8 @@ def train( for epoch in range(epochs_trained, num_train_epochs): if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): train_dataloader.sampler.set_epoch(epoch) + elif isinstance(train_dataloader.dataset, IterableDatasetShard): + train_dataloader.dataset.set_epoch(epoch) if is_torch_tpu_available(): parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader( diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index e048cd8d94162e..c81f98c74454c3 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -598,8 +598,8 @@ class IterableDatasetShard(IterableDataset): :obj:`dataset` to generate your random numbers and call the :meth:`~transformers.trainer_pt_utils.IterableDatasetShard.set_epoch` method of this object. It will set the seed of this :obj:`generator` to :obj:`seed + epoch` on all processes before starting the iteration. - Alternatively, you can also subclass this class and override the :meth:`__iter__` method with your custom - logic. + Alternatively, you can also implement a :obj:`set_epoch()` method in your iterable dataset to deal with this. + Args: dataset (:obj:`torch.utils.data.dataset.IterableDataset`): @@ -637,9 +637,15 @@ def __init__( def set_epoch(self, epoch): self.epoch = epoch + if hasattr(self.dataset, "set_epoch"): + self.dataset.set_epoch(epoch) def __iter__(self): - if hasattr(self.dataset, "generator") and isinstance(self.dataset.generator, torch.Generator): + if ( + not hasattr(self.dataset, "set_epoch") + and hasattr(self.dataset, "generator") + and isinstance(self.dataset.generator, torch.Generator) + ): self.dataset.generator.manual_seed(self.seed + self.epoch) real_batch_size = self.batch_size * self.num_processes process_slice = range(self.process_index * self.batch_size, (self.process_index + 1) * self.batch_size)