From 7ef7330a3747f218d7f8872396473011a6152097 Mon Sep 17 00:00:00 2001 From: Liezl Maree <38435167+roomrys@users.noreply.github.com> Date: Thu, 19 Jan 2023 11:11:03 -0800 Subject: [PATCH 01/10] Create signal that updates plot instead of removing and replotting items (#1134) * Create signal that updates plot instead of redrawing * Remove debug code * Non-functional self-review changes --- sleap/gui/app.py | 16 +++++++++--- sleap/gui/overlays/base.py | 48 ++++++++++++++++++++++++++++++------ sleap/gui/overlays/tracks.py | 4 ++- sleap/gui/widgets/video.py | 8 +++++- 4 files changed, 63 insertions(+), 13 deletions(-) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index de79a7d85..4a3f6da95 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -311,7 +311,8 @@ def _create_video_player(self): self.player = QtVideoPlayer( color_manager=self.color_manager, state=self.state, context=self.commands ) - self.player.changedPlot.connect(self._after_plot_update) + self.player.changedPlot.connect(self._after_plot_change) + self.player.updatedPlot.connect(self._after_plot_update) self.player.view.instanceDoubleClicked.connect( self._handle_instance_double_click @@ -1185,14 +1186,16 @@ def goto_suggestion(*args): def _load_overlays(self): """Load all standard video overlays.""" - self.overlays["track_labels"] = TrackListOverlay(self.labels, self.player) + self.overlays["track_labels"] = TrackListOverlay( + labels=self.labels, player=self.player + ) self.overlays["trails"] = TrackTrailOverlay( labels=self.labels, player=self.player, trail_shade=self.state["trail_shade"], ) self.overlays["instance"] = InstanceOverlay( - self.labels, self.player, self.state + labels=self.labels, player=self.player, state=self.state ) # When gui state changes, we also want to set corresponding attribute @@ -1380,7 +1383,12 @@ def plotFrame(self, *args, **kwargs): self.player.plot() - def _after_plot_update(self, player, frame_idx, selected_inst): + def _after_plot_update(self, frame_idx): + """Run after plot is updated, but stay on same frame.""" + overlay: TrackTrailOverlay = self.overlays["trails"] + overlay.redraw(self.state["video"], frame_idx) + + def _after_plot_change(self, player, frame_idx, selected_inst): """Called each time a new frame is drawn.""" # Store the current LabeledFrame (or make new, empty object) diff --git a/sleap/gui/overlays/base.py b/sleap/gui/overlays/base.py index 60b04de1d..f648c5a43 100644 --- a/sleap/gui/overlays/base.py +++ b/sleap/gui/overlays/base.py @@ -13,7 +13,9 @@ import attr import abc import numpy as np -from typing import Sequence, Union +from typing import Sequence, Union, Optional, List + +from qtpy.QtWidgets import QGraphicsItem from sleap import Labels, Video from sleap.gui.widgets.video import QtVideoPlayer @@ -23,20 +25,52 @@ @attr.s(auto_attribs=True) class BaseOverlay(abc.ABC): - """ - Abstract base class for overlays. + """Abstract base class for overlays. Most overlays need rely on the `Labels` from which to get data and need the - `QtVideoPlayer` to which a `QGraphicsObject` item will be added, so these + `QtVideoPlayer` to which a `QGraphicsItem` will be added, so these attributes are included in the base class. + + Args: + labels: the `Labels` from which to get data + player: the `QtVideoPlayer` to which a `QGraphicsObject` item will be added + items: stores all `QGraphicsItem`s currently added to the player from this + overlay """ - labels: Labels = None - player: QtVideoPlayer = None + labels: Optional[Labels] = None + player: Optional[QtVideoPlayer] = None + items: Optional[List[QGraphicsItem]] = None @abc.abstractmethod def add_to_scene(self, video: Video, frame_idx: int): - pass + """Add items to scene. + + To use the `remove_from_scene` and `redraw` methods, keep track of a list of + `QGraphicsItem`s added in this function. + """ + # Start your method with: + self.items = [] + + # As items are added to the `QtVideoPlayer`, keep track of these items: + item = self.player.scene.addItem(...) + self.items.append(item) + + def remove_from_scene(self): + """Remove all items added to scene by this overlay. + + This method does not need to be called when changing the plot to a new frame. + """ + for item in self.items: + self.player.scene.removeItem(item) + + def redraw(self, video, frame_idx, *args, **kwargs): + """Remove all items from the scene before adding new items to the scene. + + This method does not need to be called when changing the plot to a new frame. + """ + self.remove_from_scene(*args, **kwargs) + self.add_to_scene(video, frame_idx, *args, **kwargs) @attr.s(auto_attribs=True) diff --git a/sleap/gui/overlays/tracks.py b/sleap/gui/overlays/tracks.py index 9382bab55..361585719 100644 --- a/sleap/gui/overlays/tracks.py +++ b/sleap/gui/overlays/tracks.py @@ -144,6 +144,7 @@ def add_to_scene(self, video: Video, frame_idx: int): frame_idx: index of the frame to which the trail is attached """ + self.items = [] if not self.show or self.trail_length == 0: return @@ -188,7 +189,8 @@ def add_to_scene(self, video: Video, frame_idx: int): for segment in segments: pen.setWidthF(width) path = self.map_to_qt_path(segment) - self.player.scene.addPath(path, pen) + item = self.player.scene.addPath(path, pen) + self.items.append(item) width /= 2 @staticmethod diff --git a/sleap/gui/widgets/video.py b/sleap/gui/widgets/video.py index 53d2bb35f..865cae97e 100644 --- a/sleap/gui/widgets/video.py +++ b/sleap/gui/widgets/video.py @@ -193,6 +193,7 @@ class QtVideoPlayer(QWidget): Signals: * changedPlot: Emitted whenever the plot is redrawn + * updatedPlot: Emitted whenever a node is moved (updates trails overlays) Attributes: video: The :class:`Video` to display @@ -202,6 +203,7 @@ class QtVideoPlayer(QWidget): """ changedPlot = QtCore.Signal(QWidget, int, Instance) + updatedPlot = QtCore.Signal(int) def __init__( self, @@ -489,6 +491,10 @@ def plot(self, *args): self._video_image_loader.video = self.video self._video_image_loader.request(idx) + def update_plot(self): + idx = self.state["frame_idx"] or 0 + self.updatedPlot.emit(idx) + def showInstances(self, show): """Show/hide all instances in viewer. @@ -1568,7 +1574,7 @@ def mouseReleaseEvent(self, event): super(QtNode, self).mouseReleaseEvent(event) self.updatePoint(user_change=True) self.dragParent = False - self.player.plot() # Redraw trails after node is moved + self.player.update_plot() # Redraw trails after node is moved def wheelEvent(self, event): """Custom event handler for mouse scroll wheel.""" From 67c6da69608ad60c97b42b44b1c5a48019560970 Mon Sep 17 00:00:00 2001 From: Liezl Maree <38435167+roomrys@users.noreply.github.com> Date: Thu, 19 Jan 2023 17:44:59 -0800 Subject: [PATCH 02/10] Fix symmetric skeletons (via table input) (#1136) Ensure variable initialized before calling it --- sleap/gui/dataviews.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/sleap/gui/dataviews.py b/sleap/gui/dataviews.py index 47fdf3884..71189657a 100644 --- a/sleap/gui/dataviews.py +++ b/sleap/gui/dataviews.py @@ -150,12 +150,14 @@ def setData(self, index: QtCore.QModelIndex, value: str, role=QtCore.Qt.EditRole item, key = self.get_from_idx(index) # If nothing changed of the item, return true. (Issue #1013) - if isinstance(item, dict) and key in item: - item_value = item[key] - if hasattr(item, key): + if isinstance(item, dict): + item_value = item.get(key, None) + elif hasattr(item, key): item_value = getattr(item, key) + else: + item_value = None - if item_value == value: + if (item_value is not None) and (item_value == value): return True # Otherwise set the item From b660e004e7878e4ddd1edb08a2cd4c13cf591b89 Mon Sep 17 00:00:00 2001 From: Jan Grewe Date: Fri, 20 Jan 2023 18:48:37 +0100 Subject: [PATCH 03/10] Nix export of tracking results (#1068) * [io] export tracking results to NIX file * [io] nix added to export filter only if available * [nixio] refactor, add scores link data as mtag * [nixio] speeding up export by chunked writing * [nixio] rename point score to node score * [nixio] fix missing dimension descriptor for node scores * [export analysis] support multiple formats also for bulk export * [nixio] export centroid, some documentation * [nixio] fix double dot before filename suffix * [nixio] fix bug when not all nodes were found * [nixio] housekeeping * [nix] add nix analyis output format to convert * [nix] tiny fix, catch file write error and properly close file * [inference] main takes optional args. Can be imported to run inference form scripts * [convert] simplify if else structure and outfile handling for analysis export * [nix] use pathlib instead of os * [nix] catch if there are instances with a None frame_idx ... not sure why this occurred. The nix adaptor cannot save instances that are not related to a frame. * [nix] move checks to top of write function * [nix] use absolute imports * [nix] use black to reformat * [commands] revert qtpy import and apply code style * [convert] use absolute imports, apply code style * [commands]fix imports * [inference/nix]fix linter complaint, adjust nix types for scores * [nix] add test case for nix export format * [nix] extended testing, some modifications of adaptor * [skeleton] add __eq__ to Skeleton ... make Node.name and Node.weight instance variables instead of class variables * [nix] add nixio to requirements, remove unused nix_available, ... allow for non-unique entries in node, track and skeleton. Extend node map to store the skeleton it is part of * [nix] make the linter happy * [Node] force definition of a name Co-authored-by: Liezl Maree <38435167+roomrys@users.noreply.github.com> * [nix] use getattr for getting grayscale information Co-authored-by: Liezl Maree <38435167+roomrys@users.noreply.github.com> * [nix] fixes according to review * [convert] break out of loop upon finding the video Co-authored-by: Liezl Maree <38435167+roomrys@users.noreply.github.com> * [commands.py] use pathilb instead of splitting filename Co-authored-by: Liezl Maree <38435167+roomrys@users.noreply.github.com> * [dev requirements] remove linebreak at last line * [skeleton] revert attribute creation back to original * [nix] break lines in class documentation * Ensure all file references are closed * Make the linter happy * Add tests for ExportAnalysis and (docs for) sleap-convert Co-authored-by: Liezl Maree <38435167+roomrys@users.noreply.github.com> --- docs/guides/cli.md | 5 +- requirements.txt | 1 + sleap/gui/commands.py | 69 +++-- sleap/info/write_tracking_h5.py | 8 +- sleap/io/convert.py | 86 +++--- sleap/io/format/nix.py | 463 ++++++++++++++++++++++++++++++++ sleap/nn/inference.py | 2 +- tests/gui/test_commands.py | 37 +-- tests/io/test_convert.py | 24 +- tests/io/test_formats.py | 100 ++++++- 10 files changed, 703 insertions(+), 92 deletions(-) create mode 100644 sleap/io/format/nix.py diff --git a/docs/guides/cli.md b/docs/guides/cli.md index acdba57ef..827678ffb 100644 --- a/docs/guides/cli.md +++ b/docs/guides/cli.md @@ -307,8 +307,9 @@ optional arguments: only a single --output argument was specified, the analysis file for the latter video is given a default name. --format FORMAT Output format. Default ('slp') is SLEAP dataset; - 'analysis' results in analysis.h5 file; 'h5' or 'json' - results in SLEAP dataset with specified file format. + 'analysis' results in analysis.h5 file; 'analysis.nix' results + in an analysis nix file; 'h5' or 'json' results in SLEAP dataset + with specified file format. --video VIDEO Path to video (if needed for conversion). ``` diff --git a/requirements.txt b/requirements.txt index a9cecb183..becc335f0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -34,3 +34,4 @@ rich==10.16.1 certifi>=2017.4.17,<=2021.10.8 pynwb ndx-pose +nixio>=1.5.3 diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index 0ef7d219f..681a7faeb 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -43,8 +43,6 @@ class which inherits from `AppCommand` (or a more specialized class such as from qtpy import QtCore, QtWidgets, QtGui -from qtpy.QtWidgets import QMessageBox, QProgressDialog - from sleap.skeleton import Node, Skeleton from sleap.instance import Instance, PredictedInstance, Point, Track, LabeledFrame from sleap.io.video import Video @@ -1100,12 +1098,23 @@ def ask(context: CommandContext, params: dict) -> bool: class ExportAnalysisFile(AppCommand): + export_formats = { + "SLEAP Analysis HDF5 (*.h5)": "h5", + "NIX for Tracking data (*.nix)": "nix", + } + export_filter = ";;".join(export_formats.keys()) + @classmethod def do_action(cls, context: CommandContext, params: dict): from sleap.io.format.sleap_analysis import SleapAnalysisAdaptor + from sleap.io.format.nix import NixAdaptor for output_path, video in params["analysis_videos"]: - SleapAnalysisAdaptor.write( + if Path(output_path).suffix[1:] == "nix": + adaptor = NixAdaptor + else: + adaptor = SleapAnalysisAdaptor + adaptor.write( filename=output_path, source_object=context.labels, source_path=context.state["filename"], @@ -1120,14 +1129,14 @@ def ask_for_filename(default_name: str) -> str: context.app, caption="Export Analysis File...", dir=default_name, - filter="SLEAP Analysis HDF5 (*.h5)", + filter=ExportAnalysisFile.export_filter, ) return filename # Ensure labels has labeled frames labels = context.labels if len(labels.labeled_frames) == 0: - return False + raise ValueError("No labeled frames in project. Nothing to export.") # Get a subset of videos if params["all_videos"]: @@ -1138,11 +1147,12 @@ def ask_for_filename(default_name: str) -> str: # Only use videos with labeled frames videos = [video for video in all_videos if len(labels.get(video)) != 0] if len(videos) == 0: - return False + raise ValueError("No labeled frames in video(s). Nothing to export.") # Specify (how to get) the output filename default_name = context.state["filename"] or "labels" fn = PurePath(default_name) + file_extension = "h5" if len(videos) == 1: # Allow user to specify the filename use_default = False @@ -1155,6 +1165,18 @@ def ask_for_filename(default_name: str) -> str: caption="Select Folder to Export Analysis Files...", dir=str(fn.parent), ) + if len(ExportAnalysisFile.export_formats) > 1: + item, ok = QtWidgets.QInputDialog.getItem( + context.app, + "Select export format", + "Available export formats", + list(ExportAnalysisFile.export_formats.keys()), + 0, + False, + ) + if not ok: + return False + file_extension = ExportAnalysisFile.export_formats[item] if len(dirname) == 0: return False @@ -1168,9 +1190,10 @@ def ask_for_filename(default_name: str) -> str: video=video, output_path=dirname, output_prefix=str(fn.stem), + format_suffix=file_extension, ) - filename = default_name if use_default else ask_for_filename(default_name) + filename = default_name if use_default else ask_for_filename(default_name) # Check that filename is valid and create list of video / output paths if len(filename) != 0: analysis_videos.append(video) @@ -1327,7 +1350,9 @@ def export_dataset_gui( instances. suggested: If `True`, include image data for suggested frames. """ - win = QProgressDialog("Exporting dataset with frame images...", "Cancel", 0, 1) + win = QtWidgets.QProgressDialog( + "Exporting dataset with frame images...", "Cancel", 0, 1 + ) def update_progress(n, n_total): if win.wasCanceled(): @@ -1729,7 +1754,7 @@ def _get_truncation_message(truncation_messages, path, video): # Warn user: newly added labels will be discarded if project is not saved if not context.state["filename"] or context.state["has_changes"]: - QMessageBox( + QtWidgets.QMessageBox( text=("You have unsaved changes. Please save before replacing videos.") ).exec_() return False @@ -1799,15 +1824,15 @@ def ask(context: CommandContext, params: dict) -> bool: # Warn if there are labels that will be deleted if n > 0: - response = QMessageBox.critical( + response = QtWidgets.QMessageBox.critical( context.app, "Removing video with labels", f"{n} labeled frames in this video will be deleted, " "are you sure you want to remove this video?", - QMessageBox.Yes, - QMessageBox.No, + QtWidgets.QMessageBox.Yes, + QtWidgets.QMessageBox.No, ) - if response == QMessageBox.No: + if response == QtWidgets.QMessageBox.No: return False params["video"] = video @@ -2104,11 +2129,15 @@ def _confirm_deletion(context: CommandContext, lf_inst_list: List) -> bool: ) # Confirm that we want to delete - resp = QMessageBox.critical( - context.app, title, message, QMessageBox.Yes, QMessageBox.No + resp = QtWidgets.QMessageBox.critical( + context.app, + title, + message, + QtWidgets.QMessageBox.Yes, + QtWidgets.QMessageBox.No, ) - if resp == QMessageBox.No: + if resp == QtWidgets.QMessageBox.No: return False return True @@ -2565,14 +2594,14 @@ def ask(context: CommandContext, params: dict) -> bool: # Warn that suggestions will be cleared - response = QMessageBox.warning( + response = QtWidgets.QMessageBox.warning( context.app, "Clearing all suggestions", "Are you sure you want to remove all suggestions from the project?", - QMessageBox.Yes, - QMessageBox.No, + QtWidgets.QMessageBox.Yes, + QtWidgets.QMessageBox.No, ) - if response == QMessageBox.No: + if response == QtWidgets.QMessageBox.No: return False return True diff --git a/sleap/info/write_tracking_h5.py b/sleap/info/write_tracking_h5.py index 255c8c61d..8bd583230 100644 --- a/sleap/info/write_tracking_h5.py +++ b/sleap/info/write_tracking_h5.py @@ -258,9 +258,11 @@ def write_occupancy_file( """ with h5.File(output_path, "w") as f: + print(f"\nExporting to SLEAP Analysis file...") for key, val in data_dict.items(): + print(f"\t{key}: ", end="") if isinstance(val, np.ndarray): - print(f"{key}: {val.shape}") + print(f"{val.shape}") if transpose: # Transpose since MATLAB expects column-major @@ -276,9 +278,9 @@ def write_occupancy_file( ) else: if isinstance(val, (str, int, type(None))): - print(f"{key}: {val}") + print(f"{val}") else: - print(f"{key}: {len(val)}") + print(f"{len(val)}") f.create_dataset(key, data=val) print(f"Saved as {output_path}") diff --git a/sleap/io/convert.py b/sleap/io/convert.py index 6e20a05e4..3353a169b 100644 --- a/sleap/io/convert.py +++ b/sleap/io/convert.py @@ -69,6 +69,7 @@ def create_parser(): default="slp", help="Output format. Default ('slp') is SLEAP dataset; " "'analysis' results in analysis.h5 file; " + "'analysis.nix' results in an analysis nix file;" "'h5' or 'json' results in SLEAP dataset " "with specified file format.", ) @@ -79,14 +80,18 @@ def create_parser(): def default_analysis_filename( - labels: Labels, video: Video, output_path: str, output_prefix: PurePath + labels: Labels, + video: Video, + output_path: str, + output_prefix: PurePath, + format_suffix: str = "h5", ) -> str: video_idx = labels.videos.index(video) vn = PurePath(video.backend.filename) filename = str( PurePath( output_path, - f"{output_prefix}.{video_idx:03}_{vn.stem}.analysis.h5", + f"{output_prefix}.{video_idx:03}_{vn.stem}.analysis.{format_suffix}", ) ) return filename @@ -117,36 +122,53 @@ def main(args: list = None): video_search=video_callback, video=video_path, ) - - if args.format == "analysis": - from sleap.info.write_tracking_h5 import main as write_analysis - - output_paths = [path for path in args.outputs] - - # Generate filenames if user has not specified (enough) output filenames - labels_path = args.input_path - fn = re.sub("(\\.json(\\.zip)?|\\.h5|\\.slp)$", "", labels_path) - fn = PurePath(fn) - default_names = [ - default_analysis_filename( - labels=labels, - video=video, - output_path=str(fn.parent), - output_prefix=str(fn.stem), - ) - for video in labels.videos[len(args.outputs) :] - ] - - output_paths.extend(default_names) - - for video, output_path in zip(labels.videos, output_paths): - write_analysis( - labels, - output_path=output_path, - labels_path=labels_path, - all_frames=True, - video=video, - ) + if "analysis" in args.format: + vids = [] + if len(args.video) > 0: # if a video is specified + for v in labels.videos: # check if it is among the videos in the project + if args.video in v.backend.filename: + vids.append(v) + break + else: + vids = labels.videos # otherwise all videos are converted + + outnames = [path for path in args.outputs] + if len(outnames) < len(vids): + # if there are less outnames provided than videos to convert... + out_suffix = "nix" if "nix" in args.format else "h5" + fn = args.input_path + fn = re.sub("(\.json(\.zip)?|\.h5|\.slp)$", "", fn) + fn = PurePath(fn) + + for video in vids[len(outnames) :]: + dflt_name = default_analysis_filename( + labels=labels, + video=video, + output_path=str(fn.parent), + output_prefix=str(fn.stem), + format_suffix=out_suffix, + ) + outnames.append(dflt_name) + + if "nix" in args.format: + from sleap.io.format.nix import NixAdaptor + + for video, outname in zip(vids, outnames): + try: + NixAdaptor.write(outname, labels, args.input_path, video) + except ValueError as e: + print(e.args[0]) + else: + from sleap.info.write_tracking_h5 import main as write_analysis + + for video, output_path in zip(vids, outnames): + write_analysis( + labels, + output_path=output_path, + labels_path=args.input_path, + all_frames=True, + video=video, + ) elif len(args.outputs) > 0: print(f"Output SLEAP dataset: {args.outputs[0]}") diff --git a/sleap/io/format/nix.py b/sleap/io/format/nix.py new file mode 100644 index 000000000..4c39ec8b6 --- /dev/null +++ b/sleap/io/format/nix.py @@ -0,0 +1,463 @@ +import numpy as np +import nixio as nix + +from pathlib import Path +from typing import Dict, List, Optional, cast +from sleap.instance import Track + +from sleap.io.format.adaptor import Adaptor, SleapObjectType +from sleap.io.format.filehandle import FileHandle +from sleap.io.dataset import Labels +from sleap.io.video import Video +from sleap.skeleton import Node, Skeleton + + +class NixAdaptor(Adaptor): + """Adaptor class for export of tracking analysis results to the generic + [NIX](https://github.com/g-node/nix) format. + NIX defines a generic data model for scientific data that combines data and data + annotations within the same container. The written files are hdf5 files that can + be read with any hdf5 library but follow the entity definitions of the NIX data + model. For reading nix-files with python install the nixio low-level library + ```pip install nixio``` or use the high-level api + [nixtrack](https://github.com/bendalab/nixtrack). + + So far the adaptor exports the tracked positions for each node of each instance, + the track and skeleton information along with the respective scores and the + centroid. Additionally, the video information is exported as metadata. + For more information on the mapping from sleap to nix see the docs on + [nixtrack](https://github.com/bendalab/nixtrack) (work in progress). + The adaptor uses a chunked writing approach which avoids numpy out of memory + exceptions when exporting large datasets. + + author: Jan Grewe (jan.grewe@g-node.org) + """ + + @property + def default_ext(self): + return "nix" + + @property + def all_exts(self) -> List[str]: + return [self.default_ext] + + @property + def handles(self): + return SleapObjectType.misc + + @property + def name(self) -> str: + """Human-reading name of the file format""" + return ( + "NIX file flavoured for animal tracking data https://github.com/g-node/nix" + ) + + @classmethod + def can_read_file(cls, file: FileHandle) -> bool: + """Returns whether this adaptor can read this file.""" + return False + + def can_write_filename(self, filename: str) -> bool: + """Returns whether this adaptor can write format of this filename.""" + return filename.endswith(tuple(self.all_exts)) + + @classmethod + def does_read(cls) -> bool: + """Returns whether this adaptor supports reading.""" + return False + + @classmethod + def does_write(cls) -> bool: + """Returns whether this adaptor supports writing.""" + return True + + @classmethod + def read(cls, file: FileHandle) -> object: + """Reads the file and returns the appropriate deserialized object.""" + raise NotImplementedError("NixAdaptor does not support reading.") + + @classmethod + def __check_video(cls, labels: Labels, video: Optional[Video]): + if (video is None) and (len(labels.videos) == 0): + raise ValueError( + f"There are no videos in this project. " + "No analysis file will be be written." + ) + if video is not None: + if video not in labels.videos: + raise ValueError( + f"Specified video {video} is not part of this project. " + "Skipping the analysis file for this video." + ) + if len(labels.get(video)) == 0: + raise ValueError( + f"No labeled frames in {video.backend.filename}. " + "Skipping the analysis file for this video." + ) + + @classmethod + def write( + cls, + filename: str, + source_object: object, + source_path: Optional[str] = None, + video: Optional[Video] = None, + ): + """Writes the object to a file.""" + source_object = cast(Labels, source_object) + + cls.__check_video(source_object, video) + + def create_file(filename: str, project: Optional[str], video: Video): + print(f"Creating nix file...", end="\t") + nf = nix.File.open(filename, nix.FileMode.Overwrite) + try: + s = nf.create_section("TrackingAnalysis", "nix.tracking.metadata") + s["version"] = "0.1.0" + s["format"] = "nix.tracking" + s["definitions"] = "https://github.com/bendalab/nixtrack" + s["writer"] = str(cls)[8:-2] + if project is not None: + s["project"] = project + + name = Path(video.backend.filename).name + b = nf.create_block(name, "nix.tracking_results") + + # add video metadata, if exists + src = b.create_source(name, "nix.tracking.source.video") + sec = src.file.create_section( + name, "nix.tracking.source.video.metadata" + ) + sec["filename"] = video.backend.filename + sec["fps"] = getattr(video.backend, "fps", 0.0) + sec.props["fps"].unit = "Hz" + sec["frames"] = video.num_frames + sec["grayscale"] = getattr(video.backend, "grayscale", None) + sec["height"] = video.backend.height + sec["width"] = video.backend.width + src.metadata = sec + except Exception as e: + nf.close() + raise e + + print("done") + return nf + + def track_map(source: Labels) -> Dict[Track, int]: + track_map: Dict[Track, int] = {} + for track in source.tracks: + if track in track_map: + continue + track_map[track] = len(track_map) + return track_map + + def skeleton_map(source: Labels) -> Dict[Skeleton, int]: + skel_map: Dict[Skeleton, int] = {} + for skeleton in source.skeletons: + if skeleton in skel_map: + continue + skel_map[skeleton] = len(skel_map) + return skel_map + + def node_map(source: Labels) -> Dict[Node, int]: + n_map: Dict[Node, int] = {} + for node in source.nodes: + if node in n_map: + continue + n_map[node] = len(n_map) + return n_map + + def create_feature_array(name, type, block, frame_index_array, shape, dtype): + array = block.create_data_array(name, type, dtype=dtype, shape=shape) + rd = array.append_range_dimension() + rd.link_data_array(frame_index_array, [-1]) + return array + + def create_positions_array( + name, type, block, frame_index_array, node_names, shape, dtype + ): + array = block.create_data_array( + name, type, dtype=dtype, shape=shape, label="pixel" + ) + rd = array.append_range_dimension() + rd.link_data_array(frame_index_array, [-1]) + array.append_set_dimension(["x", "y"]) + array.append_set_dimension(node_names) + return array + + def chunked_write( + instances, + frameid_array, + positions_array, + track_array, + skeleton_array, + pointscore_array, + instancescore_array, + trackingscore_array, + centroid_array, + track_map, + node_map, + skeleton_map, + chunksize=10000, + ): + data_written = 0 + indices = np.zeros(chunksize, dtype=int) + track = np.zeros_like(indices) + skeleton = np.zeros_like(indices) + instscore = np.zeros_like(indices, dtype=float) + positions = np.zeros((chunksize, 2, len(node_map.keys())), dtype=float) + centroids = np.zeros((chunksize, 2), dtype=float) + trackscore = np.zeros_like(instscore) + pointscore = np.zeros((chunksize, len(node_map.keys())), dtype=float) + dflt_pointscore = [0.0 for n in range(len(node_map.keys()))] + + while data_written < len(instances): + print(".", end="") + start = data_written + end = ( + len(instances) + if start + chunksize >= len(instances) + else start + chunksize + ) + for i in range(start, end): + inst = instances[i] + index = i - start + indices[index] = inst.frame_idx + if inst.track is not None: + track[index] = track_map[inst.track] + else: + track[index] = -1 + + skeleton[index] = skeleton_map[inst.skeleton] + + all_nodes = set([n.name for n in inst.nodes]) + used_nodes = set([n.name for n in node_map.keys()]) + missing_nodes = all_nodes.difference(used_nodes) + for n, p in zip(inst.nodes, inst.points): + positions[index, :, node_map[n]] = np.array([p.x, p.y]) + for m in missing_nodes: + positions[index, :, node_map[m]] = np.array([np.nan, np.nan]) + + centroids[index, :] = inst.centroid + if hasattr(inst, "score"): + instscore[index] = inst.score + trackscore[index] = inst.tracking_score + pointscore[index, :] = inst.scores + else: + instscore[index] = 0.0 + trackscore[index] = 0.0 + pointscore[index, :] = dflt_pointscore + + frameid_array[start:end] = indices[: end - start] + track_array[start:end] = track[: end - start] + positions_array[start:end, :, :] = positions[: end - start, :, :] + centroid_array[start:end, :] = centroids[: end - start, :] + skeleton_array[start:end] = skeleton[: end - start] + pointscore_array[start:end] = pointscore[: end - start] + instancescore_array[start:end] = instscore[: end - start] + trackingscore_array[start:end] = trackscore[: end - start] + data_written += end - start + + def write_data(block, source: Labels, video: Video): + instances = [ + instance + for instance in source.instances(video=video) + if instance.frame_idx is not None + ] + instances = sorted(instances, key=lambda i: i.frame_idx) + nodes = node_map(source) + tracks = track_map(source) + skeletons = skeleton_map(source) + positions_shape = (len(instances), 2, len(nodes)) + + frameid_array = block.create_data_array( + "frame", + "nix.tracking.instance_frameidx", + label="frame index", + shape=(len(instances),), + dtype=nix.DataType.Int64, + ) + frameid_array.append_range_dimension_using_self() + + positions_array = create_positions_array( + "position", + "nix.tracking.instance_position", + block, + frameid_array, + [node.name for node in nodes.keys()], + positions_shape, + nix.DataType.Float, + ) + + track_array = create_feature_array( + "track", + "nix.tracking.instance_track", + block, + frameid_array, + shape=(len(instances),), + dtype=nix.DataType.Int64, + ) + + skeleton_array = create_feature_array( + "skeleton", + "nix.tracking.instance_skeleton", + block, + frameid_array, + (len(instances),), + nix.DataType.Int64, + ) + + point_score = create_feature_array( + "node score", + "nix.tracking.nodes_score", + block, + frameid_array, + (len(instances), len(nodes)), + nix.DataType.Float, + ) + point_score.append_set_dimension([node.name for node in nodes.keys()]) + + centroid_array = create_feature_array( + "centroid", + "nix.tracking.centroid_position", + block, + frameid_array, + (len(instances), 2), + nix.DataType.Float, + ) + + centroid_array.append_set_dimension(["x", "y"]) + instance_score = create_feature_array( + "instance score", + "nix.tracking.instance_score", + block, + frameid_array, + (len(instances),), + nix.DataType.Float, + ) + + tracking_score = create_feature_array( + "tracking score", + "nix.tracking.tack_score", + block, + frameid_array, + (len(instances),), + nix.DataType.Float, + ) + + # bind all together using a nix.MultiTag + mtag = block.create_multi_tag( + "tracking results", "nix.tracking.results", positions=frameid_array + ) + mtag.references.append(positions_array) + mtag.create_feature(track_array, nix.LinkType.Indexed) + mtag.create_feature(skeleton_array, nix.LinkType.Indexed) + mtag.create_feature(point_score, nix.LinkType.Indexed) + mtag.create_feature(instance_score, nix.LinkType.Indexed) + mtag.create_feature(tracking_score, nix.LinkType.Indexed) + mtag.create_feature(centroid_array, nix.LinkType.Indexed) + + sm = block.create_data_frame( + "skeleton map", + "nix.tracking.skeleton_map", + col_names=["name", "index"], + col_dtypes=[nix.DataType.String, nix.DataType.Int8], + ) + table_data = [] + for track in skeletons.keys(): + table_data.append((track.name, skeletons[track])) + sm.append_rows(table_data) + + nm = block.create_data_frame( + "node map", + "nix.tracking.node_map", + col_names=["name", "weight", "index", "skeleton"], + col_dtypes=[ + nix.DataType.String, + nix.DataType.Float, + nix.DataType.Int8, + nix.DataType.Int8, + ], + ) + table_data = [] + for node in nodes.keys(): + skel_index = -1 # if node is not assigned to a skeleton + for track in skeletons: + if node in track.nodes: + skel_index = skeletons[track] + break + table_data.append((node.name, node.weight, nodes[node], skel_index)) + nm.append_rows(table_data) + + tm = block.create_data_frame( + "track map", + "nix.tracking.track_map", + col_names=["name", "spawned_on", "index"], + col_dtypes=[nix.DataType.String, nix.DataType.Int64, nix.DataType.Int8], + ) + table_data = [("none", -1, -1)] # default for user-labeled instances + for track in tracks.keys(): + table_data.append((track.name, track.spawned_on, tracks[track])) + tm.append_rows(table_data) + + # Print shape info + data_dict = { + "instances": instances, + "frameid_array": frameid_array, + "positions_array": positions_array, + "track_array": track_array, + "skeleton_array": skeleton_array, + "point_score": point_score, + "instance_score": instance_score, + "tracking_score": tracking_score, + "centroid_array": centroid_array, + "tracks": tracks, + "nodes": nodes, + "skeletons": skeletons, + } + for key, val in data_dict.items(): + print(f"\t{key}:", end=" ") + if hasattr(val, "shape"): + print(f"{val.shape}") + else: + print(f"{len(val)}") + + # Print labels/video info + print( + f"\tlabels path: {source_path}\n" + f"\tvideo path: {video.backend.filename}\n" + f"\tvideo index = {source_object.videos.index(video)}" + ) + + print(f"Writing to NIX file...") + chunked_write( + instances, + frameid_array, + positions_array, + track_array, + skeleton_array, + point_score, + instance_score, + tracking_score, + centroid_array, + tracks, + nodes, + skeletons, + ) + print(f"done") + + print(f"\nExporting to NIX analysis file...") + if video is None: + video = source_object.videos[0] + print(f"No video specified, exporting the first one...") + + nix_file = None + try: + nix_file = create_file(filename, source_path, video) + write_data(nix_file.blocks[0], source_object, video) + print(f"Saved as {filename}") + except Exception as e: + print(f"\n\tWriting failed with following error:\n{e}!") + finally: + if nix_file is not None: + nix_file.close() diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index bf994f18a..c9cd1f77b 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -4842,7 +4842,7 @@ def main(args: Optional[list] = None): parser = _make_cli_parser() # Parse inputs. - args, _ = parser.parse_known_args(args=args) + args, _ = parser.parse_known_args(args) print("Args:") pprint(vars(args)) print() diff --git a/tests/gui/test_commands.py b/tests/gui/test_commands.py index 2747d53b7..a032830db 100644 --- a/tests/gui/test_commands.py +++ b/tests/gui/test_commands.py @@ -1,4 +1,10 @@ +import shutil +from pathlib import PurePath, Path +from typing import List from tempfile import tempdir + +import pytest + from sleap.gui.commands import ( CommandContext, ImportDeepLabCutFolder, @@ -16,15 +22,9 @@ from sleap.io.convert import default_analysis_filename from sleap.instance import Instance, LabeledFrame from sleap import Skeleton, Track - from tests.info.test_h5 import extract_meta_hdf5 from tests.io.test_video import assert_video_params - -from pathlib import PurePath, Path -from typing import List - -import shutil -import pytest +from tests.io.test_formats import read_nix_meta def test_delete_user_dialog(centered_pair_predictions): @@ -85,8 +85,12 @@ def test_get_new_version_filename(): ) +@pytest.mark.parametrize("out_suffix", ["h5", "nix"]) def test_ExportAnalysisFile( - centered_pair_predictions: Labels, small_robot_mp4_vid: Video, tmpdir + centered_pair_predictions: Labels, + small_robot_mp4_vid: Video, + out_suffix: str, + tmpdir, ): def ExportAnalysisFile_ask(context: CommandContext, params: dict): """Taken from ExportAnalysisFile.ask()""" @@ -98,7 +102,7 @@ def ask_for_filename(default_name: str) -> str: labels = context.labels if len(labels.labeled_frames) == 0: - return False + raise ValueError("No labeled frames in project. Nothing to export.") if params["all_videos"]: all_videos = context.labels.videos @@ -108,7 +112,7 @@ def ask_for_filename(default_name: str) -> str: # Check for labeled frames in each video videos = [video for video in all_videos if len(labels.get(video)) != 0] if len(videos) == 0: - return False + raise ValueError("No labeled frames in video(s). Nothing to export.") default_name = context.state["filename"] or "labels" fn = PurePath(tmpdir, default_name) @@ -132,6 +136,7 @@ def ask_for_filename(default_name: str) -> str: video=video, output_path=dirname, output_prefix=str(fn.stem), + format_suffix=out_suffix, ) filename = default_name if use_default else ask_for_filename(default_name) @@ -153,10 +158,10 @@ def assert_videos_written(num_videos: int, labels_path: str = None): output_paths.append(output_path) if labels_path is not None: - read_meta = extract_meta_hdf5( - output_path, dset_names_in=["labels_path"] - ) - assert read_meta["labels_path"] == labels_path + meta_reader = extract_meta_hdf5 if out_suffix == "h5" else read_nix_meta + labels_key = "labels_path" if out_suffix == "h5" else "project" + read_meta = meta_reader(output_path, dset_names_in=["labels_path"]) + assert read_meta[labels_key] == labels_path assert len(output_paths) == num_videos, "Wrong number of outputs written" assert len(set(output_paths)) == num_videos, "Some output paths overwritten" @@ -243,8 +248,8 @@ def assert_videos_written(num_videos: int, labels_path: str = None): labels.remove_video(labels.videos[-1]) params = {"all_videos": True} - okay = ExportAnalysisFile_ask(context=context, params=params) - assert okay == False + with pytest.raises(ValueError): + okay = ExportAnalysisFile_ask(context=context, params=params) def test_ToggleGrayscale(centered_pair_predictions: Labels): diff --git a/tests/io/test_convert.py b/tests/io/test_convert.py index 0d7114635..da1971c11 100644 --- a/tests/io/test_convert.py +++ b/tests/io/test_convert.py @@ -8,29 +8,33 @@ import pytest +@pytest.mark.parametrize("format", ["analysis", "analysis.nix"]) def test_analysis_format( min_labels_slp: Labels, min_labels_slp_path: Labels, small_robot_mp4_vid: Video, + format: str, tmpdir, ): labels = min_labels_slp slp_path = PurePath(min_labels_slp_path) tmpdir = PurePath(tmpdir) - def generate_filenames(paths): + def generate_filenames(paths, format="analysis"): output_paths = [path for path in paths] # Generate filenames if user has not specified (enough) output filenames labels_path = str(slp_path) fn = re.sub("(\\.json(\\.zip)?|\\.h5|\\.slp)$", "", labels_path) fn = PurePath(fn) + out_suffix = "nix" if "nix" in format else "h5" default_names = [ default_analysis_filename( labels=labels, video=video, output_path=str(fn.parent), output_prefix=str(fn.stem), + format_suffix=out_suffix, ) for video in labels.videos[len(paths) :] ] @@ -38,8 +42,8 @@ def generate_filenames(paths): output_paths.extend(default_names) return output_paths - def assert_analysis_existance(output_paths: list): - output_paths = generate_filenames(output_paths) + def assert_analysis_existance(output_paths: list, format="analysis"): + output_paths = generate_filenames(output_paths, format) for video, path in zip(labels.videos, output_paths): video_exists = Path(path).exists() if len(labels.get(video)) == 0: @@ -47,21 +51,21 @@ def assert_analysis_existance(output_paths: list): else: assert video_exists - def sleap_convert_assert(output_paths, slp_path): + def sleap_convert_assert(output_paths, slp_path, format="analysis"): output_args = "" for path in output_paths: output_args += f"-o {path} " - args = f"--format analysis {output_args}{slp_path}".split() + args = f"--format {format} {output_args}{slp_path}".split() sleap_convert(args) - assert_analysis_existance(output_paths) + assert_analysis_existance(output_paths, format) # No output specified output_paths = [] - sleap_convert_assert(output_paths, slp_path) + sleap_convert_assert(output_paths, slp_path, format) # Specify output and retest output_paths = [str(tmpdir.with_name("prefix")), str(tmpdir.with_name("prefix2"))] - sleap_convert_assert(output_paths, slp_path) + sleap_convert_assert(output_paths, slp_path, format) # Add video and retest labels.add_video(small_robot_mp4_vid) @@ -69,7 +73,7 @@ def sleap_convert_assert(output_paths, slp_path): labels.save(filename=slp_path) output_paths = [str(tmpdir.with_name("prefix"))] - sleap_convert_assert(output_paths, slp_path) + sleap_convert_assert(output_paths, slp_path, format) # Add labeled frame to video and retest labeled_frame = labels.find(video=labels.videos[1], frame_idx=0, return_new=True)[0] @@ -80,7 +84,7 @@ def sleap_convert_assert(output_paths, slp_path): labels.save(filename=slp_path) output_paths = [str(tmpdir.with_name("prefix"))] - sleap_convert_assert(output_paths, slp_path) + sleap_convert_assert(output_paths, slp_path, format) def test_sleap_format( diff --git a/tests/io/test_formats.py b/tests/io/test_formats.py index 1387b084e..5001febb5 100644 --- a/tests/io/test_formats.py +++ b/tests/io/test_formats.py @@ -1,21 +1,23 @@ +import os +from pathlib import Path, PurePath + +import numpy as np +from numpy.testing import assert_array_equal +import pytest +import nixio + +from sleap.io.video import Video from sleap.instance import Instance, LabeledFrame, PredictedInstance, Track from sleap.io.dataset import Labels from sleap.io.format import read, dispatch, adaptor, text, genericjson, hdf5, filehandle from sleap.io.format.adaptor import SleapObjectType from sleap.io.format.alphatracker import AlphaTrackerAdaptor from sleap.io.format.ndx_pose import NDXPoseAdaptor +from sleap.io.format.nix import NixAdaptor from sleap.gui.commands import ImportAlphaTracker from sleap.gui.app import MainWindow from sleap.gui.state import GuiState -import pytest -import os -from pathlib import Path, PurePath -import numpy as np -from numpy.testing import assert_array_equal - -from sleap.io.video import Video - def test_text_adaptor(tmpdir): disp = dispatch.Dispatch() @@ -397,3 +399,85 @@ def test_nwb( labels.instances = [] with pytest.raises(TypeError): NDXPoseAdaptor.write(NDXPoseAdaptor, filename, labels) + + +def test_nix_adaptor( + centered_pair_predictions: Labels, + small_robot_mp4_vid: Video, + tmpdir, +): + # general tests + na = NixAdaptor() + assert na.default_ext == "nix" + assert "nix" in na.all_exts + assert len(na.name) > 0 + assert na.can_write_filename("somefile.nix") + assert not na.can_write_filename("somefile.slp") + assert NixAdaptor.does_read() == False + assert NixAdaptor.does_write() == True + + with pytest.raises(NotImplementedError): + NixAdaptor.read("some file") + + print("writing test predictions to nix file...") + filename = str(PurePath(tmpdir, "ndx_pose_test.nix")) + with pytest.raises(ValueError): + NixAdaptor.write(filename, centered_pair_predictions, video=small_robot_mp4_vid) + NixAdaptor.write(filename, centered_pair_predictions) + NixAdaptor.write( + filename, centered_pair_predictions, video=centered_pair_predictions.videos[0] + ) + + # basic read tests using the generic nix library + import nixio + + file = nixio.File.open(filename, nixio.FileMode.ReadOnly) + try: + file_meta = file.sections[0] + assert file_meta["format"] == "nix.tracking" + assert "sleap" in file_meta["writer"].lower() + + assert len([b for b in file.blocks if b.type == "nix.tracking_results"]) > 0 + b = file.blocks[0] + assert ( + len( + [ + da + for da in b.data_arrays + if da.type == "nix.tracking.instance_position" + ] + ) + == 1 + ) + assert ( + len( + [ + da + for da in b.data_arrays + if da.type == "nix.tracking.instance_frameidx" + ] + ) + == 1 + ) + + inst_positions = b.data_arrays["position"] + assert len(inst_positions.shape) == 3 + assert len(inst_positions.shape) == len(inst_positions.dimensions) + assert inst_positions.shape[2] == len(centered_pair_predictions.nodes) + + frame_indices = b.data_arrays["frame"] + assert len(frame_indices.shape) == 1 + assert frame_indices.shape[0] == inst_positions.shape[0] + except Exception as e: + file.close() + raise e + + +def read_nix_meta(filename, *args, **kwargs): + file = nixio.File.open(filename, nixio.FileMode.ReadOnly) + try: + file_meta = file_meta = file.sections[0] + except Exception: + file.close() + + return file_meta From ae491b32921bb07dbc48efedd449662e30213b1d Mon Sep 17 00:00:00 2001 From: Talmo Pereira Date: Tue, 24 Jan 2023 06:36:29 -0800 Subject: [PATCH 04/10] Fix body vs symmetry subgraph filtering (#1142) Co-authored-by: Liezl Maree <38435167+roomrys@users.noreply.github.com> --- sleap/skeleton.py | 37 +++++++++++++++++-------------------- tests/test_skeleton.py | 25 +++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 20 deletions(-) diff --git a/sleap/skeleton.py b/sleap/skeleton.py index cd89bcd62..064105a1f 100644 --- a/sleap/skeleton.py +++ b/sleap/skeleton.py @@ -171,42 +171,39 @@ def dict_match(dict1, dict2): @property def is_arborescence(self) -> bool: """Return whether this skeleton graph forms an arborescence.""" - return nx.algorithms.tree.recognition.is_arborescence(self._graph) + return nx.algorithms.tree.recognition.is_arborescence(self.graph) @property def in_degree_over_one(self) -> List[Node]: - return [node for node, in_degree in self._graph.in_degree if in_degree > 1] + return [node for node, in_degree in self.graph.in_degree if in_degree > 1] @property def root_nodes(self) -> List[Node]: - return [node for node, in_degree in self._graph.in_degree if in_degree == 0] + return [node for node, in_degree in self.graph.in_degree if in_degree == 0] @property def cycles(self) -> List[List[Node]]: - return list(nx.algorithms.simple_cycles(self._graph)) + return list(nx.algorithms.simple_cycles(self.graph)) @property def graph(self): - """Return subgraph of BODY edges for skeleton.""" - edges = [ - (src, dst, key) - for src, dst, key, edge_type in self._graph.edges(keys=True, data="type") - if edge_type == EdgeType.BODY - ] - # TODO: properly induce subgraph for MultiDiGraph - # Currently, NetworkX will just return the nodes in the subgraph. - # See: https://stackoverflow.com/questions/16150557/networkxcreating-a-subgraph-induced-from-edges - return self._graph.edge_subgraph(edges) + """Return a view on the subgraph of body nodes and edges for skeleton.""" + + def edge_filter_fn(src, dst, edge_key): + edge_data = self._graph.get_edge_data(src, dst, edge_key) + return edge_data["type"] == EdgeType.BODY + + return nx.subgraph_view(self._graph, filter_edge=edge_filter_fn) @property def graph_symmetry(self): """Return subgraph of symmetric edges for skeleton.""" - edges = [ - (src, dst, key) - for src, dst, key, edge_type in self._graph.edges(keys=True, data="type") - if edge_type == EdgeType.SYMMETRY - ] - return self._graph.edge_subgraph(edges) + + def edge_filter_fn(src, dst, edge_key): + edge_data = self._graph.get_edge_data(src, dst, edge_key) + return edge_data["type"] == EdgeType.SYMMETRY + + return nx.subgraph_view(self._graph, filter_edge=edge_filter_fn) @staticmethod def find_unique_nodes(skeletons: List["Skeleton"]) -> List[Node]: diff --git a/tests/test_skeleton.py b/tests/test_skeleton.py index e35aa5bec..e409f3bbe 100644 --- a/tests/test_skeleton.py +++ b/tests/test_skeleton.py @@ -169,6 +169,18 @@ def test_symmetry(): with pytest.raises(ValueError): s1.delete_symmetry("1", "5") + s2 = Skeleton() + s2.add_nodes(["1", "2", "3"]) + s2.add_edge("1", "2") + s2.add_edge("2", "3") + s2.add_symmetry("1", "3") + assert s2.graph.number_of_edges() == 2 + assert s2.graph_symmetry.number_of_edges() == 2 + assert list(s2.graph_symmetry.edges()) == [ + (s2.nodes[0], s2.nodes[2]), + (s2.nodes[2], s2.nodes[0]), + ] + def test_json(skeleton, tmpdir): """ @@ -254,6 +266,9 @@ def dict_match(dict1, dict2): def test_graph_property(skeleton): assert [node for node in skeleton.graph.nodes()] == skeleton.nodes + no_edge_skel = Skeleton.from_names_and_edge_inds(["A", "B"]) + assert [node for node in no_edge_skel.graph.nodes()] == no_edge_skel.nodes + def test_load_mat_format(): skeleton = Skeleton.load_mat( @@ -396,3 +411,13 @@ def test_arborescence(): assert len(skeleton.cycles) == 0 assert len(skeleton.root_nodes) == 1 assert len(skeleton.in_degree_over_one) == 1 + + # symmetry edges should be ignored + skeleton = Skeleton() + skeleton.add_node("a") + skeleton.add_node("b") + skeleton.add_node("c") + skeleton.add_edge("a", "b") + skeleton.add_edge("b", "c") + skeleton.add_symmetry("a", "c") + assert skeleton.is_arborescence From 1e6d4d122db46df893926169ed7a5190b02353f9 Mon Sep 17 00:00:00 2001 From: Talmo Pereira Date: Tue, 24 Jan 2023 07:17:05 -0800 Subject: [PATCH 05/10] Handle changing backbones in training editor GUI (#1140) Co-authored-by: Liezl Maree <38435167+roomrys@users.noreply.github.com> --- sleap/gui/learning/dialog.py | 55 +++++++++++++++++++++-------- sleap/gui/learning/scopedkeydict.py | 2 +- tests/gui/learning/test_dialog.py | 32 +++++++++++++++-- 3 files changed, 71 insertions(+), 18 deletions(-) diff --git a/sleap/gui/learning/dialog.py b/sleap/gui/learning/dialog.py index 12c245409..3e0dd5b4d 100644 --- a/sleap/gui/learning/dialog.py +++ b/sleap/gui/learning/dialog.py @@ -412,6 +412,43 @@ def merge_pipeline_and_head_config_data(self, head_name, head_data, pipeline_dat continue head_data[key] = val + @staticmethod + def update_loaded_config( + loaded_cfg: configs.TrainingJobConfig, tab_cfg_key_val_dict: dict + ) -> scopedkeydict.ScopedKeyDict: + """Update a loaded preset config with values from the training editor. + + Args: + loaded_cfg: A `TrainingJobConfig` that was loaded from a preset or previous + training run. + tab_cfg_key_val_dict: A dictionary with the values extracted from the training + editor GUI tab. + + Returns: + A `ScopedKeyDict` with the loaded config values overriden by the corresponding + ones from the `tab_cfg_key_val_dict`. + """ + # Serialize training config + loaded_cfg_hierarchical: dict = cattr.unstructure(loaded_cfg) + + # Clear backbone subfields since these will be set by the GUI + if ( + "model" in loaded_cfg_hierarchical + and "backbone" in loaded_cfg_hierarchical["model"] + ): + for k in loaded_cfg_hierarchical["model"]["backbone"]: + loaded_cfg_hierarchical["model"]["backbone"][k] = None + + loaded_cfg_scoped: scopedkeydict.ScopedKeyDict = ( + scopedkeydict.ScopedKeyDict.from_hierarchical_dict(loaded_cfg_hierarchical) + ) + + # Replace params exposed in GUI with values from GUI + for param, value in tab_cfg_key_val_dict.items(): + loaded_cfg_scoped.key_val_dict[param] = value + + return loaded_cfg_scoped + def get_every_head_config_data( self, pipeline_form_data ) -> List[configs.ConfigFileInfo]: @@ -436,24 +473,14 @@ def get_every_head_config_data( scopedkeydict.apply_cfg_transforms_to_key_val_dict(tab_cfg_key_val_dict) if trained_cfg_info is None: - # Config could not be loaded + # Config could not be loaded, just use the values from the GUI loaded_cfg_scoped: dict = tab_cfg_key_val_dict else: - # Config loaded - loaded_cfg: configs.TrainingJobConfig = trained_cfg_info.config - - # Serialize and flatten training config - loaded_cfg_heirarchical: dict = cattr.unstructure(loaded_cfg) - loaded_cfg_scoped: scopedkeydict.ScopedKeyDict = ( - scopedkeydict.ScopedKeyDict.from_hierarchical_dict( - loaded_cfg_heirarchical - ) + # Config was loaded, override with the values from the GUI + loaded_cfg_scoped = LearningDialog.update_loaded_config( + trained_cfg_info.config, tab_cfg_key_val_dict ) - # Replace params exposed in GUI with values from GUI - for param, value in tab_cfg_key_val_dict.items(): - loaded_cfg_scoped.key_val_dict[param] = value - # Deserialize merged dict to object cfg = scopedkeydict.make_training_config_from_key_val_dict( loaded_cfg_scoped diff --git a/sleap/gui/learning/scopedkeydict.py b/sleap/gui/learning/scopedkeydict.py index a5b481f00..294673570 100644 --- a/sleap/gui/learning/scopedkeydict.py +++ b/sleap/gui/learning/scopedkeydict.py @@ -39,7 +39,7 @@ def set_hierarchical_key_val(cls, current_dict: dict, key: Text, val: Any): current_dict[key] = val else: top_key, *subkey_list = key.split(".") - if top_key not in current_dict: + if top_key not in current_dict or current_dict[top_key] is None: current_dict[top_key] = dict() subkey = ".".join(subkey_list) cls.set_hierarchical_key_val(current_dict[top_key], subkey, val) diff --git a/tests/gui/learning/test_dialog.py b/tests/gui/learning/test_dialog.py index 1016615f8..a244fd2ed 100644 --- a/tests/gui/learning/test_dialog.py +++ b/tests/gui/learning/test_dialog.py @@ -2,16 +2,14 @@ from sleap.gui.learning.configs import TrainingConfigFilesWidget from sleap.gui.learning.configs import ConfigFileInfo from sleap.gui.learning.scopedkeydict import ( - make_training_config_from_key_val_dict, ScopedKeyDict, apply_cfg_transforms_to_key_val_dict, ) from sleap.gui.app import MainWindow +from sleap.nn.config import TrainingJobConfig, UNetConfig -import pytest import cattr from pathlib import Path -from qtpy import QtWidgets def test_use_hidden_params_from_loaded_config( @@ -90,3 +88,31 @@ def test_use_hidden_params_from_loaded_config( elif k not in params_reset: # 2. Uses hidden parameters from loaded config assert config_info_dict[k] == training_cfg_info_dict[k] + + +def test_update_loaded_config(): + base_cfg = TrainingJobConfig() + base_cfg.data.preprocessing.input_scaling = 0.5 + base_cfg.model.backbone.unet = UNetConfig(max_stride=32, output_stride=2) + base_cfg.optimization.augmentation_config.rotation_max_angle = 180 + base_cfg.optimization.augmentation_config.rotation_min_angle = -180 + + gui_vals = { + "data.preprocessing.input_scaling": 1.0, + "model.backbone.pretrained_encoder.encoder": "vgg16", + } + + scoped_cfg = LearningDialog.update_loaded_config(base_cfg, gui_vals) + assert scoped_cfg.key_val_dict["data.preprocessing.input_scaling"] == 1.0 + assert scoped_cfg.key_val_dict["model.backbone.unet"] is None + assert ( + scoped_cfg.key_val_dict["model.backbone.pretrained_encoder.encoder"] == "vgg16" + ) + assert ( + scoped_cfg.key_val_dict["optimization.augmentation_config.rotation_max_angle"] + == 180 + ) + assert ( + scoped_cfg.key_val_dict["optimization.augmentation_config.rotation_min_angle"] + == -180 + ) From 36a27ab0ca6add05b75b9c1769379f94b011484c Mon Sep 17 00:00:00 2001 From: Sean Afshar <84047864+sean-afshar@users.noreply.github.com> Date: Tue, 24 Jan 2023 08:16:44 -0800 Subject: [PATCH 06/10] Added scaling functionality for both the instances and bounding box. (#1133) * Create VisibleBoundingBox class. * Added instance scaling functionality in addition to bounding box scaling functionality. * Update sleap/gui/widgets/video.py Co-authored-by: Talmo Pereira * Update sleap/gui/widgets/video.py Co-authored-by: Talmo Pereira * Update sleap/gui/widgets/video.py Co-authored-by: Talmo Pereira * Update sleap/gui/widgets/video.py Co-authored-by: Talmo Pereira * Update sleap/gui/widgets/video.py Co-authored-by: Talmo Pereira * Added new testing for scaling operation and simplified VisibleBoundingBox class code. * Added type hinting to the scaling padding and removed erroneous bounding rect initialization. Co-authored-by: Talmo Pereira Co-authored-by: Liezl Maree <38435167+roomrys@users.noreply.github.com> --- sleap/gui/widgets/video.py | 202 ++++++++++++++++++++++++++++++++- tests/gui/test_video_player.py | 34 ++++++ 2 files changed, 232 insertions(+), 4 deletions(-) diff --git a/sleap/gui/widgets/video.py b/sleap/gui/widgets/video.py index 865cae97e..8c8bbdbac 100644 --- a/sleap/gui/widgets/video.py +++ b/sleap/gui/widgets/video.py @@ -50,6 +50,7 @@ QPen, QBrush, QColor, + QCursor, QFont, QPolygonF, QKeyEvent, @@ -1801,7 +1802,10 @@ def __init__( ) # Add box to go around instance for selection - self.box = QGraphicsRectItem(parent=self) + if self.predicted: + self.box = QGraphicsRectItem(parent=self) + else: + self.box = VisibleBoundingBox(rect=self._bounding_rect, parent=self) box_pen_width = color_manager.get_item_pen_width(self.instance) box_pen = QPen(QColor(*color), box_pen_width) box_pen.setStyle(Qt.DashLine) @@ -2063,6 +2067,197 @@ def paint(self, painter, option, widget=None): pass +class VisibleBoundingBox(QtWidgets.QGraphicsRectItem): + """QGraphicsRectItem for user instance bounding boxes. + + This object defines a scalable bounding box that encases an instance and handles + the relevant scaling operations. It is instantiated when its respective QtInstance + object is instantiated. + + When instantiated, it creates 4 boxes, which are properties of the overall object, + on the corners of the overall bounding box. These corner boxes can be dragged to + scale the overall bounding box. + + Args: + rect: The :class:`QRectF` object which defines the non-scalable bounding box. + parent: The :class:`QtInstance` to encompass. + + """ + + def __init__( + self, + rect: QRectF, + parent: QtInstance, + opacity: float = 0.8, + scaling_padding: float = 10.0, + ): + super().__init__(rect, parent) + self.box_width = parent.markerRadius + color_manager = parent.player.color_manager + int_color = color_manager.get_item_color(parent.instance) + self.int_color = QColor(*int_color) + self.corner_opacity = opacity + self.scaling_padding = scaling_padding + + self.parent = parent + self.resizing = None + self.origin = rect.topLeft() + self.ref_width = rect.width() + self.ref_height = rect.height() + + box_pen = QPen(Qt.black) + box_pen.setCosmetic(True) + box_brush = QBrush(self.int_color) + + # Create the edge boxes + self.top_left_box = QtWidgets.QGraphicsRectItem(parent=self) + self.bottom_left_box = QtWidgets.QGraphicsRectItem(parent=self) + self.top_right_box = QtWidgets.QGraphicsRectItem(parent=self) + self.bottom_right_box = QtWidgets.QGraphicsRectItem(parent=self) + + corner_boxes = [ + self.top_left_box, + self.bottom_left_box, + self.top_right_box, + self.bottom_right_box, + ] + for corner_box in corner_boxes: + corner_box.setPen(box_pen) + corner_box.setBrush(box_brush) + corner_box.setOpacity(self.corner_opacity) + corner_box.setCursor(QCursor(Qt.DragMoveCursor)) + + def setRect(self, rect: QRectF): + """Update edge boxes along with instance box""" + super().setRect(rect) + x1, y1, x2, y2 = rect.getCoords() + w = self.box_width + self.top_left_box.setRect(QRectF(QPointF(x1, y1), QPointF(x1 + w, y1 + w))) + self.top_right_box.setRect(QRectF(QPointF(x2 - w, y1), QPointF(x2, y1 + w))) + self.bottom_left_box.setRect(QRectF(QPointF(x1, y2 - w), QPointF(x1 + w, y2))) + self.bottom_right_box.setRect(QRectF(QPointF(x2 - w, y2 - w), QPointF(x2, y2))) + + def mousePressEvent(self, event): + """Custom event handler for pressing on an adjustable corner box. + + This function recognizes that the user has begun resizing the instance and + stores relevant information about the bounding box before the transformation. + """ + if event.button() == Qt.LeftButton: + if self.top_left_box.contains(event.pos()): + self.resizing = "top_left" + self.origin = self.rect().bottomRight() + elif self.top_right_box.contains(event.pos()): + self.resizing = "top_right" + self.origin = self.rect().bottomLeft() + elif self.bottom_left_box.contains(event.pos()): + self.resizing = "bottom_left" + self.origin = self.rect().topRight() + elif self.bottom_right_box.contains(event.pos()): + self.resizing = "bottom_right" + self.origin = self.rect().topLeft() + + self.ref_width = self.rect().width() + self.ref_height = self.rect().height() + + def mouseMoveEvent(self, event): + """Custom event handler for moving an adjustable corner box. + + This function resizes the bounding box as the user drags one of its corners. + """ + # Scale the bounding box and QtInstance if an edge box is selected + if event.buttons() & Qt.LeftButton: + x1, y1, x2, y2 = self.rect().getCoords() + new_x = event.pos().x() + new_y = event.pos().y() + + w = self.parent.player.video.width + h = self.parent.player.video.height + + if self.resizing == "top_left": + # Check to see if outside the range of the original bounding box + if new_x < 0: + new_x = 0 + if new_x >= x2 - self.scaling_padding - self.box_width: + new_x = x2 - self.scaling_padding - self.box_width + if new_y < 0: + new_y = 0 + if new_y >= y2 - self.scaling_padding - self.box_width: + new_y = y2 - self.scaling_padding - self.box_width + + # Update the bounding box + self.setRect(QRectF(QPointF(new_x, new_y), QPointF(x2, y2))) + + elif self.resizing == "top_right": + # Check to see if outside the range of the original bounding box + if new_x > w: + new_x = w + if new_x <= x1 + self.scaling_padding + self.box_width: + new_x = x1 + self.scaling_padding + self.box_width + if new_y < 0: + new_y = 0 + if new_y >= y2 - self.scaling_padding - self.box_width: + new_y = y2 - self.scaling_padding - self.box_width + + # Update the bounding box + self.setRect(QRectF(QPointF(x1, new_y), QPointF(new_x, y2))) + + elif self.resizing == "bottom_left": + # Check to see if outside the range of the original bounding box + if new_x < 0: + new_x = 0 + if new_x >= x2 - self.scaling_padding - self.box_width: + new_x = x2 - self.scaling_padding - self.box_width + if new_y > h: + new_y = h + if new_y <= y1 + self.scaling_padding + self.box_width: + new_y = y1 + self.scaling_padding + self.box_width + + # Update the bounding box + self.setRect(QRectF(QPointF(new_x, y1), QPointF(x2, new_y))) + + elif self.resizing == "bottom_right": + # Check to see if outside the range of the original bounding box + if new_x > w: + new_x = w + if new_x <= x1 + self.scaling_padding + self.box_width: + new_x = x1 + self.scaling_padding + self.box_width + if new_y > h: + new_y = h + if new_y <= y1 + self.scaling_padding + self.box_width: + new_y = y1 + self.scaling_padding + self.box_width + + # Update the bounding box + self.setRect(QRectF(QPointF(x1, y1), QPointF(new_x, new_y))) + + def mouseReleaseEvent(self, event): + """Custom event handler for releasing an adjustable corner box. + + This function recognizes the end of a scaling operation by transforming the + instance linked to the bounding box. This is done by updating the positions of + the nodes belonging to the instance and then calling the instance's updatePoints + function to update the entire instance. + """ + if event.button() == Qt.LeftButton: + # Scale the instance + scale_x = self.rect().width() / self.ref_width + scale_y = self.rect().height() / self.ref_height + + for node_key, node_value in self.parent.nodes.items(): + new_x = ( + scale_x * (node_value.point.x - self.origin.x()) + self.origin.x() + ) + new_y = ( + scale_y * (node_value.point.y - self.origin.y()) + self.origin.y() + ) + self.parent.nodes[node_key].setPos(new_x, new_y) + + # Update the instance + self.parent.updatePoints(complete=True, user_change=True) + + self.resizing = None + + class QtTextWithBackground(QGraphicsTextItem): """ Inherits methods/behavior of `QGraphicsTextItem`, but with background box. @@ -2161,11 +2356,10 @@ def plot_instances(scene, frame_idx, labels, video=None, fixed=True): if __name__ == "__main__": import argparse - from sleap.io.dataset import Labels parser = argparse.ArgumentParser() - parser.add_argument("data_path", help="Path to labels json file") + parser.add_argument("data_path", help="Path to labels file") args = parser.parse_args() - labels = Labels.load_json(args.data_path) + labels = sleap.load_file(args.data_path) video_demo(labels=labels, standalone=True) diff --git a/tests/gui/test_video_player.py b/tests/gui/test_video_player.py index 52ca49e68..b0661a4e1 100644 --- a/tests/gui/test_video_player.py +++ b/tests/gui/test_video_player.py @@ -6,6 +6,7 @@ QtInstance, QtVideoPlayer, QtTextWithBackground, + VisibleBoundingBox, ) from qtpy import QtCore, QtWidgets @@ -110,3 +111,36 @@ def test_QtTextWithBackground(qtbot): scene.addItem(txt) qtbot.addWidget(view) + + +def test_VisibleBoundingBox(qtbot, centered_pair_labels): + vp = QtVideoPlayer(centered_pair_labels.video) + + test_idx = 27 + for instance in centered_pair_labels.labeled_frames[test_idx].instances: + vp.addInstance(instance) + + inst = vp.instances[0] + + # Check if type of bounding box is correct + assert type(inst.box) == VisibleBoundingBox + + # Scale the bounding box + start_top_left = inst.box.rect().topLeft() + start_bottom_right = inst.box.rect().bottomRight() + initial_width = inst.box.rect().width() + initial_height = inst.box.rect().height() + + dx = 5 + dy = 10 + + end_top_left = QtCore.QPointF(start_top_left.x() - dx, start_top_left.y() - dy) + end_bottom_right = QtCore.QPointF( + start_bottom_right.x() + dx, start_bottom_right.y() + dy + ) + + inst.box.setRect(QtCore.QRectF(end_top_left, end_bottom_right)) + + # Check if bounding box scaled appropriately + assert inst.box.rect().width() - initial_width == 2 * dx + assert inst.box.rect().height() - initial_height == 2 * dy From b37b34f1b06d6e4e5d6973ad1aae8709c566bda8 Mon Sep 17 00:00:00 2001 From: Liezl Maree <38435167+roomrys@users.noreply.github.com> Date: Tue, 24 Jan 2023 10:18:46 -0800 Subject: [PATCH 07/10] Add better error message for top down (#1121) * Add better error message for top down * Add test for error message * Raise different error, fix test --- sleap/nn/inference.py | 9 ++++++++- tests/nn/test_inference.py | 16 ++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index c9cd1f77b..2d5d195d8 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -2147,7 +2147,14 @@ def call( crop_output = self.centroid_crop(example) if isinstance(self.instance_peaks, FindInstancePeaksGroundTruth): - peaks_output = self.instance_peaks(example, crop_output) + + if "instances" in example: + peaks_output = self.instance_peaks(example, crop_output) + else: + raise ValueError( + "Ground truth data was not detected... " + "Please load both models when predicting on non-ground-truth data." + ) else: peaks_output = self.instance_peaks(crop_output) return peaks_output diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index f85466e2c..4c8df81a6 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -1360,3 +1360,19 @@ def test_flow_tracker(centered_pair_predictions: Labels, tmpdir): for key in tracker.candidate_maker.shifted_instances.keys(): assert lf.frame_idx - key[0] <= track_window # Keys are pruned assert abs(key[0] - key[1]) <= track_window # References within window + + +def test_top_down_model(min_tracks_2node_labels: Labels, min_centroid_model_path: str): + labels = min_tracks_2node_labels + video = sleap.load_video(labels.videos[0].backend.filename) + predictor = sleap.load_model(min_centroid_model_path, batch_size=16) + + # Preload images + imgs = video[:3] + + # Raise better error message + with pytest.raises(ValueError): + predictor.predict(imgs[:1]) + + # Runs without error message + predictor.predict(labels.extract(inds=[0, 1])) From 1fc73870546bef08014342815896893e2ec88b7a Mon Sep 17 00:00:00 2001 From: roomrys <38435167+roomrys@users.noreply.github.com> Date: Thu, 26 Jan 2023 14:00:09 -0800 Subject: [PATCH 08/10] Finialize change on load test --- sleap/gui/commands.py | 2 +- sleap/io/dataset.py | 1 + tests/gui/test_commands.py | 74 ++++++++++++++++++++++---------------- 3 files changed, 45 insertions(+), 32 deletions(-) diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index d7793a7d3..6b8c65d51 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -670,7 +670,7 @@ def ask(context: "CommandContext", params: dict): has_loaded = True except ValueError as e: print(e) - QMessageBox(text=f"Unable to load {filename}.").exec_() + QtWidgets.QMessageBox(text=f"Unable to load {filename}.").exec_() params["labels"] = labels diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index f9d7765cb..044d55758 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -2582,6 +2582,7 @@ def video_callback( if fixed_path != filename: filenames[i] = fixed_path missing[i] = False + context["changed_on_load"] = True if use_gui: # If there are still missing paths, prompt user diff --git a/tests/gui/test_commands.py b/tests/gui/test_commands.py index 23e9c9433..1bfc0b36e 100644 --- a/tests/gui/test_commands.py +++ b/tests/gui/test_commands.py @@ -23,6 +23,9 @@ from sleap.io.format.ndx_pose import NDXPoseAdaptor from sleap.io.pathutils import fix_path_separator from sleap.io.video import Video + +# These imports cause trouble when running `pytest.main()` from within the file +# Comment out to debug tests file via VSCode's "Debug Python File" from tests.info.test_h5 import extract_meta_hdf5 from tests.io.test_video import assert_video_params from tests.io.test_formats import read_nix_meta @@ -503,42 +506,51 @@ def test_SetSelectedInstanceTrack(centered_pair_predictions: Labels): assert pred_inst.track == new_instance.track -def test_LoadLabelsObject( - centered_pair_predictions: Labels, centered_pair_predictions_slp_path: str, tmpdir +@pytest.mark.parametrize("video_move_case", ["new_directory", "new_name"]) +def test_LoadProjectFile( + centered_pair_predictions_slp_path: str, + video_move_case, + tmpdir, ): """Test that changing a labels object on load flags any changes.""" - class FlexyObject: - """Object that allows adding object attributes.""" - def __getattr__(self, value): - self.value = FlexyObject() - return self.value - - def __call__(self, *args, **kwargs): - return self.__init__() - - def ask_LoadProjectFile(context): - gui_video_callback = Labels.make_video_callback(context=context) - labels = Labels.load_file(centered_pair_predictions_slp_path, video_search=gui_video_callback) + def ask_LoadProjectFile(params): + """Implement `LoadProjectFile.ask` without GUI elements.""" + filename: Path = params["filename"] + gui_video_callback = Labels.make_video_callback( + search_paths=[str(filename)], context=params + ) + labels = Labels.load_file( + centered_pair_predictions_slp_path, video_search=gui_video_callback + ) return labels - # Move labels video to tmpdir - labels = centered_pair_predictions - video_path = Path(labels.video.filename).absolute() - new_video_path = Path(tmpdir, "new_video.mp4") - try: - shutil.move(video_path, new_video_path) - except Exception: - pass + def load_and_assert_changes(new_video_path: Path): + # Load the project + params = {"filename": new_video_path} + ask_LoadProjectFile(params) - # Load the project - context: CommandContext = CommandContext(state=GuiState(), app=FlexyObject()) - labels = ask_LoadProjectFile(context = context) - context.loadLabelsObject(labels=labels, filename="it_doesnt_matter") + # Assert project has changes + assert params["changed_on_load"] - # Assert project has changes - print(context.state["has_changes"]) - assert context.state["has_changes"] + # Get labels and video path + labels = Labels.load_file(centered_pair_predictions_slp_path) + expected_video_path = Path(labels.video.backend.filename) -if __name__ == "__main__": - pytest.main([r"tests\gui\test_commands.py::test_LoadLabelsObject"]) \ No newline at end of file + # Move video to new location based on case + if video_move_case == "new_directory": # Needs to have same name + new_video_path = Path(tmpdir, expected_video_path.name) + else: # Needs to have different name + new_video_path = expected_video_path.with_name("new_name.mp4") + shutil.move(expected_video_path, new_video_path) # Move video to new location + + # Shorten video path if using directory location only + search_path = ( + new_video_path.parent if video_move_case == "new_directory" else new_video_path + ) + + # Load project and assert changes + try: + load_and_assert_changes(search_path) + finally: # Move video back to original location - for ease of re-testing + shutil.move(new_video_path, expected_video_path) From 1e123c962fd29edde4b35ff31a7a569473cd316b Mon Sep 17 00:00:00 2001 From: roomrys <38435167+roomrys@users.noreply.github.com> Date: Thu, 26 Jan 2023 14:03:59 -0800 Subject: [PATCH 09/10] Remove unused imports --- tests/gui/test_commands.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/gui/test_commands.py b/tests/gui/test_commands.py index 1bfc0b36e..9474b524c 100644 --- a/tests/gui/test_commands.py +++ b/tests/gui/test_commands.py @@ -1,6 +1,5 @@ from pathlib import PurePath, Path import shutil -from tempfile import tempdir from typing import List import pytest @@ -15,7 +14,6 @@ SaveProjectAs, get_new_version_filename, ) -from sleap.gui.state import GuiState from sleap.instance import Instance, LabeledFrame from sleap.io.convert import default_analysis_filename from sleap.io.dataset import Labels From 219875c790fef4905896f27ccfc71b03f49d77a1 Mon Sep 17 00:00:00 2001 From: roomrys <38435167+roomrys@users.noreply.github.com> Date: Thu, 26 Jan 2023 15:10:25 -0800 Subject: [PATCH 10/10] Skip test if on windows since files are being used in parallel --- tests/gui/test_commands.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/gui/test_commands.py b/tests/gui/test_commands.py index 9474b524c..20d8ba6fa 100644 --- a/tests/gui/test_commands.py +++ b/tests/gui/test_commands.py @@ -1,5 +1,6 @@ from pathlib import PurePath, Path import shutil +import sys from typing import List import pytest @@ -504,6 +505,11 @@ def test_SetSelectedInstanceTrack(centered_pair_predictions: Labels): assert pred_inst.track == new_instance.track +@pytest.mark.skipif( + sys.platform.startswith("win"), + reason="Files being using in parallel by linux CI tests via Github Actions " + "(and linux tests give us codecov reports)", +) @pytest.mark.parametrize("video_move_case", ["new_directory", "new_name"]) def test_LoadProjectFile( centered_pair_predictions_slp_path: str,