Skip to content

Commit

Permalink
Fix checkpoint loading with run_id (#1403)
Browse files Browse the repository at this point in the history
* fix

* replace change

* use os
Louis-Dupont authored Aug 23, 2023
1 parent 54e5a23 commit a6d9003
Showing 2 changed files with 2 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -108,6 +108,7 @@ def get_checkpoints_dir_path(experiment_name: str, ckpt_root_dir: Optional[str]
"""
experiment_dir = get_experiment_dir_path(checkpoints_root_dir=ckpt_root_dir, experiment_name=experiment_name)
checkpoint_dir = experiment_dir if run_id is None else os.path.join(experiment_dir, run_id)
os.makedirs(checkpoint_dir, exist_ok=True)
return checkpoint_dir


2 changes: 1 addition & 1 deletion src/super_gradients/common/sg_loggers/base_sg_logger.py
Original file line number Diff line number Diff line change
@@ -133,7 +133,7 @@ def _setup_dir(self):
# Only if it exists, i.e. if hydra was used.
if os.path.exists(source_hydra_path):
destination_hydra_path = os.path.join(self._local_dir, ".hydra")
shutil.copytree(source_hydra_path, destination_hydra_path)
shutil.copytree(source_hydra_path, destination_hydra_path, dirs_exist_ok=True)

@multi_process_safe
def _init_log_file(self):

0 comments on commit a6d9003

Please sign in to comment.