Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions tests/v1/sample/test_rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata:
temperature=torch.tensor([]),
all_greedy=True,
all_random=False,
spec_token_ids=spec_tokens,
top_p=None,
top_k=None,
min_p=torch.empty(batch_size, ),
Expand All @@ -55,7 +54,7 @@ def test_perfect_match(sampler):
metadata = create_sampling_metadata(spec_tokens)
logits = create_logits_tensor(output_tokens)

output = sampler(logits, metadata)
output = sampler(spec_tokens, logits, metadata)
expected = torch.tensor([[1, 2, 3, 4]],
dtype=torch.int,
device=logits.device)
Expand All @@ -70,7 +69,7 @@ def test_early_mismatch(sampler):
metadata = create_sampling_metadata(spec_tokens)
logits = create_logits_tensor(output_tokens)

output = sampler(logits, metadata)
output = sampler(spec_tokens, logits, metadata)
expected = torch.tensor([[1, 5, INVALID_TOKEN_ID, INVALID_TOKEN_ID]],
dtype=torch.int,
device=logits.device)
Expand All @@ -85,7 +84,7 @@ def test_multiple_sequences(sampler):
metadata = create_sampling_metadata(spec_tokens)
logits = create_logits_tensor(output_tokens)

output = sampler(logits, metadata)
output = sampler(spec_tokens, logits, metadata)
expected = torch.tensor([[1, 2, 5], [3, 4, INVALID_TOKEN_ID]],
dtype=torch.int,
device=logits.device)
Expand All @@ -100,7 +99,7 @@ def test_single_token_sequence(sampler):
metadata = create_sampling_metadata(spec_tokens)
logits = create_logits_tensor(output_tokens)

output = sampler(logits, metadata)
output = sampler(spec_tokens, logits, metadata)
expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device)
assert torch.equal(output.sampled_token_ids, expected)

Expand All @@ -113,7 +112,7 @@ def test_empty_sequence(sampler):
metadata = create_sampling_metadata(spec_tokens)
logits = create_logits_tensor(output_tokens)

output = sampler(logits, metadata)
output = sampler(spec_tokens, logits, metadata)
expected = torch.tensor([[5]], dtype=torch.int, device=logits.device)
assert torch.equal(output.sampled_token_ids, expected)

Expand All @@ -126,7 +125,7 @@ def test_multiple_mismatches(sampler):
metadata = create_sampling_metadata(spec_tokens)
logits = create_logits_tensor(output_tokens)

output = sampler(logits, metadata)
output = sampler(spec_tokens, logits, metadata)
expected = torch.tensor([[1, 2, 7, INVALID_TOKEN_ID],
[4, 8, INVALID_TOKEN_ID, INVALID_TOKEN_ID]],
dtype=torch.int,
Expand All @@ -147,7 +146,7 @@ def test_parametrized_cases(sampler, spec_tokens, output_tokens, expected):
metadata = create_sampling_metadata(spec_tokens)
logits = create_logits_tensor(output_tokens)

output = sampler(logits, metadata)
output = sampler(spec_tokens, logits, metadata)
expected_tensor = torch.tensor(expected,
dtype=torch.int,
device=logits.device)
Expand All @@ -163,7 +162,7 @@ def test_logits_shape_handling(sampler):
metadata = create_sampling_metadata(spec_tokens)
logits = create_logits_tensor(output_tokens, vocab_size)

output = sampler(logits, metadata)
output = sampler(spec_tokens, logits, metadata)
expected = torch.tensor([[1, 2, 3]], dtype=torch.int, device=logits.device)
assert torch.equal(output.sampled_token_ids, expected)
assert logits.shape[-1] == vocab_size
1 change: 0 additions & 1 deletion tests/v1/sample/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ def _create_default_sampling_metadata(
prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids,
vocab_size, device),
output_token_ids=output_token_ids,
spec_token_ids=None,
frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device),
presence_penalties=_create_penalty_tensor(batch_size, 0.0, device),
repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device),
Expand Down
1 change: 0 additions & 1 deletion tests/v1/worker/test_gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ def _construct_expected_sampling_metadata(
dtype=torch.float,
device=device),
output_token_ids=output_token_ids,
spec_token_ids=None,
min_tokens=min_tokens,
no_penalties=(all(x == 0 for x in presence_penalties)
and all(x == 0 for x in frequency_penalties)
Expand Down
3 changes: 0 additions & 3 deletions vllm/v1/sample/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@ class SamplingMetadata:
all_greedy: bool
all_random: bool

# None when there are no speculated tokens.
spec_token_ids: Optional[List[List[int]]]

top_p: Optional[torch.Tensor]
top_k: Optional[torch.Tensor]
min_p: Optional[torch.Tensor]
Expand Down
134 changes: 68 additions & 66 deletions vllm/v1/sample/rejection_sampler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import List

import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
Expand Down Expand Up @@ -52,62 +54,62 @@ def __init__(self):
else:
self.forward_method = self.forward_native

def forward(self, logits: torch.Tensor,
def forward(self, draft_token_ids: List[List[int]],
target_probs: torch.Tensor,
sampling_metadata: SamplingMetadata) -> SamplerOutput:
if not sampling_metadata.all_greedy:
raise NotImplementedError(
"Currently, only greedy sampling is supported by "
"rejection sampler.")
return self.forward_method(logits, sampling_metadata)
return self.forward_method(draft_token_ids, target_probs,
sampling_metadata)

def flashinfer_sample(
self,
logits: torch.Tensor,
draft_token_ids: List[List[int]],
target_probs: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
# NOTE: The following input preparationg can be moved
# to the model runner with a persistent manner for better
# performance.
assert sampling_metadata.spec_token_ids is not None
spec_token_ids = sampling_metadata.spec_token_ids
max_spec_len = max(len(s) for s in spec_token_ids)
batch_size = len(spec_token_ids)
draft_token_ids = torch.full((batch_size, max_spec_len),
INVALID_TOKEN_ID,
device="cpu",
dtype=torch.long)

target_token_ids = torch.full((batch_size, max_spec_len + 1),
fill_value=INVALID_TOKEN_ID,
device=logits.device,
dtype=torch.long)

# TODO: Vectorize the following loop for better performance.
start_loc = 0
for i in range(batch_size):
num_spec_tokens = len(spec_token_ids[i])
draft_token_ids[i, :num_spec_tokens] = torch.tensor(
spec_token_ids[i], device="cpu", dtype=torch.long)
end_loc = start_loc + num_spec_tokens + 1
# Assume greedy sampling.
target_token_ids[i, :num_spec_tokens + 1] = torch.argmax(
logits[start_loc:end_loc], dim=-1)
start_loc = end_loc

vocab_size = logits.size(-1)
# NOTE: CPU <-> GPU synchronization happens here.
draft_token_ids = draft_token_ids.to(logits.device)
draft_probs = _create_greedy_token_probs(draft_token_ids, vocab_size,
logits.device)
target_probs = _create_greedy_token_probs(target_token_ids, vocab_size,
logits.device)
uniform_samples = torch.zeros(batch_size,
max_spec_len + 1,
device=logits.device)
sample_lens = [len(x) + 1 for x in draft_token_ids]
# Convert draft token IDs to a tensor, split by sample_lens, then pad.
draft_token_ids = [
torch.tensor(x, dtype=int, device='cpu') for x in draft_token_ids
]
draft_token_ids_tensor = pad_sequence(draft_token_ids,
batch_first=True,
padding_value=INVALID_TOKEN_ID)

if sampling_metadata.all_greedy:
target_token_ids = target_probs.argmax(dim=-1).view(-1)
target_token_ids = target_token_ids.split(sample_lens)
target_token_ids = pad_sequence(target_token_ids,
batch_first=True,
padding_value=INVALID_TOKEN_ID)

vocab_size = target_probs.size(-1)
# NOTE: CPU <-> GPU synchronization happens here.
draft_token_ids_tensor = draft_token_ids_tensor.to(
target_probs.device)
draft_probs = _create_greedy_token_probs(draft_token_ids_tensor,
vocab_size,
target_probs.device)
target_probs = _create_greedy_token_probs(target_token_ids,
vocab_size,
target_probs.device)
uniform_samples = torch.zeros(draft_token_ids_tensor.size(0),
draft_token_ids_tensor.size(1) + 1,
device=target_probs.device)
else:
raise NotImplementedError(
"Currently, only greedy sampling is supported by "
"rejection sampler.")

sampled_token_ids, _, _ = fs.chain_speculative_sampling(
draft_probs,
draft_token_ids,
draft_token_ids_tensor,
uniform_samples,
target_probs,
)
Expand All @@ -117,35 +119,35 @@ def flashinfer_sample(
# TODO: The following method can be optimized for better performance.
def forward_native(
self,
logits: torch.Tensor,
draft_token_ids: List[List[int]],
target_probs: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
assert sampling_metadata.spec_token_ids is not None
spec_lens = [len(x) for x in sampling_metadata.spec_token_ids]
# Add 1 to include the 'bonus' token.
sample_lens = [x + 1 for x in spec_lens]

output_token_ids = logits.argmax(dim=-1).view(-1)
output_token_ids = output_token_ids.split(sample_lens)
output_token_ids = pad_sequence(output_token_ids,
batch_first=True,
padding_value=INVALID_TOKEN_ID)

# Convert spec token IDs to a tensor, split by sample_lens, then pad.
spec_token_ids = [
torch.tensor(x,
dtype=output_token_ids.dtype,
device=output_token_ids.device)
for x in sampling_metadata.spec_token_ids
sample_lens = [len(x) + 1 for x in draft_token_ids]
# Convert draft token IDs to a tensor, split by sample_lens, then pad.
draft_token_ids = [
torch.tensor(x, dtype=int, device='cpu') for x in draft_token_ids
]
spec_token_ids = pad_sequence(spec_token_ids,
batch_first=True,
padding_value=INVALID_TOKEN_ID)

# Produce a mask that remains 1 (True) until the first
# mismatch (cumprod turns 0 after a mismatch).
accept_mask = (output_token_ids[:, :-1] == spec_token_ids).cumprod(
dim=1)
draft_token_ids_tensor = pad_sequence(draft_token_ids,
batch_first=True,
padding_value=INVALID_TOKEN_ID)
draft_token_ids_tensor = draft_token_ids_tensor.to(target_probs.device)
# Add 1 to include the 'bonus' token.
if sampling_metadata.all_greedy:
output_token_ids = target_probs.argmax(dim=-1).view(-1)
output_token_ids = output_token_ids.split(sample_lens)
output_token_ids = pad_sequence(output_token_ids,
batch_first=True,
padding_value=INVALID_TOKEN_ID)
# Produce a mask that remains 1 (True) until the first
# mismatch (cumprod turns 0 after a mismatch).
accept_mask = (
output_token_ids[:, :-1] == draft_token_ids_tensor).cumprod(
dim=1)
else:
raise NotImplementedError(
"Currently, only greedy sampling is supported by "
"rejection sampler.")
# Identify valid positions (non-padding).
valid_mask = output_token_ids != INVALID_TOKEN_ID
# Generate mask with bonus token.
Expand Down
19 changes: 8 additions & 11 deletions vllm/v1/sample/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from vllm.v1.sample.ops.penalties import (apply_all_penalties,
apply_min_token_penalties)
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
from vllm.v1.sample.rejection_sampler import RejectionSampler

_SAMPLING_EPS = 1e-5

Expand All @@ -19,22 +18,12 @@ class Sampler(nn.Module):
def __init__(self):
super().__init__()
self.topk_topp_sampler = TopKTopPSampler()
self.rejection_sampler = RejectionSampler()

def forward(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
if sampling_metadata.spec_token_ids:
if sampling_metadata.max_num_logprobs:
raise NotImplementedError(
"Rejection sampling does not support logprobs.")
return self.rejection_sampler(
logits,
sampling_metadata,
)

# NOTE(woosuk): Use the original logits (before any penalties or
# temperature scaling) for the top-k logprobs.
# This is different from the V0 sampler, which uses the logits that
Expand Down Expand Up @@ -127,6 +116,14 @@ def sample(
)
return sampled

def compute_probs(self, logits: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
if sampling_metadata.all_greedy:
return logits
# Apply temperature. This is an in-place op changing logits.
logits = self.apply_temperature(logits, sampling_metadata.temperature)
return logits.softmax(dim=-1, dtype=torch.float32)

def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
return logits.log_softmax(dim=-1, dtype=torch.float32)

Expand Down
11 changes: 0 additions & 11 deletions vllm/v1/worker/gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,23 +490,12 @@ def _make_sampling_metadata(self) -> SamplingMetadata:
presence_penalties=self.presence_penalties[:num_reqs],
repetition_penalties=self.repetition_penalties[:num_reqs],
output_token_ids=cast(List[List[int]], self.req_output_token_ids),
spec_token_ids=None,
min_tokens=self.min_tokens,
no_penalties=self.no_penalties,
logit_bias=self.logit_bias[:num_reqs],
allowed_token_ids_mask=allowed_token_ids_mask,
)

def get_sampling_metadata(
self,
req_id_to_spec_token_ids: Dict[str, List[int]],
) -> SamplingMetadata:
# Set the new spec token ids in the cached sampling metadata.
self.sampling_metadata.spec_token_ids = [
req_id_to_spec_token_ids.get(req_id, []) for req_id in self.req_ids
] if req_id_to_spec_token_ids else None
return self.sampling_metadata

def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
prompt_token_ids_cpu_tensor = torch.empty(
Expand Down
Loading