Skip to content

Commit

Permalink
Merge pull request #11 from AllenCell/feature/alignment
Browse files Browse the repository at this point in the history
Feature/alignment
  • Loading branch information
vianamp committed Apr 6, 2024
2 parents d12c5a2 + 15c02df commit 04128cf
Show file tree
Hide file tree
Showing 6 changed files with 319 additions and 112 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test-and-lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@0a5c61591373683505ea898e09a3ea4f39ef2b9c
with:
python-version: 3.8
python-version: 3.9
- name: Install Dependencies
run: |
python -m pip install --upgrade pip
Expand Down
Empty file.
149 changes: 149 additions & 0 deletions aicscytoparam/alignment/generic_2d_shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import numpy as np
import matplotlib.pyplot as plt
from scipy import spatial as spspatial
from skimage import transform as sktrans
from skimage import measure as skmeasure


class Generic2DShape:
"""
Generic class for 2D shapes
"""

def __init__():
pass

def _compute_contour(self):
"""
Compute the contour of the shape
"""
cont = skmeasure.find_contours(self._polygon)[0]
cx, cy = cont[:, 1], cont[:, 0]
self.cx = cx - cx.mean()
self.cy = cy - cy.mean()
return

def show(self, ax=None):
if ax is None:
fig, ax = plt.subplots()
ax.plot(self.cx, self.cy)
if ax is None:
plt.show()

def find_angle_that_minimizes_countour_distance(self, cx, cy):
"""
Find the angle that minimizes the distance between the
shape and the contour (cx, cy)
Parameters
----------
cx: np.ndarray
x coordinates of the contour
cy: np.ndarray
y coordinates of the contour
Returns
-------
angle: int
angle that minimizes the distance
dist: float
minimum distance
"""
# Assumes cx and cy are centered at origin
dists = []
for theta in range(180):
cx2rot, cy2rot = Generic2DShape.rotate_contour(cx, cy, theta)
X = np.c_[self.cx, self.cy]
Y = np.c_[cx2rot, cy2rot]
D = spspatial.distance.cdist(X, Y)
dist_min = D.min(axis=0).mean() + D.min(axis=1).mean()
dists.append(dist_min)
return np.argmin(dists), np.min(dists)

@staticmethod
def rotate_contour(cx, cy, theta):
"""
Rotate a contour around the origin
Parameters
----------
cx: np.ndarray
x coordinates of the contour
cy: np.ndarray
y coordinates of the contour
theta: float
angle of rotation
Returns
-------
cxrot: np.ndarray
x coordinates of the rotated contour
cyrot: np.ndarray
y coordinates of the rotated contour
"""
cxrot = cx * np.cos(np.deg2rad(theta)) - cy * np.sin(np.deg2rad(theta))
cyrot = cx * np.sin(np.deg2rad(theta)) + cy * np.cos(np.deg2rad(theta))
return cxrot, cyrot

@staticmethod
def get_contour_from_3d_image(image, pad=5, center=True):
"""
Get the contour of a 3D image
Parameters
----------
image: np.ndarray
3D image
pad: int
padding
center: bool
center the contour
Returns
-------
cx: np.ndarray
x coordinates of the contour
cy: np.ndarray
y coordinates of the contour
"""
mip = image.max(axis=0)
y, x = np.where(mip > 0)
mip = np.pad(mip, ((pad, pad), (pad, pad)))
cont = skmeasure.find_contours(mip > 0)[0]
cx, cy = cont[:, 1], cont[:, 0]
if center:
cx = cx - cx.mean()
cy = cy - cy.mean()
return (cx, cy)


class ElongatedHexagonalShape(Generic2DShape):
"""
Elongated hexagonal shape
"""

def __init__(self, base, elongation, pad=5):
self._pad = pad
self._base = base
self._height = int(self._base / np.sqrt(2))
self._elongation_factor = elongation
self._create()
self._compute_contour()

def _create(self):
"""
Create the elongated hexagonal shape
"""
pad = self._pad
triangle = np.tril(np.ones((self._height, self._base)))
triangle = sktrans.rotate(triangle, angle=-15, center=(0, 0), order=0)
rectangle = np.ones((self._height, self._base))
for _ in range(self._elongation_factor):
rectangle = np.concatenate([rectangle, rectangle[:, :1]], axis=1)
upper_half = np.concatenate([triangle[:, ::-1], rectangle, triangle], axis=1)
hexagon = np.concatenate([upper_half, upper_half[::-1]], axis=0)
hexagon = np.pad(hexagon, ((pad, pad), (pad, pad)))
self._polygon = hexagon
return

@staticmethod
def get_default_parameters_as_dict(elongation=8, base_ini=24, base_end=64):
params = []
for wid, w in enumerate(np.linspace(base_ini, base_end, elongation)):
for fid, f in enumerate(np.linspace(0, w, elongation)):
params.append({"base": int(w), "elongation": int(f)})
return params
89 changes: 89 additions & 0 deletions aicscytoparam/alignment/shape_library_2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import numpy as np
import matplotlib.pyplot as plt
from aicscytoparam.alignment.generic_2d_shape import Generic2DShape


class ShapeLibrary2D:
"""
Define a library of 2D shapes
"""

def __init__(self):
pass

def set_base_shape(self, polygon):
"""
Set the base shape for the library
Parameters
----------
polygon: Generic2DShape
base shape for the library
"""
self._polygon = polygon

def set_parameters_range(self, params_dict):
"""
Set the parameters range for the library
Parameters
----------
params_dict: dict
dictionary with the parameters range
"""
self._params = params_dict

def find_best_match(self, cx, cy):
"""
Find the best match between the contour (cx, cy) and the shapes in the library
Parameters
----------
cx: np.ndarray
x coordinates of the contour
cy: np.ndarray
y coordinates of the contour
Returns
-------
idx: int
index of the best match
params: dict
parameters of the best match
angle: float
angle that minimizes the distance
"""
angles, dists = [], []
for p in self._params:
poly = self._polygon(**p)
a, d = poly.find_angle_that_minimizes_countour_distance(cx, cy)
angles.append(a)
dists.append(d)
idx = np.argmin(dists)
return idx, self._params[idx], angles[idx]

def display(self, xlim=[-150, 150], ylim=[-50, 50], contours_to_match=None):
"""
Display the shapes in the library
Parameters
----------
xlim: list
x limits of the plot
ylim: list
y limits of the plot
contours_to_match: list of tuples
list of tuples with the contours to match
"""
n = int(np.sqrt(len(self._params)))
fig, axs = plt.subplots(n, n, figsize=(3 * n, 1 * n))
for pid, p in enumerate(self._params):
j, i = pid // n, pid % n
poly = self._polygon(**p)
axs[j, i].plot(poly.cx, poly.cy, lw=7, color="k", alpha=0.2)
axs[j, i].axis("off")
axs[j, i].set_aspect("equal")
axs[j, i].set_xlim(xlim[0], xlim[1])
axs[j, i].set_ylim(ylim[0], ylim[1])
if contours_to_match is not None:
for cx, cy in contours_to_match:
pid, p, angle = self.find_best_match(cx, cy)
cxrot, cyrot = Generic2DShape.rotate_contour(cx, cy, angle)
axs[j, i].plot(cxrot, cyrot, color="magenta")
plt.tight_layout()
plt.show()
112 changes: 1 addition & 111 deletions aicscytoparam/cytoparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import warnings
import numpy as np
from aicsimageio import AICSImage
from typing import Optional, List, Dict
from aicsshparam import shparam, shtools
from scipy import interpolate as spinterp
from typing import Optional, List, Dict, Tuple
from vtk.util.numpy_support import vtk_to_numpy, numpy_to_vtk


Expand Down Expand Up @@ -419,116 +419,6 @@ def get_intensity_representation(polydata: vtk.vtkPolyData, images_to_probe: Lis
return representation


def voxelize_mesh(
imagedata: vtk.vtkImageData, shape: Tuple, mesh: vtk.vtkPolyData, origin: List
):
"""
Voxelize a triangle mesh into an image.
Parameters
--------------------
imagedata: vtkImageData
Imagedata that will be uses as support for voxelization.
shape: tuple
Shape that imagedata scalars will take after
voxelization.
mesh: vtkPolyData
Mesh to be voxelized
origin: List
xyz specifying the lower left corner of the mesh.
Returns
-------
img: np.array
Binary array.
"""

pol2stenc = vtk.vtkPolyDataToImageStencil()
pol2stenc.SetInputData(mesh)
pol2stenc.SetOutputOrigin(origin)
pol2stenc.SetOutputWholeExtent(imagedata.GetExtent())
pol2stenc.Update()

imgstenc = vtk.vtkImageStencil()
imgstenc.SetInputData(imagedata)
imgstenc.SetStencilConnection(pol2stenc.GetOutputPort())
imgstenc.ReverseStencilOff()
imgstenc.SetBackgroundValue(0)
imgstenc.Update()

# Convert scalars from vtkImageData back to numpy
scalars = imgstenc.GetOutput().GetPointData().GetScalars()
img = vtk_to_numpy(scalars).reshape(shape)

return img


def voxelize_meshes(meshes: List):
"""
List of meshes to be voxelized into an image. Usually
the input corresponds to the cell membrane and nuclear
shell meshes.
Parameters
--------------------
meshes: List
List of vtkPolydatas representing the meshes to
be voxelized into an image.
Returns
-------
img: np.array
3D image where voxels with value i represent are
those found in the interior of the i-th mesh in
the input list. If a voxel is interior to one or
more meshes form the input list, it will take the
value of the right most mesh in the list.
origin:
Origin of the meshes in the voxelized image.
"""

# 1st mesh is used as reference (cell) and it should be
# the larger than the 2nd one (nucleus).
mesh = meshes[0]

# Find mesh coordinates
coords = vtk_to_numpy(mesh.GetPoints().GetData())

# Find bounds of the mesh
rmin = (coords.min(axis=0) - 0.5).astype(int)
rmax = (coords.max(axis=0) + 0.5).astype(int)

# Width, height and depth
w = int(2 + (rmax[0] - rmin[0]))
h = int(2 + (rmax[1] - rmin[1]))
d = int(2 + (rmax[2] - rmin[2]))

# Create image data
imagedata = vtk.vtkImageData()
imagedata.SetDimensions([w, h, d])
imagedata.SetExtent(0, w - 1, 0, h - 1, 0, d - 1)
imagedata.SetOrigin(rmin)
imagedata.AllocateScalars(vtk.VTK_UNSIGNED_CHAR, 1)

# Set all values to 1
imagedata.GetPointData().GetScalars().FillComponent(0, 1)

# Create an empty 3D numpy array to sum up
# voxelization of all meshes
img = np.zeros((d, h, w), dtype=np.uint8)

# Voxelize one mesh at the time
for mid, mesh in enumerate(meshes):
seg = voxelize_mesh(
imagedata=imagedata, shape=(d, h, w), mesh=mesh, origin=rmin
)
img[seg > 0] = mid + 1

# Origin of the reference system in the image
origin = rmin.reshape(1, 3)

return img, origin


def morph_representation_on_shape(
img: np.array, param_img_coords: np.array, representation: np.array
):
Expand Down
Loading

0 comments on commit 04128cf

Please sign in to comment.