Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensors24,
CompressedTensorsScheme, CompressedTensorsW4A4Fp4,
CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
CompressedTensorsW4A8Int, CompressedTensorsW4A16Fp4,
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
find_matched_target, is_activation_quantization_format,
should_ignore_layer)
Expand Down Expand Up @@ -74,7 +75,7 @@ def get_linear_method(self) -> "CompressedTensorsLinearMethod":
return CompressedTensorsLinearMethod(self)

def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
return [torch.float32, torch.float16, torch.bfloat16]

@classmethod
def get_min_capability(cls) -> int:
Expand Down Expand Up @@ -299,6 +300,22 @@ def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
# Only symmetric weight quantization supported.
return is_8_bits and is_token and weight_quant.symmetric and is_dynamic

def _is_dynamic_token_w4a8_int(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
is_weight_4_bits = weight_quant.num_bits == 4
is_activation_8_bits = input_quant.num_bits == 8
weight_strategy = (
weight_quant.strategy == QuantizationStrategy.GROUP.value
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
is_token = (weight_strategy and input_quant.strategy
== QuantizationStrategy.TOKEN.value)
is_dynamic = not weight_quant.dynamic and input_quant.dynamic

# Both symmetric and asymmetric input quantization supported.
# Only symmetric weight quantization supported.
return (is_weight_4_bits and is_activation_8_bits and is_token
and weight_quant.symmetric and is_dynamic)

def _is_fp8_w8a8(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
# Confirm weights and activations quantized.
Expand Down Expand Up @@ -374,7 +391,6 @@ def _is_wNa16_group_channel(self, weight_quant: BaseModel,
def _get_scheme_from_parts(
self, weight_quant: BaseModel,
input_quant: BaseModel) -> "CompressedTensorsScheme":

# Detect If Mixed Precision
if self._is_fp4a16_nvfp4(weight_quant, input_quant):
return CompressedTensorsW4A16Fp4()
Expand Down Expand Up @@ -443,6 +459,16 @@ def _get_scheme_from_parts(
is_static_input_scheme=False,
input_symmetric=input_quant.symmetric)

if self._is_dynamic_token_w4a8_int(weight_quant, input_quant):
is_static_input_scheme = (input_quant
and not input_quant.dynamic)
return CompressedTensorsW4A8Int(
num_bits=weight_quant.num_bits,
strategy=weight_quant.strategy,
group_size=weight_quant.group_size,
is_static_input_scheme=is_static_input_scheme,
input_symmetric=input_quant.symmetric)

raise NotImplementedError(
"No compressed-tensors compatible scheme was found.")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from .compressed_tensors_scheme import CompressedTensorsScheme
from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4
from .compressed_tensors_w4a8_int import CompressedTensorsW4A8Int
from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS,
CompressedTensorsW4A16Sparse24)
from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4
Expand All @@ -20,5 +21,5 @@
"CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8",
"WNA16_SUPPORTED_BITS", "W4A16SPARSE24_SUPPORTED_BITS",
"CompressedTensors24", "CompressedTensorsW4A16Fp4",
"CompressedTensorsW4A4Fp4"
"CompressedTensorsW4A4Fp4", "CompressedTensorsW4A8Int"
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import Callable, Optional

import torch

from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
MPLinearLayerConfig, choose_mp_linear_kernel)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
ModelWeightParameter)
from vllm.scalar_type import scalar_types

logger = init_logger(__name__)

__all__ = ["CompressedTensorsW4A8Int"]
W4A8_SUPPORTED_TYPES_MAP = {
4: scalar_types.int4,
}
W4A8_SUPPORTED_BITS = list(W4A8_SUPPORTED_TYPES_MAP.keys())


class CompressedTensorsW4A8Int(CompressedTensorsScheme):
_kernel_backends_being_used: set[str] = set()

def __init__(self,
strategy: str,
num_bits: int,
group_size: Optional[int] = None,
is_static_input_scheme: bool = False,
input_symmetric: bool = True):
self.strategy = strategy
self.group_size = -1 if group_size is None else group_size
self.is_static_input_scheme = is_static_input_scheme
self.input_symmetric = input_symmetric

if num_bits not in W4A8_SUPPORTED_TYPES_MAP:
raise ValueError(
f"Unsupported num_bits = {num_bits}."
f"Supported num_bits = {W4A8_SUPPORTED_TYPES_MAP.keys()}")
self.quant_type = W4A8_SUPPORTED_TYPES_MAP[num_bits]

@classmethod
def get_min_capability(cls) -> int:
return 1

def create_weights(self, layer: torch.nn.Module, output_size: int,
input_size: int, output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
output_size_per_partition = sum(output_partition_sizes)
row_parallel = (input_size != input_size_per_partition)

# Compute effective group_size
if self.group_size == -1:
effective_group_size = (input_size_per_partition
if row_parallel else input_size)
else:
effective_group_size = self.group_size

# Ensure group_size divides input_size_per_partition
assert input_size_per_partition % effective_group_size == 0, (
f"input_size_per_partition {input_size_per_partition}"
f" not divisible by group_size {effective_group_size}")

# Determine scale partitioning
is_channelwise = (self.group_size == -1)
repeat_scales = (is_channelwise and row_parallel)
partition_scales = not repeat_scales

mp_linear_kernel_config = MPLinearLayerConfig(
full_weight_shape=(input_size, output_size),
partition_weight_shape=(input_size_per_partition,
output_size_per_partition),
weight_type=self.quant_type,
act_type=params_dtype,
group_size=effective_group_size,
zero_points=False,
has_g_idx=False,
)

kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for CompressedTensorsW4A8Int",
kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__)

scales_and_zp_size = input_size_per_partition // effective_group_size

weight = ModelWeightParameter(data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=torch.int8),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight", weight)

weight_scale_args = {
"weight_loader":
weight_loader,
"data":
torch.empty(output_size_per_partition,
scales_and_zp_size,
dtype=params_dtype)
}

if partition_scales:
weight_scale = GroupQuantScaleParameter(output_dim=0,
input_dim=1,
**weight_scale_args)
else:
weight_scale = ChannelQuantScaleParameter(output_dim=0,
**weight_scale_args)

layer.register_parameter("weight_packed", weight)
layer.register_parameter("weight_scale", weight_scale)

self.kernel = kernel_type(mp_linear_kernel_config,
w_q_param_name="weight_packed",
w_s_param_name="weight_scale",
w_zp_param_name=None,
w_gidx_param_name=None)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.kernel.process_weights_after_loading(layer)

def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return self.kernel.apply_weights(layer, x, bias)
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
BitBLASLinearKernel)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.conch import ( # noqa: E501
ConchLinearKernel)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.dynamic_4bit import ( # noqa: E501
Dynamic4bitLinearKernel)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501
ExllamaLinearKernel)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501
Expand All @@ -25,6 +27,7 @@
MacheteLinearKernel,
AllSparkLinearKernel,
MarlinLinearKernel,
Dynamic4bitLinearKernel,
BitBLASLinearKernel,
ConchLinearKernel,
ExllamaLinearKernel,
Expand Down Expand Up @@ -56,20 +59,21 @@ def choose_mp_linear_kernel(
if current_platform is None:
raise ValueError("Cannot determine compute capability")
_cc = current_platform.get_device_capability()
compute_capability = _cc[0] * 10 + _cc[1]
if _cc is not None:
compute_capability = _cc[0] * 10 + _cc[1]

failure_reasons = []
for kernel in _POSSIBLE_KERNELS:
if kernel.__name__ in envs.VLLM_DISABLED_KERNELS:
failure_reasons.append(
f' {kernel.__name__} disabled by environment variable')
continue

if kernel.get_min_capability() > compute_capability:
if (compute_capability is not None
and kernel.get_min_capability() > compute_capability):
failure_reasons.append(
f"{kernel.__name__} requires capability "
f"{kernel.get_min_capability()}, current compute capability "
f"is {compute_capability}")
f"{kernel.get_min_capability()}, current compute "
f" capability is {compute_capability}")
continue

can_implement, failure_reason = kernel.can_implement(config)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import Optional

import torch

from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.platforms import CpuArchEnum, current_platform
from vllm.scalar_type import scalar_types

from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig


class Dynamic4bitLinearKernel(MPLinearKernel):
SUPPORTED_QUANT_TYPES = [scalar_types.int4]

@classmethod
def get_min_capability(cls) -> int:
return 1

@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
if not current_platform.is_cpu():
return False, "Only CPU is supported"
if c.weight_type not in cls.SUPPORTED_QUANT_TYPES:
return False, f"Unsupported quant type {c.weight_type}"
if current_platform.get_cpu_architecture(
) == CpuArchEnum.ARM and c.act_type not in [
torch.float32,
]:
return False, "Dynamic4bitLinearKernel on Arm requires"\
" Float32 activations"
if c.full_weight_shape[0] % c.group_size != 0:
return False, f"Group size ({c.group_size}) does not evenly divide"\
" the number of input features "\
f"({c.full_weight_shape[0]})"
if current_platform.get_cpu_architecture() == CpuArchEnum.ARM:
try:
# Attempt to retrieve the operation
_ = torch.ops.aten._dyn_quant_matmul_4bit
except AttributeError:
return False, f"PyTorch {torch.__version__} does not support"\
" _dyn_quant_matmul_4bit. Install a newer version"
return True, None

def process_weights_after_loading(self, layer: torch.nn.Module):
c = self.config
packed_weight = getattr(layer, self.w_q_name)
packed_weight = packed_weight.add(8)
uint8_packed = (packed_weight[::, 1::2] << 4
| packed_weight[::, ::2]).to(torch.uint8)

scales = getattr(layer, self.w_s_name)
block_size = c.group_size

# Handle scaling factors for partitioned weights
if block_size == c.partition_weight_shape[0]:
scales = scales.to(
torch.float32
) # Float32 & Bfloat16 variants requires float32 scales
scales = scales.view(-1, 1) # Channel-wise scales
if layer.bias is not None:
layer.bias = layer.bias.to(
torch.float32
) # Float32 & Bfloat16 variants requires float32 bias
else:
# KleidiAI kernel requires bfloat16 scales with groupwise scheme
scales = scales.to(torch.bfloat16)

# Repack weights as per kernel requirement
w = torch.ops.aten._dyn_quant_pack_4bit_weight(
uint8_packed, scales, layer.bias, block_size,
c.partition_weight_shape[0], c.partition_weight_shape[1])
replace_parameter(layer, self.w_q_name,
torch.nn.Parameter(w, requires_grad=False))
setattr(layer, self.w_s_name, None)

def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
c = self.config
x_2d = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )

w_q = getattr(layer, self.w_q_name)
output = torch.ops.aten._dyn_quant_matmul_4bit(
x_2d, w_q, c.group_size, c.partition_weight_shape[0],
c.partition_weight_shape[1])
return output.reshape(out_shape)