Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Composability with sparse and quantization compressors #948

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional

from compressed_tensors import CompressionFormat
from compressed_tensors.config import SparsityCompressionConfig
from compressed_tensors.config import SparsityStructure
from compressed_tensors.quantization import QuantizationStrategy, QuantizationType
from compressed_tensors.quantization.utils import (
is_model_quantized,
Expand All @@ -16,7 +16,7 @@ def infer_quantization_format(
model,
quantization_format: Optional[str] = None,
save_compressed: bool = False,
sparsity_config: Optional[SparsityCompressionConfig] = None,
sparsity_structure: Optional[str] = None,
) -> str:
"""
Infers a quantization format based on model state and compression args
Expand All @@ -37,7 +37,7 @@ def infer_quantization_format(
if save_compressed:
weight_args, input_args = _get_unique_quant_args(model)
is_24_structure = (
sparsity_config and sparsity_config.sparsity_structure == "2:4"
SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR
)
is_weight_only = len(input_args) == 0 and len(weight_args) > 0

Expand Down
22 changes: 15 additions & 7 deletions src/llmcompressor/transformers/compression/sparsity_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Dict, Optional

from compressed_tensors import CompressionFormat, SparsityCompressionConfig
from compressed_tensors.quantization.utils import is_model_quantized
from compressed_tensors.config import SparsityStructure
from torch import Tensor
from torch.nn import Module

Expand All @@ -20,7 +20,7 @@ class SparsityConfigMetadata:
metadata from the model
"""

SPARSITY_THRESHOLD: float = 0.4
SPARSITY_THRESHOLD: float = 0.5

@staticmethod
def infer_global_sparsity(
Expand Down Expand Up @@ -67,13 +67,14 @@ def infer_sparsity_structure(model: Optional[Module] = None) -> str:
if model and sparsity_structure is None:
sparsity_structure = infer_sparsity_structure_from_model(model)

return sparsity_structure or "unstructured"
return SparsityStructure(sparsity_structure).value

@staticmethod
def from_pretrained(
model: Module,
state_dict: Optional[Dict[str, Tensor]] = None,
compress: bool = False,
is_marlin: bool = False,
) -> Optional["SparsityCompressionConfig"]:
"""
Determines compression type and informational parameters for a given model
Expand All @@ -82,6 +83,7 @@ def from_pretrained(
:param state_dict: optional state_dict to replace that in model, used for
gathering global FSDP model info
:param compress: whether or not to compress the model on disk
:param is_marlin: whether or not marlin compression is being used
:return: compression config inferred from the model
"""

Expand All @@ -95,11 +97,17 @@ def from_pretrained(
sparsity_structure = SparsityConfigMetadata.infer_sparsity_structure(
model=model
)
if is_model_quantized(model):
# compressing a sparse quantized model is not supported yet
if is_marlin:
# sparse compressor should be dense for marlin
# compression
format = CompressionFormat.dense.value
elif compress:
format = CompressionFormat.sparse_bitmask.value
if compress:
format = (
CompressionFormat.sparse_24_bitmask.value
if sparsity_structure == SparsityStructure.TWO_FOUR.value
else CompressionFormat.sparse_bitmask.value
)

else:
format = CompressionFormat.dense.value

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import transformers
from accelerate.accelerator import get_state_dict_offloaded_model
from compressed_tensors import (
CompressionFormat,
ModelCompressor,
SparsityCompressionConfig,
is_module_offloaded,
Expand Down Expand Up @@ -273,13 +274,20 @@ def get_model_compressor(
if state_dict is None:
state_dict = get_state_dict_offloaded_model(model)

sparsity_stucture = SparsityConfigMetadata.infer_sparsity_structure(model)
quantization_format = infer_quantization_format(
model=model,
quantization_format=quantization_format,
save_compressed=save_compressed,
sparsity_structure=sparsity_stucture,
)
is_marlin = quantization_format == CompressionFormat.marlin_24.value

if sparsity_config is not None:
sparsity_config.global_sparsity = SparsityConfigMetadata.infer_global_sparsity(
model, state_dict=state_dict
)
sparsity_config.sparsity_structure = (
SparsityConfigMetadata.infer_sparsity_structure()
)
sparsity_config.sparsity_structure = sparsity_stucture
elif not skip_compression_stats:
# try to infer a sparsity config from the model if none is provided
logger.info(
Expand All @@ -289,15 +297,12 @@ def get_model_compressor(
"skip_compression_stats=True"
)
sparsity_config = SparsityConfigMetadata.from_pretrained(
model, state_dict=state_dict, compress=save_compressed
model,
state_dict=state_dict,
compress=save_compressed,
is_marlin=is_marlin,
)

quantization_format = infer_quantization_format(
model=model,
quantization_format=quantization_format,
save_compressed=save_compressed,
sparsity_config=sparsity_config,
)
return ModelCompressor.from_pretrained_model(
model,
sparsity_config=sparsity_config,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
pruning_stage:
obcq_modifiers:
SparseGPTModifier:
sparsity: 0.5
sequential_update: true
mask_structure: "2:4"
targets: ['re:model.layers.\d*$']
quant_stage:
quant_modifiers:
QuantizationModifier:
ignore: ["lm_head"]
config_groups:
group_0:
weights:
num_bits: 8
type: float
strategy: tensor
dynamic: false
symmetric: true
input_activations:
num_bits: 8
type: float
strategy: tensor
dynamic: true
symmetric: true
targets: ["Linear"]
pruning_modifiers:
ConstantPruningModifier:
targets: [
're:.*q_proj.weight',
're:.*k_proj.weight',
're:.*v_proj.weight',
're:.*o_proj.weight',
're:.*gate_proj.weight',
're:.*up_proj.weight',
're:.*down_proj.weight',
]
start: 0
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
pruning_stage:
obcq_modifiers:
SparseGPTModifier:
sparsity: 0.5
sequential_update: true
mask_structure: "0:0"
targets: ['re:model.layers.\d*$']
test_stage:
quant_modifiers:
QuantizationModifier:
ignore: ["lm_head"]
config_groups:
group_0:
weights:
num_bits: 8
type: "int"
symmetric: true
strategy: "tensor"
input_activations:
num_bits: 8
type: "int"
symmetric: false
strategy: "tensor"
output_activations: null
targets: ["Linear"]
group_1:
weights:
num_bits: 8
type: "int"
symmetric: true
strategy: "tensor"
input_activations: null
output_activations: null
targets: ["Embedding"]
pruning_modifiers:
ConstantPruningModifier:
targets: [
're:.*q_proj.weight',
're:.*k_proj.weight',
're:.*v_proj.weight',
're:.*o_proj.weight',
're:.*gate_proj.weight',
're:.*up_proj.weight',
're:.*down_proj.weight',
]
start: 0
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from accelerate import cpu_offload
from accelerate.accelerator import get_state_dict_offloaded_model
from compressed_tensors import QUANTIZATION_CONFIG_NAME
from compressed_tensors import QUANTIZATION_CONFIG_NAME, CompressionFormat
from compressed_tensors.compressors import ModelCompressor
from compressed_tensors.config import BitmaskConfig, DenseSparsityConfig
from compressed_tensors.quantization import QuantizationStatus
Expand Down Expand Up @@ -364,3 +364,88 @@ def test_model_shared_tensors_gpu(
test_model_shared_tensors(
offload, torch_dtype, tie_word_embeddings, device_map, tmp_path
)


@pytest.mark.parametrize(
"model_stub, recipe, sparse_format, quant_format",
[
(
"Xenova/llama2.c-stories15M",
"tests/llmcompressor/transformers/compression/recipes/sparse_int8.yaml",
CompressionFormat.sparse_bitmask.value,
CompressionFormat.int_quantized.value,
),
],
)
def test_compressor_stacking(model_stub, recipe, sparse_format, quant_format, tmp_path):
from llmcompressor.pytorch.model_load.helpers import get_session_model

device = "cuda"
if not torch.cuda.is_available():
device = "cpu"
dataset = "open_platypus"
concatenate_data = False
num_calibration_samples = 64
splits = {"calibration": "train[:10%]"}
empty_model = AutoModelForCausalLM.from_pretrained(model_stub, torch_dtype="auto")

oneshot(
model=model_stub,
dataset=dataset,
num_calibration_samples=num_calibration_samples,
recipe=recipe,
concatenate_data=concatenate_data,
splits=splits,
oneshot_device=device,
clear_sparse_session=False,
)

# Fetch the oneshot model
model = get_session_model()
og_state_dict = model.state_dict()
path = tmp_path / "compressed"

# Compress and save
model.save_pretrained(
path,
quantization_format=quant_format,
save_compressed=True,
)

# Verify config on disk
config = AutoConfig.from_pretrained(path)
compression_config = getattr(config, QUANTIZATION_CONFIG_NAME, None)
quant_config = ModelCompressor.parse_quantization_config(compression_config)

# As HFQuantizer doesn't decompress the model, use the compressor to decompress
# the model instead
compressor = ModelCompressor.from_compression_config(compression_config)

assert (
compressor.sparsity_compressor is not None
), "Sparse compressor not initialized"
assert compressor.sparsity_config.format == sparse_format

assert (
compressor.quantization_compressor is not None
), "Quantization compressor not initialized"
assert quant_config["format"] == quant_format

compressor.quantization_config.quantization_status = QuantizationStatus.FROZEN
compressor.decompress(model_path=path, model=empty_model)

# Verify the abs difference between the decompressed model
# and the original model
reconstructed_state_dict = empty_model.state_dict()
assert len(og_state_dict) == len(reconstructed_state_dict)
for key in og_state_dict.keys():
dense_tensor = og_state_dict[key].to(device)
reconstructed_tensor = reconstructed_state_dict[key].to(device)
assert dense_tensor.dtype == reconstructed_tensor.dtype
if key.endswith("weight") and quant_format != "dense":
# we don't expect an exact match for compressed
diff = torch.abs(dense_tensor - reconstructed_tensor)
assert not torch.any(diff > 0.01), f"Max diff: {torch.max(diff)}"
else:
assert torch.equal(dense_tensor, reconstructed_tensor)
shutil.rmtree(tmp_path)
Loading