Skip to content

Commit 84eebc5

Browse files
committed
update
Signed-off-by: shen-shanshan <467638484@qq.com>
1 parent d0c1e2a commit 84eebc5

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ class AscendMetadata:
122122

123123
# **************************** Basic Properties ****************************
124124
attn_mask: Optional[torch.Tensor] = None
125+
125126
# Current state of this attention run.
126127
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
127128

@@ -134,7 +135,9 @@ class AscendMetadata:
134135
seq_lens: torch.Tensor = None
135136

136137
query_start_loc: torch.Tensor = None
138+
137139
query_lens: torch.Tensor = None
140+
138141
# Maximum query length in the batch (None for decoding).
139142
max_query_len: Optional[int] = None
140143

@@ -325,6 +328,7 @@ def _forward_prefill_no_cache(
325328
) -> torch.Tensor:
326329
assert attn_metadata is not None
327330
assert attn_metadata.attn_mask is not None
331+
328332
mask = attn_metadata.attn_mask
329333

330334
if is_310p():
@@ -506,16 +510,17 @@ def forward(
506510

507511
# V0-Style scheduler situation.
508512
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
509-
output = self._forward_prefill_no_cache(attn_metadata, query, key,
510-
value, output, num_tokens)
513+
output = self._forward_prefill_no_cache(query, key, value,
514+
attn_metadata, output,
515+
num_tokens)
511516
elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
512-
output = self._forward_prefill_cache_hit(attn_metadata, query,
517+
output = self._forward_prefill_cache_hit(query, attn_metadata,
513518
output)
514519
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
515-
output = self._forward_decode_only(attn_metadata, query, output)
520+
output = self._forward_decode_only(query, attn_metadata, output)
516521
# Normal V1 situation.
517522
else:
518-
output = self._forward_v1_style(attn_metadata, query, output)
523+
output = self._forward_v1_style(query, attn_metadata, output)
519524

520525
# to make in-place change to the output tensor
521526
ori_output[:, :, :] = output[:num_tokens, :, :]

0 commit comments

Comments
 (0)