Skip to content

Commit 1e798e9

Browse files
committed
Sync code with commit afc4c0c on the main branch
1 parent b80be0d commit 1e798e9

File tree

2 files changed

+372
-330
lines changed

2 files changed

+372
-330
lines changed

vllm_ascend/spec_decode/eagle_v1.py

Lines changed: 86 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
# SPDX-License-Identifier: Apache-2.0
2+
import os
23
import torch
34
import torch.nn as nn
45
from vllm.attention.layer import Attention
56
from vllm.config import (CompilationLevel, VllmConfig,
67
get_layers_from_vllm_config, set_current_vllm_config)
8+
from vllm.distributed.parallel_state import get_pp_group
79
from vllm.forward_context import set_forward_context
810
from vllm.logger import init_logger
9-
from vllm.model_executor.model_loader import get_model_loader
11+
from vllm.model_executor.model_loader import get_model
12+
from vllm.model_executor.models import supports_multimodal
1013
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
1114
from vllm.model_executor.models import ModelRegistry
1215
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
13-
from vllm_ascend.attention.attention_v1 import AscendMetadata
16+
from vllm_ascend.attention.attention_v1 import AscendMetadata, AscendAttentionState
1417
from vllm.v1.sample.metadata import SamplingMetadata
1518
from vllm_ascend.attention.attention import AttentionMaskBuilder
1619

@@ -23,11 +26,13 @@ def __init__(
2326
self,
2427
vllm_config: VllmConfig,
2528
device: torch.device,
29+
runner=None
2630
):
2731
self.vllm_config = vllm_config
2832
self.speculative_config = vllm_config.speculative_config
2933
self.draft_model_config = self.speculative_config.draft_model_config
3034
self.method = self.speculative_config.method
35+
self.runner = runner
3136
self.model_config = vllm_config.model_config
3237
self.dtype = vllm_config.model_config.dtype
3338
self.max_model_len = vllm_config.model_config.max_model_len
@@ -95,10 +100,13 @@ def propose(
95100
sampling_metadata: SamplingMetadata,
96101

97102
) -> torch.Tensor:
103+
device = cu_num_tokens.device
104+
cu_num_tokens = cu_num_tokens.cpu()
105+
block_table = block_table.cpu()
98106
num_tokens = target_token_ids.shape[0]
99107
batch_size = next_token_ids.shape[0]
100108
last_token_indices = cu_num_tokens[1:] - 1
101-
109+
target_positions = target_positions.cpu()
102110
if self.method == "eagle3":
103111
assert isinstance(self.model, Eagle3LlamaForCausalLM)
104112
target_hidden_states = self.model.combine_hidden_states(
@@ -112,47 +120,29 @@ def propose(
112120
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
113121
self.input_ids[last_token_indices] = next_token_ids[0]
114122

115-
123+
query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
124+
max_query_len = query_lens.max().item()
116125
# FA requires seq_len to have dtype int32.
117126
seq_lens = (target_positions[last_token_indices] + 1).int().to('cpu')
118127

119128
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
120-
max_seq_len = seq_lens.max().item()
121-
max_num_tokens = (cu_num_tokens[1:] - cu_num_tokens[:-1]).max().item()
122-
123-
124-
# attn_mask = torch.zeros((20, 20), dtype=torch.bfloat16)
125-
attn_mask = self._make_attention_mask(seq_lens=seq_lens,
126-
query_lens=seq_lens,
127-
position=target_positions,
128-
)
129-
130-
attn_metadata = AscendMetadata(
129+
# max_seq_len = seq_lens.max().item()
130+
# max_num_tokens = (cu_num_tokens[1:] - cu_num_tokens[:-1]).max().item()
131+
attn_metadata = self.runner.attn_metadata_builder.build(
132+
num_reqs=batch_size,
131133
num_actual_tokens=num_tokens,
132-
max_query_len=max_num_tokens,
133-
query_start_loc=cu_num_tokens,
134-
max_seq_len=max_seq_len,
135-
seq_lens=seq_lens,
136-
query_lens=seq_lens,
137-
block_table=block_table,
138-
block_tables=block_table,
139-
slot_mapping=target_slot_mapping,
140-
# TODO(woosuk): Support cascade attention.
134+
max_query_len=max_query_len,
141135
common_prefix_len=0,
142-
attn_mask=attn_mask,
143-
cu_prefix_query_lens=None,
144-
prefix_kv_lens=None,
145-
suffix_kv_lens=None,
146136
)
147137
if self.use_cuda_graph and \
148138
num_tokens <= self.cudagraph_batch_sizes[-1]:
149139
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
150140
else:
151141
num_input_tokens = num_tokens
152142
# copy inputs to buffer for cudagraph
153-
self.positions[:num_tokens] = target_positions
143+
self.positions[:num_tokens] = target_positions.to(device)
154144
self.hidden_states[:num_tokens] = target_hidden_states
155-
int_positions = int(target_positions[last_token_indices])
145+
attn_metadata.block_tables = block_table.to(device)
156146
with set_forward_context(attn_metadata,
157147
self.vllm_config,
158148
num_tokens=num_input_tokens):
@@ -161,22 +151,22 @@ def propose(
161151
positions=self.positions[:num_input_tokens],
162152
hidden_states=self.hidden_states[:num_input_tokens],
163153
)
164-
print(f"last_token_indices={last_token_indices}")
165154
sample_hidden_states = last_hidden_states[last_token_indices]
166155
logits = self.model.compute_logits(sample_hidden_states, None)
167156
draft_token_ids = logits.argmax(dim=-1)
168-
157+
169158
# Early exit if there is only one draft token to be generated.
170159
if self.num_speculative_tokens == 1:
171160
# [batch_size, 1]
172161
return draft_token_ids.view(-1, 1)
173162

174163
# Generate the remaining draft tokens.
175-
draft_token_ids_list = [draft_token_ids]
176-
draft_token_ids_tensor = torch.zeros((self.num_speculative_tokens, *draft_token_ids.shape), dtype=draft_token_ids.dtype)
164+
draft_token_ids_tensor = torch.zeros(
165+
(self.num_speculative_tokens, *draft_token_ids.shape),
166+
dtype=draft_token_ids.dtype)
177167
draft_token_ids_tensor[0] = draft_token_ids
178-
179-
positions = target_positions[last_token_indices]
168+
169+
positions_cpu = target_positions[last_token_indices].cpu().to(torch.int64)
180170
hidden_states = hidden_states[last_token_indices]
181171
if self.use_cuda_graph and \
182172
batch_size <= self.cudagraph_batch_sizes[-1]:
@@ -188,75 +178,73 @@ def propose(
188178
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
189179

190180
if self.num_speculative_tokens > 2:
191-
raise ValueError("Speculative tokens > 2 are not yet supported.")
181+
raise ValueError("Speculative tokens > 2 are not supported yet.")
192182

183+
attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
193184
for now_speculative in range(self.num_speculative_tokens - 1):
194185
# Update the inputs.
195186
# cast to int32 is crucial when eagle model is compiled.
196187
# tensor.argmax() returns int64 by default.
197-
input_ids = draft_token_ids_tensor[now_speculative]
198-
# input_ids = draft_token_ids_list[-1]
199-
# positions += 1
200-
int_positions += 1
201-
positions = torch.tensor([int_positions], dtype=torch.int64, device='npu:0')
188+
input_ids = draft_token_ids_tensor[now_speculative].to(device)
189+
positions_cpu += 1
190+
202191

203192
# NOTE(woosuk): We should handle the case where the draft model
204193
# generates tokens beyond the max model length. Since it is complex
205194
# to remove such requests from the batch, we keep them in the batch
206195
# but adjust the position ids and slot mappings to avoid the
207196
# out-of-range access during the model execution. The draft tokens
208197
# generated with this adjustment should be ignored.
209-
exceeds_max_model_len = positions >= self.max_model_len
198+
exceeds_max_model_len = positions_cpu >= self.max_model_len
210199
# print(f"exceeds_max_model_len={exceeds_max_model_len}")
211200
# Mask out the position ids that exceed the max model length.
212201
# Otherwise, we may get out-of-range error in RoPE.
213-
clamped_positions = torch.where(exceeds_max_model_len, 0,
214-
positions)
202+
clamped_positions_cpu = torch.where(exceeds_max_model_len, 0,
203+
positions_cpu)
204+
clamped_positions = clamped_positions_cpu.to(device)
215205

216-
# Increment the sequence lengths.
217-
attn_metadata.max_seq_len += 1
206+
# TODO: Increment the sequence lengths.
218207
# attn_metadata.max_seq_len += 1
208+
219209
attn_metadata.seq_lens += 1
220-
# Consider max model length.
221-
attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
222-
self.max_model_len)
210+
# TODO: Consider max model length.
211+
# attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
212+
# self.max_model_len)
223213
# For the requests that exceed the max model length, we set the
224-
# sequence length to 1 to minimize their overheads in attention.
225-
attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len.to('cpu'), 1)
226-
227-
# block_table_indices = (req_indices * self.max_num_blocks_per_req +
228-
# positions_np // self.block_size)
229-
# block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
230-
# block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
231-
# block_offsets = positions_np % self.block_size
232-
# np.add(block_numbers * self.block_size,
233-
# block_offsets,
234-
# out=self.slot_mapping_np[:total_num_scheduled_tokens])
214+
# TODO: sequence length to 1 to minimize their overheads in attention.
215+
# attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len.to('cpu'), 1)
216+
235217
# Compute the slot mapping.
236-
block_numbers = clamped_positions // self.block_size
218+
block_numbers = (clamped_positions_cpu // self.block_size)
237219
block_ids = block_table.gather(dim=1,
238220
index=block_numbers.view(-1, 1))
239221
block_ids = block_ids.view(-1)
240-
attn_metadata.slot_mapping = (block_ids * self.block_size +
241-
clamped_positions % self.block_size)
222+
slot_mapping_cpu = (block_ids * self.block_size +
223+
clamped_positions_cpu % self.block_size)
224+
225+
# attn_metadata.slot_mapping = (block_ids * self.block_size +
226+
# clamped_positions % self.block_size)
242227
# Mask out the slot mappings that exceed the max model length.
243228
# Otherwise, the KV cache will be inadvertently updated with the
244229
# padding tokens.
245-
attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len,
230+
# attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len,
231+
# PADDING_SLOT_ID)
232+
slot_mapping_cpu.masked_fill_(exceeds_max_model_len,
246233
PADDING_SLOT_ID)
247-
234+
# NOTE: ASCEND slot_mapping must on cpu
235+
attn_metadata.slot_mapping = slot_mapping_cpu.to(torch.int32).to(device)
248236
# attn_metadata.num_actual_tokens = attn_metadata.seq_lens
249237
# copy inputs to buffer for cudagraph
250238
self.input_ids[:batch_size] = input_ids
251-
# self.input_ids[:batch_size] = input_ids
252239
self.positions[:batch_size] = clamped_positions
253240
self.hidden_states[:batch_size] = hidden_states
254-
241+
positions = positions_cpu.to(device)
255242
attn_mask = self._make_attention_mask(seq_lens=attn_metadata.seq_lens,
256243
query_lens=attn_metadata.max_query_len,
257-
position=torch.tensor([int_positions], dtype=torch.int64, device='npu:0'),
244+
position=positions,
258245
)
259246
attn_metadata.attn_mask = attn_mask
247+
attn_metadata.block_tables = block_table.to(device)
260248
# Run the model.
261249
with set_forward_context(attn_metadata,
262250
self.vllm_config,
@@ -273,15 +261,12 @@ def propose(
273261

274262
# TODO(wenlong): get more than one token for tree attention
275263
draft_token_ids = logits.argmax(dim=-1)
276-
# for _id in range(len(old_draft_token_ids_list)):
277-
# draft_token_ids_list[_id] = old_draft_token_ids_list[_id]
278-
# draft_token_ids_list.append(draft_token_ids)
279-
draft_token_ids_tensor[now_speculative+1] = draft_token_ids
264+
draft_token_ids_tensor[now_speculative+1] = draft_token_ids.cpu()
280265

281266

282267
# [batch_size, num_speculative_tokens]
283-
# draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
284268
draft_token_ids = draft_token_ids_tensor.swapaxes(0,1)
269+
print(f"draft_token_ids_tensor={draft_token_ids}")
285270
return draft_token_ids
286271

287272
@staticmethod
@@ -326,42 +311,43 @@ def prepare_inputs(
326311
return cu_num_tokens, token_indices
327312

328313
def load_model(self, target_model: nn.Module) -> None:
329-
loader = get_model_loader(self.vllm_config.load_config)
330-
target_layer_num = self.vllm_config.model_config.get_num_layers(
331-
self.vllm_config.parallel_config)
314+
draft_model_config = \
315+
self.vllm_config.speculative_config.draft_model_config
332316
target_attn_layer_names = set(
333317
get_layers_from_vllm_config(self.vllm_config, Attention).keys())
334318

335-
draft_model_config = \
336-
self.vllm_config.speculative_config.draft_model_config
337-
# FIXME(lily): This does not handle with distributed inference.
338-
target_device = self.vllm_config.device_config.device
339-
# We need to set the vllm_config here to register attention
340-
# layers in the forward context.
341-
with set_default_torch_dtype(
342-
draft_model_config.dtype), set_current_vllm_config(
343-
self.vllm_config):
344-
draft_model_cls, arch = ModelRegistry.resolve_model_cls(
345-
draft_model_config.architectures)
346-
self.model = draft_model_cls(
347-
vllm_config=self.vllm_config,
348-
start_layer_id=target_layer_num).to(target_device)
319+
self.model = get_model(vllm_config=self.vllm_config,
320+
model_config=draft_model_config)
349321

350322
draft_attn_layer_names = (
351323
get_layers_from_vllm_config(self.vllm_config, Attention).keys() -
352324
target_attn_layer_names)
353-
assert len(draft_attn_layer_names) == 1
325+
326+
self.attn_layer_names = list(draft_attn_layer_names)
354327
self.attn_layer_name = next(iter(draft_attn_layer_names))
355-
loaded_weights = self.model.load_weights(
356-
loader.get_all_weights(draft_model_config, self.model))
357-
if self.vllm_config.speculative_config.method == "eagle3":
358-
if "model.embed_tokens.weight" not in loaded_weights:
359-
logger.info(
360-
"Loading EAGLE embedding weights from the target model.")
361-
self.model.model.embed_tokens = target_model.model.embed_tokens
328+
# share embed_tokens with the target model if needed
329+
if get_pp_group().world_size == 1:
330+
logger.info(
331+
"The EAGLE head shares the same vocab embedding" \
332+
" with the target model."
333+
)
334+
self.model.model.embed_tokens = target_model.model.embed_tokens
362335
else:
336+
logger.info(
337+
"Since PP > 1, the EAGLE head loaded its own vocab embedding" \
338+
" weights instead of sharing them with the target model."
339+
)
340+
341+
# share lm_head with the target model if needed
342+
# some model definition do not define lm_head explicitly
343+
# and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
344+
if self.vllm_config.speculative_config.method != "eagle3" and \
345+
hasattr(target_model, "lm_head"):
363346
logger.info("Loading EAGLE LM head weights from the target model.")
364-
self.model.lm_head = target_model.lm_head
347+
if supports_multimodal(target_model):
348+
self.model.lm_head = target_model.get_language_model().lm_head
349+
else:
350+
self.model.lm_head = target_model.lm_head
365351

366352
@torch.inference_mode()
367353
def dummy_run(

0 commit comments

Comments
 (0)