11# SPDX-License-Identifier: Apache-2.0
2+ import os
23import torch
34import torch .nn as nn
45from vllm .attention .layer import Attention
56from vllm .config import (CompilationLevel , VllmConfig ,
67 get_layers_from_vllm_config , set_current_vllm_config )
8+ from vllm .distributed .parallel_state import get_pp_group
79from vllm .forward_context import set_forward_context
810from vllm .logger import init_logger
9- from vllm .model_executor .model_loader import get_model_loader
11+ from vllm .model_executor .model_loader import get_model
12+ from vllm .model_executor .models import supports_multimodal
1013from vllm .model_executor .model_loader .utils import set_default_torch_dtype
1114from vllm .model_executor .models import ModelRegistry
1215from vllm .model_executor .models .llama_eagle3 import Eagle3LlamaForCausalLM
13- from vllm_ascend .attention .attention_v1 import AscendMetadata
16+ from vllm_ascend .attention .attention_v1 import AscendMetadata , AscendAttentionState
1417from vllm .v1 .sample .metadata import SamplingMetadata
1518from vllm_ascend .attention .attention import AttentionMaskBuilder
1619
@@ -23,11 +26,13 @@ def __init__(
2326 self ,
2427 vllm_config : VllmConfig ,
2528 device : torch .device ,
29+ runner = None
2630 ):
2731 self .vllm_config = vllm_config
2832 self .speculative_config = vllm_config .speculative_config
2933 self .draft_model_config = self .speculative_config .draft_model_config
3034 self .method = self .speculative_config .method
35+ self .runner = runner
3136 self .model_config = vllm_config .model_config
3237 self .dtype = vllm_config .model_config .dtype
3338 self .max_model_len = vllm_config .model_config .max_model_len
@@ -95,10 +100,13 @@ def propose(
95100 sampling_metadata : SamplingMetadata ,
96101
97102 ) -> torch .Tensor :
103+ device = cu_num_tokens .device
104+ cu_num_tokens = cu_num_tokens .cpu ()
105+ block_table = block_table .cpu ()
98106 num_tokens = target_token_ids .shape [0 ]
99107 batch_size = next_token_ids .shape [0 ]
100108 last_token_indices = cu_num_tokens [1 :] - 1
101-
109+ target_positions = target_positions . cpu ()
102110 if self .method == "eagle3" :
103111 assert isinstance (self .model , Eagle3LlamaForCausalLM )
104112 target_hidden_states = self .model .combine_hidden_states (
@@ -112,47 +120,29 @@ def propose(
112120 # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
113121 self .input_ids [last_token_indices ] = next_token_ids [0 ]
114122
115-
123+ query_lens = cu_num_tokens [1 :] - cu_num_tokens [:- 1 ]
124+ max_query_len = query_lens .max ().item ()
116125 # FA requires seq_len to have dtype int32.
117126 seq_lens = (target_positions [last_token_indices ] + 1 ).int ().to ('cpu' )
118127
119128 # FIXME(woosuk): The below two ops cause synchronization. Optimize.
120- max_seq_len = seq_lens .max ().item ()
121- max_num_tokens = (cu_num_tokens [1 :] - cu_num_tokens [:- 1 ]).max ().item ()
122-
123-
124- # attn_mask = torch.zeros((20, 20), dtype=torch.bfloat16)
125- attn_mask = self ._make_attention_mask (seq_lens = seq_lens ,
126- query_lens = seq_lens ,
127- position = target_positions ,
128- )
129-
130- attn_metadata = AscendMetadata (
129+ # max_seq_len = seq_lens.max().item()
130+ # max_num_tokens = (cu_num_tokens[1:] - cu_num_tokens[:-1]).max().item()
131+ attn_metadata = self .runner .attn_metadata_builder .build (
132+ num_reqs = batch_size ,
131133 num_actual_tokens = num_tokens ,
132- max_query_len = max_num_tokens ,
133- query_start_loc = cu_num_tokens ,
134- max_seq_len = max_seq_len ,
135- seq_lens = seq_lens ,
136- query_lens = seq_lens ,
137- block_table = block_table ,
138- block_tables = block_table ,
139- slot_mapping = target_slot_mapping ,
140- # TODO(woosuk): Support cascade attention.
134+ max_query_len = max_query_len ,
141135 common_prefix_len = 0 ,
142- attn_mask = attn_mask ,
143- cu_prefix_query_lens = None ,
144- prefix_kv_lens = None ,
145- suffix_kv_lens = None ,
146136 )
147137 if self .use_cuda_graph and \
148138 num_tokens <= self .cudagraph_batch_sizes [- 1 ]:
149139 num_input_tokens = self .vllm_config .pad_for_cudagraph (num_tokens )
150140 else :
151141 num_input_tokens = num_tokens
152142 # copy inputs to buffer for cudagraph
153- self .positions [:num_tokens ] = target_positions
143+ self .positions [:num_tokens ] = target_positions . to ( device )
154144 self .hidden_states [:num_tokens ] = target_hidden_states
155- int_positions = int ( target_positions [ last_token_indices ] )
145+ attn_metadata . block_tables = block_table . to ( device )
156146 with set_forward_context (attn_metadata ,
157147 self .vllm_config ,
158148 num_tokens = num_input_tokens ):
@@ -161,22 +151,22 @@ def propose(
161151 positions = self .positions [:num_input_tokens ],
162152 hidden_states = self .hidden_states [:num_input_tokens ],
163153 )
164- print (f"last_token_indices={ last_token_indices } " )
165154 sample_hidden_states = last_hidden_states [last_token_indices ]
166155 logits = self .model .compute_logits (sample_hidden_states , None )
167156 draft_token_ids = logits .argmax (dim = - 1 )
168-
157+
169158 # Early exit if there is only one draft token to be generated.
170159 if self .num_speculative_tokens == 1 :
171160 # [batch_size, 1]
172161 return draft_token_ids .view (- 1 , 1 )
173162
174163 # Generate the remaining draft tokens.
175- draft_token_ids_list = [draft_token_ids ]
176- draft_token_ids_tensor = torch .zeros ((self .num_speculative_tokens , * draft_token_ids .shape ), dtype = draft_token_ids .dtype )
164+ draft_token_ids_tensor = torch .zeros (
165+ (self .num_speculative_tokens , * draft_token_ids .shape ),
166+ dtype = draft_token_ids .dtype )
177167 draft_token_ids_tensor [0 ] = draft_token_ids
178-
179- positions = target_positions [last_token_indices ]
168+
169+ positions_cpu = target_positions [last_token_indices ]. cpu (). to ( torch . int64 )
180170 hidden_states = hidden_states [last_token_indices ]
181171 if self .use_cuda_graph and \
182172 batch_size <= self .cudagraph_batch_sizes [- 1 ]:
@@ -188,75 +178,73 @@ def propose(
188178 attn_metadata .query_start_loc = self .arange [:batch_size + 1 ]
189179
190180 if self .num_speculative_tokens > 2 :
191- raise ValueError ("Speculative tokens > 2 are not yet supported." )
181+ raise ValueError ("Speculative tokens > 2 are not supported yet ." )
192182
183+ attn_metadata .attn_state = AscendAttentionState .ChunkedPrefill
193184 for now_speculative in range (self .num_speculative_tokens - 1 ):
194185 # Update the inputs.
195186 # cast to int32 is crucial when eagle model is compiled.
196187 # tensor.argmax() returns int64 by default.
197- input_ids = draft_token_ids_tensor [now_speculative ]
198- # input_ids = draft_token_ids_list[-1]
199- # positions += 1
200- int_positions += 1
201- positions = torch .tensor ([int_positions ], dtype = torch .int64 , device = 'npu:0' )
188+ input_ids = draft_token_ids_tensor [now_speculative ].to (device )
189+ positions_cpu += 1
190+
202191
203192 # NOTE(woosuk): We should handle the case where the draft model
204193 # generates tokens beyond the max model length. Since it is complex
205194 # to remove such requests from the batch, we keep them in the batch
206195 # but adjust the position ids and slot mappings to avoid the
207196 # out-of-range access during the model execution. The draft tokens
208197 # generated with this adjustment should be ignored.
209- exceeds_max_model_len = positions >= self .max_model_len
198+ exceeds_max_model_len = positions_cpu >= self .max_model_len
210199 # print(f"exceeds_max_model_len={exceeds_max_model_len}")
211200 # Mask out the position ids that exceed the max model length.
212201 # Otherwise, we may get out-of-range error in RoPE.
213- clamped_positions = torch .where (exceeds_max_model_len , 0 ,
214- positions )
202+ clamped_positions_cpu = torch .where (exceeds_max_model_len , 0 ,
203+ positions_cpu )
204+ clamped_positions = clamped_positions_cpu .to (device )
215205
216- # Increment the sequence lengths.
217- attn_metadata .max_seq_len += 1
206+ # TODO: Increment the sequence lengths.
218207 # attn_metadata.max_seq_len += 1
208+
219209 attn_metadata .seq_lens += 1
220- # Consider max model length.
221- attn_metadata .max_seq_len = min (attn_metadata .max_seq_len ,
222- self .max_model_len )
210+ # TODO: Consider max model length.
211+ # attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
212+ # self.max_model_len)
223213 # For the requests that exceed the max model length, we set the
224- # sequence length to 1 to minimize their overheads in attention.
225- attn_metadata .seq_lens .masked_fill_ (exceeds_max_model_len .to ('cpu' ), 1 )
226-
227- # block_table_indices = (req_indices * self.max_num_blocks_per_req +
228- # positions_np // self.block_size)
229- # block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
230- # block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
231- # block_offsets = positions_np % self.block_size
232- # np.add(block_numbers * self.block_size,
233- # block_offsets,
234- # out=self.slot_mapping_np[:total_num_scheduled_tokens])
214+ # TODO: sequence length to 1 to minimize their overheads in attention.
215+ # attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len.to('cpu'), 1)
216+
235217 # Compute the slot mapping.
236- block_numbers = clamped_positions // self .block_size
218+ block_numbers = ( clamped_positions_cpu // self .block_size )
237219 block_ids = block_table .gather (dim = 1 ,
238220 index = block_numbers .view (- 1 , 1 ))
239221 block_ids = block_ids .view (- 1 )
240- attn_metadata .slot_mapping = (block_ids * self .block_size +
241- clamped_positions % self .block_size )
222+ slot_mapping_cpu = (block_ids * self .block_size +
223+ clamped_positions_cpu % self .block_size )
224+
225+ # attn_metadata.slot_mapping = (block_ids * self.block_size +
226+ # clamped_positions % self.block_size)
242227 # Mask out the slot mappings that exceed the max model length.
243228 # Otherwise, the KV cache will be inadvertently updated with the
244229 # padding tokens.
245- attn_metadata .slot_mapping .masked_fill_ (exceeds_max_model_len ,
230+ # attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len,
231+ # PADDING_SLOT_ID)
232+ slot_mapping_cpu .masked_fill_ (exceeds_max_model_len ,
246233 PADDING_SLOT_ID )
247-
234+ # NOTE: ASCEND slot_mapping must on cpu
235+ attn_metadata .slot_mapping = slot_mapping_cpu .to (torch .int32 ).to (device )
248236 # attn_metadata.num_actual_tokens = attn_metadata.seq_lens
249237 # copy inputs to buffer for cudagraph
250238 self .input_ids [:batch_size ] = input_ids
251- # self.input_ids[:batch_size] = input_ids
252239 self .positions [:batch_size ] = clamped_positions
253240 self .hidden_states [:batch_size ] = hidden_states
254-
241+ positions = positions_cpu . to ( device )
255242 attn_mask = self ._make_attention_mask (seq_lens = attn_metadata .seq_lens ,
256243 query_lens = attn_metadata .max_query_len ,
257- position = torch . tensor ([ int_positions ], dtype = torch . int64 , device = 'npu:0' ) ,
244+ position = positions ,
258245 )
259246 attn_metadata .attn_mask = attn_mask
247+ attn_metadata .block_tables = block_table .to (device )
260248 # Run the model.
261249 with set_forward_context (attn_metadata ,
262250 self .vllm_config ,
@@ -273,15 +261,12 @@ def propose(
273261
274262 # TODO(wenlong): get more than one token for tree attention
275263 draft_token_ids = logits .argmax (dim = - 1 )
276- # for _id in range(len(old_draft_token_ids_list)):
277- # draft_token_ids_list[_id] = old_draft_token_ids_list[_id]
278- # draft_token_ids_list.append(draft_token_ids)
279- draft_token_ids_tensor [now_speculative + 1 ] = draft_token_ids
264+ draft_token_ids_tensor [now_speculative + 1 ] = draft_token_ids .cpu ()
280265
281266
282267 # [batch_size, num_speculative_tokens]
283- # draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
284268 draft_token_ids = draft_token_ids_tensor .swapaxes (0 ,1 )
269+ print (f"draft_token_ids_tensor={ draft_token_ids } " )
285270 return draft_token_ids
286271
287272 @staticmethod
@@ -326,42 +311,43 @@ def prepare_inputs(
326311 return cu_num_tokens , token_indices
327312
328313 def load_model (self , target_model : nn .Module ) -> None :
329- loader = get_model_loader (self .vllm_config .load_config )
330- target_layer_num = self .vllm_config .model_config .get_num_layers (
331- self .vllm_config .parallel_config )
314+ draft_model_config = \
315+ self .vllm_config .speculative_config .draft_model_config
332316 target_attn_layer_names = set (
333317 get_layers_from_vllm_config (self .vllm_config , Attention ).keys ())
334318
335- draft_model_config = \
336- self .vllm_config .speculative_config .draft_model_config
337- # FIXME(lily): This does not handle with distributed inference.
338- target_device = self .vllm_config .device_config .device
339- # We need to set the vllm_config here to register attention
340- # layers in the forward context.
341- with set_default_torch_dtype (
342- draft_model_config .dtype ), set_current_vllm_config (
343- self .vllm_config ):
344- draft_model_cls , arch = ModelRegistry .resolve_model_cls (
345- draft_model_config .architectures )
346- self .model = draft_model_cls (
347- vllm_config = self .vllm_config ,
348- start_layer_id = target_layer_num ).to (target_device )
319+ self .model = get_model (vllm_config = self .vllm_config ,
320+ model_config = draft_model_config )
349321
350322 draft_attn_layer_names = (
351323 get_layers_from_vllm_config (self .vllm_config , Attention ).keys () -
352324 target_attn_layer_names )
353- assert len (draft_attn_layer_names ) == 1
325+
326+ self .attn_layer_names = list (draft_attn_layer_names )
354327 self .attn_layer_name = next (iter (draft_attn_layer_names ))
355- loaded_weights = self . model . load_weights (
356- loader . get_all_weights ( draft_model_config , self . model ))
357- if self . vllm_config . speculative_config . method == "eagle3" :
358- if "model.embed_tokens.weight" not in loaded_weights :
359- logger . info (
360- "Loading EAGLE embedding weights from the target model." )
361- self .model .model .embed_tokens = target_model .model .embed_tokens
328+ # share embed_tokens with the target model if needed
329+ if get_pp_group (). world_size == 1 :
330+ logger . info (
331+ "The EAGLE head shares the same vocab embedding" \
332+ " with the target model."
333+ )
334+ self .model .model .embed_tokens = target_model .model .embed_tokens
362335 else :
336+ logger .info (
337+ "Since PP > 1, the EAGLE head loaded its own vocab embedding" \
338+ " weights instead of sharing them with the target model."
339+ )
340+
341+ # share lm_head with the target model if needed
342+ # some model definition do not define lm_head explicitly
343+ # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
344+ if self .vllm_config .speculative_config .method != "eagle3" and \
345+ hasattr (target_model , "lm_head" ):
363346 logger .info ("Loading EAGLE LM head weights from the target model." )
364- self .model .lm_head = target_model .lm_head
347+ if supports_multimodal (target_model ):
348+ self .model .lm_head = target_model .get_language_model ().lm_head
349+ else :
350+ self .model .lm_head = target_model .lm_head
365351
366352 @torch .inference_mode ()
367353 def dummy_run (
0 commit comments