Skip to content

Commit

Permalink
When save a model, make a copy to be moved to CPU, dont move the orig…
Browse files Browse the repository at this point in the history
…inal

model
  • Loading branch information
qihqi committed Dec 13, 2023
1 parent f4db565 commit c658ffe
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2803,7 +2803,9 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa
output_dir = self.args.output_dir

if is_torch_tpu_available():
self._save_tpu(output_dir)
model = copy.deepcopy(self.model)
model.to("cpu")
self._save_tpu(model, output_dir)
elif is_sagemaker_mp_enabled():
# Calling the state_dict needs to be done on the wrapped model and on all processes.
os.makedirs(output_dir, exist_ok=True)
Expand Down Expand Up @@ -2843,7 +2845,7 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa
if self.args.push_to_hub and not _internal_call:
self.push_to_hub(commit_message="Model save")

def _save_tpu(self, output_dir: Optional[str] = None):
def _save_tpu(self, model, output_dir: Optional[str] = None):
output_dir = output_dir if output_dir is not None else self.args.output_dir
logger.info(f"Saving model checkpoint to {output_dir}")

Expand All @@ -2854,22 +2856,20 @@ def _save_tpu(self, output_dir: Optional[str] = None):
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
xm.rendezvous("saving_checkpoint")
if not isinstance(self.model, PreTrainedModel):
if isinstance(unwrap_model(self.model), PreTrainedModel):
unwrap_model(self.model).to("cpu").save_pretrained(
if not isinstance(model, PreTrainedModel):
if isinstance(unwrap_model(model), PreTrainedModel):
unwrap_model(model).save_pretrained(
output_dir,
is_main_process=self.args.should_save,
state_dict=self.model.state_dict(),
state_dict=model.state_dict(),
save_function=xm.save,
)
else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
state_dict = self.model.state_dict().to("cpu")
state_dict = model.state_dict()
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
self.model.to("cpu").save_pretrained(
output_dir, is_main_process=self.args.should_save, save_function=xm.save
)
model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save)
if self.tokenizer is not None and self.args.should_save:
self.tokenizer.save_pretrained(output_dir)

Expand Down

0 comments on commit c658ffe

Please sign in to comment.