diff --git a/tests/e2e/singlecard/test_quantization.py b/tests/e2e/singlecard/test_quantization.py index 6ab54084357..9a5c8fcccbb 100644 --- a/tests/e2e/singlecard/test_quantization.py +++ b/tests/e2e/singlecard/test_quantization.py @@ -18,6 +18,7 @@ from modelscope import snapshot_download # type: ignore[import-untyped] from tests.e2e.conftest import VllmRunner +from tests.e2e.model_utils import check_outputs_equal def test_qwen3_w8a8_quant(): @@ -25,10 +26,53 @@ def test_qwen3_w8a8_quant(): example_prompts = [ "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs." ] + vllm_target_outputs = [([ + 85, 4086, 44, 374, 264, 1550, 42747, 628, 323, 4938, 72816, 44378, 323, + 13480, 4712, 369, 444, 10994, 82, 13, 1084, 374, 6188, 311, 387 + ], 'vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be' + )] + with VllmRunner( snapshot_download("vllm-ascend/Qwen3-0.6B-W8A8"), max_model_len=8192, gpu_memory_utilization=0.7, quantization="ascend", ) as vllm_model: - vllm_model.generate_greedy(example_prompts, max_tokens) + vllm_quant_w8a8_outputs = vllm_model.generate_greedy( + example_prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=vllm_target_outputs, + outputs_1_lst=vllm_quant_w8a8_outputs, + name_0="vllm_target_outputs", + name_1="vllm_w8a16_outputs", + ) + + +def test_qwen3_dense_w8a16(): + max_tokens = 5 + example_prompts = [ + "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs." + ] + vllm_target_outputs = [([ + 85, 4086, 44, 374, 264, 1550, 42747, 628, 323, 4938, 72816, 44378, 323, + 13480, 4712, 369, 444, 10994, 82, 13, 1084, 374, 6188, 311, 387 + ], 'vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be' + )] + + with VllmRunner( + snapshot_download("vllm-ascend/Qwen3-0.6B-W8A16"), + max_model_len=8192, + enforce_eager=False, + gpu_memory_utilization=0.7, + quantization="ascend", + ) as vllm_model: + vllm_quant_w8a16_outputs = vllm_model.generate_greedy( + example_prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=vllm_target_outputs, + outputs_1_lst=vllm_quant_w8a16_outputs, + name_0="vllm_target_outputs", + name_1="vllm_w8a16_outputs", + ) diff --git a/tests/ut/quantization/test_w8a16.py b/tests/ut/quantization/test_w8a16.py new file mode 100644 index 00000000000..1d839bfa763 --- /dev/null +++ b/tests/ut/quantization/test_w8a16.py @@ -0,0 +1,91 @@ +import os +from unittest.mock import MagicMock, patch + +import torch + +from tests.ut.base import TestBase +from vllm_ascend.quantization.w8a16 import AscendW8A16LinearMethod + + +class TestAscendW8A16LinearMethod(TestBase): + + def setUp(self): + self.method = AscendW8A16LinearMethod() + + def test_get_weight(self): + weight = self.method.get_weight(10, 20) + self.assertEqual(weight['weight'].dtype, torch.int8) + self.assertEqual(weight['weight'].shape, (20, 10)) + + @patch("torch_npu.npu_weight_quant_batchmatmul") + def test_apply_with_x_is_int8(self, mock_npu_weight_quant_batchmatmul): + layer = MagicMock() + layer.weight.data = torch.randn(128, 256) + layer.weight_scale.data = torch.randn(128, 1) + layer.weight_offset.data = torch.randn(128, 1) + + x = torch.randn(32, 128) + bias = torch.randn(256) + + expected_y_output = torch.randn(32, 256) + mock_npu_weight_quant_batchmatmul.return_value = expected_y_output + + output = self.method.apply(layer, x, bias) + expected_y_output += bias + self.assertTrue(torch.equal(output, expected_y_output)) + + @patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "0"}) + @patch('torch_npu.npu_format_cast') + def test_process_weights_after_loading_with_nz0(self, + mock_npu_format_cast): + layer = MagicMock() + layer.weight.data = torch.randint(-127, + 128, (128, 256), + dtype=torch.int8) + layer.weight_scale.data = torch.randn(128, 1) + layer.weight_offset.data = torch.randn(128, 1) + + mock_npu_format_cast.return_value = MagicMock + self.method.process_weights_after_loading(layer) + + self.assertEqual(layer.weight_scale.data.shape, (128, )) + self.assertEqual(layer.weight_offset.data.shape, (128, )) + mock_npu_format_cast.assert_not_called() + + @patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "1"}) + @patch('torch_npu.npu_format_cast') + def test_process_weights_after_loading_with_nz1(self, + mock_npu_format_cast): + layer = MagicMock() + + layer.weight.data = torch.randint(-127, + 128, (128, 256), + dtype=torch.int8) + layer.weight_scale.data = torch.randn(128, 1) + layer.weight_offset.data = torch.randn(128, 1) + + mock_npu_format_cast.return_value = MagicMock + self.method.process_weights_after_loading(layer) + + self.assertEqual(layer.weight_scale.data.shape, (128, )) + self.assertEqual(layer.weight_offset.data.shape, (128, )) + mock_npu_format_cast.assert_called_once() + + @patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "2"}) + @patch('torch_npu.npu_format_cast') + def test_process_weights_after_loading_with_nz2(self, + mock_npu_format_cast): + layer = MagicMock() + + layer.weight.data = torch.randint(-127, + 128, (128, 256), + dtype=torch.int8) + layer.weight_scale.data = torch.randn(128, 1) + layer.weight_offset.data = torch.randn(128, 1) + + mock_npu_format_cast.return_value = MagicMock + self.method.process_weights_after_loading(layer) + + self.assertEqual(layer.weight_scale.data.shape, (128, )) + self.assertEqual(layer.weight_offset.data.shape, (128, )) + mock_npu_format_cast.assert_called_once() diff --git a/vllm_ascend/quantization/utils.py b/vllm_ascend/quantization/utils.py index 58fe6db131b..71db5269b09 100644 --- a/vllm_ascend/quantization/utils.py +++ b/vllm_ascend/quantization/utils.py @@ -14,6 +14,7 @@ AscendW8A8DynamicLinearMethod) from .w8a8_pdmix import (AscendW8A8PDMixFusedMoeMethod, AscendW8A8PDMixLinearMethod) +from .w8a16 import AscendW8A16LinearMethod ASCEND_QUANTIZATION_METHOD_MAP: Dict[str, Dict[str, Type[Any]]] = { "W4A16": { @@ -36,6 +37,9 @@ "W8A8_MIX": { "linear": AscendW8A8PDMixLinearMethod, "moe": AscendW8A8PDMixFusedMoeMethod, + }, + "W8A16": { + "linear": AscendW8A16LinearMethod, } } diff --git a/vllm_ascend/quantization/w8a16.py b/vllm_ascend/quantization/w8a16.py new file mode 100644 index 00000000000..1e66c5e8420 --- /dev/null +++ b/vllm_ascend/quantization/w8a16.py @@ -0,0 +1,89 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any, Dict, Optional + +import torch +import torch_npu + +from vllm_ascend.utils import maybe_trans_nz + + +class AscendW8A16LinearMethod: + """Linear method for Ascend W8A16. + + """ + + def __init__(self) -> None: + pass + + @staticmethod + def get_weight( + input_size: int, + output_size: int, + params_dtype: torch.dtype = torch.bfloat16, + ) -> Dict[str, Any]: + params_dict = { + "weight": torch.empty(output_size, input_size, dtype=torch.int8) + } + return params_dict + + @staticmethod + def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]: + return {} + + @staticmethod + def get_perchannel_param( + output_size: int, + params_dtype: torch.dtype, + ) -> Dict[str, Any]: + params_dict = {} + params_dict["weight_scale"] = torch.empty(output_size, + 1, + dtype=params_dtype) + params_dict["weight_offset"] = torch.empty(output_size, + 1, + dtype=params_dtype) + return params_dict + + def get_pergroup_param(self, + input_size: int, + output_size: int, + params_dtype: torch.dtype, + layer_type: Optional[str] = None) -> Dict[str, Any]: + return {} + + @staticmethod + def apply( + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + tp_rank: Optional[int] = 0, + ) -> torch.Tensor: + output = torch_npu.npu_weight_quant_batchmatmul( + x=x, + weight=layer.weight, + antiquant_scale=layer.weight_scale, + antiquant_offset=layer.weight_offset, + bias=bias) + return output + + def process_weights_after_loading(self, layer): + layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() + layer.weight.data = maybe_trans_nz(layer.weight.data) + layer.weight_scale.data = torch.flatten(layer.weight_scale.data) + layer.weight_offset.data = torch.flatten(layer.weight_offset.data)