Skip to content

Commit

Permalink
Revert "remove unnecessary for loop, adjust tests, change default thr…
Browse files Browse the repository at this point in the history
…eshold"

This reverts commit 3baf219, reversing
changes made to 8cc046c.
  • Loading branch information
grquach committed Jul 17, 2024
1 parent 3baf219 commit 61fe572
Show file tree
Hide file tree
Showing 31 changed files with 17 additions and 114 deletions.
6 changes: 2 additions & 4 deletions docs/guides/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ optional arguments:

```none
usage: sleap-train [-h] [--video-paths VIDEO_PATHS] [--val_labels VAL_LABELS]
[--test_labels TEST_LABELS] [--tensorboard] [--save_viz]
[--keep_viz] [--zmq] [--run_name RUN_NAME] [--prefix PREFIX]
[--test_labels TEST_LABELS] [--tensorboard] [--save_viz]
[--zmq] [--run_name RUN_NAME] [--prefix PREFIX]
[--suffix SUFFIX]
training_job_path [labels_path]
Expand Down Expand Up @@ -68,8 +68,6 @@ optional arguments:
--save_viz Enable saving of prediction visualizations to the run
folder if not already specified in the training job
config.
--keep_viz Keep prediction visualization images in the run
folder after training if --save_viz is enabled.
--zmq Enable ZMQ logging (for GUI) if not already specified
in the training job config.
--run_name RUN_NAME Run name to use when saving file, overrides other run
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@
" \"runs_folder\": \"models\",\n",
" \"tags\": [],\n",
" \"save_visualizations\": true,\n",
" \"keep_viz_images\": true,\n",
" \"delete_viz_images\": true,\n",
" \"zip_outputs\": false,\n",
" \"log_to_csv\": true,\n",
" \"checkpointing\": {\n",
Expand Down Expand Up @@ -727,7 +727,7 @@
" \"runs_folder\": \"models\",\n",
" \"tags\": [],\n",
" \"save_visualizations\": true,\n",
" \"keep_viz_images\": true,\n",
" \"delete_viz_images\": true,\n",
" \"zip_outputs\": false,\n",
" \"log_to_csv\": true,\n",
" \"checkpointing\": {\n",
Expand Down
5 changes: 0 additions & 5 deletions sleap/config/pipeline_form.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -286,11 +286,6 @@ training:
type: bool
default: true

- name: _keep_viz
label: Keep Prediction Visualization Images After Training
type: bool
default: false

- name: _predict_frames
label: Predict On
type: list
Expand Down
17 changes: 3 additions & 14 deletions sleap/gui/learning/runners.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Run training/inference in background process via CLI."""

import abc
import attr
import os
Expand Down Expand Up @@ -501,11 +500,9 @@ def write_pipeline_files(
"data_path": os.path.basename(data_path),
"models": [Path(p).as_posix() for p in new_cfg_filenames],
"output_path": prediction_output_path,
"type": (
"labels"
if type(item_for_inference) == DatasetItemForInference
else "video"
),
"type": "labels"
if type(item_for_inference) == DatasetItemForInference
else "video",
"only_suggested_frames": only_suggested_frames,
"tracking": tracking_args,
}
Expand Down Expand Up @@ -547,7 +544,6 @@ def run_learning_pipeline(
"""

save_viz = inference_params.get("_save_viz", False)
keep_viz = inference_params.get("_keep_viz", False)

if "movenet" in inference_params["_pipeline"]:
trained_job_paths = [inference_params["_pipeline"]]
Expand All @@ -561,7 +557,6 @@ def run_learning_pipeline(
inference_params=inference_params,
gui=True,
save_viz=save_viz,
keep_viz=keep_viz,
)

# Check that all the models were trained
Expand Down Expand Up @@ -590,7 +585,6 @@ def run_gui_training(
inference_params: Dict[str, Any],
gui: bool = True,
save_viz: bool = False,
keep_viz: bool = False,
) -> Dict[Text, Text]:
"""
Runs training for each training job.
Expand All @@ -600,7 +594,6 @@ def run_gui_training(
config_info_list: List of ConfigFileInfo with configs for training.
gui: Whether to show gui windows and process gui events.
save_viz: Whether to save visualizations from training.
keep_viz: Whether to keep prediction visualization images after training.
Returns:
Dictionary, keys are head name, values are path to trained config.
Expand Down Expand Up @@ -690,7 +683,6 @@ def waiting():
video_paths=video_path_list,
waiting_callback=waiting,
save_viz=save_viz,
keep_viz=keep_viz,
)

if ret == "success":
Expand Down Expand Up @@ -833,7 +825,6 @@ def train_subprocess(
video_paths: Optional[List[Text]] = None,
waiting_callback: Optional[Callable] = None,
save_viz: bool = False,
keep_viz: bool = False,
):
"""Runs training inside subprocess."""
run_path = job_config.outputs.run_path
Expand Down Expand Up @@ -862,8 +853,6 @@ def train_subprocess(

if save_viz:
cli_args.append("--save_viz")
if keep_viz:
cli_args.append("--keep_viz")

# Use cli arg since cli ignores setting in config
if job_config.outputs.tensorboard.write_logs:
Expand Down
6 changes: 3 additions & 3 deletions sleap/nn/config/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ class OutputsConfig:
save_visualizations: If True, will render and save visualizations of the model
predictions as PNGs to "{run_folder}/viz/{split}.{epoch:04d}.png", where the
split is one of "train", "validation", "test".
keep_viz_images: If True, keep the saved visualization images after training
completes. This is useful unchecked to reduce the model folder size if you do not need
delete_viz_images: If True, delete the saved visualizations after training
completes. This is useful to reduce the model folder size if you do not need
to keep the visualization images.
zip_outputs: If True, compress the run folder to a zip file. This will be named
"{run_folder}.zip".
Expand All @@ -170,7 +170,7 @@ class OutputsConfig:
runs_folder: Text = "models"
tags: List[Text] = attr.ib(factory=list)
save_visualizations: bool = True
keep_viz_images: bool = False
delete_viz_images: bool = True
zip_outputs: bool = False
log_to_csv: bool = True
checkpointing: CheckpointingConfig = attr.ib(factory=CheckpointingConfig)
Expand Down
13 changes: 2 additions & 11 deletions sleap/nn/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,7 +946,7 @@ def train(self):
if self.config.outputs.save_outputs:
if (
self.config.outputs.save_visualizations
and not self.config.outputs.keep_viz_images
and self.config.outputs.delete_viz_images
):
self.cleanup()

Expand Down Expand Up @@ -997,7 +997,7 @@ def cleanup(self):

def package(self):
"""Package model folder into a zip file for portability."""
if not self.config.outputs.keep_viz_images:
if self.config.outputs.delete_viz_images:
self.cleanup()
logger.info(f"Packaging results to: {self.run_path}.zip")
shutil.make_archive(
Expand Down Expand Up @@ -1864,14 +1864,6 @@ def create_trainer_using_cli(args: Optional[List] = None):
"already specified in the training job config."
),
)
parser.add_argument(
"--keep_viz",
action="store_true",
help=(
"Keep prediction visualization images in the run folder after training when "
"--save_viz is enabled."
),
)
parser.add_argument(
"--zmq",
action="store_true",
Expand Down Expand Up @@ -1957,7 +1949,6 @@ def create_trainer_using_cli(args: Optional[List] = None):
if args.suffix != "":
job_config.outputs.run_name_suffix = args.suffix
job_config.outputs.save_visualizations |= args.save_viz
job_config.outputs.keep_viz_images = args.keep_viz
if args.labels_path == "":
args.labels_path = None
args.video_paths = args.video_paths.split(",")
Expand Down
1 change: 0 additions & 1 deletion sleap/training_profiles/baseline.centroid.json
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
1 change: 0 additions & 1 deletion sleap/training_profiles/baseline_large_rf.bottomup.json
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
1 change: 0 additions & 1 deletion sleap/training_profiles/baseline_large_rf.single.json
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
1 change: 0 additions & 1 deletion sleap/training_profiles/baseline_large_rf.topdown.json
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
1 change: 0 additions & 1 deletion sleap/training_profiles/baseline_medium_rf.bottomup.json
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
1 change: 0 additions & 1 deletion sleap/training_profiles/baseline_medium_rf.single.json
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
1 change: 0 additions & 1 deletion sleap/training_profiles/baseline_medium_rf.topdown.json
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
1 change: 0 additions & 1 deletion sleap/training_profiles/pretrained.bottomup.json
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
1 change: 0 additions & 1 deletion sleap/training_profiles/pretrained.centroid.json
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
1 change: 0 additions & 1 deletion sleap/training_profiles/pretrained.single.json
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
1 change: 0 additions & 1 deletion sleap/training_profiles/pretrained.topdown.json
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": false,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": false,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": false,
"keep_viz_images": false,
"delete_viz_images": true,
"zip_outputs": false,
"log_to_csv": true,
"checkpointing": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": false,
"keep_viz_images": false,
"delete_viz_images": true,
"zip_outputs": false,
"log_to_csv": true,
"checkpointing": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": false,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": false,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": false,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": false,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": false,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": false,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@
""
],
"save_visualizations": false,
"keep_viz_images": true,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,6 @@
""
],
"save_visualizations": false,
"keep_viz_images": true,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
1 change: 1 addition & 0 deletions tests/gui/test_dialogs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Module to test the dialogs of the GUI (contained in sleap/gui/dialogs)."""


import os
from pathlib import Path

Expand Down
Loading

0 comments on commit 61fe572

Please sign in to comment.