Skip to content

Conversation

@brian-dellabetta
Copy link
Collaborator

@brian-dellabetta brian-dellabetta commented May 16, 2025

SUMMARY:
In AWQ, resolving mappings can take a while because it is traversing the entire model tree, rather than just the parent, to find the balance layers. This scopes the search to just the parent module. For MoE models, the previous implementation only found a single layer for each regex string provided in mappings. This updates that to find as many as it can, which is necessary for mappings like

AWQMapping(
    "re:.*post_attention_layernorm$",
    ["re:.*mlp.experts.*.gate_proj$", "re:.*mlp.experts.*.up_proj$"],
)

which have multiple gate_proj and up_proj layers, one for each expert.

gsm8k results with Qwen/Qwen3-30B-A3B MoE model after AWQ W4A16 Symmetric:

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.3813 ± 0.0134
strict-match 5 exact_match 0.8810 ± 0.0089

TEST PLAN:

  • working with Qwen/Qwen3-30B-A3B with same set of mappings used in AutoAWQ. Example included in this PR in examples/awq/qwen3_moe_example.py. Ran successfully in ~2 hours on a single H100 with ~70GB of 80GB used (additional memory needed during saving)
  • Same wikitext PPL (14.0814) as on main for meta-llama/Llama-3.2-3B-Instruct

@github-actions
Copy link

👋 Hi! Thank you for contributing to llm-compressor. Please add the ready label when the PR is ready for review.

Note: This is required to complete the testing suite, please only add the label once the PR is code complete and local testing has been performed.

@brian-dellabetta brian-dellabetta force-pushed the bdellabe/awq-fast-resolve-mappings branch 2 times, most recently from 83333a6 to 63012b4 Compare May 21, 2025 20:27
This was referenced May 22, 2025
@brian-dellabetta brian-dellabetta force-pushed the bdellabe/awq-fast-resolve-mappings branch from 6749f77 to 661454f Compare May 29, 2025 16:19
@brian-dellabetta brian-dellabetta force-pushed the bdellabe/awq-fast-resolve-mappings branch from ba26683 to 0ec8e0e Compare June 2, 2025 15:25
@brian-dellabetta brian-dellabetta marked this pull request as ready for review June 3, 2025 16:33
@brian-dellabetta brian-dellabetta changed the title AWQModifier fast resolve mappings, better logging AWQModifier fast resolve mappings, better logging, MoE support Jun 3, 2025
@brian-dellabetta brian-dellabetta added the ready When a PR is ready for review label Jun 3, 2025
@brian-dellabetta brian-dellabetta force-pushed the bdellabe/awq-fast-resolve-mappings branch from f3d9f10 to ac26dbf Compare June 3, 2025 17:24
@brian-dellabetta brian-dellabetta requested review from dsikka, kylesayrs, rahul-tuli and shanjiaz and removed request for dsikka and kylesayrs June 3, 2025 17:30
rahul-tuli
rahul-tuli previously approved these changes Jun 3, 2025
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
kylesayrs
kylesayrs previously approved these changes Jun 5, 2025
Copy link
Collaborator

@kylesayrs kylesayrs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many nits, otherwise looks good

Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
@brian-dellabetta brian-dellabetta dismissed stale reviews from kylesayrs and rahul-tuli via 96d5f59 June 5, 2025 16:52
@brian-dellabetta brian-dellabetta merged commit ceffa64 into main Jun 5, 2025
11 checks passed
@brian-dellabetta brian-dellabetta deleted the bdellabe/awq-fast-resolve-mappings branch June 5, 2025 18:00
@Chao-Xue
Copy link

https://github.com/vllm-project/llm-compressor/blob/ceffa644072b1d440df3d99b0f98f6416a05bf2f/examples/awq/qwen3_moe_example.py#L55~L56
It looks like re:.*mlp.shared_expert_gate$ is included in the ignore list, but Qwen3-MoE doesn’t seem to have a shared_expert_gate module. Probably not an issue tho

@brian-dellabetta
Copy link
Collaborator Author

@Chao-Xue , yes that is for other qwen MoE architectures where there is a shared expert that we ignore. not all qwen MoE models have that though

aireilly pushed a commit to aireilly/llm-compressor that referenced this pull request Jul 30, 2025
…project#1444)

SUMMARY:
In AWQ, resolving mappings can take a while because it is traversing the
entire model tree, rather than just the parent, to find the balance
layers. This scopes the search to just the parent module. For MoE
models, the previous implementation only found a single layer for each
regex string provided in mappings. This updates that to find as many as
it can, which is necessary for mappings like

```python
AWQMapping(
    "re:.*post_attention_layernorm$",
    ["re:.*mlp.experts.*.gate_proj$", "re:.*mlp.experts.*.up_proj$"],
)
```

which have multiple gate_proj and up_proj layers, one for each expert.

gsm8k results with `Qwen/Qwen3-30B-A3B` MoE model after AWQ W4A16
Symmetric:

|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr|

|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.3813|± |0.0134|
| | |strict-match | 5|exact_match|↑ |0.8810|± |0.0089|

TEST PLAN:
- [x] working with `Qwen/Qwen3-30B-A3B` with [same set of mappings used
in
AutoAWQ](https://github.com/casper-hansen/AutoAWQ/blob/main/awq/models/qwen3_moe.py#L24).
Example included in this PR in `examples/awq/qwen3_moe_example.py`. Ran
successfully in ~2 hours on a single H100 with ~70GB of 80GB used
(additional memory needed during saving)
- [x] Same wikitext PPL (14.0814) as on `main` for
`meta-llama/Llama-3.2-3B-Instruct`

---------

Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
Co-authored-by: Kyle Sayers <kylesayrs@gmail.com>
HDCharles added a commit that referenced this pull request Dec 11, 2025
### Summary
To allow for arbitrary heterogeneous quantization schemes, this PR
switches several helpers from AutoAWQ to the observer and QDQ logic. AWQ
no longer constrains that the quantization config needs to have the same
settings for group_size, symmetric, and num_bits for each config_group.

Resolves #1657 

Prerequisites:
* vllm-project/compressed-tensors#519

### Test plan
- [x] When running `llm-compressor/examples/awq/llama_example.py` with
this (with `duo_scaling="both"`) and logging the best configuration of
`(ratio, duo_scaling)`, I see a good mix of Falses and Trues. i.e. a
good percentage of best_scales were found with duo_scaling=False and a
good percentage were found with `duo_scaling=True`. Generated model
output looks good.
- [x] When using `awq_one_shot.py` (pasted below), Wikitext PPL is
consistent for w4a16 and w4a16_asym on this branch when compared to
main, and better than what was reported in a [previous AWQ
PR](#1444 (comment)),
but those might have been differently configured. For W4A16_ASYM, the
results are both 13.41 for main and this branch. This is what we've been
historically using to test regressions.

|Scheme|Wikitext PPL RTN|AWQ main|AWQ this branch|
|-----------:|---------------------:|----------|-----:|
|W4A16|   13.784   |13.477| 13.426|
|W4A16_ASYM | 13.606 | 13.346 | 13.377|

- [x] I see a small regression in recovery when running `CADENCE=weekly
TEST_DATA_FILE=~/projects/llm-compressor/tests/lmeval/configs/w4a16_awq_sym.yaml
pytest -s ~/projects/llm-compressor/tests/lmeval/test_lmeval.py` on this
branch, which causes the test to fail. This persists even when using
`pseudo_quantize_tensor` instead of `call_observer`/`forward_quantize`,
as shown in [this
diff](https://github.com/vllm-project/llm-compressor/compare/kylesayrs/awq-generalize-quant...bdellabe/awq-generalize-quant?expand=1).
I get the same result in this diff, so at least that means quantization
logic in CT is consistent with AutoAWQ
Output:
```
<main>
2025-11-17T18:26:04.682699+0000 | _validate_recovery | INFO - ✓ exact_match,strict-match                 | Base: 0.7650 | Compressed: 0.7090 | Recovery: 92.68% ↑ | Threshold: ≥92.00%
2025-11-17T18:26:04.682811+0000 | _validate_recovery | INFO - ✓ exact_match,flexible-extract             | Base: 0.7630 | Compressed: 0.7100 | Recovery: 93.05% ↑ | Threshold: ≥93.00%
<this branch>
2025-11-17T17:55:00.648672+0000 | _validate_recovery | ERROR - ✗ exact_match,strict-match                 | Base: 0.7650 | Compressed: 0.6950 | Recovery: 90.85% ↑ | Threshold: ≥92.00%
2025-11-17T17:55:00.648967+0000 | _validate_recovery | ERROR - ✗ exact_match,flexible-extract             | Base: 0.7630 | Compressed: 0.6960 | Recovery: 91.22% ↑ | Threshold: ≥93.00%
```
This is already a pretty high drop in recovery, should we revisit this
test?



- [x] Further regression testing against main was done in this
[commit](8b6b0a5)
see
[run.sh](https://github.com/vllm-project/llm-compressor/blob/8b6b0a5ae27084756df5d7e3fd0eca60cbe07b87/run.sh)
as of that commit which was removed in the final PR. Results look
reasonable comparing branch and main, some up some down, within margin
of error.

  Test Group Quantization (w4a16_awq_sym)
  
| Branch | Metric | Base | Compressed | Recovery |

|-----------|------------------------------|--------|------------|----------|
| On Branch | exact_match,strict-match | 0.7620 | 0.7170 | 94.09% ↑ |
| On Branch | exact_match,flexible-extract | 0.7600 | 0.7130 | 93.82% ↑
|
| On Main | exact_match,strict-match | 0.7620 | 0.7090 | 93.04% |
| On Main | exact_match,flexible-extract | 0.7600 | 0.7060 | 92.89% |

  Test Tensor Quantization (int8_tensor)

| Branch | Metric | Base | Compressed | Recovery |

|-----------|------------------------------|--------|------------|----------|
| On Branch | exact_match,strict-match | 0.7620 | 0.7220 | 94.75% ↓ |
| On Branch | exact_match,flexible-extract | 0.7600 | 0.7240 | 95.26% ↓
|
| On Main | exact_match,strict-match | 0.7620 | 0.7280 | 95.54% |
| On Main | exact_match,flexible-extract | 0.7600 | 0.7310 | 96.18% |

  Test Channel Quantization (fp8_dynamic)

| Branch | Metric | Base | Compressed | Recovery |

|-----------|------------------------------|--------|------------|----------|
| On Branch | exact_match,strict-match | 0.7650 | 0.7610 | 99.48% |
| On Branch | exact_match,flexible-extract | 0.7630 | 0.7580 | 99.34% |

  Test Block Quantization (fp8_block)

| Branch | Metric | Base | Compressed | Recovery |

|-----------|------------------------------|--------|------------|-----------|
| On Branch | exact_match,strict-match | 0.7650 | 0.7720 | 100.92% |
| On Branch | exact_match,flexible-extract | 0.7630 | 0.7690 | 100.79% |
  
<details>
<summary>awq_oneshot.py script</summary>
```python
import os

os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

from llmcompressor import oneshot, active_session
from llmcompressor.utils import dispatch_for_generation
from llmcompressor.modifiers.awq import AWQModifier, AWQMapping
from llmcompressor.modifiers.quantization import QuantizationModifier
from compressed_tensors.quantization import (
    QuantizationArgs,
    QuantizationScheme,
    QuantizationStrategy,
    QuantizationType,
)


MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"

SAVE_DIR = MODEL_ID.split("/")[-1] + "-awq-asym"


# Configure the quantization algorithm to run.
recipe = [
    AWQModifier(
        ignore=[
            "lm_head",
            "re:.*mlp.gate$",
            "re:.*mlp.shared_expert_gate$",
            "re:visual.*",
        ],
        scheme="W4A16_ASYM",
        duo_scaling="both",
        targets=["Linear"],
        # offload_device=torch.device("cpu"),
    ),
]

# Select calibration dataset.
DATASET_ID = "mit-han-lab/pile-val-backup"
DATASET_SPLIT = "validation"

# Select number of samples. 256 samples is a good place to start.
# Increasing the number of samples can improve accuracy.
NUM_CALIBRATION_SAMPLES = 256
MAX_SEQUENCE_LENGTH = 512


def get_calib_dataset(tokenizer):
    from datasets import load_dataset

    ds = load_dataset(
        DATASET_ID,
        split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES*10}]",
    )

    def preprocess(example):
        return {"input_ids": tokenizer.encode(example["text"].strip())}

    ds = (
        ds.shuffle(seed=42)
        .map(preprocess, remove_columns=ds.column_names)
        .select(range(NUM_CALIBRATION_SAMPLES))
    )

    return ds


if __name__ == "__main__":
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID, torch_dtype="auto", trust_remote_code=True
    )
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)

    ###
    ### Apply algorithms.
    ###
    oneshot(
        model=model,
        dataset=get_calib_dataset(tokenizer),
        recipe=recipe,
        max_seq_length=MAX_SEQUENCE_LENGTH,
        num_calibration_samples=NUM_CALIBRATION_SAMPLES,
        log_dir=None,
        trust_remote_code_model=True,
    )

    # Confirm generations of the quantized model look sane.
    dispatch_for_generation(model)
    print("\n\n")
    print("========== SAMPLE GENERATION ==============")
    input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
    output = model.generate(input_ids, max_new_tokens=100)
    print(tokenizer.decode(output[0]))
    print("==========================================\n\n")

    # Save to disk compressed.
    model.save_pretrained(SAVE_DIR)
    tokenizer.save_pretrained(SAVE_DIR)

    ##
    ### Apply algorithms.
    ##

    ## LM EVAL

    active_session().reset()
    del model
    del tokenizer
    torch.cuda.empty_cache()

    import lm_eval
    from lm_eval.utils import make_table

    results = lm_eval.simple_evaluate(
        model="vllm",
        model_args={
            "pretrained": SAVE_DIR,
            "add_bos_token": True,
            "dtype": "bfloat16",
            "gpu_memory_utilization": 0.7,
            "max_model_len": 4096,
            # "max_num_batched_tokens": 128,
            # "max_num_seqs": 128,
        },
        tasks=["wikitext"],
        batch_size=128,
    )
    print(make_table(results))
```
</details>

---------

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Co-authored-by: Brian Dellabetta <bdellabe@redhat.com>
Co-authored-by: HDCharles <charlesdavidhernandez@gmail.com>
Co-authored-by: Fynn Schmitt-Ulms <fynnsu@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready When a PR is ready for review

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants