Skip to content

Commit ac3dbfa

Browse files
[V1][spec decode] return logprobs for spec decoding
Signed-off-by: Giancarlo Delfin <gdelfin@meta.com>
1 parent f6b3bcb commit ac3dbfa

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

vllm/v1/sample/rejection_sampler.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@ class RejectionSampler(nn.Module):
5151
def __init__(self, sampler: Sampler):
5252
super().__init__()
5353
self.sampler = sampler
54-
self.return_processed_logprobs = self.sampler.logprobs_mode.startswith(
55-
"processed"
56-
)
54+
logprobs_mode = self.sampler.logprobs_mode
55+
self.is_processed_logprobs_mode = logprobs_mode.startswith("processed")
56+
self.is_logits_logprobs_mode = logprobs_mode.endswith("logits")
5757

5858
def forward(
5959
self,
@@ -107,7 +107,7 @@ def forward(
107107
# Override the logprobs mode to return logits because they are
108108
# needed later to compute the accepted token logprobs.
109109
logprobs_mode_override="processed_logits"
110-
if self.return_processed_logprobs
110+
if self.is_processed_logprobs_mode
111111
else "raw_logits",
112112
)
113113
bonus_logits = bonus_sampler_output.logprobs_tensors.logprobs
@@ -150,7 +150,7 @@ def forward(
150150
sampling_metadata,
151151
metadata,
152152
logits,
153-
target_logits if self.return_processed_logprobs else raw_target_logits,
153+
target_logits if self.is_processed_logprobs_mode else raw_target_logits,
154154
bonus_logits,
155155
output_token_ids,
156156
),
@@ -190,7 +190,7 @@ def _get_logprobs_tensors(
190190
accepted_logits = final_logits[accepted_logit_indices]
191191
accepted_logprobs = (
192192
accepted_logits
193-
if self.logprobs_mode.endswith("logits")
193+
if self.is_logits_logprobs_mode
194194
else self.sampler.compute_logprobs(accepted_logits)
195195
)
196196
accepted_tokens = sampled_token_ids[accepted_mask]

0 commit comments

Comments
 (0)