Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Quantization Save Defaults #22

Merged
merged 4 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 42 additions & 17 deletions src/llmcompressor/transformers/compression/quantization_format.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Optional

from compressed_tensors import CompressionFormat
from compressed_tensors.config import SparsityCompressionConfig
from compressed_tensors.quantization import QuantizationStrategy, QuantizationType
from compressed_tensors.quantization.utils import (
is_model_quantized,
is_module_quantized,
Expand All @@ -11,7 +13,10 @@


def infer_quantization_format(
model, quantization_format: Optional[str] = None, save_compressed: bool = False
model,
quantization_format: Optional[str] = None,
save_compressed: bool = False,
sparsity_config: Optional[SparsityCompressionConfig] = None,
) -> str:
"""
Infers a quantization format based on model state and compression args
Expand All @@ -30,32 +35,52 @@ def infer_quantization_format(
return quantization_format

if save_compressed:
quant_types = _get_quant_types(model)
if quant_types == ["int4"]: # save packed if everything is int4
return CompressionFormat.pack_quantized
elif quant_types == ["float8"]:
return CompressionFormat.float_quantized
weight_args, input_args = _get_unique_quant_args(model)
is_24_structure = (
sparsity_config and sparsity_config.sparsity_structure == "2:4"
)
is_weight_only = len(input_args) == 0 and len(weight_args) > 0

# otherwise just quantize to int8
return CompressionFormat.int_quantized
if is_weight_only: # w4a16 and w8a16
is_valid_pack = (
len(weight_args) == 1
and weight_args[0].num_bits in [4, 8]
and weight_args[0].type == QuantizationType.INT.value
)
if not is_valid_pack: # packing only valid for int4 and int 8
return CompressionFormat.naive_quantized
if is_24_structure:
for arg in weight_args:
if (
arg.strategy is not QuantizationStrategy.CHANNEL.value
and arg.strategy is not QuantizationStrategy.GROUP.value
):
# marlin24 kernel only applicable for channel/group quantization
return CompressionFormat.pack_quantized
return CompressionFormat.marlin_24
return CompressionFormat.pack_quantized
else: # w8a8 float and int
return CompressionFormat.naive_quantized
else:
# format will be inferred from config
return None


def _get_quant_types(model):
def _get_unique_quant_args(model):
"""
Gets a list of all the quantized types present in model
Gets a list of all the unique quantization settings present in model
"""
quant_info = []
quant_info_weight = []
quant_info_inputs = []
for _, submodule in iter_named_leaf_modules(model):
if is_module_quantized(submodule):
weight_scheme = submodule.quantization_scheme.weights
input_scheme = submodule.quantization_scheme.input_activations
if weight_scheme is not None:
weight_bit_depth = weight_scheme.num_bits
weight_type = weight_scheme.type
weight_info = f"{weight_type}{weight_bit_depth}"
if weight_info not in quant_info:
quant_info.append(weight_info)
if weight_scheme not in quant_info_weight:
quant_info_weight.append(weight_scheme)
if input_scheme is not None:
if input_scheme not in quant_info_inputs:
quant_info_inputs.append(input_scheme)

return quant_info
return quant_info_weight, quant_info_inputs
3 changes: 2 additions & 1 deletion src/llmcompressor/transformers/finetune/session_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ def __init__(
if self.is_fsdp_enabled:
self._prepare_model_for_fsdp()

self.min_tokens_per_module = data_args.min_tokens_per_module
if data_args is not None:
self.min_tokens_per_module = data_args.min_tokens_per_module

def initialize_session(
self,
Expand Down
2 changes: 1 addition & 1 deletion src/llmcompressor/transformers/finetune/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class TrainingArguments(HFTrainingArgs):
},
)
save_compressed: Optional[bool] = field(
default=False,
default=True,
metadata={"help": "Whether to compress sparse models during save"},
)
do_oneshot: Optional[bool] = field(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def save_pretrained_wrapper(
save_directory: str,
sparsity_config: Optional[SparsityCompressionConfig] = None,
quantization_format: Optional[str] = None,
save_compressed: bool = False,
save_compressed: bool = True,
skip_compression_stats: bool = False,
**kwargs,
):
Expand Down Expand Up @@ -88,6 +88,7 @@ def save_pretrained_wrapper(
model=model,
quantization_format=quantization_format,
save_compressed=save_compressed,
sparsity_config=sparsity_config,
)
compressor = ModelCompressor.from_pretrained_model(
model,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import pytest
from compressed_tensors.config import SparsityCompressionConfig
from compressed_tensors.quantization import preset_name_to_scheme

from llmcompressor.transformers.compression.quantization_format import (
infer_quantization_format,
)
from tests.llmcompressor.pytorch.helpers import LinearNet


@pytest.mark.parametrize(
"preset,sparsity_structure,expected_format",
[
["W8A8", "unstructured", "naive-quantized"],
["W8A16", "unstructured", "pack-quantized"],
["W8A16", "2:4", "marlin-24"],
["W4A16", "unstructured", "pack-quantized"],
["W4A16", "2:4", "marlin-24"],
["FP8", "unstructured", "naive-quantized"],
],
)
def test_infer_quant_format(preset, sparsity_structure, expected_format):
sparsity_config = SparsityCompressionConfig(
format="dense", sparsity_structure=sparsity_structure
)
quant_scheme = preset_name_to_scheme(preset, targets=["Linear"])

dummy_model = LinearNet()
for _, module in dummy_model.named_modules():
module.quantization_scheme = quant_scheme

inferred_format = infer_quantization_format(
dummy_model, save_compressed=True, sparsity_config=sparsity_config
)
assert inferred_format.value == expected_format
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def _run_oneshot(model, recipe, dataset, output_dir):
pad_to_max_length=pad_to_max_length,
clear_sparse_session=True,
splits={"calibration": "train_gen[:5%]"},
save_compressed=False,
)

def _get_quant_info(self, model):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def test_mask_structure_preserved(self):
output_dir=self.output_first,
oneshot_device=self.device,
clear_sparse_session=False,
save_compressed=False,
)
first_tiny_model = get_session_model()
targetted_layer = first_tiny_model.model.layers[0].self_attn.k_proj
Expand All @@ -87,6 +88,7 @@ def test_mask_structure_preserved(self):
output_dir=self.output_second,
oneshot_device=self.device,
clear_sparse_session=False,
save_compressed=False,
)

second_tiny_model = get_session_model()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,7 @@
from compressed_tensors import COMPRESSION_CONFIG_NAME
from compressed_tensors.compressors import ModelCompressor
from compressed_tensors.config import BitmaskConfig, DenseSparsityConfig
from compressed_tensors.quantization import (
QuantizationStatus,
compress_quantized_weights,
freeze_module_quantization,
)
from safetensors import safe_open
from compressed_tensors.quantization import QuantizationStatus
from transformers import AutoConfig

from llmcompressor.core import reset_session
Expand Down Expand Up @@ -213,64 +208,3 @@ def test_quant_model_reload(format, dtype, tmp_path):
assert torch.equal(dense_tensor, reconstructed_tensor)

shutil.rmtree(tmp_path)


@pytest.mark.parametrize(
"status,expected_format,expected_dtype",
[
[QuantizationStatus.FROZEN, "dense", torch.float32],
[QuantizationStatus.COMPRESSED, "int-quantized", torch.int8],
],
)
def test_quant_infer_format(status, expected_format, expected_dtype, tmp_path):
recipe_str = (
"tests/llmcompressor/transformers/compression/recipes/new_quant_simple.yaml"
)
model_path = "Xenova/llama2.c-stories15M"
device = "cuda:0"
if not torch.cuda.is_available():
device = "cpu"
dataset = "open_platypus"
concatenate_data = False
num_calibration_samples = 64
output_dir = tmp_path / "oneshot_out"
splits = {"calibration": "train[:10%]"}

model = SparseAutoModelForCausalLM.from_pretrained(model_path)

# create a quantized model
oneshot(
model=model,
dataset=dataset,
output_dir=output_dir,
num_calibration_samples=num_calibration_samples,
recipe=recipe_str,
concatenate_data=concatenate_data,
splits=splits,
oneshot_device=device,
)

if status == QuantizationStatus.FROZEN:
model.apply(freeze_module_quantization)
elif status == QuantizationStatus.COMPRESSED:
model.apply(compress_quantized_weights)

for _, module in model.named_modules():
if hasattr(module, "quantization_scheme"):
assert module.quantization_status == status

model.save_pretrained(tmp_path / "compress_out")

config = AutoConfig.from_pretrained(tmp_path / "compress_out")
compression_config = getattr(config, COMPRESSION_CONFIG_NAME, None)
quant_config = ModelCompressor.parse_quantization_config(compression_config)
assert quant_config["quantization_status"] == status.value
assert quant_config["format"] == expected_format

with safe_open(
tmp_path / "compress_out" / "model.safetensors", framework="pt", device=device
) as f:
test_tensor = f.get_tensor("model.layers.0.mlp.down_proj.weight")
assert test_tensor.dtype == expected_dtype

shutil.rmtree(tmp_path)
Loading