Skip to content

Commit ed10021

Browse files
authored
[generate] PromptLookupCandidateGenerator won't generate forbidden tokens (#40726)
* no longer flaky :) * PR comments * any token-blocking logits processor works * ? * default * -_- * create fake tensors once
1 parent 82d66e5 commit ed10021

File tree

7 files changed

+60
-58
lines changed

7 files changed

+60
-58
lines changed

src/transformers/generation/candidate_generator.py

Lines changed: 54 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,26 +1004,39 @@ class PromptLookupCandidateGenerator(CandidateGenerator):
10041004
Read the following blog post for more information: https://github.com/apoorvumang/prompt-lookup-decoding
10051005
10061006
Args:
1007-
max_matching_ngram_size (`int`):
1008-
The maximum ngram size to be considered for matching in the prompt
1009-
num_output_tokens (`int`):
1007+
eos_token_id (`torch.Tensor`, *optional*):
1008+
The token id of the end of sequence token.
1009+
num_output_tokens (`int`, *optional*, defaults to 10):
10101010
The number of tokens to be output as candidate tokens.
1011-
max_length (`int`):
1012-
The number of total maximum tokens that can be generated. For decoder-only models that includes the prompt length.
1013-
Defaults to 20, which is the max length used as default in generation config.
1011+
max_matching_ngram_size (`int`, *optional*, defaults to 2):
1012+
The maximum ngram size to be considered for matching in the prompt
1013+
max_length (`int`, *optional*, defaults to 20):
1014+
The number of total maximum tokens that can be generated. For decoder-only models that includes the
1015+
prompt length. Defaults to 20, which is the max length used as default in generation config.
1016+
logits_processor (`LogitsProcessorList`, *optional*):
1017+
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
1018+
used to modify the prediction scores of the language modeling head applied at each generation step. In
1019+
prompt lookup assisted generation, they are not used to manipulate probabilities, but rather to find
1020+
forbidden tokens (p = -inf) and block them from being valid candidates.
1021+
vocab_size (`int`, *optional*):
1022+
The size of the vocabulary. Required if `logits_processor` is provided.
10141023
"""
10151024

10161025
def __init__(
10171026
self,
10181027
eos_token_id: Optional[torch.Tensor] = None,
10191028
num_output_tokens: int = 10,
1020-
max_matching_ngram_size: Optional[int] = None,
1029+
max_matching_ngram_size: int = 2,
10211030
max_length: int = 20,
1031+
logits_processor: Optional["LogitsProcessorList"] = None,
1032+
vocab_size: Optional[int] = None,
10221033
):
10231034
self.num_output_tokens = num_output_tokens
1024-
self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2
1035+
self.max_matching_ngram_size = max_matching_ngram_size
10251036
self.max_length = max_length
10261037
self.eos_token_id = eos_token_id
1038+
self.logits_processor = logits_processor
1039+
self.vocab_size = vocab_size
10271040

10281041
if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0:
10291042
raise ValueError("Invalid max_matching_ngram_size or num_output_tokens")
@@ -1039,7 +1052,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor,
10391052
Return:
10401053
`torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried.
10411054
"""
1042-
input_length = input_ids.size(1)
1055+
bsz, input_length = input_ids.shape
10431056

10441057
# Don't generate more than `max_length - 1` candidates since the target model generates one extra token.
10451058
if self.max_length == input_length + 1:
@@ -1061,13 +1074,43 @@ def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor,
10611074
match_indices = matches.nonzero(as_tuple=True)[1]
10621075

10631076
# Iterate through match indices to find a valid continuation
1077+
# TODO (joao): this finds the first valid candidates (left to right), but perhaps we should find the
1078+
# longest valid candidates?
10641079
for idx in match_indices:
10651080
start_idx = idx + ngram_size
10661081
end_idx = start_idx + self.num_output_tokens
10671082
end_idx = min(end_idx, input_length, self.max_length)
10681083

10691084
if start_idx < end_idx:
10701085
chosen_ids = input_ids[0, start_idx:end_idx]
1086+
1087+
# Check if the each new candidate token is forbidden according to the logits processor. If all
1088+
# tokens are allowed, we keep `chosen_ids` as is.
1089+
# 1. create random logits.
1090+
# 2. apply the logits processor to get output logits for the next token, using the arbitrary
1091+
# logits as input.
1092+
# 3. compare the output logits with the next candidate token. If they are -inf, then the next
1093+
# candidate token is forbidden and we don't want to generate it.
1094+
if self.logits_processor is not None:
1095+
sequence_with_candidate = input_ids
1096+
fake_input_logits = torch.ones(
1097+
(bsz, self.vocab_size), device=input_ids.device, dtype=torch.float32
1098+
)
1099+
for candidate_idx, new_candidate_token in enumerate(chosen_ids):
1100+
fake_output_logits = self.logits_processor(sequence_with_candidate, fake_input_logits)
1101+
fake_candidate_logits = fake_output_logits[0, new_candidate_token]
1102+
# next candidate token is forbidden -> crop chosen_ids accordingly
1103+
if fake_candidate_logits in (-float("Inf"), torch.finfo(fake_candidate_logits.dtype).min):
1104+
chosen_ids = chosen_ids[:candidate_idx]
1105+
break
1106+
else:
1107+
sequence_with_candidate = torch.cat(
1108+
(input_ids, chosen_ids[: candidate_idx + 1].unsqueeze(0)), dim=1
1109+
)
1110+
# no valid candidate tokens -> look for a different match
1111+
if chosen_ids.shape[0] == 0:
1112+
continue
1113+
10711114
match_found = True
10721115

10731116
# remove remaining candidate ids if an "eos" token is found, otherwise the target model may
@@ -1082,8 +1125,8 @@ def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor,
10821125
if match_found:
10831126
break
10841127

1085-
if chosen_ids is None or len(chosen_ids) == 0:
1086-
# In case we didn't find a match return the input sequence unchanged, reverts back to autoregressive decoding
1128+
# In case we didn't find a match return the input sequence unchanged, reverts back to autoregressive decoding
1129+
if not match_found or len(chosen_ids) == 0:
10871130
return input_ids, None
10881131

10891132
# Now need extend input_ids with chosen_ids

src/transformers/generation/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1036,8 +1036,10 @@ def _get_candidate_generator(
10361036
candidate_generator = PromptLookupCandidateGenerator(
10371037
eos_token_id=generation_config._eos_token_tensor,
10381038
num_output_tokens=generation_config.prompt_lookup_num_tokens,
1039-
max_matching_ngram_size=generation_config.max_matching_ngram_size,
1039+
max_matching_ngram_size=generation_config.max_matching_ngram_size or 2,
10401040
max_length=generation_config.max_length,
1041+
logits_processor=logits_processor,
1042+
vocab_size=self.config.get_text_config().vocab_size,
10411043
)
10421044
elif different_tokenizers:
10431045
if generation_config.do_sample is True:

tests/generation/test_utils.py

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -779,27 +779,6 @@ def test_prompt_lookup_decoding_matches_greedy_search(self):
779779
"blip2", # overridden `generate()` for all BLIP models
780780
"instructblip",
781781
"instructblipvideo",
782-
# TODO: The list is growing huge 🙃! Let's try to check if the config has any of audio/image/video token id and skip the test!
783-
# All models below: shouldn't suggest image tokens. Can be fixed by passing `suppress_ids` to candidate generator: @joaa @raushan
784-
"llava",
785-
"idefics2",
786-
"idefics3",
787-
"mllama",
788-
"paligemma",
789-
"emu3",
790-
"gotocr2",
791-
"qwen2vl",
792-
"qwen2_5_vl",
793-
"ayavision",
794-
"janus",
795-
"gemma3",
796-
"mistral3",
797-
"chameleon",
798-
"internvl",
799-
"qwen2_5omni", # the file is named `qwen2_5_omni`, but the model class is `Qwen2_5Omni`,
800-
# All models below: shouldn't suggest audio tokens. Can be fixed by passing `suppress_ids` to candidate generator: @joaa @raushan
801-
"voxtral",
802-
"qwen2audio",
803782
]
804783
):
805784
self.skipTest(reason="May fix in the future: need model-specific fixes")
@@ -835,11 +814,12 @@ def test_prompt_lookup_decoding_matches_greedy_search(self):
835814
"return_dict_in_generate": True,
836815
"use_cache": True,
837816
}
817+
logits_processor_kwargs = self._get_logits_processor_kwargs(config=model.config)
838818

839-
output_greedy = model.generate(**generation_kwargs, **inputs_dict)
819+
output_greedy = model.generate(**generation_kwargs, **inputs_dict, **logits_processor_kwargs)
840820

841821
generation_kwargs.update({"prompt_lookup_num_tokens": 2}) # see b)
842-
output_prompt_lookup = model.generate(**generation_kwargs, **inputs_dict)
822+
output_prompt_lookup = model.generate(**generation_kwargs, **inputs_dict, **logits_processor_kwargs)
843823

844824
# The two outputs must match and their shape must be as expected
845825
self.assertTrue(has_similar_generate_outputs(output_greedy, output_prompt_lookup))

tests/models/idefics2/test_modeling_idefics2.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -390,12 +390,6 @@ def test_flash_attn_2_generate_padding_right(self):
390390
def test_flash_attn_2_inference_padding_right(self):
391391
pass
392392

393-
@unittest.skip(
394-
reason="Prompt lookup decoding needs a way to indicate `bad_word_ids` that should not be suggested as candidates"
395-
)
396-
def test_prompt_lookup_decoding_matches_greedy_search(self):
397-
pass
398-
399393
@pytest.mark.generate
400394
@slow
401395
@unittest.skip(

tests/models/idefics3/test_modeling_idefics3.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -351,12 +351,6 @@ def test_inputs_embeds():
351351
def test_flash_attn_2_inference_padding_right(self):
352352
pass
353353

354-
@unittest.skip(
355-
reason="Prompt lookup decoding needs a way to indicate `bad_word_ids` that should not be suggested as candidates"
356-
)
357-
def test_prompt_lookup_decoding_matches_greedy_search(self):
358-
pass
359-
360354
@pytest.mark.generate
361355
@slow
362356
@unittest.skip(

tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from transformers.testing_utils import (
3131
Expectations,
3232
cleanup,
33-
is_flaky,
3433
require_cv2,
3534
require_flash_attn,
3635
require_torch,
@@ -446,10 +445,6 @@ def test_multi_gpu_data_parallel_forward(self):
446445
def test_model_is_small(self):
447446
pass
448447

449-
@is_flaky() # TODO (joao/raushan): Investigate why this test is flaky on this model
450-
def test_prompt_lookup_decoding_matches_greedy_search(self):
451-
super().test_prompt_lookup_decoding_matches_greedy_search()
452-
453448

454449
@require_torch
455450
class Qwen2_5_VLIntegrationTest(unittest.TestCase):

tests/models/smolvlm/test_modeling_smolvlm.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -345,12 +345,6 @@ def setUp(self):
345345
def test_flash_attn_2_inference_padding_right(self):
346346
pass
347347

348-
@unittest.skip(
349-
reason="Prompt lookup decoding needs a way to indicate `bad_word_ids` that should not be suggested as candidates"
350-
)
351-
def test_prompt_lookup_decoding_matches_greedy_search(self):
352-
pass
353-
354348
@pytest.mark.generate
355349
@is_flaky(description="TODO: check why flaky")
356350
def test_generate_methods_with_logits_to_keep(self):

0 commit comments

Comments
 (0)