Skip to content

Commit 28eae8b

Browse files
authored
Add weights_only=True to torch.load (#37062)
1 parent bf46e44 commit 28eae8b

File tree

106 files changed

+161
-136
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

106 files changed

+161
-136
lines changed

src/transformers/data/datasets/glue.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def __init__(
122122
with FileLock(lock_path):
123123
if os.path.exists(cached_features_file) and not args.overwrite_cache:
124124
start = time.time()
125-
self.features = torch.load(cached_features_file)
125+
self.features = torch.load(cached_features_file, weights_only=True)
126126
logger.info(
127127
f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
128128
)

src/transformers/models/bark/convert_suno_to_hf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def _load_model(ckpt_path, device, use_small=False, model_type="text"):
109109
if not os.path.exists(ckpt_path):
110110
logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.")
111111
_download(model_info["repo_id"], model_info["file_name"])
112-
checkpoint = torch.load(ckpt_path, map_location=device)
112+
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)
113113
# this is a hack
114114
model_args = checkpoint["model_args"]
115115
if "input_vocab_size" not in model_args:

src/transformers/models/bart/convert_bart_original_pytorch_checkpoint_to_pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def rename_key(dct, old, new):
7171

7272
def load_xsum_checkpoint(checkpoint_path):
7373
"""Checkpoint path should end in model.pt"""
74-
sd = torch.load(checkpoint_path, map_location="cpu")
74+
sd = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
7575
hub_interface = torch.hub.load("pytorch/fairseq", "bart.large.cnn").eval()
7676
hub_interface.model.load_state_dict(sd["model"])
7777
return hub_interface

src/transformers/models/bert/convert_bert_pytorch_checkpoint_to_original_tf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def main(raw_args=None):
101101

102102
model = BertModel.from_pretrained(
103103
pretrained_model_name_or_path=args.model_name,
104-
state_dict=torch.load(args.pytorch_model_path),
104+
state_dict=torch.load(args.pytorch_model_path, weights_only=True),
105105
cache_dir=args.cache_dir,
106106
)
107107

src/transformers/models/biogpt/convert_biogpt_original_pytorch_checkpoint_to_pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def convert_biogpt_checkpoint_to_pytorch(biogpt_checkpoint_path, pytorch_dump_fo
168168
checkpoint_file = os.path.join(biogpt_checkpoint_path, "checkpoint.pt")
169169
if not os.path.isfile(checkpoint_file):
170170
raise ValueError(f"path to the file {checkpoint_file} does not exist!")
171-
chkpt = torch.load(checkpoint_file, map_location="cpu")
171+
chkpt = torch.load(checkpoint_file, map_location="cpu", weights_only=True)
172172

173173
args = chkpt["cfg"]["model"]
174174

src/transformers/models/blenderbot/convert_blenderbot_original_pytorch_checkpoint_to_pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def convert_parlai_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_
7979
"""
8080
Copy/paste/tweak model's weights to our BERT structure.
8181
"""
82-
model = torch.load(checkpoint_path, map_location="cpu")
82+
model = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
8383
sd = model["model"]
8484
cfg = BlenderbotConfig.from_json_file(config_json_path)
8585
m = BlenderbotForConditionalGeneration(cfg)

src/transformers/models/bloom/convert_bloom_original_checkpoint_to_pytorch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def convert_bloom_checkpoint_to_pytorch(
104104
for i in range(pretraining_tp):
105105
# load all TP files
106106
f_name = file.replace("model_00", f"model_0{i}")
107-
temp = torch.load(os.path.join(bloom_checkpoint_path, f_name), map_location="cpu")
107+
temp = torch.load(os.path.join(bloom_checkpoint_path, f_name), map_location="cpu", weights_only=True)
108108

109109
# Rename keys in the transformers names
110110
keys = list(temp.keys())
@@ -164,7 +164,7 @@ def convert_bloom_checkpoint_to_pytorch(
164164
for i in range(pretraining_tp):
165165
# load all TP files
166166
f_name = file.replace("model_00", f"model_0{i}")
167-
temp = torch.load(os.path.join(bloom_checkpoint_path, f_name), map_location="cpu")
167+
temp = torch.load(os.path.join(bloom_checkpoint_path, f_name), map_location="cpu", weights_only=True)
168168

169169
# Rename keys in the transformers names
170170
keys = list(temp.keys())

src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,13 +130,15 @@ def write_model(model_path, input_base_path, model_size, chameleon_version=1):
130130
for possible_name in ["consolidated.pth", "consolidated.00.pth"]:
131131
possible_path = os.path.join(input_model_path, possible_name)
132132
if os.path.exists(possible_path):
133-
loaded = torch.load(possible_path, map_location="cpu")
133+
loaded = torch.load(possible_path, map_location="cpu", weights_only=True)
134134
break
135135
assert loaded is not None
136136
else:
137137
# Sharded
138138
loaded = [
139-
torch.load(os.path.join(input_model_path, f"consolidated.{i:02d}.pth"), map_location="cpu")
139+
torch.load(
140+
os.path.join(input_model_path, f"consolidated.{i:02d}.pth"), map_location="cpu", weights_only=True
141+
)
140142
for i in range(num_shards)
141143
]
142144

@@ -314,7 +316,7 @@ def permute(w, n_heads, dim1=dim, dim2=dim):
314316

315317
# Load VQGAN weights
316318
vqgan_path = os.path.join(input_base_path, "tokenizer/vqgan.ckpt")
317-
vqgan_state_dict = torch.load(vqgan_path, map_location="cpu")["state_dict"]
319+
vqgan_state_dict = torch.load(vqgan_path, map_location="cpu", weights_only=True)["state_dict"]
318320
for k, v in vqgan_state_dict.items():
319321
if "decoder" in k:
320322
continue # we dont do image generation yet

src/transformers/models/chinese_clip/convert_chinese_clip_original_pytorch_to_hf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def convert_chinese_clip_checkpoint(checkpoint_path, pytorch_dump_folder_path, c
104104

105105
hf_model = ChineseCLIPModel(config).eval()
106106

107-
pt_weights = torch.load(checkpoint_path, map_location="cpu")["state_dict"]
107+
pt_weights = torch.load(checkpoint_path, map_location="cpu", weights_only=True)["state_dict"]
108108
pt_weights = {(name[7:] if name.startswith("module.") else name): value for name, value in pt_weights.items()}
109109

110110
copy_text_model_and_projection(hf_model, pt_weights)

src/transformers/models/clipseg/convert_clipseg_original_pytorch_to_hf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def convert_clipseg_checkpoint(model_name, checkpoint_path, pytorch_dump_folder_
169169
model = CLIPSegForImageSegmentation(config)
170170
model.eval()
171171

172-
state_dict = torch.load(checkpoint_path, map_location="cpu")
172+
state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
173173

174174
# remove some keys
175175
for key in state_dict.copy().keys():

0 commit comments

Comments
 (0)