Skip to content

Commit

Permalink
black linting for cli resumable training
Browse files Browse the repository at this point in the history
  • Loading branch information
jimzers committed Jan 19, 2023
1 parent 34dbbdb commit 9b5f163
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 14 deletions.
8 changes: 2 additions & 6 deletions sleap/nn/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1871,16 +1871,12 @@ def main(args: Optional[List] = None):
parser.add_argument(
"--resume",
action="store_true",
help=(
"Resume training from last checkpoint."
),
help=("Resume training from last checkpoint."),
)
parser.add_argument(
"--base_checkpoint",
default="",
help=(
"Path to base checkpoint to resume training from."
),
help=("Path to base checkpoint to resume training from."),
)

device_group = parser.add_mutually_exclusive_group(required=False)
Expand Down
27 changes: 19 additions & 8 deletions tests/nn/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ def test_train_cropping(
== 0
)


def test_resume_training_cli(tmp_path, training_labels, cfg):
"""
Test CLI to resume training.
Expand All @@ -319,14 +320,24 @@ def test_resume_training_cli(tmp_path, training_labels, cfg):

# path to sleap/tests/data/models/minimal_robot.UNet.single_instance
base_checkpoint_path = os.path.join(
os.path.dirname(__file__), "..", "data", "models", "minimal_robot.UNet.single_instance"
)
json_path = os.path.join(
base_checkpoint_path, "training_config.json"
)
labels_path = os.path.join(
base_checkpoint_path, "labels_gt.slp"
os.path.dirname(__file__),
"..",
"data",
"models",
"minimal_robot.UNet.single_instance",
)
json_path = os.path.join(base_checkpoint_path, "training_config.json")
labels_path = os.path.join(base_checkpoint_path, "labels_gt.slp")

# run CLI to resume training
main(["--training_job_path", json_path, "--labels_path", labels_path, "--resume", "--base_checkpoint", base_checkpoint_path])
main(
[
"--training_job_path",
json_path,
"--labels_path",
labels_path,
"--resume",
"--base_checkpoint",
base_checkpoint_path,
]
)

0 comments on commit 9b5f163

Please sign in to comment.