Skip to content

Commit

Permalink
Model Offloading Support (vllm-project#113)
Browse files Browse the repository at this point in the history
* compute zp, scale if weight exists in module

* WIP, gets through 1 forward pass

* fix for zeroed out scales

* fix model load

* style

* offload helper fns

* pass tests

* add test to check that observers are used to populate zp and scale in initialization

* fix no calibration case

* clean up for PR

* fix test

* update dependencies

* fix forward bug

* don't calibrate on weights

* dont calib weight in forward

* fix zp load

* check calibration

---------

Co-authored-by: George Ohashi <george@neuralmagic.com>
  • Loading branch information
Sara Adkins and horheynm authored Jul 30, 2024
1 parent c214cbc commit 622f721
Show file tree
Hide file tree
Showing 10 changed files with 189 additions and 27 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _setup_packages() -> List:
)

def _setup_install_requires() -> List:
return ["torch>=1.7.0", "transformers", "pydantic>=2.0"]
return ["torch>=1.7.0", "transformers", "accelerate", "pydantic>=2.0"]

def _setup_extras() -> Dict:
return {"dev": ["black==22.12.0", "isort==5.8.0", "wheel>=0.36.2", "flake8>=3.8.3", "pytest>=6.0.0", "nbconvert>=7.16.3"]}
Expand Down
14 changes: 6 additions & 8 deletions src/compressed_tensors/compressors/model_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@
is_module_quantized,
iter_named_leaf_modules,
)
from compressed_tensors.utils import get_safetensors_folder
from compressed_tensors.utils import get_safetensors_folder, update_parameter_data
from compressed_tensors.utils.helpers import fix_fsdp_module_name
from torch import Tensor
from torch.nn import Module, Parameter
from torch.nn import Module
from tqdm import tqdm
from transformers import AutoConfig
from transformers.file_utils import CONFIG_NAME
Expand Down Expand Up @@ -307,12 +307,10 @@ def update_config(self, save_directory: str):

def _replace_weights(self, dense_weight_generator, model):
for name, data in tqdm(dense_weight_generator, desc="Decompressing model"):
# loading the decompressed weights into the model
model_device = operator.attrgetter(name)(model).device
data_old = operator.attrgetter(name)(model)
data_dtype = data_old.dtype
data_new = Parameter(data.to(model_device).to(data_dtype))
data_old.data = data_new.data
split_name = name.split(".")
prefix, param_name = ".".join(split_name[:-1]), split_name[-1]
module = operator.attrgetter(prefix)(model)
update_parameter_data(module, data, param_name)


def map_modules_to_quant_args(model: Module) -> Dict:
Expand Down
25 changes: 12 additions & 13 deletions src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
iter_named_leaf_modules,
)
from compressed_tensors.utils.helpers import fix_fsdp_module_name
from compressed_tensors.utils.offload import update_parameter_data
from compressed_tensors.utils.safetensors_load import get_safetensors_folder
from torch.nn import Module

Expand Down Expand Up @@ -265,19 +266,17 @@ def _load_quant_args_from_state_dict(
"""
scale_name = f"{base_name}_scale"
zp_name = f"{base_name}_zero_point"
device = next(module.parameters()).device

scale = getattr(module, scale_name, None)
zp = getattr(module, zp_name, None)
if scale is not None:
state_dict_scale = state_dict[f"{module_name}.{scale_name}"]
scale.data = state_dict_scale.to(device).to(scale.dtype)
if zp is not None:
zp_from_state = state_dict.get(f"{module_name}.{zp_name}", None)
if zp_from_state is not None: # load the non-zero zero points
zp.data = zp_from_state.to(device).to(zp.dtype)
else: # fill with zeros matching scale shape
zp.data = torch.zeros_like(scale, dtype=zp.dtype).to(device)

state_dict_scale = state_dict.get(f"{module_name}.{scale_name}", None)
state_dict_zp = state_dict.get(f"{module_name}.{zp_name}", None)

if state_dict_scale is not None:
# module is quantized
update_parameter_data(module, state_dict_scale, scale_name)
if state_dict_zp is None:
# fill in zero point for symmetric quantization
state_dict_zp = torch.zeros_like(state_dict_scale, device="cpu")
update_parameter_data(module, state_dict_zp, zp_name)


def _scheme_from_targets(
Expand Down
17 changes: 17 additions & 0 deletions src/compressed_tensors/quantization/lifecycle/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import logging

from compressed_tensors.quantization.quant_config import QuantizationStatus
from compressed_tensors.utils import is_module_offloaded, update_parameter_data
from torch.nn import Module


Expand Down Expand Up @@ -48,4 +49,20 @@ def set_module_for_calibration(module: Module):
"to re-calibrate a frozen module"
)

if module.quantization_scheme.weights is not None:
# set weight scale and zero_point up front, calibration data doesn't affect it
observer = module.weight_observer

offloaded = False
if is_module_offloaded(module):
module._hf_hook.pre_forward(module)
offloaded = True

scale, zero_point = observer(module.weight)
update_parameter_data(module, scale, "weight_scale")
update_parameter_data(module, zero_point, "weight_zero_point")

if offloaded:
module._hf_hook.post_forward(module, None)

module.quantization_status = QuantizationStatus.CALIBRATION
12 changes: 8 additions & 4 deletions src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)
from compressed_tensors.quantization.quant_config import QuantizationStatus
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
from compressed_tensors.utils import update_parameter_data
from torch.nn import Module


Expand Down Expand Up @@ -312,16 +313,19 @@ def maybe_calibrate_or_quantize(
scale = getattr(module, f"{base_name}_scale")
zero_point = getattr(module, f"{base_name}_zero_point")

if module.quantization_status == QuantizationStatus.CALIBRATION:
if (
module.quantization_status == QuantizationStatus.CALIBRATION
and base_name != "weight"
):
# calibration mode - get new quant params from observer
observer = getattr(module, f"{base_name}_observer")

updated_scale, updated_zero_point = observer(value)

# update scale and zero point
device = next(module.parameters()).device
scale.data = updated_scale.to(device)
zero_point.data = updated_zero_point.to(device)
update_parameter_data(module, updated_scale, f"{base_name}_scale")
update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point")

return fake_quantize(value, scale, zero_point, args)


Expand Down
28 changes: 28 additions & 0 deletions src/compressed_tensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from typing import Optional

import torch
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
from accelerate.utils import PrefixedDataset
from compressed_tensors.quantization.lifecycle.forward import (
wrap_module_forward_quantized,
)
Expand All @@ -26,6 +28,7 @@
)
from compressed_tensors.quantization.quant_config import QuantizationStatus
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
from compressed_tensors.utils import get_execution_device, is_module_offloaded
from torch.nn import Module, Parameter


Expand Down Expand Up @@ -81,9 +84,32 @@ def initialize_module_for_quantization(
module.quantization_scheme = scheme
module.quantization_status = QuantizationStatus.INITIALIZED

offloaded = False
if is_module_offloaded(module):
offloaded = True
hook = module._hf_hook
prefix_dict = module._hf_hook.weights_map
new_prefix = {}

# recreate the prefix dict (since it is immutable)
# and add quantization parameters
for key, data in module.named_parameters():
if key not in prefix_dict:
new_prefix[f"{prefix_dict.prefix}{key}"] = data
else:
new_prefix[f"{prefix_dict.prefix}{key}"] = prefix_dict[key]
new_prefix_dict = PrefixedDataset(new_prefix, prefix_dict.prefix)
remove_hook_from_module(module)

# wrap forward call of module to perform quantized actions based on calltime status
wrap_module_forward_quantized(module, scheme)

if offloaded:
# we need to re-add the hook for offloading now that we've wrapped forward
add_hook_to_module(module, hook)
if prefix_dict is not None:
module._hf_hook.weights_map = new_prefix_dict


def _initialize_scale_zero_point_observer(
module: Module,
Expand All @@ -99,6 +125,8 @@ def _initialize_scale_zero_point_observer(
return # no need to register a scale and zero point for a dynamic observer

device = next(module.parameters()).device
if is_module_offloaded(module):
device = get_execution_device(module)

# infer expected scale/zero point shape
expected_shape = 1 # per tensor
Expand Down
11 changes: 11 additions & 0 deletions src/compressed_tensors/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,14 @@ def from_pretrained(
format=format,
ignore=consolidated_ignore,
)

def requires_calibration_data(self):
for _, scheme in self.config_groups.items():
if scheme.input_activations is not None:
if not scheme.input_activations.dynamic:
return True
if scheme.output_activations is not None:
if not scheme.output_activations.dynamic:
return True

return False
1 change: 1 addition & 0 deletions src/compressed_tensors/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# flake8: noqa

from .helpers import *
from .offload import *
from .permutations_24 import *
from .safetensors_load import *
from .semi_structured_conversions import *
104 changes: 104 additions & 0 deletions src/compressed_tensors/utils/offload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from torch.nn import Module


__all__ = [
"is_module_offloaded",
"get_execution_device",
"get_offloaded_device",
"update_prefix_dict",
"update_parameter_data",
]


def is_module_offloaded(module: Module) -> bool:
"""
:param module: layer to check
:return: True if layer is offloaded from GPU, False otherwise
"""
return hasattr(module, "_hf_hook") and module._hf_hook.offload


def get_execution_device(module: Module) -> torch.device:
"""
:param module: layer to check
:return: device layer is loaded onto during forward pass
"""
if is_module_offloaded(module):
return module._hf_hook.execution_device
return next(module.parameters()).device


def get_offloaded_device(module: Module) -> torch.device:
"""
:param module: layer to check
:return: device layer is offloaded to onto after forward pass
"""
if is_module_offloaded(module):
first_key = list(module._hf_hook.weights_map.keys())[0]
prefix_dataset = module._hf_hook.weights_map.dataset
return prefix_dataset[first_key].device
return next(module.parameters()).device


def update_prefix_dict(module: Module, key: str, data: torch.Tensor):
"""
Updates the offloaded state dict for a given module. Parameter named key is replaced
by data. This is neccesary because parameter updates for offloaded modules do not
persist automatically between loads. This function only affects the offloaded
state dict and not the current state of the loaded module.
:param module: layer containing the parameter to update
:param key: name of parameter to update
:param data: tensor to update parameter with in the offloaded state dict
"""
if not is_module_offloaded(module):
raise ValueError("Prefix dict is only applicable to offloaded modules")
prefix_dict = module._hf_hook.weights_map
prefix_dict.dataset[f"{prefix_dict.prefix}{key}"] = data


def update_parameter_data(
module: Module, new_param_data: torch.Tensor, param_name: str
):
"""
Updates the paramter value named param_name for a given module. This function
updates both the current loaded module state and the offloaded state dict if
the module is offloaded. This is neccesary because parameter updates for offloaded
modules do not persist automatically between loads.
:param module: layer containing the parameter to update
:param new_param_data: tensor to update parameter with
:param param_name:
"""
device = next(module.parameters()).device

offloaded = False
if is_module_offloaded(module):
offload_device = get_offloaded_device(module)
offloaded = True

parameter = getattr(module, param_name, None)
dtype = parameter.dtype
parameter.data = new_param_data.to(device).to(dtype)

if offloaded:
prefix_dict = module._hf_hook.weights_map.dataset
prefix = module._hf_hook.weights_map.prefix
prefix_dict[f"{prefix}{param_name}"] = new_param_data.to(offload_device).to(
dtype
)
2 changes: 1 addition & 1 deletion tests/test_quantization/lifecycle/test_enabled.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_quantization_enabled_disabled():
apply_quantization_config(
model=quantized_model,
config=QuantizationConfig(
config_groups=dict(W4A16=["Linear"]),
config_groups=dict(W8A8=["Linear"]),
quantization_status="calibration",
),
)
Expand Down

0 comments on commit 622f721

Please sign in to comment.