diff --git a/tests/singlecard/test_aclgraph.py b/tests/singlecard/test_aclgraph.py index e0bfb65cf8..fb02555956 100644 --- a/tests/singlecard/test_aclgraph.py +++ b/tests/singlecard/test_aclgraph.py @@ -36,9 +36,11 @@ reason="aclgraph only support on v1") @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("full_graph", [False]) def test_models( model: str, max_tokens: int, + full_graph: bool, monkeypatch: pytest.MonkeyPatch, ) -> None: with monkeypatch.context() as m: @@ -54,7 +56,15 @@ def test_models( temperature=0.0) # TODO: change to use vllmrunner when the registry of custom op is solved # while running pytest - vllm_model = LLM(model) + if full_graph: + vllm_model = LLM(model, + compilation_config={ + "full_cuda_graph": True, + "cudagraph_capture_sizes": + [1, 4, 16, 64, 256] + }) + else: + vllm_model = LLM(model) vllm_aclgraph_outputs = vllm_model.generate(prompts, sampling_params) del vllm_model torch.npu.empty_cache() diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 4a73d4b14e..75fd71c859 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -55,6 +55,10 @@ def set_ascend_forward_context( forward_context.in_profile_run = in_profile_run + # NOTE: This cannot be set using set_forward_context + # due to multiple warmups before actual capturing + forward_context.capturing = False + dp_world_size = get_dp_group().world_size if dp_world_size > 1 and forward_context.dp_metadata is not None: forward_context.max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item( diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index adb6de2af4..3417bb87fb 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -24,12 +24,16 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) from vllm.attention.backends.utils import CommonAttentionState +from vllm.config import get_current_vllm_config from vllm.forward_context import ForwardContext, get_forward_context from vllm.utils import direct_register_custom_op from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.gpu_input_batch import InputBatch +from vllm_ascend.attention.utils import \ + AscendCommonAttentionMetadata as CommonAttentionMetadata from vllm_ascend.ops.attention import vanilla_chunked_prefill +from vllm_ascend.utils import get_graph_params class AscendAttentionBackend(AttentionBackend): @@ -114,6 +118,7 @@ class AscendMetadata: query_start_loc: torch.Tensor query_lens: torch.Tensor seq_lens: torch.Tensor + seq_lens_list: list # Maximum query length in the batch. None for decoding. max_query_len: Optional[int] = None # (num_tokens,). The indices of the token slots that input tokens will be @@ -149,23 +154,26 @@ def build(self, num_reqs, num_actual_tokens, max_query_len, - common_prefix_len, - enable_dbo_across_dp: bool = False): + common_attn_metadata: CommonAttentionMetadata, + enable_dbo_across_dp: bool = False, + *args, + **kwargs): block_table = self.runner.input_batch.block_table[0].get_device_tensor( ) block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = ( block_table[:num_reqs]) - query_lens = self.runner.query_lens - seq_lens = self.runner.seq_lens_cpu[:num_reqs] - slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( - self.runner.device, non_blocking=True) + query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens + # TODO: Refactor these two param to common metadata in runners, + # preparing for the hybrid KV groups feature + query_lens = common_attn_metadata.query_lens if common_attn_metadata.query_lens is not None else self.runner.query_lens + seq_lens_list = common_attn_metadata.seq_lens_list if common_attn_metadata.seq_lens_list is not None else self.runner.seq_lens_list + + slot_mapping = self.runner.slot_mapping[:num_actual_tokens] attn_mask = self.runner.attn_mask attn_state = self.runner.attn_state - query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1] - query_start_loc = query_start_loc_cpu.to(self.runner.device, - non_blocking=True) attn_metadata = AscendMetadata( num_actual_tokens=num_actual_tokens, @@ -173,6 +181,7 @@ def build(self, query_start_loc=query_start_loc, query_lens=query_lens, seq_lens=seq_lens, + seq_lens_list=seq_lens_list, max_query_len=max_query_len, slot_mapping=slot_mapping, attn_mask=attn_mask, @@ -180,6 +189,34 @@ def build(self, enable_dbo_across_dp=enable_dbo_across_dp) return attn_metadata + def build_dummy_metadata(self, num_actual_tokens, num_reqs, + num_scheduled_tokens, attn_state): + if attn_state == AscendAttentionState.DecodeOnly: + # NOTE: We only need to pay attention to seq_lens_list and block_table here + common_attn_metadata = CommonAttentionMetadata(seq_lens_list=[2] * + num_reqs) + + block_table = self.runner.input_batch.block_table[0].block_table + block_table[:num_reqs, 0] = torch.arange(1, + num_reqs + 1, + device=block_table.device, + dtype=block_table.dtype) + + attn_metadata = self.build( + num_reqs=num_reqs, + num_actual_tokens=num_actual_tokens, + max_query_len=num_scheduled_tokens.max(), + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + ) + else: + raise NotImplementedError( + "Currently we only support building dummy metadata for DecodeOnly state" + ) + + attn_metadata.attn_state = attn_state + return attn_metadata + class AscendAttentionBackendImpl(AttentionImpl): @@ -217,6 +254,10 @@ def __init__( self.key_cache = None self.value_cache = None + vllm_config = get_current_vllm_config() + self.full_graph = vllm_config.compilation_config.full_cuda_graph + self.block_size = vllm_config.cache_config.block_size + def forward( self, layer: AttentionLayer, @@ -228,21 +269,7 @@ def forward( output: Optional[torch.Tensor] = None, trace_flag: bool = True, ) -> torch.Tensor: - """Forward pass with Ascend attention. - Args: - query: shape = [batch_size, seq_len, num_heads * head_size] - key: shape = [batch_size, seq_len, num_kv_heads * head_size] - value: shape = [batch_size, seq_len, num_kv_heads * head_size] - kv_cache: shape = [2, num_blocks, block_size, - num_kv_heads, head_size] - key_cache = [num_blocks, block_size, - num_kv_heads, head_size] - value_cache = [num_blocks, block_size, - num_kv_heads, head_size] - attn_metadata: Metadata for attention. - Returns: - shape = [batch_size * seq_len, num_heads, head_size] - """ + """Forward pass with Ascend attention.""" num_tokens = query.shape[0] if output is None: output = torch.empty(num_tokens, @@ -322,16 +349,92 @@ def forward( scale_value=self.scale, out=output) elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: - torch_npu._npu_paged_attention( - query=query, - key_cache=self.key_cache, - value_cache=self.value_cache, - num_kv_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale_value=self.scale, - block_table=attn_metadata.block_tables, - context_lens=attn_metadata.seq_lens, - out=output) + if self.full_graph: + graph_params = get_graph_params() + q = query.view(num_tokens, -1, self.hidden_size) + k = self.key_cache.view( # type: ignore + -1, self.block_size, + self.num_kv_heads * self.head_size) + v = self.value_cache.view( # type: ignore + -1, self.block_size, + self.num_kv_heads * self.head_size) + actual_seq_lens = attn_metadata.seq_lens_list + attn_args = { + "query": q, + "key": k, + "value": v, + "actual_seq_lengths_kv": actual_seq_lens, + "block_table": attn_metadata.block_tables, + "num_heads": self.num_heads, + "scale": self.scale, + "input_layout": "BSH", + "num_key_value_heads": self.num_kv_heads, + "block_size": self.block_size, + } + + # Prepare tensors for attention output + # TODO: Refactor this to step-level instead of layer-level + attn_output = torch.empty(num_tokens, + 1, + self.hidden_size, + dtype=output.dtype, + device=output.device) + softmax_lse = torch.empty(num_tokens, + dtype=output.dtype, + device=output.device) + + # Get workspace from cache or calculate it if not present. + workspace = graph_params.workspaces.get(num_tokens) + if workspace is None: + workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( + **attn_args) + graph_params.workspaces[num_tokens] = workspace + + forward_context = get_forward_context() + if not forward_context.capturing: + # Execute attention kernel directly in non-capturing mode + torch.ops.npu.npu_fused_infer_attention_score.out( + workspace=workspace, + out=[attn_output, softmax_lse], + **attn_args) + else: + # Handle graph capturing mode + stream = torch_npu.npu.current_stream() + + event = torch.npu.ExternalEvent() + event.wait(stream) + event.reset(stream) + graph_params.events[num_tokens].append(event) + + graph_params.attn_params[num_tokens].append( + (q, k, v, actual_seq_lens, + attn_metadata.block_tables, self.num_heads, + self.scale, self.num_kv_heads, attn_output, + softmax_lse)) + + torch.npu.graph_task_group_begin(stream) + torch.ops.npu.npu_fused_infer_attention_score.out( + workspace=workspace, + out=[attn_output, softmax_lse], + **attn_args) + handle = torch.npu.graph_task_group_end(stream) + graph_params.handles[num_tokens].append(handle) + + # Reshape output to match the expected format + output.copy_( + attn_output.view(num_tokens, self.num_heads, + self.head_size)) + else: + torch_npu._npu_paged_attention( + query=query, + key_cache=self.key_cache, + value_cache=self.value_cache, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + block_table=attn_metadata.block_tables, + context_lens=attn_metadata.seq_lens, + out=output) # Normal V1 situation. else: # use chunked prefill for head size 192 scenario, like deepseek diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 98f0a3389c..816d93c028 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -16,6 +16,8 @@ from vllm_ascend import envs from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.attention.utils import \ + AscendCommonAttentionMetadata as CommonAttentionMetadata from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig from vllm_ascend.multistream.context import get_multistream_comm_context from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn @@ -28,20 +30,6 @@ from vllm.v1.worker.gpu_input_batch import InputBatch -@dataclass -class CommonAttentionMetadata: - """ - Attention metadata attributes that can be shared by layers in different KV - cache groups and thus having different block table. - """ - - query_start_loc: torch.Tensor - """(batch_size + 1,), the start location of each request in query Tensor""" - seq_lens: torch.Tensor - """(batch_size,), the length of each request including both computed tokens - and newly scheduled tokens""" - - class AscendMLABackend(AttentionBackend): accept_output_buffer: bool = True diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py new file mode 100644 index 0000000000..c2b7bc156a --- /dev/null +++ b/vllm_ascend/attention/utils.py @@ -0,0 +1,23 @@ +from dataclasses import dataclass +from typing import Optional + +import torch + + +@dataclass +class AscendCommonAttentionMetadata: + """ + Attention metadata attributes that can be shared by layers in different KV + cache groups and thus having different block table. + """ + + query_start_loc: torch.Tensor = None + """(batch_size + 1,), the start location of each request in query Tensor""" + seq_lens: Optional[torch.Tensor] = None + """(batch_size,), the length of each request including both computed tokens + and newly scheduled tokens""" + query_lens: Optional[torch.Tensor] = None + """(batch_size,), the length of each request including only the newly + scheduled tokens""" + seq_lens_list: Optional[list] = None + """(num_input_tokens,), note that this is specifically for FIA kernel""" diff --git a/vllm_ascend/compilation/piecewise_backend.py b/vllm_ascend/compilation/piecewise_backend.py index c6a800b3d8..aafe639373 100644 --- a/vllm_ascend/compilation/piecewise_backend.py +++ b/vllm_ascend/compilation/piecewise_backend.py @@ -28,9 +28,13 @@ from vllm.compilation.counter import compilation_counter from vllm.compilation.monitor import end_monitoring_torch_compile from vllm.config import VllmConfig +from vllm.forward_context import get_forward_context from vllm.logger import logger from vllm.utils import weak_ref_tensors +from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.utils import get_graph_params, set_graph_params + @dataclasses.dataclass class ConcreteSizeEntry: @@ -95,6 +99,10 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" + if self.compilation_config.full_cuda_graph: + self.update_stream = torch.npu.Stream() + set_graph_params(self.aclgraph_capture_sizes) + # the entries for different shapes that we need to either # compile or capture aclgraph self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {} @@ -116,7 +124,40 @@ def check_for_ending_compilation(self): self.vllm_backend.compiler_manager.save_to_file() end_monitoring_torch_compile(self.vllm_config) + def update_attn_params(self, graph_params, forward_context, runtime_shape): + for layer_idx in range(len(graph_params.handles[runtime_shape])): + query, key, value, actual_seq_lens, block_table, num_heads, scale, num_kv_heads, output, softmax_lse = graph_params.attn_params[ + runtime_shape][layer_idx] + block_table = forward_context.attn_metadata.block_tables + actual_seq_lens = forward_context.attn_metadata.seq_lens_list + + with torch.npu.stream(self.update_stream): + torch.npu.graph_task_update_begin( + self.update_stream, + graph_params.handles[runtime_shape][layer_idx]) + torch.ops.npu.npu_fused_infer_attention_score.out( + query, + key, + value, + workspace=graph_params.workspaces[runtime_shape], + actual_seq_lengths_kv=actual_seq_lens, + block_table=block_table, + num_heads=num_heads, + scale=scale, + input_layout="BSH", + num_key_value_heads=num_kv_heads, + block_size=128, + out=[output, softmax_lse], + ) + torch.npu.graph_task_update_end(self.update_stream) + + graph_params.events[runtime_shape][layer_idx].record( + self.update_stream) + def __call__(self, *args) -> Any: + forward_context = get_forward_context() + graph_params = get_graph_params() + if not self.first_run_finished: self.first_run_finished = True self.check_for_ending_compilation() @@ -127,6 +168,11 @@ def __call__(self, *args) -> Any: # we don't need to do anything for this shape return self.compiled_graph_for_general_shape(*args) + if (getattr(forward_context.attn_metadata, "attn_state", + None) != AscendAttentionState.DecodeOnly + and self.compilation_config.full_cuda_graph): + return self.compiled_graph_for_general_shape(*args) + entry = self.concrete_size_entries[runtime_shape] if entry.runnable is None: @@ -189,6 +235,7 @@ def __call__(self, *args) -> Any: patch("torch.npu.empty_cache", lambda: None)) # mind-exploding: carefully manage the reference and memory. + forward_context.capturing = True with torch.npu.graph(aclgraph, pool=self.graph_pool): # `output` is managed by pytorch's aclgraph pool output = entry.runnable(*args) @@ -222,4 +269,9 @@ def __call__(self, *args) -> Any: ) entry.aclgraph.replay() + + if self.compilation_config.full_cuda_graph: + self.update_attn_params(graph_params, forward_context, + runtime_shape) + return entry.output diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index b4ea85692b..a7e2bc0eea 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -162,8 +162,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "PIECEWISE compilation enabled on NPU. use_inductor not supported - " "using only ACL Graph mode") compilation_config.use_inductor = False - compilation_config.splitting_ops.extend( - ["vllm.unified_ascend_attention_with_output"]) + if not compilation_config.full_cuda_graph: + compilation_config.splitting_ops.extend( + ["vllm.unified_ascend_attention_with_output"]) update_aclgraph_sizes(vllm_config) if parallel_config and parallel_config.worker_cls == "auto": diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 71f584eb86..f7ca0aba2e 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -20,9 +20,10 @@ import atexit import math from contextlib import contextmanager, nullcontext +from dataclasses import dataclass from enum import Enum from threading import Lock -from typing import TYPE_CHECKING, List, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple import torch import torch_npu @@ -171,6 +172,27 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None: original_sizes, compilation_config.cudagraph_capture_sizes = \ compilation_config.cudagraph_capture_sizes, None + if compilation_config.full_cuda_graph: + max_num_seqs = vllm_config.scheduler_config.max_num_seqs + truncated_sizes = [x for x in original_sizes if x <= max_num_seqs] + compilation_config.init_with_cudagraph_sizes(truncated_sizes) + + warning_message = """\033[91m + ********************************************************************************** + * WARNING: You have enabled the *full graph* feature. + * This is an early experimental stage and may involve various unknown issues. + * A known problem is that capturing too many batch sizes can lead to OOM + * (Out of Memory) errors or inference hangs. If you encounter such issues, + * consider reducing `gpu_memory_utilization` or manually specifying a smaller + * batch size for graph capture. + * For more details, please refer to: + * https://docs.vllm.ai/en/stable/configuration/conserving_memory.html#reduce-cuda-graphs + **********************************************************************************\033[0m + """ + + logger.warning(warning_message) + return + # Calculate parallel configuration factor num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers parallel_config = vllm_config.parallel_config @@ -305,3 +327,34 @@ def get_ascend_soc_version(): global _ascend_soc_version assert _ascend_soc_version is not None return _ascend_soc_version + + +@dataclass +class GraphParams: + events: dict[int, list[torch.npu.ExternalEvent]] + workspaces: dict[int, torch.Tensor] + handles: dict[int, list[torch_npu._C._NPUTaskGroupHandle]] + attn_params: dict[int, list[tuple]] + + +_graph_params: Optional[GraphParams] = None + + +def set_graph_params(aclgraph_capture_sizes: set[int]): + global _graph_params + if _graph_params is not None: + raise ValueError("Graph parameters have already been set!") + _graph_params = GraphParams( + {size: [] + for size in aclgraph_capture_sizes}, + {size: None + for size in aclgraph_capture_sizes}, + {size: [] + for size in aclgraph_capture_sizes}, + {size: [] + for size in aclgraph_capture_sizes}, + ) + + +def get_graph_params(): + return _graph_params diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index e0ab79be45..4b2feefa90 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -79,7 +79,8 @@ from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.attention.attention import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState -from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata +from vllm_ascend.attention.utils import \ + AscendCommonAttentionMetadata as CommonAttentionMetadata from vllm_ascend.multistream.ms_split import compute_split_seq_index from vllm_ascend.platform import NPUPlatform from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler @@ -258,6 +259,12 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.seq_lens = torch.zeros(self.max_num_reqs, dtype=torch.int32, device=self.device) + self.slot_mapping = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device=self.device) + self.query_lens = torch.zeros(self.max_num_reqs, + dtype=torch.int32, + device=self.device) # None in the first PP rank. The rest are set after load_model. self.intermediate_tensors: Optional[IntermediateTensors] = None @@ -956,15 +963,21 @@ def _process_reqs( self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True) self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], non_blocking=True) + self.slot_mapping[:total_num_scheduled_tokens].copy_( + self.slot_mapping_cpu[:total_num_scheduled_tokens], + non_blocking=True) # Fill unused with -1. Needed for reshape_and_cache + self.slot_mapping[total_num_scheduled_tokens:].fill_(-1) self.seq_lens[num_reqs:].fill_(0) self.query_start_loc[num_reqs + 1:].fill_(-1) query_start_loc = self.query_start_loc[:num_reqs + 1] - seq_lens = self.seq_lens[:num_reqs] + # Use host tensor, other wise error: tensor.hostData is null common_attn_metadata = CommonAttentionMetadata( - query_start_loc=query_start_loc, seq_lens=seq_lens) + query_start_loc=query_start_loc, + seq_lens=self.seq_lens_cpu[:num_reqs]) + self.seq_lens_list = self.seq_lens_np.tolist()[:num_input_tokens] with_prefill = attn_state not in [ AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding ] @@ -1011,6 +1024,7 @@ def _process_reqs( num_reqs=num_reqs, num_actual_tokens=total_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens, + common_attn_metadata=common_attn_metadata, common_prefix_len=None, **extra_builder_kwargs, ) @@ -1559,6 +1573,7 @@ def _dummy_run( skip_attn: bool = True, with_prefill: bool = False, is_torchair_compile: bool = False, + attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly, ) -> torch.Tensor: if self.torchair_graph_enabled and not with_prefill: num_tokens = self.select_torchair_padded_batch_size(num_tokens) @@ -1598,8 +1613,12 @@ def _dummy_run( elif skip_attn: attn_metadata = None else: - # TODO(zzzzwwjj): when aclgraph and full graph mode, we need build attn_metadata - attn_metadata = None + attn_metadata = self.attn_metadata_builder.build_dummy_metadata( + num_actual_tokens=num_tokens, + num_reqs=num_reqs, + num_scheduled_tokens=num_scheduled_tokens, + attn_state=attn_state, + ) with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): @@ -1982,11 +2001,13 @@ def capture_model(self) -> None: # can reuse the memory pool allocated for the large shapes. # TODO(zzzzwwjj): Check dummy_run with ACL Graph and full graph mode with graph_capture(device=self.device): + skip_attn = not self.vllm_config.compilation_config.full_cuda_graph + # TODO: Make sure passing attn_state to _dummy_run in the future for num_tokens in reversed(self.aclgraph_batch_sizes): for _ in range(self.vllm_config.compilation_config. cudagraph_num_of_warmups): - self._dummy_run(num_tokens) - self._dummy_run(num_tokens) + self._dummy_run(num_tokens, skip_attn=skip_attn) + self._dummy_run(num_tokens, skip_attn=skip_attn) else: logger.info("Skipping NPU graph capture for eager mode.") return diff --git a/vllm_ascend/worker/mtp_proposer_v1.py b/vllm_ascend/worker/mtp_proposer_v1.py index 1f0f75fda7..04a7d617b5 100644 --- a/vllm_ascend/worker/mtp_proposer_v1.py +++ b/vllm_ascend/worker/mtp_proposer_v1.py @@ -10,7 +10,8 @@ from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.attention.attention_v1 import AscendAttentionState -from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata +from vllm_ascend.attention.utils import \ + AscendCommonAttentionMetadata as CommonAttentionMetadata from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP from vllm_ascend.utils import ProfileExecuteDuration