Skip to content

Commit

Permalink
simplify resumable checkpoint CLI fn to a single CLI arg
Browse files Browse the repository at this point in the history
  • Loading branch information
jimzers committed Jan 19, 2023
1 parent 9b5f163 commit bd8b2da
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 12 deletions.
2 changes: 0 additions & 2 deletions sleap/nn/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,11 +653,9 @@ class ModelConfig:
Attributes:
backbone: Configurations related to the main network architecture.
heads: Configurations related to the output heads.
resume_training: If `True`, resume training from the latest checkpoint.
base_checkpoint: Path to model folder for loading a checkpoint. Should contain the .h5 file
"""

backbone: BackboneConfig = attr.ib(factory=BackboneConfig)
heads: HeadsConfig = attr.ib(factory=HeadsConfig)
resume_training: bool = False
base_checkpoint: Optional[Text] = None
10 changes: 2 additions & 8 deletions sleap/nn/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ def _setup_model(self):
logger.info(f" [{i}] = {output}")

# Resuming training if flagged
if self.config.model.resume_training:
if self.config.model.base_checkpoint:
# grab the 'best_model.h5' file from the previous training run
# and load it into the current model
previous_model_path = os.path.join(
Expand Down Expand Up @@ -1868,14 +1868,9 @@ def main(args: Optional[List] = None):
parser.add_argument("--prefix", default="", help="Prefix to prepend to run name.")
parser.add_argument("--suffix", default="", help="Suffix to append to run name.")

parser.add_argument(
"--resume",
action="store_true",
help=("Resume training from last checkpoint."),
)
parser.add_argument(
"--base_checkpoint",
default="",
type=str,
help=("Path to base checkpoint to resume training from."),
)

Expand Down Expand Up @@ -1938,7 +1933,6 @@ def main(args: Optional[List] = None):
if len(args.video_paths) == 0:
args.video_paths = None

job_config.model.resume_training = args.resume
job_config.model.base_checkpoint = args.base_checkpoint

logger.info("Versions:")
Expand Down
2 changes: 0 additions & 2 deletions tests/nn/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def test_train_load_single_instance(min_labels_robot, cfg, tmp_path):
trainer.train()

# now load a new model and resume the checkpoint
cfg.model.resume_training = True
# set the model checkpoint folder
cfg.model.base_checkpoint = cfg.outputs.run_path
# unset save directory
Expand Down Expand Up @@ -336,7 +335,6 @@ def test_resume_training_cli(tmp_path, training_labels, cfg):
json_path,
"--labels_path",
labels_path,
"--resume",
"--base_checkpoint",
base_checkpoint_path,
]
Expand Down

0 comments on commit bd8b2da

Please sign in to comment.