Skip to content

Commit 6a854c7

Browse files
authored
[V1][Sampler] Don't apply temp for greedy-only (#13311)
Signed-off-by: Nick Hill <nhill@redhat.com>
1 parent e7eea5a commit 6a854c7

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

vllm/v1/sample/sampler.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@ def forward(
4141
logits = self.apply_logits_bias(logits, sampling_metadata)
4242
# Apply penalties (e.g., min_tokens, freq_penalties).
4343
logits = self.apply_penalties(logits, sampling_metadata)
44-
# Apply temperature.
45-
logits = self.apply_temperature(logits, sampling_metadata.temperature)
4644
# Sample the next token.
4745
sampled = self.sample(logits, sampling_metadata)
4846

@@ -82,9 +80,21 @@ def sample(
8280
) -> torch.Tensor:
8381
assert not (sampling_metadata.all_greedy
8482
and sampling_metadata.all_random)
85-
if sampling_metadata.all_greedy:
86-
return self.greedy_sample(logits)
83+
if sampling_metadata.all_random:
84+
greedy_sampled = None
85+
else:
86+
greedy_sampled = self.greedy_sample(logits)
87+
if sampling_metadata.all_greedy:
88+
return greedy_sampled
8789

90+
# Apply temperature.
91+
logits = self.apply_temperature(logits, sampling_metadata.temperature)
92+
93+
# Apply min_p.
94+
if not sampling_metadata.no_min_p:
95+
logits = self.apply_min_p(logits, sampling_metadata.min_p)
96+
97+
# Apply top_k and/or top_p.
8898
random_sampled = self.topk_topp_sampler(
8999
logits,
90100
sampling_metadata.generators,
@@ -94,13 +104,9 @@ def sample(
94104
sampling_metadata.top_p,
95105
)
96106

97-
if not sampling_metadata.no_min_p:
98-
logits = self.apply_min_p(logits, sampling_metadata.min_p)
99-
100-
if sampling_metadata.all_random:
107+
if greedy_sampled is None:
101108
return random_sampled
102109

103-
greedy_sampled = self.greedy_sample(logits)
104110
sampled = torch.where(
105111
sampling_metadata.temperature < _SAMPLING_EPS,
106112
greedy_sampled,

0 commit comments

Comments
 (0)