@@ -122,6 +122,7 @@ class AscendMetadata:
122122
123123 # **************************** Basic Properties ****************************
124124 attn_mask : Optional [torch .Tensor ] = None
125+
125126 # Current state of this attention run.
126127 attn_state : AscendAttentionState = AscendAttentionState .ChunkedPrefill
127128
@@ -134,7 +135,9 @@ class AscendMetadata:
134135 seq_lens : torch .Tensor = None
135136
136137 query_start_loc : torch .Tensor = None
138+
137139 query_lens : torch .Tensor = None
140+
138141 # Maximum query length in the batch (None for decoding).
139142 max_query_len : Optional [int ] = None
140143
@@ -325,6 +328,7 @@ def _forward_prefill_no_cache(
325328 ) -> torch .Tensor :
326329 assert attn_metadata is not None
327330 assert attn_metadata .attn_mask is not None
331+
328332 mask = attn_metadata .attn_mask
329333
330334 if is_310p ():
@@ -506,16 +510,17 @@ def forward(
506510
507511 # V0-Style scheduler situation.
508512 if attn_metadata .attn_state == AscendAttentionState .PrefillNoCache :
509- output = self ._forward_prefill_no_cache (attn_metadata , query , key ,
510- value , output , num_tokens )
513+ output = self ._forward_prefill_no_cache (query , key , value ,
514+ attn_metadata , output ,
515+ num_tokens )
511516 elif attn_metadata .attn_state == AscendAttentionState .PrefillCacheHit :
512- output = self ._forward_prefill_cache_hit (attn_metadata , query ,
517+ output = self ._forward_prefill_cache_hit (query , attn_metadata ,
513518 output )
514519 elif attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
515- output = self ._forward_decode_only (attn_metadata , query , output )
520+ output = self ._forward_decode_only (query , attn_metadata , output )
516521 # Normal V1 situation.
517522 else :
518- output = self ._forward_v1_style (attn_metadata , query , output )
523+ output = self ._forward_v1_style (query , attn_metadata , output )
519524
520525 # to make in-place change to the output tensor
521526 ori_output [:, :, :] = output [:num_tokens , :, :]
0 commit comments