Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable torch.compile on HPU #307

Merged
merged 7 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 3 additions & 10 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
mv_module_from_gpu,
unsupport_meta_device, detect_device_count, clear_memory,
get_multimodal_block_names, get_library_version,
compile_func,
)
from .low_cpu_mem.utils import get_layers_before_block

Expand Down Expand Up @@ -386,11 +387,7 @@ def quant_layers(self, layer_names, layer_inputs):

self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage)
clear_memory()
torch_version = get_library_version("torch")
if version.parse(torch_version) >= version.parse("2.5.99"):
quant_layer = torch.compile(self.quant_layer)
else:
quant_layer = self.quant_layer
quant_layer = compile_func(self.quant_layer)
for layer_name in layer_names:
layer_input = layer_inputs[layer_name]
layer_input = to_device(layer_input, self.cache_device)
Expand Down Expand Up @@ -1110,11 +1107,7 @@ def quant_blocks(
elif isinstance(input_others[key], list):
for i in range(len(input_others[key])):
to_dtype(input_others[key][i], tmp_dtype)
torch_version = get_library_version("torch")
if version.parse(torch_version) >= version.parse("2.5.99"):
quant_block = torch.compile(self.quant_block)
else:
quant_block = self.quant_block
quant_block = compile_func(self.quant_block, device)

pbar = tqdm(range(0, len(block_names), nblocks))
for i in pbar:
Expand Down
63 changes: 56 additions & 7 deletions auto_round/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import sys
import subprocess
from collections import UserDict

import re
# for cpu usage
import cpuinfo
import numpy as np
Expand Down Expand Up @@ -751,14 +751,10 @@ def is_autoround_exllamav2_available():
res = False
return res


@lru_cache(None)
def is_hpu_supported(): # pragma: no cover
try:
import subprocess
import habana_frameworks.torch.core as htcore # pylint: disable=E0401
hqt_version = subprocess.check_output(['pip', 'show', \
'habana_quantization_toolkit']).decode().split('\n')[1].split(': ')[1]
assert (hqt_version >= "1.17")
except ImportError as e:
return False
return True
Expand Down Expand Up @@ -859,9 +855,62 @@ def get_autogptq_packing_qlinear(backend, bits=4, group_size=128, sym=False):
return QuantLinear


def clear_memory(tensor=None):
def _clear_memory_for_cpu_and_cuda(tensor=None):
if tensor is not None:
del tensor
gc.collect()
torch.cuda.empty_cache()


def clear_memory(tensor=None):
if is_hpu_supported():
# hpu does not have empty_cache
return
else:
_clear_memory_for_cpu_and_cuda(tensor)


# Copied from TorchAO
def parse_version(version_string):
# Extract just the X.Y.Z part from the version string
match = re.match(r"(\d+\.\d+\.\d+)", version_string)
if match:
version = match.group(1)
return [int(x) for x in version.split(".")]
else:
raise ValueError(f"Invalid version string format: {version_string}")


def compare_versions(v1, v2):
v1_parts = parse_version(v1)
yiliu30 marked this conversation as resolved.
Show resolved Hide resolved
v2_parts = parse_version(v2)
return (v1_parts > v2_parts) - (v1_parts < v2_parts)


def torch_version_at_least(min_version):
return compare_versions(torch.__version__, min_version) >= 0


TORCH_VERSION_AT_LEAST_2_6 = torch_version_at_least("2.6.0")
TORCH_VERSION_AT_LEAST_2_5 = torch_version_at_least("2.5.0")
TORCH_VERSION_AT_LEAST_2_4 = torch_version_at_least("2.4.0")


def compile_func_on_hpu(func):
if TORCH_VERSION_AT_LEAST_2_4:
return torch.compile(func, backend="hpu_backend")
return func


def compile_func_on_cuda_or_cpu(func):
if TORCH_VERSION_AT_LEAST_2_6:
return torch.compile(func)
else:
return func


def compile_func(fun, device):
if "hpu" in str(device):
return compile_func_on_hpu(fun, device)
else:
return compile_func_on_cuda_or_cpu(fun, device)