Skip to content

Commit

Permalink
make preset more explicit (vllm-project#105)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sara Adkins authored Jul 3, 2024
1 parent b341803 commit aecb127
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 11 deletions.
55 changes: 48 additions & 7 deletions src/compressed_tensors/quantization/quant_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from compressed_tensors.quantization.quant_args import (
QuantizationArgs,
QuantizationStrategy,
QuantizationType,
)
from pydantic import BaseModel
Expand Down Expand Up @@ -110,15 +111,55 @@ def is_preset_scheme(name: str) -> bool:
return name.upper() in PRESET_SCHEMES


W8A8 = dict(weights=QuantizationArgs(), input_activations=QuantizationArgs())
W8A8 = dict(
weights=QuantizationArgs(
num_bits=8,
symmetric=True,
type=QuantizationType.INT,
strategy=QuantizationStrategy.CHANNEL,
),
input_activations=QuantizationArgs(
num_bits=8,
symmetric=True,
type=QuantizationType.INT,
strategy=QuantizationStrategy.TOKEN,
dynamic=True,
),
)

W4A16 = dict(weights=QuantizationArgs(num_bits=4, group_size=128))
W8A16 = dict(
weights=QuantizationArgs(
num_bits=8,
symmetric=True,
type=QuantizationType.INT,
strategy=QuantizationStrategy.CHANNEL,
)
)

FP8 = dict(
weights=QuantizationArgs(type=QuantizationType.FLOAT),
input_activations=QuantizationArgs(type=QuantizationType.FLOAT),
W4A16 = dict(
weights=QuantizationArgs(
num_bits=4,
symmetric=True,
type=QuantizationType.INT,
strategy=QuantizationStrategy.GROUP,
group_size=128,
)
)

PRESET_SCHEMES = {"W8A8": W8A8, "W4A16": W4A16, "FP8": FP8}
FP8 = dict(
weights=QuantizationArgs(
num_bits=8,
symmetric=True,
type=QuantizationType.FLOAT,
strategy=QuantizationStrategy.TENSOR,
),
input_activations=QuantizationArgs(
num_bits=8,
symmetric=True,
type=QuantizationType.FLOAT,
strategy=QuantizationStrategy.TENSOR,
dynamic=False,
),
)

PRESET_SCHEMES = {"W8A8": W8A8, "W4A16": W4A16, "FP8": FP8}
PRESET_SCHEMES = {"W8A8": W8A8, "W8A16": W8A16, "W4A16": W4A16, "FP8": FP8}
5 changes: 1 addition & 4 deletions tests/test_quantization/test_quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,7 @@ def test_need_config_groups():

@pytest.mark.parametrize(
"scheme_name",
[
"W8A8",
"W4A16",
],
["W8A8", "W8A16", "W4A16", "FP8"],
)
def test_load_scheme_from_preset(scheme_name: str):
targets = ["Linear"]
Expand Down

0 comments on commit aecb127

Please sign in to comment.