From 75c07c869dd5af6c91d42f68dfe90ea24ca7b354 Mon Sep 17 00:00:00 2001 From: Jeffrey Huynh Date: Tue, 23 Jan 2024 10:50:52 -0800 Subject: [PATCH] Use save_safetensor to disable safe serialization for XLA https://github.com/huggingface/transformers/issues/28438 --- src/transformers/trainer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 00f4bc8a5207..96ee66010604 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2910,13 +2910,17 @@ def _save_tpu(self, output_dir: Optional[str] = None): is_main_process=self.args.should_save, state_dict=model.state_dict(), save_function=xm.save, + safe_serialization=self.args.save_safetensors, ) else: logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") state_dict = model.state_dict() xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) else: - model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save) + model.save_pretrained( + output_dir, is_main_process=self.args.should_save, save_function=xm.save, + safe_serialization=self.args.save_safetensors + ) if self.tokenizer is not None and self.args.should_save: self.tokenizer.save_pretrained(output_dir)