Skip to content

Commit

Permalink
Trainer iterable dataset (#11254)
Browse files Browse the repository at this point in the history
* IterableDatasetShard

* Test and integration in Trainer

* Update src/transformers/trainer_pt_utils.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Style

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
  • Loading branch information
2 people authored and Rocketknight1 committed Apr 21, 2021
1 parent 7c92cd6 commit 942ff8f
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 40 deletions.
25 changes: 22 additions & 3 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
DistributedLengthGroupedSampler,
DistributedSamplerWithLoop,
DistributedTensorGatherer,
IterableDatasetShard,
LabelSmoother,
LengthGroupedSampler,
SequentialDistributedSampler,
Expand Down Expand Up @@ -493,9 +494,7 @@ def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optio
dataset.set_format(type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"])

def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset) or not isinstance(
self.train_dataset, collections.abc.Sized
):
if not isinstance(self.train_dataset, collections.abc.Sized):
return None

# Build the sampler.
Expand Down Expand Up @@ -553,6 +552,26 @@ def get_train_dataloader(self) -> DataLoader:
"""
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")

if isinstance(self.train_dataset, torch.utils.data.dataset.IterableDataset):
if self.args.world_size > 1:
train_dataset = IterableDatasetShard(
self.train_dataset,
batch_size=self.args.train_batch_size,
drop_last=self.args.dataloader_drop_last,
num_processes=self.args.world_size,
process_index=self.args.process_index,
)
else:
train_dataset = self.train_dataset
return DataLoader(
train_dataset,
batch_size=self.args.train_batch_size,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
)

train_sampler = self._get_train_sampler()

return DataLoader(
Expand Down
92 changes: 91 additions & 1 deletion src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import numpy as np
import torch
from packaging import version
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataset import Dataset, IterableDataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, Sampler

Expand Down Expand Up @@ -576,6 +576,96 @@ def __iter__(self) -> Iterator:
return iter(indices)


class IterableDatasetShard(IterableDataset):
"""
Wraps a PyTorch :obj:`IterableDataset` to generate samples for one of the processes only. Instances of this class
will always yield a number of samples that is a round multiple of the actual batch size (which is :obj:`batch_size
x num_processes`). Depending on the value of the :obj:`drop_last` attribute, it will either stop the iteration at
the first batch that would be too small or loop with indices from the beginning.
On two processes with an iterable dataset yielding of :obj:`[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]` with a batch
size of 2:
- the shard on process 0 will yield :obj:`[0, 1, 4, 5, 8, 9]` so will see batches :obj:`[0, 1]`, :obj:`[4, 5]`,
:obj:`[8, 9]`
- the shard on process 1 will yield :obj:`[2, 3, 6, 7, 10, 11]` so will see batches :obj:`[2, 3]`, :obj:`[6, 7]`,
:obj:`[10, 11]`
.. warning:
If your IterableDataset implements some randomization that needs to be applied the same way on all processes
(for instance, a shuffling), you should use a :obj:`torch.Generator` in a :obj:`generator` attribute of the
:obj:`dataset` to generate your random numbers and call the
:meth:`~transformers.trainer_pt_utils.IterableDatasetShard.set_epoch` method of this object. It will set the
seed of this :obj:`generator` to :obj:`seed + epoch` on all processes before starting the iteration.
Alternatively, you can also subclass this class and override the :meth:`__iter__` method with your custom
logic.
Args:
dataset (:obj:`torch.utils.data.dataset.IterableDataset`):
The batch sampler to split in several shards.
batch_size (:obj:`int`, `optional`, defaults to 1):
The size of the batches per shard.
drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to drop the last incomplete batch or complete the last batches by using the samples from the
beginning.
num_processes (:obj:`int`, `optional`, defaults to 1):
The number of processes running concurrently.
process_index (:obj:`int`, `optional`, defaults to 0):
The index of the current process.
seed (:obj:`int`, `optional`, defaults to 0):
A random seed that will be used for the random number generation in
:meth:`~transformers.trainer_pt_utils.IterableDatasetShard.set_epoch`.
"""

def __init__(
self,
dataset: IterableDataset,
batch_size: int = 1,
drop_last: bool = False,
num_processes: int = 1,
process_index: int = 0,
seed: int = 0,
):
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
self.num_processes = num_processes
self.process_index = process_index
self.seed = seed
self.epoch = 0

def set_epoch(self, epoch):
self.epoch = epoch

def __iter__(self):
if hasattr(self.dataset, "generator") and isinstance(self.dataset.generator, torch.Generator):
self.dataset.generator.manual_seed(self.seed + self.epoch)
real_batch_size = self.batch_size * self.num_processes
process_slice = range(self.process_index * self.batch_size, (self.process_index + 1) * self.batch_size)

first_batch = None
current_batch = []
for element in self.dataset:
current_batch.append(element)
# Wait to have a full batch before yielding elements.
if len(current_batch) == real_batch_size:
for i in process_slice:
yield current_batch[i]
if first_batch is None:
first_batch = current_batch.copy()
current_batch = []

# Finished if drop_last is True, otherwise complete the last batch with elements from the beginning.
if not self.drop_last and len(current_batch) > 0:
if first_batch is None:
first_batch = current_batch.copy()
while len(current_batch) < real_batch_size:
current_batch += first_batch
for i in process_slice:
yield current_batch[i]


# In order to keep `trainer.py` compact and easy to understand, place any secondary PT Trainer
# helper methods here

Expand Down
48 changes: 12 additions & 36 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,14 @@
from torch.utils.data import IterableDataset

from transformers import (
AutoModelForMaskedLM,
AutoModelForSequenceClassification,
DataCollatorForLanguageModeling,
EarlyStoppingCallback,
GlueDataset,
GlueDataTrainingArguments,
GPT2Config,
GPT2LMHeadModel,
LineByLineTextDataset,
PreTrainedModel,
TextDataset,
Trainer,
TrainerState,
)
Expand Down Expand Up @@ -138,16 +135,12 @@ def __init__(self, a=0, b=0, double_output=False, **kwargs):
if is_torch_available():

class SampleIterableDataset(IterableDataset):
"""
Criteria is not whether it is IterableDataset or not, criteria is whether __len__ is implemented
"""

def __init__(self, file_path, tokenizer):
self.ds = TextDataset(file_path=file_path, tokenizer=tokenizer, block_size=64)
def __init__(self, a=2, b=3, length=64, seed=42, label_names=None):
self.dataset = RegressionDataset(a=a, b=b, length=length, seed=seed, label_names=label_names)

def __iter__(self):
for i in range(len(self.ds)):
yield self.ds[i]
for i in range(len(self.dataset)):
yield self.dataset[i]

class RegressionModel(torch.nn.Module):
def __init__(self, a=0, b=0, double_output=False):
Expand Down Expand Up @@ -827,18 +820,12 @@ def test_trainer_eval_lm(self):
self.assertEqual(len(dataset), 31)

def test_trainer_iterable_dataset(self):
# Simulate Language Modeling with an IterableDataset, with no __len__ method
# Pick-up a tiny model, so it works on CPU
# See Issue #5990: https://github.com/huggingface/transformers/issues/5990
MODEL_ID = "sshleifer/tiny-distilbert-base-cased"
model = AutoModelForMaskedLM.from_pretrained(MODEL_ID)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
train_dataset = SampleIterableDataset(file_path=PATH_SAMPLE_TEXT, tokenizer=tokenizer)
training_args = TrainingArguments(output_dir="./examples", no_cuda=True, max_steps=2)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
config = RegressionModelConfig()
model = RegressionPreTrainedModel(config)
train_dataset = SampleIterableDataset()

training_args = TrainingArguments(output_dir="./examples", no_cuda=True, max_steps=2)
trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, data_collator=data_collator)
args = RegressionTrainingArguments(output_dir="./examples", max_steps=2)
trainer = Trainer(model=model, args=args, train_dataset=train_dataset)
trainer.train()

loader = trainer.get_train_dataloader()
Expand All @@ -847,30 +834,19 @@ def test_trainer_iterable_dataset(self):

# Exception if giving iterable dataset and no max_steps
with self.assertRaises(ValueError):
training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
_ = Trainer(model=model, args=training_args, train_dataset=train_dataset, data_collator=data_collator)
args1 = RegressionTrainingArguments(output_dir="./examples")
_ = Trainer(model=model, args=args1, train_dataset=train_dataset)

# Exception if eval_dataset is iterable in __init__
with self.assertRaises(ValueError):
training_args = TrainingArguments(output_dir="./examples", no_cuda=True, max_steps=2)
_ = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=train_dataset,
data_collator=data_collator,
)
_ = Trainer(model=model, args=args, train_dataset=train_dataset, eval_dataset=train_dataset)

# Exception if predicting with iterable dataset
with self.assertRaises(ValueError):
training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
trainer = Trainer(model=model, args=training_args, data_collator=data_collator)
trainer.predict(train_dataset)

# Exception if evaluating with iterable dataset
with self.assertRaises(ValueError):
training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
trainer = Trainer(model=model, args=training_args, data_collator=data_collator)
trainer.evaluate(train_dataset)

def test_num_train_epochs_in_training(self):
Expand Down
60 changes: 60 additions & 0 deletions tests/test_trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@

if is_torch_available():
import torch
from torch.utils.data import IterableDataset

from transformers.modeling_outputs import SequenceClassifierOutput
from transformers.trainer_pt_utils import (
DistributedLengthGroupedSampler,
DistributedSamplerWithLoop,
DistributedTensorGatherer,
IterableDatasetShard,
LabelSmoother,
LengthGroupedSampler,
SequentialDistributedSampler,
Expand All @@ -49,6 +51,22 @@ def forward(self, x):
h = torch.nn.functional.relu(self.linear2(x))
return self.ln2(x + h + self.bias)

class RandomIterableDataset(IterableDataset):
# For testing, an iterable dataset of random length
def __init__(self, p_stop=0.01, max_length=1000):
self.p_stop = p_stop
self.max_length = max_length
self.generator = torch.Generator()

def __iter__(self):
count = 0
stop = False
while not stop and count < self.max_length:
yield count
count += 1
number = torch.rand(1, generator=self.generator).item()
stop = number < self.p_stop


@require_torch
class TrainerUtilsTest(unittest.TestCase):
Expand Down Expand Up @@ -243,3 +261,45 @@ def test_sequential_distributed_sampler(self):

self.assertListEqual(total[:length], dataset)
self.assertListEqual(total[length:], dataset[: (len(total) - length)])

def check_iterable_dataset_shard(self, dataset, batch_size, drop_last, num_processes=2, epoch=0):
# Set the seed for the base dataset to get the proper reference.
dataset.generator.manual_seed(epoch)
reference = list(dataset)

shards = [
IterableDatasetShard(
dataset, batch_size=batch_size, drop_last=drop_last, num_processes=num_processes, process_index=i
)
for i in range(num_processes)
]
for shard in shards:
shard.set_epoch(epoch)
shard_lists = [list(shard) for shard in shards]

for shard in shard_lists:
# All shards have a number of samples that is a round multiple of batch size
self.assertTrue(len(shard) % batch_size == 0)
# All shards have the same number of samples
self.assertEqual(len(shard), len(shard_lists[0]))

observed = []
for idx in range(0, len(shard_lists[0]), batch_size):
for shard in shard_lists:
observed += shard[idx : idx + batch_size]

# If drop_last is False we loop through samples at the beginning to have a size that is a round multiple of
# batch_size
if not drop_last:
while len(reference) < len(observed):
reference += reference
self.assertListEqual(observed, reference[: len(observed)])

def test_iterable_dataset_shard(self):
dataset = RandomIterableDataset()

self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=2, epoch=0)
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=2, epoch=0)

self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=3, epoch=42)
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=3, epoch=42)

0 comments on commit 942ff8f

Please sign in to comment.