Skip to content

Commit

Permalink
Merge pull request #2 from lucidworks/AL-160-support-fast-tokenizer
Browse files Browse the repository at this point in the history
AL-160: support fast tokenizer
  • Loading branch information
bdalal authored Aug 27, 2020
2 parents 73a7a3f + 54cbfb1 commit 4f11dd2
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 5 deletions.
45 changes: 40 additions & 5 deletions src/transformers/data/processors/squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

from ...file_utils import is_tf_available, is_torch_available
from ...tokenization_bert import whitespace_tokenize
from ...tokenization_utils_base import TruncationStrategy
from ...tokenization_utils_base import PaddingStrategy, TruncationStrategy
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging
from .utils import DataProcessor

Expand Down Expand Up @@ -107,6 +108,14 @@ def squad_convert_example_to_features(
tok_to_orig_index = []
orig_to_tok_index = []
all_doc_tokens = []
if isinstance(tokenizer, PreTrainedTokenizerFast):
tokenizer.set_truncation_and_padding(
padding_strategy=PaddingStrategy.DO_NOT_PAD,
truncation_strategy=TruncationStrategy.LONGEST_FIRST,
max_length=64,
stride=0,
pad_to_multiple_of=None,
)
for (i, token) in enumerate(example.doc_tokens):
orig_to_tok_index.append(len(all_doc_tokens))
sub_tokens = tokenizer.tokenize(token)
Expand All @@ -131,6 +140,12 @@ def squad_convert_example_to_features(
example.question_text, add_special_tokens=False, truncation=True, max_length=max_query_length
)

# Handle case where tokenized query is empty, since the fast tokenizer doesn't do so
if len(truncated_query) == 0:
raise ValueError(
f"Input {truncated_query} is not valid. Should be a string or a list/tuple of strings when `is_pretokenized=True`."
)

# Tokenizers who insert 2 SEP tokens in-between <context> & <question> need to have special handling
# in the way they compute mask of added tokens.
tokenizer_type = type(tokenizer).__name__.replace("Tokenizer", "").lower()
Expand All @@ -146,12 +161,28 @@ def squad_convert_example_to_features(

# Define the side we want to truncate / pad and the text/pair sorting
if tokenizer.padding_side == "right":
texts = truncated_query
pairs = span_doc_tokens
texts = (
truncated_query
if not isinstance(tokenizer, PreTrainedTokenizerFast)
else tokenizer.decode(truncated_query)
)
# Needed because some tokenizers seem to produce actual tokens,
# while others produce token_ids for overflow tokens
if isinstance(span_doc_tokens[0], str):
pairs = " ".join(span_doc_tokens).replace(" ##", "").strip()
else:
pairs = span_doc_tokens
truncation = TruncationStrategy.ONLY_SECOND.value
else:
texts = span_doc_tokens
pairs = truncated_query
if isinstance(span_doc_tokens[0], str):
texts = " ".join(span_doc_tokens).replace(" ##", "").strip()
else:
texts = span_doc_tokens
pairs = (
truncated_query
if not isinstance(tokenizer, PreTrainedTokenizerFast)
else tokenizer.decode(truncated_query)
)
truncation = TruncationStrategy.ONLY_FIRST.value

encoded_dict = tokenizer.encode_plus( # TODO(thom) update this logic
Expand All @@ -165,6 +196,10 @@ def squad_convert_example_to_features(
return_token_type_ids=True,
)

# Handle case where fast tokenizer returns list[list[int]]
if isinstance(encoded_dict["input_ids"][0], list):
encoded_dict = {k: v[0] for k, v in encoded_dict.items()}

paragraph_len = min(
len(all_doc_tokens) - len(spans) * doc_stride,
max_seq_length - len(truncated_query) - sequence_pair_added_tokens,
Expand Down
4 changes: 4 additions & 0 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,8 @@ def test_torch_question_answering(self):
for model_name in QA_FINETUNED_MODELS:
nlp = pipeline(task="question-answering", model=model_name, tokenizer=model_name)
self._test_qa_pipeline(nlp)
nlp = pipeline(task="question-answering", model=model_name, tokenizer=(model_name, {"use_fast": True}))
self._test_qa_pipeline(nlp)

# Uncomment when onnx model available
# model_name = "deepset/bert-base-cased-squad2"
Expand All @@ -686,6 +688,8 @@ def test_tf_question_answering(self):
for model_name in QA_FINETUNED_MODELS:
nlp = pipeline(task="question-answering", model=model_name, tokenizer=model_name, framework="tf")
self._test_qa_pipeline(nlp)
nlp = pipeline(task="question-answering", model=model_name, tokenizer=(model_name, {"use_fast": True}))
self._test_qa_pipeline(nlp)


class NerPipelineTests(unittest.TestCase):
Expand Down

0 comments on commit 4f11dd2

Please sign in to comment.