Skip to content

Commit

Permalink
Fix code formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
ionicsolutions committed Jun 26, 2021
1 parent 31a813c commit 524b661
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions src/transformers/data/data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@

from ..file_utils import PaddingStrategy
from ..modeling_utils import PreTrainedModel
from ..tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase
from ..models.bert import BertTokenizer, BertTokenizerFast
from ..tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase


InputDataClass = NewType("InputDataClass", Any)
Expand Down Expand Up @@ -403,10 +403,9 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
.. note::
This collator relies on details of the implementation of subword tokenization
by :class:`~transformers.BertTokenizer`, specifically that subword tokens are
prefixed with `##`. For tokenizers that do not adhere to this scheme, this
collator will produce an output that is roughly equivalent to
This collator relies on details of the implementation of subword tokenization by
:class:`~transformers.BertTokenizer`, specifically that subword tokens are prefixed with `##`. For tokenizers
that do not adhere to this scheme, this collator will produce an output that is roughly equivalent to
:class:`.DataCollatorForLanguageModeling`.
"""

Expand Down Expand Up @@ -445,8 +444,10 @@ def _whole_word_mask(self, input_tokens: List[str], max_predictions=512):
Get 0/1 labels for masked tokens with whole word mask proxy
"""
if not isinstance(self.tokenizer, (BertTokenizer, BertTokenizerFast)):
warnings.warn("DataCollatorForWholeWordMask is only suitable for BertTokenizer-like tokenizers."
"Please refer to the documentation for more information.")
warnings.warn(
"DataCollatorForWholeWordMask is only suitable for BertTokenizer-like tokenizers."
"Please refer to the documentation for more information."
)

cand_indexes = []
for (i, token) in enumerate(input_tokens):
Expand Down

0 comments on commit 524b661

Please sign in to comment.