Skip to content

Commit

Permalink
Implement test and example for stitching, fix cornercase leading to s…
Browse files Browse the repository at this point in the history
…egfault
  • Loading branch information
constantinpape committed May 19, 2023
1 parent c232223 commit 50a270d
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 6 deletions.
36 changes: 30 additions & 6 deletions elf/segmentation/stitching.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import nifty.tools as nt
import numpy as np
import vigra
from nifty.ground_truth import overlap
from tqdm import trange, tqdm

Expand All @@ -13,7 +14,8 @@
def stitch_segmentation(
input_, segmentation_function,
tile_shape, tile_overlap, beta=0.5,
shape=None, with_background=True, n_threads=None
shape=None, with_background=True, n_threads=None,
return_before_stitching=False, verbose=True,
):
"""
"""
Expand All @@ -29,7 +31,7 @@ def stitch_segmentation(
n_blocks = blocking.numberOfBlocks
# TODO enable parallelisation
# run tiled segmentation
for block_id in trange(n_blocks, desc="Run tiled segmentation"):
for block_id in trange(n_blocks, desc="Run tiled segmentation", disable=not verbose):
block = blocking.getBlockWithHalo(block_id, list(tile_overlap))
outer_bb = tuple(slice(beg, end) for beg, end in zip(block.outerBlock.begin, block.outerBlock.end))

Expand All @@ -49,6 +51,7 @@ def stitch_segmentation(

# compute the region adjacency graph for the tiled segmentation
# and the edges between block boundaries (stitch edges)
seg_ids = np.unique(seg)
rag = compute_rag(seg, n_threads=n_threads)

# we initialize the edge disaffinities with a high value (corresponding to a low overlap)
Expand Down Expand Up @@ -101,17 +104,30 @@ def _compute_overlaps(block_id):
assert len(overlap_uv_ids) == len(overlap_values)

# - get the edge ids
# - exclude invalid edge (due to bg overlap)
# - exclude invalid edge
# - set the global edge disaffinities to 1 - overlap

# we might have ids in the overlaps that are not in the final seg, these need to be filtered
valid_uv_ids = np.isin(overlap_uv_ids, seg_ids).all(axis=1)
if valid_uv_ids.sum() == 0:
continue
overlap_uv_ids, overlap_values = overlap_uv_ids[valid_uv_ids], overlap_values[valid_uv_ids]
assert len(overlap_uv_ids) == len(overlap_values)

edge_ids = rag.findEdges(overlap_uv_ids)
valid_edges = edge_ids != -1
if valid_edges.sum() == 0:
continue
edge_ids, overlap_values = edge_ids[valid_edges], overlap_values[valid_edges]
assert len(edge_ids) == len(overlap_values)

edge_disaffinties[edge_ids] = (1.0 - overlap_values)

n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
with futures.ThreadPoolExecutor(n_threads) as tp:
list(tqdm(tp.map(_compute_overlaps, range(n_blocks)), total=n_blocks, desc="Compute object overlaps"))
list(tqdm(tp.map(
_compute_overlaps, range(n_blocks)), total=n_blocks, desc="Compute object overlaps", disable=not verbose
))

# if we have background set all the edges that are connecting 0 to another element
# to be very unlikely
Expand All @@ -123,5 +139,13 @@ def _compute_overlaps(block_id):

# run multicut to get the segmentation result
node_labels = multicut_decomposition(rag, costs)
seg = project_node_labels_to_pixels(rag, node_labels, n_threads=n_threads)
return seg
seg_stitched = project_node_labels_to_pixels(rag, node_labels, n_threads=n_threads)

if with_background:
vigra.analysis.relabelConsecutive(seg_stitched, out=seg_stitched, start_label=1, keep_zeros=True)
else:
vigra.analysis.relabelConsecutive(seg_stitched, out=seg_stitched, start_label=1, keep_zeros=False)

if return_before_stitching:
return seg_stitched, seg
return seg_stitched
41 changes: 41 additions & 0 deletions example/segmentation/stitching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import napari

from elf.segmentation.stitching import stitch_segmentation
from skimage.data import binary_blobs
from skimage.measure import label


def connected_components(input_, block_id=None):
segmentation = label(input_)
return segmentation.astype("uint32")


def create_test_data(size=1024):
data = binary_blobs(size, blob_size_fraction=0.1, volume_fraction=0.2)
return data


def main():
data = create_test_data(size=1024)

# compute the segmentation using tiling and stitching
tile_shape = (256, 256)
tile_overlap = (32, 32)
seg_stitched, seg_tiles = stitch_segmentation(
data, connected_components, tile_shape, tile_overlap, return_before_stitching=True
)

# compute the segmentation based on connected components without any stitching
seg_full = connected_components(data)

# check the results visually
v = napari.Viewer()
v.add_image(data, name="image")
v.add_labels(seg_full, name="segmentation")
v.add_labels(seg_stitched, name="stitched segmentation")
v.add_labels(seg_tiles, name="segmented tiles")
napari.run()


if __name__ == "__main__":
main()
48 changes: 48 additions & 0 deletions test/segmentation/test_stitching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import unittest

from elf.evaluation import rand_index
from skimage.data import binary_blobs
from skimage.measure import label


class TestStitching(unittest.TestCase):
def get_data(self, size=1024, ndim=2):
data = binary_blobs(size, blob_size_fraction=0.1, volume_fraction=0.2, n_dim=ndim)
return data

def test_stitch_segmentation(self):
from elf.segmentation.stitching import stitch_segmentation

def _segment(input_, block_id=None):
segmentation = label(input_)
return segmentation.astype("uint32")

tile_overlap = (32, 32)
tile_shapes = [(128, 128), (256, 256), (128, 256)]
for tile_shape in tile_shapes:
for _ in range(3): # test 3 times with different data
data = self.get_data()
expected_segmentation = _segment(data)
segmentation = stitch_segmentation(data, _segment, tile_shape, tile_overlap, verbose=False)
are, _ = rand_index(segmentation, expected_segmentation)
self.assertTrue(are < 0.05)

def test_stitch_segmentation_3d(self):
from elf.segmentation.stitching import stitch_segmentation

def _segment(input_, block_id=None):
segmentation = label(input_)
return segmentation.astype("uint32")

tile_overlap = (16, 16, 16)
tile_shapes = [(32, 32, 32), (64, 64, 64), (32, 64, 24)]
for tile_shape in tile_shapes:
data = self.get_data(256, ndim=3)
expected_segmentation = _segment(data)
segmentation = stitch_segmentation(data, _segment, tile_shape, tile_overlap, verbose=False)
are, _ = rand_index(segmentation, expected_segmentation)
self.assertTrue(are < 0.05)


if __name__ == "__main__":
unittest.main()

0 comments on commit 50a270d

Please sign in to comment.