Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finish-up last resumable training PR #1165

Closed
wants to merge 13 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions docs/guides/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,9 @@ optional arguments:
only a single --output argument was specified, the
analysis file for the latter video is given a default name.
--format FORMAT Output format. Default ('slp') is SLEAP dataset;
'analysis' results in analysis.h5 file; 'h5' or 'json'
results in SLEAP dataset with specified file format.
'analysis' results in analysis.h5 file; 'analysis.nix' results
in an analysis nix file; 'h5' or 'json' results in SLEAP dataset
with specified file format.
--video VIDEO Path to video (if needed for conversion).
```

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@ rich==10.16.1
certifi>=2017.4.17,<=2021.10.8
pynwb
ndx-pose
nixio>=1.5.3
16 changes: 12 additions & 4 deletions sleap/gui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,8 @@ def _create_video_player(self):
self.player = QtVideoPlayer(
color_manager=self.color_manager, state=self.state, context=self.commands
)
self.player.changedPlot.connect(self._after_plot_update)
self.player.changedPlot.connect(self._after_plot_change)
self.player.updatedPlot.connect(self._after_plot_update)

self.player.view.instanceDoubleClicked.connect(
self._handle_instance_double_click
Expand Down Expand Up @@ -1185,14 +1186,16 @@ def goto_suggestion(*args):

def _load_overlays(self):
"""Load all standard video overlays."""
self.overlays["track_labels"] = TrackListOverlay(self.labels, self.player)
self.overlays["track_labels"] = TrackListOverlay(
labels=self.labels, player=self.player
)
self.overlays["trails"] = TrackTrailOverlay(
labels=self.labels,
player=self.player,
trail_shade=self.state["trail_shade"],
)
self.overlays["instance"] = InstanceOverlay(
self.labels, self.player, self.state
labels=self.labels, player=self.player, state=self.state
)

# When gui state changes, we also want to set corresponding attribute
Expand Down Expand Up @@ -1380,7 +1383,12 @@ def plotFrame(self, *args, **kwargs):

self.player.plot()

def _after_plot_update(self, player, frame_idx, selected_inst):
def _after_plot_update(self, frame_idx):
"""Run after plot is updated, but stay on same frame."""
overlay: TrackTrailOverlay = self.overlays["trails"]
overlay.redraw(self.state["video"], frame_idx)

def _after_plot_change(self, player, frame_idx, selected_inst):
"""Called each time a new frame is drawn."""

# Store the current LabeledFrame (or make new, empty object)
Expand Down
71 changes: 50 additions & 21 deletions sleap/gui/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ class which inherits from `AppCommand` (or a more specialized class such as

from qtpy import QtCore, QtWidgets, QtGui

from qtpy.QtWidgets import QMessageBox, QProgressDialog

from sleap.skeleton import Node, Skeleton
from sleap.instance import Instance, PredictedInstance, Point, Track, LabeledFrame
from sleap.io.video import Video
Expand Down Expand Up @@ -672,7 +670,7 @@ def ask(context: "CommandContext", params: dict):
has_loaded = True
except ValueError as e:
print(e)
QMessageBox(text=f"Unable to load {filename}.").exec_()
QtWidgets.QMessageBox(text=f"Unable to load {filename}.").exec_()

params["labels"] = labels

Expand Down Expand Up @@ -1100,12 +1098,23 @@ def ask(context: CommandContext, params: dict) -> bool:


class ExportAnalysisFile(AppCommand):
export_formats = {
"SLEAP Analysis HDF5 (*.h5)": "h5",
"NIX for Tracking data (*.nix)": "nix",
}
export_filter = ";;".join(export_formats.keys())

@classmethod
def do_action(cls, context: CommandContext, params: dict):
from sleap.io.format.sleap_analysis import SleapAnalysisAdaptor
from sleap.io.format.nix import NixAdaptor

for output_path, video in params["analysis_videos"]:
SleapAnalysisAdaptor.write(
if Path(output_path).suffix[1:] == "nix":
adaptor = NixAdaptor
else:
adaptor = SleapAnalysisAdaptor
adaptor.write(
filename=output_path,
source_object=context.labels,
source_path=context.state["filename"],
Expand All @@ -1120,14 +1129,14 @@ def ask_for_filename(default_name: str) -> str:
context.app,
caption="Export Analysis File...",
dir=default_name,
filter="SLEAP Analysis HDF5 (*.h5)",
filter=ExportAnalysisFile.export_filter,
)
return filename

# Ensure labels has labeled frames
labels = context.labels
if len(labels.labeled_frames) == 0:
return False
raise ValueError("No labeled frames in project. Nothing to export.")

# Get a subset of videos
if params["all_videos"]:
Expand All @@ -1138,11 +1147,12 @@ def ask_for_filename(default_name: str) -> str:
# Only use videos with labeled frames
videos = [video for video in all_videos if len(labels.get(video)) != 0]
if len(videos) == 0:
return False
raise ValueError("No labeled frames in video(s). Nothing to export.")

# Specify (how to get) the output filename
default_name = context.state["filename"] or "labels"
fn = PurePath(default_name)
file_extension = "h5"
if len(videos) == 1:
# Allow user to specify the filename
use_default = False
Expand All @@ -1155,6 +1165,18 @@ def ask_for_filename(default_name: str) -> str:
caption="Select Folder to Export Analysis Files...",
dir=str(fn.parent),
)
if len(ExportAnalysisFile.export_formats) > 1:
item, ok = QtWidgets.QInputDialog.getItem(
context.app,
"Select export format",
"Available export formats",
list(ExportAnalysisFile.export_formats.keys()),
0,
False,
)
if not ok:
return False
file_extension = ExportAnalysisFile.export_formats[item]
if len(dirname) == 0:
return False

Expand All @@ -1168,9 +1190,10 @@ def ask_for_filename(default_name: str) -> str:
video=video,
output_path=dirname,
output_prefix=str(fn.stem),
format_suffix=file_extension,
)
filename = default_name if use_default else ask_for_filename(default_name)

filename = default_name if use_default else ask_for_filename(default_name)
# Check that filename is valid and create list of video / output paths
if len(filename) != 0:
analysis_videos.append(video)
Expand Down Expand Up @@ -1327,7 +1350,9 @@ def export_dataset_gui(
instances.
suggested: If `True`, include image data for suggested frames.
"""
win = QProgressDialog("Exporting dataset with frame images...", "Cancel", 0, 1)
win = QtWidgets.QProgressDialog(
"Exporting dataset with frame images...", "Cancel", 0, 1
)

def update_progress(n, n_total):
if win.wasCanceled():
Expand Down Expand Up @@ -1729,7 +1754,7 @@ def _get_truncation_message(truncation_messages, path, video):

# Warn user: newly added labels will be discarded if project is not saved
if not context.state["filename"] or context.state["has_changes"]:
QMessageBox(
QtWidgets.QMessageBox(
text=("You have unsaved changes. Please save before replacing videos.")
).exec_()
return False
Expand Down Expand Up @@ -1799,15 +1824,15 @@ def ask(context: CommandContext, params: dict) -> bool:

# Warn if there are labels that will be deleted
if n > 0:
response = QMessageBox.critical(
response = QtWidgets.QMessageBox.critical(
context.app,
"Removing video with labels",
f"{n} labeled frames in this video will be deleted, "
"are you sure you want to remove this video?",
QMessageBox.Yes,
QMessageBox.No,
QtWidgets.QMessageBox.Yes,
QtWidgets.QMessageBox.No,
)
if response == QMessageBox.No:
if response == QtWidgets.QMessageBox.No:
return False

params["video"] = video
Expand Down Expand Up @@ -2104,11 +2129,15 @@ def _confirm_deletion(context: CommandContext, lf_inst_list: List) -> bool:
)

# Confirm that we want to delete
resp = QMessageBox.critical(
context.app, title, message, QMessageBox.Yes, QMessageBox.No
resp = QtWidgets.QMessageBox.critical(
context.app,
title,
message,
QtWidgets.QMessageBox.Yes,
QtWidgets.QMessageBox.No,
)

if resp == QMessageBox.No:
if resp == QtWidgets.QMessageBox.No:
return False

return True
Expand Down Expand Up @@ -2565,14 +2594,14 @@ def ask(context: CommandContext, params: dict) -> bool:

# Warn that suggestions will be cleared

response = QMessageBox.warning(
response = QtWidgets.QMessageBox.warning(
context.app,
"Clearing all suggestions",
"Are you sure you want to remove all suggestions from the project?",
QMessageBox.Yes,
QMessageBox.No,
QtWidgets.QMessageBox.Yes,
QtWidgets.QMessageBox.No,
)
if response == QMessageBox.No:
if response == QtWidgets.QMessageBox.No:
return False

return True
Expand Down
10 changes: 6 additions & 4 deletions sleap/gui/dataviews.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,14 @@ def setData(self, index: QtCore.QModelIndex, value: str, role=QtCore.Qt.EditRole
item, key = self.get_from_idx(index)

# If nothing changed of the item, return true. (Issue #1013)
if isinstance(item, dict) and key in item:
item_value = item[key]
if hasattr(item, key):
if isinstance(item, dict):
item_value = item.get(key, None)
elif hasattr(item, key):
item_value = getattr(item, key)
else:
item_value = None

if item_value == value:
if (item_value is not None) and (item_value == value):
return True

# Otherwise set the item
Expand Down
55 changes: 41 additions & 14 deletions sleap/gui/learning/dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,43 @@ def merge_pipeline_and_head_config_data(self, head_name, head_data, pipeline_dat
continue
head_data[key] = val

@staticmethod
def update_loaded_config(
loaded_cfg: configs.TrainingJobConfig, tab_cfg_key_val_dict: dict
) -> scopedkeydict.ScopedKeyDict:
"""Update a loaded preset config with values from the training editor.

Args:
loaded_cfg: A `TrainingJobConfig` that was loaded from a preset or previous
training run.
tab_cfg_key_val_dict: A dictionary with the values extracted from the training
editor GUI tab.

Returns:
A `ScopedKeyDict` with the loaded config values overriden by the corresponding
ones from the `tab_cfg_key_val_dict`.
"""
# Serialize training config
loaded_cfg_hierarchical: dict = cattr.unstructure(loaded_cfg)

# Clear backbone subfields since these will be set by the GUI
if (
"model" in loaded_cfg_hierarchical
and "backbone" in loaded_cfg_hierarchical["model"]
):
for k in loaded_cfg_hierarchical["model"]["backbone"]:
loaded_cfg_hierarchical["model"]["backbone"][k] = None

loaded_cfg_scoped: scopedkeydict.ScopedKeyDict = (
scopedkeydict.ScopedKeyDict.from_hierarchical_dict(loaded_cfg_hierarchical)
)

# Replace params exposed in GUI with values from GUI
for param, value in tab_cfg_key_val_dict.items():
loaded_cfg_scoped.key_val_dict[param] = value

return loaded_cfg_scoped

def get_every_head_config_data(
self, pipeline_form_data
) -> List[configs.ConfigFileInfo]:
Expand All @@ -436,24 +473,14 @@ def get_every_head_config_data(
scopedkeydict.apply_cfg_transforms_to_key_val_dict(tab_cfg_key_val_dict)

if trained_cfg_info is None:
# Config could not be loaded
# Config could not be loaded, just use the values from the GUI
loaded_cfg_scoped: dict = tab_cfg_key_val_dict
else:
# Config loaded
loaded_cfg: configs.TrainingJobConfig = trained_cfg_info.config

# Serialize and flatten training config
loaded_cfg_heirarchical: dict = cattr.unstructure(loaded_cfg)
loaded_cfg_scoped: scopedkeydict.ScopedKeyDict = (
scopedkeydict.ScopedKeyDict.from_hierarchical_dict(
loaded_cfg_heirarchical
)
# Config was loaded, override with the values from the GUI
loaded_cfg_scoped = LearningDialog.update_loaded_config(
trained_cfg_info.config, tab_cfg_key_val_dict
)

# Replace params exposed in GUI with values from GUI
for param, value in tab_cfg_key_val_dict.items():
loaded_cfg_scoped.key_val_dict[param] = value

# Deserialize merged dict to object
cfg = scopedkeydict.make_training_config_from_key_val_dict(
loaded_cfg_scoped
Expand Down
2 changes: 1 addition & 1 deletion sleap/gui/learning/scopedkeydict.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def set_hierarchical_key_val(cls, current_dict: dict, key: Text, val: Any):
current_dict[key] = val
else:
top_key, *subkey_list = key.split(".")
if top_key not in current_dict:
if top_key not in current_dict or current_dict[top_key] is None:
current_dict[top_key] = dict()
subkey = ".".join(subkey_list)
cls.set_hierarchical_key_val(current_dict[top_key], subkey, val)
Expand Down
Loading