Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions tests/e2e/singlecard/sample/test_rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,9 @@ def test_perfect_match(rejection_sampler):

metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
device=logits.device)
bonus_token_tensor = torch.tensor([[output_tokens[0][-1]]],
device=logits.device,
dtype=torch.int32)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
device=logits.device)

Expand All @@ -102,8 +103,9 @@ def test_early_mismatch(rejection_sampler):

metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
device=logits.device)
bonus_token_tensor = torch.tensor([[output_tokens[0][-1]]],
device=logits.device,
dtype=torch.int32)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
device=logits.device)

Expand Down Expand Up @@ -131,7 +133,9 @@ def test_multiple_sequences(rejection_sampler):
metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor(
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device)
[output_tokens[0][-1], output_tokens[1][-1]],
device=logits.device,
dtype=torch.int32).unsqueeze(1)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
device=logits.device)

Expand All @@ -155,8 +159,9 @@ def test_single_token_sequence(rejection_sampler):

metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
device=logits.device)
bonus_token_tensor = torch.tensor([[output_tokens[0][-1]]],
device=logits.device,
dtype=torch.int32)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
device=logits.device)

Expand All @@ -178,8 +183,9 @@ def test_empty_sequence(rejection_sampler):

metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
device=logits.device)
bonus_token_tensor = torch.tensor([[output_tokens[0][-1]]],
device=logits.device,
dtype=torch.int32)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
device=logits.device)

Expand All @@ -203,7 +209,9 @@ def test_multiple_mismatches(rejection_sampler):
metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor(
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device)
[output_tokens[0][-1], output_tokens[1][-1]],
device=logits.device,
dtype=torch.int32).unsqueeze(1)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
device=logits.device)

Expand Down Expand Up @@ -237,7 +245,8 @@ def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens,
metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([tokens[-1] for tokens in output_tokens],
device=logits.device)
device=logits.device,
dtype=torch.int32).unsqueeze(1)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
device=logits.device)

Expand Down
6 changes: 4 additions & 2 deletions tests/ut/sample/test_rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@ class TestAscendRejectionSampler(TestBase):
def test_rejection_greedy_sample_pytorch(self):
"""Test greedy rejection sampling: stop when draft doesn't match, otherwise append bonus token"""
batch_size = 2
max_spec_len = 3
max_spec_len = 2
output_token_ids = torch.full((batch_size, max_spec_len + 1),
PLACEHOLDER_TOKEN_ID)

cu_num_draft_tokens = torch.tensor([2, 4])
num_draft_tokens = [2, 2]
draft_token_ids = torch.tensor([10, 11, 20, 21])
target_argmax = torch.tensor([10, 99, 20, 22])
bonus_token_ids = torch.tensor([[100], [200]])
Expand All @@ -49,8 +50,9 @@ def test_rejection_greedy_sample_pytorch(self):
draft_token_ids,
target_argmax,
bonus_token_ids,
is_greedy,
num_draft_tokens,
max_spec_len,
is_greedy,
)

assert output_token_ids[0, 0].item() == 10
Expand Down
147 changes: 99 additions & 48 deletions vllm_ascend/sample/rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,16 +147,25 @@ def rejection_sample(
if not sampling_metadata.all_random:
# Rejection sampling for greedy sampling requests.
target_argmax = target_probs.argmax(dim=-1)
rejection_greedy_sample_pytorch(
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
target_argmax,
bonus_token_ids,
is_greedy,
max_spec_len,
# num_warps=1,
)
if min(num_draft_tokens) == 1 and max(
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
rejection_greedy_sample_spec_len_1_pytorch(
output_token_ids,
draft_token_ids,
target_argmax,
bonus_token_ids,
)
else:
rejection_greedy_sample_pytorch(
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
target_argmax,
bonus_token_ids,
num_draft_tokens,
max_spec_len,
is_greedy,
)
if sampling_metadata.all_greedy:
return output_token_ids

Expand Down Expand Up @@ -284,47 +293,89 @@ def sample_recovered_tokens(
return recovered_token_ids


def rejection_greedy_sample_pytorch(
output_token_ids, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens, # [batch_size]
draft_token_ids, # [num_tokens]
target_argmax, # [num_tokens]
bonus_token_ids, # [batch_size]
is_greedy=None, # [batch_size] or None
max_spec_len=None,
def rejection_greedy_sample_spec_len_1_pytorch(
output_token_ids, # [batch_size, 2]
draft_token_ids, # [num_tokens]
target_argmax, # [num_tokens]
bonus_token_ids, # [batch_size]
):
batch_size = output_token_ids.shape[0]

if is_greedy is None:
is_greedy = torch.ones(batch_size,
dtype=torch.bool,
device=output_token_ids.device)

for req_idx in range(batch_size):
if not is_greedy[req_idx]:
continue

if req_idx == 0:
start_idx = 0
else:
start_idx = cu_num_draft_tokens[req_idx - 1].item()
end_idx = cu_num_draft_tokens[req_idx].item()
num_draft_tokens = end_idx - start_idx

rejected = False
for pos in range(num_draft_tokens):
if not rejected:
draft_token_id = draft_token_ids[start_idx + pos].item()
target_argmax_id = target_argmax[start_idx + pos].item()

output_token_ids[req_idx, pos] = target_argmax_id
batch_size = output_token_ids.size(0)
num_tokens = draft_token_ids.size(0)
assert batch_size == num_tokens
accept_req_mask = draft_token_ids == target_argmax
output_token_ids[:, 0] = target_argmax
bonus_token_ids = bonus_token_ids.squeeze(1)
output_token_ids[accept_req_mask, 1] = bonus_token_ids[accept_req_mask]

if draft_token_id != target_argmax_id:
rejected = True

if not rejected:
bonus_token_id = bonus_token_ids[req_idx].item()
output_token_ids[req_idx, num_draft_tokens] = bonus_token_id
def rejection_greedy_sample_pytorch(
output_token_ids, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens, # [batch_size]
draft_token_ids, # [num_tokens]
target_argmax, # [num_tokens]
bonus_token_ids, # [batch_size]
draft_tokens_per_req, # [batch_size], list
max_spec_len,
is_greedy=None, # [batch_size] or None
):
batch_size = output_token_ids.size(0)
num_tokens = draft_token_ids.size(0)
device = output_token_ids.device
draft_tokens_per_req = torch.tensor(draft_tokens_per_req).to(
device, non_blocking=True)
if is_greedy is None:
is_greedy = torch.ones(batch_size, dtype=torch.bool, device=device)

start_indices = cu_num_draft_tokens - draft_tokens_per_req
req_ids = torch.arange(batch_size, device=device)
token_req_ids = torch.repeat_interleave(req_ids, draft_tokens_per_req)
token_positions = torch.arange(
num_tokens, device=device) - start_indices[token_req_ids]

# Find the first mismatch position of each request.
mismatch_global = (draft_token_ids != target_argmax)
if max_spec_len == 0:
first_mismatch_pos_per_req = torch.zeros(batch_size,
dtype=torch.long,
device=device)
else:
# [bs, max_spec_len]
pos_matrix = torch.full((batch_size, max_spec_len),
-1,
dtype=torch.long,
device=device)
pos_matrix[token_req_ids, token_positions] = token_positions
mismatch_matrix = torch.full((batch_size, max_spec_len),
False,
dtype=torch.bool,
device=device)
mismatch_matrix[token_req_ids, token_positions] = mismatch_global
mismatch_positions = torch.where(mismatch_matrix, pos_matrix,
max_spec_len * 2)
first_mismatch_pos_per_req, _ = torch.min(mismatch_positions, dim=1)
no_mismatch_mask = (first_mismatch_pos_per_req == max_spec_len * 2)
first_mismatch_pos_per_req[no_mismatch_mask] = draft_tokens_per_req[
no_mismatch_mask]

# Copy matched target tokens into output.
copy_len = torch.minimum(first_mismatch_pos_per_req + 1,
draft_tokens_per_req)
copy_indices = torch.arange(max_spec_len + 1,
device=device).expand(batch_size, -1)
copy_mask = copy_indices < copy_len.unsqueeze(1)
greedy_mask = is_greedy.unsqueeze(1)
final_copy_mask = copy_mask & greedy_mask
global_idx = start_indices.unsqueeze(1) + copy_indices
output_token_ids[final_copy_mask] = target_argmax[
global_idx[final_copy_mask]].to(output_token_ids.dtype)
# Fill bonus token.
needs_bonus = is_greedy & (first_mismatch_pos_per_req
>= draft_tokens_per_req)
if torch.any(needs_bonus):
bonus_rows = torch.where(needs_bonus)[0]
bonus_cols = draft_tokens_per_req[bonus_rows]
bonus_token_ids = bonus_token_ids.squeeze(1)
output_token_ids[bonus_rows, bonus_cols] = bonus_token_ids[bonus_rows]


def rejection_random_sample_pytorch(
Expand Down
Loading