diff --git a/examples/awq/qwen3_moe_example.py b/examples/awq/qwen3_moe_example.py new file mode 100644 index 0000000000..b8f4a4ec18 --- /dev/null +++ b/examples/awq/qwen3_moe_example.py @@ -0,0 +1,82 @@ +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llmcompressor import oneshot +from llmcompressor.modifiers.awq import AWQModifier + +# Select model and load it. +MODEL_ID = "Qwen/Qwen3-30B-A3B" + +model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, device_map="auto", torch_dtype="auto" +) +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) + +# Select calibration dataset. +DATASET_ID = "mit-han-lab/pile-val-backup" +DATASET_SPLIT = "validation" + +# Select number of samples. 256 samples is a good place to start. +# Increasing the number of samples can improve accuracy. +NUM_CALIBRATION_SAMPLES = 256 +MAX_SEQUENCE_LENGTH = 512 + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") +ds = ds.shuffle(seed=42) + + +def preprocess(example): + return { + "text": tokenizer.apply_chat_template( + [{"role": "user", "content": example["text"]}], + tokenize=False, + ) + } + + +ds = ds.map(preprocess) + + +# Tokenize inputs. +def tokenize(sample): + return tokenizer( + sample["text"], + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + add_special_tokens=False, + ) + + +# Configure the quantization algorithm to run. +# NOTE: vllm currently does not support asym MoE, using symmetric here +recipe = [ + AWQModifier( + ignore=["lm_head", "re:.*mlp.gate$", "re:.*mlp.shared_expert_gate$"], + scheme="W4A16", + targets=["Linear"], + ), +] + +# Apply algorithms. +oneshot( + model=model, + dataset=ds, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, +) + +# Confirm generations of the quantized model look sane. +print("\n\n") +print("========== SAMPLE GENERATION ==============") +input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=100) +print(tokenizer.decode(output[0])) +print("==========================================\n\n") + +# Save to disk compressed. +SAVE_DIR = MODEL_ID.split("/")[-1] + "-awq-sym" +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index d7decaf374..f95aaaea84 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -2,7 +2,10 @@ from typing import Dict, List, Optional, Tuple, Union import torch -from compressed_tensors.quantization import disable_quantization +from compressed_tensors.quantization import ( + disable_quantization, + find_name_or_class_matches, +) from compressed_tensors.utils import ( align_module_device, get_execution_device, @@ -26,11 +29,7 @@ from llmcompressor.pipelines.cache import IntermediatesCache from llmcompressor.utils.fsdp.helpers import get_fsdp_parent from llmcompressor.utils.helpers import calibration_forward_context -from llmcompressor.utils.pytorch.module import ( - get_layers, - get_matching_layer, - get_parent_by_name, -) +from llmcompressor.utils.pytorch.module import get_layer_by_name, get_layers __all__ = ["AWQModifier"] @@ -307,77 +306,82 @@ def _set_resolved_mappings(self, model: Module) -> None: repeat for model.layer.1 and so on """ resolved_mappings: list[ResolvedMapping] = [] - num_skipped_oproj_mappings = 0 - for mapping in self.mappings: - to_smooth_layers = get_layers(mapping.smooth_layer, model) - for layer_name, smooth_layer in to_smooth_layers.items(): - # always exclude `.weight_observer`, only want `.weight` - if layer_name not in self.ignore and not layer_name.endswith( - "_observer" - ): - balance_layers, balance_names = [], [] - for balance_suffix in mapping.balance_layers: - # find the submodule that matches the activation layer - balance_name, balance_layer = get_matching_layer( - balance_suffix, layer_name, model - ) - if not balance_layer: - continue + for mapping_idx, mapping in enumerate(self.mappings): + smooth_layers = get_layers(mapping.smooth_layer, model) + smooth_names = [ + smooth_name + for smooth_name in smooth_layers + if not find_name_or_class_matches( + smooth_name, model, self.ignore + ["re:.*_observer$"] + ) + ] + + num_skipped_mappings = 0 + pbar = tqdm(smooth_names) + for smooth_name in pbar: + pbar.set_description( + f"Resolving mapping {mapping_idx+1}/{len(self.mappings)}" + f" ({num_skipped_mappings} skipped)" + ) + smooth_layer = smooth_layers[smooth_name] + + smooth_parent_name = ".".join(smooth_name.split(".")[:-1]) + smooth_parent = get_layer_by_name(smooth_parent_name, model) + + balance_layers, balance_names = [], [] + for balance_regex in mapping.balance_layers: + # find the submodules that match the activation layer + for balance_suffix, balance_layer in get_layers( + balance_regex, + smooth_parent, + ).items(): + balance_name = f"{smooth_parent_name}.{balance_suffix}" # exclude v_proj->o_proj mappings whose shapes are incompatible # https://github.com/mit-han-lab/llm-awq/pull/67#issuecomment-1681632777 if ( isinstance(smooth_layer, torch.nn.Linear) and isinstance(balance_layer, torch.nn.Linear) - and ".o_proj" in balance_name + and balance_name.endswith(".o_proj") and ( ( - ".v_proj" in layer_name + smooth_name.endswith(".v_proj") and smooth_layer.out_features != balance_layer.in_features ) or ( - ".qkv_proj" in layer_name + smooth_name.endswith(".qkv_proj") and smooth_layer.out_features != 3 * balance_layer.in_features ) ) ): - num_skipped_oproj_mappings += 1 + num_skipped_mappings += 1 continue balance_layers.append(balance_layer) balance_names.append(balance_name) - if len(balance_layers) == 0: - continue - - # each mapping can contain multiple layers to balance, but only - # one layer to smooth - if len(balance_layers) == 1: - # for single balance layer, parent is the balance layer - parent_name, parent = balance_name, balance_layer - else: - # for multiple balance layers, - # parent of any balance layer is the parent - parent_name, parent = get_parent_by_name( - layer_name=balance_name, model=model - ) - resolved_mappings.append( - ResolvedMapping( - layer_name, - smooth_layer, - balance_layers, - balance_names=balance_names, - parent=parent, - parent_name=parent_name, - ) + if len(balance_layers) == 0: + continue + + elif len(balance_layers) == 1: + # for single balance layer, parent is the balance layer + parent_name, parent = balance_name, balance_layer + else: + # for multiple balance layers, find lowest common parent + parent_name, parent = get_lowest_common_parent(balance_names, model) + + resolved_mappings.append( + ResolvedMapping( + smooth_name, + smooth_layer, + balance_layers, + balance_names=balance_names, + parent=parent, + parent_name=parent_name, ) - if num_skipped_oproj_mappings > 0: - logger.info( - f"Excluded {num_skipped_oproj_mappings} from resolved " - "mappings due to shape mismatch" - ) + ) self._resolved_mappings = resolved_mappings return @@ -401,11 +405,9 @@ def cache_smooth_activations_hook( args: Tuple[torch.Tensor, ...], _output: torch.Tensor, ): - # Assume that first argument is the input - inp = args[0].cpu().detach().squeeze() - self._smooth_activation_means[smooth_name] = _accumulate_mean( - inp, + # Assume that first argument is the input + args[0].cpu().detach().squeeze(), self._smooth_activation_means.get(smooth_name, None), ) @@ -444,12 +446,14 @@ def _apply_smoothing(self, model: Module) -> None: :param model: model to apply smoothing to """ - for mapping in tqdm(self._resolved_mappings, desc="Smoothing"): - # NOTE: When using SequentialPipeline, not all the mappings - # will have cached activations in the segment being udpated - if mapping.smooth_name not in self._smooth_activation_means: - continue - + # NOTE: When using SequentialPipeline, not all the mappings + # will have cached activations in the segment being udpated + mappings_to_smooth = [ + mapping + for mapping in self._resolved_mappings + if mapping.smooth_name in self._smooth_activation_means + ] + for mapping in tqdm(mappings_to_smooth, desc="Smoothing"): smooth_layer = mapping.smooth_layer balance_layers = mapping.balance_layers parent_module = mapping.parent @@ -473,10 +477,15 @@ def _apply_smoothing(self, model: Module) -> None: # [STEP 3]: Compute output of module # could cache from hook, rather than recomputing here fp16_output = self._run_samples(parent_module) - fp16_output = fp16_output.clip( - torch.finfo(fp16_output.dtype).min, - torch.finfo(fp16_output.dtype).max, - ) + if fp16_output.numel() == 0: + logger.info( + f"Skipping smooth_layer {mapping.smooth_name}, no activations " + "found to scale. This can occasionally occur in MoE models " + "when certain experts are not activated by calibration samples." + ) + del self._smooth_activation_means[mapping.smooth_name] + continue + x_mean = self._smooth_activation_means[mapping.smooth_name][0] # [STEP 4]: Compute loss @@ -536,10 +545,15 @@ def smooth(module): def _run_samples(self, module: Module) -> torch.Tensor: with align_module_device(module): + outputs = [ + module(**batch_kwargs) + for batch_kwargs in self._parent_args_cache[module] + ] return torch.cat( [ - module(**batch_kwargs)[0] - for batch_kwargs in self._parent_args_cache[module] + # If Tuple, assume that first argument is the input + output[0] if isinstance(output, Tuple) else output + for output in outputs ], dim=0, ) @@ -736,3 +750,35 @@ def _accumulate_mean( new_count = prev_count + num_added return (prev_sum + sum_added) / new_count, new_count + + +def get_lowest_common_parent(names: List[str], module: Module) -> Tuple[str, Module]: + """ + Given a list of names, returns the lowest-scope common parent. + + NOTE: function excludes parents of type ModuleList, which don't play + nicely with hooks because their forward method is never directly + called for MoE models. See Qwen3MoeSparseMoeBlock for example, experts + are selected based on router output and their forward method is called. + https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py#L233 + + Returns name of parent and pointer to parent module + + Implementation is a small alteration of os.path.commonprefix + https://docs.python.org/3/library/os.path.html#os.path.commonprefix + """ + s1 = min(names) + s2 = max(names) + parent_name = "" + for i, c in enumerate(s1): + if c != s2[i]: + parent_name = s1[:i].rstrip(".") + break + + while True: + if parent_name == "": + return "", module + parent = get_layer_by_name(parent_name, module) + if not isinstance(parent, torch.nn.ModuleList): + return parent_name, parent + parent_name = ".".join(parent_name.split(".")[:-1]) diff --git a/src/llmcompressor/modifiers/awq/mappings.py b/src/llmcompressor/modifiers/awq/mappings.py index 700525ed8b..6390445c8c 100644 --- a/src/llmcompressor/modifiers/awq/mappings.py +++ b/src/llmcompressor/modifiers/awq/mappings.py @@ -25,17 +25,33 @@ class AWQMapping: _default_mappings = [ AWQMapping( - "re:.*input_layernorm", - ["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], + "re:.*input_layernorm$", + ["re:.*q_proj$", "re:.*k_proj$", "re:.*v_proj$"], ), - AWQMapping("re:.*v_proj", ["re:.*o_proj"]), + AWQMapping("re:.*v_proj$", ["re:.*o_proj$"]), AWQMapping( - "re:.*post_attention_layernorm", - ["re:.*gate_proj", "re:.*up_proj"], + "re:.*post_attention_layernorm$", + ["re:.*gate_proj$", "re:.*up_proj$"], ), AWQMapping( - "re:.*up_proj", - ["re:.*down_proj"], + "re:.*up_proj$", + ["re:.*down_proj$"], + ), +] + +_moe_default_mappings = [ + AWQMapping( + "re:.*input_layernorm$", + ["re:.*q_proj$", "re:.*k_proj$", "re:.*v_proj$"], + ), + AWQMapping("re:.*v_proj$", ["re:.*o_proj$"]), + AWQMapping( + "re:.*post_attention_layernorm$", + ["re:.*mlp.experts.*.gate_proj$", "re:.*mlp.experts.*.up_proj$"], + ), + AWQMapping( + "re:.*up_proj$", + ["re:.*down_proj$"], ), ] @@ -44,27 +60,29 @@ class AWQMapping: # gate and up proj layers into a single gate_up_proj layer _phi_mappings = [ AWQMapping( - "re:.*input_layernorm", - ["re:.*qkv_proj"], + "re:.*input_layernorm$", + ["re:.*qkv_proj$"], ), - AWQMapping("re:.*qkv_proj", ["re:.*o_proj"]), + AWQMapping("re:.*qkv_proj$", ["re:.*o_proj$"]), AWQMapping( - "re:.*post_attention_layernorm", - ["re:.*gate_up_proj"], + "re:.*post_attention_layernorm$", + ["re:.*gate_up_proj$"], ), AWQMapping( - "re:.*gate_up_proj", - ["re:.*down_proj"], + "re:.*gate_up_proj$", + ["re:.*down_proj$"], ), ] AWQ_MAPPING_REGISTRY: Dict[str, list[AWQMapping]] = { "LlamaForCausalLM": _default_mappings, - "Qwen2ForCausalLM": _default_mappings, - "Qwen3ForCausalLM": _default_mappings, "MistralForCausalLM": _default_mappings, "Phi3ForCausalLM": _phi_mappings, "Phi3VForCausalLM": _phi_mappings, + "Qwen2ForCausalLM": _default_mappings, + "Qwen2MoeForCausalLM": _moe_default_mappings, + "Qwen3ForCausalLM": _default_mappings, + "Qwen3MoeForCausalLM": _moe_default_mappings, } diff --git a/src/llmcompressor/utils/pytorch/module.py b/src/llmcompressor/utils/pytorch/module.py index 1bb3e3f701..835493fa3d 100644 --- a/src/llmcompressor/utils/pytorch/module.py +++ b/src/llmcompressor/utils/pytorch/module.py @@ -4,6 +4,7 @@ import difflib import re +from operator import attrgetter from typing import Dict, List, Optional, Tuple, Union import torch @@ -53,7 +54,6 @@ "set_layer", "get_params", "get_param", - "set_param", "get_terminal_layers", "get_prunable_layers", "get_quantizable_layers", @@ -61,7 +61,7 @@ "get_layers_params", "get_matching_layer", "get_no_split_params", - "get_parent_by_name", + "get_layer_by_name", ] @@ -208,15 +208,6 @@ def get_param(target: str, module: Module) -> Tuple[str, Parameter]: return name, param -def set_param(target: str, param: Parameter, module: Module) -> Parameter: - layer_name, param_name = target.rsplit(".", 1) - layer = get_layer(layer_name, module)[1] - old_param = getattr(layer, param_name) - setattr(layer, param_name, param) - - return old_param - - def get_terminal_layers(module: Module) -> Dict[str, Module]: terminal = {} @@ -344,20 +335,12 @@ def get_no_split_params(model: PreTrainedModel) -> Union[str, List[str]]: return no_split_modules -def get_parent_by_name(layer_name: str, model: Module) -> Tuple[str, Module]: +# https://discuss.pytorch.org/t/how-to-access-to-a-layer-by-module-name/83797/8 +def get_layer_by_name(layer_name: str, module: Module) -> Module: """ - Get the parent layer of a layer by name. - :param layer_name: Name of the layer to find the parent of. - :param model: Model to search for the parent layer. - :return: Tuple containing the name of the parent layer - and the parent layer itself. + Get the layer of a module by name. + :param layer_name: Name of the layer to find. + :param module: Module in which to search for layer_name + :return: Module, the layer with name layer_name """ - if not any(layer_name == name for name, _ in model.named_modules()): - raise ValueError(f"Layer '{layer_name}' not found in model") - - parent_name_parts = layer_name.split(".")[:-1] - if not parent_name_parts: - return "", model - - parent_name = ".".join(parent_name_parts) - return get_layer(parent_name, model) + return attrgetter(layer_name)(module) diff --git a/tests/llmcompressor/modifiers/awq/test_base.py b/tests/llmcompressor/modifiers/awq/test_base.py index ae38da08ca..a4adfbdac0 100644 --- a/tests/llmcompressor/modifiers/awq/test_base.py +++ b/tests/llmcompressor/modifiers/awq/test_base.py @@ -4,6 +4,7 @@ from pydantic import ValidationError from llmcompressor.modifiers.awq import AWQMapping, AWQModifier +from llmcompressor.modifiers.awq.base import get_lowest_common_parent from llmcompressor.modifiers.factory import ModifierFactory from tests.llmcompressor.modifiers.conf import setup_modifier_factory @@ -56,32 +57,36 @@ def test_set_resolved_mappings(): ) model = torch.nn.ModuleDict( { - "self_attn": self_attn, - "input_layernorm": torch.nn.LayerNorm(4), - "mlp": mlp, + "decoder": torch.nn.ModuleDict( + { + "self_attn": self_attn, + "input_layernorm": torch.nn.LayerNorm(4), + "mlp": mlp, + } + ) } ) awq._set_resolved_mappings(model) for mapping in awq._resolved_mappings: if "input_layernorm" in mapping.smooth_name: assert set(mapping.balance_names) == { - "self_attn.q_proj", - "self_attn.k_proj", - "self_attn.v_proj", + "decoder.self_attn.q_proj", + "decoder.self_attn.k_proj", + "decoder.self_attn.v_proj", } assert set(mapping.balance_layers) == { self_attn.q_proj, self_attn.k_proj, self_attn.v_proj, } - assert mapping.parent_name == "self_attn" + assert mapping.parent_name == "decoder.self_attn" assert mapping.parent == self_attn if "self_attn.v_proj" in mapping.smooth_name: - assert set(mapping.balance_names) == {"self_attn.o_proj"} - assert mapping.parent_name == "self_attn.o_proj" + assert set(mapping.balance_names) == {"decoder.self_attn.o_proj"} + assert mapping.parent_name == "decoder.self_attn.o_proj" if "mlp.up_proj" in mapping.smooth_name: - assert set(mapping.balance_names) == {"mlp.down_proj"} - assert mapping.parent_name == "mlp.down_proj" + assert set(mapping.balance_names) == {"decoder.mlp.down_proj"} + assert mapping.parent_name == "decoder.mlp.down_proj" # make sure we exclude case where o_proj/v_proj shapes are mismatched awq = AWQModifier( @@ -92,12 +97,16 @@ def test_set_resolved_mappings(): ) model = torch.nn.ModuleDict( { - "self_attn": torch.nn.ModuleDict( + "decoder": torch.nn.ModuleDict( { - "q_proj": torch.nn.Linear(4, 2), - "k_proj": torch.nn.Linear(4, 2), - "v_proj": torch.nn.Linear(4, 2), - "o_proj": torch.nn.Linear(4, 4), + "self_attn": torch.nn.ModuleDict( + { + "q_proj": torch.nn.Linear(4, 2), + "k_proj": torch.nn.Linear(4, 2), + "v_proj": torch.nn.Linear(4, 2), + "o_proj": torch.nn.Linear(4, 4), + } + ) } ) } @@ -164,3 +173,61 @@ def test_validate(): ), } ) + + +@pytest.mark.unit +def test_get_lowest_common_parent(): + mlp = torch.nn.ModuleDict( + { + "experts": torch.nn.ModuleList( + [ + torch.nn.ModuleDict( + { + "gate_proj": torch.nn.Linear(4, 2), + "down_proj": torch.nn.Linear(4, 2), + } + ) + for _ in range(10) + ] + ) + } + ) + self_attn = torch.nn.ModuleDict( + { + "q_proj": torch.nn.Linear(4, 2), + "k_proj": torch.nn.Linear(4, 2), + "v_proj": torch.nn.Linear(4, 2), + "o_proj": torch.nn.Linear(4, 4), + } + ) + model = torch.nn.ModuleDict( + { + "embed_tokens": torch.nn.Linear(4, 2), + "decoder": torch.nn.ModuleDict( + { + "self_attn": self_attn, + "mlp": mlp, + } + ), + } + ) + + parent_name, parent = get_lowest_common_parent( + ["decoder.mlp.experts.1.gate_proj", "decoder.mlp.experts.4.down_proj"], model + ) + assert parent_name == "decoder.mlp" and parent == mlp + + parent_name, parent = get_lowest_common_parent( + ["decoder.self_attn.q_proj", "decoder.self_attn.v_proj"], model + ) + assert parent_name == "decoder.self_attn" and parent == self_attn + + parent_name, parent = get_lowest_common_parent( + ["decoder.mlp.experts.1.gate_proj", "decoder.self_attn.v_proj"], model + ) + assert parent_name == "decoder" and parent == model["decoder"] + + parent_name, parent = get_lowest_common_parent( + ["embed_tokens", "decoder.self_attn.v_proj"], model + ) + assert parent_name == "" and parent == model diff --git a/tests/llmcompressor/utils/pytorch/test_module.py b/tests/llmcompressor/utils/pytorch/test_module.py index 22763aba90..1ab40aa159 100644 --- a/tests/llmcompressor/utils/pytorch/test_module.py +++ b/tests/llmcompressor/utils/pytorch/test_module.py @@ -1,7 +1,7 @@ import pytest import torch.nn as nn -from llmcompressor.utils.pytorch import get_parent_by_name +from llmcompressor.utils.pytorch import get_layer_by_name @pytest.fixture @@ -15,28 +15,20 @@ def example_nested_module() -> str: @pytest.mark.unit -def test_get_parent_by_name(example_nested_module): - # Test getting the parent of the first layer - name, parent = get_parent_by_name("0", example_nested_module) - assert parent == example_nested_module - +def test_get_layer_by_name(example_nested_module): # Test getting the parent of a nested layer - name, parent = get_parent_by_name("1.0", example_nested_module) - assert parent == example_nested_module[1] - assert name == "1" + layer = get_layer_by_name("0", example_nested_module) + assert layer == example_nested_module[0] - name, parent = get_parent_by_name("1.1", example_nested_module) - assert parent == example_nested_module[1] - assert name == "1" + layer = get_layer_by_name("1.1", example_nested_module) + assert layer == example_nested_module[1][1] - name, parent = get_parent_by_name("2.0", example_nested_module) - assert parent == example_nested_module[2] - assert name == "2" + layer = get_layer_by_name("2.0", example_nested_module) + assert layer == example_nested_module[2][0] - name, parent = get_parent_by_name("2.1", example_nested_module) - assert parent == example_nested_module[2] - assert name == "2" + layer = get_layer_by_name("2.1", example_nested_module) + assert layer == example_nested_module[2][1] # Test getting the parent of a non-existent layer - with pytest.raises(ValueError): - get_parent_by_name("non_existent_layer", example_nested_module) + with pytest.raises(AttributeError): + get_layer_by_name("non_existent_layer", example_nested_module)