Skip to content

Commit e7e1d81

Browse files
authored
Accelerate Utilities Follow-up (#224)
1 parent 81a1eab commit e7e1d81

File tree

4 files changed

+300
-46
lines changed

4 files changed

+300
-46
lines changed

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 9 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,11 @@
2929
from compressed_tensors.quantization.quant_config import QuantizationStatus
3030
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
3131
from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
32-
from compressed_tensors.utils import has_offloaded_params, register_offload_parameter
32+
from compressed_tensors.utils import (
33+
disable_hf_hook,
34+
has_offloaded_params,
35+
register_offload_parameter,
36+
)
3337
from torch.nn import Module, Parameter
3438

3539

@@ -112,42 +116,10 @@ def initialize_module_for_quantization(
112116
module.quantization_scheme = scheme
113117
module.quantization_status = QuantizationStatus.INITIALIZED
114118

115-
offloaded = False
116-
if has_offloaded_params(module):
117-
try:
118-
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
119-
from accelerate.utils import PrefixedDataset
120-
except ModuleNotFoundError:
121-
raise ModuleNotFoundError(
122-
"Offloaded model detected. To use CPU offloading with "
123-
"compressed-tensors the `accelerate` package must be installed, "
124-
"run `pip install compressed-tensors[accelerate]`"
125-
)
126-
127-
offloaded = True
128-
hook = module._hf_hook
129-
prefix_dict = module._hf_hook.weights_map
130-
new_prefix = {}
131-
132-
# recreate the prefix dict (since it is immutable)
133-
# and add quantization parameters
134-
for key, data in module.named_parameters():
135-
if key not in prefix_dict:
136-
new_prefix[f"{prefix_dict.prefix}{key}"] = data
137-
else:
138-
new_prefix[f"{prefix_dict.prefix}{key}"] = prefix_dict[key]
139-
new_prefix_dict = PrefixedDataset(new_prefix, prefix_dict.prefix)
140-
remove_hook_from_module(module)
141-
142-
# wrap forward call of module to perform
143-
# quantized actions based on calltime status
144-
wrap_module_forward_quantized(module, scheme)
145-
146-
if offloaded:
147-
# we need to re-add the hook for offloading now that we've wrapped forward
148-
add_hook_to_module(module, hook)
149-
if prefix_dict is not None:
150-
module._hf_hook.weights_map = new_prefix_dict
119+
with disable_hf_hook(module):
120+
# wrap forward call of module to perform
121+
# quantized actions based on calltime status
122+
wrap_module_forward_quantized(module, scheme)
151123

152124

153125
def is_attention_module(module: Module):

src/compressed_tensors/utils/offload.py

Lines changed: 87 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,19 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import contextlib
1516
from functools import wraps
1617
from typing import Any, Callable, Optional
1718

1819
import torch
19-
from compressed_tensors.utils.helpers import getattr_chain
2020

2121

2222
try:
23-
from accelerate.hooks import AlignDevicesHook
23+
from accelerate.hooks import (
24+
AlignDevicesHook,
25+
add_hook_to_module,
26+
remove_hook_from_module,
27+
)
2428
from accelerate.utils import (
2529
OffloadedWeightsLoader,
2630
PrefixedDataset,
@@ -42,6 +46,8 @@
4246
"update_offload_data",
4347
"delete_offload_parameter",
4448
"has_offloaded_params",
49+
"disable_hf_hook",
50+
"align_module_device",
4551
]
4652

4753

@@ -167,6 +173,7 @@ def update_offload_data(
167173
:param data: tensor to update parameter with
168174
"""
169175
param = getattr(module, name)
176+
data = data.to(param.dtype)
170177

171178
# copy data into onloaded parameter if applicable
172179
if param.device != "meta":
@@ -178,23 +185,34 @@ def update_offload_data(
178185

179186
# for upstreaming, better to add write capabilities to weight map classes first
180187
if isinstance(weights_map, PrefixedDataset):
181-
dataset = getattr_chain(module, "module._hf_hook.weights_map.dataset", None)
188+
dataset = getattr(weights_map, "dataset", None)
182189
if dataset is not None:
183190
prefix = module._hf_hook.weights_map.prefix
184191
key = f"{prefix}{name}"
185192

186193
offload_device = (
187194
dataset[key].device
188195
if key in dataset
189-
else next(dataset.values()).device
196+
else next(iter(dataset.values())).device
190197
)
191-
dataset[key] = param.data.to(device=offload_device)
198+
dataset[key] = data.to(device=offload_device)
199+
200+
elif isinstance(weights_map, dict):
201+
offload_device = (
202+
weights_map[name].device
203+
if name in weights_map
204+
else next(iter(weights_map.values())).device
205+
)
206+
weights_map[name] = data.to(device=offload_device)
192207

193-
if isinstance(weights_map, OffloadedWeightsLoader):
208+
elif isinstance(weights_map, OffloadedWeightsLoader):
194209
raise NotImplementedError()
195210

196211
else:
197-
raise NotImplementedError()
212+
raise NotImplementedError(
213+
"Updating offload data not implemented for weights_map of type "
214+
f"{type(weights_map)}"
215+
)
198216

199217

200218
def delete_offload_parameter(module: torch.nn.Module, name: str):
@@ -216,6 +234,9 @@ def delete_offload_parameter(module: torch.nn.Module, name: str):
216234
if dataset is not None:
217235
del dataset[f"{prefix}{name}"]
218236

237+
elif isinstance(weights_map, dict):
238+
del weights_map[name]
239+
219240
elif isinstance(weights_map, OffloadedWeightsLoader):
220241
raise NotImplementedError()
221242

@@ -225,6 +246,20 @@ def delete_offload_parameter(module: torch.nn.Module, name: str):
225246
)
226247

227248

249+
@check_accelerate(fallback=contextlib.nullcontext())
250+
@contextlib.contextmanager
251+
def disable_hf_hook(module: torch.nn.Module, recurse: bool = False):
252+
offloaded = has_offloaded_params(module)
253+
if offloaded:
254+
hook = module._hf_hook
255+
remove_hook_from_module(module, recurse=recurse)
256+
257+
yield
258+
259+
if offloaded:
260+
add_hook_to_module(module, hook)
261+
262+
228263
""" Upstreamed Functions """
229264

230265

@@ -247,3 +282,48 @@ def has_offloaded_params(module: torch.nn.Module) -> bool:
247282
and isinstance(module._hf_hook, AlignDevicesHook)
248283
and module._hf_hook.offload
249284
)
285+
286+
287+
# introduced in accelerate v1.1.0
288+
@check_accelerate(fallback=contextlib.nullcontext())
289+
@contextlib.contextmanager
290+
def align_module_device(
291+
module: torch.nn.Module, execution_device: Optional[torch.device] = None
292+
):
293+
"""
294+
Context manager that moves a module's parameters to the specified execution device.
295+
296+
Args:
297+
module (`torch.nn.Module`):
298+
Module with parameters to align.
299+
execution_device (`torch.device`, *optional*):
300+
If provided, overrides the module's execution device within the context.
301+
Otherwise, use hook execution device or pass
302+
"""
303+
if has_offloaded_params(module):
304+
if execution_device is not None:
305+
original_device = module._hf_hook.execution_device
306+
module._hf_hook.execution_device = execution_device
307+
308+
try:
309+
module._hf_hook.pre_forward(module)
310+
yield
311+
finally:
312+
module._hf_hook.post_forward(module, None)
313+
if execution_device is not None:
314+
module._hf_hook.execution_device = original_device
315+
316+
elif execution_device is not None:
317+
devices = {
318+
name: param.device for name, param in module.named_parameters(recurse=False)
319+
}
320+
try:
321+
for name in devices:
322+
set_module_tensor_to_device(module, name, execution_device)
323+
yield
324+
finally:
325+
for name, device in devices.items():
326+
set_module_tensor_to_device(module, name, device)
327+
328+
else:
329+
yield

tests/test_quantization/lifecycle/test_initialize.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,18 @@
1919
)
2020
from compressed_tensors.quantization.quant_args import QuantizationArgs
2121
from compressed_tensors.quantization.quant_config import QuantizationStatus
22+
from tests.testing_utils import requires_accelerate
2223
from torch.nn import Linear
2324

2425

2526
NUM_BITS = 8
2627

2728

29+
@pytest.fixture
30+
def layer():
31+
return Linear(4, 4)
32+
33+
2834
@pytest.mark.parametrize(
2935
"weights,input_activations",
3036
[
@@ -43,14 +49,13 @@
4349
],
4450
)
4551
def test_initialize_module_for_quantization(
46-
create_quantization_scheme, weights, input_activations
52+
create_quantization_scheme, weights, input_activations, layer
4753
):
4854
quantization_scheme = create_quantization_scheme(
4955
targets=["*"],
5056
weights=weights,
5157
input_activations=input_activations,
5258
)
53-
layer = Linear(4, 4)
5459

5560
assert not hasattr(layer, "quantization_scheme")
5661
assert not hasattr(layer, "quantization_status")
@@ -77,3 +82,37 @@ def test_initialize_module_for_quantization(
7782
assert hasattr(layer, "quantization_status")
7883

7984
assert layer.quantization_status == QuantizationStatus.INITIALIZED
85+
86+
87+
@requires_accelerate()
88+
@pytest.mark.parametrize(
89+
"weights,input_activations",
90+
[
91+
(
92+
QuantizationArgs(num_bits=NUM_BITS, symmetric=True),
93+
None,
94+
),
95+
(
96+
None,
97+
QuantizationArgs(num_bits=NUM_BITS, symmetric=True),
98+
),
99+
(
100+
QuantizationArgs(num_bits=NUM_BITS, symmetric=True),
101+
QuantizationArgs(num_bits=NUM_BITS, symmetric=True),
102+
),
103+
],
104+
)
105+
def test_initialize_module_for_quantization_offloaded(
106+
create_quantization_scheme, weights, input_activations
107+
):
108+
from accelerate.hooks import attach_align_device_hook
109+
110+
layer = Linear(4, 4)
111+
attach_align_device_hook(layer, offload=True)
112+
113+
test_initialize_module_for_quantization(
114+
create_quantization_scheme,
115+
weights,
116+
input_activations,
117+
layer,
118+
)

0 commit comments

Comments
 (0)