@@ -316,6 +316,58 @@ def test_apply_decode_only(self, mock_quant, mock_scatter):
316316 self .assertEqual (mock_scatter .call_count , 2 )
317317 self .assertTrue (torch .equal (result , expected_output ))
318318
319+ @patch ('torch_npu.npu_scatter_nd_update_' )
320+ @patch ("vllm_ascend.quantization.w8a8.quant_per_tensor" )
321+ def test_apply_attn_metadata_without_decode (self , mock_quant ,
322+ mock_scatter ):
323+
324+ num_tokens = 2
325+ query = torch .randn (num_tokens ,
326+ self .layer .num_heads * self .layer .head_size )
327+ key = torch .randn (num_tokens ,
328+ self .layer .num_kv_heads * self .layer .head_size )
329+ value = torch .randn (num_tokens ,
330+ self .layer .num_kv_heads * self .layer .head_size )
331+ output = torch .empty_like (query )
332+
333+ attn_metadata = MagicMock (spec = [
334+ 'attn_state' , 'seq_lens' , 'block_tables' , 'slot_mapping' ,
335+ 'attn_mask'
336+ ])
337+ attn_metadata .attn_state = AscendAttentionState .DecodeOnly
338+ attn_metadata .seq_lens = [10 , 10 ]
339+ attn_metadata .block_tables = torch .tensor ([[0 , 1 ], [1 , 2 ]])
340+ attn_metadata .slot_mapping = torch .tensor ([0 , 1 ])
341+ attn_metadata .attn_mask = None
342+
343+ block_size = 16
344+ key_cache = torch .empty (2 , block_size , self .layer .num_kv_heads ,
345+ self .layer .head_size )
346+ value_cache = torch .empty (2 , block_size , self .layer .num_kv_heads ,
347+ self .layer .head_size )
348+ kv_cache = (key_cache , value_cache )
349+
350+ mock_quant .side_effect = [key , value ]
351+
352+ self .layer .key_antiquant_scale .data = torch .ones (
353+ self .layer .num_kv_heads * self .layer .head_size )
354+ self .layer .value_antiquant_scale .data = torch .ones (
355+ self .layer .num_kv_heads * self .layer .head_size )
356+ self .method .process_weights_after_loading (self .layer )
357+
358+ expected_output = torch .randn (
359+ num_tokens , self .layer .num_heads * self .layer .head_size )
360+ with patch ('torch_npu.npu_incre_flash_attention' ,
361+ return_value = expected_output ):
362+ result = self .method .apply (self .layer , query , key , value , kv_cache ,
363+ attn_metadata ,
364+ self .attention_type .DECODER , 1.0 ,
365+ output )
366+
367+ self .assertEqual (mock_quant .call_count , 2 )
368+ self .assertEqual (mock_scatter .call_count , 2 )
369+ self .assertTrue (torch .equal (result , expected_output ))
370+
319371 @patch ("vllm_ascend.quantization.w8a8.quant_per_tensor" )
320372 @patch ('torch_npu._npu_flash_attention' )
321373 def test_apply_prefill_no_cache (self , mock_flash , mock_quant ):
0 commit comments