Skip to content

Commit cfb690d

Browse files
LiuXiaoxuanPKUAkshat-Tripathi
authored andcommitted
[V1][Spec Decode] Change Spec Decode Rejection Sampling API (vllm-project#13729)
1 parent 1b1e51d commit cfb690d

File tree

8 files changed

+104
-111
lines changed

8 files changed

+104
-111
lines changed

tests/v1/sample/test_rejection_sampler.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata:
2929
temperature=torch.tensor([]),
3030
all_greedy=True,
3131
all_random=False,
32-
spec_token_ids=spec_tokens,
3332
top_p=None,
3433
top_k=None,
3534
min_p=torch.empty(batch_size, ),
@@ -55,7 +54,7 @@ def test_perfect_match(sampler):
5554
metadata = create_sampling_metadata(spec_tokens)
5655
logits = create_logits_tensor(output_tokens)
5756

58-
output = sampler(logits, metadata)
57+
output = sampler(spec_tokens, logits, metadata)
5958
expected = torch.tensor([[1, 2, 3, 4]],
6059
dtype=torch.int,
6160
device=logits.device)
@@ -70,7 +69,7 @@ def test_early_mismatch(sampler):
7069
metadata = create_sampling_metadata(spec_tokens)
7170
logits = create_logits_tensor(output_tokens)
7271

73-
output = sampler(logits, metadata)
72+
output = sampler(spec_tokens, logits, metadata)
7473
expected = torch.tensor([[1, 5, INVALID_TOKEN_ID, INVALID_TOKEN_ID]],
7574
dtype=torch.int,
7675
device=logits.device)
@@ -85,7 +84,7 @@ def test_multiple_sequences(sampler):
8584
metadata = create_sampling_metadata(spec_tokens)
8685
logits = create_logits_tensor(output_tokens)
8786

88-
output = sampler(logits, metadata)
87+
output = sampler(spec_tokens, logits, metadata)
8988
expected = torch.tensor([[1, 2, 5], [3, 4, INVALID_TOKEN_ID]],
9089
dtype=torch.int,
9190
device=logits.device)
@@ -100,7 +99,7 @@ def test_single_token_sequence(sampler):
10099
metadata = create_sampling_metadata(spec_tokens)
101100
logits = create_logits_tensor(output_tokens)
102101

103-
output = sampler(logits, metadata)
102+
output = sampler(spec_tokens, logits, metadata)
104103
expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device)
105104
assert torch.equal(output.sampled_token_ids, expected)
106105

@@ -113,7 +112,7 @@ def test_empty_sequence(sampler):
113112
metadata = create_sampling_metadata(spec_tokens)
114113
logits = create_logits_tensor(output_tokens)
115114

116-
output = sampler(logits, metadata)
115+
output = sampler(spec_tokens, logits, metadata)
117116
expected = torch.tensor([[5]], dtype=torch.int, device=logits.device)
118117
assert torch.equal(output.sampled_token_ids, expected)
119118

@@ -126,7 +125,7 @@ def test_multiple_mismatches(sampler):
126125
metadata = create_sampling_metadata(spec_tokens)
127126
logits = create_logits_tensor(output_tokens)
128127

129-
output = sampler(logits, metadata)
128+
output = sampler(spec_tokens, logits, metadata)
130129
expected = torch.tensor([[1, 2, 7, INVALID_TOKEN_ID],
131130
[4, 8, INVALID_TOKEN_ID, INVALID_TOKEN_ID]],
132131
dtype=torch.int,
@@ -147,7 +146,7 @@ def test_parametrized_cases(sampler, spec_tokens, output_tokens, expected):
147146
metadata = create_sampling_metadata(spec_tokens)
148147
logits = create_logits_tensor(output_tokens)
149148

150-
output = sampler(logits, metadata)
149+
output = sampler(spec_tokens, logits, metadata)
151150
expected_tensor = torch.tensor(expected,
152151
dtype=torch.int,
153152
device=logits.device)
@@ -163,7 +162,7 @@ def test_logits_shape_handling(sampler):
163162
metadata = create_sampling_metadata(spec_tokens)
164163
logits = create_logits_tensor(output_tokens, vocab_size)
165164

166-
output = sampler(logits, metadata)
165+
output = sampler(spec_tokens, logits, metadata)
167166
expected = torch.tensor([[1, 2, 3]], dtype=torch.int, device=logits.device)
168167
assert torch.equal(output.sampled_token_ids, expected)
169168
assert logits.shape[-1] == vocab_size

tests/v1/sample/test_sampler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ def _create_default_sampling_metadata(
105105
prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids,
106106
vocab_size, device),
107107
output_token_ids=output_token_ids,
108-
spec_token_ids=None,
109108
frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device),
110109
presence_penalties=_create_penalty_tensor(batch_size, 0.0, device),
111110
repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device),

tests/v1/worker/test_gpu_input_batch.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,6 @@ def _construct_expected_sampling_metadata(
123123
dtype=torch.float,
124124
device=device),
125125
output_token_ids=output_token_ids,
126-
spec_token_ids=None,
127126
min_tokens=min_tokens,
128127
no_penalties=(all(x == 0 for x in presence_penalties)
129128
and all(x == 0 for x in frequency_penalties)

vllm/v1/sample/metadata.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@ class SamplingMetadata:
1313
all_greedy: bool
1414
all_random: bool
1515

16-
# None when there are no speculated tokens.
17-
spec_token_ids: Optional[List[List[int]]]
18-
1916
top_p: Optional[torch.Tensor]
2017
top_k: Optional[torch.Tensor]
2118
min_p: Optional[torch.Tensor]

vllm/v1/sample/rejection_sampler.py

Lines changed: 68 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
2+
from typing import List
3+
24
import torch
35
import torch.nn as nn
46
from torch.nn.utils.rnn import pad_sequence
@@ -52,62 +54,62 @@ def __init__(self):
5254
else:
5355
self.forward_method = self.forward_native
5456

55-
def forward(self, logits: torch.Tensor,
57+
def forward(self, draft_token_ids: List[List[int]],
58+
target_probs: torch.Tensor,
5659
sampling_metadata: SamplingMetadata) -> SamplerOutput:
5760
if not sampling_metadata.all_greedy:
5861
raise NotImplementedError(
5962
"Currently, only greedy sampling is supported by "
6063
"rejection sampler.")
61-
return self.forward_method(logits, sampling_metadata)
64+
return self.forward_method(draft_token_ids, target_probs,
65+
sampling_metadata)
6266

6367
def flashinfer_sample(
6468
self,
65-
logits: torch.Tensor,
69+
draft_token_ids: List[List[int]],
70+
target_probs: torch.Tensor,
6671
sampling_metadata: SamplingMetadata,
6772
) -> SamplerOutput:
6873
# NOTE: The following input preparationg can be moved
6974
# to the model runner with a persistent manner for better
7075
# performance.
71-
assert sampling_metadata.spec_token_ids is not None
72-
spec_token_ids = sampling_metadata.spec_token_ids
73-
max_spec_len = max(len(s) for s in spec_token_ids)
74-
batch_size = len(spec_token_ids)
75-
draft_token_ids = torch.full((batch_size, max_spec_len),
76-
INVALID_TOKEN_ID,
77-
device="cpu",
78-
dtype=torch.long)
79-
80-
target_token_ids = torch.full((batch_size, max_spec_len + 1),
81-
fill_value=INVALID_TOKEN_ID,
82-
device=logits.device,
83-
dtype=torch.long)
84-
85-
# TODO: Vectorize the following loop for better performance.
86-
start_loc = 0
87-
for i in range(batch_size):
88-
num_spec_tokens = len(spec_token_ids[i])
89-
draft_token_ids[i, :num_spec_tokens] = torch.tensor(
90-
spec_token_ids[i], device="cpu", dtype=torch.long)
91-
end_loc = start_loc + num_spec_tokens + 1
92-
# Assume greedy sampling.
93-
target_token_ids[i, :num_spec_tokens + 1] = torch.argmax(
94-
logits[start_loc:end_loc], dim=-1)
95-
start_loc = end_loc
96-
97-
vocab_size = logits.size(-1)
98-
# NOTE: CPU <-> GPU synchronization happens here.
99-
draft_token_ids = draft_token_ids.to(logits.device)
100-
draft_probs = _create_greedy_token_probs(draft_token_ids, vocab_size,
101-
logits.device)
102-
target_probs = _create_greedy_token_probs(target_token_ids, vocab_size,
103-
logits.device)
104-
uniform_samples = torch.zeros(batch_size,
105-
max_spec_len + 1,
106-
device=logits.device)
76+
sample_lens = [len(x) + 1 for x in draft_token_ids]
77+
# Convert draft token IDs to a tensor, split by sample_lens, then pad.
78+
draft_token_ids = [
79+
torch.tensor(x, dtype=int, device='cpu') for x in draft_token_ids
80+
]
81+
draft_token_ids_tensor = pad_sequence(draft_token_ids,
82+
batch_first=True,
83+
padding_value=INVALID_TOKEN_ID)
84+
85+
if sampling_metadata.all_greedy:
86+
target_token_ids = target_probs.argmax(dim=-1).view(-1)
87+
target_token_ids = target_token_ids.split(sample_lens)
88+
target_token_ids = pad_sequence(target_token_ids,
89+
batch_first=True,
90+
padding_value=INVALID_TOKEN_ID)
91+
92+
vocab_size = target_probs.size(-1)
93+
# NOTE: CPU <-> GPU synchronization happens here.
94+
draft_token_ids_tensor = draft_token_ids_tensor.to(
95+
target_probs.device)
96+
draft_probs = _create_greedy_token_probs(draft_token_ids_tensor,
97+
vocab_size,
98+
target_probs.device)
99+
target_probs = _create_greedy_token_probs(target_token_ids,
100+
vocab_size,
101+
target_probs.device)
102+
uniform_samples = torch.zeros(draft_token_ids_tensor.size(0),
103+
draft_token_ids_tensor.size(1) + 1,
104+
device=target_probs.device)
105+
else:
106+
raise NotImplementedError(
107+
"Currently, only greedy sampling is supported by "
108+
"rejection sampler.")
107109

108110
sampled_token_ids, _, _ = fs.chain_speculative_sampling(
109111
draft_probs,
110-
draft_token_ids,
112+
draft_token_ids_tensor,
111113
uniform_samples,
112114
target_probs,
113115
)
@@ -117,35 +119,35 @@ def flashinfer_sample(
117119
# TODO: The following method can be optimized for better performance.
118120
def forward_native(
119121
self,
120-
logits: torch.Tensor,
122+
draft_token_ids: List[List[int]],
123+
target_probs: torch.Tensor,
121124
sampling_metadata: SamplingMetadata,
122125
) -> SamplerOutput:
123-
assert sampling_metadata.spec_token_ids is not None
124-
spec_lens = [len(x) for x in sampling_metadata.spec_token_ids]
125-
# Add 1 to include the 'bonus' token.
126-
sample_lens = [x + 1 for x in spec_lens]
127-
128-
output_token_ids = logits.argmax(dim=-1).view(-1)
129-
output_token_ids = output_token_ids.split(sample_lens)
130-
output_token_ids = pad_sequence(output_token_ids,
131-
batch_first=True,
132-
padding_value=INVALID_TOKEN_ID)
133-
134-
# Convert spec token IDs to a tensor, split by sample_lens, then pad.
135-
spec_token_ids = [
136-
torch.tensor(x,
137-
dtype=output_token_ids.dtype,
138-
device=output_token_ids.device)
139-
for x in sampling_metadata.spec_token_ids
126+
sample_lens = [len(x) + 1 for x in draft_token_ids]
127+
# Convert draft token IDs to a tensor, split by sample_lens, then pad.
128+
draft_token_ids = [
129+
torch.tensor(x, dtype=int, device='cpu') for x in draft_token_ids
140130
]
141-
spec_token_ids = pad_sequence(spec_token_ids,
142-
batch_first=True,
143-
padding_value=INVALID_TOKEN_ID)
144-
145-
# Produce a mask that remains 1 (True) until the first
146-
# mismatch (cumprod turns 0 after a mismatch).
147-
accept_mask = (output_token_ids[:, :-1] == spec_token_ids).cumprod(
148-
dim=1)
131+
draft_token_ids_tensor = pad_sequence(draft_token_ids,
132+
batch_first=True,
133+
padding_value=INVALID_TOKEN_ID)
134+
draft_token_ids_tensor = draft_token_ids_tensor.to(target_probs.device)
135+
# Add 1 to include the 'bonus' token.
136+
if sampling_metadata.all_greedy:
137+
output_token_ids = target_probs.argmax(dim=-1).view(-1)
138+
output_token_ids = output_token_ids.split(sample_lens)
139+
output_token_ids = pad_sequence(output_token_ids,
140+
batch_first=True,
141+
padding_value=INVALID_TOKEN_ID)
142+
# Produce a mask that remains 1 (True) until the first
143+
# mismatch (cumprod turns 0 after a mismatch).
144+
accept_mask = (
145+
output_token_ids[:, :-1] == draft_token_ids_tensor).cumprod(
146+
dim=1)
147+
else:
148+
raise NotImplementedError(
149+
"Currently, only greedy sampling is supported by "
150+
"rejection sampler.")
149151
# Identify valid positions (non-padding).
150152
valid_mask = output_token_ids != INVALID_TOKEN_ID
151153
# Generate mask with bonus token.

vllm/v1/sample/sampler.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from vllm.v1.sample.ops.penalties import (apply_all_penalties,
1010
apply_min_token_penalties)
1111
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
12-
from vllm.v1.sample.rejection_sampler import RejectionSampler
1312

1413
_SAMPLING_EPS = 1e-5
1514

@@ -19,22 +18,12 @@ class Sampler(nn.Module):
1918
def __init__(self):
2019
super().__init__()
2120
self.topk_topp_sampler = TopKTopPSampler()
22-
self.rejection_sampler = RejectionSampler()
2321

2422
def forward(
2523
self,
2624
logits: torch.Tensor,
2725
sampling_metadata: SamplingMetadata,
2826
) -> SamplerOutput:
29-
if sampling_metadata.spec_token_ids:
30-
if sampling_metadata.max_num_logprobs:
31-
raise NotImplementedError(
32-
"Rejection sampling does not support logprobs.")
33-
return self.rejection_sampler(
34-
logits,
35-
sampling_metadata,
36-
)
37-
3827
# NOTE(woosuk): Use the original logits (before any penalties or
3928
# temperature scaling) for the top-k logprobs.
4029
# This is different from the V0 sampler, which uses the logits that
@@ -127,6 +116,14 @@ def sample(
127116
)
128117
return sampled
129118

119+
def compute_probs(self, logits: torch.Tensor,
120+
sampling_metadata: SamplingMetadata) -> torch.Tensor:
121+
if sampling_metadata.all_greedy:
122+
return logits
123+
# Apply temperature. This is an in-place op changing logits.
124+
logits = self.apply_temperature(logits, sampling_metadata.temperature)
125+
return logits.softmax(dim=-1, dtype=torch.float32)
126+
130127
def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
131128
return logits.log_softmax(dim=-1, dtype=torch.float32)
132129

vllm/v1/worker/gpu_input_batch.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -490,23 +490,12 @@ def _make_sampling_metadata(self) -> SamplingMetadata:
490490
presence_penalties=self.presence_penalties[:num_reqs],
491491
repetition_penalties=self.repetition_penalties[:num_reqs],
492492
output_token_ids=cast(List[List[int]], self.req_output_token_ids),
493-
spec_token_ids=None,
494493
min_tokens=self.min_tokens,
495494
no_penalties=self.no_penalties,
496495
logit_bias=self.logit_bias[:num_reqs],
497496
allowed_token_ids_mask=allowed_token_ids_mask,
498497
)
499498

500-
def get_sampling_metadata(
501-
self,
502-
req_id_to_spec_token_ids: Dict[str, List[int]],
503-
) -> SamplingMetadata:
504-
# Set the new spec token ids in the cached sampling metadata.
505-
self.sampling_metadata.spec_token_ids = [
506-
req_id_to_spec_token_ids.get(req_id, []) for req_id in self.req_ids
507-
] if req_id_to_spec_token_ids else None
508-
return self.sampling_metadata
509-
510499
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
511500
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
512501
prompt_token_ids_cpu_tensor = torch.empty(

0 commit comments

Comments
 (0)