Skip to content

Commit

Permalink
Add composability test
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-tuli committed Dec 3, 2024
1 parent acaa685 commit ea8b8b5
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
pruning_stage:
obcq_modifiers:
SparseGPTModifier:
sparsity: 0.5
sequential_update: true
mask_structure: "2:4"
targets: ['re:model.layers.\d*$']
quant_stage:
quant_modifiers:
QuantizationModifier:
ignore: ["lm_head"]
config_groups:
group_0:
weights:
num_bits: 8
type: float
strategy: tensor
dynamic: false
symmetric: true
input_activations:
num_bits: 8
type: float
strategy: tensor
dynamic: true
symmetric: true
targets: ["Linear"]
pruning_modifiers:
ConstantPruningModifier:
targets: [
're:.*q_proj.weight',
're:.*k_proj.weight',
're:.*v_proj.weight',
're:.*o_proj.weight',
're:.*gate_proj.weight',
're:.*up_proj.weight',
're:.*down_proj.weight',
]
start: 0
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
pruning_stage:
obcq_modifiers:
SparseGPTModifier:
sparsity: 0.5
sequential_update: true
mask_structure: "2:4"
targets: ['re:model.layers.\d*$']
quant_stage:
quant_modifiers:
QuantizationModifier:
ignore: ["lm_head"]
config_groups:
group_0:
weights:
num_bits: 8
type: int
strategy: tensor
dynamic: false
symmetric: true
input_activations:
num_bits: 8
type: int
strategy: tensor
dynamic: true
symmetric: true
targets: ["Linear"]
pruning_modifiers:
ConstantPruningModifier:
targets: [
're:.*q_proj.weight',
're:.*k_proj.weight',
're:.*v_proj.weight',
're:.*o_proj.weight',
're:.*gate_proj.weight',
're:.*up_proj.weight',
're:.*down_proj.weight',
]
start: 0
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from accelerate import cpu_offload
from accelerate.accelerator import get_state_dict_offloaded_model
from compressed_tensors import QUANTIZATION_CONFIG_NAME
from compressed_tensors import QUANTIZATION_CONFIG_NAME, CompressionFormat
from compressed_tensors.compressors import ModelCompressor
from compressed_tensors.config import BitmaskConfig, DenseSparsityConfig
from compressed_tensors.quantization import QuantizationStatus
Expand Down Expand Up @@ -364,3 +364,111 @@ def test_model_shared_tensors_gpu(
test_model_shared_tensors(
offload, torch_dtype, tie_word_embeddings, device_map, tmp_path
)


@pytest.mark.parametrize(
"model_stub, recipe, sparse_format, quant_format",
[
(
"Xenova/llama2.c-stories110M",
"tests/llmcompressor/transformers/compression/recipes/sparse_24_int8.yaml",
CompressionFormat.sparse_24.value,
CompressionFormat.int_quantized.value,
),
],
)
def test_compressor_stacking(model_stub, recipe, sparse_format, quant_format, tmp_path):
from llmcompressor.pytorch.model_load.helpers import get_session_model

device = "cuda"
if not torch.cuda.is_available():
device = "cpu"
dataset = "open_platypus"
concatenate_data = False
num_calibration_samples = 64
splits = {"calibration": "train[:10%]"}
empty_model = AutoModelForCausalLM.from_pretrained(model_stub, torch_dtype="auto")

oneshot(
model=model_stub,
dataset=dataset,
num_calibration_samples=num_calibration_samples,
recipe=recipe,
concatenate_data=concatenate_data,
splits=splits,
oneshot_device=device,
clear_sparse_session=False,
)

# Fetch the oneshot model
model = get_session_model()
og_state_dict = model.state_dict()
path = tmp_path / "compressed"

# Compress and save
model.save_pretrained(
path,
quantization_format=quant_format,
save_compressed=True,
)

# Verify config on disk
config = AutoConfig.from_pretrained(path)
compression_config = getattr(config, QUANTIZATION_CONFIG_NAME, None)
quant_config = ModelCompressor.parse_quantization_config(compression_config)

# As HFQuantizer doesn't decompress the model, use the compressor to decompress
# the model instead
compressor = ModelCompressor.from_compression_config(compression_config)

assert (
compressor.sparsity_compressor is not None
), "Sparse compressor not initialized"
assert compressor.sparsity_config.format == sparse_format

assert (
compressor.quantization_compressor is not None
), "Quantization compressor not initialized"
assert quant_config["format"] == quant_format

compressor.quantization_config.quantization_status = QuantizationStatus.FROZEN
compressor.decompress(model_path=path, model=empty_model)

# Verify the abs difference between the decompressed model
# and the original model
reconstructed_state_dict = empty_model.state_dict()
assert len(og_state_dict) == len(reconstructed_state_dict)
for key in og_state_dict.keys():
dense_tensor = og_state_dict[key].to(device)
reconstructed_tensor = reconstructed_state_dict[key].to(device)
assert dense_tensor.dtype == reconstructed_tensor.dtype
if key.endswith("weight") and quant_format != "dense":
# we don't expect an exact match for compressed
diff = torch.abs(dense_tensor - reconstructed_tensor)
assert not torch.any(
diff > 0.01
).item(), f"{key} has a diff greater than 0.01"
else:
assert torch.equal(dense_tensor, reconstructed_tensor)
shutil.rmtree(tmp_path)


# This parameterization should be added to the test_compressor_stacking test
# once the lossy nature of FP8 compress-decompress is resolved.
# Until then, this test is marked as xfail.
@pytest.mark.xfail(reason="Known issue with FP8 compress-decompress")
@pytest.mark.parametrize(
"model_stub, recipe, sparse_format, quant_format",
[
(
"Xenova/llama2.c-stories110M",
"tests/llmcompressor/transformers/compression/recipes/sparse_24_fp8.yaml",
CompressionFormat.sparse_24.value,
CompressionFormat.float_quantized.value,
),
],
)
def test_compressor_stacking_fp8(
model_stub, recipe, sparse_format, quant_format, tmp_path
):
test_compressor_stacking(model_stub, recipe, sparse_format, quant_format, tmp_path)

0 comments on commit ea8b8b5

Please sign in to comment.