@@ -1004,26 +1004,39 @@ class PromptLookupCandidateGenerator(CandidateGenerator):
10041004 Read the following blog post for more information: https://github.com/apoorvumang/prompt-lookup-decoding
10051005
10061006 Args:
1007- max_matching_ngram_size (`int` ):
1008- The maximum ngram size to be considered for matching in the prompt
1009- num_output_tokens (`int`):
1007+ eos_token_id (`torch.Tensor`, *optional* ):
1008+ The token id of the end of sequence token.
1009+ num_output_tokens (`int`, *optional*, defaults to 10 ):
10101010 The number of tokens to be output as candidate tokens.
1011- max_length (`int`):
1012- The number of total maximum tokens that can be generated. For decoder-only models that includes the prompt length.
1013- Defaults to 20, which is the max length used as default in generation config.
1011+ max_matching_ngram_size (`int`, *optional*, defaults to 2):
1012+ The maximum ngram size to be considered for matching in the prompt
1013+ max_length (`int`, *optional*, defaults to 20):
1014+ The number of total maximum tokens that can be generated. For decoder-only models that includes the
1015+ prompt length. Defaults to 20, which is the max length used as default in generation config.
1016+ logits_processor (`LogitsProcessorList`, *optional*):
1017+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
1018+ used to modify the prediction scores of the language modeling head applied at each generation step. In
1019+ prompt lookup assisted generation, they are not used to manipulate probabilities, but rather to find
1020+ forbidden tokens (p = -inf) and block them from being valid candidates.
1021+ vocab_size (`int`, *optional*):
1022+ The size of the vocabulary. Required if `logits_processor` is provided.
10141023 """
10151024
10161025 def __init__ (
10171026 self ,
10181027 eos_token_id : Optional [torch .Tensor ] = None ,
10191028 num_output_tokens : int = 10 ,
1020- max_matching_ngram_size : Optional [ int ] = None ,
1029+ max_matching_ngram_size : int = 2 ,
10211030 max_length : int = 20 ,
1031+ logits_processor : Optional ["LogitsProcessorList" ] = None ,
1032+ vocab_size : Optional [int ] = None ,
10221033 ):
10231034 self .num_output_tokens = num_output_tokens
1024- self .max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2
1035+ self .max_matching_ngram_size = max_matching_ngram_size
10251036 self .max_length = max_length
10261037 self .eos_token_id = eos_token_id
1038+ self .logits_processor = logits_processor
1039+ self .vocab_size = vocab_size
10271040
10281041 if self .max_matching_ngram_size <= 0 or self .num_output_tokens <= 0 :
10291042 raise ValueError ("Invalid max_matching_ngram_size or num_output_tokens" )
@@ -1039,7 +1052,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor,
10391052 Return:
10401053 `torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried.
10411054 """
1042- input_length = input_ids .size ( 1 )
1055+ bsz , input_length = input_ids .shape
10431056
10441057 # Don't generate more than `max_length - 1` candidates since the target model generates one extra token.
10451058 if self .max_length == input_length + 1 :
@@ -1061,13 +1074,43 @@ def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor,
10611074 match_indices = matches .nonzero (as_tuple = True )[1 ]
10621075
10631076 # Iterate through match indices to find a valid continuation
1077+ # TODO (joao): this finds the first valid candidates (left to right), but perhaps we should find the
1078+ # longest valid candidates?
10641079 for idx in match_indices :
10651080 start_idx = idx + ngram_size
10661081 end_idx = start_idx + self .num_output_tokens
10671082 end_idx = min (end_idx , input_length , self .max_length )
10681083
10691084 if start_idx < end_idx :
10701085 chosen_ids = input_ids [0 , start_idx :end_idx ]
1086+
1087+ # Check if the each new candidate token is forbidden according to the logits processor. If all
1088+ # tokens are allowed, we keep `chosen_ids` as is.
1089+ # 1. create random logits.
1090+ # 2. apply the logits processor to get output logits for the next token, using the arbitrary
1091+ # logits as input.
1092+ # 3. compare the output logits with the next candidate token. If they are -inf, then the next
1093+ # candidate token is forbidden and we don't want to generate it.
1094+ if self .logits_processor is not None :
1095+ sequence_with_candidate = input_ids
1096+ fake_input_logits = torch .ones (
1097+ (bsz , self .vocab_size ), device = input_ids .device , dtype = torch .float32
1098+ )
1099+ for candidate_idx , new_candidate_token in enumerate (chosen_ids ):
1100+ fake_output_logits = self .logits_processor (sequence_with_candidate , fake_input_logits )
1101+ fake_candidate_logits = fake_output_logits [0 , new_candidate_token ]
1102+ # next candidate token is forbidden -> crop chosen_ids accordingly
1103+ if fake_candidate_logits in (- float ("Inf" ), torch .finfo (fake_candidate_logits .dtype ).min ):
1104+ chosen_ids = chosen_ids [:candidate_idx ]
1105+ break
1106+ else :
1107+ sequence_with_candidate = torch .cat (
1108+ (input_ids , chosen_ids [: candidate_idx + 1 ].unsqueeze (0 )), dim = 1
1109+ )
1110+ # no valid candidate tokens -> look for a different match
1111+ if chosen_ids .shape [0 ] == 0 :
1112+ continue
1113+
10711114 match_found = True
10721115
10731116 # remove remaining candidate ids if an "eos" token is found, otherwise the target model may
@@ -1082,8 +1125,8 @@ def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor,
10821125 if match_found :
10831126 break
10841127
1085- if chosen_ids is None or len ( chosen_ids ) == 0 :
1086- # In case we didn't find a match return the input sequence unchanged, reverts back to autoregressive decoding
1128+ # In case we didn't find a match return the input sequence unchanged, reverts back to autoregressive decoding
1129+ if not match_found or len ( chosen_ids ) == 0 :
10871130 return input_ids , None
10881131
10891132 # Now need extend input_ids with chosen_ids
0 commit comments