diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index 845e88ee63..4dd62acc52 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -21,7 +21,10 @@ import torch from vllm.config import VllmConfig +from vllm.forward_context import get_forward_context +from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, + maybe_converting_weight_acl_format) from vllm_ascend.worker.model_runner_v1 import NPUModelRunner @@ -55,3 +58,58 @@ def _get_forward_metadata_across_dp_and_pad( maybe_padded_num_tokens = num_tokens return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo + + def _build_attention_metadata(self, with_prefill, num_reqs, skip_attn): + # NOTE: If torchair graph mode and not with_prefill, + # we can't skip_attn, it will cause graph recompile. + if not with_prefill: + attn_metadata = self.attn_metadata_builder.build_torchair_graph_dummy( + num_reqs=num_reqs, num_actual_tokens=1) + else: + attn_metadata = super()._build_attention_metadata( + with_prefill, num_reqs, skip_attn) + return attn_metadata + + def _generate_dummy_run_hidden_states(self, with_prefill, + is_torchair_compile, input_ids, + positions, attn_metadata, num_tokens, + intermediate_tensors, inputs_embeds): + + if not with_prefill: + # Only mark static while compiling + if is_torchair_compile: + torch._dynamo.mark_static(input_ids) + torch._dynamo.mark_static(positions) + torch._dynamo.mark_static(attn_metadata.decode.block_table) + torch._dynamo.mark_static(attn_metadata.decode.input_positions) + torch._dynamo.mark_static(get_forward_context().mc2_mask) + if hasattr(attn_metadata.decode, "sin"): + torch._dynamo.mark_static(attn_metadata.decode.sin) + torch._dynamo.mark_static(attn_metadata.decode.cos) + torch._dynamo.mark_static(attn_metadata.slot_mapping) + if self.speculative_config: + torch._dynamo.mark_static(attn_metadata.decode.attn_mask) + for kv in self.kv_caches: + assert isinstance(kv, tuple), "kv_cache must be a tuple" + torch._dynamo.mark_static(kv[0]) + torch._dynamo.mark_static(kv[1]) + + maybe_converting_weight_acl_format(self.model, + ACL_FORMAT_FRACTAL_NZ) + + compiled_model = self._get_torchair_lazy_compiled_model(num_tokens) + model_kwargs = {} + model_kwargs["kv_caches"] = self.kv_caches + model_kwargs["attn_metadata"] = attn_metadata + hidden_states = compiled_model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=None, + **model_kwargs, + ) + else: + hidden_states = super()._generate_dummy_run_hidden_states( + with_prefill, is_torchair_compile, input_ids, positions, + attn_metadata, num_tokens, intermediate_tensors, inputs_embeds) + return hidden_states diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index d3a29852bd..a0fe9e0802 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1832,6 +1832,31 @@ def get_finished_kv_transfer( scheduler_output.finished_req_ids) return None, None + def _build_attention_metadata(self, with_prefill, num_reqs, skip_attn): + if skip_attn: + attn_metadata = None + else: + # TODO(zzzzwwjj): when aclgraph and full graph mode, we need build attn_metadata + attn_metadata = None + return attn_metadata + + def _generate_dummy_run_hidden_states(self, with_prefill, + is_torchair_compile, input_ids, + positions, attn_metadata, num_tokens, + intermediate_tensors, inputs_embeds): + maybe_converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_ND) + hidden_states = self.model(input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds) + if self.use_aux_hidden_state_outputs: + hidden_states, _ = hidden_states + else: + hidden_states = hidden_states + if self.use_spec_decode and isinstance(self.drafter, EagleProposer): + self.drafter.dummy_run(num_tokens) + return hidden_states + @torch.inference_mode() def _dummy_run( self, @@ -1868,20 +1893,11 @@ def _dummy_run( if self.is_kv_producer: with_prefill = True - # NOTE: If torchair graph mode and not with_prefill, - # we can't skip_attn, it will cause graph recompile. - if self.torchair_graph_enabled and not with_prefill: - attn_metadata = self.attn_metadata_builder.build_torchair_graph_dummy( - num_reqs=num_reqs, num_actual_tokens=1) - elif skip_attn: - attn_metadata = None - else: - # TODO(zzzzwwjj): when aclgraph and full graph mode, we need build attn_metadata - attn_metadata = None + attn_metadata = self._build_attention_metadata(with_prefill, num_reqs, + skip_attn) with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): - model = self.model if self.is_multimodal_model: input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] @@ -1917,61 +1933,10 @@ def _dummy_run( in_profile_run=self.in_profile_run, num_actual_tokens=0, ): - model_kwargs = {} - if self.torchair_graph_enabled and not with_prefill: - # Only mark static while compiling - if is_torchair_compile: - torch._dynamo.mark_static(input_ids) - torch._dynamo.mark_static(positions) - torch._dynamo.mark_static( - attn_metadata.decode.block_table) - torch._dynamo.mark_static( - attn_metadata.decode.input_positions) - torch._dynamo.mark_static( - get_forward_context().mc2_mask) - if hasattr(attn_metadata.decode, "sin"): - torch._dynamo.mark_static(attn_metadata.decode.sin) - torch._dynamo.mark_static(attn_metadata.decode.cos) - torch._dynamo.mark_static(attn_metadata.slot_mapping) - if self.speculative_config: - torch._dynamo.mark_static( - attn_metadata.decode.attn_mask) - for kv in self.kv_caches: - assert isinstance( - kv, tuple), "kv_cache must be a tuple" - torch._dynamo.mark_static(kv[0]) - torch._dynamo.mark_static(kv[1]) - - maybe_converting_weight_acl_format(self.model, - ACL_FORMAT_FRACTAL_NZ) - - compiled_model = self._get_torchair_lazy_compiled_model( - num_tokens) - model_kwargs["kv_caches"] = self.kv_caches - model_kwargs["attn_metadata"] = attn_metadata - hidden_states = compiled_model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=None, - **model_kwargs, - ) - else: - maybe_converting_weight_acl_format(self.model, - ACL_FORMAT_FRACTAL_ND) - - hidden_states = model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds) - if self.use_aux_hidden_state_outputs: - hidden_states, _ = hidden_states - else: - hidden_states = hidden_states - if self.use_spec_decode and isinstance( - self.drafter, EagleProposer): - self.drafter.dummy_run(num_tokens) + hidden_states = self._generate_dummy_run_hidden_states( + with_prefill, is_torchair_compile, input_ids, positions, + attn_metadata, num_tokens, intermediate_tensors, + inputs_embeds) if self.speculative_config and self.speculative_config.method == "deepseek_mtp": assert isinstance(self.drafter, MtpProposer) self.drafter.dummy_run(