diff --git a/auto_round/autoround.py b/auto_round/autoround.py index 564ba6fa..a4f171f6 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -51,6 +51,7 @@ mv_module_from_gpu, unsupport_meta_device, detect_device_count, clear_memory, get_multimodal_block_names, get_library_version, + compile_func, ) from .low_cpu_mem.utils import get_layers_before_block @@ -392,11 +393,8 @@ def quant_layers(self, layer_names, layer_inputs): self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage) clear_memory() - torch_version = get_library_version("torch") - if version.parse(torch_version) >= version.parse("2.5.99"): - quant_layer = torch.compile(self.quant_layer) - else: - quant_layer = self.quant_layer + device = next(self.model.parameters()).device + quant_layer = compile_func(self.quant_layer, device) for layer_name in layer_names: layer_input = layer_inputs[layer_name] layer_input = to_device(layer_input, self.cache_device) @@ -1126,11 +1124,7 @@ def quant_blocks( elif isinstance(input_others[key], list): for i in range(len(input_others[key])): to_dtype(input_others[key][i], tmp_dtype) - torch_version = get_library_version("torch") - if version.parse(torch_version) >= version.parse("2.5.99"): - quant_block = torch.compile(self.quant_block) - else: - quant_block = self.quant_block + quant_block = compile_func(self.quant_block, device) pbar = tqdm(range(0, len(block_names), nblocks)) for i in pbar: diff --git a/auto_round/utils.py b/auto_round/utils.py index d12d2164..e8226698 100644 --- a/auto_round/utils.py +++ b/auto_round/utils.py @@ -18,7 +18,7 @@ import sys import subprocess from collections import UserDict - +import re # for cpu usage import cpuinfo import numpy as np @@ -751,14 +751,10 @@ def is_autoround_exllamav2_available(): res = False return res - +@lru_cache(None) def is_hpu_supported(): # pragma: no cover try: - import subprocess import habana_frameworks.torch.core as htcore # pylint: disable=E0401 - hqt_version = subprocess.check_output(['pip', 'show', \ - 'habana_quantization_toolkit']).decode().split('\n')[1].split(': ')[1] - assert (hqt_version >= "1.17") except ImportError as e: return False return True @@ -859,9 +855,61 @@ def get_autogptq_packing_qlinear(backend, bits=4, group_size=128, sym=False): return QuantLinear -def clear_memory(tensor=None): +def _clear_memory_for_cpu_and_cuda(tensor=None): if tensor is not None: del tensor gc.collect() torch.cuda.empty_cache() + +def clear_memory(tensor=None): + if is_hpu_supported(): + # hpu does not have empty_cache + return + else: + _clear_memory_for_cpu_and_cuda(tensor) + + +def compare_versions(v1, v2): + return version.parse(v1) >= version.parse(v2) + + +def torch_version_at_least(version_string): + return compare_versions(torch.__version__, version_string) + + +TORCH_VERSION_AT_LEAST_2_6_PRE_RELEASE = torch_version_at_least("2.5.99") +TORCH_VERSION_AT_LEAST_2_6 = torch_version_at_least("2.6.0") +TORCH_VERSION_AT_LEAST_2_5 = torch_version_at_least("2.5.0") +TORCH_VERSION_AT_LEAST_2_4 = torch_version_at_least("2.4.0") + + +def check_hpu_compile_mode(): + assert ( + os.environ["PT_HPU_LAZY_MODE"] == "0" + ), "Please set `PT_HPU_LAZY_MODE=0` to use HPU compile mode" + # Note: this is a temporary solution, will be removed in the future + assert ( + os.environ["PT_ENABLE_INT64_SUPPORT"] == "1" + ), "Please set `PT_ENABLE_INT64_SUPPORT=1` to use HPU compile mode" + + +def compile_func_on_hpu(func): + if TORCH_VERSION_AT_LEAST_2_4: + check_hpu_compile_mode() + return torch.compile(func, backend="hpu_backend") + return func + + +def compile_func_on_cuda_or_cpu(func): + if TORCH_VERSION_AT_LEAST_2_6_PRE_RELEASE: + return torch.compile(func) + else: + return func + + +def compile_func(fun, device): + if "hpu" in str(device): + return compile_func_on_hpu(fun) + else: + return compile_func_on_cuda_or_cpu(fun)