Skip to content

Commit

Permalink
Notify users that DataCollatorForWholeWordMask is limited to BertToke…
Browse files Browse the repository at this point in the history
…nier-like tokenizers
  • Loading branch information
ionicsolutions committed Jun 26, 2021
1 parent 9a75459 commit 31a813c
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion src/transformers/data/data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ..file_utils import PaddingStrategy
from ..modeling_utils import PreTrainedModel
from ..tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase
from ..models.bert import BertTokenizer, BertTokenizerFast


InputDataClass = NewType("InputDataClass", Any)
Expand Down Expand Up @@ -395,10 +396,18 @@ def mask_tokens(
@dataclass
class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
"""
Data collator used for language modeling.
Data collator used for language modeling that masks entire words.
- collates batches of tensors, honoring their tokenizer's pad_token
- preprocesses batches for masked language modeling
.. note::
This collator relies on details of the implementation of subword tokenization
by :class:`~transformers.BertTokenizer`, specifically that subword tokens are
prefixed with `##`. For tokenizers that do not adhere to this scheme, this
collator will produce an output that is roughly equivalent to
:class:`.DataCollatorForLanguageModeling`.
"""

def __call__(
Expand Down Expand Up @@ -435,6 +444,9 @@ def _whole_word_mask(self, input_tokens: List[str], max_predictions=512):
"""
Get 0/1 labels for masked tokens with whole word mask proxy
"""
if not isinstance(self.tokenizer, (BertTokenizer, BertTokenizerFast)):
warnings.warn("DataCollatorForWholeWordMask is only suitable for BertTokenizer-like tokenizers."
"Please refer to the documentation for more information.")

cand_indexes = []
for (i, token) in enumerate(input_tokens):
Expand Down

0 comments on commit 31a813c

Please sign in to comment.