Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accepts BatchEncoding in LengthGroupedSampler #11431

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from torch.utils.data.sampler import RandomSampler, Sampler

from .file_utils import is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, is_torch_tpu_available
from .tokenization_utils_base import BatchEncoding
from .utils import logging


Expand Down Expand Up @@ -514,7 +515,10 @@ def __init__(
self.batch_size = batch_size
self.model_input_name = model_input_name if model_input_name is not None else "input_ids"
if lengths is None:
if not isinstance(dataset[0], dict) or self.model_input_name not in dataset[0]:
if (
not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding))
or self.model_input_name not in dataset[0]
):
raise ValueError(
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
f"'{self.model_input_name}' key."
Expand Down Expand Up @@ -575,7 +579,10 @@ def __init__(
self.model_input_name = model_input_name if model_input_name is not None else "input_ids"

if lengths is None:
if not isinstance(dataset[0], dict) or self.model_input_name not in dataset[0]:
if (
not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding))
or self.model_input_name not in dataset[0]
):
raise ValueError(
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
f"'{self.model_input_name}' key."
Expand Down
31 changes: 31 additions & 0 deletions tests/test_trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from torch.utils.data import IterableDataset

from transformers.modeling_outputs import SequenceClassifierOutput
from transformers.tokenization_utils_base import BatchEncoding
from transformers.trainer_pt_utils import (
DistributedLengthGroupedSampler,
DistributedSamplerWithLoop,
Expand Down Expand Up @@ -185,6 +186,36 @@ def test_group_by_length(self):
# The indices should be a permutation of range(100)
self.assertEqual(list(sorted(indices)), list(range(100)))

def test_group_by_length_with_dict(self):
# Get some inputs of random lengths
data = []
for _ in range(6):
input_ids = torch.randint(0, 25, (100,)).tolist()
data.append({"input_ids": input_ids})
# Put one bigger than the others to check it ends up in first position
data[3]["input_ids"] = torch.randint(0, 25, (105,)).tolist()

indices = list(LengthGroupedSampler(data, 4))
# The biggest element should be first
self.assertEqual(len(data[indices[0]]["input_ids"]), 105)
# The indices should be a permutation of range(6)
self.assertEqual(list(sorted(indices)), list(range(6)))

def test_group_by_length_with_batch_encoding(self):
# Get some inputs of random lengths
data = []
for _ in range(6):
input_ids = torch.randint(0, 25, (100,)).tolist()
data.append(BatchEncoding({"input_ids": input_ids}))
# Put one bigger than the others to check it ends up in first position
data[3]["input_ids"] = torch.randint(0, 25, (105,)).tolist()

indices = list(LengthGroupedSampler(data, 4))
# The biggest element should be first
self.assertEqual(len(data[indices[0]]["input_ids"]), 105)
# The indices should be a permutation of range(6)
self.assertEqual(list(sorted(indices)), list(range(6)))

def test_distributed_length_grouped(self):
# Get some inputs of random lengths
lengths = torch.randint(0, 25, (100,)).tolist()
Expand Down