diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 4201911e96..49fc491fc5 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -310,14 +310,24 @@ def _set_resolved_mappings(self, model: Module) -> None: if not balance_layer: continue - # exclude v_proj/o_proj mappings whose shapes are incompatible + # exclude v_proj->o_proj mappings whose shapes are incompatible # https://github.com/mit-han-lab/llm-awq/pull/67#issuecomment-1681632777 if ( - ".v_proj" in layer_name - and ".o_proj" in balance_name - and isinstance(smooth_layer, torch.nn.Linear) + isinstance(smooth_layer, torch.nn.Linear) and isinstance(balance_layer, torch.nn.Linear) - and smooth_layer.weight.shape != balance_layer.weight.shape + and ".o_proj" in balance_name + and ( + ( + ".v_proj" in layer_name + and smooth_layer.out_features + != balance_layer.in_features + ) + or ( + ".qkv_proj" in layer_name + and smooth_layer.out_features + != 3 * balance_layer.in_features + ) + ) ): num_skipped_oproj_mappings += 1 continue @@ -466,33 +476,42 @@ def _apply_smoothing(self, model: Module) -> None: inp, w_mean, x_mean, module2inspect, balance_layers, fp16_output ) - scales = best_scales - @torch.no_grad() def smooth(module): with align_module_device(module): + scales = best_scales.to(module.weight.device) if module in balance_layers: - module.weight.mul_(scales.view(1, -1).to(module.weight.device)) + update_offload_parameter( + module, + "weight", + module.weight.mul_(scales.view(1, -1)), + ) elif module == smooth_layer: if module.weight.ndim == 1: update_offload_parameter( module, "weight", - module.weight.div(scales.to(module.weight.device)), + module.weight.div_(scales), ) else: + # NOTE: edge case when smooth layer number of out_features + # is not equal to balance layer number of in_features + # e.g. when fused qkv_proj is used to smooth o_proj + # in this case, default to scaling the last output features + # because the desired smooth layer is v_proj + # https://github.com/casper-hansen/AutoAWQ/blob/main/awq/quantize/scale.py#L123 update_offload_parameter( module, "weight", - module.weight.div( - scales.view(-1, 1).to(module.weight.device) + module.weight[-scales.size(0) :].div_( + scales.view(-1, 1) ), ) if hasattr(module, "bias") and module.bias is not None: update_offload_parameter( module, "bias", - module.bias.div(scales.to(module.bias.device)), + module.bias.div_(scales), ) parent = get_fsdp_parent(mapping.smooth_name, model)