From c65ac8752652a577792e9d54d2901b353fc1d7fa Mon Sep 17 00:00:00 2001 From: Qing Date: Wed, 13 Dec 2023 14:03:18 +0800 Subject: [PATCH] load int4 gptq/awq safetensors weight in cpu to save vRAM --- examples/qwen/weight.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/qwen/weight.py b/examples/qwen/weight.py index 2dd64666..d964fda1 100644 --- a/examples/qwen/weight.py +++ b/examples/qwen/weight.py @@ -713,7 +713,7 @@ def load_from_gptq_qwen( if quant_ckpt_path.endswith(".safetensors"): groupwise_qweight_safetensors = safe_open( - quant_ckpt_path, framework="pt", device=0 + quant_ckpt_path, framework="pt", device="cpu" ) model_params = { key: groupwise_qweight_safetensors.get_tensor(key) @@ -944,7 +944,7 @@ def load_from_awq_qwen(tensorrt_llm_qwen: QWenForCausalLM, if quant_ckpt_path.endswith(".safetensors"): groupwise_qweight_safetensors = safe_open(quant_ckpt_path, framework="pt", - device=0) + device="cpu") model_params = { key: groupwise_qweight_safetensors.get_tensor(key) for key in groupwise_qweight_safetensors.keys()