Skip to content

Commit

Permalink
[HPU] Add lazy mode back (#371)
Browse files Browse the repository at this point in the history
Signed-off-by: yiliu30 <yi4.liu@intel.com>
Signed-off-by: Yi Liu <yiliu4@habana.ai>
Co-authored-by: Yi Liu <yiliu4@habana.ai>
  • Loading branch information
yiliu30 and Yi4Liu authored Dec 5, 2024
1 parent 67281e2 commit 7acb784
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 11 deletions.
11 changes: 8 additions & 3 deletions .azure-pipelines/scripts/ut/run_ut_hpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,14 @@ LOG_DIR=/auto-round/log_dir
mkdir -p ${LOG_DIR}
ut_log_name=${LOG_DIR}/ut.log

find . -name "test*hpu_only.py" | sed "s,\.\/,python -m pytest --cov=\"${auto_round_path}\" --cov-report term --html=report.html --self-contained-html --cov-report xml:coverage.xml --cov-append -vs --disable-warnings ,g" > run.sh
cat run.sh
bash run.sh 2>&1 | tee ${ut_log_name}
find . -name "test*hpu_only.py" | sed "s,\.\/,python -m pytest --cov=\"${auto_round_path}\" --cov-report term --html=report.html --self-contained-html --cov-report xml:coverage.xml --cov-append -vs --disable-warnings ,g" > run_lazy.sh
find . -name "test*hpu_only.py" | sed "s,\.\/,python -m pytest --mode compile --cov=\"${auto_round_path}\" --cov-report term --html=report.html --self-contained-html --cov-report xml:coverage.xml --cov-append -vs --disable-warnings ,g" > run_compile.sh

cat run_lazy.sh
bash run_lazy.sh 2>&1 | tee ${ut_log_name}

cat run_compile.sh
bash run_compile.sh 2>&1 | tee ${ut_log_name}

cp report.html ${LOG_DIR}/
cp coverage.xml ${LOG_DIR}/
Expand Down
29 changes: 22 additions & 7 deletions auto_round/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def extract_block_names_to_str(quant_block_list):
prefixes = [get_common_prefix(blocks) for blocks in quant_block_list]
# Join prefixes into a single string
return ','.join(prefixes)


def find_matching_blocks(model, all_blocks, to_quant_block_names):
"""
Expand Down Expand Up @@ -966,20 +966,36 @@ def torch_version_at_least(version_string):
TORCH_VERSION_AT_LEAST_2_5 = torch_version_at_least("2.5.0")
TORCH_VERSION_AT_LEAST_2_4 = torch_version_at_least("2.4.0")

# Note on HPU usage:
# There are two modes available for enabling auto-round on HPU:
# 1. Compile Mode
# 1) Use PyTorch version ≥ 2.4 (Intel® Gaudi® v1.18 or later)
# 2) Set `PT_HPU_LAZY_MODE=0` and `PT_ENABLE_INT64_SUPPORT=1`
# The compile mode can speed up quantization process but still in experimental stage.
# 2. Lazy Mode (By default)

def check_hpu_compile_mode():

def _check_hpu_compile_mode():
assert (
os.getenv("PT_HPU_LAZY_MODE") == "0"
os.getenv("PT_HPU_LAZY_MODE") == "0"
), "Please set `PT_HPU_LAZY_MODE=0` to use HPU compile mode"
# Note: this is a temporary solution, will be removed in the future
assert (
os.getenv("PT_ENABLE_INT64_SUPPORT") == "1"
os.getenv("PT_ENABLE_INT64_SUPPORT") == "1"
), "Please set `PT_ENABLE_INT64_SUPPORT=1` to use HPU compile mode"


def is_hpu_lazy_mode():
return os.getenv("PT_HPU_LAZY_MODE") != "0"


def _use_hpu_compile_mode():
return TORCH_VERSION_AT_LEAST_2_4 and not is_hpu_lazy_mode()


def compile_func_on_hpu(func):
if TORCH_VERSION_AT_LEAST_2_4:
check_hpu_compile_mode()
if _use_hpu_compile_mode():
_check_hpu_compile_mode()
return torch.compile(func, backend="hpu_backend")
return func

Expand Down Expand Up @@ -1097,4 +1113,3 @@ def get_fp_layer_names(model, fp_layers):
not_to_quantized_layers.append(name)

return not_to_quantized_layers

9 changes: 9 additions & 0 deletions test/_test_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import pytest


def is_pytest_mode_compile():
return pytest.mode == "compile"


def is_pytest_mode_lazy():
return pytest.mode == "lazy"
34 changes: 34 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import os
from typing import Mapping

import pytest


def pytest_addoption(parser):
parser.addoption(
"--mode",
action="store",
default="lazy",
help="{compile|lazy}, default lazy. Choose mode to run tests",
)


backup_env = pytest.StashKey[Mapping]()


def pytest_configure(config):
pytest.mode = config.getoption("--mode")
assert pytest.mode.lower() in ["lazy", "compile"]

config.stash[backup_env] = os.environ

if pytest.mode == "lazy":
os.environ["PT_HPU_LAZY_MODE"] = "1"
elif pytest.mode == "compile":
os.environ["PT_HPU_LAZY_MODE"] = "0"
os.environ["PT_ENABLE_INT64_SUPPORT"] = "1"


def pytest_unconfigure(config):
os.environ.clear()
os.environ.update(config.stash[backup_env])
45 changes: 44 additions & 1 deletion test/test_auto_round_hpu_only.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,46 @@
import pytest
import torch
from auto_round.utils import is_hpu_supported

from _test_helpers import is_pytest_mode_compile, is_pytest_mode_lazy


def run_opt_125m_on_hpu():
from auto_round import AutoRound
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "facebook/opt-125m"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

bits, group_size, sym = 4, 128, False
autoround = AutoRound(
model,
tokenizer,
bits=bits,
group_size=group_size,
sym=sym,
iters=2,
seqlen=2,
)
q_model, qconfig = autoround.quantize()
assert q_model is not None, f"Expected q_model to be not None"


@pytest.mark.skipif(not is_hpu_supported(), reason="HPU is not supported")
@pytest.mark.skipif(not is_pytest_mode_lazy(), reason="Only for lazy mode")
def test_opt_125m_lazy_mode():
run_opt_125m_on_hpu()


@pytest.mark.skipif(not is_hpu_supported(), reason="HPU is not supported")
@pytest.mark.skipif(not is_pytest_mode_compile(), reason="Only for compile mode")
def test_opt_125m_compile_mode():
torch._dynamo.reset()
run_opt_125m_on_hpu()


def test_import():
from auto_round import AutoRound
from auto_round.export.export_to_itrex.export import save_quantized_as_itrex, WeightOnlyLinear
from auto_round.export.export_to_itrex.export import (
WeightOnlyLinear, save_quantized_as_itrex)

0 comments on commit 7acb784

Please sign in to comment.