Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions vllm_ascend/torchair/torchair_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@

import torch
from vllm.config import VllmConfig
from vllm.forward_context import get_forward_context

from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ,
maybe_converting_weight_acl_format)
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner


Expand Down Expand Up @@ -55,3 +58,58 @@ def _get_forward_metadata_across_dp_and_pad(
maybe_padded_num_tokens = num_tokens

return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo

def _build_attention_metadata(self, with_prefill, num_reqs, skip_attn):
# NOTE: If torchair graph mode and not with_prefill,
# we can't skip_attn, it will cause graph recompile.
if not with_prefill:
attn_metadata = self.attn_metadata_builder.build_torchair_graph_dummy(
num_reqs=num_reqs, num_actual_tokens=1)
else:
attn_metadata = super()._build_attention_metadata(
with_prefill, num_reqs, skip_attn)
return attn_metadata

def _generate_dummy_run_hidden_states(self, with_prefill,
is_torchair_compile, input_ids,
positions, attn_metadata, num_tokens,
intermediate_tensors, inputs_embeds):

if not with_prefill:
# Only mark static while compiling
if is_torchair_compile:
torch._dynamo.mark_static(input_ids)
torch._dynamo.mark_static(positions)
torch._dynamo.mark_static(attn_metadata.decode.block_table)
torch._dynamo.mark_static(attn_metadata.decode.input_positions)
torch._dynamo.mark_static(get_forward_context().mc2_mask)
if hasattr(attn_metadata.decode, "sin"):
torch._dynamo.mark_static(attn_metadata.decode.sin)
torch._dynamo.mark_static(attn_metadata.decode.cos)
torch._dynamo.mark_static(attn_metadata.slot_mapping)
if self.speculative_config:
torch._dynamo.mark_static(attn_metadata.decode.attn_mask)
for kv in self.kv_caches:
assert isinstance(kv, tuple), "kv_cache must be a tuple"
torch._dynamo.mark_static(kv[0])
torch._dynamo.mark_static(kv[1])

maybe_converting_weight_acl_format(self.model,
ACL_FORMAT_FRACTAL_NZ)

compiled_model = self._get_torchair_lazy_compiled_model(num_tokens)
model_kwargs = {}
model_kwargs["kv_caches"] = self.kv_caches
model_kwargs["attn_metadata"] = attn_metadata
hidden_states = compiled_model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=None,
**model_kwargs,
)
else:
hidden_states = super()._generate_dummy_run_hidden_states(
with_prefill, is_torchair_compile, input_ids, positions,
attn_metadata, num_tokens, intermediate_tensors, inputs_embeds)
return hidden_states
97 changes: 31 additions & 66 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -1832,6 +1832,31 @@ def get_finished_kv_transfer(
scheduler_output.finished_req_ids)
return None, None

def _build_attention_metadata(self, with_prefill, num_reqs, skip_attn):
if skip_attn:
attn_metadata = None
else:
# TODO(zzzzwwjj): when aclgraph and full graph mode, we need build attn_metadata
attn_metadata = None
return attn_metadata

def _generate_dummy_run_hidden_states(self, with_prefill,
is_torchair_compile, input_ids,
positions, attn_metadata, num_tokens,
intermediate_tensors, inputs_embeds):
maybe_converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_ND)
hidden_states = self.model(input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds)
if self.use_aux_hidden_state_outputs:
hidden_states, _ = hidden_states
else:
hidden_states = hidden_states
if self.use_spec_decode and isinstance(self.drafter, EagleProposer):
self.drafter.dummy_run(num_tokens)
return hidden_states

@torch.inference_mode()
def _dummy_run(
self,
Expand Down Expand Up @@ -1868,20 +1893,11 @@ def _dummy_run(
if self.is_kv_producer:
with_prefill = True

# NOTE: If torchair graph mode and not with_prefill,
# we can't skip_attn, it will cause graph recompile.
if self.torchair_graph_enabled and not with_prefill:
attn_metadata = self.attn_metadata_builder.build_torchair_graph_dummy(
num_reqs=num_reqs, num_actual_tokens=1)
elif skip_attn:
attn_metadata = None
else:
# TODO(zzzzwwjj): when aclgraph and full graph mode, we need build attn_metadata
attn_metadata = None
attn_metadata = self._build_attention_metadata(with_prefill, num_reqs,
skip_attn)

with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens):
model = self.model
if self.is_multimodal_model:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens]
Expand Down Expand Up @@ -1917,61 +1933,10 @@ def _dummy_run(
in_profile_run=self.in_profile_run,
num_actual_tokens=0,
):
model_kwargs = {}
if self.torchair_graph_enabled and not with_prefill:
# Only mark static while compiling
if is_torchair_compile:
torch._dynamo.mark_static(input_ids)
torch._dynamo.mark_static(positions)
torch._dynamo.mark_static(
attn_metadata.decode.block_table)
torch._dynamo.mark_static(
attn_metadata.decode.input_positions)
torch._dynamo.mark_static(
get_forward_context().mc2_mask)
if hasattr(attn_metadata.decode, "sin"):
torch._dynamo.mark_static(attn_metadata.decode.sin)
torch._dynamo.mark_static(attn_metadata.decode.cos)
torch._dynamo.mark_static(attn_metadata.slot_mapping)
if self.speculative_config:
torch._dynamo.mark_static(
attn_metadata.decode.attn_mask)
for kv in self.kv_caches:
assert isinstance(
kv, tuple), "kv_cache must be a tuple"
torch._dynamo.mark_static(kv[0])
torch._dynamo.mark_static(kv[1])

maybe_converting_weight_acl_format(self.model,
ACL_FORMAT_FRACTAL_NZ)

compiled_model = self._get_torchair_lazy_compiled_model(
num_tokens)
model_kwargs["kv_caches"] = self.kv_caches
model_kwargs["attn_metadata"] = attn_metadata
hidden_states = compiled_model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=None,
**model_kwargs,
)
else:
maybe_converting_weight_acl_format(self.model,
ACL_FORMAT_FRACTAL_ND)

hidden_states = model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds)
if self.use_aux_hidden_state_outputs:
hidden_states, _ = hidden_states
else:
hidden_states = hidden_states
if self.use_spec_decode and isinstance(
self.drafter, EagleProposer):
self.drafter.dummy_run(num_tokens)
hidden_states = self._generate_dummy_run_hidden_states(
with_prefill, is_torchair_compile, input_ids, positions,
attn_metadata, num_tokens, intermediate_tensors,
inputs_embeds)
if self.speculative_config and self.speculative_config.method == "deepseek_mtp":
assert isinstance(self.drafter, MtpProposer)
self.drafter.dummy_run(
Expand Down
Loading