Skip to content

Commit d3d6bb1

Browse files
authored
Set weights_only=True when using torch.load() (#12366)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
1 parent 24b0205 commit d3d6bb1

File tree

4 files changed

+10
-6
lines changed

4 files changed

+10
-6
lines changed

vllm/assets/image.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,4 @@ def image_embeds(self) -> torch.Tensor:
2626
"""
2727
image_path = get_vllm_public_assets(filename=f"{self.name}.pt",
2828
s3_prefix=VLM_IMAGES_DIR)
29-
return torch.load(image_path, map_location="cpu")
29+
return torch.load(image_path, map_location="cpu", weights_only=True)

vllm/lora/models.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,8 @@ def from_local_checkpoint(
273273
new_embeddings_tensor_path)
274274
elif os.path.isfile(new_embeddings_bin_file_path):
275275
embeddings = torch.load(new_embeddings_bin_file_path,
276-
map_location=device)
276+
map_location=device,
277+
weights_only=True)
277278

278279
return cls.from_lora_tensors(
279280
lora_model_id=get_lora_id()

vllm/model_executor/model_loader/weight_utils.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def convert_bin_to_safetensor_file(
9393
pt_filename: str,
9494
sf_filename: str,
9595
) -> None:
96-
loaded = torch.load(pt_filename, map_location="cpu")
96+
loaded = torch.load(pt_filename, map_location="cpu", weights_only=True)
9797
if "state_dict" in loaded:
9898
loaded = loaded["state_dict"]
9999
shared = _shared_pointers(loaded)
@@ -381,7 +381,9 @@ def np_cache_weights_iterator(
381381
disable=not enable_tqdm,
382382
bar_format=_BAR_FORMAT,
383383
):
384-
state = torch.load(bin_file, map_location="cpu")
384+
state = torch.load(bin_file,
385+
map_location="cpu",
386+
weights_only=True)
385387
for name, param in state.items():
386388
param_path = os.path.join(np_folder, name)
387389
with open(param_path, "wb") as f:
@@ -447,7 +449,7 @@ def pt_weights_iterator(
447449
disable=not enable_tqdm,
448450
bar_format=_BAR_FORMAT,
449451
):
450-
state = torch.load(bin_file, map_location="cpu")
452+
state = torch.load(bin_file, map_location="cpu", weights_only=True)
451453
yield from state.items()
452454
del state
453455
torch.cuda.empty_cache()

vllm/prompt_adapter/utils.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def load_peft_weights(model_id: str,
8989
adapters_weights = safe_load_file(filename, device=device)
9090
else:
9191
adapters_weights = torch.load(filename,
92-
map_location=torch.device(device))
92+
map_location=torch.device(device),
93+
weights_only=True)
9394

9495
return adapters_weights

0 commit comments

Comments
 (0)