diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index 29d7bf43a2d7..7aadff233fc3 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -4,6 +4,7 @@ import torch from torch.nn.utils.rnn import pad_sequence +from ..tokenization_utils_base import BatchEncoding from ..tokenization_utils import PreTrainedTokenizer @@ -33,7 +34,7 @@ def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Ten # have the same attributes. # So we will look at the first element as a proxy for what attributes exist # on the whole batch. - if not isinstance(features[0], dict): + if not isinstance(features[0], (dict, BatchEncoding)): features = [vars(f) for f in features] first = features[0] @@ -78,7 +79,7 @@ class DataCollatorForLanguageModeling: mlm_probability: float = 0.15 def __call__(self, examples: List[Union[torch.Tensor, Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]: - if isinstance(examples[0], dict): + if isinstance(examples[0], (dict, BatchEncoding)): examples = [e["input_ids"] for e in examples] batch = self._tensorize_batch(examples) if self.mlm: @@ -151,7 +152,7 @@ class DataCollatorForPermutationLanguageModeling: max_span_length: int = 5 # maximum length of a span of masked tokens def __call__(self, examples: List[Union[torch.Tensor, Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]: - if isinstance(examples[0], dict): + if isinstance(examples[0], (dict, BatchEncoding)): examples = [e["input_ids"] for e in examples] batch = self._tensorize_batch(examples) inputs, perm_mask, target_mapping, labels = self.mask_tokens(batch)