Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions tests/e2e/singlecard/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,18 @@ def test_quant_W8A8():
quantization="ascend",
) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)


def test_quant_W8A16():
max_tokens = 5
example_prompts = [
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs."
]
with VllmRunner(
snapshot_download("vllm-ascend/Qwen3-0.6B-W8A16"),
max_model_len=8192,
enforce_eager=False,
gpu_memory_utilization=0.7,
quantization="ascend",
) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
90 changes: 90 additions & 0 deletions tests/ut/quantization/test_w8a16.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from unittest.mock import MagicMock, patch

import torch

from tests.ut.base import TestBase
from vllm_ascend.quantization.w8a16 import AscendW8A16LinearMethod
from vllm_ascend.utils import AscendDeviceType


class TestAscendW8A16LinearMethod(TestBase):

def setUp(self):
self.method = AscendW8A16LinearMethod()

def test_get_weight(self):
weight = self.method.get_weight(10, 20)
self.assertEqual(weight['weight'].dtype, torch.int8)
self.assertEqual(weight['weight'].shape, (20, 10))

@patch("torch_npu.npu_weight_quant_batchmatmul")
def test_apply_with_x_is_int8(self, mock_npu_weight_quant_batchmatmul):
layer = MagicMock()
layer.weight.data = torch.randn(128, 256)
layer.weight_scale.data = torch.randn(128, 1)
layer.weight_offset.data = torch.randn(128, 1)

x = torch.randn(32, 128)
bias = torch.randn(256)

expected_y_output = torch.randn(32, 256)
mock_npu_weight_quant_batchmatmul.return_value = expected_y_output

output = self.method.apply(layer, x, bias)
expected_y_output += bias
self.assertTrue(torch.equal(output, expected_y_output))

@patch('vllm_ascend.utils.get_ascend_device_type',
return_value=AscendDeviceType._310P)
@patch("torch_npu.npu_weight_quant_batchmatmul")
def test_apply_with_x_is_310p(self, mock_npu_weight_quant_batchmatmul,
mock_soc_version):
layer = MagicMock()
layer.weight.data = torch.randn(128, 256)
layer.weight_scale.data = torch.randn(128, 1)
layer.weight_offset.data = torch.randn(128, 1)

x = torch.randn(32, 128)
bias = torch.randn(256)

expected_y_output = torch.randn(32, 256)
mock_npu_weight_quant_batchmatmul.return_value = expected_y_output

output = self.method.apply(layer, x, bias)
expected_y_output += bias
self.assertTrue(torch.equal(output, expected_y_output))

@patch("vllm_ascend.quantization.w8a16.is_enable_nz")
@patch('torch_npu.npu_format_cast')
def test_process_weights_after_loading_not_nz(self, mock_npu_format_cast,
mock_is_nz):
layer = MagicMock()
layer.weight.data = torch.randn(128, 256)
layer.weight_scale.data = torch.randn(128, 1)
layer.weight_offset.data = torch.randn(128, 1)

mock_is_nz.return_value = 0
mock_npu_format_cast.return_value = MagicMock
self.method.process_weights_after_loading(layer)

self.assertEqual(layer.weight_scale.data.shape, (128, ))
self.assertEqual(layer.weight_offset.data.shape, (128, ))
mock_npu_format_cast.assert_not_called()

@patch("vllm_ascend.quantization.w8a16.is_enable_nz")
@patch('torch_npu.npu_format_cast')
def test_process_weights_after_loading_nz(self, mock_npu_format_cast,
mock_is_nz):
layer = MagicMock()

layer.weight.data = torch.randn(128, 256)
layer.weight_scale.data = torch.randn(128, 1)
layer.weight_offset.data = torch.randn(128, 1)

mock_is_nz.return_value = 1
mock_npu_format_cast.return_value = MagicMock
self.method.process_weights_after_loading(layer)

self.assertEqual(layer.weight_scale.data.shape, (128, ))
self.assertEqual(layer.weight_offset.data.shape, (128, ))
mock_npu_format_cast.assert_called_once()
4 changes: 4 additions & 0 deletions vllm_ascend/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
AscendW8A8LinearMethod)
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
AscendW8A8DynamicLinearMethod)
from .w8a16 import AscendW8A16LinearMethod

ASCEND_QUANTIZATION_METHOD_MAP: Dict[str, Dict[str, Type[Any]]] = {
"W4A8_DYNAMIC": {
Expand All @@ -33,6 +34,9 @@
"C8": {
"attention": AscendC8KVCacheMethod,
},
"W8A16": {
"linear": AscendW8A16LinearMethod,
}
}


Expand Down
105 changes: 105 additions & 0 deletions vllm_ascend/quantization/w8a16.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from typing import Any, Dict, Optional

import torch
import torch_npu

from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendDeviceType,
get_ascend_device_type, is_enable_nz)


class AscendW8A16LinearMethod:
"""Linear method for Ascend W8A16.

"""

def __init__(self) -> None:
# aclnn quant matmul requires to transpose matrix B, set to true by default.
self.transpose_weight = get_ascend_device_type(
) != AscendDeviceType._310P

@staticmethod
def get_weight(
input_size: int,
output_size: int,
params_dtype: torch.dtype = torch.bfloat16,
) -> Dict[str, Any]:
params_dict = {
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
}
return params_dict

@staticmethod
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
return {}

@staticmethod
def get_perchannel_param(
output_size: int,
params_dtype: torch.dtype,
) -> Dict[str, Any]:
params_dict = {}
params_dict["weight_scale"] = torch.empty(output_size,
1,
dtype=params_dtype)
params_dict["weight_offset"] = torch.empty(output_size,
1,
dtype=params_dtype)
return params_dict

def get_pergroup_param(self,
input_size: int,
output_size: int,
params_dtype: torch.dtype,
layer_type: Optional[str] = None) -> Dict[str, Any]:
return {}

@staticmethod
def apply(
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
tp_rank: Optional[int] = 0,
) -> torch.Tensor:
if get_ascend_device_type() == AscendDeviceType._310P:
# On 300I Duo platform, we need transpose again if
# using nz. This transpose can be skipped in torchair.
output = torch_npu.npu_weight_quant_batchmatmul(
x=x,
weight=layer.weight.data.transpose(0, 1),
antiquant_scale=layer.weight_scale,
antiquant_offset=layer.weight_offset,
bias=bias)
else:
output = torch_npu.npu_weight_quant_batchmatmul(
x=x,
weight=layer.weight,
antiquant_scale=layer.weight_scale,
antiquant_offset=layer.weight_offset,
bias=bias)
return output

def process_weights_after_loading(self, layer):
if self.transpose_weight:
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
if is_enable_nz():
layer.weight.data = torch_npu.npu_format_cast(
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
Loading