Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizing away the fill-mask pipeline. #12113

Merged
merged 5 commits into from
Jun 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 52 additions & 26 deletions src/transformers/pipelines/fill_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@ def __call__(self, *args, targets=None, top_k: Optional[int] = None, **kwargs):
args (:obj:`str` or :obj:`List[str]`):
One or several texts (or one list of prompts) with masked tokens.
targets (:obj:`str` or :obj:`List[str]`, `optional`):
When passed, the model will return the scores for the passed token or tokens rather than the top k
predictions in the entire vocabulary. If the provided targets are not in the model vocab, they will be
tokenized and the first resulting token will be used (with a warning).
When passed, the model will limit the scores to the passed targets instead of looking up in the whole
vocab. If the provided targets are not in the model vocab, they will be tokenized and the first
resulting token will be used (with a warning, and that might be slower).
top_k (:obj:`int`, `optional`):
When passed, overrides the number of predictions to return.

Expand All @@ -115,25 +115,56 @@ def __call__(self, *args, targets=None, top_k: Optional[int] = None, **kwargs):
inputs = self._parse_and_tokenize(*args, **kwargs)
outputs = self._forward(inputs, return_tensors=True)

# top_k must be defined
if top_k is None:
top_k = self.top_k

results = []
batch_size = outputs.shape[0] if self.framework == "tf" else outputs.size(0)

if targets is not None:
if len(targets) == 0 or len(targets[0]) == 0:
raise ValueError("At least one target must be provided when passed.")
if isinstance(targets, str):
targets = [targets]

targets_proc = []
try:
vocab = self.tokenizer.get_vocab()
except Exception:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we could have a better exception than this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depends if we want to be permissive or restrictive in the type of error we catch.
Ideally there would not be a try except. But I was not sure it was mandatory to have such a method.

If we consider this could fail for valid reasons (not implemented), I don't really think we should make assumptions about the type of exceptions. NotImplementedError, AttributeError, but it could very well be a signature issue or something else.

I think we should either keep it like this (very permissive in types of errors) or remove the try altogether and be very restrictive (tokenizers MUST implement a valid get_vocab)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, works for me!

vocab = {}
target_ids = []
for target in targets:
target_enc = self.tokenizer.tokenize(target)
if len(target_enc) > 1 or target_enc[0] == self.tokenizer.unk_token:
id_ = vocab.get(target, None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here instead of using vocab.get which isn't available on all tokenizers, you could encode the target without special tokens and ensure it returns a list with a single value that is not an unknown token.

I can't think of a simpler way that is tokenizer agnostic.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you still want to keep the speed of vocab.get, maybe putting it behind a statement such as if hasattr(self.tokenizer, "vocab") would help.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's another option, but I don't think we can beat the lookup.

The tokenizer part is actually quite slow (compared to a Python dict lookup) here. Especially when called sequentially.
When called batched it was reduced heavily but still non negligible (something like 30ms for 10k targets).
I would have to dive to make sure, but I think it's linked to the strings which have to all be copied and parsed at the Python/Rust interface.

There are other implementations which I tried that still speedup compared to master, but I had the impression to do anything more than raw lookup for best case scenario was wrong. That's actually what the doc says too, the targets should be in the vocab.

if id_ is None:
input_ids = self.tokenizer(
target,
add_special_tokens=False,
return_attention_mask=False,
return_token_type_ids=False,
max_length=1,
truncation=True,
)["input_ids"]
if len(input_ids) == 0:
logger.warning(
f"The specified target token `{target}` does not exist in the model vocabulary. "
f"We cannot replace it with anything meaningful, ignoring it"
)
continue
id_ = input_ids[0]
# XXX: If users encounter this pass
# it becomes pretty slow, so let's make sure
# The warning enables them to fix the input to
# get faster performance.
logger.warning(
f"The specified target token `{target}` does not exist in the model vocabulary. "
f"Replacing with `{target_enc[0]}`."
f"Replacing with `{self.tokenizer.convert_ids_to_tokens(id_)}`."
)
targets_proc.append(target_enc[0])
target_inds = np.array(self.tokenizer.convert_tokens_to_ids(targets_proc))
target_ids.append(id_)
target_ids = list(set(target_ids))
if len(target_ids) == 0:
raise ValueError("At least one target must be provided when passed.")
target_ids = np.array(target_ids)
# Cap top_k if there are targets
if top_k > target_ids.shape[0]:
top_k = target_ids.shape[0]
Comment on lines +165 to +167
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice


for i in range(batch_size):
input_ids = inputs["input_ids"][i]
Expand All @@ -147,14 +178,11 @@ def __call__(self, *args, targets=None, top_k: Optional[int] = None, **kwargs):

logits = outputs[i, masked_index.item(), :]
probs = tf.nn.softmax(logits)
if targets is None:
topk = tf.math.top_k(probs, k=top_k if top_k is not None else self.top_k)
values, predictions = topk.values.numpy(), topk.indices.numpy()
else:
values = tf.gather_nd(probs, tf.reshape(target_inds, (-1, 1)))
sort_inds = tf.reverse(tf.argsort(values), [0])
values = tf.gather_nd(values, tf.reshape(sort_inds, (-1, 1))).numpy()
predictions = target_inds[sort_inds.numpy()]
if targets is not None:
probs = tf.gather_nd(probs, tf.reshape(target_ids, (-1, 1)))

topk = tf.math.top_k(probs, k=top_k)
values, predictions = topk.values.numpy(), topk.indices.numpy()
else:
masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False)

Expand All @@ -163,13 +191,11 @@ def __call__(self, *args, targets=None, top_k: Optional[int] = None, **kwargs):

logits = outputs[i, masked_index.item(), :]
probs = logits.softmax(dim=0)
if targets is None:
values, predictions = probs.topk(top_k if top_k is not None else self.top_k)
else:
values = probs[..., target_inds]
sort_inds = list(reversed(values.argsort(dim=-1)))
values = values[..., sort_inds]
predictions = target_inds[sort_inds]

if targets is not None:
probs = probs[..., target_ids]

values, predictions = probs.topk(top_k)

for v, p in zip(values.tolist(), predictions.tolist()):
tokens = input_ids.numpy()
Expand Down
33 changes: 29 additions & 4 deletions tests/test_pipelines_fill_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def test_torch_fill_mask(self):
@require_torch
def test_torch_fill_mask_with_targets(self):
valid_inputs = ["My name is <mask>"]
valid_targets = [[" Teven", " Patrick", " Clara"], [" Sam"]]
# ' Sam' will yield a warning but work
valid_targets = [[" Teven", "ĠPatrick", "ĠClara"], ["ĠSam"], [" Sam"]]
invalid_targets = [[], [""], ""]
for model_name in self.small_models:
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt")
Expand All @@ -89,10 +90,34 @@ def test_torch_fill_mask_with_targets(self):
for targets in invalid_targets:
self.assertRaises(ValueError, unmasker, valid_inputs, targets=targets)

@require_torch
def test_torch_fill_mask_with_targets_and_topk(self):
model_name = self.small_models[0]
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt")
targets = [" Teven", "ĠPatrick", "ĠClara"]
top_k = 2
outputs = unmasker("My name is <mask>", targets=targets, top_k=top_k)

self.assertEqual(len(outputs), 2)

@require_torch
def test_torch_fill_mask_with_duplicate_targets_and_topk(self):
model_name = self.small_models[0]
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt")
# String duplicates + id duplicates
targets = [" Teven", "ĠPatrick", "ĠClara", "ĠClara", " Clara"]
top_k = 10
outputs = unmasker("My name is <mask>", targets=targets, top_k=top_k)

# The target list contains duplicates, so we can't output more
# than them
self.assertEqual(len(outputs), 3)

@require_tf
def test_tf_fill_mask_with_targets(self):
valid_inputs = ["My name is <mask>"]
valid_targets = [[" Teven", " Patrick", " Clara"], [" Sam"]]
# ' Sam' will yield a warning but work
valid_targets = [[" Teven", "ĠPatrick", "ĠClara"], ["ĠSam"], [" Sam"]]
invalid_targets = [[], [""], ""]
for model_name in self.small_models:
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf")
Expand All @@ -111,7 +136,7 @@ def test_torch_fill_mask_results(self):
"My name is <mask>",
"The largest city in France is <mask>",
]
valid_targets = [" Patrick", " Clara"]
valid_targets = ["ĠPatrick", "ĠClara"]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those tests could stay the same, but it's actually missing from the vocabulary, so it's hitting the slow path and triggering warnings.

the 2 small models don't share the same vocabulary, so if we want tests that avoid the slow path all the time we should have 2 version of the targets IMO.

for model_name in self.large_models:
unmasker = pipeline(
task="fill-mask",
Expand Down Expand Up @@ -184,7 +209,7 @@ def test_tf_fill_mask_results(self):
"My name is <mask>",
"The largest city in France is <mask>",
]
valid_targets = [" Patrick", " Clara"]
valid_targets = ["ĠPatrick", "ĠClara"]
for model_name in self.large_models:
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf", top_k=2)

Expand Down