diff --git a/docs/guides/cli.md b/docs/guides/cli.md index 03b806903..134461c60 100644 --- a/docs/guides/cli.md +++ b/docs/guides/cli.md @@ -207,7 +207,7 @@ optional arguments: --tracking.clean_iou_threshold TRACKING.CLEAN_IOU_THRESHOLD IOU to use when culling instances *after* tracking. (default: 0) --tracking.similarity TRACKING.SIMILARITY - Options: instance, centroid, iou (default: instance) + Options: instance, normalized_instance, object_keypoint, centroid, iou (default: instance) --tracking.match TRACKING.MATCH Options: hungarian, greedy (default: greedy) --tracking.robust TRACKING.ROBUST diff --git a/docs/guides/proofreading.md b/docs/guides/proofreading.md index fea1c5ebc..941b85154 100644 --- a/docs/guides/proofreading.md +++ b/docs/guides/proofreading.md @@ -50,6 +50,8 @@ There are currently three methods for matching instances in frame N against thes - “**centroid**” measures similarity by the distance between the instance centroids - “**iou**” measures similarity by the intersection/overlap of the instance bounding boxes - “**instance**” measures similarity by looking at the distances between corresponding nodes in the instances, normalized by the number of valid nodes in the candidate instance. +- “**normalized_instance**” measures similarity by looking at the distances between corresponding nodes in the instances, normalized by the number of valid nodes in the candidate instance and the keypoints normalized by the image size. +- “**object_keypoint**” measures similarity by measuring the distance between each keypoints from a reference instance and a query instance, takes the exp(-d**2), sum for all the keypoints and divide by the number of visible keypoints in the reference instance. Once SLEAP has measured the similarity between all the candidates and the instances in frame N, you need to choose a way to pair them up. You can do this either by picking the best match, and the picking the best remaining match for each remaining instance in turn—this is “**greedy**” matching—or you can find the way of matching identities which minimizes the total cost (or: maximizes the total similarity)—this is “**Hungarian**” matching. diff --git a/environment.yml b/environment.yml index 2aba3c7d2..06a0633d2 100644 --- a/environment.yml +++ b/environment.yml @@ -10,7 +10,7 @@ channels: dependencies: # Packages SLEAP uses directly - - conda-forge::attrs >=21.2.0 #,<=21.4.0 + - conda-forge::attrs >=21.2.0 - conda-forge::cattrs ==1.1.1 - conda-forge::imageio-ffmpeg # Required for imageio to read/write videos with ffmpeg - conda-forge::jsmin diff --git a/environment_no_cuda.yml b/environment_no_cuda.yml index 2adee7a89..ba2b54a22 100644 --- a/environment_no_cuda.yml +++ b/environment_no_cuda.yml @@ -11,7 +11,7 @@ channels: dependencies: # Packages SLEAP uses directly - - conda-forge::attrs >=21.2.0 #,<=21.4.0 + - conda-forge::attrs >=21.2.0 - conda-forge::cattrs ==1.1.1 - conda-forge::imageio-ffmpeg # Required for imageio to read/write videos with ffmpeg - conda-forge::jsmin diff --git a/pypi_requirements.txt b/pypi_requirements.txt index 62c0c0ddc..ad34e2ad8 100644 --- a/pypi_requirements.txt +++ b/pypi_requirements.txt @@ -6,6 +6,8 @@ # These are also distributed through conda and not pip installed when using conda. attrs>=21.2.0,<=21.4.0 cattrs==1.1.1 +imageio +imageio-ffmpeg # certifi>=2017.4.17,<=2021.10.8 jsmin jsonpickle==1.2 diff --git a/sleap/config/pipeline_form.yaml b/sleap/config/pipeline_form.yaml index 0497e2815..ed05b91f8 100644 --- a/sleap/config/pipeline_form.yaml +++ b/sleap/config/pipeline_form.yaml @@ -435,7 +435,7 @@ inference: label: Similarity Method type: list default: instance - options: "instance,centroid,iou,object keypoint" + options: "instance,normalized_instance,centroid,iou,object keypoint" - name: tracking.match label: Matching Method type: list @@ -536,7 +536,7 @@ inference: label: Similarity Method type: list default: instance - options: "instance,centroid,iou,object keypoint" + options: "instance,normalized_instance,centroid,iou,object keypoint" - name: tracking.match label: Matching Method type: list diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 4c75dac3f..2dbceb3b7 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -377,7 +377,9 @@ def add_menu_item(menu, key: str, name: str, action: Callable): def connect_check(key): self._menu_actions[key].setCheckable(True) self._menu_actions[key].setChecked(self.state[key]) - self.state.connect(key, self._menu_actions[key].setChecked) + self.state.connect( + key, lambda checked: self._menu_actions[key].setChecked(checked) + ) # add checkable menu item connected to state variable def add_menu_check_item(menu, key: str, name: str): diff --git a/sleap/gui/dataviews.py b/sleap/gui/dataviews.py index 0a008bea7..f68dc0180 100644 --- a/sleap/gui/dataviews.py +++ b/sleap/gui/dataviews.py @@ -413,13 +413,6 @@ def set_item(self, item, key, value): elif key == "symmetry": self.context.setNodeSymmetry(skeleton=self.obj, node=item, symmetry=value) - def get_item_color(self, item: Any, key: str): - if self.skeleton: - color = self.context.app.color_manager.get_item_color( - item, parent_skeleton=self.skeleton - ) - return QtGui.QColor(*color) - class SkeletonEdgesTableModel(GenericTableModel): """Table model for skeleton edges.""" @@ -436,14 +429,6 @@ def object_to_items(self, skeleton: Skeleton): ] return items - def get_item_color(self, item: Any, key: str): - if self.skeleton: - edge_pair = (item["source"], item["destination"]) - color = self.context.app.color_manager.get_item_color( - edge_pair, parent_skeleton=self.skeleton - ) - return QtGui.QColor(*color) - class LabeledFrameTableModel(GenericTableModel): """Table model for listing instances in labeled frame. diff --git a/sleap/gui/widgets/docks.py b/sleap/gui/widgets/docks.py index 3375e4713..bd20bf79a 100644 --- a/sleap/gui/widgets/docks.py +++ b/sleap/gui/widgets/docks.py @@ -30,10 +30,8 @@ ) from sleap.gui.dialogs.formbuilder import YamlFormWidget from sleap.gui.widgets.views import CollapsibleWidget -from sleap.skeleton import Skeleton -from sleap.util import decode_preview_image, find_files_by_suffix, get_package_file - -# from sleap.gui.app import MainWindow +from sleap.skeleton import Skeleton, SkeletonDecoder +from sleap.util import find_files_by_suffix, get_package_file class DockWidget(QDockWidget): @@ -365,7 +363,7 @@ def create_templates_groupbox(self) -> QGroupBox: def updatePreviewImage(preview_image_bytes: bytes): # Decode the preview image - preview_image = decode_preview_image(preview_image_bytes) + preview_image = SkeletonDecoder.decode_preview_image(preview_image_bytes) # Create a QImage from the Image preview_image = QtGui.QImage( diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index a1a083ba5..1c3b2d619 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -2651,6 +2651,7 @@ def _object_builder(): # Set tracks for predicted instances in this frame. predicted_instances = self.tracker.track( untracked_instances=predicted_instances, + img_hw=ex["image"].shape[-3:-1], img=image, t=frame_ind, ) @@ -3293,6 +3294,7 @@ def _object_builder(): # Set tracks for predicted instances in this frame. predicted_instances = self.tracker.track( untracked_instances=predicted_instances, + img_hw=ex["image"].shape[-3:-1], img=image, t=frame_ind, ) diff --git a/sleap/nn/tracker/components.py b/sleap/nn/tracker/components.py index b2f35b21f..0b77f4ac9 100644 --- a/sleap/nn/tracker/components.py +++ b/sleap/nn/tracker/components.py @@ -12,6 +12,7 @@ """ + import operator from collections import defaultdict import logging @@ -29,6 +30,21 @@ InstanceType = TypeVar("InstanceType", Instance, PredictedInstance) +def normalized_instance_similarity( + ref_instance: InstanceType, query_instance: InstanceType, img_hw: Tuple[int] +) -> float: + """Computes similarity between instances with normalized keypoints.""" + + normalize_factors = np.array((img_hw[1], img_hw[0])) + ref_visible = ~(np.isnan(ref_instance.points_array).any(axis=1)) + normalized_query_keypoints = query_instance.points_array / normalize_factors + normalized_ref_keypoints = ref_instance.points_array / normalize_factors + dists = np.sum((normalized_query_keypoints - normalized_ref_keypoints) ** 2, axis=1) + similarity = np.nansum(np.exp(-dists)) / np.sum(ref_visible) + + return similarity + + def instance_similarity( ref_instance: InstanceType, query_instance: InstanceType ) -> float: diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index 619b33f76..59d819643 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -11,12 +11,14 @@ import cv2 import numpy as np import rich.progress +import functools from sleap import Track, LabeledFrame, Skeleton from sleap.nn.tracker.components import ( factory_object_keypoint_similarity, instance_similarity, + normalized_instance_similarity, centroid_distance, instance_iou, hungarian_matching, @@ -504,7 +506,8 @@ def get_candidates( instance=instance_similarity, centroid=centroid_distance, iou=instance_iou, - object_keypoint=instance_similarity, + normalized_instance=normalized_instance_similarity, + object_keypoint=factory_object_keypoint_similarity, ) match_policies = dict( @@ -799,6 +802,7 @@ def uses_image(self): def track( self, untracked_instances: List[InstanceType], + img_hw: Tuple[int], img: Optional[np.ndarray] = None, t: int = None, ) -> List[InstanceType]: @@ -806,12 +810,18 @@ def track( Args: untracked_instances: List of instances to assign to tracks. + img_hw: (height, width) of the image used to normalize the keypoints. img: Image data of the current frame for flow shifting. t: Current timestep. If not provided, increments from the internal queue. Returns: A list of the instances that were tracked. """ + if self.similarity_function == normalized_instance_similarity: + factory_normalized_instance = functools.partial( + normalized_instance_similarity, img_hw=img_hw + ) + self.similarity_function = factory_normalized_instance if self.candidate_maker is None: return untracked_instances diff --git a/sleap/skeleton.py b/sleap/skeleton.py index eca393b8e..f6477cf66 100644 --- a/sleap/skeleton.py +++ b/sleap/skeleton.py @@ -6,24 +6,25 @@ their connection to each other, and needed meta-data. """ -import attr -import cattr -import numpy as np -import jsonpickle -import json -import h5py +import base64 import copy - +import json import operator from enum import Enum +from io import BytesIO from itertools import count -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Text +from typing import Any, Dict, Iterable, List, Optional, Text, Tuple, Union +import attr +import cattr +import h5py +import jsonpickle import networkx as nx +import numpy as np from networkx.readwrite import json_graph +from PIL import Image from scipy.io import loadmat - NodeRef = Union[str, "Node"] H5FileRef = Union[str, h5py.File] @@ -85,6 +86,483 @@ def matches(self, other: "Node") -> bool: return other.name == self.name and other.weight == self.weight +class SkeletonDecoder: + """Replace jsonpickle.decode with our own decoder. + + This function will decode the following from jsonpickle's encoded format: + + `Node` objects from + { + "py/object": "sleap.skeleton.Node", + "py/state": { "py/tuple": ["thorax1", 1.0] } + } + to `Node(name="thorax1", weight=1.0)` + + `EdgeType` objects from + { + "py/reduce": [ + { "py/type": "sleap.skeleton.EdgeType" }, + { "py/tuple": [1] } + ] + } + to `EdgeType(1)` + + `bytes` from + { + "py/b64": "aVZC..." + } + to `b"iVBO..."` + + and any repeated objects from + { + "py/id": 1 + } + to the object with the same reconstruction id (from top to bottom). + """ + + def __init__(self): + self.decoded_objects: List[Union[Node, EdgeType]] = [] + + def _decode_id(self, id: int) -> Union[Node, EdgeType]: + """Decode the object with the given `py/id` value of `id`. + + Args: + id: The `py/id` value to decode (1-indexed). + objects: The dictionary of objects that have already been decoded. + + Returns: + The object with the given `py/id` value. + """ + return self.decoded_objects[id - 1] + + @staticmethod + def _decode_state(state: dict) -> Node: + """Reconstruct the `Node` object from 'py/state' key in the serialized nx_graph. + + We support states in either dictionary or tuple format: + { + "py/state": { "py/tuple": ["thorax1", 1.0] } + } + or + { + "py/state": {"name": "thorax1", "weight": 1.0} + } + + Args: + state: The state to decode, i.e. state = dict["py/state"] + + Returns: + The `Node` object reconstructed from the state. + """ + + if "py/tuple" in state: + return Node(*state["py/tuple"]) + + return Node(**state) + + @staticmethod + def _decode_object_dict(object_dict) -> Node: + """Decode dict containing `py/object` key in the serialized nx_graph. + + Args: + object_dict: The dict to decode, i.e. + object_dict = {"py/object": ..., "py/state":...} + + Raises: + ValueError: If object_dict does not have 'py/object' and 'py/state' keys. + ValueError: If object_dict['py/object'] is not 'sleap.skeleton.Node'. + + Returns: + The decoded `Node` object. + """ + + if object_dict["py/object"] != "sleap.skeleton.Node": + raise ValueError("Only 'sleap.skeleton.Node' objects are supported.") + + node: Node = SkeletonDecoder._decode_state(state=object_dict["py/state"]) + return node + + def _decode_node(self, encoded_node: dict) -> Node: + """Decode an item believed to be an encoded `Node` object. + + Also updates the list of decoded objects. + + Args: + encoded_node: The encoded node to decode. + + Returns: + The decoded node and the updated list of decoded objects. + """ + + if isinstance(encoded_node, int): + # Using index mapping to replace the object (load from Labels) + return encoded_node + elif "py/object" in encoded_node: + decoded_node: Node = SkeletonDecoder._decode_object_dict(encoded_node) + self.decoded_objects.append(decoded_node) + elif "py/id" in encoded_node: + decoded_node: Node = self._decode_id(encoded_node["py/id"]) + + return decoded_node + + def _decode_nodes(self, encoded_nodes: List[dict]) -> List[Dict[str, Node]]: + """Decode the 'nodes' key in the serialized nx_graph. + + The encoded_nodes is a list of dictionary of two types: + - A dictionary with 'py/object' and 'py/state' keys. + - A dictionary with 'py/id' key. + + Args: + encoded_nodes: The list of encoded nodes to decode. + + Returns: + The decoded nodes. + """ + + decoded_nodes: List[Dict[str, Node]] = [] + for e_node_dict in encoded_nodes: + e_node = e_node_dict["id"] + d_node = self._decode_node(e_node) + decoded_nodes.append({"id": d_node}) + + return decoded_nodes + + def _decode_reduce_dict(self, reduce_dict: Dict[str, List[dict]]) -> EdgeType: + """Decode the 'reduce' key in the serialized nx_graph. + + The reduce_dict is a dictionary in the following format: + { + "py/reduce": [ + { "py/type": "sleap.skeleton.EdgeType" }, + { "py/tuple": [1] } + ] + } + + Args: + reduce_dict: The dictionary to decode i.e. reduce_dict = {"py/reduce": ...} + + Returns: + The decoded `EdgeType` object. + """ + + reduce_list = reduce_dict["py/reduce"] + has_py_type = has_py_tuple = False + for reduce_item in reduce_list: + if reduce_item is None: + # Sometimes the reduce list has None values, skip them + continue + if ( + "py/type" in reduce_item + and reduce_item["py/type"] == "sleap.skeleton.EdgeType" + ): + has_py_type = True + elif "py/tuple" in reduce_item: + edge_type: int = reduce_item["py/tuple"][0] + has_py_tuple = True + + if not has_py_type or not has_py_tuple: + raise ValueError( + "Only 'sleap.skeleton.EdgeType' objects are supported. " + "The 'py/reduce' list must have dictionaries with 'py/type' and " + "'py/tuple' keys." + f"\n\tHas py/type: {has_py_type}\n\tHas py/tuple: {has_py_tuple}" + ) + + edge = EdgeType(edge_type) + self.decoded_objects.append(edge) + + return edge + + def _decode_edge_type(self, encoded_edge_type: dict) -> EdgeType: + """Decode the 'type' key in the serialized nx_graph. + + Args: + encoded_edge_type: a dictionary with either 'py/id' or 'py/reduce' key. + + Returns: + The decoded `EdgeType` object. + """ + + if "py/reduce" in encoded_edge_type: + edge_type = self._decode_reduce_dict(encoded_edge_type) + else: + # Expect a "py/id" instead of "py/reduce" + edge_type = self._decode_id(encoded_edge_type["py/id"]) + return edge_type + + def _decode_links( + self, links: List[dict] + ) -> List[Dict[str, Union[int, Node, EdgeType]]]: + """Decode the 'links' key in the serialized nx_graph. + + The links are the edges in the graph and will have the following keys: + - source: The source node of the edge. + - target: The destination node of the edge. + - type: The type of the edge (e.g. BODY, SYMMETRY). + and more. + + Args: + encoded_links: The list of encoded links to decode. + """ + + for link in links: + for key, value in link.items(): + if key == "source": + link[key] = self._decode_node(value) + elif key == "target": + link[key] = self._decode_node(value) + elif key == "type": + link[key] = self._decode_edge_type(value) + + return links + + @staticmethod + def decode_preview_image( + img_b64: bytes, return_bytes: bool = False + ) -> Union[Image.Image, bytes]: + """Decode a skeleton preview image byte string representation to a `PIL.Image` + + Args: + img_b64: a byte string representation of a skeleton preview image + return_bytes: whether to return the decoded image as bytes + + Returns: + Either a PIL.Image of the skeleton preview image or the decoded image as bytes + (if `return_bytes` is True). + """ + bytes = base64.b64decode(img_b64) + if return_bytes: + return bytes + + buffer = BytesIO(bytes) + img = Image.open(buffer) + return img + + def _decode(self, json_str: str): + dicts = json.loads(json_str) + + # Enforce same format across template and non-template skeletons + if "nx_graph" not in dicts: + # Non-template skeletons use the dicts as the "nx_graph" + dicts = {"nx_graph": dicts} + + # Decode the graph + nx_graph = dicts["nx_graph"] + + self.decoded_objects = [] # Reset the decoded objects incase reusing decoder + for key, value in nx_graph.items(): + if key == "nodes": + nx_graph[key] = self._decode_nodes(value) + elif key == "links": + nx_graph[key] = self._decode_links(value) + + # Decode the preview image (if it exists) + preview_image = dicts.get("preview_image", None) + if preview_image is not None: + dicts["preview_image"] = SkeletonDecoder.decode_preview_image( + preview_image["py/b64"], return_bytes=True + ) + + return dicts + + @classmethod + def decode(cls, json_str: str) -> Dict: + """Decode the given json string into a dictionary. + + Returns: + A dict with `Node`s, `EdgeType`s, and `bytes` decoded/reconstructed. + """ + decoder = cls() + return decoder._decode(json_str) + + +class SkeletonEncoder: + """Replace jsonpickle.encode with our own encoder. + + The input is a dictionary containing python objects that need to be encoded as + JSON strings. The output is a JSON string that represents the input dictionary. + + `Node(name='neck', weight=1.0)` => + { + "py/object": "sleap.Skeleton.Node", + "py/state": {"py/tuple" ["neck", 1.0]} + } + + `` => + {"py/reduce": [ + {"py/type": "sleap.Skeleton.EdgeType"}, + {"py/tuple": [1] } + ] + }` + + Where `name` and `weight` are the attributes of the `Node` class; weight is always 1.0. + `EdgeType` is an enum with values `BODY = 1` and `SYMMETRY = 2`. + + See sleap.skeleton.Node and sleap.skeleton.EdgeType. + + If the object has been "seen" before, it will not be encoded as the full JSON string + but referenced by its `py/id`, which starts at 1 and indexes the objects in the + order they are seen so that the second time the first object is used, it will be + referenced as `{"py/id": 1}`. + """ + + def __init__(self): + """Initializes a SkeletonEncoder instance.""" + # Maps object id to py/id + self._encoded_objects: Dict[int, int] = {} + + @classmethod + def encode(cls, data: Dict[str, Any]) -> str: + """Encodes the input dictionary as a JSON string. + + Args: + data: The data to encode. + + Returns: + json_str: The JSON string representation of the data. + """ + encoder = cls() + encoded_data = encoder._encode(data) + json_str = json.dumps(encoded_data) + return json_str + + def _encode(self, obj: Any) -> Any: + """Recursively encodes the input object. + + Args: + obj: The object to encode. Can be a dictionary, list, Node, EdgeType or + primitive data type. + + Returns: + The encoded object as a dictionary. + """ + if isinstance(obj, dict): + encoded_obj = {} + for key, value in obj.items(): + if key == "links": + encoded_obj[key] = self._encode_links(value) + else: + encoded_obj[key] = self._encode(value) + return encoded_obj + elif isinstance(obj, list): + return [self._encode(v) for v in obj] + elif isinstance(obj, EdgeType): + return self._encode_edge_type(obj) + elif isinstance(obj, Node): + return self._encode_node(obj) + else: + return obj # Primitive data types + + def _encode_links(self, links: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Encodes the list of links (edges) in the skeleton graph. + + Args: + links: A list of dictionaries, each representing an edge in the graph. + + Returns: + A list of encoded edge dictionaries with keys ordered as specified. + """ + encoded_links = [] + for link in links: + # Use a regular dict (insertion order preserved in Python 3.7+) + encoded_link = {} + + for key, value in link.items(): + if key in ("source", "target"): + encoded_link[key] = self._encode_node(value) + elif key == "type": + encoded_link[key] = self._encode_edge_type(value) + else: + encoded_link[key] = self._encode(value) + encoded_links.append(encoded_link) + + return encoded_links + + def _encode_node(self, node: Union["Node", int]) -> Dict[str, Any]: + """Encodes a Node object. + + Args: + node: The Node object to encode or integer index. The latter requires that + the class has the `idx_to_node` attribute set. + + Returns: + The encoded `Node` object as a dictionary. + """ + if isinstance(node, int): + # We sometimes have the node object already replaced by its index (when + # `node_to_idx` is provided). In this case, the node is already encoded. + return node + + # Check if object has been encoded before + first_encoding = self._is_first_encoding(node) + py_id = self._get_or_assign_id(node, first_encoding) + if first_encoding: + # Full encoding + return { + "py/object": "sleap.skeleton.Node", + "py/state": {"py/tuple": [node.name, node.weight]}, + } + else: + # Reference by py/id + return {"py/id": py_id} + + def _encode_edge_type(self, edge_type: "EdgeType") -> Dict[str, Any]: + """Encodes an EdgeType object. + + Args: + edge_type: The EdgeType object to encode. Either `EdgeType.BODY` or + `EdgeType.SYMMETRY` enum with values 1 and 2 respectively. + + Returns: + The encoded EdgeType object as a dictionary. + """ + # Check if object has been encoded before + first_encoding = self._is_first_encoding(edge_type) + py_id = self._get_or_assign_id(edge_type, first_encoding) + if first_encoding: + # Full encoding + return { + "py/reduce": [ + {"py/type": "sleap.skeleton.EdgeType"}, + {"py/tuple": [edge_type.value]}, + ] + } + else: + # Reference by py/id + return {"py/id": py_id} + + def _get_or_assign_id(self, obj: Any, first_encoding: bool) -> int: + """Gets or assigns a py/id for the object. + + Args: + The object to get or assign a py/id for. + + Returns: + The py/id assigned to the object. + """ + # Object id is unique for each object in the current session + obj_id = id(obj) + # Assign a py/id to the object if it hasn't been assigned one yet + if first_encoding: + py_id = len(self._encoded_objects) + 1 # py/id starts at 1 + # Assign the py/id to the object and store it in _encoded_objects + self._encoded_objects[obj_id] = py_id + return self._encoded_objects[obj_id] + + def _is_first_encoding(self, obj: Any) -> bool: + """Checks if the object is being encoded for the first time. + + Args: + obj: The object to check. + + Returns: + True if this is the first encoding of the object, False otherwise. + """ + obj_id = id(obj) + first_time = obj_id not in self._encoded_objects + return first_time + + class Skeleton: """The main object for representing animal skeletons. @@ -937,7 +1415,7 @@ def to_dict(obj: "Skeleton", node_to_idx: Optional[Dict[Node, int]] = None) -> D # This is a weird hack to serialize the whole _graph into a dict. # I use the underlying to_json and parse it. - return json.loads(obj.to_json(node_to_idx)) + return json.loads(obj.to_json(node_to_idx=node_to_idx)) @classmethod def from_dict(cls, d: Dict, node_to_idx: Dict[Node, int] = None) -> "Skeleton": @@ -1001,10 +1479,10 @@ def to_json(self, node_to_idx: Optional[Dict[Node, int]] = None) -> str: """ jsonpickle.set_encoder_options("simplejson", sort_keys=True, indent=4) if node_to_idx is not None: - indexed_node_graph = nx.relabel_nodes( - G=self._graph, mapping=node_to_idx - ) # map nodes to int + # Map Nodes to int + indexed_node_graph = nx.relabel_nodes(G=self._graph, mapping=node_to_idx) else: + # Keep graph nodes as Node objects indexed_node_graph = self._graph # Encode to JSON @@ -1023,7 +1501,7 @@ def to_json(self, node_to_idx: Optional[Dict[Node, int]] = None) -> str: else: data = graph - json_str = jsonpickle.encode(data) + json_str = SkeletonEncoder.encode(data) return json_str @@ -1071,7 +1549,7 @@ def from_json( Returns: An instance of the `Skeleton` object decoded from the JSON. """ - dicts = jsonpickle.decode(json_str) + dicts: dict = SkeletonDecoder.decode(json_str) nx_graph = dicts.get("nx_graph", dicts) graph = json_graph.node_link_graph(nx_graph) diff --git a/sleap/util.py b/sleap/util.py index 1e59ea237..6c8a91f6c 100644 --- a/sleap/util.py +++ b/sleap/util.py @@ -11,7 +11,6 @@ import re import shutil from collections import defaultdict -from io import BytesIO from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, Hashable, Iterable, List, Optional from urllib.parse import unquote, urlparse @@ -31,6 +30,7 @@ import yaml from PIL import Image + import sleap.version as sleap_version if TYPE_CHECKING: @@ -391,18 +391,3 @@ def find_files_by_suffix( def parse_uri_path(uri: str) -> str: """Parse a URI starting with 'file:///' to a posix path.""" return Path(url2pathname(urlparse(unquote(uri)).path)).as_posix() - - -def decode_preview_image(img_b64: bytes) -> Image: - """Decode a skeleton preview image byte string representation to a `PIL.Image` - - Args: - img_b64: a byte string representation of a skeleton preview image - - Returns: - A PIL.Image of the skeleton preview - """ - bytes = base64.b64decode(img_b64) - buffer = BytesIO(bytes) - img = Image.open(buffer) - return img diff --git a/tests/data/skeleton/fly_skeleton_legs_pystate_dict.json b/tests/data/skeleton/fly_skeleton_legs_pystate_dict.json new file mode 100644 index 000000000..eae83d6bc --- /dev/null +++ b/tests/data/skeleton/fly_skeleton_legs_pystate_dict.json @@ -0,0 +1 @@ +{"directed": true, "graph": {"name": "skeleton_legs.mat", "num_edges_inserted": 23}, "links": [{"edge_insert_idx": 1, "key": 0, "source": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "neck", "weight": 1.0}}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "head", "weight": 1.0}}, "type": {"py/reduce": [{"py/type": "sleap.skeleton.EdgeType"}, {"py/tuple": [1]}]}}, {"edge_insert_idx": 0, "key": 0, "source": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "thorax", "weight": 1.0}}, "target": {"py/id": 1}, "type": {"py/id": 3}}, {"edge_insert_idx": 2, "key": 0, "source": {"py/id": 4}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "abdomen", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 3, "key": 0, "source": {"py/id": 4}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "wingL", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 4, "key": 0, "source": {"py/id": 4}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "wingR", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 5, "key": 0, "source": {"py/id": 4}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "forelegL1", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 8, "key": 0, "source": {"py/id": 4}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "forelegR1", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 11, "key": 0, "source": {"py/id": 4}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "midlegL1", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 14, "key": 0, "source": {"py/id": 4}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "midlegR1", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 17, "key": 0, "source": {"py/id": 4}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "hindlegL1", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 20, "key": 0, "source": {"py/id": 4}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "hindlegR1", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 6, "key": 0, "source": {"py/id": 8}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "forelegL2", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 7, "key": 0, "source": {"py/id": 14}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "forelegL3", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 9, "key": 0, "source": {"py/id": 9}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "forelegR2", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 10, "key": 0, "source": {"py/id": 16}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "forelegR3", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 12, "key": 0, "source": {"py/id": 10}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "midlegL2", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 13, "key": 0, "source": {"py/id": 18}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "midlegL3", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 15, "key": 0, "source": {"py/id": 11}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "midlegR2", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 16, "key": 0, "source": {"py/id": 20}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "midlegR3", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 18, "key": 0, "source": {"py/id": 12}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "hindlegL2", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 19, "key": 0, "source": {"py/id": 22}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "hindlegL3", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 21, "key": 0, "source": {"py/id": 13}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "hindlegR2", "weight": 1.0}}, "type": {"py/id": 3}}, {"edge_insert_idx": 22, "key": 0, "source": {"py/id": 24}, "target": {"py/object": "sleap.skeleton.Node", "py/state": {"name": "hindlegR3", "weight": 1.0}}, "type": {"py/id": 3}}], "multigraph": true, "nodes": [{"id": {"py/id": 2}}, {"id": {"py/id": 1}}, {"id": {"py/id": 4}}, {"id": {"py/id": 5}}, {"id": {"py/id": 6}}, {"id": {"py/id": 7}}, {"id": {"py/id": 8}}, {"id": {"py/id": 14}}, {"id": {"py/id": 15}}, {"id": {"py/id": 9}}, {"id": {"py/id": 16}}, {"id": {"py/id": 17}}, {"id": {"py/id": 10}}, {"id": {"py/id": 18}}, {"id": {"py/id": 19}}, {"id": {"py/id": 11}}, {"id": {"py/id": 20}}, {"id": {"py/id": 21}}, {"id": {"py/id": 12}}, {"id": {"py/id": 22}}, {"id": {"py/id": 23}}, {"id": {"py/id": 13}}, {"id": {"py/id": 24}}, {"id": {"py/id": 25}}]} \ No newline at end of file diff --git a/tests/fixtures/skeletons.py b/tests/fixtures/skeletons.py index 311510e6a..b432ca2c7 100644 --- a/tests/fixtures/skeletons.py +++ b/tests/fixtures/skeletons.py @@ -3,14 +3,27 @@ from sleap.skeleton import Skeleton TEST_FLY_LEGS_SKELETON = "tests/data/skeleton/fly_skeleton_legs.json" +TEST_FLY_LEGS_SKELETON_DICT = "tests/data/skeleton/fly_skeleton_legs_pystate_dict.json" @pytest.fixture def fly_legs_skeleton_json(): - """Path to fly_skeleton_legs.json""" + """Path to fly_skeleton_legs.json + + This skeleton json has py/state in tuple format. + """ return TEST_FLY_LEGS_SKELETON +@pytest.fixture +def fly_legs_skeleton_dict_json(): + """Path to fly_skeleton_legs_pystate_dict.json + + This skeleton json has py/state dict format. + """ + return TEST_FLY_LEGS_SKELETON_DICT + + @pytest.fixture def stickman(): diff --git a/tests/gui/test_app.py b/tests/gui/test_app.py index 745989da1..def835b6e 100644 --- a/tests/gui/test_app.py +++ b/tests/gui/test_app.py @@ -414,6 +414,12 @@ def toggle_and_verify_visibility(expected_visibility: bool = True): window.showNormal() vp = window.player + # Change state and ensure menu-item check updates + color_predicted = window.state["color predicted"] + assert window._menu_actions["color predicted"].isChecked() == color_predicted + window.state["color predicted"] = not color_predicted + assert window._menu_actions["color predicted"].isChecked() == (not color_predicted) + # Enable distinct colors window.state["color predicted"] = True diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index fd615ea81..0a978de0a 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -1932,7 +1932,11 @@ def test_flow_tracker(centered_pair_predictions_sorted: Labels, tmpdir): for inst in lf.instances: inst.track = None - track_args = dict(untracked_instances=lf.instances, img=lf.video[lf.frame_idx]) + track_args = dict( + untracked_instances=lf.instances, + img=lf.video[lf.frame_idx], + img_hw=lf.image.shape[-3:-1], + ) tracker.track(**track_args) # Check that saved instances are pruned to track window @@ -1975,7 +1979,11 @@ def test_max_tracks_matching_queue( for inst in lf.instances: inst.track = None - track_args = dict(untracked_instances=lf.instances, img=lf.video[lf.frame_idx]) + track_args = dict( + untracked_instances=lf.instances, + img=lf.video[lf.frame_idx], + img_hw=lf.image.shape[-3:-1], + ) tracker.track(**track_args) if trackername == "flowmaxtracks": diff --git a/tests/nn/test_tracker_components.py b/tests/nn/test_tracker_components.py index ffdd35257..9d3b65b38 100644 --- a/tests/nn/test_tracker_components.py +++ b/tests/nn/test_tracker_components.py @@ -35,7 +35,10 @@ def run_tracker_by_name(frames=None, img_scale: float = 0, **kwargs): @pytest.mark.parametrize( "tracker", ["simple", "flow", "simplemaxtracks", "flowmaxtracks"] ) -@pytest.mark.parametrize("similarity", ["instance", "iou", "centroid"]) +@pytest.mark.parametrize( + "similarity", + ["instance", "normalized_instance", "iou", "centroid", "object_keypoint"], +) @pytest.mark.parametrize("match", ["greedy", "hungarian"]) @pytest.mark.parametrize("img_scale", [0, 1, 0.25]) @pytest.mark.parametrize("count", [0, 2]) @@ -289,7 +292,7 @@ def test_max_tracking_large_gap_single_track(): tracked = [] for insts in preds: - tracked_insts = tracker.track(insts) + tracked_insts = tracker.track(insts, img_hw=(1, 1)) tracked.append(tracked_insts) all_tracks = list(set([inst.track for frame in tracked for inst in frame])) @@ -306,7 +309,7 @@ def test_max_tracking_large_gap_single_track(): tracked = [] for insts in preds: - tracked_insts = tracker.track(insts) + tracked_insts = tracker.track(insts, img_hw=(1, 1)) tracked.append(tracked_insts) all_tracks = list(set([inst.track for frame in tracked for inst in frame])) @@ -353,7 +356,7 @@ def test_max_tracking_small_gap_on_both_tracks(): tracked = [] for insts in preds: - tracked_insts = tracker.track(insts) + tracked_insts = tracker.track(insts, img_hw=(1, 1)) tracked.append(tracked_insts) all_tracks = list(set([inst.track for frame in tracked for inst in frame])) @@ -370,7 +373,7 @@ def test_max_tracking_small_gap_on_both_tracks(): tracked = [] for insts in preds: - tracked_insts = tracker.track(insts) + tracked_insts = tracker.track(insts, img_hw=(1, 1)) tracked.append(tracked_insts) all_tracks = list(set([inst.track for frame in tracked for inst in frame])) @@ -422,7 +425,7 @@ def test_max_tracking_extra_detections(): tracked = [] for insts in preds: - tracked_insts = tracker.track(insts) + tracked_insts = tracker.track(insts, img_hw=(1, 1)) tracked.append(tracked_insts) all_tracks = list(set([inst.track for frame in tracked for inst in frame])) @@ -439,7 +442,7 @@ def test_max_tracking_extra_detections(): tracked = [] for insts in preds: - tracked_insts = tracker.track(insts) + tracked_insts = tracker.track(insts, img_hw=(1, 1)) tracked.append(tracked_insts) all_tracks = list(set([inst.track for frame in tracked for inst in frame])) diff --git a/tests/nn/test_tracking_integration.py b/tests/nn/test_tracking_integration.py index 2852a72b7..caebe49ff 100644 --- a/tests/nn/test_tracking_integration.py +++ b/tests/nn/test_tracking_integration.py @@ -103,6 +103,8 @@ def main(f, dir): instance=sleap.nn.tracker.components.instance_similarity, centroid=sleap.nn.tracker.components.centroid_distance, iou=sleap.nn.tracker.components.instance_iou, + normalized_instance=sleap.nn.tracker.components.normalized_instance_similarity, + object_keypoint=sleap.nn.tracker.components.factory_object_keypoint_similarity(), ) scales = ( 1, diff --git a/tests/test_skeleton.py b/tests/test_skeleton.py index 1f7c3a853..7c5216316 100644 --- a/tests/test_skeleton.py +++ b/tests/test_skeleton.py @@ -1,10 +1,62 @@ -import os import copy - -import jsonpickle +import os import pytest +import json + +from networkx.readwrite import json_graph +from sleap.skeleton import Skeleton, SkeletonDecoder +from sleap.skeleton import SkeletonEncoder + + +def test_decoded_encoded_Skeleton_from_load_json(fly_legs_skeleton_json): + """ + Test Skeleton decoded from SkeletonEncoder.encode matches the original Skeleton. + """ + # Get the skeleton from the fixture + skeleton = Skeleton.load_json(fly_legs_skeleton_json) + # Get the graph from the skeleton + indexed_node_graph = skeleton._graph + graph = json_graph.node_link_data(indexed_node_graph) + + # Encode the graph as a json string to test .encode method + encoded_json_str = SkeletonEncoder.encode(graph) + + # Get the skeleton from the encoded json string + decoded_skeleton = Skeleton.from_json(encoded_json_str) + + # Check that the decoded skeleton is the same as the original skeleton + assert skeleton.matches(decoded_skeleton) -from sleap.skeleton import Skeleton + +@pytest.mark.parametrize( + "skeleton_fixture_name", ["flies13_skeleton", "skeleton", "stickman"] +) +def test_decoded_encoded_Skeleton(skeleton_fixture_name, request): + """ + Test Skeleton decoded from SkeletonEncoder.encode matches the original Skeleton. + """ + # Use request.getfixturevalue to get the actual fixture value by name + skeleton = request.getfixturevalue(skeleton_fixture_name) + + # Get the graph from the skeleton + indexed_node_graph = skeleton._graph + graph = json_graph.node_link_data(indexed_node_graph) + + # Encode the graph as a json string to test .encode method + encoded_json_str = SkeletonEncoder.encode(graph) + + # Get the skeleton from the encoded json string + decoded_skeleton = Skeleton.from_json(encoded_json_str) + + # Check that the decoded skeleton is the same as the original skeleton + assert skeleton.matches(decoded_skeleton) + + # Now make everything into a JSON string + skeleton_json_str = skeleton.to_json() + decoded_skeleton_json_str = decoded_skeleton.to_json() + + # Check that the JSON strings are the same + assert json.loads(skeleton_json_str) == json.loads(decoded_skeleton_json_str) def test_add_dupe_node(skeleton): @@ -194,9 +246,9 @@ def test_json(skeleton: Skeleton, tmpdir): ) assert skeleton.is_template == False json_str = skeleton.to_json() - json_dict = jsonpickle.decode(json_str) + json_dict = SkeletonDecoder.decode(json_str) json_dict_keys = list(json_dict.keys()) - assert "nx_graph" not in json_dict_keys + assert "nx_graph" in json_dict_keys # SkeletonDecoder adds this key assert "preview_image" not in json_dict_keys assert "description" not in json_dict_keys @@ -208,7 +260,7 @@ def test_json(skeleton: Skeleton, tmpdir): skeleton._is_template = True json_str = skeleton.to_json() - json_dict = jsonpickle.decode(json_str) + json_dict = SkeletonDecoder.decode(json_str) json_dict_keys = list(json_dict.keys()) assert "nx_graph" in json_dict_keys assert "preview_image" in json_dict_keys @@ -224,6 +276,26 @@ def test_json(skeleton: Skeleton, tmpdir): assert skeleton.matches(skeleton_copy) +def test_decode_preview_image(flies13_skeleton: Skeleton): + skeleton = flies13_skeleton + img_b64 = skeleton.preview_image + img = SkeletonDecoder.decode_preview_image(img_b64) + assert img.mode == "RGBA" + + +def test_skeleton_decoder(fly_legs_skeleton_json, fly_legs_skeleton_dict_json): + """Test that SkeletonDecoder can decode both tuple and dict py/state formats.""" + + skeleton_tuple_pystate = Skeleton.load_json(fly_legs_skeleton_json) + assert isinstance(skeleton_tuple_pystate, Skeleton) + + skeleton_dict_pystate = Skeleton.load_json(fly_legs_skeleton_dict_json) + assert isinstance(skeleton_dict_pystate, Skeleton) + + # These are the same skeleton, so they should match + assert skeleton_dict_pystate.matches(skeleton_tuple_pystate) + + def test_hdf5(skeleton, stickman, tmpdir): filename = os.path.join(tmpdir, "skeleton.h5") diff --git a/tests/test_util.py b/tests/test_util.py index a7916d47f..35b41afa8 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,5 +1,4 @@ import pytest -from sleap.skeleton import Skeleton from sleap.util import * @@ -147,10 +146,3 @@ def test_save_dict_to_hdf5(tmpdir): assert f["bar"][-1].decode() == "zop" assert f["cab"]["a"][()] == 2 - - -def test_decode_preview_image(flies13_skeleton: Skeleton): - skeleton = flies13_skeleton - img_b64 = skeleton.preview_image - img = decode_preview_image(img_b64) - assert img.mode == "RGBA"