Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add finishing touches to resumable training PR (#1150) #1168

Merged
merged 8 commits into from
Feb 9, 2023
14 changes: 14 additions & 0 deletions sleap/gui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -1835,3 +1835,17 @@ def main(args: Optional[list] = None):
app.exec_()

pass


# TODO (LM): Remove after testing
if __name__ == "__main__":
import os

ds = os.environ["ds-dmc"]

app = QApplication([])

window = MainWindow(labels_path=ds, no_usage_data=True)
window._show_learning_dialog("inference")

app.exec_()
176 changes: 106 additions & 70 deletions sleap/gui/learning/dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from sleap.gui.dialogs.formbuilder import YamlFormWidget
from sleap.gui.learning import runners, scopedkeydict, configs, datagen, receptivefield

from typing import Dict, List, Optional, Text, Optional
from typing import Dict, List, Optional, Text, Optional, cast

from qtpy import QtWidgets, QtCore

Expand Down Expand Up @@ -75,7 +75,7 @@ def __init__(

self.current_pipeline = ""

self.tabs = dict()
self.tabs: Dict[str, TrainingEditorWidget] = dict()
self.shown_tab_names = []

self._cfg_getter = configs.TrainingConfigsGetter.make_from_labels_filename(
Expand Down Expand Up @@ -628,7 +628,7 @@ def run(self):
"""Run with current dialog settings."""

pipeline_form_data = self.pipeline_form_widget.get_form_data()
# TODO: debug ^^^

items_for_inference = self.get_items_for_inference(pipeline_form_data)

config_info_list = self.get_every_head_config_data(pipeline_form_data)
Expand Down Expand Up @@ -907,7 +907,7 @@ def __init__(

yaml_name = "training_editor_form"

self.form_widgets = dict()
self.form_widgets: Dict[str, YamlFormWidget] = dict()

for key in ("model", "data", "augmentation", "optimization", "outputs"):
self.form_widgets[key] = YamlFormWidget.from_name(
Expand Down Expand Up @@ -962,11 +962,10 @@ def __init__(

# If we have an object which gets a list of config files,
# then we'll show a menu to allow selection from the list.

if self._cfg_getter:
if self._cfg_getter is not None:
self._cfg_list_widget = configs.TrainingConfigFilesWidget(
cfg_getter=self._cfg_getter,
head_name=head,
head_name=cast(str, head), # Expect head to be a string
require_trained=require_trained,
)
self._cfg_list_widget.onConfigSelection.connect(
Expand All @@ -976,25 +975,22 @@ def __init__(

layout.addWidget(self._cfg_list_widget)

# Add option for using trained model from selected config
if self._require_trained:
self._update_use_trained()
else:
self._use_trained_model = QtWidgets.QCheckBox("Use Trained Model")
self._use_trained_model.setEnabled(False)
self._use_trained_model.setVisible(False)
self._resume_training = QtWidgets.QCheckBox("Resume Training")
self._resume_training.setEnabled(False)
self._resume_training.setVisible(False)

self._use_trained_model.stateChanged.connect(self._update_use_trained)
self._resume_training.stateChanged.connect(self._update_use_trained)
if self._require_trained:
self._update_use_trained()
elif self._cfg_list_widget is not None:
# Add option for using trained model from selected config file
self._use_trained_model = QtWidgets.QCheckBox("Use Trained Model")
self._use_trained_model.setEnabled(False)
self._use_trained_model.setVisible(False)
self._resume_training = QtWidgets.QCheckBox("Resume Training")
self._resume_training.setEnabled(False)
self._resume_training.setVisible(False)

layout.addWidget(self._use_trained_model)
layout.addWidget(self._resume_training)
self._use_trained_model.stateChanged.connect(self._update_use_trained)
self._resume_training.stateChanged.connect(self._update_use_trained)

elif self._require_trained:
self._update_use_trained()
layout.addWidget(self._use_trained_model)
layout.addWidget(self._resume_training)

layout.addWidget(self._layout_widget(col_layout))
self.setLayout(layout)
Expand Down Expand Up @@ -1024,10 +1020,15 @@ def acceptSelectedConfigInfo(self, cfg_info: configs.ConfigFileInfo):
self._load_config(cfg_info)

has_trained_model = cfg_info.has_trained_model
if self._use_trained_model:
if self._use_trained_model is not None:
self._use_trained_model.setChecked(self._require_trained)
self._use_trained_model.setVisible(has_trained_model)
self._use_trained_model.setEnabled(has_trained_model)
# Redundant check (for readability) since this checkbox exists if the above does
if self._resume_training is not None:
self._use_trained_model.setChecked(False)
self._resume_training.setVisible(has_trained_model)
self._resume_training.setEnabled(has_trained_model)

self.update_receptive_field()

Expand Down Expand Up @@ -1068,40 +1069,64 @@ def _load_config(self, cfg_info: configs.ConfigFileInfo):
# self._cfg_list_widget.setUserConfigData(cfg_form_data_dict)

def _update_use_trained(self, check_state=0):
if self._require_trained:
use_trained = True
else:
use_trained = check_state == QtCore.Qt.CheckState.Checked
"""Update config GUI based on _use_trained_model and _resume_training checkboxes.

# check if self._resume_training is None, because it might not be initialized yet
resume_training = self._resume_training and self._resume_training.isChecked()
This function is called when either _use_trained_model or _resume_training checkbox
is checked/unchecked or when _require_trained is changed.

if self._use_trained_model is not None:
for form in self.form_widgets.values():
form.set_enabled(
(not self._use_trained_model.isChecked())
or (
self._use_trained_model.isChecked()
and self._resume_training.isChecked()
)
)
If _require_trained is True, then we'll disable all fields.
If _use_trained_model is checked, then we'll disable all fields.
If _resume_training is checked, then we'll disable only the model field.

if self._use_trained_model.isChecked():
self._resume_training.setEnabled(True)
Args:
check_state (int, optional): Check state of checkbox. Defaults to 0. Unused.

if resume_training:
self.form_widgets["model"].set_enabled(False)
else:
# change the checkbox to unchecked
self._resume_training.setChecked(False)
self._resume_training.setEnabled(False)
else:
for form in self.form_widgets.values():
form.set_enabled(not use_trained)
Returns:
None

Side Effects:
Disables/Enables fields based on checkbox values (and _required_training).
"""

# If user wants to use trained model, then reset form to match config
if use_trained and self._cfg_list_widget and (not resume_training):
# Check which checkbox changed its value (if any)
sender = self.sender()

if sender is None: # If sender is None, then _required_training is True
pass
# Uncheck _resume_training checkbox if _use_trained_model is unchecked
elif (sender == self._use_trained_model) and (
not self._use_trained_model.isChecked()
):
self._resume_training.setChecked(False)

# Check _use_trained_model checkbox if _resume_training is checked
elif (sender == self._resume_training) and self._resume_training.isChecked():
self._use_trained_model.setChecked(True)

# Update form widgets
use_trained_params = self.use_trained
use_model_params = self.resume_training
for form in self.form_widgets.values():
form.set_enabled(not use_trained_params)

if use_trained_params or use_model_params:
cfg_info = self._cfg_list_widget.getSelectedConfigInfo()

# If user wants to resume training, then reset only model form to match config
if use_model_params:
self.form_widgets["model"].set_enabled(False)

# Set model form to match config
cfg = cfg_info.config
cfg_dict = cattr.unstructure(cfg)
model_dict = {"model": cfg_dict["model"]}
key_val_dict = scopedkeydict.ScopedKeyDict.from_hierarchical_dict(
model_dict
).key_val_dict
self.set_fields_from_key_val_dict(key_val_dict)

# If user wants to use trained model, then reset entire form to match config
if use_trained_params:
self._load_config(cfg_info)

self._set_head()
Expand Down Expand Up @@ -1131,29 +1156,36 @@ def _set_backbone_from_key_val_dict(self, cfg_key_val_dict):

@property
def use_trained(self) -> bool:
use_trained = False
if self._require_trained:
use_trained = True
elif self._use_trained_model and self._use_trained_model.isChecked():
use_trained = True
return use_trained
if self._require_trained or (
(self._use_trained_model is not None)
and self._use_trained_model.isChecked()
and (not self.resume_training)
):
return True

return False

@property
def resume_training(self) -> bool:
resume_training = False
if self._resume_training and self._resume_training.isChecked():
resume_training = True
return resume_training
if (self._resume_training is not None) and self._resume_training.isChecked():
return True
return False

@property
def trained_config_info_to_use(self) -> Optional[configs.ConfigFileInfo]:
trained_config_info = self._cfg_list_widget.getSelectedConfigInfo()
if trained_config_info is None:
# If `TrainingEditorWidget` was initialized with a config getter, then
# we expect to have a list of config files
if self._cfg_list_widget is None:
return None

trained_config_info: Optional[
configs.ConfigFileInfo
] = self._cfg_list_widget.getSelectedConfigInfo()
if (trained_config_info is None) or (not trained_config_info.has_trained_model):
return None

if self.use_trained:
if not self.resume_training:
trained_config_info.dont_retrain = True
trained_config_info.dont_retrain = True
else:
# Set certain parameters to defaults
trained_config = trained_config_info.config
Expand All @@ -1163,9 +1195,9 @@ def trained_config_info_to_use(self) -> Optional[configs.ConfigFileInfo]:
trained_config.outputs.run_name_suffix = None

if self.resume_training:
# get the folder path of trained_config_info.path and set it as the output folder
trained_config_info.config.model.base_checkpoint = os.path.dirname(
trained_config_info.path
# Get the folder path of trained config and set it as the output folder
trained_config_info.config.model.base_checkpoint = str(
Path(cast(str, trained_config_info.path)).parent
)
else:
trained_config_info.config.model.base_checkpoint = None
Expand All @@ -1174,9 +1206,13 @@ def trained_config_info_to_use(self) -> Optional[configs.ConfigFileInfo]:

@property
def has_trained_config_selected(self) -> bool:
if self._cfg_list_widget is None:
return False

cfg_info = self._cfg_list_widget.getSelectedConfigInfo()
if cfg_info and cfg_info.has_trained_model:
return True

return False

def get_all_form_data(self) -> dict:
Expand Down
Loading