From e3934198a3b625913f6b1faed8c2085ddd859512 Mon Sep 17 00:00:00 2001 From: jeffhataws Date: Wed, 24 Jan 2024 03:57:45 -0800 Subject: [PATCH] Use save_safetensor to disable safe serialization for XLA (#28669) * Use save_safetensor to disable safe serialization for XLA https://github.com/huggingface/transformers/issues/28438 * Style fixup --- src/transformers/trainer.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 3416bcd72ce378..9bc577c7b7e3f3 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2907,13 +2907,19 @@ def _save_tpu(self, output_dir: Optional[str] = None): is_main_process=self.args.should_save, state_dict=model.state_dict(), save_function=xm.save, + safe_serialization=self.args.save_safetensors, ) else: logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") state_dict = model.state_dict() xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) else: - model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save) + model.save_pretrained( + output_dir, + is_main_process=self.args.should_save, + save_function=xm.save, + safe_serialization=self.args.save_safetensors, + ) if self.tokenizer is not None and self.args.should_save: self.tokenizer.save_pretrained(output_dir)