Skip to content

Commit

Permalink
Accepts BatchEncoding in LengthSampler (#11431)
Browse files Browse the repository at this point in the history
  • Loading branch information
tma15 authored Apr 30, 2021
1 parent 30ede89 commit c2cd02a
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 2 deletions.
11 changes: 9 additions & 2 deletions src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from torch.utils.data.sampler import RandomSampler, Sampler

from .file_utils import is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, is_torch_tpu_available
from .tokenization_utils_base import BatchEncoding
from .utils import logging


Expand Down Expand Up @@ -514,7 +515,10 @@ def __init__(
self.batch_size = batch_size
self.model_input_name = model_input_name if model_input_name is not None else "input_ids"
if lengths is None:
if not isinstance(dataset[0], dict) or self.model_input_name not in dataset[0]:
if (
not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding))
or self.model_input_name not in dataset[0]
):
raise ValueError(
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
f"'{self.model_input_name}' key."
Expand Down Expand Up @@ -575,7 +579,10 @@ def __init__(
self.model_input_name = model_input_name if model_input_name is not None else "input_ids"

if lengths is None:
if not isinstance(dataset[0], dict) or self.model_input_name not in dataset[0]:
if (
not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding))
or self.model_input_name not in dataset[0]
):
raise ValueError(
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
f"'{self.model_input_name}' key."
Expand Down
31 changes: 31 additions & 0 deletions tests/test_trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from torch.utils.data import IterableDataset

from transformers.modeling_outputs import SequenceClassifierOutput
from transformers.tokenization_utils_base import BatchEncoding
from transformers.trainer_pt_utils import (
DistributedLengthGroupedSampler,
DistributedSamplerWithLoop,
Expand Down Expand Up @@ -185,6 +186,36 @@ def test_group_by_length(self):
# The indices should be a permutation of range(100)
self.assertEqual(list(sorted(indices)), list(range(100)))

def test_group_by_length_with_dict(self):
# Get some inputs of random lengths
data = []
for _ in range(6):
input_ids = torch.randint(0, 25, (100,)).tolist()
data.append({"input_ids": input_ids})
# Put one bigger than the others to check it ends up in first position
data[3]["input_ids"] = torch.randint(0, 25, (105,)).tolist()

indices = list(LengthGroupedSampler(data, 4))
# The biggest element should be first
self.assertEqual(len(data[indices[0]]["input_ids"]), 105)
# The indices should be a permutation of range(6)
self.assertEqual(list(sorted(indices)), list(range(6)))

def test_group_by_length_with_batch_encoding(self):
# Get some inputs of random lengths
data = []
for _ in range(6):
input_ids = torch.randint(0, 25, (100,)).tolist()
data.append(BatchEncoding({"input_ids": input_ids}))
# Put one bigger than the others to check it ends up in first position
data[3]["input_ids"] = torch.randint(0, 25, (105,)).tolist()

indices = list(LengthGroupedSampler(data, 4))
# The biggest element should be first
self.assertEqual(len(data[indices[0]]["input_ids"]), 105)
# The indices should be a permutation of range(6)
self.assertEqual(list(sorted(indices)), list(range(6)))

def test_distributed_length_grouped(self):
# Get some inputs of random lengths
lengths = torch.randint(0, 25, (100,)).tolist()
Expand Down

0 comments on commit c2cd02a

Please sign in to comment.