From c658ffe69789c746d0d9d938866e9c7d54b4cc71 Mon Sep 17 00:00:00 2001 From: Han Qi Date: Wed, 13 Dec 2023 01:58:43 +0000 Subject: [PATCH] When save a model, make a copy to be moved to CPU, dont move the original model --- src/transformers/trainer.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index d6ccc4334dd46d..2a38d522468eb7 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2803,7 +2803,9 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa output_dir = self.args.output_dir if is_torch_tpu_available(): - self._save_tpu(output_dir) + model = copy.deepcopy(self.model) + model.to("cpu") + self._save_tpu(model, output_dir) elif is_sagemaker_mp_enabled(): # Calling the state_dict needs to be done on the wrapped model and on all processes. os.makedirs(output_dir, exist_ok=True) @@ -2843,7 +2845,7 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa if self.args.push_to_hub and not _internal_call: self.push_to_hub(commit_message="Model save") - def _save_tpu(self, output_dir: Optional[str] = None): + def _save_tpu(self, model, output_dir: Optional[str] = None): output_dir = output_dir if output_dir is not None else self.args.output_dir logger.info(f"Saving model checkpoint to {output_dir}") @@ -2854,22 +2856,20 @@ def _save_tpu(self, output_dir: Optional[str] = None): # Save a trained model and configuration using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` xm.rendezvous("saving_checkpoint") - if not isinstance(self.model, PreTrainedModel): - if isinstance(unwrap_model(self.model), PreTrainedModel): - unwrap_model(self.model).to("cpu").save_pretrained( + if not isinstance(model, PreTrainedModel): + if isinstance(unwrap_model(model), PreTrainedModel): + unwrap_model(model).save_pretrained( output_dir, is_main_process=self.args.should_save, - state_dict=self.model.state_dict(), + state_dict=model.state_dict(), save_function=xm.save, ) else: logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") - state_dict = self.model.state_dict().to("cpu") + state_dict = model.state_dict() xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) else: - self.model.to("cpu").save_pretrained( - output_dir, is_main_process=self.args.should_save, save_function=xm.save - ) + model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save) if self.tokenizer is not None and self.args.should_save: self.tokenizer.save_pretrained(output_dir)