Skip to content

Commit b942cf1

Browse files
committed
Change tie-breaking behavior
* Do not break ties. Instead, include all tied tokens in the return set and leave tie breaking to the final sampling stage (since all tie tokens will have equal probability of being chosen). * Removed random perturbation. * Removed warning regarding the algorithm being approx. * Edited tests. Signed-off-by: Hyesoo Yang <hyeygit@gmail.com>
1 parent a4bc75e commit b942cf1

File tree

2 files changed

+27
-46
lines changed

2 files changed

+27
-46
lines changed

tests/v1/tpu/test_topk_topp_sampler.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
BATCH_SIZE = 1024
1515
VOCAB_SIZE = 128 * 1024
16-
TOLERANCE = 1e-4
16+
TOLERANCE = 1e-6
1717

1818

1919
def test_topp_result_sums_past_p():
@@ -89,9 +89,12 @@ def test_topp_with_ties():
8989
k=torch.tensor([4]),
9090
p=torch.tensor([0.2]))
9191

92-
# Expect math.log(0.3) to be the only selected element.
93-
expected_result = torch.tensor([math.log(0.3)])
94-
assert torch.allclose(expected_result, result[result.isfinite()])
92+
# All tie values are included in the top-p set. Tie breaking is left
93+
# to be done during final sampling (all tie tokens have equal
94+
# probability of being chosen).
95+
expected_result = logits.clone()
96+
expected_result[0, 3] = float("-inf")
97+
assert torch.allclose(expected_result, result)
9598

9699

97100
def test_both_topk_topp():

vllm/v1/sample/ops/topk_topp_sampler.py

Lines changed: 20 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,6 @@ def __init__(self):
7979
"which could be very slow.")
8080
self.forward = self.forward_native
8181
else:
82-
logger.info(
83-
"Using approximate top-p optimized for TPU. Result may in "
84-
"theory differ from the exact algorithm if there are "
85-
"tokens with near-identical probabilities (< 1e-9 diff).")
8682
self.forward = self.forward_tpu
8783
else:
8884
self.forward = self.forward_native
@@ -135,53 +131,35 @@ def apply_top_k_top_p_tpu(
135131
logits: torch.Tensor,
136132
k: torch.Tensor,
137133
p: torch.Tensor,
138-
) -> torch.Tensor:
139-
if k is not None:
140-
logits = apply_top_k_only(logits, k)
141-
142-
if p is not None:
143-
logits = apply_approx_top_p(logits, p)
144-
145-
return logits
146-
147-
148-
def apply_approx_top_p(
149-
logits: torch.Tensor,
150-
p: torch.Tensor,
151134
) -> torch.Tensor:
152135
"""
153-
Apply approximate top-p that is optimized for TPU.
136+
Apply top-k and top-p optimized for TPU.
154137
155138
This algorithm avoids using torch.scatter which is extremely slow on TPU.
156139
This is achieved by finding a "cut-off" element in the original logit, and
157140
after thresholding the logit using this cut-off, the remaining elements
158141
shall constitute the top-p set.
159142
160-
A caveat of the above approach is that ties are not correctly handled --
161-
if there are duplicate cutoff elements present in the logit, then the
162-
resulting top-p set will be incorrect. To address this problem, we
163-
introduce a tiny perturbation to the probabilities (after softmax) to
164-
break any potential ties. The added perturbation is tiny so it should
165-
not alter the end results significantly, but it still makes this algorithm
166-
approximate rather than an exact one.
143+
Note: in the case of tie (i.e. multipple cut-off elements present in the
144+
logit), all tie elements are included in the top-p set. In other words,
145+
this function does not break ties. Instead, these tie tokens have equal
146+
chance of being chosen during final sampling, so we can consider the tie
147+
being broken then.
167148
"""
168-
probs = logits.softmax(dim=-1)
169-
170-
# Add a small, random perturbation to the probabilities, and re-normalize.
171-
epsilon = torch.empty(probs.shape,
172-
device=logits.device).uniform_(-1e-9, 1e-9)
173-
probs += epsilon
174-
probs /= probs.sum(dim=-1, keepdim=True)
175-
176-
probs_sort, sorted_idx = probs.sort(dim=-1, descending=False)
177-
cumprob = torch.cumsum(probs_sort, dim=-1)
178-
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
179-
top_p_mask[:, -1] = False # at least one
180-
181-
top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
182-
top_p_cutoff = probs_sort.gather(-1, top_p_count)
183-
elements_to_discard = probs < top_p_cutoff
184-
logits.masked_fill_(elements_to_discard, -float("inf"))
149+
if k is not None:
150+
logits = apply_top_k_only(logits, k)
151+
152+
if p is not None:
153+
probs = logits.softmax(dim=-1)
154+
probs_sort, _ = probs.sort(dim=-1, descending=False)
155+
cumprob = torch.cumsum(probs_sort, dim=-1)
156+
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
157+
top_p_mask[:, -1] = False # at least one
158+
159+
top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
160+
top_p_cutoff = probs_sort.gather(-1, top_p_count)
161+
elements_to_discard = probs < top_p_cutoff
162+
logits.masked_fill_(elements_to_discard, -float("inf"))
185163

186164
return logits
187165

0 commit comments

Comments
 (0)