Skip to content

Commit

Permalink
Use save_safetensor to disable safe serialization for XLA (#28669)
Browse files Browse the repository at this point in the history
* Use save_safetensor to disable safe serialization for XLA

#28438

* Style fixup
  • Loading branch information
jeffhataws authored Jan 24, 2024
1 parent c5c6909 commit 0549000
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2910,13 +2910,19 @@ def _save_tpu(self, output_dir: Optional[str] = None):
is_main_process=self.args.should_save,
state_dict=model.state_dict(),
save_function=xm.save,
safe_serialization=self.args.save_safetensors,
)
else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
state_dict = model.state_dict()
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save)
model.save_pretrained(
output_dir,
is_main_process=self.args.should_save,
save_function=xm.save,
safe_serialization=self.args.save_safetensors,
)
if self.tokenizer is not None and self.args.should_save:
self.tokenizer.save_pretrained(output_dir)

Expand Down

0 comments on commit 0549000

Please sign in to comment.