diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 00f4bc8a5207ac..479bb602843f33 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2910,13 +2910,19 @@ def _save_tpu(self, output_dir: Optional[str] = None): is_main_process=self.args.should_save, state_dict=model.state_dict(), save_function=xm.save, + safe_serialization=self.args.save_safetensors, ) else: logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") state_dict = model.state_dict() xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) else: - model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save) + model.save_pretrained( + output_dir, + is_main_process=self.args.should_save, + save_function=xm.save, + safe_serialization=self.args.save_safetensors, + ) if self.tokenizer is not None and self.args.should_save: self.tokenizer.save_pretrained(output_dir)