From b0e3a732425de29623d251865a465650d0ea213a Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Sat, 4 Nov 2023 13:21:55 +0100 Subject: [PATCH 1/2] make torch.load a bit safer --- .../convert_pytorch_checkpoint_to_tf2.py | 2 +- .../modeling_flax_pytorch_utils.py | 4 ++-- src/transformers/modeling_tf_pytorch_utils.py | 2 +- src/transformers/modeling_utils.py | 2 +- src/transformers/trainer.py | 20 +++++++++---------- 5 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/transformers/convert_pytorch_checkpoint_to_tf2.py b/src/transformers/convert_pytorch_checkpoint_to_tf2.py index f1358408a5cb57..f300b0bb92c661 100755 --- a/src/transformers/convert_pytorch_checkpoint_to_tf2.py +++ b/src/transformers/convert_pytorch_checkpoint_to_tf2.py @@ -329,7 +329,7 @@ def convert_pt_checkpoint_to_tf( if compare_with_pt_model: tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network - state_dict = torch.load(pytorch_checkpoint_path, map_location="cpu") + state_dict = torch.load(pytorch_checkpoint_path, map_location="cpu", weights_only=True) pt_model = pt_model_class.from_pretrained( pretrained_model_name_or_path=None, config=config, state_dict=state_dict ) diff --git a/src/transformers/modeling_flax_pytorch_utils.py b/src/transformers/modeling_flax_pytorch_utils.py index f78c4e78c78ba8..f6014d7c208ab6 100644 --- a/src/transformers/modeling_flax_pytorch_utils.py +++ b/src/transformers/modeling_flax_pytorch_utils.py @@ -68,7 +68,7 @@ def load_pytorch_checkpoint_in_flax_state_dict( for k in f.keys(): pt_state_dict[k] = f.get_tensor(k) else: - pt_state_dict = torch.load(pt_path, map_location="cpu") + pt_state_dict = torch.load(pt_path, map_location="cpu", weights_only=True) logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.") flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model) @@ -249,7 +249,7 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model): flax_state_dict = {} for shard_file in shard_filenames: # load using msgpack utils - pt_state_dict = torch.load(shard_file) + pt_state_dict = torch.load(shard_file, weights_only=True) pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} model_prefix = flax_model.base_model_prefix diff --git a/src/transformers/modeling_tf_pytorch_utils.py b/src/transformers/modeling_tf_pytorch_utils.py index c599b795bf1932..aca1b9e4d9dccf 100644 --- a/src/transformers/modeling_tf_pytorch_utils.py +++ b/src/transformers/modeling_tf_pytorch_utils.py @@ -186,7 +186,7 @@ def load_pytorch_checkpoint_in_tf2_model( if pt_path.endswith(".safetensors"): state_dict = safe_load_file(pt_path) else: - state_dict = torch.load(pt_path, map_location="cpu") + state_dict = torch.load(pt_path, map_location="cpu", weights_only=True) pt_state_dict.update(state_dict) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 7e5d3e54e619e8..dc725054e0dff4 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -516,7 +516,7 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]): else: map_location = "cpu" - return torch.load(checkpoint_file, map_location=map_location) + return torch.load(checkpoint_file, map_location=map_location, weights_only=True) except Exception as e: try: with open(checkpoint_file) as f: diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 3a4ff5528047ae..6d3bdd4b1b2d8b 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2086,7 +2086,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None): logger.warning( "Enabling FP16 and loading from smp < 1.10 checkpoint together is not suppported." ) - state_dict = torch.load(weights_file, map_location="cpu") + state_dict = torch.load(weights_file, map_location="cpu", weights_only=True) # Required for smp to not auto-translate state_dict from hf to smp (is already smp). state_dict["_smp_is_partial"] = False load_result = model.load_state_dict(state_dict, strict=True) @@ -2099,7 +2099,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None): if self.args.save_safetensors and os.path.isfile(safe_weights_file): state_dict = safetensors.torch.load_file(safe_weights_file, device="cpu") else: - state_dict = torch.load(weights_file, map_location="cpu") + state_dict = torch.load(weights_file, map_location="cpu", weights_only=True) # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963 # which takes *args instead of **kwargs @@ -2167,7 +2167,7 @@ def _load_best_model(self): if self.args.save_safetensors and os.path.isfile(best_safe_model_path): state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu") else: - state_dict = torch.load(best_model_path, map_location="cpu") + state_dict = torch.load(best_model_path, map_location="cpu", weights_only=True) state_dict["_smp_is_partial"] = False load_result = model.load_state_dict(state_dict, strict=True) @@ -2196,7 +2196,7 @@ def _load_best_model(self): if self.args.save_safetensors and os.path.isfile(best_safe_model_path): state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu") else: - state_dict = torch.load(best_model_path, map_location="cpu") + state_dict = torch.load(best_model_path, map_location="cpu", weights_only=True) # If the model is on the GPU, it still works! # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963 @@ -2300,7 +2300,7 @@ def _load_rng_state(self, checkpoint): ) return - checkpoint_rng_state = torch.load(rng_file) + checkpoint_rng_state = torch.load(rng_file, weights_only=True) random.setstate(checkpoint_rng_state["python"]) np.random.set_state(checkpoint_rng_state["numpy"]) torch.random.set_rng_state(checkpoint_rng_state["cpu"]) @@ -2479,7 +2479,7 @@ def _load_optimizer_and_scheduler(self, checkpoint): # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init if not isinstance(self.lr_scheduler, DeepSpeedSchedulerWrapper): with warnings.catch_warnings(record=True) as caught_warnings: - self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME))) + self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME), weights_only=True)) reissue_pt_warnings(caught_warnings) return @@ -2503,9 +2503,9 @@ def _load_optimizer_and_scheduler(self, checkpoint): # Load in optimizer and scheduler states if is_torch_tpu_available(): # On TPU we have to take some extra precautions to properly load the states on the right device. - optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu") + optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu", weights_only=True) with warnings.catch_warnings(record=True) as caught_warnings: - lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu") + lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu", weights_only=True) reissue_pt_warnings(caught_warnings) xm.send_cpu_data_to_device(optimizer_state, self.args.device) @@ -2546,10 +2546,10 @@ def opt_load_hook(mod, opt): ) else: self.optimizer.load_state_dict( - torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location) + torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location, weights_only=True) ) with warnings.catch_warnings(record=True) as caught_warnings: - self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME))) + self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME), weights_only=True)) reissue_pt_warnings(caught_warnings) def hyperparameter_search( From 3323d7b396225f5be5b62089091f882af3c675f6 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Fri, 15 Dec 2023 15:20:42 +0100 Subject: [PATCH 2/2] Fixes --- src/transformers/modeling_utils.py | 2 +- .../models/wav2vec2/modeling_wav2vec2.py | 2 +- src/transformers/trainer.py | 12 ++++++------ 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index dc725054e0dff4..8be9709d072afe 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -480,7 +480,7 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True): error_message += f"\nMissing key(s): {str_unexpected_keys}." raise RuntimeError(error_message) - loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu") + loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu", weights_only=True) for shard_file in shard_files: state_dict = loader(os.path.join(folder, shard_file)) diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 3d97e7c73d3522..ddfa2e21263f0f 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -1333,7 +1333,7 @@ def load_adapter(self, target_lang: str, force_load=True, **kwargs): cache_dir=cache_dir, ) - state_dict = torch.load(weight_path, map_location="cpu") + state_dict = torch.load(weight_path, map_location="cpu", weights_only=True) except EnvironmentError: # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 6d3bdd4b1b2d8b..0b56488907fc17 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2300,7 +2300,7 @@ def _load_rng_state(self, checkpoint): ) return - checkpoint_rng_state = torch.load(rng_file, weights_only=True) + checkpoint_rng_state = torch.load(rng_file) random.setstate(checkpoint_rng_state["python"]) np.random.set_state(checkpoint_rng_state["numpy"]) torch.random.set_rng_state(checkpoint_rng_state["cpu"]) @@ -2479,7 +2479,7 @@ def _load_optimizer_and_scheduler(self, checkpoint): # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init if not isinstance(self.lr_scheduler, DeepSpeedSchedulerWrapper): with warnings.catch_warnings(record=True) as caught_warnings: - self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME), weights_only=True)) + self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME))) reissue_pt_warnings(caught_warnings) return @@ -2503,9 +2503,9 @@ def _load_optimizer_and_scheduler(self, checkpoint): # Load in optimizer and scheduler states if is_torch_tpu_available(): # On TPU we have to take some extra precautions to properly load the states on the right device. - optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu", weights_only=True) + optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu") with warnings.catch_warnings(record=True) as caught_warnings: - lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu", weights_only=True) + lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu") reissue_pt_warnings(caught_warnings) xm.send_cpu_data_to_device(optimizer_state, self.args.device) @@ -2546,10 +2546,10 @@ def opt_load_hook(mod, opt): ) else: self.optimizer.load_state_dict( - torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location, weights_only=True) + torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location) ) with warnings.catch_warnings(record=True) as caught_warnings: - self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME), weights_only=True)) + self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME))) reissue_pt_warnings(caught_warnings) def hyperparameter_search(