Skip to content

Commit d27c8c3

Browse files
CyrilvallezArthurZucker
authored andcommitted
Remove HQQ from caching allocator warmup (#37347)
Update modeling_utils.py
1 parent 04c0ced commit d27c8c3

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

src/transformers/modeling_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4622,6 +4622,7 @@ def _load_pretrained_model(
46224622
):
46234623
# Useful flags
46244624
is_quantized = hf_quantizer is not None
4625+
is_hqq = is_quantized and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ
46254626
is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in [
46264627
QuantizationMethod.HQQ,
46274628
QuantizationMethod.BITS_AND_BYTES,
@@ -4787,7 +4788,7 @@ def _load_pretrained_model(
47874788
expected_keys = hf_quantizer.update_expected_keys(model_to_load, expected_keys, checkpoint_keys)
47884789

47894790
# Warmup cuda to load the weights much faster on devices
4790-
if device_map is not None:
4791+
if device_map is not None and not is_hqq:
47914792
expanded_device_map = expand_device_map(device_map, expected_keys)
47924793
caching_allocator_warmup(model_to_load, expanded_device_map, factor=2 if hf_quantizer is None else 4)
47934794

0 commit comments

Comments
 (0)