Skip to content

Commit

Permalink
Add cucim.skimage.morphology.medial_axis (#342)
Browse files Browse the repository at this point in the history
closes #336

This PR adds a function for skeletonization of 2D images via the medial axis transform. 

It should be reviewed after #318 is merged. The new commits only start from 19a6fed.

There is one sequential component to this algorithm that still must be run on the CPU, but the majority of the computations are on the GPU and acceleration is good. I will add benchmark results here soon.

Authors:
  - Gregory Lee (https://github.com/grlee77)
  - https://github.com/jakirkham

Approvers:
  - https://github.com/jakirkham
  - Gigon Bae (https://github.com/gigony)

URL: #342
  • Loading branch information
grlee77 authored Aug 3, 2022
1 parent 553418b commit c834039
Show file tree
Hide file tree
Showing 4 changed files with 413 additions and 27 deletions.
3 changes: 2 additions & 1 deletion python/cucim/src/cucim/skimage/morphology/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ._skeletonize import thin
from ._skeletonize import medial_axis, thin
from .binary import (binary_closing, binary_dilation, binary_erosion,
binary_opening)
from .footprints import (ball, cube, diamond, disk, octagon, octahedron,
Expand Down Expand Up @@ -32,4 +32,5 @@
"remove_small_objects",
"remove_small_holes",
"thin",
"medial_axis",
]
67 changes: 67 additions & 0 deletions python/cucim/src/cucim/skimage/morphology/_medial_axis_lookup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import numpy as np

# medial axis lookup tables (independent of image content)
#
# Note: lookup table generated using scikit-image code from
# https://github.com/scikit-image/scikit-image/blob/38b595d60befe3a0b4c0742995b9737200a079c6/skimage/morphology/_skeletonize.py#L449-L458 # noqa

lookup_table = np.array(
[
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1,
0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0,
1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,
0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0
],
dtype=bool,
)


cornerness_table = np.array(
[
9, 8, 8, 7, 8, 7, 7, 6, 8, 7, 7, 6, 7, 6, 6, 5, 8, 7, 7, 6, 7, 6,
6, 5, 7, 6, 6, 5, 6, 5, 5, 4, 8, 7, 7, 6, 7, 6, 6, 5, 7, 6, 6, 5,
6, 5, 5, 4, 7, 6, 6, 5, 6, 5, 5, 4, 6, 5, 5, 4, 5, 4, 4, 3, 8, 7,
7, 6, 7, 6, 6, 5, 7, 6, 6, 5, 6, 5, 5, 4, 7, 6, 6, 5, 6, 5, 5, 4,
6, 5, 5, 4, 5, 4, 4, 3, 7, 6, 6, 5, 6, 5, 5, 4, 6, 5, 5, 4, 5, 4,
4, 3, 6, 5, 5, 4, 5, 4, 4, 3, 5, 4, 4, 3, 4, 3, 3, 2, 8, 7, 7, 6,
7, 6, 6, 5, 7, 6, 6, 5, 6, 5, 5, 4, 7, 6, 6, 5, 6, 5, 5, 4, 6, 5,
5, 4, 5, 4, 4, 3, 7, 6, 6, 5, 6, 5, 5, 4, 6, 5, 5, 4, 5, 4, 4, 3,
6, 5, 5, 4, 5, 4, 4, 3, 5, 4, 4, 3, 4, 3, 3, 2, 7, 6, 6, 5, 6, 5,
5, 4, 6, 5, 5, 4, 5, 4, 4, 3, 6, 5, 5, 4, 5, 4, 4, 3, 5, 4, 4, 3,
4, 3, 3, 2, 6, 5, 5, 4, 5, 4, 4, 3, 5, 4, 4, 3, 4, 3, 3, 2, 5, 4,
4, 3, 4, 3, 3, 2, 4, 3, 3, 2, 3, 2, 2, 1, 8, 7, 7, 6, 7, 6, 6, 5,
7, 6, 6, 5, 6, 5, 5, 4, 7, 6, 6, 5, 6, 5, 5, 4, 6, 5, 5, 4, 5, 4,
4, 3, 7, 6, 6, 5, 6, 5, 5, 4, 6, 5, 5, 4, 5, 4, 4, 3, 6, 5, 5, 4,
5, 4, 4, 3, 5, 4, 4, 3, 4, 3, 3, 2, 7, 6, 6, 5, 6, 5, 5, 4, 6, 5,
5, 4, 5, 4, 4, 3, 6, 5, 5, 4, 5, 4, 4, 3, 5, 4, 4, 3, 4, 3, 3, 2,
6, 5, 5, 4, 5, 4, 4, 3, 5, 4, 4, 3, 4, 3, 3, 2, 5, 4, 4, 3, 4, 3,
3, 2, 4, 3, 3, 2, 3, 2, 2, 1, 7, 6, 6, 5, 6, 5, 5, 4, 6, 5, 5, 4,
5, 4, 4, 3, 6, 5, 5, 4, 5, 4, 4, 3, 5, 4, 4, 3, 4, 3, 3, 2, 6, 5,
5, 4, 5, 4, 4, 3, 5, 4, 4, 3, 4, 3, 3, 2, 5, 4, 4, 3, 4, 3, 3, 2,
4, 3, 3, 2, 3, 2, 2, 1, 6, 5, 5, 4, 5, 4, 4, 3, 5, 4, 4, 3, 4, 3,
3, 2, 5, 4, 4, 3, 4, 3, 3, 2, 4, 3, 3, 2, 3, 2, 2, 1, 5, 4, 4, 3,
4, 3, 3, 2, 4, 3, 3, 2, 3, 2, 2, 1, 4, 3, 3, 2, 3, 2, 2, 1, 3, 2,
2, 1, 2, 1, 1, 0
],
dtype=np.uint8,
)
226 changes: 224 additions & 2 deletions python/cucim/src/cucim/skimage/morphology/_skeletonize.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import warnings

import cupy as cp
import cucim.skimage._vendored.ndimage as ndi
import numpy as np

from cucim.core.operations.morphology import distance_transform_edt

from .._shared.utils import check_nD, deprecate_kwarg
from ._medial_axis_lookup import \
cornerness_table as _medial_axis_cornerness_table
from ._medial_axis_lookup import lookup_table as _medial_axis_lookup_table

# --------- Skeletonization and thinning based on Guo and Hall 1989 ---------

Expand Down Expand Up @@ -39,7 +46,7 @@
0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=bool)


@deprecate_kwarg({'max_iter': 'max_num_iter'}, removed_version="23.02.00",
@deprecate_kwarg({"max_iter": "max_num_iter"}, removed_version="23.02.00",
deprecated_version="22.02.00")
def thin(image, max_num_iter=None):
"""
Expand Down Expand Up @@ -131,7 +138,7 @@ def thin(image, max_num_iter=None):
# perform the two "subiterations" described in the paper
for lut in [G123_LUT, G123P_LUT]:
# correlate image with neighborhood mask
N = ndi.correlate(skel, mask, mode='constant')
N = ndi.correlate(skel, mask, mode="constant")
# take deletion decision from this subiteration's LUT
D = cp.take(lut, N)
# perform deletion
Expand All @@ -141,3 +148,218 @@ def thin(image, max_num_iter=None):
num_iter += 1

return skel.astype(bool)


# --------- Skeletonization by medial axis transform --------


def _get_tiebreaker(n, random_seed):
# CuPy generator doesn't currently have the permutation method, so
# fall back to cp.random.permutation instead.
cp.random.seed(random_seed)
if n < 2 << 31:
dtype = np.int32
else:
dtype = np.intp
tiebreaker = cp.random.permutation(cp.arange(n, dtype=dtype))
return tiebreaker


def medial_axis(image, mask=None, return_distance=False, *, random_state=None):
"""Compute the medial axis transform of a binary image.
Parameters
----------
image : binary ndarray, shape (M, N)
The image of the shape to be skeletonized.
mask : binary ndarray, shape (M, N), optional
If a mask is given, only those elements in `image` with a true
value in `mask` are used for computing the medial axis.
return_distance : bool, optional
If true, the distance transform is returned as well as the skeleton.
random_state : {None, int, `numpy.random.Generator`}, optional
If `random_state` is None the `numpy.random.Generator` singleton is
used.
If `random_state` is an int, a new ``Generator`` instance is used,
seeded with `random_state`.
If `random_state` is already a ``Generator`` instance then that
instance is used.
.. versionadded:: 0.19
Returns
-------
out : ndarray of bools
Medial axis transform of the image
dist : ndarray of ints, optional
Distance transform of the image (only returned if `return_distance`
is True)
See Also
--------
skeletonize
Notes
-----
This algorithm computes the medial axis transform of an image
as the ridges of its distance transform.
The different steps of the algorithm are as follows
* A lookup table is used, that assigns 0 or 1 to each configuration of
the 3x3 binary square, whether the central pixel should be removed
or kept. We want a point to be removed if it has more than one neighbor
and if removing it does not change the number of connected components.
* The distance transform to the background is computed, as well as
the cornerness of the pixel.
* The foreground (value of 1) points are ordered by
the distance transform, then the cornerness.
* A cython function is called to reduce the image to its skeleton. It
processes pixels in the order determined at the previous step, and
removes or maintains a pixel according to the lookup table. Because
of the ordering, it is possible to process all pixels in only one
pass.
Examples
--------
>>> square = np.zeros((7, 7), dtype=np.uint8)
>>> square[1:-1, 2:-2] = 1
>>> square
array([[0, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 1, 1, 0, 0],
[0, 0, 1, 1, 1, 0, 0],
[0, 0, 1, 1, 1, 0, 0],
[0, 0, 1, 1, 1, 0, 0],
[0, 0, 1, 1, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 0]], dtype=uint8)
>>> medial_axis(square).astype(np.uint8)
array([[0, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 1, 0, 0],
[0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0],
[0, 0, 1, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 0]], dtype=uint8)
"""
try:
from skimage.morphology._skeletonize_cy import _skeletonize_loop
except ImportError as e:
warnings.warn(
"Could not find required private skimage Cython function:\n"
"\tskimage.morphology._skeletonize_cy._skeletonize_loop\n"
)
raise e

if mask is None:
# masked_image is modified in-place later so make a copy of the input
masked_image = image.astype(bool, copy=True)
else:
masked_image = image.astype(bool, copy=True)
masked_image[~mask] = False

# Load precomputed lookup table based on three conditions:
# 1. Keep only positive pixels
# AND
# 2. Keep if removing the pixel results in a different connectivity
# (if the number of connected components is different with and
# without the central pixel)
# OR
# 3. Keep if # pixels in neighborhood is 2 or less
# Note that this table is independent of the image
table = _medial_axis_lookup_table

# Build distance transform
distance = distance_transform_edt(masked_image)
if return_distance:
store_distance = distance.copy()

# Corners
# The processing order along the edge is critical to the shape of the
# resulting skeleton: if you process a corner first, that corner will
# be eroded and the skeleton will miss the arm from that corner. Pixels
# with fewer neighbors are more "cornery" and should be processed last.
# We use a cornerness_table lookup table where the score of a
# configuration is the number of background (0-value) pixels in the
# 3x3 neighborhood
cornerness_table = cp.asarray(_medial_axis_cornerness_table)
corner_score = _table_lookup(masked_image, cornerness_table)

# Define arrays for inner loop
distance = distance[masked_image]
i, j = cp.where(masked_image)

# Determine the order in which pixels are processed.
# We use a random # for tiebreaking. Assign each pixel in the image a
# predictable, random # so that masking doesn't affect arbitrary choices
# of skeletons
tiebreaker = _get_tiebreaker(n=distance.size, random_seed=random_state)
order = cp.lexsort(
cp.stack(
(tiebreaker, corner_score[masked_image], distance),
axis=0
)
)

# Call _skeletonize_loop on the CPU. It requies a single pass over the
# full array using a specific pixel order, so cannot be run multithreaded!
order = cp.asnumpy(order.astype(cp.int32, copy=False))
table = cp.asnumpy(table.astype(cp.uint8, copy=False))
i = cp.asnumpy(i).astype(dtype=np.intp, copy=False)
j = cp.asnumpy(j).astype(dtype=np.intp, copy=False)
result = cp.asnumpy(masked_image)
# Remove pixels not belonging to the medial axis
_skeletonize_loop(result.view(np.uint8), i, j, order, table)
result = cp.asarray(result.view(bool), dtype=bool)

if mask is not None:
result[~mask] = image[~mask]
if return_distance:
return result, store_distance
else:
return result


def _table_lookup(image, table):
"""
Perform a morphological transform on an image, directed by its
neighbors
Parameters
----------
image : ndarray
A binary image
table : ndarray
A 512-element table giving the transform of each pixel given
the values of that pixel and its 8-connected neighbors.
Returns
-------
result : ndarray of same shape as `image`
Transformed image
Notes
-----
The pixels are numbered like this::
0 1 2
3 4 5
6 7 8
The index at a pixel is the sum of 2**<pixel-number> for pixels
that evaluate to true.
"""
#
# We accumulate into the indexer to get the index into the table
# at each point in the image
#
# max possible value of indexer is 512, so just use int16 dtype
kernel = cp.array(
[[256, 128, 64], [32, 16, 8], [4, 2, 1]],
dtype=cp.int16
)
indexer = ndi.convolve(image, kernel, output=np.int16, mode="constant")
image = table[indexer]
return image
Loading

0 comments on commit c834039

Please sign in to comment.