Skip to content

Commit 8c24a6f

Browse files
committed
[fix]: resolve format issues
Signed-off-by: zhuohuan <zxdu1997@gmail.com>
1 parent eef22e1 commit 8c24a6f

File tree

9 files changed

+468
-360
lines changed

9 files changed

+468
-360
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,6 @@
44
import numpy as np
55
import torch
66
import torch_npu
7-
from vllm_ascend.attention.attention_v1 import AscendAttentionState
8-
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
9-
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
10-
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
11-
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
12-
137
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
148
AttentionMetadata,
159
MLAAttentionImpl)
@@ -20,6 +14,12 @@
2014
UnquantizedLinearMethod)
2115
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
2216

17+
from vllm_ascend.attention.attention_v1 import AscendAttentionState
18+
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
19+
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
20+
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
21+
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
22+
2323
if TYPE_CHECKING:
2424
from vllm.v1.core.sched.output import SchedulerOutput
2525
from vllm.v1.worker.gpu_input_batch import InputBatch
@@ -123,8 +123,8 @@ def __post_init__(self):
123123
# f"received {self.head_dim}.")
124124

125125
def split_metadata_for_multistream(
126-
self,
127-
ms_split_config: MSAttentionMetadataSplitConfig,
126+
self,
127+
ms_split_config: MSAttentionMetadataSplitConfig,
128128
) -> list["AscendMLAMetadata"]:
129129
"""Split metadata for multi-stream with AscendMLAMetadata"""
130130
return model_input_split_v1_mla_attn(
@@ -133,6 +133,7 @@ def split_metadata_for_multistream(
133133
_metadata_cls=AscendMLAMetadata,
134134
)
135135

136+
136137
M = TypeVar("M", bound=AscendMLAMetadata)
137138

138139

@@ -574,14 +575,14 @@ def _forward_prefill(
574575
)
575576
attn_output = attn_output.reshape(
576577
[num_tokens, self.num_heads * self.v_head_dim])
577-
578+
578579
# A better way is to modify the communication ops or RowParallel Layer in vllm;
579580
from vllm_ascend.multistream.context import \
580581
get_multistream_comm_context
581-
current_ms_metadata = get_multistream_comm_context()
582+
current_ms_metadata = get_multistream_comm_context()
582583
if current_ms_metadata is None:
583584
return self.o_proj(attn_output)[0]
584-
else:
585+
else:
585586
current_ms_metadata.before_comm_event.record()
586587
with torch.npu.stream(current_ms_metadata.comm_stream):
587588
current_ms_metadata.before_comm_event.wait()
@@ -687,16 +688,15 @@ def _forward_decode(
687688
out=attn_output)
688689
from vllm_ascend.multistream.context import \
689690
get_multistream_comm_context
690-
current_ms_metadata = get_multistream_comm_context()
691+
current_ms_metadata = get_multistream_comm_context()
691692
if current_ms_metadata is None:
692693
return self._v_up_proj_and_o_proj(attn_output)
693-
else:
694+
else:
694695
current_ms_metadata.before_comm_event.record()
695696
with torch.npu.stream(current_ms_metadata.comm_stream):
696697
current_ms_metadata.before_comm_event.wait()
697698
return self._v_up_proj_and_o_proj(attn_output)
698699

699-
700700
def forward(
701701
self,
702702
layer: AttentionLayer,
@@ -820,14 +820,15 @@ def forward(
820820
key_cache=kv_cache,
821821
slot_indices=attn_metadata.slot_mapping.flatten())
822822
if has_prefill:
823-
# FIX: aicore move should be also placed on the comm stream in dbo,
824-
# otherwise it may affect the accuracy
823+
# FIX: aicore move should be also placed on the comm stream in dbo,
824+
# otherwise it may affect the accuracy
825825
# TODO: use an elegant way to overlap
826826
from vllm_ascend.multistream.context import \
827827
get_multistream_comm_context
828-
output_prefill = self._forward_prefill(
829-
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
830-
attn_metadata)
828+
output_prefill = self._forward_prefill(prefill_q,
829+
prefill_k_c_normed,
830+
prefill_k_pe, kv_cache,
831+
attn_metadata)
831832
current_ms_metadata = get_multistream_comm_context()
832833
if current_ms_metadata is not None:
833834
with torch.npu.stream(current_ms_metadata.comm_stream):
@@ -836,7 +837,6 @@ def forward(
836837
else:
837838
output[num_decode_tokens:] = output_prefill
838839

839-
840840
if has_decode:
841841
if self.running_in_graph:
842842
return self._forward_decode(decode_ql_nope, decode_q_pe,
@@ -845,16 +845,17 @@ def forward(
845845
else:
846846
from vllm_ascend.multistream.context import \
847847
get_multistream_comm_context
848-
output_decode = self._forward_decode(
849-
decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe,
850-
kv_cache, attn_metadata)
851-
current_ms_metadata = get_multistream_comm_context()
848+
output_decode = self._forward_decode(decode_ql_nope,
849+
decode_q_pe,
850+
decode_k_nope,
851+
decode_k_pe, kv_cache,
852+
attn_metadata)
853+
current_ms_metadata = get_multistream_comm_context()
852854
if current_ms_metadata is not None:
853855
with torch.npu.stream(current_ms_metadata.comm_stream):
854856
output[:num_decode_tokens] = output_decode
855857
current_ms_metadata.after_comm_event.record()
856858
else:
857859
output[:num_decode_tokens] = output_decode
858860

859-
860861
return output_padded

0 commit comments

Comments
 (0)