From 206476e41522025347f0f30859f4b6a0cfdcd549 Mon Sep 17 00:00:00 2001 From: bghira Date: Fri, 11 Oct 2024 06:03:42 -0600 Subject: [PATCH] load state from input_dir --- helpers/training/save_hooks.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/helpers/training/save_hooks.py b/helpers/training/save_hooks.py index 47d9ef1f..61580bdf 100644 --- a/helpers/training/save_hooks.py +++ b/helpers/training/save_hooks.py @@ -491,11 +491,12 @@ def _load_full_model(self, models, input_dir): def load_model_hook(self, models, input_dir): # Check the checkpoint dir for a "training_state.json" file to load - if os.path.exists(self.training_state_path): - StateTracker.load_training_state(self.training_state_path) + training_state_path = os.path.join(input_dir, self.training_state_path) + if os.path.exists(training_state_path): + StateTracker.load_training_state(training_state_path) else: logger.warning( - f"Could not find {self.training_state_path} in checkpoint dir {input_dir}" + f"Could not find {training_state_path} in checkpoint dir {input_dir}" ) if "lora" in self.args.model_type and self.args.lora_type == "standard": self._load_lora(models=models, input_dir=input_dir)