-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimizing away the fill-mask
pipeline.
#12113
Changes from all commits
38143a3
f56b32a
190b35d
6e9dca7
68729e5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -98,9 +98,9 @@ def __call__(self, *args, targets=None, top_k: Optional[int] = None, **kwargs): | |
args (:obj:`str` or :obj:`List[str]`): | ||
One or several texts (or one list of prompts) with masked tokens. | ||
targets (:obj:`str` or :obj:`List[str]`, `optional`): | ||
When passed, the model will return the scores for the passed token or tokens rather than the top k | ||
predictions in the entire vocabulary. If the provided targets are not in the model vocab, they will be | ||
tokenized and the first resulting token will be used (with a warning). | ||
When passed, the model will limit the scores to the passed targets instead of looking up in the whole | ||
vocab. If the provided targets are not in the model vocab, they will be tokenized and the first | ||
resulting token will be used (with a warning, and that might be slower). | ||
top_k (:obj:`int`, `optional`): | ||
When passed, overrides the number of predictions to return. | ||
|
||
|
@@ -115,25 +115,56 @@ def __call__(self, *args, targets=None, top_k: Optional[int] = None, **kwargs): | |
inputs = self._parse_and_tokenize(*args, **kwargs) | ||
outputs = self._forward(inputs, return_tensors=True) | ||
|
||
# top_k must be defined | ||
if top_k is None: | ||
top_k = self.top_k | ||
|
||
results = [] | ||
batch_size = outputs.shape[0] if self.framework == "tf" else outputs.size(0) | ||
|
||
if targets is not None: | ||
if len(targets) == 0 or len(targets[0]) == 0: | ||
raise ValueError("At least one target must be provided when passed.") | ||
if isinstance(targets, str): | ||
targets = [targets] | ||
|
||
targets_proc = [] | ||
try: | ||
vocab = self.tokenizer.get_vocab() | ||
except Exception: | ||
vocab = {} | ||
target_ids = [] | ||
for target in targets: | ||
target_enc = self.tokenizer.tokenize(target) | ||
if len(target_enc) > 1 or target_enc[0] == self.tokenizer.unk_token: | ||
id_ = vocab.get(target, None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here instead of using I can't think of a simpler way that is tokenizer agnostic. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you still want to keep the speed of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's another option, but I don't think we can beat the lookup. The There are other implementations which I tried that still speedup compared to master, but I had the impression to do anything more than raw lookup for best case scenario was wrong. That's actually what the doc says too, the |
||
if id_ is None: | ||
input_ids = self.tokenizer( | ||
target, | ||
add_special_tokens=False, | ||
return_attention_mask=False, | ||
return_token_type_ids=False, | ||
max_length=1, | ||
truncation=True, | ||
)["input_ids"] | ||
if len(input_ids) == 0: | ||
logger.warning( | ||
f"The specified target token `{target}` does not exist in the model vocabulary. " | ||
f"We cannot replace it with anything meaningful, ignoring it" | ||
) | ||
continue | ||
id_ = input_ids[0] | ||
# XXX: If users encounter this pass | ||
# it becomes pretty slow, so let's make sure | ||
# The warning enables them to fix the input to | ||
# get faster performance. | ||
logger.warning( | ||
f"The specified target token `{target}` does not exist in the model vocabulary. " | ||
f"Replacing with `{target_enc[0]}`." | ||
f"Replacing with `{self.tokenizer.convert_ids_to_tokens(id_)}`." | ||
) | ||
targets_proc.append(target_enc[0]) | ||
target_inds = np.array(self.tokenizer.convert_tokens_to_ids(targets_proc)) | ||
target_ids.append(id_) | ||
target_ids = list(set(target_ids)) | ||
if len(target_ids) == 0: | ||
raise ValueError("At least one target must be provided when passed.") | ||
target_ids = np.array(target_ids) | ||
# Cap top_k if there are targets | ||
if top_k > target_ids.shape[0]: | ||
top_k = target_ids.shape[0] | ||
Comment on lines
+165
to
+167
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice |
||
|
||
for i in range(batch_size): | ||
input_ids = inputs["input_ids"][i] | ||
|
@@ -147,14 +178,11 @@ def __call__(self, *args, targets=None, top_k: Optional[int] = None, **kwargs): | |
|
||
logits = outputs[i, masked_index.item(), :] | ||
probs = tf.nn.softmax(logits) | ||
if targets is None: | ||
topk = tf.math.top_k(probs, k=top_k if top_k is not None else self.top_k) | ||
values, predictions = topk.values.numpy(), topk.indices.numpy() | ||
else: | ||
values = tf.gather_nd(probs, tf.reshape(target_inds, (-1, 1))) | ||
sort_inds = tf.reverse(tf.argsort(values), [0]) | ||
values = tf.gather_nd(values, tf.reshape(sort_inds, (-1, 1))).numpy() | ||
predictions = target_inds[sort_inds.numpy()] | ||
if targets is not None: | ||
probs = tf.gather_nd(probs, tf.reshape(target_ids, (-1, 1))) | ||
|
||
topk = tf.math.top_k(probs, k=top_k) | ||
values, predictions = topk.values.numpy(), topk.indices.numpy() | ||
else: | ||
masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False) | ||
|
||
|
@@ -163,13 +191,11 @@ def __call__(self, *args, targets=None, top_k: Optional[int] = None, **kwargs): | |
|
||
logits = outputs[i, masked_index.item(), :] | ||
probs = logits.softmax(dim=0) | ||
if targets is None: | ||
values, predictions = probs.topk(top_k if top_k is not None else self.top_k) | ||
else: | ||
values = probs[..., target_inds] | ||
sort_inds = list(reversed(values.argsort(dim=-1))) | ||
values = values[..., sort_inds] | ||
predictions = target_inds[sort_inds] | ||
|
||
if targets is not None: | ||
probs = probs[..., target_ids] | ||
|
||
values, predictions = probs.topk(top_k) | ||
|
||
for v, p in zip(values.tolist(), predictions.tolist()): | ||
tokens = input_ids.numpy() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -78,7 +78,8 @@ def test_torch_fill_mask(self): | |
@require_torch | ||
def test_torch_fill_mask_with_targets(self): | ||
valid_inputs = ["My name is <mask>"] | ||
valid_targets = [[" Teven", " Patrick", " Clara"], [" Sam"]] | ||
# ' Sam' will yield a warning but work | ||
valid_targets = [[" Teven", "ĠPatrick", "ĠClara"], ["ĠSam"], [" Sam"]] | ||
invalid_targets = [[], [""], ""] | ||
for model_name in self.small_models: | ||
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt") | ||
|
@@ -89,10 +90,34 @@ def test_torch_fill_mask_with_targets(self): | |
for targets in invalid_targets: | ||
self.assertRaises(ValueError, unmasker, valid_inputs, targets=targets) | ||
|
||
@require_torch | ||
def test_torch_fill_mask_with_targets_and_topk(self): | ||
model_name = self.small_models[0] | ||
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt") | ||
targets = [" Teven", "ĠPatrick", "ĠClara"] | ||
top_k = 2 | ||
outputs = unmasker("My name is <mask>", targets=targets, top_k=top_k) | ||
|
||
self.assertEqual(len(outputs), 2) | ||
|
||
@require_torch | ||
def test_torch_fill_mask_with_duplicate_targets_and_topk(self): | ||
model_name = self.small_models[0] | ||
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt") | ||
# String duplicates + id duplicates | ||
targets = [" Teven", "ĠPatrick", "ĠClara", "ĠClara", " Clara"] | ||
top_k = 10 | ||
outputs = unmasker("My name is <mask>", targets=targets, top_k=top_k) | ||
|
||
# The target list contains duplicates, so we can't output more | ||
# than them | ||
self.assertEqual(len(outputs), 3) | ||
|
||
@require_tf | ||
def test_tf_fill_mask_with_targets(self): | ||
valid_inputs = ["My name is <mask>"] | ||
valid_targets = [[" Teven", " Patrick", " Clara"], [" Sam"]] | ||
# ' Sam' will yield a warning but work | ||
valid_targets = [[" Teven", "ĠPatrick", "ĠClara"], ["ĠSam"], [" Sam"]] | ||
invalid_targets = [[], [""], ""] | ||
for model_name in self.small_models: | ||
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf") | ||
|
@@ -111,7 +136,7 @@ def test_torch_fill_mask_results(self): | |
"My name is <mask>", | ||
"The largest city in France is <mask>", | ||
] | ||
valid_targets = [" Patrick", " Clara"] | ||
valid_targets = ["ĠPatrick", "ĠClara"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Those tests could stay the same, but it's actually missing from the vocabulary, so it's hitting the slow path and triggering warnings. the 2 small models don't share the same vocabulary, so if we want tests that avoid the slow path all the time we should have 2 version of the targets IMO. |
||
for model_name in self.large_models: | ||
unmasker = pipeline( | ||
task="fill-mask", | ||
|
@@ -184,7 +209,7 @@ def test_tf_fill_mask_results(self): | |
"My name is <mask>", | ||
"The largest city in France is <mask>", | ||
] | ||
valid_targets = [" Patrick", " Clara"] | ||
valid_targets = ["ĠPatrick", "ĠClara"] | ||
for model_name in self.large_models: | ||
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf", top_k=2) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think we could have a better exception than this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Depends if we want to be permissive or restrictive in the type of error we catch.
Ideally there would not be a try except. But I was not sure it was mandatory to have such a method.
If we consider this could fail for valid reasons (not implemented), I don't really think we should make assumptions about the type of exceptions. NotImplementedError, AttributeError, but it could very well be a signature issue or something else.
I think we should either keep it like this (very permissive in types of errors) or remove the try altogether and be very restrictive (tokenizers MUST implement a valid
get_vocab
)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright, works for me!