Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for set_epoch in IterableDataset #11258

Merged
merged 1 commit into from
Apr 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,15 @@ class Trainer:
The function to use to form a batch from a list of elements of :obj:`train_dataset` or :obj:`eval_dataset`.
Will default to :func:`~transformers.default_data_collator` if no ``tokenizer`` is provided, an instance of
:func:`~transformers.DataCollatorWithPadding` otherwise.
train_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
train_dataset (:obj:`torch.utils.data.dataset.Dataset` or :obj:`torch.utils.data.dataset.IterableDataset`, `optional`):
The dataset to use for training. If it is an :obj:`datasets.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed.

Note that if it's a :obj:`torch.utils.data.dataset.IterableDataset` with some randomization and you are
training in a distributed fashion, your iterable dataset should either use a internal attribute
:obj:`generator` that is a :obj:`torch.Generator` for the randomization that must be identic on all
processes (and the Trainer will manually set the seed of this :obj:`generator` at each epoch) or have a
:obj:`set_epoch()` method that internally sets the seed of the RNGs used.
Comment on lines +198 to +202
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice docstring!

eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
The dataset to use for evaluation. If it is an :obj:`datasets.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed.
Expand Down Expand Up @@ -1095,6 +1101,8 @@ def train(
for epoch in range(epochs_trained, num_train_epochs):
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
train_dataloader.sampler.set_epoch(epoch)
elif isinstance(train_dataloader.dataset, IterableDatasetShard):
train_dataloader.dataset.set_epoch(epoch)

if is_torch_tpu_available():
parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
Expand Down
12 changes: 9 additions & 3 deletions src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,8 +598,8 @@ class IterableDatasetShard(IterableDataset):
:obj:`dataset` to generate your random numbers and call the
:meth:`~transformers.trainer_pt_utils.IterableDatasetShard.set_epoch` method of this object. It will set the
seed of this :obj:`generator` to :obj:`seed + epoch` on all processes before starting the iteration.
Alternatively, you can also subclass this class and override the :meth:`__iter__` method with your custom
logic.
Alternatively, you can also implement a :obj:`set_epoch()` method in your iterable dataset to deal with this.


Args:
dataset (:obj:`torch.utils.data.dataset.IterableDataset`):
Expand Down Expand Up @@ -637,9 +637,15 @@ def __init__(

def set_epoch(self, epoch):
self.epoch = epoch
if hasattr(self.dataset, "set_epoch"):
self.dataset.set_epoch(epoch)

def __iter__(self):
if hasattr(self.dataset, "generator") and isinstance(self.dataset.generator, torch.Generator):
if (
not hasattr(self.dataset, "set_epoch")
and hasattr(self.dataset, "generator")
and isinstance(self.dataset.generator, torch.Generator)
):
self.dataset.generator.manual_seed(self.seed + self.epoch)
real_batch_size = self.batch_size * self.num_processes
process_slice = range(self.process_index * self.batch_size, (self.process_index + 1) * self.batch_size)
Expand Down