Skip to content

Commit

Permalink
Fix #6092
Browse files Browse the repository at this point in the history
  • Loading branch information
sgugger committed Jul 28, 2020
1 parent 54f49af commit a0a3e3f
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/transformers/data/data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from torch.nn.utils.rnn import pad_sequence

from ..tokenization_utils_base import BatchEncoding
from ..tokenization_utils import PreTrainedTokenizer


Expand Down Expand Up @@ -33,7 +34,7 @@ def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Ten
# have the same attributes.
# So we will look at the first element as a proxy for what attributes exist
# on the whole batch.
if not isinstance(features[0], dict):
if not isinstance(features[0], (dict, BatchEncoding)):
features = [vars(f) for f in features]

first = features[0]
Expand Down Expand Up @@ -78,7 +79,7 @@ class DataCollatorForLanguageModeling:
mlm_probability: float = 0.15

def __call__(self, examples: List[Union[torch.Tensor, Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
if isinstance(examples[0], dict):
if isinstance(examples[0], (dict, BatchEncoding)):
examples = [e["input_ids"] for e in examples]
batch = self._tensorize_batch(examples)
if self.mlm:
Expand Down Expand Up @@ -151,7 +152,7 @@ class DataCollatorForPermutationLanguageModeling:
max_span_length: int = 5 # maximum length of a span of masked tokens

def __call__(self, examples: List[Union[torch.Tensor, Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
if isinstance(examples[0], dict):
if isinstance(examples[0], (dict, BatchEncoding)):
examples = [e["input_ids"] for e in examples]
batch = self._tensorize_batch(examples)
inputs, perm_mask, target_mapping, labels = self.mask_tokens(batch)
Expand Down

0 comments on commit a0a3e3f

Please sign in to comment.