Skip to content

Commit

Permalink
Fix SmoothQuant offload bug (#978)
Browse files Browse the repository at this point in the history
* fix offload

Signed-off-by: Dipika <dipikasikka1@gmail.com>

* fix smoothquant offload bug

* remove logtime

---------

Signed-off-by: Dipika <dipikasikka1@gmail.com>
  • Loading branch information
dsikka authored and horheynm committed Dec 20, 2024
1 parent 0e1745e commit c939f67
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions src/llmcompressor/modifiers/smoothquant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Callable, Dict, List, Optional, Tuple

import torch
from compressed_tensors.utils.offload import is_module_offloaded
from loguru import logger
from torch.nn import Module

Expand Down Expand Up @@ -282,6 +283,10 @@ def _apply_smoothing(self, model: Module):

@torch.no_grad()
def smooth(module):
offloaded = is_module_offloaded(module)
if offloaded:
module._hf_hook.pre_forward(module)

if module in balance_layers:
module.weight.mul_(scales.view(1, -1))
elif module == smooth_layer:
Expand All @@ -292,6 +297,9 @@ def smooth(module):
if hasattr(module, "bias") and module.bias is not None:
module.bias.div_(scales)

if offloaded:
module._hf_hook.post_forward(module, None)

parent = get_fsdp_parent(mapping.smooth_name, model)
if parent is not None:
parent.apply(smooth)
Expand All @@ -318,8 +326,16 @@ def _calculate_smoothing_scales(
# get the channel-wise dynamic range for each layer to be balanced
weight_scales = []
for layer in balance_layers:
offloaded = is_module_offloaded(layer)
if offloaded:
layer._hf_hook.pre_forward(layer)

scale = layer.weight.abs().max(dim=0, keepdim=True)[0]
weight_scales.append(scale)

if offloaded:
layer._hf_hook.post_forward(layer, None)

weight_scales = 2.0 * torch.cat(weight_scales, dim=0).max(dim=0)[0]

# calculate the amount of smoothing to apply
Expand All @@ -329,4 +345,5 @@ def _calculate_smoothing_scales(
1 - self.smoothing_strength
)
scales = torch.where(weight_scales > 0.0, scales, activation_scales)

return scales

0 comments on commit c939f67

Please sign in to comment.