Skip to content

Commit

Permalink
set torch compile to false by default (#447)
Browse files Browse the repository at this point in the history
* align auto_quantizer with main branch in Transformers

Signed-off-by: Zhang, Weiwei1 <weiwei1.zhang@intel.com>

* rename torch_compile argment, set torch_compile to False by default

Signed-off-by: Zhang, Weiwei1 <weiwei1.zhang@intel.com>

* Update auto_quantizer.py

* fixtypos

Signed-off-by: Zhang, Weiwei1 <weiwei1.zhang@intel.com>

* refine code

Signed-off-by: Zhang, Weiwei1 <weiwei1.zhang@intel.com>

* fixtypo and refine compile func

Signed-off-by: Zhang, Weiwei1 <weiwei1.zhang@intel.com>

* fix scan issue

Signed-off-by: Zhang, Weiwei1 <weiwei1.zhang@intel.com>

* fixtypo

Signed-off-by: Zhang, Weiwei1 <weiwei1.zhang@intel.com>

---------

Signed-off-by: Zhang, Weiwei1 <weiwei1.zhang@intel.com>
  • Loading branch information
WeiweiZhang1 authored Feb 21, 2025
1 parent 3a14328 commit b650162
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 40 deletions.
41 changes: 19 additions & 22 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@
mv_module_from_gpu,
unsupport_meta_device, clear_memory,
compile_func,
find_matching_blocks, is_debug_mode
find_matching_blocks, is_debug_mode,
TORCH_VERSION_AT_LEAST_2_6
)
from .low_cpu_mem.utils import get_layers_before_block

Expand Down Expand Up @@ -159,7 +160,7 @@ def __init__(
act_dynamic: bool = True,
to_quant_block_names: Union[str, list] = None,
enable_norm_bias_tuning: bool = False,
enable_torch_compile: bool = None,
enable_torch_compile: bool = False,
device_map: Union[str, dict] = None,
**kwargs,
):
Expand Down Expand Up @@ -232,19 +233,24 @@ def __init__(
logger.info(f"using {self.model.dtype} for quantization tuning")

self.enable_torch_compile = enable_torch_compile
if self.act_bits <= 8 and self.enable_torch_compile != False:
if not self.enable_torch_compile and TORCH_VERSION_AT_LEAST_2_6 and self.act_bits > 8 and not is_debug_mode() \
and self.low_cpu_mem_usage != True and "fp8" not in self.data_type and "fp8" not in self.act_data_type:
logger.info("'enable_torch_compile' is set to `False` by default. " \
"Enabling it can reduce tuning cost by 20%, but it might throw an exception.")

if self.act_bits <= 8 and self.enable_torch_compile:
self.enable_torch_compile = False
logger.warning("reset enable_torch_compile to `False` as activation quantization is enabled")

if self.low_cpu_mem_usage == True and self.enable_torch_compile != False:
if self.low_cpu_mem_usage == True and self.enable_torch_compile:
self.enable_torch_compile = False
logger.warning("reset enable_torch_compile to `False` as low_cpu_mem_usage is enabled")

if is_debug_mode() and self.enable_torch_compile != False:
if is_debug_mode() and self.enable_torch_compile:
self.enable_torch_compile = False
logger.warning("reset enable_torch_compile to `False` as debug mode is enabled")

if ("fp8" in self.data_type or "fp8" in self.act_data_type) and self.enable_torch_compile != False:
if ("fp8" in self.data_type or "fp8" in self.act_data_type) and self.enable_torch_compile:
self.enable_torch_compile = False
logger.warning("reset enable_torch_compile to `False` as fp8 is enabled")

Expand Down Expand Up @@ -493,13 +499,8 @@ def quant_layers(self, layer_names, layer_inputs):
self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage)
clear_memory()
device = next(self.model.parameters()).device
if self.enable_torch_compile != False:
try:
quant_layer = compile_func(self.quant_layer, self.device, self.enable_torch_compile)
except:
logger.warning("torch compile failed, reset it to `False`")
self.enable_torch_compile = False
quant_layer = self.quant_layer
if self.enable_torch_compile:
quant_layer = compile_func(self.quant_layer, self.device)
else:
quant_layer = self.quant_layer
for layer_name in layer_names:
Expand Down Expand Up @@ -1311,13 +1312,8 @@ def quant_blocks(
elif isinstance(input_others[key], list):
for i in range(len(input_others[key])):
to_dtype(input_others[key][i], tmp_dtype)
if self.enable_torch_compile != False:
try:
quant_block = compile_func(self.quant_block, device, self.enable_torch_compile)
except:
logger.warning("torch compile failed, reset it to `False`")
self.enable_torch_compile = False
quant_block = self.quant_block
if self.enable_torch_compile:
quant_block = compile_func(self.quant_block, device)
else:
quant_block = self.quant_block

Expand Down Expand Up @@ -1648,7 +1644,7 @@ def __init__(
act_dynamic: bool = True,
to_quant_block_names: Union[str, list] = None,
enable_norm_bias_tuning: bool = False,
enable_torch_compile: bool = None,
enable_torch_compile: bool = False,
device_map: Union[str, dict] = None,
optimizer="AdamW",
**kwargs,
Expand Down Expand Up @@ -1822,7 +1818,7 @@ def __init__(
act_dynamic: bool = True,
to_quant_block_names: Union[str, list] = None,
enable_norm_bias_tuning: bool = False,
enable_torch_compile: bool = None,
enable_torch_compile: bool = False,
device_map: Union[str, dict] = None,
optimizer="AdamW",
**kwargs,
Expand Down Expand Up @@ -1868,3 +1864,4 @@ def __init__(
optimizer=optimizer,
**kwargs,
)

5 changes: 3 additions & 2 deletions auto_round/mllm/autoround_mllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class AutoRoundMLLM(AutoRound):
act_dynamic (bool): Whether to use dynamic activation quantization. Default is True.
to_quant_block_names (str|list): A string or list whose elements are list of
block's layer names to be quantized.
enable_torch_compile (bool): Whether to enable torch compile to optimize quant_block/layer, torch>=2.6 True
enable_torch_compile (bool): Whether to enable torch compile to optimize quant_block/layer
**kwargs: Additional keyword arguments.
Expand Down Expand Up @@ -160,7 +160,7 @@ def __init__(
to_quant_block_names: Union[str, list] = None,
enable_norm_bias_tuning: bool = False,
truncation: bool = None,
enable_torch_compile: bool = None,
enable_torch_compile: bool = False,
**kwargs,
):
all_blocks = get_multimodal_block_names(model, quant_nontext_module)
Expand Down Expand Up @@ -410,3 +410,4 @@ def save_quantized(self, output_dir=None, format="auto_round", inplace=True, **k
compressed_model = super().save_quantized(
output_dir=output_dir, format=format, inplace=inplace, processor=self.processor, **kwargs)
return compressed_model

13 changes: 7 additions & 6 deletions auto_round/script/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ def __init__(self, *args, **kwargs):
type=str,
help="Names of quantitative blocks, please use commas to separate them.")

self.add_argument("--disable_torch_compile", action='store_true',
help="whether to disable torch compile")
self.add_argument("--enable_torch_compile", action='store_true',
help="whether to enable torch compile")

self.add_argument("--act_data_type", default=None, type=str, help="activation data type")

Expand Down Expand Up @@ -353,9 +353,9 @@ def tune(args):
# logger.info("`torch.use_deterministic_algorithms` is enabled by default for reproducibility "
# "and can be disabled using the `--disable_deterministic_algorithms` argument.")

if not args.disable_torch_compile:
logger.info("`torch.compile` is enabled by default to reduce tuning costs. "
"If it causes issues, you can disable it using the `--disable_torch_compile` argument.")
if args.enable_torch_compile:
logger.info("`torch.compile` is enabled to reduce tuning costs. "
"If it causes issues, you can disable it by remove `--enable_torch_compile` argument.")

model_name = args.model
if model_name[-1] == "/":
Expand Down Expand Up @@ -482,7 +482,7 @@ def tune(args):
if not awq_supported:
logger.warning(f"The AutoAWQ format may not be supported due to {info}")

enable_torch_compile = False if "--disable_torch_compile" in sys.argv else None
enable_torch_compile = True if "--enable_torch_compile" in sys.argv else False

autoround = round(
model,
Expand Down Expand Up @@ -621,3 +621,4 @@ def eval_sequence(args):
for key in res_keys:
res_all[key].update(res[key])
print(make_table(res_all))

7 changes: 4 additions & 3 deletions auto_round/script/mllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,8 @@ def __init__(self, *args, **kwargs):
action='store_true',
help="whether to use the iter of best mes loss in the tuning phase")

self.add_argument("--disable_torch_compile", action='store_true',
help="whether to disable torch compile")
self.add_argument("--enable_torch_compile", action='store_true',
help="whether to enable torch compile")

self.add_argument("--disable_deterministic_algorithms", action='store_true',
help="disable torch deterministic algorithms.")
Expand Down Expand Up @@ -446,7 +446,7 @@ def tune(args):
if not awq_supported:
logger.warning(f"The AutoAWQ format may not be supported due to {info}")

enable_torch_compile = False if "--disable_torch_compile" in sys.argv else None
enable_torch_compile = True if "--enable_torch_compile" in sys.argv else False

autoround = round(
model,
Expand Down Expand Up @@ -598,3 +598,4 @@ def lmms_eval(args):
)
return results


12 changes: 5 additions & 7 deletions auto_round/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,18 +1009,15 @@ def compile_func_on_hpu(func):
return func


def compile_func_on_cuda_or_cpu(func, enable_torch_compile):
if enable_torch_compile or (TORCH_VERSION_AT_LEAST_2_6_PRE_RELEASE and enable_torch_compile != False):
return torch.compile(func)
else:
return func
def compile_func_on_cuda_or_cpu(func):
return torch.compile(func)


def compile_func(fun, device, enable_torch_compile):
def compile_func(fun, device):
if "hpu" in str(device):
return compile_func_on_hpu(fun) ## use auto by default
else:
return compile_func_on_cuda_or_cpu(fun, enable_torch_compile)
return compile_func_on_cuda_or_cpu(fun)


def is_numba_available(): # pragma: no cover
Expand Down Expand Up @@ -1201,3 +1198,4 @@ def is_debug_mode():
bool: True if debugging is enabled, False otherwise.
"""
return sys.gettrace() is not None or sys.flags.debug == 1

0 comments on commit b650162

Please sign in to comment.