Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix multiple device bug #321

Merged
merged 11 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,8 @@ def __init__(
self.seed = seed
set_seed(self.seed)
assert not unsupport_meta_device(model), (
"autoround does not support for params on meta device by transformers` interfaces,"
"please do not using device_map='auto' in model loading, "
"or follow examples/language-modeling/main.py to enable low_cpu_mem_usage")
"AutoRound does not support for params on meta device."
" Please use more gpus vis set `--device 0,1,2,3` or just use one gpu")

## important tuning hype-parameters
self.amp = amp
Expand Down Expand Up @@ -283,12 +282,12 @@ def quantize(self):
layer_names = self.get_quantized_layer_names_outside_blocks()
self.start_time = time.time()
all_first_block_names = [block[0] for block in all_blocks]
logger.info("start calibration")
logger.info("start to cache block inputs")
all_inputs = self.try_cache_inter_data_gpucpu(all_first_block_names, self.nsamples, layer_names=layer_names)
if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1:
accelerate.hooks.remove_hook_from_submodules(self.model) ##self.model.hf_device_map has not been changed
self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage)
logger.info("calibration done")
logger.info("caching done")
for block_names in all_blocks:
inputs = all_inputs[block_names[0]]
all_inputs.pop(block_names[0])
Expand Down Expand Up @@ -615,8 +614,8 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l
clear_memory()
except RuntimeError as e:
if "CUDA out of memory" in str(e):
logger.info("switch to cpu to cache inputs")
if (("lm_head" in self.layer_config and self.layer_config["lm_head"]["bits"] <= 16) or
logger.info("switch to cpu to cache block inputs")
if (("lm_head" in self.layer_config and self.layer_config["lm_head"]["bits"] < 16) or
self.__class__.__name__ == "AutoRoundMLLM"):
logger.warning(f"we strongly recommend using additional CUDA/HPU devices,e.g. "
f"set `--device '0,1'` in our cmd line usage or "
Expand Down
24 changes: 17 additions & 7 deletions auto_round/script/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ def __init__(self, *args, **kwargs):
self.add_argument("--not_use_best_mse", action='store_true',
help="whether to use the iter of best mes loss in the tuning phase")

self.add_argument("--enable_torch_compile", default=None, type=bool,
help="whether to enable torch compile")

def setup_parser():
parser = BasicArgumentParser()
Expand Down Expand Up @@ -208,6 +210,7 @@ def setup_fast_parser():
parser.add_argument("--nsamples", default=128, type=int,
help="number of samples")


args = parser.parse_args()

return args
Expand Down Expand Up @@ -254,7 +257,8 @@ def tune(args):
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
args.device = ",".join(map(str, range(len(devices))))
devices = args.device.replace(" ", "").split(',')
use_auto_mapping = True
if len(devices) > 1: ##for 70B model on single card, use auto will cause some layer offload to cpu
use_auto_mapping = True

import re
import torch
Expand Down Expand Up @@ -328,10 +332,14 @@ def tune(args):
seqlen = 2048

if args.model_dtype != None:
if args.model_dtype == "float16" or args.model_dtype == "fp16":
model = model.to(torch.float16)
if args.model_dtype == "bfloat16" or args.model_dtype == "bfp16":
model = model.to(torch.bfloat16)
try:
if args.model_dtype == "float16" or args.model_dtype == "fp16":
model = model.to(torch.float16)
if args.model_dtype == "bfloat16" or args.model_dtype == "bfp16":
model = model.to(torch.bfloat16)
except:
logger.error("please use more device to fit the device or just use one device")
exit()

if hasattr(tokenizer, "model_max_length"):
if tokenizer.model_max_length < seqlen:
Expand Down Expand Up @@ -395,7 +403,8 @@ def tune(args):
gradient_accumulate_steps=args.gradient_accumulate_steps, layer_config=layer_config,
enable_minmax_tuning=not args.disable_minmax_tuning, act_bits=args.act_bits,
low_cpu_mem_usage=low_cpu_mem_usage, data_type=args.data_type,
enable_norm_bias_tuning=args.enable_norm_bias_tuning, not_use_best_mse=args.not_use_best_mse)
enable_norm_bias_tuning=args.enable_norm_bias_tuning, not_use_best_mse=args.not_use_best_mse,
enable_torch_compile=args.enable_torch_compile)
model, _ = autoround.quantize()
model_name = args.model.rstrip("/")
if args.low_cpu_mem_mode == 1 or args.low_cpu_mem_mode == 2:
Expand Down Expand Up @@ -473,7 +482,8 @@ def eval(args):
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
args.device = ",".join(map(str, range(len(devices))))
devices = args.device.replace(" ", "").split(',')
parallelism = True
if len(devices) > 1:
parallelism = True
device_str = None
else:
device_str = detect_device(args.device.replace(" ", ""))
Expand Down
6 changes: 5 additions & 1 deletion auto_round/script/mllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ def __init__(self, *args, **kwargs):
self.add_argument("--not_use_best_mse", action='store_true',
help="whether to use the iter of best mes loss in the tuning phase")

self.add_argument("--enable_torch_compile", default=None, type=bool,
help="whether to enable torch compile")

## ======================= VLM =======================
self.add_argument("--quant_nontext_module", action='store_true',
help="whether to quantize non-text module, e.g. vision component")
Expand Down Expand Up @@ -337,7 +340,8 @@ def tune(args):
device=device_str, seed=args.seed, gradient_accumulate_steps=args.gradient_accumulate_steps,
scale_dtype=args.scale_dtype, layer_config=layer_config, template=args.template,
enable_minmax_tuning=not args.disable_minmax_tuning, act_bits=args.act_bits,
quant_nontext_module=args.quant_nontext_module, not_use_best_mse=args.not_use_best_mse)
quant_nontext_module=args.quant_nontext_module, not_use_best_mse=args.not_use_best_mse,
enable_torch_compile=args.enable_torch_compile)
model, _ = autoround.quantize()

model.eval()
Expand Down
2 changes: 1 addition & 1 deletion auto_round/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,7 +906,7 @@ def compile_func_on_hpu(func):


def compile_func_on_cuda_or_cpu(func, enable_torch_compile):
if enable_torch_compile or TORCH_VERSION_AT_LEAST_2_6_PRE_RELEASE:
if enable_torch_compile or (TORCH_VERSION_AT_LEAST_2_6_PRE_RELEASE and enable_torch_compile!=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can it be simplified to 'if enable_torch_compile' ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when enable_torch_compile is set to None and torch version is >=2.6, automatically set enable_torch_compile to True

return torch.compile(func)
else:
return func
Expand Down
23 changes: 16 additions & 7 deletions examples/language-modeling/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@
parser.add_argument("--fp_layers", default="", type=str,
help="List of Layers to maintain original data type")

parser.add_argument("--enable_torch_compile", default=None, type=bool,
help="whether to enable torch compile")

args = parser.parse_args()

print(
Expand All @@ -151,7 +154,7 @@

if args.format is None:
args.format = "auto_round"
supported_formats = ["auto_round", "auto_gptq", "auto_awq", "auto_round:auto_gptq","auto_round:auto_awq",
supported_formats = ["auto_round", "auto_gptq", "auto_awq", "auto_round:auto_gptq", "auto_round:auto_awq",
"auto_gptq:marlin", "itrex", "iterx_xpu", "fake"]
formats = args.format.replace(' ', '').split(",")
for format in formats:
Expand All @@ -176,7 +179,8 @@
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
args.device = ",".join(map(str, range(len(devices))))
devices = args.device.replace(" ", "").split(',')
use_auto_mapping = True
if len(devices) > 1:
use_auto_mapping = True

import torch
import transformers
Expand Down Expand Up @@ -260,10 +264,14 @@
seqlen = args.seqlen

if args.model_dtype != None:
if args.model_dtype == "float16" or args.model_dtype == "fp16":
model = model.to(torch.float16)
if args.model_dtype == "bfloat16" or args.model_dtype == "bfp16":
model = model.to(torch.bfloat16)
try:
if args.model_dtype == "float16" or args.model_dtype == "fp16":
model = model.to(torch.float16)
if args.model_dtype == "bfloat16" or args.model_dtype == "bfp16":
model = model.to(torch.bfloat16)
except:
print("please use more device,e.g `--devices 0,1,2,3` to fit the device or just use one device")
exit()

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code)

Expand Down Expand Up @@ -340,7 +348,8 @@
scale_dtype=args.scale_dtype, layer_config=layer_config,
enable_minmax_tuning=not args.disable_minmax_tuning, act_bits=args.act_bits,
low_cpu_mem_usage=low_cpu_mem_usage, data_type=args.data_type,
enable_norm_bias_tuning=args.enable_norm_bias_tuning)
enable_norm_bias_tuning=args.enable_norm_bias_tuning,
enable_torch_compile=args.enable_torch_compile)
model, _ = autoround.quantize()
model_name = args.model_name.rstrip("/")
if args.low_cpu_mem_mode == 1 or args.low_cpu_mem_mode == 2:
Expand Down
8 changes: 4 additions & 4 deletions test/test_mllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def test_tune(self):
bits, group_size = 4, 128
autoround = AutoRoundMLLM(
model, tokenizer, bits=bits, group_size=group_size,
nsamples=2,
batch_size=1, iters=2, dataset=self.dataset)
nsamples=1,
batch_size=1, iters=2, dataset=self.dataset,seqlen=256)
autoround.quantize()
autoround.save_quantized("./saved/", format="auto_gptq", inplace=False)
autoround.save_quantized("./saved/", format="auto_round", inplace=False)
Expand All @@ -63,8 +63,8 @@ def test_quant_vision(self): ## bug need to fix
bits, group_size = 4, 128
autoround = AutoRoundMLLM(
model, tokenizer, bits=bits, group_size=group_size,
nsamples=2,
batch_size=1, iters=2, dataset=self.dataset, quant_nontext_module=True)
nsamples=1,
batch_size=1, iters=2, dataset=self.dataset, quant_nontext_module=True,seqlen=256)
autoround.quantize()
autoround.save_quantized("./saved/", format="auto_round", inplace=True)

Expand Down
Loading