@@ -122,6 +122,7 @@ class AscendMetadata:
122122
123123 # **************************** Basic Properties ****************************
124124 attn_mask : Optional [torch .Tensor ] = None
125+
125126 # Current state of this attention run.
126127 attn_state : AscendAttentionState = AscendAttentionState .ChunkedPrefill
127128
@@ -134,7 +135,9 @@ class AscendMetadata:
134135 seq_lens : torch .Tensor = None
135136
136137 query_start_loc : torch .Tensor = None
138+
137139 query_lens : torch .Tensor = None
140+
138141 # Maximum query length in the batch (None for decoding).
139142 max_query_len : Optional [int ] = None
140143
@@ -339,6 +342,7 @@ def _forward_prefill_no_cache(
339342 ) -> torch .Tensor :
340343 assert attn_metadata is not None
341344 assert attn_metadata .attn_mask is not None
345+
342346 mask = attn_metadata .attn_mask
343347
344348 if is_310p ():
@@ -520,16 +524,17 @@ def forward(
520524
521525 # V0-Style scheduler situation.
522526 if attn_metadata .attn_state == AscendAttentionState .PrefillNoCache :
523- output = self ._forward_prefill_no_cache (attn_metadata , query , key ,
524- value , output , num_tokens )
527+ output = self ._forward_prefill_no_cache (query , key , value ,
528+ attn_metadata , output ,
529+ num_tokens )
525530 elif attn_metadata .attn_state == AscendAttentionState .PrefillCacheHit :
526- output = self ._forward_prefill_cache_hit (attn_metadata , query ,
531+ output = self ._forward_prefill_cache_hit (query , attn_metadata ,
527532 output )
528533 elif attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
529- output = self ._forward_decode_only (attn_metadata , query , output )
534+ output = self ._forward_decode_only (query , attn_metadata , output )
530535 # Normal V1 situation.
531536 else :
532- output = self ._forward_v1_style (attn_metadata , query , output )
537+ output = self ._forward_v1_style (query , attn_metadata , output )
533538
534539 # to make in-place change to the output tensor
535540 ori_output [:, :, :] = output [:num_tokens , :, :]
0 commit comments