-
-
Notifications
You must be signed in to change notification settings - Fork 11.3k
[Feat]: Add support for Dynamic Quant 4 bit CPU kleidiai kernels #17112
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
135 changes: 135 additions & 0 deletions
135
...el_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_int.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,135 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| from typing import Callable, Optional | ||
|
|
||
| import torch | ||
|
|
||
| from vllm.logger import init_logger | ||
| from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( | ||
| CompressedTensorsScheme) | ||
| from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( | ||
| MPLinearLayerConfig, choose_mp_linear_kernel) | ||
| from vllm.model_executor.parameter import (ChannelQuantScaleParameter, | ||
| GroupQuantScaleParameter, | ||
| ModelWeightParameter) | ||
| from vllm.scalar_type import scalar_types | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
| __all__ = ["CompressedTensorsW4A8Int"] | ||
| W4A8_SUPPORTED_TYPES_MAP = { | ||
| 4: scalar_types.int4, | ||
| } | ||
| W4A8_SUPPORTED_BITS = list(W4A8_SUPPORTED_TYPES_MAP.keys()) | ||
|
|
||
|
|
||
| class CompressedTensorsW4A8Int(CompressedTensorsScheme): | ||
| _kernel_backends_being_used: set[str] = set() | ||
|
|
||
| def __init__(self, | ||
| strategy: str, | ||
| num_bits: int, | ||
| group_size: Optional[int] = None, | ||
| is_static_input_scheme: bool = False, | ||
| input_symmetric: bool = True): | ||
| self.strategy = strategy | ||
| self.group_size = -1 if group_size is None else group_size | ||
| self.is_static_input_scheme = is_static_input_scheme | ||
| self.input_symmetric = input_symmetric | ||
|
|
||
| if num_bits not in W4A8_SUPPORTED_TYPES_MAP: | ||
| raise ValueError( | ||
| f"Unsupported num_bits = {num_bits}." | ||
| f"Supported num_bits = {W4A8_SUPPORTED_TYPES_MAP.keys()}") | ||
| self.quant_type = W4A8_SUPPORTED_TYPES_MAP[num_bits] | ||
|
|
||
| @classmethod | ||
| def get_min_capability(cls) -> int: | ||
| return 1 | ||
|
|
||
| def create_weights(self, layer: torch.nn.Module, output_size: int, | ||
| input_size: int, output_partition_sizes: list[int], | ||
| input_size_per_partition: int, | ||
| params_dtype: torch.dtype, weight_loader: Callable, | ||
| **kwargs): | ||
| output_size_per_partition = sum(output_partition_sizes) | ||
| row_parallel = (input_size != input_size_per_partition) | ||
|
|
||
| # Compute effective group_size | ||
| if self.group_size == -1: | ||
| effective_group_size = (input_size_per_partition | ||
| if row_parallel else input_size) | ||
| else: | ||
| effective_group_size = self.group_size | ||
|
|
||
| # Ensure group_size divides input_size_per_partition | ||
| assert input_size_per_partition % effective_group_size == 0, ( | ||
| f"input_size_per_partition {input_size_per_partition}" | ||
| f" not divisible by group_size {effective_group_size}") | ||
|
|
||
| # Determine scale partitioning | ||
| is_channelwise = (self.group_size == -1) | ||
| repeat_scales = (is_channelwise and row_parallel) | ||
| partition_scales = not repeat_scales | ||
|
|
||
| mp_linear_kernel_config = MPLinearLayerConfig( | ||
| full_weight_shape=(input_size, output_size), | ||
| partition_weight_shape=(input_size_per_partition, | ||
| output_size_per_partition), | ||
| weight_type=self.quant_type, | ||
| act_type=params_dtype, | ||
| group_size=effective_group_size, | ||
| zero_points=False, | ||
| has_g_idx=False, | ||
| ) | ||
|
|
||
| kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config) | ||
| if kernel_type.__name__ not in self._kernel_backends_being_used: | ||
| logger.info("Using %s for CompressedTensorsW4A8Int", | ||
| kernel_type.__name__) | ||
| self._kernel_backends_being_used.add(kernel_type.__name__) | ||
|
|
||
| scales_and_zp_size = input_size_per_partition // effective_group_size | ||
|
|
||
| weight = ModelWeightParameter(data=torch.empty( | ||
| output_size_per_partition, | ||
| input_size_per_partition, | ||
| dtype=torch.int8), | ||
| input_dim=1, | ||
| output_dim=0, | ||
| weight_loader=weight_loader) | ||
| layer.register_parameter("weight", weight) | ||
|
|
||
| weight_scale_args = { | ||
| "weight_loader": | ||
| weight_loader, | ||
| "data": | ||
| torch.empty(output_size_per_partition, | ||
| scales_and_zp_size, | ||
| dtype=params_dtype) | ||
| } | ||
|
|
||
| if partition_scales: | ||
| weight_scale = GroupQuantScaleParameter(output_dim=0, | ||
| input_dim=1, | ||
| **weight_scale_args) | ||
| else: | ||
| weight_scale = ChannelQuantScaleParameter(output_dim=0, | ||
| **weight_scale_args) | ||
|
|
||
| layer.register_parameter("weight_packed", weight) | ||
| layer.register_parameter("weight_scale", weight_scale) | ||
|
|
||
| self.kernel = kernel_type(mp_linear_kernel_config, | ||
| w_q_param_name="weight_packed", | ||
| w_s_param_name="weight_scale", | ||
| w_zp_param_name=None, | ||
| w_gidx_param_name=None) | ||
|
|
||
| def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | ||
| self.kernel.process_weights_after_loading(layer) | ||
|
|
||
| def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, | ||
| bias: Optional[torch.Tensor]) -> torch.Tensor: | ||
| return self.kernel.apply_weights(layer, x, bias) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
92 changes: 92 additions & 0 deletions
92
vllm/model_executor/layers/quantization/kernels/mixed_precision/dynamic_4bit.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,92 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| from typing import Optional | ||
|
|
||
| import torch | ||
|
|
||
| from vllm.model_executor.layers.quantization.utils import replace_parameter | ||
| from vllm.platforms import CpuArchEnum, current_platform | ||
| from vllm.scalar_type import scalar_types | ||
|
|
||
| from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig | ||
|
|
||
|
|
||
| class Dynamic4bitLinearKernel(MPLinearKernel): | ||
| SUPPORTED_QUANT_TYPES = [scalar_types.int4] | ||
|
|
||
| @classmethod | ||
| def get_min_capability(cls) -> int: | ||
| return 1 | ||
|
|
||
| @classmethod | ||
| def can_implement(cls, | ||
| c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: | ||
| if not current_platform.is_cpu(): | ||
| return False, "Only CPU is supported" | ||
| if c.weight_type not in cls.SUPPORTED_QUANT_TYPES: | ||
| return False, f"Unsupported quant type {c.weight_type}" | ||
| if current_platform.get_cpu_architecture( | ||
| ) == CpuArchEnum.ARM and c.act_type not in [ | ||
| torch.float32, | ||
| ]: | ||
| return False, "Dynamic4bitLinearKernel on Arm requires"\ | ||
| " Float32 activations" | ||
| if c.full_weight_shape[0] % c.group_size != 0: | ||
| return False, f"Group size ({c.group_size}) does not evenly divide"\ | ||
| " the number of input features "\ | ||
| f"({c.full_weight_shape[0]})" | ||
| if current_platform.get_cpu_architecture() == CpuArchEnum.ARM: | ||
| try: | ||
| # Attempt to retrieve the operation | ||
| _ = torch.ops.aten._dyn_quant_matmul_4bit | ||
| except AttributeError: | ||
| return False, f"PyTorch {torch.__version__} does not support"\ | ||
| " _dyn_quant_matmul_4bit. Install a newer version" | ||
| return True, None | ||
|
|
||
| def process_weights_after_loading(self, layer: torch.nn.Module): | ||
| c = self.config | ||
| packed_weight = getattr(layer, self.w_q_name) | ||
| packed_weight = packed_weight.add(8) | ||
| uint8_packed = (packed_weight[::, 1::2] << 4 | ||
| | packed_weight[::, ::2]).to(torch.uint8) | ||
|
|
||
| scales = getattr(layer, self.w_s_name) | ||
| block_size = c.group_size | ||
|
|
||
| # Handle scaling factors for partitioned weights | ||
| if block_size == c.partition_weight_shape[0]: | ||
| scales = scales.to( | ||
| torch.float32 | ||
| ) # Float32 & Bfloat16 variants requires float32 scales | ||
| scales = scales.view(-1, 1) # Channel-wise scales | ||
| if layer.bias is not None: | ||
| layer.bias = layer.bias.to( | ||
| torch.float32 | ||
| ) # Float32 & Bfloat16 variants requires float32 bias | ||
| else: | ||
| # KleidiAI kernel requires bfloat16 scales with groupwise scheme | ||
| scales = scales.to(torch.bfloat16) | ||
|
|
||
| # Repack weights as per kernel requirement | ||
| w = torch.ops.aten._dyn_quant_pack_4bit_weight( | ||
| uint8_packed, scales, layer.bias, block_size, | ||
| c.partition_weight_shape[0], c.partition_weight_shape[1]) | ||
| replace_parameter(layer, self.w_q_name, | ||
| torch.nn.Parameter(w, requires_grad=False)) | ||
| setattr(layer, self.w_s_name, None) | ||
|
|
||
| def apply_weights(self, | ||
| layer: torch.nn.Module, | ||
| x: torch.Tensor, | ||
| bias: Optional[torch.Tensor] = None) -> torch.Tensor: | ||
| c = self.config | ||
| x_2d = x.reshape(-1, x.shape[-1]) | ||
| out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) | ||
|
|
||
| w_q = getattr(layer, self.w_q_name) | ||
| output = torch.ops.aten._dyn_quant_matmul_4bit( | ||
| x_2d, w_q, c.group_size, c.partition_weight_shape[0], | ||
| c.partition_weight_shape[1]) | ||
| return output.reshape(out_shape) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.