From bb1bfb1c0375d1f9d1d3b77afefe09db9eb74eff Mon Sep 17 00:00:00 2001 From: roomrys <38435167+roomrys@users.noreply.github.com> Date: Fri, 3 Feb 2023 14:43:51 -0800 Subject: [PATCH 1/8] Refactor/update 'use trained' and 'resume training' checkbox logic --- sleap/gui/learning/dialog.py | 77 +++++++++++++++++++++++------------- 1 file changed, 50 insertions(+), 27 deletions(-) diff --git a/sleap/gui/learning/dialog.py b/sleap/gui/learning/dialog.py index 0cff9ae70..0ad4ccd80 100644 --- a/sleap/gui/learning/dialog.py +++ b/sleap/gui/learning/dialog.py @@ -1028,6 +1028,7 @@ def acceptSelectedConfigInfo(self, cfg_info: configs.ConfigFileInfo): self._use_trained_model.setVisible(has_trained_model) self._use_trained_model.setEnabled(has_trained_model) self._resume_training.setVisible(has_trained_model) + self._resume_training.setEnabled(has_trained_model) self.update_receptive_field() @@ -1068,36 +1069,58 @@ def _load_config(self, cfg_info: configs.ConfigFileInfo): # self._cfg_list_widget.setUserConfigData(cfg_form_data_dict) def _update_use_trained(self, check_state=0): - if self._require_trained: - use_trained = True - else: - use_trained = check_state == QtCore.Qt.CheckState.Checked - - # check if self._resume_training is None, because it might not be initialized yet - resume_training = self._resume_training and self._resume_training.isChecked() - - if self._use_trained_model is not None: - for form in self.form_widgets.values(): - form.set_enabled( - (not self._use_trained_model.isChecked()) - or ( - self._use_trained_model.isChecked() - and self._resume_training.isChecked() - ) - ) + """Update config GUI based on _use_trained_model and _resume_training checkboxes. - if self._use_trained_model.isChecked(): - self._resume_training.setEnabled(True) + This function is called when either _use_trained_model or _resume_training checkbox + is checked/unchecked or when _require_trained is changed. - if resume_training: - self.form_widgets["model"].set_enabled(False) - else: - # change the checkbox to unchecked + If _require_trained is True, then we'll disable all fields. + If _use_trained_model is checked, then we'll disable all fields. + If _resume_training is checked, then we'll disable only the model field. + + Args: + check_state (int, optional): Check state of checkbox. Defaults to 0. Unused. + + Returns: + None + + Side Effects: + Disables/Enables fields based on checkbox values (and _required_training). + """ + + # If we're requiring a trained model, then we don't plan on retraining the model + sender = self.sender() + use_trained = self._require_trained + resume_training = False + + # Either _use_trained_model or _resume_training checkbox value changed + if not self._require_trained: + + # Uncheck _resume_training checkbox if _use_trained_model is unchecked + if ( + self._use_trained_model == sender + and not self._use_trained_model.isChecked() + ): self._resume_training.setChecked(False) - self._resume_training.setEnabled(False) - else: - for form in self.form_widgets.values(): - form.set_enabled(not use_trained) + + # Check _use_trained_model checkbox if _resume_training is checked + elif self._resume_training == sender and self._resume_training.isChecked(): + self._use_trained_model.setChecked(True) + + # Determine what to do with the form widgets + use_trained = ( + self._use_trained_model.isChecked() + and not self._resume_training.isChecked() + ) + resume_training = ( + self._resume_training and self._resume_training.isChecked() + ) + + # Update form widgets + for form in self.form_widgets.values(): + form.set_enabled(not use_trained) + if resume_training: + self.form_widgets["model"].set_enabled(False) # If user wants to use trained model, then reset form to match config if use_trained and self._cfg_list_widget and (not resume_training): From 7be9fb2f19c8fdd49b8041294e4d72ffef7d2913 Mon Sep 17 00:00:00 2001 From: roomrys <38435167+roomrys@users.noreply.github.com> Date: Tue, 7 Feb 2023 19:16:23 -0800 Subject: [PATCH 2/8] Simplify checkbox logic and reset model field when resume training --- sleap/gui/app.py | 13 +++++ sleap/gui/learning/dialog.py | 93 +++++++++++++++++++----------------- 2 files changed, 63 insertions(+), 43 deletions(-) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 4a3f6da95..273b8537a 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -1835,3 +1835,16 @@ def main(args: Optional[list] = None): app.exec_() pass + + +if __name__ == "__main__": + import os + + ds = os.environ["ds-dmc"] + + app = QApplication([]) + + window = MainWindow(labels_path=ds, no_usage_data=True) + window._show_learning_dialog("training") + + app.exec_() diff --git a/sleap/gui/learning/dialog.py b/sleap/gui/learning/dialog.py index 0ad4ccd80..46e18361f 100644 --- a/sleap/gui/learning/dialog.py +++ b/sleap/gui/learning/dialog.py @@ -75,7 +75,7 @@ def __init__( self.current_pipeline = "" - self.tabs = dict() + self.tabs: Dict[str, TrainingEditorWidget] = dict() self.shown_tab_names = [] self._cfg_getter = configs.TrainingConfigsGetter.make_from_labels_filename( @@ -907,7 +907,7 @@ def __init__( yaml_name = "training_editor_form" - self.form_widgets = dict() + self.form_widgets: Dict[str, YamlFormWidget] = dict() for key in ("model", "data", "augmentation", "optimization", "outputs"): self.form_widgets[key] = YamlFormWidget.from_name( @@ -1024,9 +1024,11 @@ def acceptSelectedConfigInfo(self, cfg_info: configs.ConfigFileInfo): self._load_config(cfg_info) has_trained_model = cfg_info.has_trained_model - if self._use_trained_model: + if self._use_trained_model is not None: self._use_trained_model.setVisible(has_trained_model) self._use_trained_model.setEnabled(has_trained_model) + # Redundant check (for readability) since this checkbox exists if the above does + if self._resume_training is not None: self._resume_training.setVisible(has_trained_model) self._resume_training.setEnabled(has_trained_model) @@ -1088,42 +1090,44 @@ def _update_use_trained(self, check_state=0): Disables/Enables fields based on checkbox values (and _required_training). """ - # If we're requiring a trained model, then we don't plan on retraining the model + # Check which checkbox changed its value (if any) sender = self.sender() - use_trained = self._require_trained - resume_training = False - # Either _use_trained_model or _resume_training checkbox value changed - if not self._require_trained: - - # Uncheck _resume_training checkbox if _use_trained_model is unchecked - if ( - self._use_trained_model == sender - and not self._use_trained_model.isChecked() - ): - self._resume_training.setChecked(False) - - # Check _use_trained_model checkbox if _resume_training is checked - elif self._resume_training == sender and self._resume_training.isChecked(): - self._use_trained_model.setChecked(True) + # Uncheck _resume_training checkbox if _use_trained_model is unchecked + if (sender == self._use_trained_model) and ( + not self._use_trained_model.isChecked() + ): + self._resume_training.setChecked(False) - # Determine what to do with the form widgets - use_trained = ( - self._use_trained_model.isChecked() - and not self._resume_training.isChecked() - ) - resume_training = ( - self._resume_training and self._resume_training.isChecked() - ) + # Check _use_trained_model checkbox if _resume_training is checked + elif (sender == self._resume_training) and self._resume_training.isChecked(): + self._use_trained_model.setChecked(True) # Update form widgets + + use_trained_params = self.use_trained + use_model_params = self.resume_training for form in self.form_widgets.values(): - form.set_enabled(not use_trained) - if resume_training: + form.set_enabled(not use_trained_params) + + if use_trained_params or use_model_params: + cfg_info = self._cfg_list_widget.getSelectedConfigInfo() + + # If user wants to resume training, then reset only model form to match config + if use_model_params: self.form_widgets["model"].set_enabled(False) - # If user wants to use trained model, then reset form to match config - if use_trained and self._cfg_list_widget and (not resume_training): + # Set model form to match config + cfg = cfg_info.config + cfg_dict = cattr.unstructure(cfg) + model_dict = {"model": cfg_dict["model"]} + key_val_dict = scopedkeydict.ScopedKeyDict.from_hierarchical_dict( + model_dict + ).key_val_dict + self.set_fields_from_key_val_dict(key_val_dict) + + # If user wants to use trained model, then reset entire form to match config + if use_trained_params: cfg_info = self._cfg_list_widget.getSelectedConfigInfo() self._load_config(cfg_info) @@ -1154,19 +1158,23 @@ def _set_backbone_from_key_val_dict(self, cfg_key_val_dict): @property def use_trained(self) -> bool: - use_trained = False if self._require_trained: - use_trained = True - elif self._use_trained_model and self._use_trained_model.isChecked(): - use_trained = True - return use_trained + return True + + if ( + (self._use_trained_model is not None) + and self._use_trained_model.isChecked() + and (not self.resume_training) + ): + return True + + return False @property def resume_training(self) -> bool: - resume_training = False - if self._resume_training and self._resume_training.isChecked(): - resume_training = True - return resume_training + if (self._resume_training is not None) and self._resume_training.isChecked(): + return True + return False @property def trained_config_info_to_use(self) -> Optional[configs.ConfigFileInfo]: @@ -1175,8 +1183,7 @@ def trained_config_info_to_use(self) -> Optional[configs.ConfigFileInfo]: return None if self.use_trained: - if not self.resume_training: - trained_config_info.dont_retrain = True + trained_config_info.dont_retrain = True else: # Set certain parameters to defaults trained_config = trained_config_info.config @@ -1186,7 +1193,7 @@ def trained_config_info_to_use(self) -> Optional[configs.ConfigFileInfo]: trained_config.outputs.run_name_suffix = None if self.resume_training: - # get the folder path of trained_config_info.path and set it as the output folder + # Get the folder path of trained config and set it as the output folder trained_config_info.config.model.base_checkpoint = os.path.dirname( trained_config_info.path ) From 021bab6278331703ffd21af25df927d0f04f1b83 Mon Sep 17 00:00:00 2001 From: roomrys <38435167+roomrys@users.noreply.github.com> Date: Wed, 8 Feb 2023 12:31:57 -0800 Subject: [PATCH 3/8] Reset checkboxes upon changing config selection --- sleap/gui/learning/dialog.py | 66 +++++++++++++++++++----------------- 1 file changed, 35 insertions(+), 31 deletions(-) diff --git a/sleap/gui/learning/dialog.py b/sleap/gui/learning/dialog.py index 46e18361f..ce4a10da0 100644 --- a/sleap/gui/learning/dialog.py +++ b/sleap/gui/learning/dialog.py @@ -14,7 +14,7 @@ from sleap.gui.dialogs.formbuilder import YamlFormWidget from sleap.gui.learning import runners, scopedkeydict, configs, datagen, receptivefield -from typing import Dict, List, Optional, Text, Optional +from typing import Dict, List, Optional, Text, Optional, cast from qtpy import QtWidgets, QtCore @@ -962,11 +962,10 @@ def __init__( # If we have an object which gets a list of config files, # then we'll show a menu to allow selection from the list. - - if self._cfg_getter: + if self._cfg_getter is not None: self._cfg_list_widget = configs.TrainingConfigFilesWidget( cfg_getter=self._cfg_getter, - head_name=head, + head_name=cast(str, head), # Expect head to be a string require_trained=require_trained, ) self._cfg_list_widget.onConfigSelection.connect( @@ -976,25 +975,22 @@ def __init__( layout.addWidget(self._cfg_list_widget) - # Add option for using trained model from selected config - if self._require_trained: - self._update_use_trained() - else: - self._use_trained_model = QtWidgets.QCheckBox("Use Trained Model") - self._use_trained_model.setEnabled(False) - self._use_trained_model.setVisible(False) - self._resume_training = QtWidgets.QCheckBox("Resume Training") - self._resume_training.setEnabled(False) - self._resume_training.setVisible(False) - - self._use_trained_model.stateChanged.connect(self._update_use_trained) - self._resume_training.stateChanged.connect(self._update_use_trained) + if self._require_trained: + self._update_use_trained() + elif self._cfg_list_widget is not None: + # Add option for using trained model from selected config file + self._use_trained_model = QtWidgets.QCheckBox("Use Trained Model") + self._use_trained_model.setEnabled(False) + self._use_trained_model.setVisible(False) + self._resume_training = QtWidgets.QCheckBox("Resume Training") + self._resume_training.setEnabled(False) + self._resume_training.setVisible(False) - layout.addWidget(self._use_trained_model) - layout.addWidget(self._resume_training) + self._use_trained_model.stateChanged.connect(self._update_use_trained) + self._resume_training.stateChanged.connect(self._update_use_trained) - elif self._require_trained: - self._update_use_trained() + layout.addWidget(self._use_trained_model) + layout.addWidget(self._resume_training) layout.addWidget(self._layout_widget(col_layout)) self.setLayout(layout) @@ -1025,10 +1021,12 @@ def acceptSelectedConfigInfo(self, cfg_info: configs.ConfigFileInfo): has_trained_model = cfg_info.has_trained_model if self._use_trained_model is not None: + self._use_trained_model.setChecked(self._require_trained) self._use_trained_model.setVisible(has_trained_model) self._use_trained_model.setEnabled(has_trained_model) # Redundant check (for readability) since this checkbox exists if the above does if self._resume_training is not None: + self._use_trained_model.setChecked(False) self._resume_training.setVisible(has_trained_model) self._resume_training.setEnabled(has_trained_model) @@ -1104,7 +1102,6 @@ def _update_use_trained(self, check_state=0): self._use_trained_model.setChecked(True) # Update form widgets - use_trained_params = self.use_trained use_model_params = self.resume_training for form in self.form_widgets.values(): @@ -1128,7 +1125,6 @@ def _update_use_trained(self, check_state=0): # If user wants to use trained model, then reset entire form to match config if use_trained_params: - cfg_info = self._cfg_list_widget.getSelectedConfigInfo() self._load_config(cfg_info) self._set_head() @@ -1158,10 +1154,7 @@ def _set_backbone_from_key_val_dict(self, cfg_key_val_dict): @property def use_trained(self) -> bool: - if self._require_trained: - return True - - if ( + if self._require_trained or ( (self._use_trained_model is not None) and self._use_trained_model.isChecked() and (not self.resume_training) @@ -1178,8 +1171,15 @@ def resume_training(self) -> bool: @property def trained_config_info_to_use(self) -> Optional[configs.ConfigFileInfo]: - trained_config_info = self._cfg_list_widget.getSelectedConfigInfo() - if trained_config_info is None: + # If `TrainingEditorWidget` was initialized with a config getter, then + # we expect to have a list of config files + if self._cfg_list_widget is None: + return None + + trained_config_info: Optional[ + configs.ConfigFileInfo + ] = self._cfg_list_widget.getSelectedConfigInfo() + if (trained_config_info is None) or (not trained_config_info.has_trained_model): return None if self.use_trained: @@ -1194,8 +1194,8 @@ def trained_config_info_to_use(self) -> Optional[configs.ConfigFileInfo]: if self.resume_training: # Get the folder path of trained config and set it as the output folder - trained_config_info.config.model.base_checkpoint = os.path.dirname( - trained_config_info.path + trained_config_info.config.model.base_checkpoint = str( + Path(cast(str, trained_config_info.path)).parent ) else: trained_config_info.config.model.base_checkpoint = None @@ -1204,9 +1204,13 @@ def trained_config_info_to_use(self) -> Optional[configs.ConfigFileInfo]: @property def has_trained_config_selected(self) -> bool: + if self._cfg_list_widget is None: + return False + cfg_info = self._cfg_list_widget.getSelectedConfigInfo() if cfg_info and cfg_info.has_trained_model: return True + return False def get_all_form_data(self) -> dict: From 2f2d6c3faa93c70cfd151c851f3bad5f058ff8e1 Mon Sep 17 00:00:00 2001 From: roomrys <38435167+roomrys@users.noreply.github.com> Date: Wed, 8 Feb 2023 13:02:20 -0800 Subject: [PATCH 4/8] Handle case for updating TrainingEditor when sender is not a checkbox --- sleap/gui/app.py | 2 +- sleap/gui/learning/dialog.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 273b8537a..3c333d19b 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -1845,6 +1845,6 @@ def main(args: Optional[list] = None): app = QApplication([]) window = MainWindow(labels_path=ds, no_usage_data=True) - window._show_learning_dialog("training") + window._show_learning_dialog("inference") app.exec_() diff --git a/sleap/gui/learning/dialog.py b/sleap/gui/learning/dialog.py index ce4a10da0..f65bf9f85 100644 --- a/sleap/gui/learning/dialog.py +++ b/sleap/gui/learning/dialog.py @@ -1091,8 +1091,10 @@ def _update_use_trained(self, check_state=0): # Check which checkbox changed its value (if any) sender = self.sender() + if sender is None: # If sender is None, then _required_training is True + pass # Uncheck _resume_training checkbox if _use_trained_model is unchecked - if (sender == self._use_trained_model) and ( + elif (sender == self._use_trained_model) and ( not self._use_trained_model.isChecked() ): self._resume_training.setChecked(False) From 1895bb9cda5993bba5e66ed7bc53236a802ae7cc Mon Sep 17 00:00:00 2001 From: roomrys <38435167+roomrys@users.noreply.github.com> Date: Wed, 8 Feb 2023 19:02:59 -0800 Subject: [PATCH 5/8] Add complete state space GUI test for checkboxes --- sleap/gui/app.py | 1 + tests/gui/learning/test_dialog.py | 199 +++++++++++++++++++++++++++++- 2 files changed, 195 insertions(+), 5 deletions(-) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 3c333d19b..da8f6d8bf 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -1837,6 +1837,7 @@ def main(args: Optional[list] = None): pass +# TODO (LM): Remove after testing if __name__ == "__main__": import os diff --git a/tests/gui/learning/test_dialog.py b/tests/gui/learning/test_dialog.py index a244fd2ed..737a7b23c 100644 --- a/tests/gui/learning/test_dialog.py +++ b/tests/gui/learning/test_dialog.py @@ -1,15 +1,26 @@ +import shutil +from typing import Optional, List, Callable, Set +from pathlib import Path +import traceback + +import cattr +import pytest +from qtpy import QtWidgets + from sleap.gui.learning.dialog import LearningDialog, TrainingEditorWidget -from sleap.gui.learning.configs import TrainingConfigFilesWidget -from sleap.gui.learning.configs import ConfigFileInfo +from sleap.gui.learning.configs import ( + TrainingConfigFilesWidget, + ConfigFileInfo, + TrainingConfigsGetter, +) from sleap.gui.learning.scopedkeydict import ( ScopedKeyDict, apply_cfg_transforms_to_key_val_dict, ) from sleap.gui.app import MainWindow +from sleap.io.dataset import Labels from sleap.nn.config import TrainingJobConfig, UNetConfig - -import cattr -from pathlib import Path +from sleap.util import get_package_file def test_use_hidden_params_from_loaded_config( @@ -116,3 +127,181 @@ def test_update_loaded_config(): scoped_cfg.key_val_dict["optimization.augmentation_config.rotation_min_angle"] == -180 ) + + +# Parameters: LearningDialog type (training, inference), config type (trained, untrained) +def test_training_editor_checkbox_states( + qtbot, tmpdir, min_labels: Labels, min_centroid_model_path: str +): + """Test that Use Trained Model and Resume Training checkboxes operate correctly.""" + + def assert_checkbox_states( + ted: TrainingEditorWidget, + use_trained: Optional[bool] = None, + resume_training: Optional[bool] = None, + ): + assert ( + ted._use_trained_model.isChecked() == use_trained + if use_trained is not None + else True + ) + assert ( + ted._resume_training.isChecked() == resume_training + if resume_training is not None + else True + ) + + def switch_states( + ted: TrainingEditorWidget, + prev_use_trained: Optional[bool] = None, + prev_resume_training: Optional[bool] = None, + new_use_trained: Optional[bool] = None, + new_resume_training: Optional[bool] = None, + ): + """Switch the states of the checkboxes.""" + + # Assert previous checkbox state + assert_checkbox_states( + ted, use_trained=prev_use_trained, resume_training=prev_resume_training + ) + + # Switch states + if new_use_trained is not None: + ted._use_trained_model.setChecked(new_use_trained) + if new_resume_training is not None: + ted._resume_training.setChecked(new_resume_training) + + # Assert new checkbox state + assert_checkbox_states( + ted, use_trained=new_use_trained, resume_training=new_resume_training + ) + + def check_resume_training( + ted: TrainingEditorWidget, prev_use_trained: Optional[bool] = None + ): + """Check the Resume Training checkbox.""" + switch_states( + ted, + prev_use_trained=prev_use_trained, + new_use_trained=True, + new_resume_training=True, + ) + assert not ted.use_trained + assert ted.resume_training + + def check_resume_training_00(ted: TrainingEditorWidget): + """Check the Resume Training checkbox when Use Trained is unchecked.""" + check_resume_training(ted, prev_use_trained=False) + + def check_resume_training_10(ted: TrainingEditorWidget): + """Check the Resume Training checkbox when Use Trained is checked.""" + check_resume_training(ted, prev_use_trained=True) + + def check_use_trained(ted: TrainingEditorWidget): + """Check the Use Trained checkbox when Resume Training is unchecked.""" + switch_states(ted, prev_resume_training=False, new_use_trained=True) + assert ted.use_trained + assert not ted.resume_training + + def uncheck_resume_training(ted: TrainingEditorWidget): + """Uncheck the Resume Training checkbox when Use Trained is checked.""" + switch_states(ted, prev_use_trained=True, new_resume_training=False) + assert ted.use_trained + assert not ted.resume_training + + def uncheck_use_trained( + ted: TrainingEditorWidget, prev_resume_training: Optional[bool] = None + ): + """Uncheck the Use Trained checkbox.""" + switch_states( + ted, + prev_resume_training=prev_resume_training, + new_use_trained=False, + new_resume_training=False, + ) + assert not ted.use_trained + assert not ted.resume_training + + def uncheck_use_trained_10(ted: TrainingEditorWidget): + """Uncheck the Use Trained checkbox when Resume Training is unchecked.""" + uncheck_use_trained(ted, prev_resume_training=False) + + def uncheck_use_trained_11(ted: TrainingEditorWidget): + """Uncheck the Use Trained checkbox when Resume Training is checked.""" + uncheck_use_trained(ted, prev_resume_training=True) + + def assert_form_state( + change_state: Callable, + ted: TrainingEditorWidget, + og_form_data: dict, + reset_causing_actions: Set[Callable] = { + check_use_trained, + uncheck_resume_training, + }, + ): + expected_form_data = dict() + actual_form_data = dict() + + # Read form values before changing state + if change_state not in reset_causing_actions: + for key in ted.form_widgets.keys(): + expected_form_data[key] = ted.form_widgets[key].get_form_data() + + # Change state + change_state(ted) + + # Modify expected form values depending on state, and check if form is enabled + if ted.resume_training: + expected_form_data["model"] = og_form_data["model"] + elif ted.use_trained: + expected_form_data = og_form_data + + # Read form values after changing state + for key in ted.form_widgets.keys(): + actual_form_data[key] = ted.form_widgets[key].get_form_data() + assert expected_form_data == actual_form_data + + # Load the data + labels: Labels = min_labels + video = labels.video + skeleton = labels.skeleton + model_path = Path(min_centroid_model_path) + head_name = (model_path.name).split(".")[-1] + mode = "training" + + # Create a training TrainingEditorWidget + cfg_getter = TrainingConfigsGetter( + dir_paths=[str(model_path)], head_filter=head_name + ) + ted = TrainingEditorWidget( + video=video, + skeleton=skeleton, + head=head_name, + cfg_getter=cfg_getter, + require_trained=(mode == "inference"), + ) + ted.update_file_list() + + og_form_data = dict() + for key in ted.form_widgets.keys(): + og_form_data[key] = ted.form_widgets[key].get_form_data() + + # The action trajectory below should cover the entire state space of the checkboxes + action_trajectory: List[Callable] = [ + check_resume_training_00, + uncheck_use_trained_11, + check_use_trained, + check_resume_training_10, + uncheck_resume_training, + uncheck_use_trained_10, + ] + for action in action_trajectory: + assert_form_state(action, ted, og_form_data) + + # TODO (LM): Add test for when an untrained model is selected (check that boxes are unchecked) + # TODO (LM): Add test for when mode is inference (check that boxes are unchecked) + + +# TODO (LM): Remove after testing +if __name__ == "__main__": + pytest.main([f"{__file__}::test_training_editor_checkbox_states", "-vv", "-rP"]) From f330c1daab98ec592ef0c96103043f5aba7128be Mon Sep 17 00:00:00 2001 From: roomrys <38435167+roomrys@users.noreply.github.com> Date: Thu, 9 Feb 2023 11:48:52 -0800 Subject: [PATCH 6/8] Finish combobox test --- tests/gui/learning/test_dialog.py | 48 ++++++++++++++++++++++++------- 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/tests/gui/learning/test_dialog.py b/tests/gui/learning/test_dialog.py index 737a7b23c..4616f5e38 100644 --- a/tests/gui/learning/test_dialog.py +++ b/tests/gui/learning/test_dialog.py @@ -129,7 +129,6 @@ def test_update_loaded_config(): ) -# Parameters: LearningDialog type (training, inference), config type (trained, untrained) def test_training_editor_checkbox_states( qtbot, tmpdir, min_labels: Labels, min_centroid_model_path: str ): @@ -266,12 +265,20 @@ def assert_form_state( video = labels.video skeleton = labels.skeleton model_path = Path(min_centroid_model_path) - head_name = (model_path.name).split(".")[-1] - mode = "training" + + # Spoof an untrained model + untrained_model_path = Path(tmpdir, model_path.parts[-1]) + untrained_model_path.mkdir() + shutil.copy( + Path(model_path, "training_config.json"), + Path(untrained_model_path, "training_config.json"), + ) # Create a training TrainingEditorWidget + head_name = (model_path.name).split(".")[-1] + mode = "training" cfg_getter = TrainingConfigsGetter( - dir_paths=[str(model_path)], head_filter=head_name + dir_paths=[str(model_path), str(untrained_model_path)], head_filter=head_name ) ted = TrainingEditorWidget( video=video, @@ -298,10 +305,29 @@ def assert_form_state( for action in action_trajectory: assert_form_state(action, ted, og_form_data) - # TODO (LM): Add test for when an untrained model is selected (check that boxes are unchecked) - # TODO (LM): Add test for when mode is inference (check that boxes are unchecked) - - -# TODO (LM): Remove after testing -if __name__ == "__main__": - pytest.main([f"{__file__}::test_training_editor_checkbox_states", "-vv", "-rP"]) + # Test the case where the user selectes untrained model + ted._cfg_list_widget.setCurrentIndex(1) + assert not ted.has_trained_config_selected + assert not ted._resume_training.isChecked() + assert not ted._resume_training.isVisible() + assert not ted._use_trained_model.isVisible() + assert not ted._use_trained_model.isChecked() + assert not ted.use_trained + assert not ted.resume_training + + # Test the case where the user opts to perform inference + mode = "inference" + ted = TrainingEditorWidget( + video=video, + skeleton=skeleton, + head=head_name, + cfg_getter=cfg_getter, + require_trained=(mode == "inference"), + ) + ted.update_file_list() + assert len(ted._cfg_list_widget._cfg_list) == 1 + assert ted.has_trained_config_selected + assert ted._resume_training is None + assert ted._use_trained_model is None + assert ted.use_trained + assert not ted.resume_training From 22af25dff03f32e11df209d77a9dc7dfd72cbc9f Mon Sep 17 00:00:00 2001 From: roomrys <38435167+roomrys@users.noreply.github.com> Date: Thu, 9 Feb 2023 12:39:15 -0800 Subject: [PATCH 7/8] Test that form is reset --- tests/gui/learning/test_dialog.py | 36 +++++++++++++++++++++++-------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/tests/gui/learning/test_dialog.py b/tests/gui/learning/test_dialog.py index 4616f5e38..2bf9e63b3 100644 --- a/tests/gui/learning/test_dialog.py +++ b/tests/gui/learning/test_dialog.py @@ -239,25 +239,24 @@ def assert_form_state( }, ): expected_form_data = dict() - actual_form_data = dict() # Read form values before changing state if change_state not in reset_causing_actions: - for key in ted.form_widgets.keys(): - expected_form_data[key] = ted.form_widgets[key].get_form_data() + expected_form_data = ted.get_all_form_data() # Change state change_state(ted) # Modify expected form values depending on state, and check if form is enabled if ted.resume_training: - expected_form_data["model"] = og_form_data["model"] + for key, val in og_form_data.items(): + if key.startswith("model."): + expected_form_data[key] = val elif ted.use_trained: expected_form_data = og_form_data # Read form values after changing state - for key in ted.form_widgets.keys(): - actual_form_data[key] = ted.form_widgets[key].get_form_data() + actual_form_data = ted.get_all_form_data() assert expected_form_data == actual_form_data # Load the data @@ -289,9 +288,15 @@ def assert_form_state( ) ted.update_file_list() - og_form_data = dict() - for key in ted.form_widgets.keys(): - og_form_data[key] = ted.form_widgets[key].get_form_data() + og_form_data = ted.get_all_form_data() + + # Modify the form data + copy_form_data_everything = og_form_data.copy() + copy_form_data_everything["data.labels.validation_fraction"] = 0.3 + copy_form_data_everything["optimization.augmentation_config.rotate"] = True + copy_form_data_everything["optimization.epochs"] = 50 + copy_form_data_except_model = copy_form_data_everything.copy() + copy_form_data_everything["_backbone_name"] = "leap" # The action trajectory below should cover the entire state space of the checkboxes action_trajectory: List[Callable] = [ @@ -302,8 +307,21 @@ def assert_form_state( uncheck_resume_training, uncheck_use_trained_10, ] + + actions_that_allow_change_everything_except_model = { + check_resume_training_00, + check_resume_training_10, + } + actions_that_allow_change_everything = { + uncheck_use_trained_10, + uncheck_use_trained_10, + } for action in action_trajectory: assert_form_state(action, ted, og_form_data) + if action in actions_that_allow_change_everything: + ted.set_fields_from_key_val_dict(copy_form_data_everything) + elif action in actions_that_allow_change_everything_except_model: + ted.set_fields_from_key_val_dict(copy_form_data_except_model) # Test the case where the user selectes untrained model ted._cfg_list_widget.setCurrentIndex(1) From aaa0f3b1e373dfcb98c4721f739328de423fc1d2 Mon Sep 17 00:00:00 2001 From: roomrys <38435167+roomrys@users.noreply.github.com> Date: Thu, 9 Feb 2023 12:43:00 -0800 Subject: [PATCH 8/8] Remove straggling TODO --- sleap/gui/learning/dialog.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sleap/gui/learning/dialog.py b/sleap/gui/learning/dialog.py index f65bf9f85..a8ea86015 100644 --- a/sleap/gui/learning/dialog.py +++ b/sleap/gui/learning/dialog.py @@ -628,7 +628,7 @@ def run(self): """Run with current dialog settings.""" pipeline_form_data = self.pipeline_form_widget.get_form_data() - # TODO: debug ^^^ + items_for_inference = self.get_items_for_inference(pipeline_form_data) config_info_list = self.get_every_head_config_data(pipeline_form_data)