diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 009215e0..9cb99b2e 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -29,7 +29,11 @@ from compressed_tensors.quantization.quant_config import QuantizationStatus from compressed_tensors.quantization.quant_scheme import QuantizationScheme from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme -from compressed_tensors.utils import has_offloaded_params, register_offload_parameter +from compressed_tensors.utils import ( + disable_hf_hook, + has_offloaded_params, + register_offload_parameter, +) from torch.nn import Module, Parameter @@ -112,42 +116,10 @@ def initialize_module_for_quantization( module.quantization_scheme = scheme module.quantization_status = QuantizationStatus.INITIALIZED - offloaded = False - if has_offloaded_params(module): - try: - from accelerate.hooks import add_hook_to_module, remove_hook_from_module - from accelerate.utils import PrefixedDataset - except ModuleNotFoundError: - raise ModuleNotFoundError( - "Offloaded model detected. To use CPU offloading with " - "compressed-tensors the `accelerate` package must be installed, " - "run `pip install compressed-tensors[accelerate]`" - ) - - offloaded = True - hook = module._hf_hook - prefix_dict = module._hf_hook.weights_map - new_prefix = {} - - # recreate the prefix dict (since it is immutable) - # and add quantization parameters - for key, data in module.named_parameters(): - if key not in prefix_dict: - new_prefix[f"{prefix_dict.prefix}{key}"] = data - else: - new_prefix[f"{prefix_dict.prefix}{key}"] = prefix_dict[key] - new_prefix_dict = PrefixedDataset(new_prefix, prefix_dict.prefix) - remove_hook_from_module(module) - - # wrap forward call of module to perform - # quantized actions based on calltime status - wrap_module_forward_quantized(module, scheme) - - if offloaded: - # we need to re-add the hook for offloading now that we've wrapped forward - add_hook_to_module(module, hook) - if prefix_dict is not None: - module._hf_hook.weights_map = new_prefix_dict + with disable_hf_hook(module): + # wrap forward call of module to perform + # quantized actions based on calltime status + wrap_module_forward_quantized(module, scheme) def is_attention_module(module: Module): diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index 0d7b0bbe..7b7cc864 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -12,15 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib from functools import wraps from typing import Any, Callable, Optional import torch -from compressed_tensors.utils.helpers import getattr_chain try: - from accelerate.hooks import AlignDevicesHook + from accelerate.hooks import ( + AlignDevicesHook, + add_hook_to_module, + remove_hook_from_module, + ) from accelerate.utils import ( OffloadedWeightsLoader, PrefixedDataset, @@ -42,6 +46,8 @@ "update_offload_data", "delete_offload_parameter", "has_offloaded_params", + "disable_hf_hook", + "align_module_device", ] @@ -167,6 +173,7 @@ def update_offload_data( :param data: tensor to update parameter with """ param = getattr(module, name) + data = data.to(param.dtype) # copy data into onloaded parameter if applicable if param.device != "meta": @@ -178,7 +185,7 @@ def update_offload_data( # for upstreaming, better to add write capabilities to weight map classes first if isinstance(weights_map, PrefixedDataset): - dataset = getattr_chain(module, "module._hf_hook.weights_map.dataset", None) + dataset = getattr(weights_map, "dataset", None) if dataset is not None: prefix = module._hf_hook.weights_map.prefix key = f"{prefix}{name}" @@ -186,15 +193,26 @@ def update_offload_data( offload_device = ( dataset[key].device if key in dataset - else next(dataset.values()).device + else next(iter(dataset.values())).device ) - dataset[key] = param.data.to(device=offload_device) + dataset[key] = data.to(device=offload_device) + + elif isinstance(weights_map, dict): + offload_device = ( + weights_map[name].device + if name in weights_map + else next(iter(weights_map.values())).device + ) + weights_map[name] = data.to(device=offload_device) - if isinstance(weights_map, OffloadedWeightsLoader): + elif isinstance(weights_map, OffloadedWeightsLoader): raise NotImplementedError() else: - raise NotImplementedError() + raise NotImplementedError( + "Updating offload data not implemented for weights_map of type " + f"{type(weights_map)}" + ) def delete_offload_parameter(module: torch.nn.Module, name: str): @@ -216,6 +234,9 @@ def delete_offload_parameter(module: torch.nn.Module, name: str): if dataset is not None: del dataset[f"{prefix}{name}"] + elif isinstance(weights_map, dict): + del weights_map[name] + elif isinstance(weights_map, OffloadedWeightsLoader): raise NotImplementedError() @@ -225,6 +246,20 @@ def delete_offload_parameter(module: torch.nn.Module, name: str): ) +@check_accelerate(fallback=contextlib.nullcontext()) +@contextlib.contextmanager +def disable_hf_hook(module: torch.nn.Module, recurse: bool = False): + offloaded = has_offloaded_params(module) + if offloaded: + hook = module._hf_hook + remove_hook_from_module(module, recurse=recurse) + + yield + + if offloaded: + add_hook_to_module(module, hook) + + """ Upstreamed Functions """ @@ -247,3 +282,48 @@ def has_offloaded_params(module: torch.nn.Module) -> bool: and isinstance(module._hf_hook, AlignDevicesHook) and module._hf_hook.offload ) + + +# introduced in accelerate v1.1.0 +@check_accelerate(fallback=contextlib.nullcontext()) +@contextlib.contextmanager +def align_module_device( + module: torch.nn.Module, execution_device: Optional[torch.device] = None +): + """ + Context manager that moves a module's parameters to the specified execution device. + + Args: + module (`torch.nn.Module`): + Module with parameters to align. + execution_device (`torch.device`, *optional*): + If provided, overrides the module's execution device within the context. + Otherwise, use hook execution device or pass + """ + if has_offloaded_params(module): + if execution_device is not None: + original_device = module._hf_hook.execution_device + module._hf_hook.execution_device = execution_device + + try: + module._hf_hook.pre_forward(module) + yield + finally: + module._hf_hook.post_forward(module, None) + if execution_device is not None: + module._hf_hook.execution_device = original_device + + elif execution_device is not None: + devices = { + name: param.device for name, param in module.named_parameters(recurse=False) + } + try: + for name in devices: + set_module_tensor_to_device(module, name, execution_device) + yield + finally: + for name, device in devices.items(): + set_module_tensor_to_device(module, name, device) + + else: + yield diff --git a/tests/test_quantization/lifecycle/test_initialize.py b/tests/test_quantization/lifecycle/test_initialize.py index 987b2ae2..8252b545 100644 --- a/tests/test_quantization/lifecycle/test_initialize.py +++ b/tests/test_quantization/lifecycle/test_initialize.py @@ -19,12 +19,18 @@ ) from compressed_tensors.quantization.quant_args import QuantizationArgs from compressed_tensors.quantization.quant_config import QuantizationStatus +from tests.testing_utils import requires_accelerate from torch.nn import Linear NUM_BITS = 8 +@pytest.fixture +def layer(): + return Linear(4, 4) + + @pytest.mark.parametrize( "weights,input_activations", [ @@ -43,14 +49,13 @@ ], ) def test_initialize_module_for_quantization( - create_quantization_scheme, weights, input_activations + create_quantization_scheme, weights, input_activations, layer ): quantization_scheme = create_quantization_scheme( targets=["*"], weights=weights, input_activations=input_activations, ) - layer = Linear(4, 4) assert not hasattr(layer, "quantization_scheme") assert not hasattr(layer, "quantization_status") @@ -77,3 +82,37 @@ def test_initialize_module_for_quantization( assert hasattr(layer, "quantization_status") assert layer.quantization_status == QuantizationStatus.INITIALIZED + + +@requires_accelerate() +@pytest.mark.parametrize( + "weights,input_activations", + [ + ( + QuantizationArgs(num_bits=NUM_BITS, symmetric=True), + None, + ), + ( + None, + QuantizationArgs(num_bits=NUM_BITS, symmetric=True), + ), + ( + QuantizationArgs(num_bits=NUM_BITS, symmetric=True), + QuantizationArgs(num_bits=NUM_BITS, symmetric=True), + ), + ], +) +def test_initialize_module_for_quantization_offloaded( + create_quantization_scheme, weights, input_activations +): + from accelerate.hooks import attach_align_device_hook + + layer = Linear(4, 4) + attach_align_device_hook(layer, offload=True) + + test_initialize_module_for_quantization( + create_quantization_scheme, + weights, + input_activations, + layer, + ) diff --git a/tests/test_utils/test_offload.py b/tests/test_utils/test_offload.py new file mode 100644 index 00000000..c127dd98 --- /dev/null +++ b/tests/test_utils/test_offload.py @@ -0,0 +1,163 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from compressed_tensors.utils import ( + align_module_device, + delete_offload_parameter, + disable_hf_hook, + has_offloaded_params, + register_offload_parameter, + update_offload_data, +) +from tests.testing_utils import requires_accelerate + + +class ExampleModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = torch.nn.Parameter(torch.tensor(0).float()) + self.b = torch.nn.Parameter(torch.tensor(0).float()) + + def forward(self, x): + return x * self.a + self.b + + +@requires_accelerate() +def test_has_offloaded_params(): + from accelerate.big_modeling import cpu_offload_with_hook + from accelerate.hooks import attach_align_device_hook, remove_hook_from_module + + module = ExampleModule() + assert not has_offloaded_params(module) + + attach_align_device_hook(module, offload=False) + assert not has_offloaded_params(module) + + remove_hook_from_module(module) + module, _ = cpu_offload_with_hook(module) + assert not has_offloaded_params(module) + + remove_hook_from_module(module) + attach_align_device_hook(module, offload=True, weights_map=module.state_dict()) + assert has_offloaded_params(module) + + +@requires_accelerate() +def test_register_offload_parameter(): + from accelerate.hooks import attach_align_device_hook + + module = ExampleModule() + parameter = torch.nn.Parameter(torch.tensor(1.0)) + + # register a param prior to offloading + register_offload_parameter(module, "c", parameter) + assert hasattr(module, "c") and module.c == parameter + + # offloading, check that added param was offloaded + attach_align_device_hook(module, offload=True, weights_map=module.state_dict()) + assert "c" in module._hf_hook.weights_map + + # register a param after offloading, check that added param was offloaded + register_offload_parameter(module, "d", parameter) + assert hasattr(module, "d") and module.d.device == torch.device("meta") + assert "d" in module._hf_hook.weights_map + + # added parameters can be onloaded and offloaded + with align_module_device(module, execution_device="cpu"): + assert module.c.device == torch.device("cpu") + assert module.d.device == torch.device("cpu") + assert module.c.device == torch.device("meta") + assert module.d.device == torch.device("meta") + + +@requires_accelerate() +def test_update_offload_data(): + from accelerate.hooks import attach_align_device_hook + + module = ExampleModule() + param_a = torch.nn.Parameter(torch.tensor(1.0)) + param_b = torch.nn.Parameter(torch.tensor(2.0)) + + # can update modules which are not offloaded + update_offload_data(module, "a", param_a) + assert module.a == param_a + + # can update modules which are offloaded + attach_align_device_hook(module, offload=True, weights_map=module.state_dict()) + update_offload_data(module, "b", param_b) + assert module.b.device == torch.device("meta") + assert module._hf_hook.weights_map["b"] == param_b.data + + # data persists across onloading + with align_module_device(module, execution_device="cpu"): + assert module.a == param_a + assert module.b == param_b + assert module._hf_hook.weights_map["a"] == param_a.data + assert module._hf_hook.weights_map["b"] == param_b.data + + # data persists across offloading + assert module.a.device == torch.device("meta") + assert module.b.device == torch.device("meta") + assert module._hf_hook.weights_map["a"] == param_a.data + assert module._hf_hook.weights_map["b"] == param_b.data + + +@requires_accelerate() +def test_delete_offload_parameter(): + from accelerate.hooks import attach_align_device_hook + + module = ExampleModule() + param_c = torch.nn.Parameter(torch.tensor(1.0)) + param_d = torch.nn.Parameter(torch.tensor(2.0)) + register_offload_parameter(module, "c", param_c) + register_offload_parameter(module, "d", param_d) + + # parameters are deleted + delete_offload_parameter(module, "a") + delete_offload_parameter(module, "c") + assert not hasattr(module, "a") + assert hasattr(module, "b") + assert not hasattr(module, "c") + assert hasattr(module, "d") + + # parameters and their offload are deleted + attach_align_device_hook(module, offload=True, weights_map=module.state_dict()) + delete_offload_parameter(module, "b") + delete_offload_parameter(module, "d") + assert not hasattr(module, "a") + assert not hasattr(module, "b") + assert not hasattr(module, "c") + assert not hasattr(module, "d") + assert "a" not in module._hf_hook.weights_map + assert "b" not in module._hf_hook.weights_map + assert "c" not in module._hf_hook.weights_map + assert "d" not in module._hf_hook.weights_map + + +@requires_accelerate() +def test_disable_hf_hook(): + from accelerate.hooks import attach_align_device_hook + + module = ExampleModule() + + def custom_forward(): + pass + + attach_align_device_hook(module, offload=True, weights_map=module.state_dict()) + with disable_hf_hook(module): + assert not hasattr(module, "_hf_hook") + module.forward = custom_forward + + assert hasattr(module, "_hf_hook") + assert module._old_forward == custom_forward