From c0a3741d3e0c52125bf2ae2ccde082583c19657d Mon Sep 17 00:00:00 2001 From: jeffhataws Date: Wed, 24 Jan 2024 03:57:45 -0800 Subject: [PATCH] Use save_safetensor to disable safe serialization for XLA (#28669) * Use save_safetensor to disable safe serialization for XLA https://github.com/huggingface/transformers/issues/28438 * Style fixup --- src/transformers/trainer.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 00f4bc8a5207ac..479bb602843f33 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2910,13 +2910,19 @@ def _save_tpu(self, output_dir: Optional[str] = None): is_main_process=self.args.should_save, state_dict=model.state_dict(), save_function=xm.save, + safe_serialization=self.args.save_safetensors, ) else: logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") state_dict = model.state_dict() xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) else: - model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save) + model.save_pretrained( + output_dir, + is_main_process=self.args.should_save, + save_function=xm.save, + safe_serialization=self.args.save_safetensors, + ) if self.tokenizer is not None and self.args.should_save: self.tokenizer.save_pretrained(output_dir)