Skip to content

Commit 0c4f231

Browse files
qyqc731angazenn
authored andcommitted
supply ut for new branch in w8a8
Signed-off-by: tianyitang <tangtianyi4@huawei.com>
1 parent 4239f10 commit 0c4f231

File tree

1 file changed

+52
-0
lines changed

1 file changed

+52
-0
lines changed

tests/ut/quantization/test_w8a8.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,58 @@ def test_apply_decode_only(self, mock_quant, mock_scatter):
316316
self.assertEqual(mock_scatter.call_count, 2)
317317
self.assertTrue(torch.equal(result, expected_output))
318318

319+
@patch('torch_npu.npu_scatter_nd_update_')
320+
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
321+
def test_apply_attn_metadata_without_decode(self, mock_quant,
322+
mock_scatter):
323+
324+
num_tokens = 2
325+
query = torch.randn(num_tokens,
326+
self.layer.num_heads * self.layer.head_size)
327+
key = torch.randn(num_tokens,
328+
self.layer.num_kv_heads * self.layer.head_size)
329+
value = torch.randn(num_tokens,
330+
self.layer.num_kv_heads * self.layer.head_size)
331+
output = torch.empty_like(query)
332+
333+
attn_metadata = MagicMock(spec=[
334+
'attn_state', 'seq_lens', 'block_tables', 'slot_mapping',
335+
'attn_mask'
336+
])
337+
attn_metadata.attn_state = AscendAttentionState.DecodeOnly
338+
attn_metadata.seq_lens = [10, 10]
339+
attn_metadata.block_tables = torch.tensor([[0, 1], [1, 2]])
340+
attn_metadata.slot_mapping = torch.tensor([0, 1])
341+
attn_metadata.attn_mask = None
342+
343+
block_size = 16
344+
key_cache = torch.empty(2, block_size, self.layer.num_kv_heads,
345+
self.layer.head_size)
346+
value_cache = torch.empty(2, block_size, self.layer.num_kv_heads,
347+
self.layer.head_size)
348+
kv_cache = (key_cache, value_cache)
349+
350+
mock_quant.side_effect = [key, value]
351+
352+
self.layer.key_antiquant_scale.data = torch.ones(
353+
self.layer.num_kv_heads * self.layer.head_size)
354+
self.layer.value_antiquant_scale.data = torch.ones(
355+
self.layer.num_kv_heads * self.layer.head_size)
356+
self.method.process_weights_after_loading(self.layer)
357+
358+
expected_output = torch.randn(
359+
num_tokens, self.layer.num_heads * self.layer.head_size)
360+
with patch('torch_npu.npu_incre_flash_attention',
361+
return_value=expected_output):
362+
result = self.method.apply(self.layer, query, key, value, kv_cache,
363+
attn_metadata,
364+
self.attention_type.DECODER, 1.0,
365+
output)
366+
367+
self.assertEqual(mock_quant.call_count, 2)
368+
self.assertEqual(mock_scatter.call_count, 2)
369+
self.assertTrue(torch.equal(result, expected_output))
370+
319371
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
320372
@patch('torch_npu._npu_flash_attention')
321373
def test_apply_prefill_no_cache(self, mock_flash, mock_quant):

0 commit comments

Comments
 (0)