Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Update w4a16 compressed-tensors support to include w8a16 #5794

Merged
merged 6 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions tests/quantization/test_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

from vllm import SamplingParams
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsLinearMethod, CompressedTensorsW4A16,
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8DynamicToken,
CompressedTensorsW8A8StaticTensor)
CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor,
CompressedTensorsWNA16)


@pytest.mark.parametrize("model_args", [
Expand Down Expand Up @@ -74,26 +74,27 @@ def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner, model_args):
assert qkv_proj.weight.dtype is torch.int8


@pytest.mark.parametrize("w4a16_args", [
("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None),
("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128),
])
def test_compressed_tensors_w4a16(vllm_runner, w4a16_args):
model, strategy, group = w4a16_args
@pytest.mark.parametrize(
"wNa16_args",
[("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None, 8),
("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128, 8),
("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4)])
def test_compressed_tensors_w4a16(vllm_runner, wNa16_args):
model, strategy, group, pack_factor = wNa16_args
with vllm_runner(model) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]

qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW4A16)
assert isinstance(qkv_proj.scheme, CompressedTensorsWNA16)

assert qkv_proj.scheme.strategy == strategy
assert qkv_proj.scheme.group_size == group

assert qkv_proj.weight_packed.dtype is torch.int32
assert qkv_proj.weight_scale.dtype is torch.float16
assert qkv_proj.weight_packed.pack_factor == 8
assert qkv_proj.weight_packed.pack_factor == pack_factor


def test_compressed_tensors_w4a16_marlin24(vllm_runner):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme, CompressedTensorsW4A16,
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8DynamicToken,
CompressedTensorsW8A8StaticTensor)
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS,
CompressedTensorsScheme, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor,
CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
CompressionFormat, QuantizationArgs, QuantizationStrategy,
find_first_name_or_class_match)
Expand Down Expand Up @@ -108,26 +109,31 @@ def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,

return is_8_bits and is_token and is_symmetric and is_dynamic

def _is_w4a16(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
def _is_wNa16_group_channel(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
input_quant_none = input_quant is None
is_4_bits = weight_quant.num_bits == 4
is_symmetric = weight_quant.symmetric
is_channel_group = (
weight_quant.strategy == QuantizationStrategy.CHANNEL.value
or weight_quant.strategy == QuantizationStrategy.GROUP.value)
is_static = not weight_quant.dynamic

return is_4_bits and input_quant_none and is_symmetric and is_static
return (is_channel_group and input_quant_none and is_symmetric
and is_static)

def _get_schema(self, weight_quant: BaseModel,
input_quant: BaseModel) -> "CompressedTensorsScheme":

if self._is_w4a16(weight_quant, input_quant):
if self.quant_format == CompressionFormat.marlin_24.value:
if self._is_wNa16_group_channel(weight_quant, input_quant):
if (self.quant_format == CompressionFormat.marlin_24.value
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS):
return CompressedTensorsW4A16Sparse24(
strategy=weight_quant.strategy,
num_bits=weight_quant.num_bits,
group_size=weight_quant.group_size)
if self.quant_format == CompressionFormat.pack_quantized.value:
return CompressedTensorsW4A16(
if (self.quant_format == CompressionFormat.pack_quantized.value
and weight_quant.num_bits in WNA16_SUPPORTED_BITS):
return CompressedTensorsWNA16(
num_bits=weight_quant.num_bits,
strategy=weight_quant.strategy,
group_size=weight_quant.group_size)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from .compressed_tensors_scheme import CompressedTensorsScheme # noqa: F401
from .compressed_tensors_unquantized import ( # noqa: F401
CompressedTensorsUnquantized)
from .compressed_tensors_w4a16 import CompressedTensorsW4A16 # noqa: F401
from .compressed_tensors_w4a16_24 import ( # noqa: F401
CompressedTensorsW4A16Sparse24)
W4A16SPARSE24_SUPPORTED_BITS, CompressedTensorsW4A16Sparse24)
from .compressed_tensors_w8a8_dynamictoken import ( # noqa: F401, E501
CompressedTensorsW8A8DynamicToken)
from .compressed_tensors_w8a8_statictensor import ( # noqa: F401, E501
CompressedTensorsW8A8StaticTensor)
from .compressed_tensors_wNa16 import WNA16_SUPPORTED_BITS # noqa: F401
from .compressed_tensors_wNa16 import CompressedTensorsWNA16 # noqa: F401
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from vllm.model_executor.utils import set_weight_attrs

__all__ = ["CompressedTensorsW4A16Sparse24"]
W4A16SPARSE24_SUPPORTED_BITS = [4]


class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
marlin_permute_scales)
from vllm.model_executor.utils import set_weight_attrs

__all__ = ["CompressedTensorsW4A16"]
__all__ = ["CompressedTensorsWNA16"]
WNA16_SUPPORTED_BITS = [4, 8]


class CompressedTensorsW4A16(CompressedTensorsScheme):
class CompressedTensorsWNA16(CompressedTensorsScheme):

def __init__(self,
strategy: str,
Expand Down
Loading