1+ import unittest
2+ from unittest .mock import MagicMock , patch
3+
4+ import torch
5+
6+ from tests .ut .base import TestBase
7+ from vllm_ascend .quantization .w8a16 import AscendW8A16LinearMethod
8+
9+
10+ class TestAscendW8A16LinearMethod (TestBase ):
11+
12+ def setUp (self ):
13+ self .method = AscendW8A16LinearMethod ()
14+
15+ def test_get_weight (self ):
16+ weight = self .method .get_weight (10 , 20 )
17+ self .assertEqual (weight ['weight' ].dtype , torch .int8 )
18+ self .assertEqual (weight ['weight' ].shape , (20 , 10 ))
19+
20+ @patch ("torch_npu.npu_weight_quant_batchmatmul" )
21+ def test_apply_with_x_is_int8 (self , mock_npu_weight_quant_batchmatmul ):
22+ layer = MagicMock ()
23+ layer .weight .data = torch .randn (128 , 256 )
24+ layer .weight_scale .data = torch .randn (128 , 1 )
25+ layer .weight_offset .data = torch .randn (128 , 1 )
26+
27+ x = torch .randn (32 , 128 )
28+ bias = torch .randn (256 )
29+
30+ expected_y_output = torch .randn (32 , 256 )
31+ mock_npu_weight_quant_batchmatmul .return_value = expected_y_output
32+
33+ output = self .method .apply (layer , x , bias )
34+ expected_y_output += bias
35+ self .assertTrue (torch .equal (output , expected_y_output ))
36+
37+ @patch ("vllm_ascend.quantization.w8a16.is_310p" , return_value = True )
38+ @patch ("torch_npu.npu_weight_quant_batchmatmul" )
39+ def test_apply_with_x_is_310p (self , mock_npu_weight_quant_batchmatmul , mock_is_310p ):
40+ layer = MagicMock ()
41+ layer .weight .data = torch .randn (128 , 256 )
42+ layer .weight_scale .data = torch .randn (128 , 1 )
43+ layer .weight_offset .data = torch .randn (128 , 1 )
44+
45+ x = torch .randn (32 , 128 )
46+ bias = torch .randn (256 )
47+
48+ expected_y_output = torch .randn (32 , 256 )
49+ mock_npu_weight_quant_batchmatmul .return_value = expected_y_output
50+
51+ output = self .method .apply (layer , x , bias )
52+ expected_y_output += bias
53+ self .assertTrue (torch .equal (output , expected_y_output ))
54+
55+ @patch ("vllm_ascend.quantization.w8a16.is_enable_nz" )
56+ @patch ('torch_npu.npu_format_cast' )
57+ def test_process_weights_after_loading_not_nz (self , mock_npu_format_cast ,
58+ mock_is_nz ):
59+ layer = MagicMock ()
60+ layer .weight .data = torch .randn (128 , 256 )
61+ layer .weight_scale .data = torch .randn (128 , 1 )
62+ layer .weight_offset .data = torch .randn (128 , 1 )
63+
64+ mock_is_nz .return_value = 0
65+ mock_npu_format_cast .return_value = MagicMock
66+ self .method .process_weights_after_loading (layer )
67+
68+ self .assertEqual (layer .weight_scale .data .shape , (128 , ))
69+ self .assertEqual (layer .weight_offset .data .shape , (128 , ))
70+ mock_npu_format_cast .assert_not_called ()
71+
72+ @patch ("vllm_ascend.quantization.w8a16.is_enable_nz" )
73+ @patch ('torch_npu.npu_format_cast' )
74+ def test_process_weights_after_loading_nz (self , mock_npu_format_cast ,
75+ mock_is_nz ):
76+ layer = MagicMock ()
77+
78+ layer .weight .data = torch .randn (128 , 256 )
79+ layer .weight_scale .data = torch .randn (128 , 1 )
80+ layer .weight_offset .data = torch .randn (128 , 1 )
81+
82+ mock_is_nz .return_value = 1
83+ mock_npu_format_cast .return_value = MagicMock
84+ self .method .process_weights_after_loading (layer )
85+
86+ self .assertEqual (layer .weight_scale .data .shape , (128 , ))
87+ self .assertEqual (layer .weight_offset .data .shape , (128 , ))
88+ mock_npu_format_cast .assert_called_once ()
0 commit comments