diff --git a/examples/awq/llama_example.py b/examples/awq/llama_example.py index d06a2ccb9..e31304b29 100644 --- a/examples/awq/llama_example.py +++ b/examples/awq/llama_example.py @@ -50,7 +50,9 @@ def tokenize(sample): # Configure the quantization algorithm to run. recipe = [ - AWQModifier(ignore=["lm_head"], scheme="W4A16_ASYM", targets=["Linear"]), + AWQModifier( + ignore=["lm_head"], scheme="W4A16_ASYM", targets=["Linear"], duo_scaling="both" + ), ] # Apply algorithms. diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 5a19a96b3..426ec0e22 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -3,17 +3,24 @@ from typing import Iterator, Literal import torch -from compressed_tensors.quantization import disable_quantization +from compressed_tensors.quantization import ( + QuantizationStrategy, + disable_quantization, + forward_quantize, +) +from compressed_tensors.quantization.utils import strategy_cdiv from compressed_tensors.utils import ( align_modules, get_execution_device, get_lowest_common_ancestor_name, + getattr_chain, match_modules_set, match_named_modules, + patch_attrs, update_offload_parameter, ) from loguru import logger -from pydantic import ConfigDict, PrivateAttr, model_validator +from pydantic import ConfigDict, PrivateAttr, field_validator from torch.nn import Module from torch.utils._pytree import tree_leaves from tqdm import tqdm @@ -25,9 +32,13 @@ ResolvedMapping, get_layer_mappings_from_architecture, ) -from llmcompressor.modifiers.quantization.calibration import update_weight_zp_scale +from llmcompressor.modifiers.quantization.calibration import ( + call_observer, + update_weight_zp_scale, +) from llmcompressor.modifiers.quantization.quantization import QuantizationMixin from llmcompressor.modifiers.utils.hooks import HooksMixin +from llmcompressor.observers.base import Observer from llmcompressor.pipelines.cache import IntermediatesCache from llmcompressor.utils.fsdp.helpers import get_fsdp_parent from llmcompressor.utils.helpers import calibration_forward_context @@ -138,11 +149,6 @@ class AWQModifier(Modifier, QuantizationMixin): duo_scaling: bool | Literal["both"] = True n_grid: int = 20 - # Private vars set during validation - _num_bits: int | None = PrivateAttr(default=None) - _symmetric: bool | None = PrivateAttr(default=None) - _group_size: int | None = PrivateAttr(default=None) - # Private vars set during initialization, cleared during finalization _resolved_mappings: list[ResolvedMapping] = PrivateAttr(default_factory=list) # Cache list of forward input args for each parent module, one dict for each batch @@ -154,74 +160,6 @@ class AWQModifier(Modifier, QuantizationMixin): default_factory=dict ) - # NOTE: different name chosen to avoid collision with - # QuantizationMixin.validate_model_after, which must be called first - @model_validator(mode="after") - def validate_awq_after(model: "AWQModifier") -> "AWQModifier": - """ - Confirm only one configuration for group_size, symmetric, and num_bits, - as AWQ algorithm depends on it - Confirm no activation quantization, as AWQ only works with WNA16 - """ - config = model.resolve_quantization_config() - - num_bits_set = set( - group.weights.num_bits - for group in config.config_groups.values() - if group.weights is not None - ) - assert ( - len(num_bits_set) == 1 - ), "In AWQ, all config groups must use the same configuration for num_bits" - - model._num_bits = next(iter(num_bits_set)) - - symmetric_set = set( - group.weights.symmetric - for group in config.config_groups.values() - if group.weights is not None - ) - assert ( - len(symmetric_set) == 1 - ), "In AWQ, all config groups must use the same configuration for symmetric" - - model._symmetric = next(iter(symmetric_set)) - - group_size_set = set( - group.weights.group_size - for group in config.config_groups.values() - if group.weights is not None - ) - assert ( - len(group_size_set) == 1 - ), "In AWQ, all config groups must use the same configuration for group_size" - - model._group_size = next(iter(group_size_set)) - if model._group_size is None: - model._group_size = -1 - - in_num_bits_set = set( - group.input_activations.num_bits - for group in config.config_groups.values() - if group.input_activations is not None - ) - assert len(in_num_bits_set) == 0 or in_num_bits_set == {16}, ( - "AWQ activations must be 16-bit precision, " - f"input activations {in_num_bits_set} not allowed" - ) - - out_num_bits_set = set( - group.output_activations.num_bits - for group in config.config_groups.values() - if group.output_activations is not None - ) - assert len(out_num_bits_set) == 0 or out_num_bits_set == {16}, ( - "AWQ activations must be 16-bit precision, " - f"output activations {out_num_bits_set} not allowed" - ) - - return model - def on_initialize(self, state: State, **kwargs) -> bool: """ Initialize AWQ on the given state @@ -235,6 +173,24 @@ def on_initialize(self, state: State, **kwargs) -> bool: if QuantizationMixin.has_config(self): QuantizationMixin.initialize_quantization(self, state.model) + # Validate that duo_scaling is only used with per-channel quantization + if self.duo_scaling is not False: + for _, module in match_named_modules( + state.model, self.resolved_targets, self.ignore + ): + if ( + hasattr(module, "quantization_scheme") + and hasattr(module.quantization_scheme, "weights") + and module.quantization_scheme.weights.strategy + == QuantizationStrategy.TENSOR + ): + raise ValueError( + "duo_scaling is only supported with per-channel quantization " + "strategies (group or channel), but found TENSOR strategy. " + "Please set duo_scaling=False or use a per-channel " + "quantization strategy." + ) + if self.mappings is None: logger.info("No AWQModifier.mappings provided, inferring from model...") self.mappings = get_layer_mappings_from_architecture( @@ -387,7 +343,7 @@ def _setup_activation_cache_hooks(self) -> None: """ def cache_parent_kwargs_hook( - module: torch.nn.Module, + module: Module, args: tuple[torch.Tensor, ...], kwargs, ): @@ -396,7 +352,7 @@ def cache_parent_kwargs_hook( def create_cache_smooth_activations_hook_fn(smooth_name): def cache_smooth_activations_hook( - _module: torch.nn.Module, + _module: Module, args: tuple[torch.Tensor, ...], _output: torch.Tensor, ): @@ -442,7 +398,7 @@ def _apply_smoothing(self, model: Module) -> None: :param model: model to apply smoothing to """ # NOTE: When using SequentialPipeline, not all the mappings - # will have cached activations in the segment being udpated + # will have cached activations in the segment being updated mappings_to_smooth = [ mapping for mapping in self._resolved_mappings @@ -458,28 +414,7 @@ def _apply_smoothing(self, model: Module) -> None: calibration_forward_context(model), HooksMixin.disable_hooks(), ): - # [STEP 1]: Compute per-channel mean of normalised weights - # All layer weights are concatted together - weight = torch.cat([bl.weight for bl in balance_layers], dim=0) - org_shape = weight.shape - # The weights are reshaped to be organised by quantization group - if self._group_size > 0: - weight = weight.view(-1, self._group_size) - # Calculates the relative magnitude of the weights within - # each of the quantization groups, and rescales each group - # individually so that each group has weights on a 0-1 scale. - weight.abs_() - weight.div_(weight.amax(dim=1, keepdim=True) + 1e-6) - if self._group_size > 0: - # Resizes the rescaled weight matrix back up to - # its original dimensions - weight = weight.view(org_shape) - # Gets the average rescaled magnitude for each output channel - w_mean = weight.mean(0) - del weight - - # [STEP 3]: Compute output of module - # could cache from hook, rather than recomputing here + # Compute output of unquantized module fp16_outputs = self._run_samples(parent_module) if len(fp16_outputs) == 0 or all(f.numel() == 0 for f in fp16_outputs): logger.info( @@ -504,15 +439,10 @@ def _apply_smoothing(self, model: Module) -> None: del self._smooth_activation_means[mapping.smooth_name] continue - x_mean = self._smooth_activation_means[mapping.smooth_name][0] - - # [STEP 4]: Compute loss - best_scales = self._compute_best_scale( - x_mean, w_mean, parent_module, balance_layers, fp16_outputs - ) + best_scales = self._compute_best_scale(mapping, fp16_outputs) @torch.no_grad() - def _smooth(module): + def _smooth(module: Module): scales = best_scales.to(module.weight.device) if module in balance_layers: update_offload_parameter( @@ -565,27 +495,31 @@ def _run_samples(self, module: Module) -> list[torch.Tensor]: module(**batch_kwargs) for batch_kwargs in self._parent_args_cache[module] ] return [ - # If Tuple, assume that first argument is the input + # If tuple, assume that first argument is the input output[0] if isinstance(output, tuple) else output for output in outputs ] def _compute_best_scale( self, - x_mean: torch.Tensor, - w_mean: torch.Tensor, - parent_module: torch.nn.Module, - linears2scale: list[torch.nn.Linear], + mapping: ResolvedMapping, fp16_outputs: list[torch.Tensor], ) -> torch.Tensor: """ - Compute loss and select best scales + Select best scales for a given mapping in a grid search + Best scales are those that minimize MSE loss of quantized weight + outputs compared to fp16_outputs L(s) = || Q(W * s) (s^-1 * X) - W * X || Q: weight quantization function | _pseudo_quantize_tensor(W * s) X: inputs from calib dataset | X W: original weights in FP16 | layer s: per channel scaling factor | s^-1 * X + + :param mapping: best scales will be found for the ResolvedMapping. + :param fp16_outputs: output of mapping.parent in unquantized case, + one tensor for each batch. + :return: tensor of best scales, one for each channel """ history = [] best_ratio = -1 @@ -594,13 +528,15 @@ def _compute_best_scale( org_sd = { k: v.cpu() - for k, v in parent_module.state_dict().items() + for k, v in mapping.parent.state_dict().items() if v.device != torch.device("meta") } - device = get_execution_device(parent_module) - x_mean = x_mean.view(-1).to(device) - w_mean = w_mean.view(-1).to(device) + device = get_execution_device(mapping.parent) + + x_mean = self._smooth_activation_means[mapping.smooth_name][0].to(device) + if self.duo_scaling: + w_mean = self._compute_layer_means(mapping.balance_layers).to(device) match self.duo_scaling: # if self.duo_scaling is "both", perform half the grid search with @@ -611,52 +547,88 @@ def _compute_best_scale( case _: n_grid = self.n_grid duo_scalings = [self.duo_scaling] - for grid_idx, use_duo_scaling in product(range(n_grid), duo_scalings): - # create new scales - ratio = grid_idx / n_grid - - # NOTE: s^-1 * x is fused here, according to paper - if use_duo_scaling: - scales = (x_mean.pow(ratio) / (w_mean.pow(1 - ratio) + 1e-4)).clamp( - min=1e-4 - ) - else: - scales = x_mean.pow(ratio).clamp(min=1e-4).view(-1) - scales = scales / (scales.max() * scales.min()).sqrt() - _scalesview = scales.view(1, -1).to(device) - - # avoid scaling values that overflow - scales[torch.isinf(scales)] = 1 - scales[torch.isnan(scales)] = 1 - - # Q(W * s) - for linear in linears2scale: - linear.weight.mul_(_scalesview) - update_offload_parameter( - linear, - "weight", - _pseudo_quantize_tensor( - w=linear.weight.data, - symmetric=self._symmetric, - bit_width=self._num_bits, - group_size=self._group_size, - )[0] - / _scalesview, + + # Where appropriate, replace observers with memoryless_minmax + # for duration of grid search + balance_layers_to_patch = [ + balance_layer + for balance_layer in mapping.balance_layers + if hasattr(balance_layer, "quantization_scheme") + and hasattr(balance_layer.quantization_scheme, "weights") + ] + with patch_attrs( + balance_layers_to_patch, + "weight_observer", + [ + Observer.load_from_registry( + "memoryless_minmax", + base_name="weight", + args=balance_layer.quantization_scheme.weights, + module=balance_layer, ) + for balance_layer in balance_layers_to_patch + ], + ): + for grid_idx, use_duo_scaling in product(range(n_grid), duo_scalings): + # create new scales + ratio = grid_idx / n_grid + + # NOTE: s^-1 * x is fused here, according to paper + if use_duo_scaling: + scales = (x_mean.pow(ratio) / (w_mean.pow(1 - ratio) + 1e-4)).clamp( + min=1e-4 + ) + else: + scales = x_mean.pow(ratio).clamp(min=1e-4).view(-1) + scales = scales / (scales.max() * scales.min()).sqrt() + _scalesview = scales.view(1, -1).to(device) + + # avoid scaling values that overflow + scales[torch.isinf(scales)] = 1 + scales[torch.isnan(scales)] = 1 + + # Q(W * s) + for balance_layer in balance_layers_to_patch: + if not hasattr(balance_layer, "quantization_scheme") or not hasattr( + balance_layer.quantization_scheme, "weights" + ): + continue + + w_qscheme = balance_layer.quantization_scheme.weights + balance_layer.weight.mul_(_scalesview) + call_observer( + balance_layer, + "weight", + balance_layer.weight, + # TODO test should_calculate_gparam for nvfp4 support + ) + update_offload_parameter( + balance_layer, + "weight", + forward_quantize( + balance_layer, + balance_layer.weight.data, + "weight", + w_qscheme, + ) + / _scalesview, + ) - # W * X - int_w_outputs = self._run_samples(parent_module) + # W * X + int_w_outputs = self._run_samples(mapping.parent) - # compute mean squared error (L2 norm) - loss = self._compute_loss(fp16_outputs, int_w_outputs, device) + # compute mean squared error (L2 norm) + loss = self._compute_loss(fp16_outputs, int_w_outputs) - history.append(loss) - if loss < best_error: - best_error = loss - best_ratio = ratio - best_scales = scales.clone() + history.append( + {"ratio": ratio, "duo_scaling": use_duo_scaling, "error": loss} + ) + if loss < best_error: + best_error = loss + best_ratio = ratio + best_scales = scales.clone() - parent_module.load_state_dict(org_sd, strict=False) + mapping.parent.load_state_dict(org_sd, strict=False) if best_ratio == -1: logger.debug(history) @@ -678,22 +650,15 @@ def _compute_loss( self, fp16_outputs: list[torch.Tensor], int_w_outputs: list[torch.Tensor], - device: torch.device, - ) -> torch.Tensor: + ) -> float: loss = 0.0 num_elements = 0 # Compute the MSE loss for each batch for fp16_batch, int_w_batch in zip(fp16_outputs, int_w_outputs): - batch_loss = ( - (fp16_batch.to(device) - int_w_batch.to(device)) - .view(-1) - .float() - .pow(2) - .sum() - .item() - ) - loss += batch_loss + loss += torch.nn.functional.mse_loss( + fp16_batch, int_w_batch.to(fp16_batch.device) + ).item() num_elements += fp16_batch.numel() # Normalize the loss by the total number of elements @@ -709,6 +674,133 @@ def _assert_all_activations_consumed(self): if len(self._smooth_activation_means) != 0: raise RuntimeError("Some cached activations were not used") + @staticmethod + def _compute_layer_means(layers: list[Module]) -> torch.Tensor: + """ + Compute per-channel/group/block/tensor mean of normalised weights + for all passed in layers taking into account the quantization_scheme. + + To minimize memory requirements, layers are reduced to a running total + of sums and counts when calculating mean + """ + # to calculate mean without having to carry full population + weight_total_count = 0 + weight_total_sum = 0 + + for layer in layers: + if not hasattr(layer, "weight"): + logger.warning( + "Unable to find weight param for targeted" + f" layer {type(layer)}, skipping" + ) + continue + weight = layer.weight + orig_shape = weight.shape + + q_args = getattr_chain(layer, "quantization_scheme.weights", None) + if not q_args: + logger.warning( + "Unable to find quantization scheme for " + f"targeted layer {type(layer)}, skipping" + ) + continue + + # need to get to shape [num different chunks x size of each chunk] + weight = _orient_weight(weight, q_args) + # TODO ^ simplify logic and use flatten_for_calibration + + weight.abs_() + weight.div_(weight.amax(dim=1, keepdim=True) + 1e-6) + # Reshape back to original dimensions + weight = _reorient_weight(weight, q_args, orig_shape) + # Gets the average rescaled magnitude for each output channel + weight_total_count += weight.size(0) + weight_sum = weight.sum(0, dtype=torch.float64) + weight_total_sum += weight_sum + + return weight_total_sum / weight_total_count + + @field_validator("duo_scaling") + @classmethod + def validate_duo_scaling(cls, v): + """Validate that duo_scaling is either True, False, or 'both' (lowercase)""" + if v not in (True, False, "both"): + raise ValueError(f"duo_scaling must be True, False, or 'both', got {v!r}") + return v + + +def _orient_weight(weight: torch.Tensor, q_args) -> torch.Tensor: + """ + Orient weight so we have shape + [, ]. + Works for TENSOR, CHANNEL, GROUP, BLOCK strategies + """ + if q_args.strategy in [ + QuantizationStrategy.TENSOR, + QuantizationStrategy.CHANNEL, + QuantizationStrategy.GROUP, + ]: + match q_args.strategy: + case QuantizationStrategy.TENSOR: + group_size = weight.numel() + case QuantizationStrategy.CHANNEL: + group_size = weight.size(1) + case QuantizationStrategy.GROUP: + group_size = q_args.group_size + weight = weight.view(-1, group_size) + + elif q_args.strategy == QuantizationStrategy.BLOCK: + block_height, block_width = q_args.block_structure + block_size = block_height * block_width + rows, cols = weight.shape + num_heights = strategy_cdiv(rows, block_height, q_args.strategy, strict=True) + num_widths = strategy_cdiv(cols, block_width, q_args.strategy, strict=True) + weight = ( + weight.reshape( # nH*H=rows, nW*W=cols + num_heights, block_height, num_widths, block_width + ) # nH, H, nW, W + .transpose(1, 2) # nH, nW, H, W + .reshape(-1, block_size) # nH*nW, H*W + ) + else: + raise NotImplementedError( + "expected weight quantization strategy to be one " + f"of TENSOR, CHANNEL, GROUP, or BLOCK, got {q_args.strategy}" + ) + return weight + + +def _reorient_weight(weight: torch.Tensor, q_args, orig_shape) -> torch.Tensor: + """ + undo _orient_weight() operation returning weight to original shape + """ + if q_args.strategy in [ + QuantizationStrategy.TENSOR, + QuantizationStrategy.CHANNEL, + QuantizationStrategy.GROUP, + ]: + return weight.reshape(orig_shape) + + elif q_args.strategy == QuantizationStrategy.BLOCK: + block_height, block_width = q_args.block_structure + rows, cols = orig_shape + num_heights = strategy_cdiv(rows, block_height, q_args.strategy, strict=True) + num_widths = strategy_cdiv(cols, block_width, q_args.strategy, strict=True) + + weight = ( + weight.view( # nH*nW, H*W + num_heights, num_widths, block_height, block_width + ) # nH, nW, H, W + .transpose(1, 2) # nH, H, nW, W + .reshape(orig_shape) # nH*H=rows, nW*W=cols + ) + else: + raise NotImplementedError( + "expected weight quantization strategy to be " + f"one of TENSOR, CHANNEL, GROUP, or BLOCK, got {q_args.strategy}" + ) + return weight + def _check_layers_are_compatible( smooth_layer, smooth_name, balance_layers, balance_names @@ -764,49 +856,6 @@ def get_lowest_common_ancestor_with_avoid( ancestor_name = ".".join(ancestor_name.split(".")[:-1]) -def _pseudo_quantize_tensor( - w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1 -): - org_w_shape = w.shape - if group_size > 0: - assert org_w_shape[-1] % group_size == 0, ( - f"org_w_shape ({org_w_shape[-1]}) must be a multiple " - + f"of group_size ({group_size})!" - ) - w = w.reshape(-1, group_size) - assert w.dim() == 2 - assert torch.isnan(w).sum() == 0 - - # zero point quantization - if not symmetric: - max_val = w.amax(dim=1, keepdim=True) - min_val = w.amin(dim=1, keepdim=True) - max_int = 2**bit_width - 1 - min_int = 0 - scales = (max_val - min_val).clamp(min=1e-5) / max_int - zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int) - w = ( - torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros - ) * scales - zeros = (zeros - 2 ** (bit_width - 1)).view(org_w_shape[0], -1) - else: - max_val = w.abs().amax(dim=1, keepdim=True) - max_val = max_val.clamp(min=1e-5) - max_int = 2 ** (bit_width - 1) - 1 - min_int = -(2 ** (bit_width - 1)) - scales = max_val / max_int - zeros = None - w = torch.clamp(torch.round(w / scales), min_int, max_int) * scales - - assert torch.isnan(scales).sum() == 0 - assert torch.isnan(w).sum() == 0 - - scales = scales.view(org_w_shape[0], -1) - w = w.reshape(org_w_shape) - - return w, scales, zeros - - def _accumulate_mean( inp: torch.Tensor, prev_mean_and_count: tuple[torch.FloatTensor, int] | None, diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index 9109b8fa3..3017c4dea 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -81,7 +81,8 @@ def call_observer( base_name is "weight", then the module's weight tensor will be used """ with align_module_device(module): - value = module.weight if base_name == "weight" else value + if value is None and base_name == "weight": + value = module.weight observer: Observer = getattr(module, f"{base_name}_observer") if should_calculate_gparam: diff --git a/tests/llmcompressor/modifiers/awq/test_base.py b/tests/llmcompressor/modifiers/awq/test_base.py index 5d82da29a..3801bd2fd 100644 --- a/tests/llmcompressor/modifiers/awq/test_base.py +++ b/tests/llmcompressor/modifiers/awq/test_base.py @@ -1,10 +1,22 @@ +from itertools import product + import pytest import torch -from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationScheme, + QuantizationStrategy, +) from pydantic import ValidationError from torch.nn import Linear +from torch.testing import assert_close from llmcompressor.modifiers.awq import AWQMapping, AWQModifier +from llmcompressor.modifiers.awq.base import ( + _orient_weight, + _reorient_weight, + get_lowest_common_ancestor_with_avoid, +) from llmcompressor.modifiers.factory import ModifierFactory @@ -140,63 +152,6 @@ def test_set_resolved_mappings(): @pytest.mark.unit def test_validate(): - with pytest.raises(ValidationError): - AWQModifier(scheme="W8A8") - - with pytest.raises(ValidationError): - AWQModifier( - config_groups={ - "group_0": QuantizationScheme( - targets=["Linear"], - weights=QuantizationArgs( - num_bits=4, - group_size=64, - ), - ), - "group_1": QuantizationScheme( - targets=["Linear"], - weights=QuantizationArgs( - num_bits=4, - group_size=128, - ), - ), - } - ) - - with pytest.raises(ValidationError): - AWQModifier( - config_groups={ - "group_0": QuantizationScheme( - targets=["Linear"], - weights=QuantizationArgs( - num_bits=4, - group_size=128, - ), - ), - "group_1": QuantizationScheme( - targets=["Linear"], - weights=QuantizationArgs( - num_bits=8, - group_size=128, - ), - ), - } - ) - - # valid configuration - AWQModifier( - config_groups={ - "group_0": QuantizationScheme( - targets=["Linear"], - weights=QuantizationArgs(num_bits=4, group_size=128, symmetric=False), - ), - "group_1": QuantizationScheme( - targets=["Linear"], - weights=QuantizationArgs(num_bits=4, group_size=128, symmetric=False), - ), - } - ) - AWQModifier(scheme="W4A16", duo_scaling="both") with pytest.raises(ValidationError): AWQModifier(scheme="W4A16", duo_scaling="Both") @@ -261,5 +216,171 @@ def test_moe_multiple_balance_layers(): } assert set(mapping.balance_names) == expected_balance_names - assert mapping.parent_name == "layer.mlp" - assert mapping.parent == mlp + parent_name, parent = get_lowest_common_ancestor_with_avoid( + ["embed_tokens", "decoder.self_attn.v_proj"], model + ) + assert parent_name == "" and parent == model + + +def _auto_awq_normalize(layers: list[torch.nn.Module], group_size) -> torch.Tensor: + """ + Original AutoAwq implementation (need to call .mean(0) to get normalized layer + means + """ + # [STEP 1]: Compute per-channel mean of normalised weights + # All layer weights are concatted together + weight = torch.cat([bl.weight for bl in layers], dim=0) + orig_shape = weight.shape + # The weights are reshaped to be organised by quantization group + if group_size is not None: + weight = weight.view(-1, group_size) + # Calculates the relative magnitude of the weights within + # each of the quantization groups, and rescales each group + # individually so that each group has weights on a 0-1 scale. + weight.abs_() + weight.div_(weight.amax(dim=1, keepdim=True) + 1e-6) + return weight.view(orig_shape) + + +@torch.no_grad +@pytest.mark.unit +@pytest.mark.parametrize( + "n_balance_layers, group_size, n_input_features", + [ + (5, -1, 32), # channel + (4, 10, 40), # group + (4, torch.inf, 40), # tensor + ], +) +def test_compute_layer_means(n_balance_layers, group_size, n_input_features): + """ + Confirm our logic to compute duo_scaling layer means via a running tally + matches the original memory-intensive AutoAWQ implementation, which concats + all balance layers into a single tensor before reducing to mean + Large models were prone to fail at this step. + """ + balance_layers = [ + torch.nn.Linear(n_input_features, 10) for _ in range(n_balance_layers) + ] + group_size_arg = None + match group_size: + case -1: + strategy = QuantizationStrategy.CHANNEL + group_size = balance_layers[0].weight.shape[1] + case torch.inf: + strategy = QuantizationStrategy.TENSOR + group_size = n_input_features * 10 + case _: + strategy = QuantizationStrategy.GROUP + group_size_arg = group_size + + for balance_layer in balance_layers: + setattr( + balance_layer, + "quantization_scheme", + QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs( + strategy=strategy, + group_size=group_size_arg, + ), + ), + ) + + auto_awq_means = _auto_awq_normalize(balance_layers, group_size).mean(0) + + llmc_awq_means = AWQModifier._compute_layer_means(balance_layers).to( + auto_awq_means.dtype + ) + + assert_close(auto_awq_means, llmc_awq_means) + + +@pytest.mark.unit +@pytest.mark.parametrize( + "rows, cols, block_height, block_width", + [ + ( + 32, + 256, + 4, + 8, + ), + ( + 4, + 3, + 2, + 1, + ), + ( + 10, + 10, + 10, + 10, + ), + ( + 512, + 256, + 128, + 128, + ), + ], +) +@torch.no_grad +def test_block_strategy_compute_layer_means(rows, cols, block_height, block_width): + """ + Confirm our logic to compute layer means works for BLOCK quantization + """ + lin = torch.nn.Linear(cols, rows) + setattr( + lin, + "quantization_scheme", + QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs( + strategy=QuantizationStrategy.BLOCK, + block_structure=[block_height, block_width], + ), + ), + ) + # main + llmc_awq_means = AWQModifier._compute_layer_means([lin]) + + # ref + num_heights = rows // block_height + num_widths = cols // block_width + + ref_weight = torch.zeros_like(lin.weight) + with torch.no_grad(): + for i, j in product(range(num_heights), range(num_widths)): + block = lin.weight[ + i * block_height : (i + 1) * block_height, + j * block_width : (j + 1) * block_width, + ].abs() + block = block / (block.max() + 1e-6) + ref_weight[ + i * block_height : (i + 1) * block_height, + j * block_width : (j + 1) * block_width, + ] = block + ref_means = ref_weight.sum(0, dtype=torch.float64) / ref_weight.size(0) + + # auto awq + # we first reshape the weight such that it is effectively per-channel quantization + # so that we can compare to the existing _auto_awq_normalize function + orig_shape = lin.weight.shape + oriented_weight = _orient_weight(lin.weight, lin.quantization_scheme.weights) + lin.weight.data = oriented_weight + + auto_awq_means = ( + _reorient_weight( + _auto_awq_normalize([lin], None), + lin.quantization_scheme.weights, + orig_shape, + ) + .mean(0) + .to(llmc_awq_means.dtype) + ) + + # check + assert_close(llmc_awq_means, ref_means, atol=1e-5, rtol=1e-5) + assert_close(llmc_awq_means, auto_awq_means, atol=1e-5, rtol=1e-5)