diff --git a/examples/qwen/weight.py b/examples/qwen/weight.py index 2dd64666..d964fda1 100644 --- a/examples/qwen/weight.py +++ b/examples/qwen/weight.py @@ -713,7 +713,7 @@ def load_from_gptq_qwen( if quant_ckpt_path.endswith(".safetensors"): groupwise_qweight_safetensors = safe_open( - quant_ckpt_path, framework="pt", device=0 + quant_ckpt_path, framework="pt", device="cpu" ) model_params = { key: groupwise_qweight_safetensors.get_tensor(key) @@ -944,7 +944,7 @@ def load_from_awq_qwen(tensorrt_llm_qwen: QWenForCausalLM, if quant_ckpt_path.endswith(".safetensors"): groupwise_qweight_safetensors = safe_open(quant_ckpt_path, framework="pt", - device=0) + device="cpu") model_params = { key: groupwise_qweight_safetensors.get_tensor(key) for key in groupwise_qweight_safetensors.keys()