Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement HooksMixin #917

Merged
merged 18 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions src/llmcompressor/modifiers/modifier.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
from abc import ABC, abstractmethod
from abc import abstractmethod
from typing import Optional

from pydantic import BaseModel

from llmcompressor.core.events import Event, EventType
from llmcompressor.core.state import State
from llmcompressor.modifiers.interface import ModifierInterface
from llmcompressor.modifiers.utils.hooks import HooksMixin

__all__ = ["Modifier"]


class Modifier(BaseModel, ModifierInterface, ABC):
class Modifier(ModifierInterface, HooksMixin):
"""
A base class for all modifiers to inherit from.
Modifiers are used to modify the training process for a model.
Expand Down
72 changes: 35 additions & 37 deletions src/llmcompressor/modifiers/obcq/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -130,7 +131,8 @@ def initialize_compression(
"Inferring layer-wise sparsities from "
f"{len(dataloader)} calibration samples..."
)
self.sparsity = self._infer_layer_sparsity(dataloader)
activations = self._get_activations(dataloader)
self.sparsity = self._infer_layer_sparsity(activations)
self._validate_layerwise_sparsity()

for idx, (name, layer) in enumerate(self.compressible_layers_.items()):
Expand Down Expand Up @@ -254,19 +256,17 @@ def _infer_mask_block_size(self):

self.prunen_, self.prunem_ = list(map(int, self.mask_structure.split(":")))

def _infer_layer_sparsity(self, calibration_dataloader):
acts = _get_activations(self.model, calibration_dataloader)
def _infer_layer_sparsity(self, activations):
sparsegpt_groups = {}
for name, layer in self.compressible_layers_.items():
prunable_layers = get_prunable_layers(layer)
z = [
m.weight.abs() * acts[f"{name}.{n}"].unsqueeze(0)
m.weight.abs() * activations[f"{name}.{n}"].unsqueeze(0)
for n, m in prunable_layers.items()
]
sparsegpt_groups[name] = torch.cat([item.flatten().cpu() for item in z])

acts = None
del acts
del activations
torch.cuda.empty_cache()

outlier_ratios = {}
Expand Down Expand Up @@ -300,36 +300,34 @@ def _infer_layer_sparsity(self, calibration_dataloader):
logger.info(f"Sparsity for {k}: {sparsities[k]}")
return sparsities

@torch.no_grad()
def _get_activations(self, data_loader, nsamples=128):
self.model.eval()
acts = {}

def save_acts(module, input, name):
if isinstance(input, tuple):
input = input[0]
if name not in acts:
acts[name] = (
1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt()
)
else:
acts[name] += (
1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt()
)

for name, mod in self.model.named_modules():
if isinstance(mod, torch.nn.Linear) and "lm_head" not in name:
self.register_hook(mod, partial(save_acts, name=name), "forward_pre")

device = next(self.model.parameters()).device
for batch in tqdm(data_loader):
batch = {k: v.to(device) for k, v in batch.items()}
self.model(**batch)
batch = None
torch.cuda.empty_cache()

@torch.no_grad()
def _get_activations(model, data_loader, nsamples=128):
import functools

model.eval()
acts = {}

def save_acts(module, input, name):
if isinstance(input, tuple):
input = input[0]
if name not in acts:
acts[name] = 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt()
else:
acts[name] += 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt()

hooks = []
for name, mod in model.named_modules():
if isinstance(mod, torch.nn.Linear) and "lm_head" not in name:
hooks.append(
mod.register_forward_pre_hook(functools.partial(save_acts, name=name))
)
device = next(model.parameters()).device
for batch in tqdm(data_loader):
batch = {k: v.to(device) for k, v in batch.items()}
model(**batch)
batch = None
torch.cuda.empty_cache()

for h in hooks:
h.remove()
self.remove_hooks()

return acts
return acts
21 changes: 5 additions & 16 deletions src/llmcompressor/modifiers/pruning/utils/pytorch/layer_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
from typing import Dict

import torch
from pydantic import BaseModel
from torch.nn import Parameter
from torch.utils.hooks import RemovableHandle

from llmcompressor.core import ModelParameterizedLayer
from llmcompressor.modifiers.utils.hooks import HooksMixin

__all__ = ["LayerParamMasking", "param_mask_name"]

Expand Down Expand Up @@ -39,11 +38,9 @@ class ParameterizedLayerMaskSettings:
use_hooks: bool = False


class LayerParamMasking(BaseModel):
class LayerParamMasking(HooksMixin):
_mask_settings: Dict[str, ParameterizedLayerMaskSettings] = {}
_masked_layer_params: Dict[str, ModelParameterizedLayer] = {}
_forward_hooks: Dict[str, RemovableHandle] = {}
_backward_hooks: Dict[str, RemovableHandle] = {}
enabled_: bool = False

def add_mask(
Expand Down Expand Up @@ -100,12 +97,8 @@ def _backward_hook_fn(gradients):

return gradients

self._forward_hooks[layer_param_name] = (
parameterized_layer.layer.register_forward_hook(_forward_hook_fn)
)
self._backward_hooks[layer_param_name] = (
parameterized_layer.param.register_hook(_backward_hook_fn)
)
self.register_hook(parameterized_layer.layer, _forward_hook_fn, "forward")
self.register_hook(parameterized_layer.param, _backward_hook_fn, "")

def update_mask(
self,
Expand All @@ -131,11 +124,7 @@ def remove_mask(self, layer_param_name: str):
del self._mask_settings[layer_param_name]

if mask_settings.use_hooks:
self._forward_hooks[layer_param_name].remove()
self._backward_hooks[layer_param_name].remove()

del self._forward_hooks[layer_param_name]
del self._backward_hooks[layer_param_name]
self.remove_hooks()

def apply_mask_weight(self, layer_param_name: str):
if not self.enabled_:
Expand Down
72 changes: 35 additions & 37 deletions src/llmcompressor/modifiers/pruning/wanda/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -121,7 +122,8 @@ def initialize_compression(
"Inferring layer-wise sparsities from "
f"{len(dataloader) if dataloader else 0} calibration samples..."
)
self.sparsity = self._infer_layer_sparsity(dataloader)
activations = self._get_activations(dataloader)
self.sparsity = self._infer_layer_sparsity(activations)
self._validate_layerwise_sparsity()

for idx, (name, layer) in enumerate(self.compressible_layers_.items()):
Expand Down Expand Up @@ -224,19 +226,17 @@ def _infer_mask_block_size(self):

self.prunen_, self.prunem_ = list(map(int, self.mask_structure.split(":")))

def _infer_layer_sparsity(self, calibration_dataloader):
acts = _get_activations(self.model, calibration_dataloader)
def _infer_layer_sparsity(self, activations):
wanda = {}
for name, layer in self.compressible_layers_.items():
prunable_layers = get_prunable_layers(layer)
z = [
m.weight.abs() * acts[f"{name}.{n}"].unsqueeze(0)
m.weight.abs() * activations[f"{name}.{n}"].unsqueeze(0)
for n, m in prunable_layers.items()
]
wanda[name] = torch.cat([item.flatten().cpu() for item in z])

acts = None
del acts
del activations
torch.cuda.empty_cache()

outlier_ratios = {}
Expand Down Expand Up @@ -268,36 +268,34 @@ def _infer_layer_sparsity(self, calibration_dataloader):
logger.info(f"Sparsity for {k}: {sparsities[k]}")
return sparsities

@torch.no_grad()
def _get_activations(self, data_loader, nsamples=128):
self.model.eval()
acts = {}

def save_acts(module, input, name):
if isinstance(input, tuple):
input = input[0]
if name not in acts:
acts[name] = (
1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt()
)
else:
acts[name] += (
1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt()
)

for name, mod in self.model.named_modules():
if isinstance(mod, torch.nn.Linear) and "lm_head" not in name:
self.register_hook(mod, partial(save_acts, name=name), "forward_pre")

device = next(self.model.parameters()).device
for batch in tqdm(data_loader):
batch = {k: v.to(device) for k, v in batch.items()}
self.model(**batch)
batch = None
torch.cuda.empty_cache()

@torch.no_grad()
def _get_activations(model, data_loader, nsamples=128):
import functools

model.eval()
acts = {}

def save_acts(module, input, name):
if isinstance(input, tuple):
input = input[0]
if name not in acts:
acts[name] = 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt()
else:
acts[name] += 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt()

hooks = []
for name, mod in model.named_modules():
if isinstance(mod, torch.nn.Linear) and "lm_head" not in name:
hooks.append(
mod.register_forward_pre_hook(functools.partial(save_acts, name=name))
)
device = next(model.parameters()).device
for batch in tqdm(data_loader):
batch = {k: v.to(device) for k, v in batch.items()}
model(**batch)
batch = None
torch.cuda.empty_cache()

for h in hooks:
h.remove()
self.remove_hooks()

return acts
return acts
70 changes: 28 additions & 42 deletions src/llmcompressor/modifiers/quantization/calibration.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Any, Dict, Optional, Tuple

import torch
from compressed_tensors.quantization import QuantizationStatus, is_attention_module
Expand Down Expand Up @@ -146,71 +146,57 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str):
)


def calibrate_input_hook():
def calibrate_input_hook(module: Module, args: Any):
"""
Hook to calibrate input activations.
Will call the observers to update the scales/zp before applying
input QDQ in the module's forward pass.
"""
args = args[0] if isinstance(args, tuple) else args
calibrate_activations(module, value=args, base_name="input")

def hook_fn(module: Module, inp):
inp = inp[0] if isinstance(inp, tuple) else inp
calibrate_activations(module, value=inp, base_name="input")

return hook_fn


def calibrate_output_hook():
def calibrate_output_hook(module: Module, _args: Any, output: torch.Tensor):
"""
Hook to calibrate output activations.
Will call the observers to update the scales/zp before applying
output QDQ.
"""

def hook_fn(module: Module, inp, output: torch.Tensor):
calibrate_activations(
module,
value=output,
base_name="output",
)
output = forward_quantize(
module=module,
value=output,
base_name="output",
args=module.quantization_scheme.output_activations,
)
return output

return hook_fn
calibrate_activations(
module,
value=output,
base_name="output",
)
output = forward_quantize(
module=module,
value=output,
base_name="output",
args=module.quantization_scheme.output_activations,
)
return output


def calibrate_kv_cache_input_hook():
def calibrate_kv_cache_input_hook(
module: Module, args: Any, kwargs: Dict[str, Any]
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
"""
Hook to update inputs to attention layers when running
kv_cache quantization. Will update the passed in
kv_cache to singleton QuantizedKVParameterCache.
"""
kv_cache = getattr(module, "kv_cache")
kwargs["past_key_value"] = kv_cache
kwargs["use_cache"] = False
return args, kwargs

def hook_fn(module: Module, args, kwargs):
kv_cache = getattr(module, "kv_cache")
kwargs["past_key_value"] = kv_cache
kwargs["use_cache"] = False
return args, kwargs

return hook_fn


def calibrate_kv_cache_output_hook():
def calibrate_kv_cache_output_hook(module: Module, _args: Any, _output: torch.Tensor):
"""
Hook to update k_scale and v_scale parameters when running kv_cache quantization.
"""

def hook_fn(module: Module, inpt, output: torch.Tensor):
kv_cache = getattr(module, "kv_cache")
update_parameter_data(module, kv_cache.k_scales[module.layer_idx], "k_scale")
update_parameter_data(module, kv_cache.v_scales[module.layer_idx], "v_scale")

return hook_fn
kv_cache = getattr(module, "kv_cache")
update_parameter_data(module, kv_cache.k_scales[module.layer_idx], "k_scale")
update_parameter_data(module, kv_cache.v_scales[module.layer_idx], "v_scale")


def set_unset_kv_cache(module: Module):
Expand Down
Loading
Loading