Skip to content
Merged
82 changes: 82 additions & 0 deletions examples/awq/qwen3_moe_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.awq import AWQModifier

# Select model and load it.
MODEL_ID = "Qwen/Qwen3-30B-A3B"

model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, device_map="auto", torch_dtype="auto"
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)

# Select calibration dataset.
DATASET_ID = "mit-han-lab/pile-val-backup"
DATASET_SPLIT = "validation"

# Select number of samples. 256 samples is a good place to start.
# Increasing the number of samples can improve accuracy.
NUM_CALIBRATION_SAMPLES = 256
MAX_SEQUENCE_LENGTH = 512

# Load dataset and preprocess.
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
ds = ds.shuffle(seed=42)


def preprocess(example):
return {
"text": tokenizer.apply_chat_template(
[{"role": "user", "content": example["text"]}],
tokenize=False,
)
}


ds = ds.map(preprocess)


# Tokenize inputs.
def tokenize(sample):
return tokenizer(
sample["text"],
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
add_special_tokens=False,
)


# Configure the quantization algorithm to run.
# NOTE: vllm currently does not support asym MoE, using symmetric here
recipe = [
AWQModifier(
ignore=["lm_head", "re:.*mlp.gate$", "re:.*mlp.shared_expert_gate$"],
scheme="W4A16",
targets=["Linear"],
),
]

# Apply algorithms.
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)

# Confirm generations of the quantized model look sane.
print("\n\n")
print("========== SAMPLE GENERATION ==============")
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")

# Save to disk compressed.
SAVE_DIR = MODEL_ID.split("/")[-1] + "-awq-sym"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
186 changes: 116 additions & 70 deletions src/llmcompressor/modifiers/awq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
from typing import Dict, List, Optional, Tuple, Union

import torch
from compressed_tensors.quantization import disable_quantization
from compressed_tensors.quantization import (
disable_quantization,
find_name_or_class_matches,
)
from compressed_tensors.utils import (
align_module_device,
get_execution_device,
Expand All @@ -26,11 +29,7 @@
from llmcompressor.pipelines.cache import IntermediatesCache
from llmcompressor.utils.fsdp.helpers import get_fsdp_parent
from llmcompressor.utils.helpers import calibration_forward_context
from llmcompressor.utils.pytorch.module import (
get_layers,
get_matching_layer,
get_parent_by_name,
)
from llmcompressor.utils.pytorch.module import get_layer_by_name, get_layers

__all__ = ["AWQModifier"]

Expand Down Expand Up @@ -307,77 +306,82 @@ def _set_resolved_mappings(self, model: Module) -> None:
repeat for model.layer.1 and so on
"""
resolved_mappings: list[ResolvedMapping] = []
num_skipped_oproj_mappings = 0
for mapping in self.mappings:
to_smooth_layers = get_layers(mapping.smooth_layer, model)
for layer_name, smooth_layer in to_smooth_layers.items():
# always exclude `.weight_observer`, only want `.weight`
if layer_name not in self.ignore and not layer_name.endswith(
"_observer"
):
balance_layers, balance_names = [], []
for balance_suffix in mapping.balance_layers:
# find the submodule that matches the activation layer
balance_name, balance_layer = get_matching_layer(
balance_suffix, layer_name, model
)
if not balance_layer:
continue
for mapping_idx, mapping in enumerate(self.mappings):
smooth_layers = get_layers(mapping.smooth_layer, model)
smooth_names = [
smooth_name
for smooth_name in smooth_layers
if not find_name_or_class_matches(
smooth_name, model, self.ignore + ["re:.*_observer$"]
)
]

num_skipped_mappings = 0
pbar = tqdm(smooth_names)
for smooth_name in pbar:
pbar.set_description(
f"Resolving mapping {mapping_idx+1}/{len(self.mappings)}"
f" ({num_skipped_mappings} skipped)"
)
smooth_layer = smooth_layers[smooth_name]

smooth_parent_name = ".".join(smooth_name.split(".")[:-1])
smooth_parent = get_layer_by_name(smooth_parent_name, model)

balance_layers, balance_names = [], []
for balance_regex in mapping.balance_layers:
# find the submodules that match the activation layer
for balance_suffix, balance_layer in get_layers(
balance_regex,
smooth_parent,
).items():
balance_name = f"{smooth_parent_name}.{balance_suffix}"

# exclude v_proj->o_proj mappings whose shapes are incompatible
# https://github.com/mit-han-lab/llm-awq/pull/67#issuecomment-1681632777
if (
isinstance(smooth_layer, torch.nn.Linear)
and isinstance(balance_layer, torch.nn.Linear)
and ".o_proj" in balance_name
and balance_name.endswith(".o_proj")
and (
(
".v_proj" in layer_name
smooth_name.endswith(".v_proj")
and smooth_layer.out_features
!= balance_layer.in_features
)
or (
".qkv_proj" in layer_name
smooth_name.endswith(".qkv_proj")
and smooth_layer.out_features
!= 3 * balance_layer.in_features
)
)
):
num_skipped_oproj_mappings += 1
num_skipped_mappings += 1
continue

balance_layers.append(balance_layer)
balance_names.append(balance_name)

if len(balance_layers) == 0:
continue

# each mapping can contain multiple layers to balance, but only
# one layer to smooth
if len(balance_layers) == 1:
# for single balance layer, parent is the balance layer
parent_name, parent = balance_name, balance_layer
else:
# for multiple balance layers,
# parent of any balance layer is the parent
parent_name, parent = get_parent_by_name(
layer_name=balance_name, model=model
)
resolved_mappings.append(
ResolvedMapping(
layer_name,
smooth_layer,
balance_layers,
balance_names=balance_names,
parent=parent,
parent_name=parent_name,
)
if len(balance_layers) == 0:
continue

elif len(balance_layers) == 1:
# for single balance layer, parent is the balance layer
parent_name, parent = balance_name, balance_layer
else:
# for multiple balance layers, find lowest common parent
parent_name, parent = get_lowest_common_parent(balance_names, model)

resolved_mappings.append(
ResolvedMapping(
smooth_name,
smooth_layer,
balance_layers,
balance_names=balance_names,
parent=parent,
parent_name=parent_name,
)
if num_skipped_oproj_mappings > 0:
logger.info(
f"Excluded {num_skipped_oproj_mappings} from resolved "
"mappings due to shape mismatch"
)
)
self._resolved_mappings = resolved_mappings
return

Expand All @@ -401,11 +405,9 @@ def cache_smooth_activations_hook(
args: Tuple[torch.Tensor, ...],
_output: torch.Tensor,
):
# Assume that first argument is the input
inp = args[0].cpu().detach().squeeze()

self._smooth_activation_means[smooth_name] = _accumulate_mean(
inp,
# Assume that first argument is the input
args[0].cpu().detach().squeeze(),
self._smooth_activation_means.get(smooth_name, None),
)

Expand Down Expand Up @@ -444,12 +446,14 @@ def _apply_smoothing(self, model: Module) -> None:

:param model: model to apply smoothing to
"""
for mapping in tqdm(self._resolved_mappings, desc="Smoothing"):
# NOTE: When using SequentialPipeline, not all the mappings
# will have cached activations in the segment being udpated
if mapping.smooth_name not in self._smooth_activation_means:
continue

# NOTE: When using SequentialPipeline, not all the mappings
# will have cached activations in the segment being udpated
mappings_to_smooth = [
mapping
for mapping in self._resolved_mappings
if mapping.smooth_name in self._smooth_activation_means
]
for mapping in tqdm(mappings_to_smooth, desc="Smoothing"):
smooth_layer = mapping.smooth_layer
balance_layers = mapping.balance_layers
parent_module = mapping.parent
Expand All @@ -473,10 +477,15 @@ def _apply_smoothing(self, model: Module) -> None:
# [STEP 3]: Compute output of module
# could cache from hook, rather than recomputing here
fp16_output = self._run_samples(parent_module)
fp16_output = fp16_output.clip(
torch.finfo(fp16_output.dtype).min,
torch.finfo(fp16_output.dtype).max,
)
if fp16_output.numel() == 0:
logger.info(
f"Skipping smooth_layer {mapping.smooth_name}, no activations "
"found to scale. This can occasionally occur in MoE models "
"when certain experts are not activated by calibration samples."
)
del self._smooth_activation_means[mapping.smooth_name]
continue

x_mean = self._smooth_activation_means[mapping.smooth_name][0]

# [STEP 4]: Compute loss
Expand Down Expand Up @@ -536,10 +545,15 @@ def smooth(module):

def _run_samples(self, module: Module) -> torch.Tensor:
with align_module_device(module):
outputs = [
module(**batch_kwargs)
for batch_kwargs in self._parent_args_cache[module]
]
return torch.cat(
[
module(**batch_kwargs)[0]
for batch_kwargs in self._parent_args_cache[module]
# If Tuple, assume that first argument is the input
output[0] if isinstance(output, Tuple) else output
for output in outputs
],
dim=0,
)
Expand Down Expand Up @@ -736,3 +750,35 @@ def _accumulate_mean(
new_count = prev_count + num_added

return (prev_sum + sum_added) / new_count, new_count


def get_lowest_common_parent(names: List[str], module: Module) -> Tuple[str, Module]:
"""
Given a list of names, returns the lowest-scope common parent.

NOTE: function excludes parents of type ModuleList, which don't play
nicely with hooks because their forward method is never directly
called for MoE models. See Qwen3MoeSparseMoeBlock for example, experts
are selected based on router output and their forward method is called.
https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py#L233

Returns name of parent and pointer to parent module

Implementation is a small alteration of os.path.commonprefix
https://docs.python.org/3/library/os.path.html#os.path.commonprefix
"""
s1 = min(names)
s2 = max(names)
parent_name = ""
for i, c in enumerate(s1):
if c != s2[i]:
parent_name = s1[:i].rstrip(".")
break

while True:
if parent_name == "":
return "", module
parent = get_layer_by_name(parent_name, module)
if not isinstance(parent, torch.nn.ModuleList):
return parent_name, parent
parent_name = ".".join(parent_name.split(".")[:-1])
Loading