diff --git a/auto_round/autoround.py b/auto_round/autoround.py index 7cadb49f..95efc8d1 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -52,7 +52,8 @@ mv_module_from_gpu, unsupport_meta_device, clear_memory, compile_func, - find_matching_blocks, is_debug_mode + find_matching_blocks, is_debug_mode, + TORCH_VERSION_AT_LEAST_2_6 ) from .low_cpu_mem.utils import get_layers_before_block @@ -159,7 +160,7 @@ def __init__( act_dynamic: bool = True, to_quant_block_names: Union[str, list] = None, enable_norm_bias_tuning: bool = False, - enable_torch_compile: bool = None, + enable_torch_compile: bool = False, device_map: Union[str, dict] = None, **kwargs, ): @@ -232,19 +233,24 @@ def __init__( logger.info(f"using {self.model.dtype} for quantization tuning") self.enable_torch_compile = enable_torch_compile - if self.act_bits <= 8 and self.enable_torch_compile != False: + if not self.enable_torch_compile and TORCH_VERSION_AT_LEAST_2_6 and self.act_bits > 8 and not is_debug_mode() \ + and self.low_cpu_mem_usage != True and "fp8" not in self.data_type and "fp8" not in self.act_data_type: + logger.info("'enable_torch_compile' is set to `False` by default. " \ + "Enabling it can reduce tuning cost by 20%, but it might throw an exception.") + + if self.act_bits <= 8 and self.enable_torch_compile: self.enable_torch_compile = False logger.warning("reset enable_torch_compile to `False` as activation quantization is enabled") - if self.low_cpu_mem_usage == True and self.enable_torch_compile != False: + if self.low_cpu_mem_usage == True and self.enable_torch_compile: self.enable_torch_compile = False logger.warning("reset enable_torch_compile to `False` as low_cpu_mem_usage is enabled") - if is_debug_mode() and self.enable_torch_compile != False: + if is_debug_mode() and self.enable_torch_compile: self.enable_torch_compile = False logger.warning("reset enable_torch_compile to `False` as debug mode is enabled") - if ("fp8" in self.data_type or "fp8" in self.act_data_type) and self.enable_torch_compile != False: + if ("fp8" in self.data_type or "fp8" in self.act_data_type) and self.enable_torch_compile: self.enable_torch_compile = False logger.warning("reset enable_torch_compile to `False` as fp8 is enabled") @@ -493,13 +499,8 @@ def quant_layers(self, layer_names, layer_inputs): self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage) clear_memory() device = next(self.model.parameters()).device - if self.enable_torch_compile != False: - try: - quant_layer = compile_func(self.quant_layer, self.device, self.enable_torch_compile) - except: - logger.warning("torch compile failed, reset it to `False`") - self.enable_torch_compile = False - quant_layer = self.quant_layer + if self.enable_torch_compile: + quant_layer = compile_func(self.quant_layer, self.device) else: quant_layer = self.quant_layer for layer_name in layer_names: @@ -1311,13 +1312,8 @@ def quant_blocks( elif isinstance(input_others[key], list): for i in range(len(input_others[key])): to_dtype(input_others[key][i], tmp_dtype) - if self.enable_torch_compile != False: - try: - quant_block = compile_func(self.quant_block, device, self.enable_torch_compile) - except: - logger.warning("torch compile failed, reset it to `False`") - self.enable_torch_compile = False - quant_block = self.quant_block + if self.enable_torch_compile: + quant_block = compile_func(self.quant_block, device) else: quant_block = self.quant_block @@ -1648,7 +1644,7 @@ def __init__( act_dynamic: bool = True, to_quant_block_names: Union[str, list] = None, enable_norm_bias_tuning: bool = False, - enable_torch_compile: bool = None, + enable_torch_compile: bool = False, device_map: Union[str, dict] = None, optimizer="AdamW", **kwargs, @@ -1822,7 +1818,7 @@ def __init__( act_dynamic: bool = True, to_quant_block_names: Union[str, list] = None, enable_norm_bias_tuning: bool = False, - enable_torch_compile: bool = None, + enable_torch_compile: bool = False, device_map: Union[str, dict] = None, optimizer="AdamW", **kwargs, @@ -1868,3 +1864,4 @@ def __init__( optimizer=optimizer, **kwargs, ) + diff --git a/auto_round/mllm/autoround_mllm.py b/auto_round/mllm/autoround_mllm.py index 429e914a..1427df4d 100644 --- a/auto_round/mllm/autoround_mllm.py +++ b/auto_round/mllm/autoround_mllm.py @@ -112,7 +112,7 @@ class AutoRoundMLLM(AutoRound): act_dynamic (bool): Whether to use dynamic activation quantization. Default is True. to_quant_block_names (str|list): A string or list whose elements are list of block's layer names to be quantized. - enable_torch_compile (bool): Whether to enable torch compile to optimize quant_block/layer, torch>=2.6 True + enable_torch_compile (bool): Whether to enable torch compile to optimize quant_block/layer **kwargs: Additional keyword arguments. @@ -160,7 +160,7 @@ def __init__( to_quant_block_names: Union[str, list] = None, enable_norm_bias_tuning: bool = False, truncation: bool = None, - enable_torch_compile: bool = None, + enable_torch_compile: bool = False, **kwargs, ): all_blocks = get_multimodal_block_names(model, quant_nontext_module) @@ -410,3 +410,4 @@ def save_quantized(self, output_dir=None, format="auto_round", inplace=True, **k compressed_model = super().save_quantized( output_dir=output_dir, format=format, inplace=inplace, processor=self.processor, **kwargs) return compressed_model + diff --git a/auto_round/script/llm.py b/auto_round/script/llm.py index 98ab958d..0b0498cd 100644 --- a/auto_round/script/llm.py +++ b/auto_round/script/llm.py @@ -169,8 +169,8 @@ def __init__(self, *args, **kwargs): type=str, help="Names of quantitative blocks, please use commas to separate them.") - self.add_argument("--disable_torch_compile", action='store_true', - help="whether to disable torch compile") + self.add_argument("--enable_torch_compile", action='store_true', + help="whether to enable torch compile") self.add_argument("--act_data_type", default=None, type=str, help="activation data type") @@ -353,9 +353,9 @@ def tune(args): # logger.info("`torch.use_deterministic_algorithms` is enabled by default for reproducibility " # "and can be disabled using the `--disable_deterministic_algorithms` argument.") - if not args.disable_torch_compile: - logger.info("`torch.compile` is enabled by default to reduce tuning costs. " - "If it causes issues, you can disable it using the `--disable_torch_compile` argument.") + if args.enable_torch_compile: + logger.info("`torch.compile` is enabled to reduce tuning costs. " + "If it causes issues, you can disable it by remove `--enable_torch_compile` argument.") model_name = args.model if model_name[-1] == "/": @@ -482,7 +482,7 @@ def tune(args): if not awq_supported: logger.warning(f"The AutoAWQ format may not be supported due to {info}") - enable_torch_compile = False if "--disable_torch_compile" in sys.argv else None + enable_torch_compile = True if "--enable_torch_compile" in sys.argv else False autoround = round( model, @@ -621,3 +621,4 @@ def eval_sequence(args): for key in res_keys: res_all[key].update(res[key]) print(make_table(res_all)) + diff --git a/auto_round/script/mllm.py b/auto_round/script/mllm.py index 0080bd65..3b00d74e 100644 --- a/auto_round/script/mllm.py +++ b/auto_round/script/mllm.py @@ -152,8 +152,8 @@ def __init__(self, *args, **kwargs): action='store_true', help="whether to use the iter of best mes loss in the tuning phase") - self.add_argument("--disable_torch_compile", action='store_true', - help="whether to disable torch compile") + self.add_argument("--enable_torch_compile", action='store_true', + help="whether to enable torch compile") self.add_argument("--disable_deterministic_algorithms", action='store_true', help="disable torch deterministic algorithms.") @@ -446,7 +446,7 @@ def tune(args): if not awq_supported: logger.warning(f"The AutoAWQ format may not be supported due to {info}") - enable_torch_compile = False if "--disable_torch_compile" in sys.argv else None + enable_torch_compile = True if "--enable_torch_compile" in sys.argv else False autoround = round( model, @@ -598,3 +598,4 @@ def lmms_eval(args): ) return results + diff --git a/auto_round/utils.py b/auto_round/utils.py index 4c540654..f863cd3a 100644 --- a/auto_round/utils.py +++ b/auto_round/utils.py @@ -1009,18 +1009,15 @@ def compile_func_on_hpu(func): return func -def compile_func_on_cuda_or_cpu(func, enable_torch_compile): - if enable_torch_compile or (TORCH_VERSION_AT_LEAST_2_6_PRE_RELEASE and enable_torch_compile != False): - return torch.compile(func) - else: - return func +def compile_func_on_cuda_or_cpu(func): + return torch.compile(func) -def compile_func(fun, device, enable_torch_compile): +def compile_func(fun, device): if "hpu" in str(device): return compile_func_on_hpu(fun) ## use auto by default else: - return compile_func_on_cuda_or_cpu(fun, enable_torch_compile) + return compile_func_on_cuda_or_cpu(fun) def is_numba_available(): # pragma: no cover @@ -1201,3 +1198,4 @@ def is_debug_mode(): bool: True if debugging is enabled, False otherwise. """ return sys.gettrace() is not None or sys.flags.debug == 1 +