diff --git a/src/llmcompressor/modifiers/smoothquant/base.py b/src/llmcompressor/modifiers/smoothquant/base.py index f4117e31d..9381348b1 100644 --- a/src/llmcompressor/modifiers/smoothquant/base.py +++ b/src/llmcompressor/modifiers/smoothquant/base.py @@ -2,6 +2,7 @@ from typing import Callable, Dict, List, Optional, Tuple import torch +from compressed_tensors.utils.offload import is_module_offloaded from loguru import logger from torch.nn import Module @@ -282,6 +283,10 @@ def _apply_smoothing(self, model: Module): @torch.no_grad() def smooth(module): + offloaded = is_module_offloaded(module) + if offloaded: + module._hf_hook.pre_forward(module) + if module in balance_layers: module.weight.mul_(scales.view(1, -1)) elif module == smooth_layer: @@ -292,6 +297,9 @@ def smooth(module): if hasattr(module, "bias") and module.bias is not None: module.bias.div_(scales) + if offloaded: + module._hf_hook.post_forward(module, None) + parent = get_fsdp_parent(mapping.smooth_name, model) if parent is not None: parent.apply(smooth) @@ -318,8 +326,16 @@ def _calculate_smoothing_scales( # get the channel-wise dynamic range for each layer to be balanced weight_scales = [] for layer in balance_layers: + offloaded = is_module_offloaded(layer) + if offloaded: + layer._hf_hook.pre_forward(layer) + scale = layer.weight.abs().max(dim=0, keepdim=True)[0] weight_scales.append(scale) + + if offloaded: + layer._hf_hook.post_forward(layer, None) + weight_scales = 2.0 * torch.cat(weight_scales, dim=0).max(dim=0)[0] # calculate the amount of smoothing to apply @@ -329,4 +345,5 @@ def _calculate_smoothing_scales( 1 - self.smoothing_strength ) scales = torch.where(weight_scales > 0.0, scales, activation_scales) + return scales