From 30d1023450a4e6a1103abff55abb1b2f16290263 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Mon, 25 Nov 2024 20:17:46 -0500 Subject: [PATCH 1/4] Refactor multihypothesis segmentation handling --- .../candidate_graph/__init__.py | 6 +- .../candidate_graph/compute_graph.py | 80 +++++++++++++++---- src/motile_toolbox/candidate_graph/iou.py | 50 +++++++----- src/motile_toolbox/candidate_graph/utils.py | 52 ++++++------ .../utils/relabel_segmentation.py | 21 ++--- tests/conftest.py | 21 ++--- .../test_compute_graph.py | 26 +++--- .../test_conflict_sets.py | 4 +- tests/test_candidate_graph/test_iou.py | 12 +-- tests/test_candidate_graph/test_utils.py | 21 ----- tests/test_utils/test_relabel_segmentation.py | 4 +- 11 files changed, 170 insertions(+), 127 deletions(-) diff --git a/src/motile_toolbox/candidate_graph/__init__.py b/src/motile_toolbox/candidate_graph/__init__.py index 3293cf3..fd4201f 100644 --- a/src/motile_toolbox/candidate_graph/__init__.py +++ b/src/motile_toolbox/candidate_graph/__init__.py @@ -1,4 +1,8 @@ -from .compute_graph import get_candidate_graph, get_candidate_graph_from_points_list +from .compute_graph import ( + compute_graph_from_multiseg, + compute_graph_from_points_list, + compute_graph_from_seg, +) from .graph_attributes import EdgeAttr, NodeAttr from .graph_to_nx import graph_to_nx from .iou import add_iou diff --git a/src/motile_toolbox/candidate_graph/compute_graph.py b/src/motile_toolbox/candidate_graph/compute_graph.py index fc05e14..e1e40e3 100644 --- a/src/motile_toolbox/candidate_graph/compute_graph.py +++ b/src/motile_toolbox/candidate_graph/compute_graph.py @@ -11,20 +11,19 @@ logger = logging.getLogger(__name__) -def get_candidate_graph( +def compute_graph_from_seg( segmentation: np.ndarray, max_edge_distance: float, iou: bool = False, scale: list[float] | None = None, -) -> tuple[nx.DiGraph, list[set[Any]] | None]: +) -> nx.DiGraph: """Construct a candidate graph from a segmentation array. Nodes are placed at the centroid of each segmentation and edges are added for all nodes in adjacent frames - within max_edge_distance. If segmentation contains multiple hypotheses, will also - return a list of conflicting node ids that cannot be selected together. + within max_edge_distance. Args: segmentation (np.ndarray): A numpy array with integer labels and dimensions - (t, h, [z], y, x), where h is the number of hypotheses. + (t, [z], y, x). max_edge_distance (float): Maximum distance that objects can travel between frames. All nodes with centroids within this distance in adjacent frames will by connected with a candidate edge. @@ -35,11 +34,8 @@ def get_candidate_graph( Defaults to None, which implies the data is isotropic. Returns: - tuple[nx.DiGraph, list[set[Any]] | None]: A candidate graph that can be passed - to the motile solver, and a list of conflicting node ids. + nx.DiGraph: A candidate graph that can be passed to the motile solver """ - num_hypotheses = segmentation.shape[1] - # add nodes cand_graph, node_frame_dict = nodes_from_segmentation(segmentation, scale=scale) logger.info(f"Candidate nodes: {cand_graph.number_of_nodes()}") @@ -57,16 +53,73 @@ def get_candidate_graph( logger.info(f"Candidate edges: {cand_graph.number_of_edges()}") + return cand_graph + + +def compute_graph_from_multiseg( + segmentations: np.ndarray, + max_edge_distance: float, + iou: bool = False, + scale: list[float] | None = None, +) -> tuple[nx.DiGraph, list[set[Any]]]: + """Construct a candidate graph from a segmentation array. Nodes are placed at the + centroid of each segmentation and edges are added for all nodes in adjacent frames + within max_edge_distance. + + Args: + segmentations (np.ndarray): numpy array with mupliple possible segmentations + stacked. Each segmentation has integer labels and + dimensions (h, t, [z], y, x). Assumes unique labels even between hypotheses. + max_edge_distance (float): Maximum distance that objects can travel between + frames. All nodes with centroids within this distance in adjacent frames + will by connected with a candidate edge. + iou (bool, optional): Whether to include IOU on the candidate graph. + Defaults to False. + scale (list[float] | None, optional): The scale of the segmentation data. + Will be used to rescale the point locations and attribute computations. + Defaults to None, which implies the data is isotropic. + + Returns: + tuple[nx.DiGraph, list[set[Any]]: A candidate graph that can be passed to the + motile solver, and a list of conflicting node sets + """ + # add nodes + cand_graph = nx.DiGraph() + node_frame_dict: dict[int, Any] = {} + for hypo_id, seg in enumerate(segmentations): + seg_node_graph, seg_node_frame_dict = nodes_from_segmentation( + seg, scale=scale, seg_hypo=hypo_id + ) + cand_graph.update(seg_node_graph) + for frame, nodes in seg_node_frame_dict.items(): + if frame not in node_frame_dict: + node_frame_dict[frame] = [] + node_frame_dict[frame].extend(nodes) + logger.info(f"Candidate nodes: {cand_graph.number_of_nodes()}") + + # add edges + add_cand_edges( + cand_graph, + max_edge_distance=max_edge_distance, + node_frame_dict=node_frame_dict, + ) + if iou: + # Scale does not matter to IOU, because both numerator and denominator + # are scaled by the anisotropy. + add_iou(cand_graph, segmentations, node_frame_dict, multiseg=True) + + logger.info(f"Candidate edges: {cand_graph.number_of_edges()}") + # Compute conflict sets between segmentations conflicts = [] - if num_hypotheses > 1: - for time, segs in enumerate(segmentation): - conflicts.extend(compute_conflict_sets(segs, time)) + for time in range(segmentations.shape[1]): + segs = segmentations[:, time] + conflicts.extend(compute_conflict_sets(segs, time)) return cand_graph, conflicts -def get_candidate_graph_from_points_list( +def compute_graph_from_points_list( points_list: np.ndarray, max_edge_distance: float, scale: list[float] | None = None, @@ -86,7 +139,6 @@ def get_candidate_graph_from_points_list( Returns: nx.DiGraph: A candidate graph that can be passed to the motile solver. - Multiple hypotheses not supported for points input. """ # add nodes cand_graph, node_frame_dict = nodes_from_points_list(points_list, scale=scale) diff --git a/src/motile_toolbox/candidate_graph/iou.py b/src/motile_toolbox/candidate_graph/iou.py index d29f2a3..90751fd 100644 --- a/src/motile_toolbox/candidate_graph/iou.py +++ b/src/motile_toolbox/candidate_graph/iou.py @@ -32,9 +32,9 @@ def _compute_ious( values, counts = np.unique(flattened_stacked, axis=1, return_counts=True) frame1_values, frame1_counts = np.unique(frame1, return_counts=True) - frame1_label_sizes = dict(zip(frame1_values, frame1_counts)) + frame1_label_sizes = dict(zip(frame1_values, frame1_counts, strict=False)) frame2_values, frame2_counts = np.unique(frame2, return_counts=True) - frame2_label_sizes = dict(zip(frame2_values, frame2_counts)) + frame2_label_sizes = dict(zip(frame2_values, frame2_counts, strict=False)) ious: list[tuple[int, int, float]] = [] for index in range(values.shape[1]): pair = values[:, index] @@ -45,37 +45,47 @@ def _compute_ious( return ious -def _get_iou_dict(segmentation) -> dict[str, dict[str, float]]: - """Get all ious values for the provided segmentation (all frames). +def _get_iou_dict(segmentation, multiseg=False) -> dict[str, dict[str, float]]: + """Get all ious values for the provided segmentations (all frames). Will return as map from node_id -> dict[node_id] -> iou for easy navigation when adding to candidate graph. Args: - segmentation (np.ndarray): Segmentation that was used to create cand_graph. - Has shape (t, h, [z], y, x), where h is the number of hypotheses. + segmentation (np.ndarray): Segmentations that were used to create cand_graph. + Has shape ([h], t, [z], y, x), where h is the number of hypotheses + if multiseg is True. + multiseg (bool): Flag indicating if the provided segmentation contains + multiple hypothesis segmentations. Defaults to False. Returns: dict[str, dict[str, float]]: A map from node id to another dictionary, which contains node_ids to iou values. """ iou_dict: dict[str, dict[str, float]] = {} - hypo_pairs: list[tuple[int | None, ...]] - num_hypotheses = segmentation.shape[1] - if num_hypotheses > 1: - hypo_pairs = list(product(range(num_hypotheses), repeat=2)) + hypo_pairs: list[tuple[int, ...]] = [(0, 0)] + if multiseg: + num_hypotheses = segmentation.shape[0] + if num_hypotheses > 1: + hypo_pairs = list(product(range(num_hypotheses), repeat=2)) else: - hypo_pairs = [(None, None)] + segmentation = np.expand_dims(segmentation, 0) - for frame in range(len(segmentation) - 1): + for frame in range(segmentation.shape[1] - 1): for hypo1, hypo2 in hypo_pairs: - seg1 = segmentation[frame][hypo1] - seg2 = segmentation[frame + 1][hypo2] + seg1 = segmentation[hypo1][frame] + seg2 = segmentation[hypo2][frame + 1] ious = _compute_ious(seg1, seg2) + print(hypo1, hypo2, ious) for label1, label2, iou in ious: - node_id1 = get_node_id(frame, label1, hypo1) + if multiseg: + node_id1 = get_node_id(frame, label1, hypo1) + node_id2 = get_node_id(frame + 1, label2, hypo2) + else: + node_id1 = get_node_id(frame, label1) + node_id2 = get_node_id(frame + 1, label2) + if node_id1 not in iou_dict: iou_dict[node_id1] = {} - node_id2 = get_node_id(frame + 1, label2, hypo2) iou_dict[node_id1][node_id2] = iou return iou_dict @@ -84,22 +94,26 @@ def add_iou( cand_graph: nx.DiGraph, segmentation: np.ndarray, node_frame_dict: dict[int, list[Any]] | None = None, + multiseg=False, ) -> None: """Add IOU to the candidate graph. Args: cand_graph (nx.DiGraph): Candidate graph with nodes and edges already populated segmentation (np.ndarray): segmentation that was used to create cand_graph. - Has shape (t, h, [z], y, x), where h is the number of hypotheses. + Has shape ([h], t, [z], y, x), where h is the number of hypotheses if + multiseg is True. node_frame_dict(dict[int, list[Any]] | None, optional): A mapping from time frames to nodes in that frame. Will be computed if not provided, but can be provided for efficiency (e.g. after running nodes_from_segmentation). Defaults to None. + multiseg (bool): Flag indicating if the given segmentation is actually multiple + stacked segmentations. Defaults to False. """ if node_frame_dict is None: node_frame_dict = _compute_node_frame_dict(cand_graph) frames = sorted(node_frame_dict.keys()) - ious = _get_iou_dict(segmentation) + ious = _get_iou_dict(segmentation, multiseg=multiseg) for frame in tqdm(frames): if frame + 1 not in node_frame_dict.keys(): continue diff --git a/src/motile_toolbox/candidate_graph/utils.py b/src/motile_toolbox/candidate_graph/utils.py index efa48d8..4e4dffe 100644 --- a/src/motile_toolbox/candidate_graph/utils.py +++ b/src/motile_toolbox/candidate_graph/utils.py @@ -38,6 +38,7 @@ def get_node_id(time: int, label_id: int, hypothesis_id: int | None = None) -> s def nodes_from_segmentation( segmentation: np.ndarray, scale: list[float] | None = None, + seg_hypo=None, ) -> tuple[nx.DiGraph, dict[int, list[Any]]]: """Extract candidate nodes from a segmentation. Returns a networkx graph with only nodes, and also a dictionary from frames to node_ids for @@ -48,55 +49,52 @@ def nodes_from_segmentation( - position - segmentation id - area - - hypothesis id (optional) Args: segmentation (np.ndarray): A numpy array with integer labels and dimensions - (t, h, [z], y, x), where h is the number of hypotheses. + (t, [z], y, x). scale (list[float] | None, optional): The scale of the segmentation data. Will be used to rescale the point locations and attribute computations. Defaults to None, which implies the data is isotropic. Should include time and all spatial dimentsions. + seg_hypo (int | None): A number to be stored in NodeAttr.SEG_HYPO, if given. Returns: tuple[nx.DiGraph, dict[int, list[Any]]]: A candidate graph with only nodes, and a mapping from time frames to node ids. """ + logger.debug("Extracting nodes from segmentation") cand_graph = nx.DiGraph() # also construct a dictionary from time frame to node_id for efficiency node_frame_dict: dict[int, list[Any]] = {} - logger.info("Extracting nodes from segmentation") - num_hypotheses = segmentation.shape[1] + if scale is None: scale = [ 1, - ] * (segmentation.ndim - 1) # don't include hypothesis + ] * segmentation.ndim else: assert ( - len(scale) == segmentation.ndim - 1 - ), f"Scale {scale} should have {segmentation.ndim - 1} dims" + len(scale) == segmentation.ndim + ), f"Scale {scale} should have {segmentation.ndim} dims" + for t in tqdm(range(len(segmentation))): segs = segmentation[t] - hypo_id: int | None - for hypo_id, hypo in enumerate(segs): - if num_hypotheses == 1: - hypo_id = None - nodes_in_frame = [] - props = regionprops(hypo, spacing=tuple(scale[1:])) - for regionprop in props: - node_id = get_node_id(t, regionprop.label, hypothesis_id=hypo_id) - attrs = {NodeAttr.TIME.value: t, NodeAttr.AREA.value: regionprop.area} - attrs[NodeAttr.SEG_ID.value] = regionprop.label - if hypo_id is not None: - attrs[NodeAttr.SEG_HYPO.value] = hypo_id - centroid = regionprop.centroid # [z,] y, x - attrs[NodeAttr.POS.value] = centroid - cand_graph.add_node(node_id, **attrs) - nodes_in_frame.append(node_id) - if nodes_in_frame: - if t not in node_frame_dict: - node_frame_dict[t] = [] - node_frame_dict[t].extend(nodes_in_frame) + nodes_in_frame = [] + props = regionprops(segs, spacing=tuple(scale[1:])) + for regionprop in props: + node_id = get_node_id(t, regionprop.label, hypothesis_id=seg_hypo) + attrs = {NodeAttr.TIME.value: t, NodeAttr.AREA.value: regionprop.area} + attrs[NodeAttr.SEG_ID.value] = regionprop.label + if seg_hypo: + attrs[NodeAttr.SEG_HYPO.value] = seg_hypo + centroid = regionprop.centroid # [z,] y, x + attrs[NodeAttr.POS.value] = centroid + cand_graph.add_node(node_id, **attrs) + nodes_in_frame.append(node_id) + if nodes_in_frame: + if t not in node_frame_dict: + node_frame_dict[t] = [] + node_frame_dict[t].extend(nodes_in_frame) return cand_graph, node_frame_dict diff --git a/src/motile_toolbox/utils/relabel_segmentation.py b/src/motile_toolbox/utils/relabel_segmentation.py index 9b41c07..5fd5a71 100644 --- a/src/motile_toolbox/utils/relabel_segmentation.py +++ b/src/motile_toolbox/utils/relabel_segmentation.py @@ -14,18 +14,15 @@ def relabel_segmentation( Args: solution_nx_graph (nx.DiGraph): Networkx graph with the solution to use for relabeling. Nodes not in graph will be removed from seg. Original - segmentation ids and hypothesis ids have to be stored in the graph so we + segmentation ids have to be stored in the graph so we can map them back. - segmentation (np.ndarray): Original (potentially multi-hypothesis) - segmentation with dimensions (t,h,[z],y,x), where h is 1 for single - input segmentation. + segmentation (np.ndarray): Original segmentation with dimensions (t, [z], y, x) Returns: np.ndarray: Relabeled segmentation array where nodes in same track share same - id with shape (t,1,[z],y,x) + id with shape (t,[z],y,x) """ - output_shape = (segmentation.shape[0], 1, *segmentation.shape[2:]) - tracked_masks = np.zeros_like(segmentation, shape=output_shape) + tracked_masks = np.zeros_like(segmentation) id_counter = 1 parent_nodes = [n for (n, d) in solution_nx_graph.out_degree() if d > 1] soln_copy = solution_nx_graph.copy() @@ -36,13 +33,7 @@ def relabel_segmentation( for node in node_set: time_frame = solution_nx_graph.nodes[node][NodeAttr.TIME.value] previous_seg_id = solution_nx_graph.nodes[node][NodeAttr.SEG_ID.value] - if NodeAttr.SEG_HYPO.value in solution_nx_graph.nodes[node]: - hypothesis_id = solution_nx_graph.nodes[node][NodeAttr.SEG_HYPO.value] - else: - hypothesis_id = 0 - previous_seg_mask = ( - segmentation[time_frame, hypothesis_id] == previous_seg_id - ) - tracked_masks[time_frame, 0][previous_seg_mask] = id_counter + previous_seg_mask = segmentation[time_frame] == previous_seg_id + tracked_masks[time_frame][previous_seg_mask] = id_counter id_counter += 1 return tracked_masks diff --git a/tests/conftest.py b/tests/conftest.py index b88eaaf..f2e9958 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,7 +23,7 @@ def segmentation_2d(): rr, cc = disk(center=(60, 45), radius=15, shape=frame_shape) segmentation[1][rr, cc] = 2 - return np.expand_dims(segmentation, 1) + return segmentation @pytest.fixture @@ -33,7 +33,7 @@ def multi_hypothesis_segmentation_2d(): """ frame_shape = (100, 100) - total_shape = (2, 2, *frame_shape) # 2 time points, 2 hypotheses layers, H, W + total_shape = (2, 2, *frame_shape) # 2 hypotheses, 2 time points, H, W segmentation = np.zeros(total_shape, dtype="int32") # make frame with one cell in center with label 1 (hypo 1) rr0, cc0 = disk(center=(50, 50), radius=20, shape=frame_shape) @@ -41,21 +41,21 @@ def multi_hypothesis_segmentation_2d(): rr1, cc1 = disk(center=(45, 45), radius=15, shape=frame_shape) segmentation[0, 0][rr0, cc0] = 1 - segmentation[0, 1][rr1, cc1] = 1 + segmentation[1, 0][rr1, cc1] = 1 # make frame with two cells # first cell centered at (20, 80) with label 1 rr0, cc0 = disk(center=(20, 80), radius=10, shape=frame_shape) rr1, cc1 = disk(center=(15, 75), radius=15, shape=frame_shape) - segmentation[1, 0][rr0, cc0] = 1 + segmentation[0, 1][rr0, cc0] = 1 segmentation[1, 1][rr1, cc1] = 1 # second cell centered at (60, 45) with label 2 rr0, cc0 = disk(center=(60, 45), radius=15, shape=frame_shape) rr1, cc1 = disk(center=(55, 40), radius=20, shape=frame_shape) - segmentation[1, 0][rr0, cc0] = 2 + segmentation[0, 1][rr0, cc0] = 2 segmentation[1, 1][rr1, cc1] = 2 return segmentation @@ -220,7 +220,7 @@ def segmentation_3d(): mask = sphere(center=(60, 50, 45), radius=15, shape=frame_shape) segmentation[1][mask] = 2 - return np.expand_dims(segmentation, 1) + return segmentation @pytest.fixture @@ -235,20 +235,21 @@ def multi_hypothesis_segmentation_3d(): # make first frame with one cell in center with label 1 mask = sphere(center=(50, 50, 50), radius=20, shape=frame_shape) segmentation[0, 0][mask] = 1 + + # make second hypothesis first frame with one cell in center with label 1 mask = sphere(center=(45, 50, 55), radius=20, shape=frame_shape) - segmentation[0, 1][mask] = 1 + segmentation[1, 0][mask] = 1 # make second frame, first hypothesis with two cells # first cell centered at (20, 50, 80) with label 1 # second cell centered at (60, 50, 45) with label 2 mask = sphere(center=(20, 50, 80), radius=10, shape=frame_shape) - segmentation[1, 0][mask] = 1 + segmentation[0, 1][mask] = 1 mask = sphere(center=(60, 50, 45), radius=15, shape=frame_shape) - segmentation[1, 0][mask] = 2 + segmentation[0, 1][mask] = 2 # make second frame, second hypothesis with one cell # first cell centered at (15, 50, 70) with label 1 - # second cell centered at (55, 55, 45) with label 2 mask = sphere(center=(15, 50, 70), radius=10, shape=frame_shape) segmentation[1, 1][mask] = 1 diff --git a/tests/test_candidate_graph/test_compute_graph.py b/tests/test_candidate_graph/test_compute_graph.py index 755abf2..0b0a13b 100644 --- a/tests/test_candidate_graph/test_compute_graph.py +++ b/tests/test_candidate_graph/test_compute_graph.py @@ -2,16 +2,20 @@ import numpy as np import pytest -from motile_toolbox.candidate_graph import EdgeAttr, get_candidate_graph +from motile_toolbox.candidate_graph import ( + EdgeAttr, + compute_graph_from_multiseg, + compute_graph_from_seg, +) from motile_toolbox.candidate_graph.compute_graph import ( - get_candidate_graph_from_points_list, + compute_graph_from_points_list, ) from motile_toolbox.candidate_graph.graph_attributes import NodeAttr def test_graph_from_segmentation_2d(segmentation_2d, graph_2d): # test with 2D segmentation - cand_graph, _ = get_candidate_graph( + cand_graph = compute_graph_from_seg( segmentation=segmentation_2d, max_edge_distance=100, iou=True, @@ -28,7 +32,7 @@ def test_graph_from_segmentation_2d(segmentation_2d, graph_2d): ) # lower edge distance - cand_graph, _ = get_candidate_graph( + cand_graph = compute_graph_from_seg( segmentation=segmentation_2d, max_edge_distance=15, ) @@ -38,7 +42,7 @@ def test_graph_from_segmentation_2d(segmentation_2d, graph_2d): def test_graph_from_segmentation_3d(segmentation_3d, graph_3d): # test with 3D segmentation - cand_graph, _ = get_candidate_graph( + cand_graph = compute_graph_from_seg( segmentation=segmentation_3d, max_edge_distance=100, ) @@ -54,8 +58,8 @@ def test_graph_from_multi_segmentation_2d( multi_hypothesis_segmentation_2d, multi_hypothesis_graph_2d ): # test with 2D segmentation - cand_graph, conflict_set = get_candidate_graph( - segmentation=multi_hypothesis_segmentation_2d, + cand_graph, conflict_set = compute_graph_from_multiseg( + segmentations=multi_hypothesis_segmentation_2d, max_edge_distance=100, iou=True, ) @@ -77,8 +81,8 @@ def test_graph_from_multi_segmentation_2d( # TODO: Test conflict set # lower edge distance - cand_graph, _ = get_candidate_graph( - segmentation=multi_hypothesis_segmentation_2d, + cand_graph, _ = compute_graph_from_multiseg( + segmentations=multi_hypothesis_segmentation_2d, max_edge_distance=14, ) assert Counter(list(cand_graph.nodes)) == Counter( @@ -100,12 +104,12 @@ def test_graph_from_points_list(): [2, 1, 1, 1], ] ) - cand_graph = get_candidate_graph_from_points_list(points_list, max_edge_distance=3) + cand_graph = compute_graph_from_points_list(points_list, max_edge_distance=3) assert cand_graph.number_of_edges() == 3 assert len(cand_graph.in_edges(3)) == 0 # test scale - cand_graph = get_candidate_graph_from_points_list( + cand_graph = compute_graph_from_points_list( points_list, max_edge_distance=3, scale=[1, 1, 1, 5] ) assert cand_graph.number_of_edges() == 0 diff --git a/tests/test_candidate_graph/test_conflict_sets.py b/tests/test_candidate_graph/test_conflict_sets.py index c07e4e8..a820609 100644 --- a/tests/test_candidate_graph/test_conflict_sets.py +++ b/tests/test_candidate_graph/test_conflict_sets.py @@ -5,7 +5,7 @@ def test_conflict_sets_2d(multi_hypothesis_segmentation_2d): for t in range(multi_hypothesis_segmentation_2d.shape[0]): - conflict_set = compute_conflict_sets(multi_hypothesis_segmentation_2d[t], t) + conflict_set = compute_conflict_sets(multi_hypothesis_segmentation_2d[:, t], t) if t == 0: expected = [{"0_1_1", "0_0_1"}] assert len(conflict_set) == 1 @@ -22,7 +22,7 @@ def test_conflict_sets_2d_reshaped(multi_hypothesis_segmentation_2d): reshaped = np.asarray( [ multi_hypothesis_segmentation_2d[0, 0], # hypothesis 0 - multi_hypothesis_segmentation_2d[1, 0], # hypothesis 1 + multi_hypothesis_segmentation_2d[0, 1], # hypothesis 1 multi_hypothesis_segmentation_2d[1, 1], ] ) # hypothesis 2 diff --git a/tests/test_candidate_graph/test_iou.py b/tests/test_candidate_graph/test_iou.py index d192e2f..cba2906 100644 --- a/tests/test_candidate_graph/test_iou.py +++ b/tests/test_candidate_graph/test_iou.py @@ -9,24 +9,24 @@ def test_compute_ious_2d(segmentation_2d): expected = [ (1, 2, 555.46 / 1408.0), ] - for iou, expected_iou in zip(ious, expected): + for iou, expected_iou in zip(ious, expected, strict=False): assert iou == pytest.approx(expected_iou, abs=0.01) ious = _compute_ious(segmentation_2d[1], segmentation_2d[1]) expected = [(1, 1, 1.0), (2, 2, 1.0)] - for iou, expected_iou in zip(ious, expected): + for iou, expected_iou in zip(ious, expected, strict=False): assert iou == pytest.approx(expected_iou, abs=0.01) def test_compute_ious_3d(segmentation_3d): ious = _compute_ious(segmentation_3d[0], segmentation_3d[1]) expected = [(1, 2, 0.30)] - for iou, expected_iou in zip(ious, expected): + for iou, expected_iou in zip(ious, expected, strict=False): assert iou == pytest.approx(expected_iou, abs=0.01) ious = _compute_ious(segmentation_3d[1], segmentation_3d[1]) expected = [(1, 1, 1.0), (2, 2, 1.0)] - for iou, expected_iou in zip(ious, expected): + for iou, expected_iou in zip(ious, expected, strict=False): assert iou == pytest.approx(expected_iou, abs=0.01) @@ -46,9 +46,9 @@ def test_multi_hypo_iou_2d(multi_hypothesis_segmentation_2d, multi_hypothesis_gr expected = multi_hypothesis_graph_2d input_graph = multi_hypothesis_graph_2d.copy() nx.set_edge_attributes(input_graph, -1, name=EdgeAttr.IOU.value) - add_iou(input_graph, multi_hypothesis_segmentation_2d) + add_iou(input_graph, multi_hypothesis_segmentation_2d, multiseg=True) for s, t, attrs in expected.edges(data=True): - print(s, t) + print(s, t, attrs) assert ( pytest.approx(attrs[EdgeAttr.IOU.value], abs=0.01) == input_graph.edges[(s, t)][EdgeAttr.IOU.value] diff --git a/tests/test_candidate_graph/test_utils.py b/tests/test_candidate_graph/test_utils.py index a2f9462..3513a82 100644 --- a/tests/test_candidate_graph/test_utils.py +++ b/tests/test_candidate_graph/test_utils.py @@ -52,26 +52,6 @@ def test_nodes_from_segmentation_2d(segmentation_2d): assert Counter(node_frame_dict[1]) == Counter(["1_1", "1_2"]) -def test_nodes_from_segmentation_2d_hypo( - multi_hypothesis_segmentation_2d, multi_hypothesis_graph_2d -): - # test with 2D segmentation - node_graph, node_frame_dict = nodes_from_segmentation( - segmentation=multi_hypothesis_segmentation_2d - ) - assert Counter(list(node_graph.nodes)) == Counter( - list(multi_hypothesis_graph_2d.nodes) - ) - assert node_graph.nodes["1_0_1"][NodeAttr.SEG_ID.value] == 1 - assert node_graph.nodes["1_0_1"][NodeAttr.SEG_HYPO.value] == 0 - assert node_graph.nodes["1_0_1"][NodeAttr.TIME.value] == 1 - assert node_graph.nodes["1_0_1"][NodeAttr.AREA.value] == 305 - assert node_graph.nodes["1_0_1"][NodeAttr.POS.value] == (20, 80) - - assert Counter(node_frame_dict[0]) == Counter(["0_0_1", "0_1_1"]) - assert Counter(node_frame_dict[1]) == Counter(["1_0_1", "1_0_2", "1_1_1", "1_1_2"]) - - def test_nodes_from_segmentation_3d(segmentation_3d): # test with 3D segmentation node_graph, node_frame_dict = nodes_from_segmentation( @@ -116,7 +96,6 @@ def test_add_cand_edges_3d(graph_3d): def test_get_node_id(): assert get_node_id(0, 2) == "0_2" - assert get_node_id(2, 10, 3) == "2_3_10" def test_compute_node_frame_dict(graph_2d): diff --git a/tests/test_utils/test_relabel_segmentation.py b/tests/test_utils/test_relabel_segmentation.py index 257d9ed..57d796f 100644 --- a/tests/test_utils/test_relabel_segmentation.py +++ b/tests/test_utils/test_relabel_segmentation.py @@ -9,11 +9,11 @@ def test_relabel_segmentation(segmentation_2d, graph_2d): expected = np.zeros(segmentation_2d.shape, dtype="int32") # make frame with one cell in center with label 1 rr, cc = disk(center=(50, 50), radius=20, shape=(100, 100)) - expected[0, 0][rr, cc] = 1 + expected[0][rr, cc] = 1 # make frame with cell centered at (20, 80) with label 1 rr, cc = disk(center=(20, 80), radius=10, shape=frame_shape) - expected[1, 0][rr, cc] = 1 + expected[1][rr, cc] = 1 graph_2d.remove_node("1_2") relabeled_seg = relabel_segmentation(graph_2d, segmentation_2d) From 7c638508de77d0db115ed888f08178c82581f545 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Tue, 26 Nov 2024 13:06:19 -0500 Subject: [PATCH 2/4] fix strict zip linting --- pyproject.toml | 4 ++++ src/motile_toolbox/candidate_graph/iou.py | 4 ++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1ac3323..7169e3a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,10 @@ ignore = [ "S101", # Use of assert detected ] +unfixable = [ + "B905", # currently adds strict=False to zips. Should add strict=True (manually) +] + [tool.ruff.lint.per-file-ignores] "tests/*.py" = ["D", "S"] "*/__init__.py" = ["F401", "D"] diff --git a/src/motile_toolbox/candidate_graph/iou.py b/src/motile_toolbox/candidate_graph/iou.py index 90751fd..ea979f8 100644 --- a/src/motile_toolbox/candidate_graph/iou.py +++ b/src/motile_toolbox/candidate_graph/iou.py @@ -32,9 +32,9 @@ def _compute_ious( values, counts = np.unique(flattened_stacked, axis=1, return_counts=True) frame1_values, frame1_counts = np.unique(frame1, return_counts=True) - frame1_label_sizes = dict(zip(frame1_values, frame1_counts, strict=False)) + frame1_label_sizes = dict(zip(frame1_values, frame1_counts, strict=True)) frame2_values, frame2_counts = np.unique(frame2, return_counts=True) - frame2_label_sizes = dict(zip(frame2_values, frame2_counts, strict=False)) + frame2_label_sizes = dict(zip(frame2_values, frame2_counts, strict=True)) ious: list[tuple[int, int, float]] = [] for index in range(values.shape[1]): pair = values[:, index] From e83471b73de4f7940075b6203dc3d2cb0885817e Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Tue, 26 Nov 2024 13:07:47 -0500 Subject: [PATCH 3/4] Remove print statement --- src/motile_toolbox/candidate_graph/iou.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/motile_toolbox/candidate_graph/iou.py b/src/motile_toolbox/candidate_graph/iou.py index ea979f8..d134ebc 100644 --- a/src/motile_toolbox/candidate_graph/iou.py +++ b/src/motile_toolbox/candidate_graph/iou.py @@ -75,7 +75,6 @@ def _get_iou_dict(segmentation, multiseg=False) -> dict[str, dict[str, float]]: seg1 = segmentation[hypo1][frame] seg2 = segmentation[hypo2][frame + 1] ious = _compute_ious(seg1, seg2) - print(hypo1, hypo2, ious) for label1, label2, iou in ious: if multiseg: node_id1 = get_node_id(frame, label1, hypo1) From 2ffe2c58fb7a60bd21383ff6759b6b2df33067e5 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Tue, 26 Nov 2024 14:30:36 -0500 Subject: [PATCH 4/4] Explicitly mention dummy value for scaling time --- src/motile_toolbox/candidate_graph/utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/motile_toolbox/candidate_graph/utils.py b/src/motile_toolbox/candidate_graph/utils.py index 4e4dffe..9a430d6 100644 --- a/src/motile_toolbox/candidate_graph/utils.py +++ b/src/motile_toolbox/candidate_graph/utils.py @@ -53,10 +53,10 @@ def nodes_from_segmentation( Args: segmentation (np.ndarray): A numpy array with integer labels and dimensions (t, [z], y, x). - scale (list[float] | None, optional): The scale of the segmentation data. + scale (list[float] | None, optional): The scale of the segmentation data in all + dimensions (including time, which should have a dummy 1 value). Will be used to rescale the point locations and attribute computations. - Defaults to None, which implies the data is isotropic. Should include - time and all spatial dimentsions. + Defaults to None, which implies the data is isotropic. seg_hypo (int | None): A number to be stored in NodeAttr.SEG_HYPO, if given. Returns: @@ -111,9 +111,9 @@ def nodes_from_points_list( points_list (np.ndarray): An NxD numpy array with N points and D (3 or 4) dimensions. Dimensions should be in order (t, [z], y, x). scale (list[float] | None, optional): Amount to scale the points in each - dimension. Only needed if the provided points are in "voxel" coordinates - instead of world coordinates. Defaults to None, which implies the data is - isotropic. + dimension (including time). Only needed if the provided points are in + "voxel" coordinates instead of world coordinates. Defaults to None, which + implies the data is isotropic. Returns: tuple[nx.DiGraph, dict[int, list[Any]]]: A candidate graph with only nodes,