Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accelerate Utilities Follow-up #224

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 9 additions & 37 deletions src/compressed_tensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@
from compressed_tensors.quantization.quant_config import QuantizationStatus
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
from compressed_tensors.utils import has_offloaded_params, register_offload_parameter
from compressed_tensors.utils import (
disable_hf_hook,
has_offloaded_params,
register_offload_parameter,
)
from torch.nn import Module, Parameter


Expand Down Expand Up @@ -112,42 +116,10 @@ def initialize_module_for_quantization(
module.quantization_scheme = scheme
module.quantization_status = QuantizationStatus.INITIALIZED

offloaded = False
if has_offloaded_params(module):
try:
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
from accelerate.utils import PrefixedDataset
except ModuleNotFoundError:
raise ModuleNotFoundError(
"Offloaded model detected. To use CPU offloading with "
"compressed-tensors the `accelerate` package must be installed, "
"run `pip install compressed-tensors[accelerate]`"
)

offloaded = True
hook = module._hf_hook
prefix_dict = module._hf_hook.weights_map
new_prefix = {}

# recreate the prefix dict (since it is immutable)
# and add quantization parameters
for key, data in module.named_parameters():
if key not in prefix_dict:
new_prefix[f"{prefix_dict.prefix}{key}"] = data
else:
new_prefix[f"{prefix_dict.prefix}{key}"] = prefix_dict[key]
new_prefix_dict = PrefixedDataset(new_prefix, prefix_dict.prefix)
remove_hook_from_module(module)

# wrap forward call of module to perform
# quantized actions based on calltime status
wrap_module_forward_quantized(module, scheme)

if offloaded:
# we need to re-add the hook for offloading now that we've wrapped forward
add_hook_to_module(module, hook)
if prefix_dict is not None:
module._hf_hook.weights_map = new_prefix_dict
with disable_hf_hook(module):
# wrap forward call of module to perform
# quantized actions based on calltime status
wrap_module_forward_quantized(module, scheme)


def is_attention_module(module: Module):
Expand Down
94 changes: 87 additions & 7 deletions src/compressed_tensors/utils/offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
from functools import wraps
from typing import Any, Callable, Optional

import torch
from compressed_tensors.utils.helpers import getattr_chain


try:
from accelerate.hooks import AlignDevicesHook
from accelerate.hooks import (
AlignDevicesHook,
add_hook_to_module,
remove_hook_from_module,
)
from accelerate.utils import (
OffloadedWeightsLoader,
PrefixedDataset,
Expand All @@ -42,6 +46,8 @@
"update_offload_data",
"delete_offload_parameter",
"has_offloaded_params",
"disable_hf_hook",
"align_module_device",
]


Expand Down Expand Up @@ -167,6 +173,7 @@ def update_offload_data(
:param data: tensor to update parameter with
"""
param = getattr(module, name)
data = data.to(param.dtype)

# copy data into onloaded parameter if applicable
if param.device != "meta":
Expand All @@ -178,23 +185,34 @@ def update_offload_data(

# for upstreaming, better to add write capabilities to weight map classes first
if isinstance(weights_map, PrefixedDataset):
dataset = getattr_chain(module, "module._hf_hook.weights_map.dataset", None)
dataset = getattr(weights_map, "dataset", None)
if dataset is not None:
prefix = module._hf_hook.weights_map.prefix
key = f"{prefix}{name}"

offload_device = (
dataset[key].device
if key in dataset
else next(dataset.values()).device
else next(iter(dataset.values())).device
)
dataset[key] = param.data.to(device=offload_device)
dataset[key] = data.to(device=offload_device)

elif isinstance(weights_map, dict):
offload_device = (
weights_map[name].device
if name in weights_map
else next(iter(weights_map.values())).device
)
weights_map[name] = data.to(device=offload_device)

if isinstance(weights_map, OffloadedWeightsLoader):
elif isinstance(weights_map, OffloadedWeightsLoader):
raise NotImplementedError()

else:
raise NotImplementedError()
raise NotImplementedError(
"Updating offload data not implemented for weights_map of type "
f"{type(weights_map)}"
)


def delete_offload_parameter(module: torch.nn.Module, name: str):
Expand All @@ -216,6 +234,9 @@ def delete_offload_parameter(module: torch.nn.Module, name: str):
if dataset is not None:
del dataset[f"{prefix}{name}"]

elif isinstance(weights_map, dict):
del weights_map[name]

elif isinstance(weights_map, OffloadedWeightsLoader):
raise NotImplementedError()

Expand All @@ -225,6 +246,20 @@ def delete_offload_parameter(module: torch.nn.Module, name: str):
)


@check_accelerate(fallback=contextlib.nullcontext())
@contextlib.contextmanager
def disable_hf_hook(module: torch.nn.Module, recurse: bool = False):
offloaded = has_offloaded_params(module)
if offloaded:
hook = module._hf_hook
remove_hook_from_module(module, recurse=recurse)

yield

if offloaded:
add_hook_to_module(module, hook)


""" Upstreamed Functions """


Expand All @@ -247,3 +282,48 @@ def has_offloaded_params(module: torch.nn.Module) -> bool:
and isinstance(module._hf_hook, AlignDevicesHook)
and module._hf_hook.offload
)


# introduced in accelerate v1.1.0
@check_accelerate(fallback=contextlib.nullcontext())
@contextlib.contextmanager
def align_module_device(
module: torch.nn.Module, execution_device: Optional[torch.device] = None
):
"""
Context manager that moves a module's parameters to the specified execution device.

Args:
module (`torch.nn.Module`):
Module with parameters to align.
execution_device (`torch.device`, *optional*):
If provided, overrides the module's execution device within the context.
Otherwise, use hook execution device or pass
"""
if has_offloaded_params(module):
if execution_device is not None:
original_device = module._hf_hook.execution_device
module._hf_hook.execution_device = execution_device

try:
module._hf_hook.pre_forward(module)
yield
finally:
module._hf_hook.post_forward(module, None)
if execution_device is not None:
module._hf_hook.execution_device = original_device

elif execution_device is not None:
devices = {
name: param.device for name, param in module.named_parameters(recurse=False)
}
try:
for name in devices:
set_module_tensor_to_device(module, name, execution_device)
yield
finally:
for name, device in devices.items():
set_module_tensor_to_device(module, name, device)

else:
yield
43 changes: 41 additions & 2 deletions tests/test_quantization/lifecycle/test_initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,18 @@
)
from compressed_tensors.quantization.quant_args import QuantizationArgs
from compressed_tensors.quantization.quant_config import QuantizationStatus
from tests.testing_utils import requires_accelerate
from torch.nn import Linear


NUM_BITS = 8


@pytest.fixture
def layer():
return Linear(4, 4)


@pytest.mark.parametrize(
"weights,input_activations",
[
Expand All @@ -43,14 +49,13 @@
],
)
def test_initialize_module_for_quantization(
create_quantization_scheme, weights, input_activations
create_quantization_scheme, weights, input_activations, layer
):
quantization_scheme = create_quantization_scheme(
targets=["*"],
weights=weights,
input_activations=input_activations,
)
layer = Linear(4, 4)

assert not hasattr(layer, "quantization_scheme")
assert not hasattr(layer, "quantization_status")
Expand All @@ -77,3 +82,37 @@ def test_initialize_module_for_quantization(
assert hasattr(layer, "quantization_status")

assert layer.quantization_status == QuantizationStatus.INITIALIZED


@requires_accelerate()
@pytest.mark.parametrize(
"weights,input_activations",
[
(
QuantizationArgs(num_bits=NUM_BITS, symmetric=True),
None,
),
(
None,
QuantizationArgs(num_bits=NUM_BITS, symmetric=True),
),
(
QuantizationArgs(num_bits=NUM_BITS, symmetric=True),
QuantizationArgs(num_bits=NUM_BITS, symmetric=True),
),
],
)
def test_initialize_module_for_quantization_offloaded(
create_quantization_scheme, weights, input_activations
):
from accelerate.hooks import attach_align_device_hook

layer = Linear(4, 4)
attach_align_device_hook(layer, offload=True)

test_initialize_module_for_quantization(
create_quantization_scheme,
weights,
input_activations,
layer,
)
Loading