@@ -79,10 +79,6 @@ def __init__(self):
7979 "which could be very slow." )
8080 self .forward = self .forward_native
8181 else :
82- logger .info (
83- "Using approximate top-p optimized for TPU. Result may in "
84- "theory differ from the exact algorithm if there are "
85- "tokens with near-identical probabilities (< 1e-9 diff)." )
8682 self .forward = self .forward_tpu
8783 else :
8884 self .forward = self .forward_native
@@ -135,53 +131,35 @@ def apply_top_k_top_p_tpu(
135131 logits : torch .Tensor ,
136132 k : torch .Tensor ,
137133 p : torch .Tensor ,
138- ) -> torch .Tensor :
139- if k is not None :
140- logits = apply_top_k_only (logits , k )
141-
142- if p is not None :
143- logits = apply_approx_top_p (logits , p )
144-
145- return logits
146-
147-
148- def apply_approx_top_p (
149- logits : torch .Tensor ,
150- p : torch .Tensor ,
151134) -> torch .Tensor :
152135 """
153- Apply approximate top-p that is optimized for TPU.
136+ Apply top-k and top-p optimized for TPU.
154137
155138 This algorithm avoids using torch.scatter which is extremely slow on TPU.
156139 This is achieved by finding a "cut-off" element in the original logit, and
157140 after thresholding the logit using this cut-off, the remaining elements
158141 shall constitute the top-p set.
159142
160- A caveat of the above approach is that ties are not correctly handled --
161- if there are duplicate cutoff elements present in the logit, then the
162- resulting top-p set will be incorrect. To address this problem, we
163- introduce a tiny perturbation to the probabilities (after softmax) to
164- break any potential ties. The added perturbation is tiny so it should
165- not alter the end results significantly, but it still makes this algorithm
166- approximate rather than an exact one.
143+ Note: in the case of tie (i.e. multipple cut-off elements present in the
144+ logit), all tie elements are included in the top-p set. In other words,
145+ this function does not break ties. Instead, these tie tokens have equal
146+ chance of being chosen during final sampling, so we can consider the tie
147+ being broken then.
167148 """
168- probs = logits .softmax (dim = - 1 )
169-
170- # Add a small, random perturbation to the probabilities, and re-normalize.
171- epsilon = torch .empty (probs .shape ,
172- device = logits .device ).uniform_ (- 1e-9 , 1e-9 )
173- probs += epsilon
174- probs /= probs .sum (dim = - 1 , keepdim = True )
175-
176- probs_sort , sorted_idx = probs .sort (dim = - 1 , descending = False )
177- cumprob = torch .cumsum (probs_sort , dim = - 1 )
178- top_p_mask = cumprob <= 1 - p .unsqueeze (dim = 1 )
179- top_p_mask [:, - 1 ] = False # at least one
180-
181- top_p_count = top_p_mask .sum (dim = - 1 ).unsqueeze (1 )
182- top_p_cutoff = probs_sort .gather (- 1 , top_p_count )
183- elements_to_discard = probs < top_p_cutoff
184- logits .masked_fill_ (elements_to_discard , - float ("inf" ))
149+ if k is not None :
150+ logits = apply_top_k_only (logits , k )
151+
152+ if p is not None :
153+ probs = logits .softmax (dim = - 1 )
154+ probs_sort , _ = probs .sort (dim = - 1 , descending = False )
155+ cumprob = torch .cumsum (probs_sort , dim = - 1 )
156+ top_p_mask = cumprob <= 1 - p .unsqueeze (dim = 1 )
157+ top_p_mask [:, - 1 ] = False # at least one
158+
159+ top_p_count = top_p_mask .sum (dim = - 1 ).unsqueeze (1 )
160+ top_p_cutoff = probs_sort .gather (- 1 , top_p_count )
161+ elements_to_discard = probs < top_p_cutoff
162+ logits .masked_fill_ (elements_to_discard , - float ("inf" ))
185163
186164 return logits
187165
0 commit comments