Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SmoothQuant offload bug #978

Merged
merged 5 commits into from
Dec 18, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/llmcompressor/modifiers/smoothquant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Callable, Dict, List, Optional, Tuple

import torch
from compressed_tensors.utils.offload import is_module_offloaded
from loguru import logger
from torch.nn import Module

Expand Down Expand Up @@ -282,6 +283,10 @@ def _apply_smoothing(self, model: Module):

@torch.no_grad()
def smooth(module):
offloaded = is_module_offloaded(module)
if offloaded:
module._hf_hook.pre_forward(module)

if module in balance_layers:
module.weight.mul_(scales.view(1, -1))
elif module == smooth_layer:
Expand All @@ -292,6 +297,9 @@ def smooth(module):
if hasattr(module, "bias") and module.bias is not None:
module.bias.div_(scales)

if offloaded:
module._hf_hook.post_forward(module, None)

parent = get_fsdp_parent(mapping.smooth_name, model)
if parent is not None:
parent.apply(smooth)
Expand All @@ -318,8 +326,16 @@ def _calculate_smoothing_scales(
# get the channel-wise dynamic range for each layer to be balanced
weight_scales = []
for layer in balance_layers:
offloaded = is_module_offloaded(layer)
if offloaded:
layer._hf_hook.pre_forward(layer)

scale = layer.weight.abs().max(dim=0, keepdim=True)[0]
weight_scales.append(scale)

if offloaded:
layer._hf_hook.post_forward(layer, None)

weight_scales = 2.0 * torch.cat(weight_scales, dim=0).max(dim=0)[0]

# calculate the amount of smoothing to apply
Expand All @@ -329,4 +345,5 @@ def _calculate_smoothing_scales(
1 - self.smoothing_strength
)
scales = torch.where(weight_scales > 0.0, scales, activation_scales)

return scales
Loading