diff --git a/auto_round/autoround.py b/auto_round/autoround.py index 36166633..0bf66177 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -160,9 +160,8 @@ def __init__( self.seed = seed set_seed(self.seed) assert not unsupport_meta_device(model), ( - "autoround does not support for params on meta device by transformers` interfaces," - "please do not using device_map='auto' in model loading, " - "or follow examples/language-modeling/main.py to enable low_cpu_mem_usage") + "AutoRound does not support for params on meta device." + " Please use more gpus vis set `--device 0,1,2,3` or just use one gpu") ## important tuning hype-parameters self.amp = amp @@ -283,12 +282,12 @@ def quantize(self): layer_names = self.get_quantized_layer_names_outside_blocks() self.start_time = time.time() all_first_block_names = [block[0] for block in all_blocks] - logger.info("start calibration") + logger.info("start to cache block inputs") all_inputs = self.try_cache_inter_data_gpucpu(all_first_block_names, self.nsamples, layer_names=layer_names) if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1: accelerate.hooks.remove_hook_from_submodules(self.model) ##self.model.hf_device_map has not been changed self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage) - logger.info("calibration done") + logger.info("caching done") for block_names in all_blocks: inputs = all_inputs[block_names[0]] all_inputs.pop(block_names[0]) @@ -615,8 +614,8 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l clear_memory() except RuntimeError as e: if "CUDA out of memory" in str(e): - logger.info("switch to cpu to cache inputs") - if (("lm_head" in self.layer_config and self.layer_config["lm_head"]["bits"] <= 16) or + logger.info("switch to cpu to cache block inputs") + if (("lm_head" in self.layer_config and self.layer_config["lm_head"]["bits"] < 16) or self.__class__.__name__ == "AutoRoundMLLM"): logger.warning(f"we strongly recommend using additional CUDA/HPU devices,e.g. " f"set `--device '0,1'` in our cmd line usage or " diff --git a/auto_round/script/llm.py b/auto_round/script/llm.py index eb137b33..d153cb07 100644 --- a/auto_round/script/llm.py +++ b/auto_round/script/llm.py @@ -143,6 +143,8 @@ def __init__(self, *args, **kwargs): self.add_argument("--not_use_best_mse", action='store_true', help="whether to use the iter of best mes loss in the tuning phase") + self.add_argument("--enable_torch_compile", default=None, type=bool, + help="whether to enable torch compile") def setup_parser(): parser = BasicArgumentParser() @@ -208,6 +210,7 @@ def setup_fast_parser(): parser.add_argument("--nsamples", default=128, type=int, help="number of samples") + args = parser.parse_args() return args @@ -254,7 +257,8 @@ def tune(args): os.environ["CUDA_VISIBLE_DEVICES"] = args.device args.device = ",".join(map(str, range(len(devices)))) devices = args.device.replace(" ", "").split(',') - use_auto_mapping = True + if len(devices) > 1: ##for 70B model on single card, use auto will cause some layer offload to cpu + use_auto_mapping = True import re import torch @@ -328,10 +332,14 @@ def tune(args): seqlen = 2048 if args.model_dtype != None: - if args.model_dtype == "float16" or args.model_dtype == "fp16": - model = model.to(torch.float16) - if args.model_dtype == "bfloat16" or args.model_dtype == "bfp16": - model = model.to(torch.bfloat16) + try: + if args.model_dtype == "float16" or args.model_dtype == "fp16": + model = model.to(torch.float16) + if args.model_dtype == "bfloat16" or args.model_dtype == "bfp16": + model = model.to(torch.bfloat16) + except: + logger.error("please use more device to fit the device or just use one device") + exit() if hasattr(tokenizer, "model_max_length"): if tokenizer.model_max_length < seqlen: @@ -395,7 +403,8 @@ def tune(args): gradient_accumulate_steps=args.gradient_accumulate_steps, layer_config=layer_config, enable_minmax_tuning=not args.disable_minmax_tuning, act_bits=args.act_bits, low_cpu_mem_usage=low_cpu_mem_usage, data_type=args.data_type, - enable_norm_bias_tuning=args.enable_norm_bias_tuning, not_use_best_mse=args.not_use_best_mse) + enable_norm_bias_tuning=args.enable_norm_bias_tuning, not_use_best_mse=args.not_use_best_mse, + enable_torch_compile=args.enable_torch_compile) model, _ = autoround.quantize() model_name = args.model.rstrip("/") if args.low_cpu_mem_mode == 1 or args.low_cpu_mem_mode == 2: @@ -473,7 +482,8 @@ def eval(args): os.environ["CUDA_VISIBLE_DEVICES"] = args.device args.device = ",".join(map(str, range(len(devices)))) devices = args.device.replace(" ", "").split(',') - parallelism = True + if len(devices) > 1: + parallelism = True device_str = None else: device_str = detect_device(args.device.replace(" ", "")) diff --git a/auto_round/script/mllm.py b/auto_round/script/mllm.py index d3337468..5e18b932 100644 --- a/auto_round/script/mllm.py +++ b/auto_round/script/mllm.py @@ -134,6 +134,9 @@ def __init__(self, *args, **kwargs): self.add_argument("--not_use_best_mse", action='store_true', help="whether to use the iter of best mes loss in the tuning phase") + self.add_argument("--enable_torch_compile", default=None, type=bool, + help="whether to enable torch compile") + ## ======================= VLM ======================= self.add_argument("--quant_nontext_module", action='store_true', help="whether to quantize non-text module, e.g. vision component") @@ -337,7 +340,8 @@ def tune(args): device=device_str, seed=args.seed, gradient_accumulate_steps=args.gradient_accumulate_steps, scale_dtype=args.scale_dtype, layer_config=layer_config, template=args.template, enable_minmax_tuning=not args.disable_minmax_tuning, act_bits=args.act_bits, - quant_nontext_module=args.quant_nontext_module, not_use_best_mse=args.not_use_best_mse) + quant_nontext_module=args.quant_nontext_module, not_use_best_mse=args.not_use_best_mse, + enable_torch_compile=args.enable_torch_compile) model, _ = autoround.quantize() model.eval() diff --git a/auto_round/utils.py b/auto_round/utils.py index 4037a19b..4a04ff18 100644 --- a/auto_round/utils.py +++ b/auto_round/utils.py @@ -906,7 +906,7 @@ def compile_func_on_hpu(func): def compile_func_on_cuda_or_cpu(func, enable_torch_compile): - if enable_torch_compile or TORCH_VERSION_AT_LEAST_2_6_PRE_RELEASE: + if enable_torch_compile or (TORCH_VERSION_AT_LEAST_2_6_PRE_RELEASE and enable_torch_compile!=False): return torch.compile(func) else: return func diff --git a/examples/language-modeling/main.py b/examples/language-modeling/main.py index 4a3ca875..f0594f5d 100644 --- a/examples/language-modeling/main.py +++ b/examples/language-modeling/main.py @@ -130,6 +130,9 @@ parser.add_argument("--fp_layers", default="", type=str, help="List of Layers to maintain original data type") + parser.add_argument("--enable_torch_compile", default=None, type=bool, + help="whether to enable torch compile") + args = parser.parse_args() print( @@ -151,7 +154,7 @@ if args.format is None: args.format = "auto_round" - supported_formats = ["auto_round", "auto_gptq", "auto_awq", "auto_round:auto_gptq","auto_round:auto_awq", + supported_formats = ["auto_round", "auto_gptq", "auto_awq", "auto_round:auto_gptq", "auto_round:auto_awq", "auto_gptq:marlin", "itrex", "iterx_xpu", "fake"] formats = args.format.replace(' ', '').split(",") for format in formats: @@ -176,7 +179,8 @@ os.environ["CUDA_VISIBLE_DEVICES"] = args.device args.device = ",".join(map(str, range(len(devices)))) devices = args.device.replace(" ", "").split(',') - use_auto_mapping = True + if len(devices) > 1: + use_auto_mapping = True import torch import transformers @@ -260,10 +264,14 @@ seqlen = args.seqlen if args.model_dtype != None: - if args.model_dtype == "float16" or args.model_dtype == "fp16": - model = model.to(torch.float16) - if args.model_dtype == "bfloat16" or args.model_dtype == "bfp16": - model = model.to(torch.bfloat16) + try: + if args.model_dtype == "float16" or args.model_dtype == "fp16": + model = model.to(torch.float16) + if args.model_dtype == "bfloat16" or args.model_dtype == "bfp16": + model = model.to(torch.bfloat16) + except: + print("please use more device,e.g `--devices 0,1,2,3` to fit the device or just use one device") + exit() tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code) @@ -340,7 +348,8 @@ scale_dtype=args.scale_dtype, layer_config=layer_config, enable_minmax_tuning=not args.disable_minmax_tuning, act_bits=args.act_bits, low_cpu_mem_usage=low_cpu_mem_usage, data_type=args.data_type, - enable_norm_bias_tuning=args.enable_norm_bias_tuning) + enable_norm_bias_tuning=args.enable_norm_bias_tuning, + enable_torch_compile=args.enable_torch_compile) model, _ = autoround.quantize() model_name = args.model_name.rstrip("/") if args.low_cpu_mem_mode == 1 or args.low_cpu_mem_mode == 2: diff --git a/test/test_mllm.py b/test/test_mllm.py index ce8de559..616233f2 100644 --- a/test/test_mllm.py +++ b/test/test_mllm.py @@ -48,8 +48,8 @@ def test_tune(self): bits, group_size = 4, 128 autoround = AutoRoundMLLM( model, tokenizer, bits=bits, group_size=group_size, - nsamples=2, - batch_size=1, iters=2, dataset=self.dataset) + nsamples=1, + batch_size=1, iters=2, dataset=self.dataset,seqlen=256) autoround.quantize() autoround.save_quantized("./saved/", format="auto_gptq", inplace=False) autoround.save_quantized("./saved/", format="auto_round", inplace=False) @@ -63,8 +63,8 @@ def test_quant_vision(self): ## bug need to fix bits, group_size = 4, 128 autoround = AutoRoundMLLM( model, tokenizer, bits=bits, group_size=group_size, - nsamples=2, - batch_size=1, iters=2, dataset=self.dataset, quant_nontext_module=True) + nsamples=1, + batch_size=1, iters=2, dataset=self.dataset, quant_nontext_module=True,seqlen=256) autoround.quantize() autoround.save_quantized("./saved/", format="auto_round", inplace=True)