Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distance label trafos #177

Merged
merged 3 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 41 additions & 3 deletions test/transform/test_label_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def affs_brute_force_with_mask(seg, offsets, mask_bg_transition=True):
class TestLabelTransforms(unittest.TestCase):
def get_labels(self, with_zero):
shape = (64, 64)
# shape = (6, 6)
labels = np.random.randint(1, 6, size=shape).astype("uint64")
if with_zero:
bg_prob = 0.25
Expand Down Expand Up @@ -132,14 +131,14 @@ def test_distance_transform(self):
self.assertTrue((tnew >= 0).all())
self.assertTrue((tnew <= 5).all())

trafo = DistanceTransform(normalize=False, vector_distances=True)
trafo = DistanceTransform(normalize=False, directed_distances=True)
tnew = trafo(target)
self.assertEqual(tnew.shape, (3,) + target.shape)
distances, vector_distances = tnew[0], tnew[1:]
abs_dist = np.linalg.norm(vector_distances, axis=0)
self.assertTrue(np.allclose(distances, abs_dist))

trafo = DistanceTransform(normalize=True, vector_distances=True)
trafo = DistanceTransform(normalize=True, directed_distances=True)
tnew = trafo(target)
self.assertEqual(tnew.shape, (3,) + target.shape)
self.assertTrue((tnew >= -1).all())
Expand Down Expand Up @@ -169,6 +168,45 @@ def test_distance_transform_empty_labels(self):
tnew = trafo(target)
self.assertTrue(np.allclose(tnew, 1.0))

def test_per_object_distance_transform(self):
from torch_em.transform.label import PerObjectDistanceTransform
from skimage.data import binary_blobs
from skimage.measure import label

labels = label(binary_blobs(256, volume_fraction=0.25))

trafo = PerObjectDistanceTransform(
distances=True, boundary_distances=False, directed_distances=False, foreground=False,
)
result = trafo(labels)
self.assertEqual(result.shape, (1,) + labels.shape)
self.assertGreaterEqual(result.min(), 0)
self.assertLessEqual(result.max(), 1)

trafo = PerObjectDistanceTransform(
distances=False, boundary_distances=True, directed_distances=False, foreground=False,
)
result = trafo(labels)
self.assertEqual(result.shape, (1,) + labels.shape)
self.assertGreaterEqual(result.min(), 0)
self.assertLessEqual(result.max(), 1)

trafo = PerObjectDistanceTransform(
distances=False, boundary_distances=False, directed_distances=True, foreground=False,
)
result = trafo(labels)
self.assertEqual(result.shape, (2,) + labels.shape)
self.assertGreaterEqual(result.min(), -1)
self.assertLessEqual(result.max(), 1)

trafo = PerObjectDistanceTransform(
distances=True, boundary_distances=True, directed_distances=False, foreground=True,
)
result = trafo(labels)
self.assertEqual(result.shape, (3,) + labels.shape)
self.assertGreaterEqual(result.min(), 0)
self.assertLessEqual(result.max(), 1)


if __name__ == "__main__":
unittest.main()
229 changes: 189 additions & 40 deletions torch_em/transform/label.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import skimage.measure
import skimage.segmentation
from scipy.ndimage import distance_transform_edt
import vigra

from ..util import ensure_array, ensure_spatial_array

Expand Down Expand Up @@ -192,25 +192,32 @@ def __call__(self, labels):


class DistanceTransform:
"""Compute distances to foreground.
"""
eps = 1e-7

def __init__(
self,
distances=True, vector_distances=False,
normalize=True, max_distance=None,
foreground_id=1, invert=False, func=None
distances=True,
directed_distances=False,
normalize=True,
max_distance=None,
foreground_id=1,
invert=False,
func=None
):
if sum((distances, vector_distances)) == 0:
raise ValueError("At least one of 'distances' or 'vector_distances' must be set to 'True'")
self.vector_distances = vector_distances
if sum((distances, directed_distances)) == 0:
raise ValueError("At least one of 'distances' or 'directed_distances' must be set to 'True'")
self.directed_distances = directed_distances
self.distances = distances
self.normalize = normalize
self.max_distance = max_distance
self.foreground_id = foreground_id
self.invert = invert
self.func = func

def _compute_distances(self, distances):
def _compute_distances(self, directed_distances):
distances = np.linalg.norm(directed_distances, axis=0)
if self.max_distance is not None:
distances = np.clip(distances, 0, self.max_distance)
if self.normalize:
Expand All @@ -221,54 +228,196 @@ def _compute_distances(self, distances):
distances = self.func(distances)
return distances

def _compute_vector_distances(self, indices):
coordinates = np.indices(indices.shape[1:]).astype("float32")
vector_distances = indices - coordinates
def _compute_directed_distances(self, directed_distances):
if self.max_distance is not None:
vector_distances = np.clip(vector_distances, -self.max_distance, self.max_distance)
directed_distances = np.clip(directed_distances, -self.max_distance, self.max_distance)
if self.normalize:
vector_distances /= (np.abs(vector_distances).max(axis=(1, 2), keepdims=True) + self.eps)
directed_distances /= (np.abs(directed_distances).max(axis=(1, 2), keepdims=True) + self.eps)
if self.invert:
vector_distances = vector_distances.max(axis=(1, 2), keepdims=True) - vector_distances
directed_distances = directed_distances.max(axis=(1, 2), keepdims=True) - directed_distances
if self.func is not None:
vector_distances = self.func(vector_distances)
return vector_distances
directed_distances = self.func(directed_distances)
return directed_distances

def _get_distances_for_empty_labels(self, labels):
shape = labels.shape
fill_value = 0.0 if self.invert else np.linalg.norm(list(shape))
if self.distances and self.vector_distances:
data = (np.full(shape, fill_value), np.full((labels.ndim,) + shape, fill_value))
elif self.distances:
data = np.full(shape, fill_value)
elif self.vector_distances:
data = np.full((labels.ndim,) + shape, fill_value)
else:
raise RuntimeError
fill_value = 0.0 if self.invert else np.sqrt(np.linalg.norm(list(shape)) ** 2 / 2)
data = np.full((labels.ndim,) + shape, fill_value)
return data

def __call__(self, labels):
distance_mask = labels != self.foreground_id
distance_mask = (labels == self.foreground_id).astype("uint32")
# the distances are not computed corrected if they are all zero
# so this case needs to be handled separately
if distance_mask.sum() == distance_mask.size:
data = self._get_distances_for_empty_labels(labels)
if distance_mask.sum() == 0:
directed_distances = self._get_distances_for_empty_labels(labels)
else:
data = distance_transform_edt(distance_mask,
return_distances=self.distances,
return_indices=self.vector_distances)
ndim = distance_mask.ndim
to_channel_first = (ndim,) + tuple(range(ndim))
directed_distances = vigra.filters.vectorDistanceTransform(distance_mask).transpose(to_channel_first)

if self.distances:
distances = data[0] if self.vector_distances else data
distances = self._compute_distances(distances)
distances = self._compute_distances(directed_distances)

if self.vector_distances:
indices = data[1] if self.distances else data
vector_distances = self._compute_vector_distances(indices)
if self.directed_distances:
directed_distances = self._compute_directed_distances(directed_distances)

if self.distances and self.vector_distances:
return np.concatenate((distances[None], vector_distances), axis=0)
if self.distances and self.directed_distances:
return np.concatenate((distances[None], directed_distances), axis=0)
if self.distances:
return distances
if self.vector_distances:
return vector_distances
if self.directed_distances:
return directed_distances


class PerObjectDistanceTransform:
"""Compute normalized distances per object in a segmentation.
"""
eps = 1e-7

def __init__(
self,
distances=True,
boundary_distances=True,
directed_distances=False,
foreground=True,
apply_label=True,
correct_centers=True,
min_size=0,
distance_fill_value=1.0,
):
if sum([distances, directed_distances, boundary_distances]) == 0:
raise ValueError("At least one of distances or directed distances has to be passed.")
self.distances = distances
self.boundary_distances = boundary_distances
self.directed_distances = directed_distances
self.foreground = foreground

self.apply_label = apply_label
self.correct_centers = correct_centers
self.min_size = min_size
self.distance_fill_value = distance_fill_value

def compute_normalized_object_distances(self, mask, boundaries, bb, center, distances):
# Crop the mask and generate array with the correct center.
cropped_mask = mask[bb]
cropped_center = tuple(ce - b.start for ce, b in zip(center, bb))

# The centroid might not be inside of the object.
# In this case we correct the center by taking the maximum of the distance to the boundary.
# Note: the centroid is still the best estimate for the center, as long as it's in the object.
correct_center = not cropped_mask[cropped_center]

# Compute the boundary distances if necessary.
# (Either if we need to correct the center, or compute the boundary distances anyways.)
if correct_center or self.boundary_distances:
# Crop the boundary mask and compute the boundary distances.
cropped_boundary_mask = boundaries[bb]
boundary_distances = vigra.filters.distanceTransform(cropped_boundary_mask)
boundary_distances[~cropped_mask] = 0
max_dist_point = np.unravel_index(np.argmax(boundary_distances), boundary_distances.shape)

# Set the crop center to the max dist point
if correct_center:
# Find the center (= maximal distance from the boundaries).
cropped_center = max_dist_point

cropped_center_mask = np.zeros_like(cropped_mask, dtype="uint32")
cropped_center_mask[cropped_center] = 1

# Compute the directed distances,
if self.distances or self.directed_distances:
this_distances = vigra.filters.vectorDistanceTransform(cropped_center_mask)
else:
this_distances = None

# Keep only the specified distances:
if self.distances and self.directed_distances: # all distances
# Compute the undirected ditacnes from directed distances and concatenate,
undir = np.linalg.norm(this_distances, axis=-1, keepdims=True)
this_distances = np.concatenate([undir, this_distances], axis=-1)

elif self.distances: # only undirected distances
# Compute the undirected distances from directed distances and keep only them.
this_distances = np.linalg.norm(this_distances, axis=-1, keepdims=True)

elif self.directed_distances: # only directed distances
pass # We don't have to do anything becasue the directed distances are already computed.

# Add an extra channel for the boundary distances if specified.
if self.boundary_distances:
boundary_distances = (boundary_distances[max_dist_point] - boundary_distances)[..., None]
if this_distances is None:
this_distances = boundary_distances
else:
this_distances = np.concatenate([this_distances, boundary_distances], axis=-1)

# Set distances outside of the mask to zero.
this_distances[~cropped_mask] = 0

# Normalize the distances.
spatial_axes = tuple(range(mask.ndim))
this_distances /= (np.abs(this_distances).max(axis=spatial_axes, keepdims=True) + self.eps)

# Set the distance values in the global result.
distances[bb][cropped_mask] = this_distances[cropped_mask]

return distances

def __call__(self, labels):
# Apply label (connected components) if specified.
if self.apply_label:
labels = skimage.measure.label(labels).astype("uint32")
else: # Otherwise just relabel the segmentation.
labels = vigra.analysis.relabelConsecutive(labels)[0].astype("uint32")

# Filter out small objects if min_size is specified.
if self.min_size > 0:
ids, sizes = np.unique(labels, return_counts=True)
discard_ids = ids[sizes < self.min_size]
labels[np.isin(labels, discard_ids)] = 0
labels = vigra.analysis.relabelConsecutive(labels)[0].astype("uint32")

# Compute the boundaries. They will be used to determine the most central point,
# and if 'self.boundary_distances is True' to add the boundary distances.
boundaries = skimage.segmentation.find_boundaries(labels, mode="inner").astype("uint32")

# Compute region properties to derive bounding boxes and centers.
ndim = labels.ndim
props = skimage.measure.regionprops(labels)
bounding_boxes = {
prop.label: tuple(slice(prop.bbox[i], prop.bbox[i + ndim]) for i in range(ndim))
for prop in props
}

# Compute the object centers from centroids.
centers = {prop.label: np.round(prop.centroid).astype("int") for prop in props}

# Compute how many distance channels we have.
n_channels = 0
if self.distances: # We need one channel for the overall distances.
n_channels += 1
if self.boundary_distances: # We need one channel for the boundary distances.
n_channels += 1
if self.directed_distances: # And ndim channels for directed distances.
n_channels += ndim

# Compute the per object distances.
distances = np.full(labels.shape + (n_channels,), self.distance_fill_value, dtype="float32")
for prop in props:
label_id = prop.label
mask = labels == label_id
distances = self.compute_normalized_object_distances(
mask, boundaries, bounding_boxes[label_id], centers[label_id], distances
)

# Bring the distance channel to the first dimension.
to_channel_first = (ndim,) + tuple(range(ndim))
distances = distances.transpose(to_channel_first)

# Add the foreground mask as first channel if specified.
if self.foreground:
binary_labels = (labels > 0).astype("float32")
distances = np.concatenate([binary_labels[None], distances], axis=0)

return distances
Loading