Skip to content

Conversation

@kylesayrs
Copy link
Collaborator

@kylesayrs kylesayrs commented Oct 22, 2025

Summary

To allow for arbitrary heterogeneous quantization schemes, this PR switches several helpers from AutoAWQ to the observer and QDQ logic. AWQ no longer constrains that the quantization config needs to have the same settings for group_size, symmetric, and num_bits for each config_group.

Resolves #1657

Prerequisites:

Test plan

  • When running llm-compressor/examples/awq/llama_example.py with this (with duo_scaling="both") and logging the best configuration of (ratio, duo_scaling), I see a good mix of Falses and Trues. i.e. a good percentage of best_scales were found with duo_scaling=False and a good percentage were found with duo_scaling=True. Generated model output looks good.
  • When using awq_one_shot.py (pasted below), Wikitext PPL is consistent for w4a16 and w4a16_asym on this branch when compared to main, and better than what was reported in a previous AWQ PR, but those might have been differently configured. For W4A16_ASYM, the results are both 13.41 for main and this branch. This is what we've been historically using to test regressions.
Scheme Wikitext PPL RTN AWQ main AWQ this branch
W4A16 13.784 13.477 13.426
W4A16_ASYM 13.606 13.346 13.377
  • I see a small regression in recovery when running CADENCE=weekly TEST_DATA_FILE=~/projects/llm-compressor/tests/lmeval/configs/w4a16_awq_sym.yaml pytest -s ~/projects/llm-compressor/tests/lmeval/test_lmeval.py on this branch, which causes the test to fail. This persists even when using pseudo_quantize_tensor instead of call_observer/forward_quantize, as shown in this diff. I get the same result in this diff, so at least that means quantization logic in CT is consistent with AutoAWQ
    Output:
<main>
2025-11-17T18:26:04.682699+0000 | _validate_recovery | INFO - ✓ exact_match,strict-match                 | Base: 0.7650 | Compressed: 0.7090 | Recovery: 92.68% ↑ | Threshold: ≥92.00%
2025-11-17T18:26:04.682811+0000 | _validate_recovery | INFO - ✓ exact_match,flexible-extract             | Base: 0.7630 | Compressed: 0.7100 | Recovery: 93.05% ↑ | Threshold: ≥93.00%
<this branch>
2025-11-17T17:55:00.648672+0000 | _validate_recovery | ERROR - ✗ exact_match,strict-match                 | Base: 0.7650 | Compressed: 0.6950 | Recovery: 90.85% ↑ | Threshold: ≥92.00%
2025-11-17T17:55:00.648967+0000 | _validate_recovery | ERROR - ✗ exact_match,flexible-extract             | Base: 0.7630 | Compressed: 0.6960 | Recovery: 91.22% ↑ | Threshold: ≥93.00%

This is already a pretty high drop in recovery, should we revisit this test?

  • Further regression testing against main was done in this commit see run.sh as of that commit which was removed in the final PR. Results look reasonable comparing branch and main, some up some down, within margin of error.

    Test Group Quantization (w4a16_awq_sym)

    Branch Metric Base Compressed Recovery
    On Branch exact_match,strict-match 0.7620 0.7170 94.09% ↑
    On Branch exact_match,flexible-extract 0.7600 0.7130 93.82% ↑
    On Main exact_match,strict-match 0.7620 0.7090 93.04%
    On Main exact_match,flexible-extract 0.7600 0.7060 92.89%

    Test Tensor Quantization (int8_tensor)

    Branch Metric Base Compressed Recovery
    On Branch exact_match,strict-match 0.7620 0.7220 94.75% ↓
    On Branch exact_match,flexible-extract 0.7600 0.7240 95.26% ↓
    On Main exact_match,strict-match 0.7620 0.7280 95.54%
    On Main exact_match,flexible-extract 0.7600 0.7310 96.18%

    Test Channel Quantization (fp8_dynamic)

    Branch Metric Base Compressed Recovery
    On Branch exact_match,strict-match 0.7650 0.7610 99.48%
    On Branch exact_match,flexible-extract 0.7630 0.7580 99.34%

    Test Block Quantization (fp8_block)

    Branch Metric Base Compressed Recovery
    On Branch exact_match,strict-match 0.7650 0.7720 100.92%
    On Branch exact_match,flexible-extract 0.7630 0.7690 100.79%
awq_oneshot.py script ```python import os

os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

from llmcompressor import oneshot, active_session
from llmcompressor.utils import dispatch_for_generation
from llmcompressor.modifiers.awq import AWQModifier, AWQMapping
from llmcompressor.modifiers.quantization import QuantizationModifier
from compressed_tensors.quantization import (
QuantizationArgs,
QuantizationScheme,
QuantizationStrategy,
QuantizationType,
)

MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"

SAVE_DIR = MODEL_ID.split("/")[-1] + "-awq-asym"

Configure the quantization algorithm to run.

recipe = [
AWQModifier(
ignore=[
"lm_head",
"re:.*mlp.gate$",
"re:.mlp.shared_expert_gate$",
"re:visual.
",
],
scheme="W4A16_ASYM",
duo_scaling="both",
targets=["Linear"],
# offload_device=torch.device("cpu"),
),
]

Select calibration dataset.

DATASET_ID = "mit-han-lab/pile-val-backup"
DATASET_SPLIT = "validation"

Select number of samples. 256 samples is a good place to start.

Increasing the number of samples can improve accuracy.

NUM_CALIBRATION_SAMPLES = 256
MAX_SEQUENCE_LENGTH = 512

def get_calib_dataset(tokenizer):
from datasets import load_dataset

ds = load_dataset(
    DATASET_ID,
    split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES*10}]",
)

def preprocess(example):
    return {"input_ids": tokenizer.encode(example["text"].strip())}

ds = (
    ds.shuffle(seed=42)
    .map(preprocess, remove_columns=ds.column_names)
    .select(range(NUM_CALIBRATION_SAMPLES))
)

return ds

if name == "main":
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, torch_dtype="auto", trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)

###
### Apply algorithms.
###
oneshot(
    model=model,
    dataset=get_calib_dataset(tokenizer),
    recipe=recipe,
    max_seq_length=MAX_SEQUENCE_LENGTH,
    num_calibration_samples=NUM_CALIBRATION_SAMPLES,
    log_dir=None,
    trust_remote_code_model=True,
)

# Confirm generations of the quantized model look sane.
dispatch_for_generation(model)
print("\n\n")
print("========== SAMPLE GENERATION ==============")
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")

# Save to disk compressed.
model.save_pretrained(SAVE_DIR)
tokenizer.save_pretrained(SAVE_DIR)

##
### Apply algorithms.
##

## LM EVAL

active_session().reset()
del model
del tokenizer
torch.cuda.empty_cache()

import lm_eval
from lm_eval.utils import make_table

results = lm_eval.simple_evaluate(
    model="vllm",
    model_args={
        "pretrained": SAVE_DIR,
        "add_bos_token": True,
        "dtype": "bfloat16",
        "gpu_memory_utilization": 0.7,
        "max_model_len": 4096,
        # "max_num_batched_tokens": 128,
        # "max_num_seqs": 128,
    },
    tasks=["wikitext"],
    batch_size=128,
)
print(make_table(results))
</details>

@github-actions
Copy link

👋 Hi! Thank you for contributing to llm-compressor. Please add the ready label when the PR is ready for review.

Note: This is required to complete the testing suite, please only add the label once the PR is code complete and local testing has been performed.

@brian-dellabetta brian-dellabetta changed the title [WIP] Generalize AWQ quantization [AWQ] Generalize AWQ quantization Nov 13, 2025
Copy link
Collaborator Author

@kylesayrs kylesayrs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that so long as you feel confident that _compute_layer_means is going to work as expected for all the supported strategies, then I think this looks good to me!

Copy link
Collaborator Author

@kylesayrs kylesayrs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approve from my side

HDCharles
HDCharles previously approved these changes Nov 17, 2025
fynnsu
fynnsu previously approved these changes Nov 17, 2025
Copy link
Collaborator

@fynnsu fynnsu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, added a couple comments below!

@HDCharles HDCharles dismissed stale reviews from fynnsu and themself via 4e480ce December 8, 2025 18:47
@HDCharles HDCharles force-pushed the kylesayrs/awq-generalize-quant branch from bdcdca4 to 4e480ce Compare December 8, 2025 18:47
QuantizationMixin.initialize_quantization(self, state.model)

# Validate that duo_scaling is only used with per-channel quantization
if self.duo_scaling != False:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kylesayrs added check for duo_scaling + per Tensor strategy

continue

balance_layer.weight.mul_(_scalesview)
call_observer(balance_layer, "weight", balance_layer.weight)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you pass should_calculate_gparam=scheme.strategy=="tensor_group", then I think this should be able to support nvfp4

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will add that

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually added todo will do in another PR

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(added todo) going to move this and simplified logic to a new PR

@HDCharles HDCharles marked this pull request as ready for review December 10, 2025 15:01
@HDCharles HDCharles force-pushed the kylesayrs/awq-generalize-quant branch from 9d2d033 to 57ade6b Compare December 10, 2025 15:04
Copy link
Collaborator Author

@kylesayrs kylesayrs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, approved from my side

sequential_targets: Union[str, List[str], None] = None
mappings: Optional[List[AWQMapping]] = None
offload_device: Optional[torch.device] = None
duo_scaling: str | bool = True
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Literal["both"] doesn't work anymore?


# avoid scaling values that overflow
scales[torch.isinf(scales)] = 1
scales[torch.isnan(scales)] = 1
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely annoying that we still have to have this line.

.item()
)
batch_loss = torch.nn.functional.mse_loss(
fp16_batch.to(device), int_w_batch.to(device)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This algorithm has a lot of device movement

@kylesayrs kylesayrs added the ready When a PR is ready for review label Dec 10, 2025
kylesayrs and others added 5 commits December 10, 2025 13:47
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
brian-dellabetta and others added 23 commits December 10, 2025 13:48
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
Summary

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Summary

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Summary

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Summary

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Summary

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Summary

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Summary

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Summary

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Summary

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Summary

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Summary

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Summary

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Summary

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Summary

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
@HDCharles HDCharles force-pushed the kylesayrs/awq-generalize-quant branch from 57ade6b to 157cf48 Compare December 10, 2025 19:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready When a PR is ready for review

Projects

None yet

Development

Successfully merging this pull request may close these issues.

W4fp8 AWQ

5 participants