Skip to content

Commit e56f12c

Browse files
committed
Merge branch 'padded-spec' of https://github.com/JC-ut0/vllm-ascend into padded-spec
2 parents 8eef0f9 + 48b296e commit e56f12c

File tree

3 files changed

+27
-30
lines changed

3 files changed

+27
-30
lines changed

tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,14 +113,22 @@ def test_mtp2_correctness_full_graph(
113113
):
114114
mtp_correctness(sampling_config, model_name, 2, CUDAGraphMode.FULL)
115115

116+
116117
def test_mtp1_correctness_piecewise_graph_with_pad(
117118
sampling_config: SamplingParams,
118119
model_name: str,
119120
):
120-
mtp_correctness(sampling_config, model_name, 1, disable_padded_drafter_batch=False)
121+
mtp_correctness(sampling_config,
122+
model_name,
123+
1,
124+
disable_padded_drafter_batch=False)
125+
121126

122127
def test_mtp2_correctness_piecewise_graph_with_pad(
123128
sampling_config: SamplingParams,
124129
model_name: str,
125130
):
126-
mtp_correctness(sampling_config, model_name, 2, disable_padded_drafter_batch=False)
131+
mtp_correctness(sampling_config,
132+
model_name,
133+
2,
134+
disable_padded_drafter_batch=False)

vllm_ascend/spec_decode/eagle_proposer.py

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -32,27 +32,16 @@ def __init__(self,
3232
device: torch.device,
3333
runner=None):
3434
self.name = SpecDcodeType.EAGLE if vllm_config.speculative_config.method == "eagle" else SpecDcodeType.EAGLE3
35-
self.device = device
3635
self.vllm_config = vllm_config
37-
self.speculative_config = vllm_config.speculative_config
38-
self.draft_model_config = self.speculative_config.draft_model_config
39-
self.method = self.speculative_config.method
40-
36+
self.device = device
4137
self.runner = runner
42-
self.dtype = vllm_config.model_config.dtype
43-
self.max_model_len = vllm_config.model_config.max_model_len
44-
self.block_size = vllm_config.cache_config.block_size
45-
self.num_speculative_tokens = (
46-
self.speculative_config.num_speculative_tokens)
47-
self.max_num_tokens = (
48-
vllm_config.scheduler_config.max_num_batched_tokens)
49-
self.token_arange_np = np.arange(self.max_num_tokens)
5038

5139
self.block_size = vllm_config.cache_config.block_size
5240
# We need to get the hidden size from the draft model config because
5341
# the draft model's hidden size can be different from the target model's
5442
# hidden size (e.g., Llama 3.3 70B).
55-
self.hidden_size = self.draft_model_config.get_hidden_size()
43+
self.hidden_size = vllm_config.speculative_config.draft_model_config.get_hidden_size(
44+
)
5645

5746
self.use_cuda_graph = (self.vllm_config.compilation_config.level
5847
== CompilationLevel.PIECEWISE and
@@ -62,15 +51,18 @@ def __init__(self,
6251
self.vllm_config.compilation_config.cudagraph_capture_sizes))
6352

6453
# persistent buffers for cuda graph
65-
self.input_ids = torch.zeros(self.max_num_tokens,
66-
dtype=torch.int32,
67-
device=device)
68-
self.positions = torch.zeros(self.max_num_tokens,
69-
dtype=torch.int64,
70-
device=device)
54+
self.input_ids = torch.zeros(
55+
self.vllm_config.scheduler_config.max_num_batched_tokens,
56+
dtype=torch.int32,
57+
device=device)
58+
self.positions = torch.zeros(
59+
self.vllm_config.scheduler_config.max_num_batched_tokens,
60+
dtype=torch.int64,
61+
device=device)
7162
self.hidden_states = torch.zeros(
72-
(self.max_num_tokens, self.hidden_size),
73-
dtype=self.dtype,
63+
(self.vllm_config.scheduler_config.max_num_batched_tokens,
64+
self.hidden_size),
65+
dtype=self.vllm_config.model_config.dtype,
7466
device=device)
7567
# We need +1 here because the arange is used to set query_start_loc,
7668
# which has one more element than batch_size.
@@ -406,17 +398,14 @@ def _propose(
406398
# [batch_size, max_num_blocks_per_req]
407399
block_table: torch.Tensor,
408400
sampling_metadata: SamplingMetadata,
409-
last_token_indices: Optional[torch.Tensor],
410401
) -> torch.Tensor:
411402
device = cu_num_tokens.device
412403
cu_num_tokens = cu_num_tokens.cpu()
413404
block_table = block_table.cpu()
414405
num_tokens = target_token_ids.shape[0]
415406
batch_size = next_token_ids.shape[0]
416-
if last_token_indices is None:
417-
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
407+
last_token_indices = cu_num_tokens[1:] - 1
418408
target_positions = target_positions.cpu()
419-
420409
if self.name == SpecDcodeType.EAGLE3:
421410
assert isinstance(self.model, Eagle3LlamaForCausalLM)
422411
target_hidden_states = self.model.combine_hidden_states(

vllm_ascend/torchair/mtp_torchair_proposer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def generate_token_ids(self,
203203
attn_metadata.slot_mapping[:num_scheduled_tokens],
204204
)
205205

206-
draft_token_ids = self._propose(
206+
draft_token_ids = self._propose_torchair(
207207
target_token_ids=target_token_ids,
208208
target_positions=target_positions,
209209
target_hidden_states=target_hidden_states,
@@ -251,7 +251,7 @@ def _torchair_prepare_inputs(
251251

252252
return cu_num_tokens, token_indices, target_token_ids, target_positions, target_hidden_states, target_slot_mapping
253253

254-
def _propose(
254+
def _propose_torchair(
255255
self,
256256
# [num_tokens]
257257
target_token_ids: torch.Tensor,

0 commit comments

Comments
 (0)