Skip to content

Commit

Permalink
Update PreTrainedTokenizerBase to check/handle batch length for `te…
Browse files Browse the repository at this point in the history
…xt_pair` parameter (#11486)

* Update tokenization_utils_base.py

* add assertion

* check batch len

* Update src/transformers/tokenization_utils_base.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* add error message

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
hamelsmu and sgugger authored Apr 28, 2021
1 parent 2d27900 commit c0eb218
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2279,6 +2279,14 @@ def __call__(
)

if is_batched:
if isinstance(text_pair, str):
raise TypeError(
"when tokenizing batches of text, `text_pair` must be a list or tuple with the same length as `text`."
)
if text_pair is not None and len(text) != len(text_pair):
raise ValueError(
f"batch length of `text`: {len(text)} does not match batch length of `text_pair`: {len(text_pair)}."
)
batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text
return self.batch_encode_plus(
batch_text_or_text_pairs=batch_text_or_text_pairs,
Expand Down

0 comments on commit c0eb218

Please sign in to comment.