Skip to content

Commit 01e35d3

Browse files
committed
update
Signed-off-by: shen-shanshan <467638484@qq.com>
1 parent 0bccf05 commit 01e35d3

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ class AscendMetadata:
122122

123123
# **************************** Basic Properties ****************************
124124
attn_mask: Optional[torch.Tensor] = None
125+
125126
# Current state of this attention run.
126127
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
127128

@@ -134,7 +135,9 @@ class AscendMetadata:
134135
seq_lens: torch.Tensor = None
135136

136137
query_start_loc: torch.Tensor = None
138+
137139
query_lens: torch.Tensor = None
140+
138141
# Maximum query length in the batch (None for decoding).
139142
max_query_len: Optional[int] = None
140143

@@ -339,6 +342,7 @@ def _forward_prefill_no_cache(
339342
) -> torch.Tensor:
340343
assert attn_metadata is not None
341344
assert attn_metadata.attn_mask is not None
345+
342346
mask = attn_metadata.attn_mask
343347

344348
if is_310p():
@@ -520,16 +524,17 @@ def forward(
520524

521525
# V0-Style scheduler situation.
522526
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
523-
output = self._forward_prefill_no_cache(attn_metadata, query, key,
524-
value, output, num_tokens)
527+
output = self._forward_prefill_no_cache(query, key, value,
528+
attn_metadata, output,
529+
num_tokens)
525530
elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
526-
output = self._forward_prefill_cache_hit(attn_metadata, query,
531+
output = self._forward_prefill_cache_hit(query, attn_metadata,
527532
output)
528533
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
529-
output = self._forward_decode_only(attn_metadata, query, output)
534+
output = self._forward_decode_only(query, attn_metadata, output)
530535
# Normal V1 situation.
531536
else:
532-
output = self._forward_v1_style(attn_metadata, query, output)
537+
output = self._forward_v1_style(query, attn_metadata, output)
533538

534539
# to make in-place change to the output tensor
535540
ori_output[:, :, :] = output[:num_tokens, :, :]

0 commit comments

Comments
 (0)