diff --git a/docs/guides/cli.md b/docs/guides/cli.md index 827678ffb..17500786b 100644 --- a/docs/guides/cli.md +++ b/docs/guides/cli.md @@ -60,6 +60,9 @@ optional arguments: Path to labels file to use for test. If specified, overrides the path specified in the training job config. + --base_checkpoint BASE_CHECKPOINT + Path to base checkpoint (directory containing best_model.h5) + to resume training from. --tensorboard Enable TensorBoard logging to the run path if not already specified in the training job config. --save_viz Enable saving of prediction visualizations to the run diff --git a/sleap/gui/learning/dialog.py b/sleap/gui/learning/dialog.py index 3e0dd5b4d..a8ea86015 100644 --- a/sleap/gui/learning/dialog.py +++ b/sleap/gui/learning/dialog.py @@ -14,7 +14,7 @@ from sleap.gui.dialogs.formbuilder import YamlFormWidget from sleap.gui.learning import runners, scopedkeydict, configs, datagen, receptivefield -from typing import Dict, List, Optional, Text, Optional +from typing import Dict, List, Optional, Text, Optional, cast from qtpy import QtWidgets, QtCore @@ -75,7 +75,7 @@ def __init__( self.current_pipeline = "" - self.tabs = dict() + self.tabs: Dict[str, TrainingEditorWidget] = dict() self.shown_tab_names = [] self._cfg_getter = configs.TrainingConfigsGetter.make_from_labels_filename( @@ -628,6 +628,7 @@ def run(self): """Run with current dialog settings.""" pipeline_form_data = self.pipeline_form_widget.get_form_data() + items_for_inference = self.get_items_for_inference(pipeline_form_data) config_info_list = self.get_every_head_config_data(pipeline_form_data) @@ -900,12 +901,13 @@ def __init__( self._cfg_list_widget = None self._receptive_field_widget = None self._use_trained_model = None + self._resume_training = None self._require_trained = require_trained self.head = head yaml_name = "training_editor_form" - self.form_widgets = dict() + self.form_widgets: Dict[str, YamlFormWidget] = dict() for key in ("model", "data", "augmentation", "optimization", "outputs"): self.form_widgets[key] = YamlFormWidget.from_name( @@ -960,11 +962,10 @@ def __init__( # If we have an object which gets a list of config files, # then we'll show a menu to allow selection from the list. - - if self._cfg_getter: + if self._cfg_getter is not None: self._cfg_list_widget = configs.TrainingConfigFilesWidget( cfg_getter=self._cfg_getter, - head_name=head, + head_name=cast(str, head), # Expect head to be a string require_trained=require_trained, ) self._cfg_list_widget.onConfigSelection.connect( @@ -974,20 +975,22 @@ def __init__( layout.addWidget(self._cfg_list_widget) - # Add option for using trained model from selected config - if self._require_trained: - self._update_use_trained() - else: - self._use_trained_model = QtWidgets.QCheckBox("Use Trained Model") - self._use_trained_model.setEnabled(False) - self._use_trained_model.setVisible(False) - - self._use_trained_model.stateChanged.connect(self._update_use_trained) + if self._require_trained: + self._update_use_trained() + elif self._cfg_list_widget is not None: + # Add option for using trained model from selected config file + self._use_trained_model = QtWidgets.QCheckBox("Use Trained Model") + self._use_trained_model.setEnabled(False) + self._use_trained_model.setVisible(False) + self._resume_training = QtWidgets.QCheckBox("Resume Training") + self._resume_training.setEnabled(False) + self._resume_training.setVisible(False) - layout.addWidget(self._use_trained_model) + self._use_trained_model.stateChanged.connect(self._update_use_trained) + self._resume_training.stateChanged.connect(self._update_use_trained) - elif self._require_trained: - self._update_use_trained() + layout.addWidget(self._use_trained_model) + layout.addWidget(self._resume_training) layout.addWidget(self._layout_widget(col_layout)) self.setLayout(layout) @@ -1017,9 +1020,15 @@ def acceptSelectedConfigInfo(self, cfg_info: configs.ConfigFileInfo): self._load_config(cfg_info) has_trained_model = cfg_info.has_trained_model - if self._use_trained_model: + if self._use_trained_model is not None: + self._use_trained_model.setChecked(self._require_trained) self._use_trained_model.setVisible(has_trained_model) self._use_trained_model.setEnabled(has_trained_model) + # Redundant check (for readability) since this checkbox exists if the above does + if self._resume_training is not None: + self._use_trained_model.setChecked(False) + self._resume_training.setVisible(has_trained_model) + self._resume_training.setEnabled(has_trained_model) self.update_receptive_field() @@ -1060,17 +1069,64 @@ def _load_config(self, cfg_info: configs.ConfigFileInfo): # self._cfg_list_widget.setUserConfigData(cfg_form_data_dict) def _update_use_trained(self, check_state=0): - if self._require_trained: - use_trained = True - else: - use_trained = check_state == QtCore.Qt.CheckState.Checked + """Update config GUI based on _use_trained_model and _resume_training checkboxes. + + This function is called when either _use_trained_model or _resume_training checkbox + is checked/unchecked or when _require_trained is changed. + + If _require_trained is True, then we'll disable all fields. + If _use_trained_model is checked, then we'll disable all fields. + If _resume_training is checked, then we'll disable only the model field. + + Args: + check_state (int, optional): Check state of checkbox. Defaults to 0. Unused. + + Returns: + None + + Side Effects: + Disables/Enables fields based on checkbox values (and _required_training). + """ + + # Check which checkbox changed its value (if any) + sender = self.sender() + + if sender is None: # If sender is None, then _required_training is True + pass + # Uncheck _resume_training checkbox if _use_trained_model is unchecked + elif (sender == self._use_trained_model) and ( + not self._use_trained_model.isChecked() + ): + self._resume_training.setChecked(False) + + # Check _use_trained_model checkbox if _resume_training is checked + elif (sender == self._resume_training) and self._resume_training.isChecked(): + self._use_trained_model.setChecked(True) + # Update form widgets + use_trained_params = self.use_trained + use_model_params = self.resume_training for form in self.form_widgets.values(): - form.set_enabled(not use_trained) + form.set_enabled(not use_trained_params) - # If user wants to use trained model, then reset form to match config - if use_trained and self._cfg_list_widget: + if use_trained_params or use_model_params: cfg_info = self._cfg_list_widget.getSelectedConfigInfo() + + # If user wants to resume training, then reset only model form to match config + if use_model_params: + self.form_widgets["model"].set_enabled(False) + + # Set model form to match config + cfg = cfg_info.config + cfg_dict = cattr.unstructure(cfg) + model_dict = {"model": cfg_dict["model"]} + key_val_dict = scopedkeydict.ScopedKeyDict.from_hierarchical_dict( + model_dict + ).key_val_dict + self.set_fields_from_key_val_dict(key_val_dict) + + # If user wants to use trained model, then reset entire form to match config + if use_trained_params: self._load_config(cfg_info) self._set_head() @@ -1100,17 +1156,32 @@ def _set_backbone_from_key_val_dict(self, cfg_key_val_dict): @property def use_trained(self) -> bool: - use_trained = False - if self._require_trained: - use_trained = True - elif self._use_trained_model and self._use_trained_model.isChecked(): - use_trained = True - return use_trained + if self._require_trained or ( + (self._use_trained_model is not None) + and self._use_trained_model.isChecked() + and (not self.resume_training) + ): + return True + + return False + + @property + def resume_training(self) -> bool: + if (self._resume_training is not None) and self._resume_training.isChecked(): + return True + return False @property def trained_config_info_to_use(self) -> Optional[configs.ConfigFileInfo]: - trained_config_info = self._cfg_list_widget.getSelectedConfigInfo() - if trained_config_info is None: + # If `TrainingEditorWidget` was initialized with a config getter, then + # we expect to have a list of config files + if self._cfg_list_widget is None: + return None + + trained_config_info: Optional[ + configs.ConfigFileInfo + ] = self._cfg_list_widget.getSelectedConfigInfo() + if (trained_config_info is None) or (not trained_config_info.has_trained_model): return None if self.use_trained: @@ -1123,13 +1194,25 @@ def trained_config_info_to_use(self) -> Optional[configs.ConfigFileInfo]: trained_config.outputs.run_name_prefix = "" trained_config.outputs.run_name_suffix = None + if self.resume_training: + # Get the folder path of trained config and set it as the output folder + trained_config_info.config.model.base_checkpoint = str( + Path(cast(str, trained_config_info.path)).parent + ) + else: + trained_config_info.config.model.base_checkpoint = None + return trained_config_info @property def has_trained_config_selected(self) -> bool: + if self._cfg_list_widget is None: + return False + cfg_info = self._cfg_list_widget.getSelectedConfigInfo() if cfg_info and cfg_info.has_trained_model: return True + return False def get_all_form_data(self) -> dict: diff --git a/sleap/nn/config/model.py b/sleap/nn/config/model.py index ca4573266..76056782f 100644 --- a/sleap/nn/config/model.py +++ b/sleap/nn/config/model.py @@ -653,7 +653,9 @@ class ModelConfig: Attributes: backbone: Configurations related to the main network architecture. heads: Configurations related to the output heads. + base_checkpoint: Path to model folder for loading a checkpoint. Should contain the .h5 file """ backbone: BackboneConfig = attr.ib(factory=BackboneConfig) heads: HeadsConfig = attr.ib(factory=HeadsConfig) + base_checkpoint: Optional[Text] = None diff --git a/sleap/nn/training.py b/sleap/nn/training.py index 164174e02..43be5cf31 100644 --- a/sleap/nn/training.py +++ b/sleap/nn/training.py @@ -745,6 +745,22 @@ def _setup_model(self): for i, output in enumerate(self.model.keras_model.outputs): logger.info(f" [{i}] = {output}") + # Resuming training if flagged + if self.config.model.base_checkpoint is not None: + # TODO (AL): Add flexibilty to resume from any checkpoint (e.g. + # latest_model, specific epoch, etc.) + + # Grab the 'best_model.h5' file from the previous training run + # and load it into the current model + previous_model_path = os.path.join( + self.config.model.base_checkpoint, "best_model.h5" + ) + + self.keras_model.load_weights(previous_model_path) + logger.info(f"Loaded previous model weights from {previous_model_path}") + else: + logger.info("Training from scratch") + @property def keras_model(self) -> tf.keras.Model: """Alias for `self.model.keras_model`.""" @@ -1783,7 +1799,7 @@ def visualize_example(example): ) -def main(): +def main(args: Optional[List] = None): """Create CLI for training and run.""" import argparse @@ -1825,6 +1841,14 @@ def main(): "specified in the training job config." ), ) + parser.add_argument( + "--base_checkpoint", + type=str, + help=( + "Path to base checkpoint (directory containing best_model.h5) to resume " + "training from." + ), + ) parser.add_argument( "--tensorboard", action="store_true", @@ -1883,7 +1907,7 @@ def main(): ), ) - args, _ = parser.parse_known_args() + args, _ = parser.parse_known_args(args) # Find job configuration file. job_filename = args.training_job_path @@ -1916,6 +1940,8 @@ def main(): if len(args.video_paths) == 0: args.video_paths = None + job_config.model.base_checkpoint = args.base_checkpoint + logger.info("Versions:") sleap.versions() @@ -1980,6 +2006,8 @@ def main(): ) trainer.train() + return trainer + if __name__ == "__main__": main() diff --git a/tests/gui/learning/test_dialog.py b/tests/gui/learning/test_dialog.py index a244fd2ed..2bf9e63b3 100644 --- a/tests/gui/learning/test_dialog.py +++ b/tests/gui/learning/test_dialog.py @@ -1,15 +1,26 @@ +import shutil +from typing import Optional, List, Callable, Set +from pathlib import Path +import traceback + +import cattr +import pytest +from qtpy import QtWidgets + from sleap.gui.learning.dialog import LearningDialog, TrainingEditorWidget -from sleap.gui.learning.configs import TrainingConfigFilesWidget -from sleap.gui.learning.configs import ConfigFileInfo +from sleap.gui.learning.configs import ( + TrainingConfigFilesWidget, + ConfigFileInfo, + TrainingConfigsGetter, +) from sleap.gui.learning.scopedkeydict import ( ScopedKeyDict, apply_cfg_transforms_to_key_val_dict, ) from sleap.gui.app import MainWindow +from sleap.io.dataset import Labels from sleap.nn.config import TrainingJobConfig, UNetConfig - -import cattr -from pathlib import Path +from sleap.util import get_package_file def test_use_hidden_params_from_loaded_config( @@ -116,3 +127,225 @@ def test_update_loaded_config(): scoped_cfg.key_val_dict["optimization.augmentation_config.rotation_min_angle"] == -180 ) + + +def test_training_editor_checkbox_states( + qtbot, tmpdir, min_labels: Labels, min_centroid_model_path: str +): + """Test that Use Trained Model and Resume Training checkboxes operate correctly.""" + + def assert_checkbox_states( + ted: TrainingEditorWidget, + use_trained: Optional[bool] = None, + resume_training: Optional[bool] = None, + ): + assert ( + ted._use_trained_model.isChecked() == use_trained + if use_trained is not None + else True + ) + assert ( + ted._resume_training.isChecked() == resume_training + if resume_training is not None + else True + ) + + def switch_states( + ted: TrainingEditorWidget, + prev_use_trained: Optional[bool] = None, + prev_resume_training: Optional[bool] = None, + new_use_trained: Optional[bool] = None, + new_resume_training: Optional[bool] = None, + ): + """Switch the states of the checkboxes.""" + + # Assert previous checkbox state + assert_checkbox_states( + ted, use_trained=prev_use_trained, resume_training=prev_resume_training + ) + + # Switch states + if new_use_trained is not None: + ted._use_trained_model.setChecked(new_use_trained) + if new_resume_training is not None: + ted._resume_training.setChecked(new_resume_training) + + # Assert new checkbox state + assert_checkbox_states( + ted, use_trained=new_use_trained, resume_training=new_resume_training + ) + + def check_resume_training( + ted: TrainingEditorWidget, prev_use_trained: Optional[bool] = None + ): + """Check the Resume Training checkbox.""" + switch_states( + ted, + prev_use_trained=prev_use_trained, + new_use_trained=True, + new_resume_training=True, + ) + assert not ted.use_trained + assert ted.resume_training + + def check_resume_training_00(ted: TrainingEditorWidget): + """Check the Resume Training checkbox when Use Trained is unchecked.""" + check_resume_training(ted, prev_use_trained=False) + + def check_resume_training_10(ted: TrainingEditorWidget): + """Check the Resume Training checkbox when Use Trained is checked.""" + check_resume_training(ted, prev_use_trained=True) + + def check_use_trained(ted: TrainingEditorWidget): + """Check the Use Trained checkbox when Resume Training is unchecked.""" + switch_states(ted, prev_resume_training=False, new_use_trained=True) + assert ted.use_trained + assert not ted.resume_training + + def uncheck_resume_training(ted: TrainingEditorWidget): + """Uncheck the Resume Training checkbox when Use Trained is checked.""" + switch_states(ted, prev_use_trained=True, new_resume_training=False) + assert ted.use_trained + assert not ted.resume_training + + def uncheck_use_trained( + ted: TrainingEditorWidget, prev_resume_training: Optional[bool] = None + ): + """Uncheck the Use Trained checkbox.""" + switch_states( + ted, + prev_resume_training=prev_resume_training, + new_use_trained=False, + new_resume_training=False, + ) + assert not ted.use_trained + assert not ted.resume_training + + def uncheck_use_trained_10(ted: TrainingEditorWidget): + """Uncheck the Use Trained checkbox when Resume Training is unchecked.""" + uncheck_use_trained(ted, prev_resume_training=False) + + def uncheck_use_trained_11(ted: TrainingEditorWidget): + """Uncheck the Use Trained checkbox when Resume Training is checked.""" + uncheck_use_trained(ted, prev_resume_training=True) + + def assert_form_state( + change_state: Callable, + ted: TrainingEditorWidget, + og_form_data: dict, + reset_causing_actions: Set[Callable] = { + check_use_trained, + uncheck_resume_training, + }, + ): + expected_form_data = dict() + + # Read form values before changing state + if change_state not in reset_causing_actions: + expected_form_data = ted.get_all_form_data() + + # Change state + change_state(ted) + + # Modify expected form values depending on state, and check if form is enabled + if ted.resume_training: + for key, val in og_form_data.items(): + if key.startswith("model."): + expected_form_data[key] = val + elif ted.use_trained: + expected_form_data = og_form_data + + # Read form values after changing state + actual_form_data = ted.get_all_form_data() + assert expected_form_data == actual_form_data + + # Load the data + labels: Labels = min_labels + video = labels.video + skeleton = labels.skeleton + model_path = Path(min_centroid_model_path) + + # Spoof an untrained model + untrained_model_path = Path(tmpdir, model_path.parts[-1]) + untrained_model_path.mkdir() + shutil.copy( + Path(model_path, "training_config.json"), + Path(untrained_model_path, "training_config.json"), + ) + + # Create a training TrainingEditorWidget + head_name = (model_path.name).split(".")[-1] + mode = "training" + cfg_getter = TrainingConfigsGetter( + dir_paths=[str(model_path), str(untrained_model_path)], head_filter=head_name + ) + ted = TrainingEditorWidget( + video=video, + skeleton=skeleton, + head=head_name, + cfg_getter=cfg_getter, + require_trained=(mode == "inference"), + ) + ted.update_file_list() + + og_form_data = ted.get_all_form_data() + + # Modify the form data + copy_form_data_everything = og_form_data.copy() + copy_form_data_everything["data.labels.validation_fraction"] = 0.3 + copy_form_data_everything["optimization.augmentation_config.rotate"] = True + copy_form_data_everything["optimization.epochs"] = 50 + copy_form_data_except_model = copy_form_data_everything.copy() + copy_form_data_everything["_backbone_name"] = "leap" + + # The action trajectory below should cover the entire state space of the checkboxes + action_trajectory: List[Callable] = [ + check_resume_training_00, + uncheck_use_trained_11, + check_use_trained, + check_resume_training_10, + uncheck_resume_training, + uncheck_use_trained_10, + ] + + actions_that_allow_change_everything_except_model = { + check_resume_training_00, + check_resume_training_10, + } + actions_that_allow_change_everything = { + uncheck_use_trained_10, + uncheck_use_trained_10, + } + for action in action_trajectory: + assert_form_state(action, ted, og_form_data) + if action in actions_that_allow_change_everything: + ted.set_fields_from_key_val_dict(copy_form_data_everything) + elif action in actions_that_allow_change_everything_except_model: + ted.set_fields_from_key_val_dict(copy_form_data_except_model) + + # Test the case where the user selectes untrained model + ted._cfg_list_widget.setCurrentIndex(1) + assert not ted.has_trained_config_selected + assert not ted._resume_training.isChecked() + assert not ted._resume_training.isVisible() + assert not ted._use_trained_model.isVisible() + assert not ted._use_trained_model.isChecked() + assert not ted.use_trained + assert not ted.resume_training + + # Test the case where the user opts to perform inference + mode = "inference" + ted = TrainingEditorWidget( + video=video, + skeleton=skeleton, + head=head_name, + cfg_getter=cfg_getter, + require_trained=(mode == "inference"), + ) + ted.update_file_list() + assert len(ted._cfg_list_widget._cfg_list) == 1 + assert ted.has_trained_config_selected + assert ted._resume_training is None + assert ted._use_trained_model is None + assert ted.use_trained + assert not ted.resume_training diff --git a/tests/nn/test_training.py b/tests/nn/test_training.py index b8f9cc072..b3eda8676 100644 --- a/tests/nn/test_training.py +++ b/tests/nn/test_training.py @@ -1,4 +1,7 @@ +from pathlib import Path + import pytest + import sleap from sleap.io.dataset import Labels from sleap.nn.config.data import LabelsConfig @@ -19,6 +22,7 @@ TopdownConfmapsModelTrainer, TopDownMultiClassModelTrainer, Trainer, + main as sleap_train, ) sleap.use_cpu_only() @@ -76,6 +80,49 @@ def test_data_reader(min_labels_slp_path): assert data_readers.test_labels_reader.example_indices == [0] +def test_train_load_single_instance( + min_labels_robot: Labels, cfg: TrainingJobConfig, tmp_path: str +): + # set save directory + cfg.outputs.run_name = "test_run" + cfg.outputs.runs_folder = str(tmp_path / "training_runs") # ensure it's a string + cfg.outputs.save_outputs = True # enable saving + cfg.outputs.checkpointing.latest_model = True # save latest model + + cfg.model.heads.single_instance = SingleInstanceConfmapsHeadConfig( + sigma=1.5, output_stride=1, offset_refinement=False + ) + trainer = SingleInstanceModelTrainer.from_config( + cfg, training_labels=min_labels_robot + ) + trainer.setup() + trainer.train() + + # now load a new model and resume the checkpoint + # set the model checkpoint folder + cfg.model.base_checkpoint = cfg.outputs.run_path + # unset save directory + cfg.outputs.run_name = None + cfg.outputs.runs_folder = None + cfg.outputs.save_outputs = False # disable saving + cfg.outputs.checkpointing.latest_model = False # disable saving latest model + + trainer2 = SingleInstanceModelTrainer.from_config( + cfg, training_labels=min_labels_robot + ) + trainer2.setup() + + # check the weights are the same + for layer, layer2 in zip(trainer.keras_model.layers, trainer2.keras_model.layers): + # grabbing the weights from the first model + weights = layer.get_weights() + # grabbing the weights from the second model + weights2 = layer2.get_weights() + # check the weights are the same + for w, w2 in zip(weights, weights2): + assert (w == w2).all() + + def test_train_single_instance(min_labels_robot, cfg): cfg.model.heads.single_instance = SingleInstanceConfmapsHeadConfig( sigma=1.5, output_stride=1, offset_refinement=False @@ -267,3 +314,45 @@ def test_train_cropping( trainer.config.data.instance_cropping.crop_size % trainer.model.maximum_stride == 0 ) + + +def test_resume_training_cli( + min_single_instance_robot_model_path: str, small_robot_mp4_path: str, tmp_path: str +): + """Test CLI to resume training.""" + + base_checkpoint_path = min_single_instance_robot_model_path + cfg = TrainingJobConfig.load_json( + str(Path(base_checkpoint_path, "training_config.json")) + ) + cfg.optimization.preload_data = False + cfg.optimization.batch_size = 1 + cfg.optimization.batches_per_epoch = 2 + cfg.optimization.epochs = 1 + cfg.outputs.save_outputs = False + + # Save training config to tmp folder + cfg_path = str(Path(tmp_path, "training_config.json")) + cfg.save_json(cfg_path) + + # TODO (LM): Stop saving absolute paths in labels files! + # We need to do this reload because we save absolute paths (for the video). + labels_path = str(Path(base_checkpoint_path, "labels_gt.train.slp")) + labels: Labels = sleap.load_file(labels_path, search_paths=[small_robot_mp4_path]) + labels_path = str(Path(tmp_path, "labels_gt.train.slp")) + labels.save_file(labels, labels_path) + + # Run CLI to resume training + trainer = sleap_train( + [ + cfg_path, + labels_path, + "--base_checkpoint", + base_checkpoint_path, + ] + ) + assert trainer.config.model.base_checkpoint == base_checkpoint_path + + # Run CLI without base checkpoint + trainer = sleap_train([cfg_path, labels_path]) + assert trainer.config.model.base_checkpoint is None