Skip to content

Commit 85c50bf

Browse files
committed
Refactor spec decode to support efficient padded speculation
Signed-off-by: xuyexiong <xuyexiong@huawei.com>
1 parent daa4dd0 commit 85c50bf

File tree

8 files changed

+1143
-436
lines changed

8 files changed

+1143
-436
lines changed

tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
from vllm.config import CompilationConfig, CUDAGraphMode
66

77
from tests.e2e.conftest import VllmRunner
8+
import os
9+
10+
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
811

912

1013
@pytest.fixture
@@ -22,6 +25,7 @@ def mtp_correctness(
2225
model_name: str,
2326
num_speculative_tokens: int,
2427
graph_mode: CUDAGraphMode = CUDAGraphMode.PIECEWISE,
28+
disable_padded_drafter_batch = True
2529
):
2630
example_prompts = [
2731
"Hello, my name is",
@@ -54,6 +58,7 @@ def mtp_correctness(
5458
speculative_config={
5559
"method": "deepseek_mtp",
5660
"num_speculative_tokens": num_speculative_tokens,
61+
"disable_padded_drafter_batch":disable_padded_drafter_batch,
5762
},
5863
enforce_eager=False,
5964
max_model_len=2000,
@@ -108,3 +113,15 @@ def test_mtp2_correctness_full_graph(
108113
model_name: str,
109114
):
110115
mtp_correctness(sampling_config, model_name, 2, CUDAGraphMode.FULL)
116+
117+
def test_mtp1_correctness_piecewise_graph_with_pad(
118+
sampling_config: SamplingParams,
119+
model_name: str,
120+
):
121+
mtp_correctness(sampling_config, model_name, 1, False)
122+
123+
def test_mtp2_correctness_piecewise_graph_with_pad(
124+
sampling_config: SamplingParams,
125+
model_name: str,
126+
):
127+
mtp_correctness(sampling_config, model_name, 2, False)

vllm_ascend/spec_decode/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,21 @@
1919
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
2020
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
2121
from vllm_ascend.spec_decode.ngram_proposer import NgramProposer
22+
from vllm_ascend.torchair.mtp_torchair_proposer import MtpTorchairProposer
2223

2324

24-
def get_spec_decode_method(method, vllm_config, device, runner):
25+
def get_spec_decode_method(method,
26+
vllm_config,
27+
device,
28+
runner,
29+
is_torchair_graph=False):
2530
if method == "ngram":
2631
return NgramProposer(vllm_config, device, runner)
2732
elif method in ["eagle", "eagle3"]:
2833
return EagleProposer(vllm_config, device, runner)
2934
elif method == 'deepseek_mtp':
35+
if is_torchair_graph:
36+
return MtpTorchairProposer(vllm_config, device, runner)
3037
return MtpProposer(vllm_config, device, runner)
3138
else:
3239
raise ValueError("Unknown speculative decoding method: "

vllm_ascend/spec_decode/eagle_proposer.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,28 @@ def __init__(self,
3232
device: torch.device,
3333
runner=None):
3434
self.name = SpecDcodeType.EAGLE if vllm_config.speculative_config.method == "eagle" else SpecDcodeType.EAGLE3
35-
self.vllm_config = vllm_config
3635
self.device = device
36+
self.vllm_config = vllm_config
37+
self.speculative_config = vllm_config.speculative_config
38+
self.draft_model_config = self.speculative_config.draft_model_config
39+
self.method = self.speculative_config.method
40+
3741
self.runner = runner
42+
self.dtype = vllm_config.model_config.dtype
43+
self.max_model_len = vllm_config.model_config.max_model_len
44+
self.block_size = vllm_config.cache_config.block_size
45+
self.num_speculative_tokens = (
46+
self.speculative_config.num_speculative_tokens)
47+
self.max_num_tokens = (
48+
vllm_config.scheduler_config.max_num_batched_tokens)
49+
self.token_arange_np = np.arange(self.max_num_tokens)
3850

3951
self.block_size = vllm_config.cache_config.block_size
4052
# We need to get the hidden size from the draft model config because
4153
# the draft model's hidden size can be different from the target model's
4254
# hidden size (e.g., Llama 3.3 70B).
43-
self.hidden_size = vllm_config.speculative_config.draft_model_config.get_hidden_size(
44-
)
55+
self.hidden_size = self.draft_model_config.get_hidden_size()
56+
4557

4658
self.use_cuda_graph = (self.vllm_config.compilation_config.level
4759
== CompilationLevel.PIECEWISE and
@@ -52,17 +64,16 @@ def __init__(self,
5264

5365
# persistent buffers for cuda graph
5466
self.input_ids = torch.zeros(
55-
self.vllm_config.scheduler_config.max_num_batched_tokens,
67+
self.max_num_tokens,
5668
dtype=torch.int32,
5769
device=device)
5870
self.positions = torch.zeros(
59-
self.vllm_config.scheduler_config.max_num_batched_tokens,
71+
self.max_num_tokens,
6072
dtype=torch.int64,
6173
device=device)
6274
self.hidden_states = torch.zeros(
63-
(self.vllm_config.scheduler_config.max_num_batched_tokens,
64-
self.hidden_size),
65-
dtype=self.vllm_config.model_config.dtype,
75+
(self.max_num_tokens, self.hidden_size),
76+
dtype=self.dtype,
6677
device=device)
6778
# We need +1 here because the arange is used to set query_start_loc,
6879
# which has one more element than batch_size.
@@ -398,14 +409,18 @@ def _propose(
398409
# [batch_size, max_num_blocks_per_req]
399410
block_table: torch.Tensor,
400411
sampling_metadata: SamplingMetadata,
412+
last_token_indices: Optional[torch.Tensor],
413+
401414
) -> torch.Tensor:
402415
device = cu_num_tokens.device
403416
cu_num_tokens = cu_num_tokens.cpu()
404417
block_table = block_table.cpu()
405418
num_tokens = target_token_ids.shape[0]
406419
batch_size = next_token_ids.shape[0]
407-
last_token_indices = cu_num_tokens[1:] - 1
420+
if last_token_indices is None:
421+
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
408422
target_positions = target_positions.cpu()
423+
409424
if self.name == SpecDcodeType.EAGLE3:
410425
assert isinstance(self.model, Eagle3LlamaForCausalLM)
411426
target_hidden_states = self.model.combine_hidden_states(

0 commit comments

Comments
 (0)