Skip to content

Commit

Permalink
Fix observing offloaded weight (#896)
Browse files Browse the repository at this point in the history
* load weight within onloading

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* remove moving activation to execution device, since this is already done since activation calibration always happens within forward pass

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

---------

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Co-authored-by: Dipika Sikka <dipikasikka1@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
  • Loading branch information
kylesayrs and dsikka committed Nov 21, 2024
1 parent e962e33 commit f8777c7
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions src/llmcompressor/modifiers/quantization/calibration.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import torch
from compressed_tensors.quantization import QuantizationStatus, is_attention_module
from compressed_tensors.quantization.lifecycle.forward import forward_quantize
Expand Down Expand Up @@ -57,27 +59,30 @@ def initialize_observer(
module.register_module(f"{base_name}_observer", observer)


def call_observer(module: Module, base_name: str, value: torch.Tensor):
def call_observer(module: Module, base_name: str, value: Optional[torch.Tensor] = None):
"""
Call a module's attached input/output observer using a provided value.
Update the module's scale and zp using the observer's return
values.
Call a module's attached input/weight/output observer using a provided value.
Update the module's scale and zp using the observer's return values.
:param module: torch.nn.Module
:param base_name: substring used to fetch the observer, scales, and zp
:param value: torch.Tensor to be passed to the observer
:param value: torch.Tensor to be passed to the observer for activations. If
base_name is "weight", then the module's weight tensor will be used
"""
offloaded = is_module_offloaded(module)
if offloaded:
module._hf_hook.pre_forward(module)

observer = getattr(module, f"{base_name}_observer")
g_idx = getattr(module, "weight_g_idx", None)

if base_name == "weight":
updated_scale, updated_zero_point = observer(value, g_idx=g_idx)
value = module.weight
g_idx = getattr(module, "weight_g_idx", None)
elif value is not None:
g_idx = None
else:
updated_scale, updated_zero_point = observer(value)
raise ValueError("Must provide a value to observe if not using weight observer")

observer = getattr(module, f"{base_name}_observer")
updated_scale, updated_zero_point = observer(value, g_idx=g_idx)

# update scale and zero point
update_parameter_data(module, updated_scale, f"{base_name}_scale")
Expand Down Expand Up @@ -116,7 +121,7 @@ def update_weight_zp_scale(module: Module):

if module.quantization_scheme.weights is not None:
# set weight scale and zero_point up front, calibration data doesn't affect it
call_observer(module=module, base_name="weight", value=module.weight)
call_observer(module=module, base_name="weight")


def calibrate_activations(module: Module, value: torch.Tensor, base_name: str):
Expand Down

0 comments on commit f8777c7

Please sign in to comment.