diff --git a/tests/e2e/singlecard/sample/test_rejection_sampler.py b/tests/e2e/singlecard/sample/test_rejection_sampler.py index 123e7c20c0..2a33120286 100644 --- a/tests/e2e/singlecard/sample/test_rejection_sampler.py +++ b/tests/e2e/singlecard/sample/test_rejection_sampler.py @@ -77,8 +77,9 @@ def test_perfect_match(rejection_sampler): metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) - bonus_token_tensor = torch.tensor([output_tokens[0][-1]], - device=logits.device) + bonus_token_tensor = torch.tensor([[output_tokens[0][-1]]], + device=logits.device, + dtype=torch.int32) spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, device=logits.device) @@ -102,8 +103,9 @@ def test_early_mismatch(rejection_sampler): metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) - bonus_token_tensor = torch.tensor([output_tokens[0][-1]], - device=logits.device) + bonus_token_tensor = torch.tensor([[output_tokens[0][-1]]], + device=logits.device, + dtype=torch.int32) spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, device=logits.device) @@ -131,7 +133,9 @@ def test_multiple_sequences(rejection_sampler): metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor( - [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device) + [output_tokens[0][-1], output_tokens[1][-1]], + device=logits.device, + dtype=torch.int32).unsqueeze(1) spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, device=logits.device) @@ -155,8 +159,9 @@ def test_single_token_sequence(rejection_sampler): metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) - bonus_token_tensor = torch.tensor([output_tokens[0][-1]], - device=logits.device) + bonus_token_tensor = torch.tensor([[output_tokens[0][-1]]], + device=logits.device, + dtype=torch.int32) spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, device=logits.device) @@ -178,8 +183,9 @@ def test_empty_sequence(rejection_sampler): metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) - bonus_token_tensor = torch.tensor([output_tokens[0][-1]], - device=logits.device) + bonus_token_tensor = torch.tensor([[output_tokens[0][-1]]], + device=logits.device, + dtype=torch.int32) spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, device=logits.device) @@ -203,7 +209,9 @@ def test_multiple_mismatches(rejection_sampler): metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor( - [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device) + [output_tokens[0][-1], output_tokens[1][-1]], + device=logits.device, + dtype=torch.int32).unsqueeze(1) spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, device=logits.device) @@ -237,7 +245,8 @@ def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens, metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor([tokens[-1] for tokens in output_tokens], - device=logits.device) + device=logits.device, + dtype=torch.int32).unsqueeze(1) spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, device=logits.device) diff --git a/tests/ut/sample/test_rejection_sampler.py b/tests/ut/sample/test_rejection_sampler.py index b6aaf868c5..adbf376dd7 100644 --- a/tests/ut/sample/test_rejection_sampler.py +++ b/tests/ut/sample/test_rejection_sampler.py @@ -32,11 +32,12 @@ class TestAscendRejectionSampler(TestBase): def test_rejection_greedy_sample_pytorch(self): """Test greedy rejection sampling: stop when draft doesn't match, otherwise append bonus token""" batch_size = 2 - max_spec_len = 3 + max_spec_len = 2 output_token_ids = torch.full((batch_size, max_spec_len + 1), PLACEHOLDER_TOKEN_ID) cu_num_draft_tokens = torch.tensor([2, 4]) + num_draft_tokens = [2, 2] draft_token_ids = torch.tensor([10, 11, 20, 21]) target_argmax = torch.tensor([10, 99, 20, 22]) bonus_token_ids = torch.tensor([[100], [200]]) @@ -49,8 +50,9 @@ def test_rejection_greedy_sample_pytorch(self): draft_token_ids, target_argmax, bonus_token_ids, - is_greedy, + num_draft_tokens, max_spec_len, + is_greedy, ) assert output_token_ids[0, 0].item() == 10 diff --git a/vllm_ascend/sample/rejection_sampler.py b/vllm_ascend/sample/rejection_sampler.py index 832f0179dd..e0d770df26 100644 --- a/vllm_ascend/sample/rejection_sampler.py +++ b/vllm_ascend/sample/rejection_sampler.py @@ -147,16 +147,25 @@ def rejection_sample( if not sampling_metadata.all_random: # Rejection sampling for greedy sampling requests. target_argmax = target_probs.argmax(dim=-1) - rejection_greedy_sample_pytorch( - output_token_ids, - cu_num_draft_tokens, - draft_token_ids, - target_argmax, - bonus_token_ids, - is_greedy, - max_spec_len, - # num_warps=1, - ) + if min(num_draft_tokens) == 1 and max( + num_draft_tokens) == 1 and sampling_metadata.all_greedy: + rejection_greedy_sample_spec_len_1_pytorch( + output_token_ids, + draft_token_ids, + target_argmax, + bonus_token_ids, + ) + else: + rejection_greedy_sample_pytorch( + output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + target_argmax, + bonus_token_ids, + num_draft_tokens, + max_spec_len, + is_greedy, + ) if sampling_metadata.all_greedy: return output_token_ids @@ -284,47 +293,89 @@ def sample_recovered_tokens( return recovered_token_ids -def rejection_greedy_sample_pytorch( - output_token_ids, # [batch_size, max_spec_len + 1] - cu_num_draft_tokens, # [batch_size] - draft_token_ids, # [num_tokens] - target_argmax, # [num_tokens] - bonus_token_ids, # [batch_size] - is_greedy=None, # [batch_size] or None - max_spec_len=None, +def rejection_greedy_sample_spec_len_1_pytorch( + output_token_ids, # [batch_size, 2] + draft_token_ids, # [num_tokens] + target_argmax, # [num_tokens] + bonus_token_ids, # [batch_size] ): - batch_size = output_token_ids.shape[0] - - if is_greedy is None: - is_greedy = torch.ones(batch_size, - dtype=torch.bool, - device=output_token_ids.device) - - for req_idx in range(batch_size): - if not is_greedy[req_idx]: - continue - - if req_idx == 0: - start_idx = 0 - else: - start_idx = cu_num_draft_tokens[req_idx - 1].item() - end_idx = cu_num_draft_tokens[req_idx].item() - num_draft_tokens = end_idx - start_idx - - rejected = False - for pos in range(num_draft_tokens): - if not rejected: - draft_token_id = draft_token_ids[start_idx + pos].item() - target_argmax_id = target_argmax[start_idx + pos].item() - - output_token_ids[req_idx, pos] = target_argmax_id + batch_size = output_token_ids.size(0) + num_tokens = draft_token_ids.size(0) + assert batch_size == num_tokens + accept_req_mask = draft_token_ids == target_argmax + output_token_ids[:, 0] = target_argmax + bonus_token_ids = bonus_token_ids.squeeze(1) + output_token_ids[accept_req_mask, 1] = bonus_token_ids[accept_req_mask] - if draft_token_id != target_argmax_id: - rejected = True - if not rejected: - bonus_token_id = bonus_token_ids[req_idx].item() - output_token_ids[req_idx, num_draft_tokens] = bonus_token_id +def rejection_greedy_sample_pytorch( + output_token_ids, # [batch_size, max_spec_len + 1] + cu_num_draft_tokens, # [batch_size] + draft_token_ids, # [num_tokens] + target_argmax, # [num_tokens] + bonus_token_ids, # [batch_size] + draft_tokens_per_req, # [batch_size], list + max_spec_len, + is_greedy=None, # [batch_size] or None +): + batch_size = output_token_ids.size(0) + num_tokens = draft_token_ids.size(0) + device = output_token_ids.device + draft_tokens_per_req = torch.tensor(draft_tokens_per_req).to( + device, non_blocking=True) + if is_greedy is None: + is_greedy = torch.ones(batch_size, dtype=torch.bool, device=device) + + start_indices = cu_num_draft_tokens - draft_tokens_per_req + req_ids = torch.arange(batch_size, device=device) + token_req_ids = torch.repeat_interleave(req_ids, draft_tokens_per_req) + token_positions = torch.arange( + num_tokens, device=device) - start_indices[token_req_ids] + + # Find the first mismatch position of each request. + mismatch_global = (draft_token_ids != target_argmax) + if max_spec_len == 0: + first_mismatch_pos_per_req = torch.zeros(batch_size, + dtype=torch.long, + device=device) + else: + # [bs, max_spec_len] + pos_matrix = torch.full((batch_size, max_spec_len), + -1, + dtype=torch.long, + device=device) + pos_matrix[token_req_ids, token_positions] = token_positions + mismatch_matrix = torch.full((batch_size, max_spec_len), + False, + dtype=torch.bool, + device=device) + mismatch_matrix[token_req_ids, token_positions] = mismatch_global + mismatch_positions = torch.where(mismatch_matrix, pos_matrix, + max_spec_len * 2) + first_mismatch_pos_per_req, _ = torch.min(mismatch_positions, dim=1) + no_mismatch_mask = (first_mismatch_pos_per_req == max_spec_len * 2) + first_mismatch_pos_per_req[no_mismatch_mask] = draft_tokens_per_req[ + no_mismatch_mask] + + # Copy matched target tokens into output. + copy_len = torch.minimum(first_mismatch_pos_per_req + 1, + draft_tokens_per_req) + copy_indices = torch.arange(max_spec_len + 1, + device=device).expand(batch_size, -1) + copy_mask = copy_indices < copy_len.unsqueeze(1) + greedy_mask = is_greedy.unsqueeze(1) + final_copy_mask = copy_mask & greedy_mask + global_idx = start_indices.unsqueeze(1) + copy_indices + output_token_ids[final_copy_mask] = target_argmax[ + global_idx[final_copy_mask]].to(output_token_ids.dtype) + # Fill bonus token. + needs_bonus = is_greedy & (first_mismatch_pos_per_req + >= draft_tokens_per_req) + if torch.any(needs_bonus): + bonus_rows = torch.where(needs_bonus)[0] + bonus_cols = draft_tokens_per_req[bonus_rows] + bonus_token_ids = bonus_token_ids.squeeze(1) + output_token_ids[bonus_rows, bonus_cols] = bonus_token_ids[bonus_rows] def rejection_random_sample_pytorch(