@@ -3088,68 +3088,6 @@ def contrastive_search_body_fn(
30883088 return generated
30893089
30903090
3091- def tf_top_k_top_p_filtering (logits , top_k = 0 , top_p = 1.0 , filter_value = - float ("Inf" ), min_tokens_to_keep = 1 ):
3092- """
3093- Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
3094-
3095- Args:
3096- logits: logits distribution shape (batch size, vocabulary size)
3097- top_k (`int`, *optional*, defaults to 0):
3098- If > 0, only keep the top k tokens with highest probability (top-k filtering)
3099- top_p (`float`, *optional*, defaults to 1.0):
3100- If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus
3101- filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
3102- min_tokens_to_keep (`int`, *optional*, defaults to 1):
3103- Minimumber of tokens we keep per batch example in the output.
3104-
3105- From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
3106- """
3107-
3108- warnings .warn (
3109- "`tf_top_k_top_p_filtering` is scheduled for deletion in v4.39. Use `TFTopKLogitsWarper` and "
3110- "`TFTopPLogitsWarper` instead." ,
3111- DeprecationWarning ,
3112- )
3113-
3114- logits_shape = shape_list (logits )
3115-
3116- if top_k > 0 :
3117- top_k = min (max (top_k , min_tokens_to_keep ), logits_shape [- 1 ]) # Safety check
3118- # Remove all tokens with a probability less than the last token of the top-k
3119- indices_to_remove = logits < tf .math .top_k (logits , k = top_k )[0 ][..., - 1 , None ]
3120- logits = tf .where (indices_to_remove , filter_value , logits )
3121- if top_p < 1.0 :
3122- sorted_indices = tf .argsort (logits , direction = "DESCENDING" )
3123- sorted_logits = tf .gather (
3124- logits , sorted_indices , axis = - 1 , batch_dims = 1
3125- ) # expects logits to be of dim (batch_size, vocab_size)
3126-
3127- cumulative_probs = tf .math .cumsum (stable_softmax (sorted_logits , axis = - 1 ), axis = - 1 )
3128-
3129- # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
3130- sorted_indices_to_remove = cumulative_probs > top_p
3131-
3132- if min_tokens_to_keep > 1 :
3133- # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
3134- sorted_indices_to_remove = tf .concat (
3135- [
3136- tf .zeros_like (sorted_indices_to_remove [:, :min_tokens_to_keep ]),
3137- sorted_indices_to_remove [:, min_tokens_to_keep :],
3138- ],
3139- - 1 ,
3140- )
3141-
3142- # Shift the indices to the right to keep also the first token above the threshold
3143- sorted_indices_to_remove = tf .concat (
3144- [tf .zeros_like (sorted_indices_to_remove [:, :1 ]), sorted_indices_to_remove [:, :- 1 ]],
3145- - 1 ,
3146- )
3147- # scatter sorted tensors to original indexing
3148- indices_to_remove = scatter_values_on_batch_indices (sorted_indices_to_remove , sorted_indices )
3149- logits = tf .where (indices_to_remove , filter_value , logits )
3150- return logits
3151-
3152-
31533091def scatter_values_on_batch_indices (values , batch_indices ):
31543092 shape = shape_list (batch_indices )
31553093 # broadcast batch dim to shape
0 commit comments