Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add label stitching functionality #106

Merged
merged 10 commits into from
Dec 30, 2024
163 changes: 148 additions & 15 deletions elf/segmentation/stitching.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import multiprocessing
from concurrent import futures
from typing import Tuple, Optional, Callable

import nifty.tools as nt
import numpy as np
import vigra
from nifty.ground_truth import overlap
import numpy as np

import nifty.tools as nt
from nifty.ground_truth import overlap as compute_overlap

try:
from napari.utils import progress as tqdm
Expand All @@ -16,19 +18,25 @@


def stitch_segmentation(
input_, segmentation_function,
tile_shape, tile_overlap, beta=0.5,
shape=None, with_background=True, n_threads=None,
return_before_stitching=False, verbose=True,
):
input_: np.ndarray,
segmentation_function: Callable,
tile_shape: Tuple[int, int],
tile_overlap: Tuple[int, int],
beta: float = 0.5,
shape: Optional[Tuple[int, int]] = None,
with_background: bool = True,
n_threads: Optional[int] = None,
return_before_stitching: bool = False,
verbose: bool = True,
) -> np.ndarray:
"""Run segmentation function tilewise and stitch the results based on overlap.

Arguments:
input_ [np.ndarray] - the input data. If the data has channels they need to be passed as last dimension,
e.g. XYC for a 2D image with channels.
segmentation_function [callable] - the function to perform segmentation for each tile.
Needs to be a segmentation that takes the input (for the tile) as well as the id of the tile as input.
I.e. the function needs to have a signature like this: 'def my_seg_func(tile_input_, tile_id)'.
i.e. the function needs to have a signature like this: 'def my_seg_func(tile_input_, tile_id)'.
The tile_id is passed in case the segmentation routine differs depending on the tile;
it can be ignored in most cases.
tile_shape [tuple] - shape of the individual tiles.
Expand Down Expand Up @@ -101,10 +109,12 @@ def _compute_overlaps(block_id):
this_seg, ngb_seg = block_segs[block_id], block_segs[ngb_id]

# get the global coordinates of the block face
face = tuple(slice(beg_out, end_out) if d != axis else slice(beg_out, beg_in + tile_overlap[d])
for d, (beg_out, end_out, beg_in) in enumerate(zip(this_block.outerBlock.begin,
this_block.outerBlock.end,
this_block.innerBlock.begin)))
face = tuple(
slice(beg_out, end_out) if d != axis else slice(beg_out, beg_in + tile_overlap[d])
for d, (beg_out, end_out, beg_in) in enumerate(
zip(this_block.outerBlock.begin, this_block.outerBlock.end, this_block.innerBlock.begin)
))

# map to the two local face coordinates
this_face_bb = tuple(
slice(fa.start - offset, fa.stop - offset) for fa, offset in zip(face, this_block.outerBlock.begin)
Expand All @@ -116,10 +126,10 @@ def _compute_overlaps(block_id):
# load the two segmentations for the face
this_face = this_seg[this_face_bb]
ngb_face = ngb_seg[ngb_face_bb]
assert this_face.shape == ngb_face.shape
assert this_face.shape == ngb_face.shape, (this_face.shape, ngb_face.shape)

# compute the object overlaps
overlap_comp = overlap(this_face, ngb_face)
overlap_comp = compute_overlap(this_face, ngb_face)
this_ids = np.unique(this_face)
overlaps = {this_id: overlap_comp.overlapArraysNormalized(this_id, sorted=False) for this_id in this_ids}
overlap_ids = {this_id: ovlps[0] for this_id, ovlps in overlaps.items()}
Expand Down Expand Up @@ -175,4 +185,127 @@ def _compute_overlaps(block_id):

if return_before_stitching:
return seg_stitched, seg

return seg_stitched


def stitch_tiled_segmentation(
segmentation: np.ndarray,
tile_shape: Tuple[int, int],
overlap: int = 1,
n_threads: Optional[int] = None,
verbose: bool = True,
) -> np.ndarray:
"""Functionality for stitching segmentations tile-wise based on overlap.

Args:
segmentation: The input segmentation.
tile_shape: The shape of inidividual tiles.
overlap: The overlap of tiles.
It is responsible to compute the edge nodes for the desired overlap region.
n_threads: The number of threads used for parallelized operations.
Set to the number of cores by default.
verbose: Whether to print the progress bars.

Returns:
The stitched segmentation with merged labels.
"""
shape = segmentation.shape
ndim = len(shape)
blocking = nt.blocking([0] * ndim, shape, tile_shape)
n_blocks = blocking.numberOfBlocks

block_segs = []

# Get the tiles from the segmentation of shape: 'tile_shape'.
def _fetch_tiles(block_id):
block = blocking.getBlock(block_id)
bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end))
block_seg = segmentation[bb]
block_segs.append(block_seg)

n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
with futures.ThreadPoolExecutor(n_threads) as tp:
anwai98 marked this conversation as resolved.
Show resolved Hide resolved
list(tqdm(tp.map(
_fetch_tiles, range(n_blocks)), total=n_blocks, desc="Get tiles from the segmentation", disable=not verbose,
))

# Conpute the Region Adjacency Graph (RAG) for the tiled segmentation.
# and the edges between block boundaries (stitch edges).
seg_ids = np.unique(segmentation)
rag = compute_rag(segmentation)

# We initialize the edge disaffinities with a high value (corresponding to a low overlap)
# so that merging things that are not on the edge is very unlikely
# but not completely impossible in case it is needed for a consistent solution.
edge_disaffinities = np.full(rag.numberOfEdges, 0.9, dtype="float32")

def _compute_overlaps(block_id):
# For each axis, load the face with the lower block neighbor and compute the object overlaps
for axis in range(ndim):
ngb_id = blocking.getNeighborId(block_id, axis, lower=True)
if ngb_id == -1:
continue

# Load the respective tiles.
this_seg, ngb_seg = block_segs[block_id], block_segs[ngb_id]

# Get the local face coordinates of the respective tiles.
# We get the face region of the shape defined by 'overlap'
# eg. The default '1' returns a 1d cross-section of the tile interfaces.
face_bb = tuple(slice(None) if d != axis else slice(0, overlap) for d in range(ndim))
ngb_face_bb = tuple(
slice(None) if d != axis else slice(ngb_seg.shape[d] - overlap, ngb_seg.shape[d]) for d in range(ndim)
)

# Load the two segmentations for the face.
this_face = this_seg[face_bb]
ngb_face = ngb_seg[ngb_face_bb]

# Both the faces from each tile are expected to be of the same shape
assert this_face.shape == ngb_face.shape, (this_face.shape, ngb_face.shape)

# Compute the object overlaps.
# In this step, we compute the per-instance overlap over both faces
overlap_comp = compute_overlap(this_face, ngb_face)
this_ids = np.unique(this_face).astype("uint32")
overlaps = {this_id: overlap_comp.overlapArraysNormalized(this_id, sorted=False) for this_id in this_ids}
overlap_ids = {this_id: ovlps[0] for this_id, ovlps in overlaps.items()}
overlap_values = {this_id: ovlps[1] for this_id, ovlps in overlaps.items()}
overlap_uv_ids = np.array([
[this_id, ovlp_id] for this_id, ovlp_ids in overlap_ids.items() for ovlp_id in ovlp_ids
])
overlap_values = np.array([ovlp for ovlps in overlap_values.values() for ovlp in ovlps], dtype="float32")
assert len(overlap_uv_ids) == len(overlap_values)

# Next, we remove the invalid edges.
# We might have ids in the overlaps that are not in the segmentation. We filter them out.
valid_uv_ids = np.isin(overlap_uv_ids, seg_ids).all(axis=1)
if valid_uv_ids.sum() == 0:
continue
overlap_uv_ids, overlap_values = overlap_uv_ids[valid_uv_ids], overlap_values[valid_uv_ids]
assert len(overlap_uv_ids) == len(overlap_values)

# Get the edge ids.
edge_ids = rag.findEdges(overlap_uv_ids)
valid_edges = edge_ids != -1
if valid_edges.sum() == 0:
continue
edge_ids, overlap_values = edge_ids[valid_edges], overlap_values[valid_edges]
assert len(edge_ids) == len(overlap_values)

# And set the global edge disaffinities to (1 - overlap).
edge_disaffinities[edge_ids] = (1.0 - overlap_values)

with futures.ThreadPoolExecutor(n_threads) as tp:
list(tqdm(tp.map(
_compute_overlaps, range(n_blocks)), total=n_blocks, desc="Compute object overlaps", disable=not verbose,
))

costs = compute_edge_costs(edge_disaffinities, beta=0.5)

# Run multicut to get the segmentation result.
node_labels = multicut_decomposition(rag, costs)
seg_stitched = project_node_labels_to_pixels(rag, node_labels)

return seg_stitched
Loading