@@ -48,7 +48,12 @@ def _prepare_decoder_ids_for_generation(
4848 return torch .ones ((batch_size , 1 ), dtype = torch .long , device = device ) * pad_idx
4949
5050 def greedy_search (
51- self , input_ids : torch .Tensor , max_length : int , eos_idx : int , pad_idx : int , ** model_kwargs
51+ self ,
52+ input_ids : torch .Tensor ,
53+ max_length : int ,
54+ eos_idx : int ,
55+ pad_idx : int ,
56+ ** model_kwargs ,
5257 ) -> torch .Tensor :
5358 """Greedy search decoding for text generation. Takes the most likely next token every time.
5459
@@ -117,7 +122,6 @@ def generate(
117122 max_length (int): Max length to generate responses.
118123 pad_idx (int): Padding index. Defaults to 0.
119124 eos_idx (int): End of sequence index. Defaults to 1.
120-
121125 Returns:
122126 Tensor of Tensors containing output sequences as ids.
123127
@@ -138,8 +142,70 @@ def generate(
138142 max_length = DEFAULT_MAX_SEQ_LEN
139143
140144 if num_beams == 1 or num_beams is None :
141- return self .greedy_search (inputs , max_length , eos_idx , pad_idx = pad_idx , ** model_kwargs )
145+ return self .greedy_search (
146+ inputs ,
147+ max_length ,
148+ eos_idx ,
149+ pad_idx = pad_idx ,
150+ ** model_kwargs ,
151+ )
142152 elif num_beams > 1 :
143153 return self .beam_search (inputs , num_beams , max_length )
144154 else :
145155 raise ValueError ("`num_beams` must be >= 1." )
156+
157+ def _get_top_k_restriction (self , scores : torch .Tensor , top_k : int ) -> torch .Tensor :
158+ """Returns a copy of `scores` restricted to its k highest values (meaning every other value is zeroed)
159+
160+ Args:
161+ scores (Tensor): typically the output logits or probabilities for a language model's vocabulary.
162+ top_k (int): the number of highest values to keep.
163+
164+ Returns:
165+ A copy of `scores` restricted to its k highest values
166+ """
167+ top_k = min (top_k , scores .size (- 1 ))
168+ if top_k <= 0 :
169+ raise ValueError (f"`top_k` is { top_k } but should be an int greater than 0" )
170+ indices_to_remove = scores < torch .topk (scores , top_k )[0 ][:, - 1 , None ]
171+ return scores .masked_fill (indices_to_remove , 0 )
172+
173+ def _get_top_p_restriction (self , probs : torch .Tensor , top_p : float , min_tokens_to_keep : int = 1 ) -> torch .Tensor :
174+ """Returns a copy of `probs` restricted to the top indices whose values sum up to `top_p`
175+ (meaning the value at any other index is zeroed)
176+
177+ Args:
178+ probs (Tensor): output probabilities for a language model vocabulary.
179+ top_p (float): the (cumulative) threshold for cutting off top-value indices; between 0 and 1.
180+
181+ Returns:
182+ A copy of `probs` restricted to the top indices whose values sum up to `top_p`
183+ """
184+ if top_p < 0 or top_p > 1 :
185+ raise ValueError (f"`top_p` is { top_p } but should be a float between 0 and 1" )
186+ sorted_probs , sorted_indices = torch .sort (probs , descending = True , dim = - 1 )
187+ cumulative_probs = sorted_probs .cumsum (dim = - 1 )
188+
189+ sorted_indices_to_keep = cumulative_probs <= top_p
190+ sorted_indices_to_keep [:, :min_tokens_to_keep ] = True
191+ indices_to_remove = ~ sorted_indices_to_keep .scatter (- 1 , sorted_indices , sorted_indices_to_keep )
192+
193+ return probs .masked_fill (indices_to_remove , 0 )
194+
195+ def _apply_temperature (self , probs : torch .Tensor , temperature : float ) -> torch .Tensor :
196+ """Applies temperature scaling to `probs`
197+ Args:
198+ probs (Tensor): output probabilities for a language model vocabulary.
199+ temperature (float): value of temperature applied to the distribution.
200+ Returns:
201+ A copy of `probs` with applied `temperature`
202+ """
203+ if not temperature > 0 :
204+ raise ValueError (f"`temperature` is { temperature } but should be positive" )
205+ return probs / temperature
206+
207+ def _remove_invalid_values (self , scores : torch .Tensor ) -> torch .Tensor :
208+ """Removes nan and inf values to prevent generation from failing when using sampling"""
209+ scores [scores != scores ] = 0.0
210+ scores [scores == float ("inf" )] = torch .finfo (scores .dtype ).max
211+ return scores
0 commit comments