Skip to content

Commit 48b296e

Browse files
committed
revert eagle
Signed-off-by: xuyexiong <xuyexiong@huawei.com>
1 parent c1c0db7 commit 48b296e

File tree

1 file changed

+15
-26
lines changed

1 file changed

+15
-26
lines changed

vllm_ascend/spec_decode/eagle_proposer.py

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -32,27 +32,16 @@ def __init__(self,
3232
device: torch.device,
3333
runner=None):
3434
self.name = SpecDcodeType.EAGLE if vllm_config.speculative_config.method == "eagle" else SpecDcodeType.EAGLE3
35-
self.device = device
3635
self.vllm_config = vllm_config
37-
self.speculative_config = vllm_config.speculative_config
38-
self.draft_model_config = self.speculative_config.draft_model_config
39-
self.method = self.speculative_config.method
40-
36+
self.device = device
4137
self.runner = runner
42-
self.dtype = vllm_config.model_config.dtype
43-
self.max_model_len = vllm_config.model_config.max_model_len
44-
self.block_size = vllm_config.cache_config.block_size
45-
self.num_speculative_tokens = (
46-
self.speculative_config.num_speculative_tokens)
47-
self.max_num_tokens = (
48-
vllm_config.scheduler_config.max_num_batched_tokens)
49-
self.token_arange_np = np.arange(self.max_num_tokens)
5038

5139
self.block_size = vllm_config.cache_config.block_size
5240
# We need to get the hidden size from the draft model config because
5341
# the draft model's hidden size can be different from the target model's
5442
# hidden size (e.g., Llama 3.3 70B).
55-
self.hidden_size = self.draft_model_config.get_hidden_size()
43+
self.hidden_size = vllm_config.speculative_config.draft_model_config.get_hidden_size(
44+
)
5645

5746
self.use_cuda_graph = (self.vllm_config.compilation_config.level
5847
== CompilationLevel.PIECEWISE and
@@ -62,15 +51,18 @@ def __init__(self,
6251
self.vllm_config.compilation_config.cudagraph_capture_sizes))
6352

6453
# persistent buffers for cuda graph
65-
self.input_ids = torch.zeros(self.max_num_tokens,
66-
dtype=torch.int32,
67-
device=device)
68-
self.positions = torch.zeros(self.max_num_tokens,
69-
dtype=torch.int64,
70-
device=device)
54+
self.input_ids = torch.zeros(
55+
self.vllm_config.scheduler_config.max_num_batched_tokens,
56+
dtype=torch.int32,
57+
device=device)
58+
self.positions = torch.zeros(
59+
self.vllm_config.scheduler_config.max_num_batched_tokens,
60+
dtype=torch.int64,
61+
device=device)
7162
self.hidden_states = torch.zeros(
72-
(self.max_num_tokens, self.hidden_size),
73-
dtype=self.dtype,
63+
(self.vllm_config.scheduler_config.max_num_batched_tokens,
64+
self.hidden_size),
65+
dtype=self.vllm_config.model_config.dtype,
7466
device=device)
7567
# We need +1 here because the arange is used to set query_start_loc,
7668
# which has one more element than batch_size.
@@ -406,17 +398,14 @@ def _propose(
406398
# [batch_size, max_num_blocks_per_req]
407399
block_table: torch.Tensor,
408400
sampling_metadata: SamplingMetadata,
409-
last_token_indices: Optional[torch.Tensor],
410401
) -> torch.Tensor:
411402
device = cu_num_tokens.device
412403
cu_num_tokens = cu_num_tokens.cpu()
413404
block_table = block_table.cpu()
414405
num_tokens = target_token_ids.shape[0]
415406
batch_size = next_token_ids.shape[0]
416-
if last_token_indices is None:
417-
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
407+
last_token_indices = cu_num_tokens[1:] - 1
418408
target_positions = target_positions.cpu()
419-
420409
if self.name == SpecDcodeType.EAGLE3:
421410
assert isinstance(self.model, Eagle3LlamaForCausalLM)
422411
target_hidden_states = self.model.combine_hidden_states(

0 commit comments

Comments
 (0)