Skip to content

Commit

Permalink
[XLA] Re-enable broken _tpu_save for XLATensors, by explicitly moving…
Browse files Browse the repository at this point in the history
… to cpu
  • Loading branch information
yeounoh committed Dec 1, 2023
1 parent 2c658b5 commit 0e18b25
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2831,20 +2831,20 @@ def _save_tpu(self, output_dir: Optional[str] = None):
xm.rendezvous("saving_checkpoint")
if not isinstance(self.model, PreTrainedModel):
if isinstance(unwrap_model(self.model), PreTrainedModel):
unwrap_model(self.model).save_pretrained(
unwrap_model(self.model).to('cpu').save_pretrained(
output_dir,
is_main_process=self.args.should_save,
state_dict=self.model.state_dict(),
save_function=xm.save,
)
else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
state_dict = self.model.state_dict()
state_dict = self.model.state_dict().to('cpu')
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
self.model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save)
self.model.to('cpu').save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save)
if self.tokenizer is not None and self.args.should_save:
self.tokenizer.save_pretrained(output_dir)
self.tokenizer.to('cpu').save_pretrained(output_dir)

def _save(self, output_dir: Optional[str] = None, state_dict=None):
# If we are executing this function, we are the process zero, so we don't check for that.
Expand Down

0 comments on commit 0e18b25

Please sign in to comment.