Skip to content

Commit

Permalink
Merge pull request #24 from funkelab/23-remove-hypo-dim
Browse files Browse the repository at this point in the history
Refactor multihypothesis segmentation handling
  • Loading branch information
cmalinmayor authored Nov 26, 2024
2 parents 4ca103e + 2ffe2c5 commit 4986526
Show file tree
Hide file tree
Showing 12 changed files with 179 additions and 133 deletions.
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ ignore = [
"S101", # Use of assert detected
]

unfixable = [
"B905", # currently adds strict=False to zips. Should add strict=True (manually)
]

[tool.ruff.lint.per-file-ignores]
"tests/*.py" = ["D", "S"]
"*/__init__.py" = ["F401", "D"]
Expand Down
6 changes: 5 additions & 1 deletion src/motile_toolbox/candidate_graph/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from .compute_graph import get_candidate_graph, get_candidate_graph_from_points_list
from .compute_graph import (
compute_graph_from_multiseg,
compute_graph_from_points_list,
compute_graph_from_seg,
)
from .graph_attributes import EdgeAttr, NodeAttr
from .graph_to_nx import graph_to_nx
from .iou import add_iou
Expand Down
80 changes: 66 additions & 14 deletions src/motile_toolbox/candidate_graph/compute_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,19 @@
logger = logging.getLogger(__name__)


def get_candidate_graph(
def compute_graph_from_seg(
segmentation: np.ndarray,
max_edge_distance: float,
iou: bool = False,
scale: list[float] | None = None,
) -> tuple[nx.DiGraph, list[set[Any]] | None]:
) -> nx.DiGraph:
"""Construct a candidate graph from a segmentation array. Nodes are placed at the
centroid of each segmentation and edges are added for all nodes in adjacent frames
within max_edge_distance. If segmentation contains multiple hypotheses, will also
return a list of conflicting node ids that cannot be selected together.
within max_edge_distance.
Args:
segmentation (np.ndarray): A numpy array with integer labels and dimensions
(t, h, [z], y, x), where h is the number of hypotheses.
(t, [z], y, x).
max_edge_distance (float): Maximum distance that objects can travel between
frames. All nodes with centroids within this distance in adjacent frames
will by connected with a candidate edge.
Expand All @@ -35,11 +34,8 @@ def get_candidate_graph(
Defaults to None, which implies the data is isotropic.
Returns:
tuple[nx.DiGraph, list[set[Any]] | None]: A candidate graph that can be passed
to the motile solver, and a list of conflicting node ids.
nx.DiGraph: A candidate graph that can be passed to the motile solver
"""
num_hypotheses = segmentation.shape[1]

# add nodes
cand_graph, node_frame_dict = nodes_from_segmentation(segmentation, scale=scale)
logger.info(f"Candidate nodes: {cand_graph.number_of_nodes()}")
Expand All @@ -57,16 +53,73 @@ def get_candidate_graph(

logger.info(f"Candidate edges: {cand_graph.number_of_edges()}")

return cand_graph


def compute_graph_from_multiseg(
segmentations: np.ndarray,
max_edge_distance: float,
iou: bool = False,
scale: list[float] | None = None,
) -> tuple[nx.DiGraph, list[set[Any]]]:
"""Construct a candidate graph from a segmentation array. Nodes are placed at the
centroid of each segmentation and edges are added for all nodes in adjacent frames
within max_edge_distance.
Args:
segmentations (np.ndarray): numpy array with mupliple possible segmentations
stacked. Each segmentation has integer labels and
dimensions (h, t, [z], y, x). Assumes unique labels even between hypotheses.
max_edge_distance (float): Maximum distance that objects can travel between
frames. All nodes with centroids within this distance in adjacent frames
will by connected with a candidate edge.
iou (bool, optional): Whether to include IOU on the candidate graph.
Defaults to False.
scale (list[float] | None, optional): The scale of the segmentation data.
Will be used to rescale the point locations and attribute computations.
Defaults to None, which implies the data is isotropic.
Returns:
tuple[nx.DiGraph, list[set[Any]]: A candidate graph that can be passed to the
motile solver, and a list of conflicting node sets
"""
# add nodes
cand_graph = nx.DiGraph()
node_frame_dict: dict[int, Any] = {}
for hypo_id, seg in enumerate(segmentations):
seg_node_graph, seg_node_frame_dict = nodes_from_segmentation(
seg, scale=scale, seg_hypo=hypo_id
)
cand_graph.update(seg_node_graph)
for frame, nodes in seg_node_frame_dict.items():
if frame not in node_frame_dict:
node_frame_dict[frame] = []
node_frame_dict[frame].extend(nodes)
logger.info(f"Candidate nodes: {cand_graph.number_of_nodes()}")

# add edges
add_cand_edges(
cand_graph,
max_edge_distance=max_edge_distance,
node_frame_dict=node_frame_dict,
)
if iou:
# Scale does not matter to IOU, because both numerator and denominator
# are scaled by the anisotropy.
add_iou(cand_graph, segmentations, node_frame_dict, multiseg=True)

logger.info(f"Candidate edges: {cand_graph.number_of_edges()}")

# Compute conflict sets between segmentations
conflicts = []
if num_hypotheses > 1:
for time, segs in enumerate(segmentation):
conflicts.extend(compute_conflict_sets(segs, time))
for time in range(segmentations.shape[1]):
segs = segmentations[:, time]
conflicts.extend(compute_conflict_sets(segs, time))

return cand_graph, conflicts


def get_candidate_graph_from_points_list(
def compute_graph_from_points_list(
points_list: np.ndarray,
max_edge_distance: float,
scale: list[float] | None = None,
Expand All @@ -86,7 +139,6 @@ def get_candidate_graph_from_points_list(
Returns:
nx.DiGraph: A candidate graph that can be passed to the motile solver.
Multiple hypotheses not supported for points input.
"""
# add nodes
cand_graph, node_frame_dict = nodes_from_points_list(points_list, scale=scale)
Expand Down
49 changes: 31 additions & 18 deletions src/motile_toolbox/candidate_graph/iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ def _compute_ious(

values, counts = np.unique(flattened_stacked, axis=1, return_counts=True)
frame1_values, frame1_counts = np.unique(frame1, return_counts=True)
frame1_label_sizes = dict(zip(frame1_values, frame1_counts))
frame1_label_sizes = dict(zip(frame1_values, frame1_counts, strict=True))
frame2_values, frame2_counts = np.unique(frame2, return_counts=True)
frame2_label_sizes = dict(zip(frame2_values, frame2_counts))
frame2_label_sizes = dict(zip(frame2_values, frame2_counts, strict=True))
ious: list[tuple[int, int, float]] = []
for index in range(values.shape[1]):
pair = values[:, index]
Expand All @@ -45,37 +45,46 @@ def _compute_ious(
return ious


def _get_iou_dict(segmentation) -> dict[str, dict[str, float]]:
"""Get all ious values for the provided segmentation (all frames).
def _get_iou_dict(segmentation, multiseg=False) -> dict[str, dict[str, float]]:
"""Get all ious values for the provided segmentations (all frames).
Will return as map from node_id -> dict[node_id] -> iou for easy
navigation when adding to candidate graph.
Args:
segmentation (np.ndarray): Segmentation that was used to create cand_graph.
Has shape (t, h, [z], y, x), where h is the number of hypotheses.
segmentation (np.ndarray): Segmentations that were used to create cand_graph.
Has shape ([h], t, [z], y, x), where h is the number of hypotheses
if multiseg is True.
multiseg (bool): Flag indicating if the provided segmentation contains
multiple hypothesis segmentations. Defaults to False.
Returns:
dict[str, dict[str, float]]: A map from node id to another dictionary, which
contains node_ids to iou values.
"""
iou_dict: dict[str, dict[str, float]] = {}
hypo_pairs: list[tuple[int | None, ...]]
num_hypotheses = segmentation.shape[1]
if num_hypotheses > 1:
hypo_pairs = list(product(range(num_hypotheses), repeat=2))
hypo_pairs: list[tuple[int, ...]] = [(0, 0)]
if multiseg:
num_hypotheses = segmentation.shape[0]
if num_hypotheses > 1:
hypo_pairs = list(product(range(num_hypotheses), repeat=2))
else:
hypo_pairs = [(None, None)]
segmentation = np.expand_dims(segmentation, 0)

for frame in range(len(segmentation) - 1):
for frame in range(segmentation.shape[1] - 1):
for hypo1, hypo2 in hypo_pairs:
seg1 = segmentation[frame][hypo1]
seg2 = segmentation[frame + 1][hypo2]
seg1 = segmentation[hypo1][frame]
seg2 = segmentation[hypo2][frame + 1]
ious = _compute_ious(seg1, seg2)
for label1, label2, iou in ious:
node_id1 = get_node_id(frame, label1, hypo1)
if multiseg:
node_id1 = get_node_id(frame, label1, hypo1)
node_id2 = get_node_id(frame + 1, label2, hypo2)
else:
node_id1 = get_node_id(frame, label1)
node_id2 = get_node_id(frame + 1, label2)

if node_id1 not in iou_dict:
iou_dict[node_id1] = {}
node_id2 = get_node_id(frame + 1, label2, hypo2)
iou_dict[node_id1][node_id2] = iou
return iou_dict

Expand All @@ -84,22 +93,26 @@ def add_iou(
cand_graph: nx.DiGraph,
segmentation: np.ndarray,
node_frame_dict: dict[int, list[Any]] | None = None,
multiseg=False,
) -> None:
"""Add IOU to the candidate graph.
Args:
cand_graph (nx.DiGraph): Candidate graph with nodes and edges already populated
segmentation (np.ndarray): segmentation that was used to create cand_graph.
Has shape (t, h, [z], y, x), where h is the number of hypotheses.
Has shape ([h], t, [z], y, x), where h is the number of hypotheses if
multiseg is True.
node_frame_dict(dict[int, list[Any]] | None, optional): A mapping from
time frames to nodes in that frame. Will be computed if not provided,
but can be provided for efficiency (e.g. after running
nodes_from_segmentation). Defaults to None.
multiseg (bool): Flag indicating if the given segmentation is actually multiple
stacked segmentations. Defaults to False.
"""
if node_frame_dict is None:
node_frame_dict = _compute_node_frame_dict(cand_graph)
frames = sorted(node_frame_dict.keys())
ious = _get_iou_dict(segmentation)
ious = _get_iou_dict(segmentation, multiseg=multiseg)
for frame in tqdm(frames):
if frame + 1 not in node_frame_dict.keys():
continue
Expand Down
64 changes: 31 additions & 33 deletions src/motile_toolbox/candidate_graph/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def get_node_id(time: int, label_id: int, hypothesis_id: int | None = None) -> s
def nodes_from_segmentation(
segmentation: np.ndarray,
scale: list[float] | None = None,
seg_hypo=None,
) -> tuple[nx.DiGraph, dict[int, list[Any]]]:
"""Extract candidate nodes from a segmentation. Returns a networkx graph
with only nodes, and also a dictionary from frames to node_ids for
Expand All @@ -48,55 +49,52 @@ def nodes_from_segmentation(
- position
- segmentation id
- area
- hypothesis id (optional)
Args:
segmentation (np.ndarray): A numpy array with integer labels and dimensions
(t, h, [z], y, x), where h is the number of hypotheses.
scale (list[float] | None, optional): The scale of the segmentation data.
(t, [z], y, x).
scale (list[float] | None, optional): The scale of the segmentation data in all
dimensions (including time, which should have a dummy 1 value).
Will be used to rescale the point locations and attribute computations.
Defaults to None, which implies the data is isotropic. Should include
time and all spatial dimentsions.
Defaults to None, which implies the data is isotropic.
seg_hypo (int | None): A number to be stored in NodeAttr.SEG_HYPO, if given.
Returns:
tuple[nx.DiGraph, dict[int, list[Any]]]: A candidate graph with only nodes,
and a mapping from time frames to node ids.
"""
logger.debug("Extracting nodes from segmentation")
cand_graph = nx.DiGraph()
# also construct a dictionary from time frame to node_id for efficiency
node_frame_dict: dict[int, list[Any]] = {}
logger.info("Extracting nodes from segmentation")
num_hypotheses = segmentation.shape[1]

if scale is None:
scale = [
1,
] * (segmentation.ndim - 1) # don't include hypothesis
] * segmentation.ndim
else:
assert (
len(scale) == segmentation.ndim - 1
), f"Scale {scale} should have {segmentation.ndim - 1} dims"
len(scale) == segmentation.ndim
), f"Scale {scale} should have {segmentation.ndim} dims"

for t in tqdm(range(len(segmentation))):
segs = segmentation[t]
hypo_id: int | None
for hypo_id, hypo in enumerate(segs):
if num_hypotheses == 1:
hypo_id = None
nodes_in_frame = []
props = regionprops(hypo, spacing=tuple(scale[1:]))
for regionprop in props:
node_id = get_node_id(t, regionprop.label, hypothesis_id=hypo_id)
attrs = {NodeAttr.TIME.value: t, NodeAttr.AREA.value: regionprop.area}
attrs[NodeAttr.SEG_ID.value] = regionprop.label
if hypo_id is not None:
attrs[NodeAttr.SEG_HYPO.value] = hypo_id
centroid = regionprop.centroid # [z,] y, x
attrs[NodeAttr.POS.value] = centroid
cand_graph.add_node(node_id, **attrs)
nodes_in_frame.append(node_id)
if nodes_in_frame:
if t not in node_frame_dict:
node_frame_dict[t] = []
node_frame_dict[t].extend(nodes_in_frame)
nodes_in_frame = []
props = regionprops(segs, spacing=tuple(scale[1:]))
for regionprop in props:
node_id = get_node_id(t, regionprop.label, hypothesis_id=seg_hypo)
attrs = {NodeAttr.TIME.value: t, NodeAttr.AREA.value: regionprop.area}
attrs[NodeAttr.SEG_ID.value] = regionprop.label
if seg_hypo:
attrs[NodeAttr.SEG_HYPO.value] = seg_hypo
centroid = regionprop.centroid # [z,] y, x
attrs[NodeAttr.POS.value] = centroid
cand_graph.add_node(node_id, **attrs)
nodes_in_frame.append(node_id)
if nodes_in_frame:
if t not in node_frame_dict:
node_frame_dict[t] = []
node_frame_dict[t].extend(nodes_in_frame)
return cand_graph, node_frame_dict


Expand All @@ -113,9 +111,9 @@ def nodes_from_points_list(
points_list (np.ndarray): An NxD numpy array with N points and D
(3 or 4) dimensions. Dimensions should be in order (t, [z], y, x).
scale (list[float] | None, optional): Amount to scale the points in each
dimension. Only needed if the provided points are in "voxel" coordinates
instead of world coordinates. Defaults to None, which implies the data is
isotropic.
dimension (including time). Only needed if the provided points are in
"voxel" coordinates instead of world coordinates. Defaults to None, which
implies the data is isotropic.
Returns:
tuple[nx.DiGraph, dict[int, list[Any]]]: A candidate graph with only nodes,
Expand Down
21 changes: 6 additions & 15 deletions src/motile_toolbox/utils/relabel_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,15 @@ def relabel_segmentation(
Args:
solution_nx_graph (nx.DiGraph): Networkx graph with the solution to use
for relabeling. Nodes not in graph will be removed from seg. Original
segmentation ids and hypothesis ids have to be stored in the graph so we
segmentation ids have to be stored in the graph so we
can map them back.
segmentation (np.ndarray): Original (potentially multi-hypothesis)
segmentation with dimensions (t,h,[z],y,x), where h is 1 for single
input segmentation.
segmentation (np.ndarray): Original segmentation with dimensions (t, [z], y, x)
Returns:
np.ndarray: Relabeled segmentation array where nodes in same track share same
id with shape (t,1,[z],y,x)
id with shape (t,[z],y,x)
"""
output_shape = (segmentation.shape[0], 1, *segmentation.shape[2:])
tracked_masks = np.zeros_like(segmentation, shape=output_shape)
tracked_masks = np.zeros_like(segmentation)
id_counter = 1
parent_nodes = [n for (n, d) in solution_nx_graph.out_degree() if d > 1]
soln_copy = solution_nx_graph.copy()
Expand All @@ -36,13 +33,7 @@ def relabel_segmentation(
for node in node_set:
time_frame = solution_nx_graph.nodes[node][NodeAttr.TIME.value]
previous_seg_id = solution_nx_graph.nodes[node][NodeAttr.SEG_ID.value]
if NodeAttr.SEG_HYPO.value in solution_nx_graph.nodes[node]:
hypothesis_id = solution_nx_graph.nodes[node][NodeAttr.SEG_HYPO.value]
else:
hypothesis_id = 0
previous_seg_mask = (
segmentation[time_frame, hypothesis_id] == previous_seg_id
)
tracked_masks[time_frame, 0][previous_seg_mask] = id_counter
previous_seg_mask = segmentation[time_frame] == previous_seg_id
tracked_masks[time_frame][previous_seg_mask] = id_counter
id_counter += 1
return tracked_masks
Loading

0 comments on commit 4986526

Please sign in to comment.