Skip to content

Commit

Permalink
Merge pull request #49 from Hendrik-code/poi_prediction
Browse files Browse the repository at this point in the history
Poi prediction
  • Loading branch information
Hendrik-code authored Oct 25, 2024
2 parents 1d5ba11 + c11ccab commit 5312854
Show file tree
Hide file tree
Showing 7 changed files with 249 additions and 15 deletions.
18 changes: 17 additions & 1 deletion TPTBox/core/nii_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@
np_calc_convex_hull,
np_calc_overlapping_labels,
np_center_of_mass,
np_compute_surface,
np_connected_components,
np_dilate_msk,
np_erode_msk,
np_fill_holes,
np_get_connected_components_center_of_mass,
np_get_largest_k_connected_components,
np_map_labels,
np_point_coordinates,
np_unique,
np_unique_withoutzero,
np_volume,
Expand Down Expand Up @@ -1164,6 +1166,20 @@ def get_largest_k_segmentation_connected_components(self, k: int | None, labels:
"""
return self.set_array(np_get_largest_k_connected_components(self.get_seg_array(), k=k, label_ref=labels, connectivity=connectivity, return_original_labels=return_original_labels))

def compute_surface_mask(self, connectivity: int, dilated_surface: bool = False):
""" Removes everything but surface voxels
Args:
connectivity (int): Connectivity for surface calculation
dilated_surface (bool): If False, will return msk - eroded mask. If true, will return dilated msk - msk
"""
return self.set_array(np_compute_surface(self.get_seg_array(), connectivity=connectivity, dilated_surface=dilated_surface))


def compute_surface_points(self, connectivity: int, dilated_surface: bool = False):
surface = self.compute_surface_mask(connectivity, dilated_surface)
return np_point_coordinates(surface)


def get_segmentation_difference_to(self, mask_gt: Self, ignore_background_tp: bool = False) -> Self:
"""Calculates an NII that represents the segmentation difference between self and given groundtruth mask
Expand Down Expand Up @@ -1275,8 +1291,8 @@ def save(self,file:str|Path|bids_files.BIDS_FILE,make_parents=True,verbose:loggi
out.set_data_dtype(np.uint16)
else:
out.set_data_dtype(np.int32)
log.print(f"Save {file} as {out.get_data_dtype()}",verbose=verbose,ltype=Log_Type.SAVE)
nib.save(out, file) #type: ignore
log.print(f"Save {file} as {out.get_data_dtype()}",verbose=verbose,ltype=Log_Type.SAVE)
def __str__(self) -> str:
return f"shp={self.shape}; ori={self.orientation}, zoom={tuple(np.around(self.zoom, decimals=2))}, seg={self.seg}" # type: ignore
def __repr__(self)-> str:
Expand Down
79 changes: 79 additions & 0 deletions TPTBox/core/np_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def np_extract_label(arr: np.ndarray, label: int, to_label: int = 1, inplace: bo
if not inplace:
arr = arr.copy()

# TODO support label being a sequence

if label != 0:
arr[arr != label] = 0
arr[arr == label] = to_label
Expand Down Expand Up @@ -334,6 +336,10 @@ def np_map_labels(arr: UINTARRAY, label_map: LABEL_MAP) -> np.ndarray:
k = np.array(list(label_map.keys()))
v = np.array(list(label_map.values()))

assert len(k) == len(v)
if len(k) == 0:
return arr

max_value = max(arr.max(), *k, *v) + 1

mapping_ar = np.arange(max_value, dtype=arr.dtype)
Expand Down Expand Up @@ -426,6 +432,25 @@ def np_bbox_binary(img: np.ndarray, px_dist: int | Sequence[int] | np.ndarray =


def np_center_of_bbox_binary(img: np.ndarray, px_dist: int | Sequence[int] | np.ndarray = 0):
"""Calculates the center coordinates of the bounding box around non-zero regions in a binary image.
This function determines the bounding box of non-zero regions in a binary image,
optionally expanding it by a specified pixel distance. It then computes and returns
the center coordinates of each dimension of the bounding box.
Args:
img (np.ndarray): A binary image represented as a NumPy array, where non-zero values indicate
points of interest.
px_dist (int | Sequence[int] | np.ndarray, optional): The pixel distance by which to expand
the bounding box in each dimension. Can be a single integer or a sequence of integers
corresponding to each dimension. Default is 0, meaning no expansion.
Returns:
list[int]: A list of center coordinates for each dimension of the bounding box.
Raises:
ValueError: If the input image is empty or not a valid binary array.
"""
bbox_nd = np_bbox_binary(img, px_dist=px_dist)
ctd_bbox = []
for i in range(len(bbox_nd)):
Expand Down Expand Up @@ -484,6 +509,60 @@ def np_find_index_of_k_max_values(arr: np.ndarray, k: int = 2) -> list[int]:
return list(indices)


def np_compute_surface(
arr: UINTARRAY,
connectivity: int = 3,
dilated_surface: bool = False,
):
"""Computes the surface of a binary array based on connectivity and dilation options.
This function identifies the surface voxels of a binary array. If `dilated_surface`
is True, it computes a dilated surface by expanding the array and subtracting the
original. Otherwise, it computes a contracted surface by eroding the array and
subtracting the result from the original.
Args:
arr (UINTARRAY): A binary array representing the segmentation or mask.
connectivity (int, optional): The connectivity used to define neighbors for
surface computation, where 1 represents face-connectivity, and 3 represents
full connectivity. Default is 3.
dilated_surface (bool, optional): Whether to compute a dilated surface. If True,
expands the surface; if False, contracts the surface. Default is False.
Returns:
UINTARRAY: An array representing the computed surface voxels.
"""
assert 1 <= connectivity <= 3, f"expected connectivity in [1,3], but got {connectivity}"
if dilated_surface:
return np_dilate_msk(arr, mm=1, connectivity=connectivity) - arr
else:
return arr - np_erode_msk(arr, mm=1, connectivity=connectivity)


def np_point_coordinates(
arr: UINTARRAY,
):
"""Extracts the coordinates of non-zero points from a 3D binary array.
This function locates all non-zero voxels within a 3D binary array and returns
their coordinates as a list of tuples.
Args:
arr (UINTARRAY): A 3-dimensional binary array representing the segmentation or mask.
Returns:
list[tuple[int, int, int]]: A list of (X, Y, Z) coordinate tuples for each non-zero
point in the array.
Raises:
AssertionError: If the input array does not have three dimensions.
"""
assert arr.ndim == 3, arr.ndim
x, y, z = np.where(arr)
surface_points = [(x[i], y[i], z[i]) for i in range(len(x))]
return surface_points


def np_connected_components(
arr: UINTARRAY,
connectivity: int = 3,
Expand Down
37 changes: 37 additions & 0 deletions TPTBox/core/poi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1717,3 +1717,40 @@ def calc_centroids(
else:
ctd_list[int(i), subreg_id] = tuple(round(x, decimals) for x in ctr_mass)
return POI(ctd_list, orientation=axc, **msk_nii._extract_affine(rm_key=["orientation"]))


######## Utility #######


def calc_poi_average(
pois: list[POI],
keep_points_not_present_in_all_pois: bool = False,
) -> POI:
"""Calculates average of POI across list of POIs and removes all points that are not fully present in all given POIs
Args:
pois (list[POI]): _description_
Returns:
POI: _description_
"""
# Get the keys that are present in all POIs
keys = set(pois[0].keys())
for ctd in pois:
keys = keys.union(set(ctd.keys())) if keep_points_not_present_in_all_pois else keys.intersection(set(ctd.keys()))
keys = list(keys)

# Make average array
ctd = {}
for key in keys:
ctd[key] = tuple(np.array([reg_ctd[key] for reg_ctd in pois if key in reg_ctd]).mean(axis=0))

# Sort the new ctd by keys
ctd = dict(sorted(ctd.items()))
return POI(
centroids=ctd,
orientation=pois[0].orientation,
zoom=pois[0].zoom,
shape=pois[0].shape,
rotation=pois[0].rotation,
)
61 changes: 60 additions & 1 deletion TPTBox/core/poi_fun/pixel_based_point_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from TPTBox import NII, POI, Logger_Interface, Print_Logger
from TPTBox.core.poi_fun._help import to_local_np
from TPTBox.core.poi_fun.vertebra_direction import _get_sub_array_by_direction, get_direction, get_vert_direction_matrix
from TPTBox.core.vert_constants import DIRECTIONS, Location
from TPTBox.core.vert_constants import COORDINATE, DIRECTIONS, Location

_log = Print_Logger()

Expand Down Expand Up @@ -163,3 +163,62 @@ def get_extreme_point_by_vert_direction(

idx = np.argmax(sum(a))
return pc[:, idx]


def project_pois_onto_segmentation_surface(
poi: POI,
seg: NII,
connectivity: int = 1,
dilated_surface: bool = False,
) -> POI:
"""Projects points of interest (POI) onto a segmentation surface.
This function computes the surface points of a segmentation volume and
projects the given points of interest (POI) onto these computed surface points.
Args:
poi (POI): The points of interest to be projected.
seg (NII): A segmentation volume object containing the target surface.
connectivity (int, optional): The connectivity level for defining the surface.
Default is 1, where 1 denotes face-connectivity.
dilated_surface (bool, optional): Whether to compute a dilated version of the
surface, expanding the surface area. Default is False.
Returns:
POI: The points of interest projected onto the segmentation surface.
"""
point_set = seg.compute_surface_points(
connectivity=connectivity,
dilated_surface=dilated_surface,
)
return project_pois_onto_set_of_points(poi, point_set)


def project_pois_onto_set_of_points(poi: POI, point_set: list[COORDINATE]) -> POI:
"""Projects points of interest (POI) onto the nearest points in a given set.
For each point in the POI, this function finds the closest point in the
provided point set and updates the POI coordinates to align with the nearest points.
Args:
poi (POI): The points of interest to be projected.
point_set (list[COORDINATE]): A list of coordinates representing the target points
for projection.
Returns:
POI: The updated POI with coordinates projected onto the nearest points in
the provided point set.
"""
poi_n = poi.copy()
point_arr = np.asarray(point_set)

for r, s, c in poi.items():
distance_to_point = cdist_to_point(c, point_arr)
new_coord = point_arr[np.argmin(distance_to_point)]
poi_n[r, s] = new_coord

return poi_n


def cdist_to_point(point, a):
return cdist([point], a)[0]
16 changes: 16 additions & 0 deletions TPTBox/core/vert_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,22 @@ def vert_subreg_labels(with_border: bool = True) -> list[Location]:

# fmt: on

"""
Abbreviations:
- SSL: Supraspinous Ligament
- ALL: Anterior Longitudinal Ligament
- PLL: Posterior Longitudinal Ligament
- FL: Flavum Ligament
- ISL: Interspinous Ligament
- ITL: Intertransverse Ligament
- CR: Cranial / Superior
- CA: Caudal / Inferior
- S: Sinistra / Left
- D: Dextra / Right
"""
# TODO clean this shit up (some values not defined in Location, some different values I think)
conversion_poi = {
"SSL": 81, # this POI is not included in our POI list
"ALL_CR_S": 109,
Expand Down
46 changes: 33 additions & 13 deletions TPTBox/mesh3D/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def load(cls, filepath: str | Path):
return Mesh3D(mesh)

def show(self):
print("1")
pv.start_xvfb()
pl = pv.Plotter()
pl.set_background("black", top=None)
Expand All @@ -58,7 +57,6 @@ def show(self):
pv.global_theme.interactive = True

pl.add_mesh(self.mesh)
print(2)
pl.show()


Expand All @@ -75,24 +73,36 @@ def __init__(self, int_arr: np.ndarray | Image_Reference) -> None:
if np.issubdtype(int_arr.dtype, np.floating):
print("input is of type float, converting to int")
int_arr.astype(np.uint16)
print("calculate bounding box cutout")
# calculate bounding box cutout
bbox_crop = np_bbox_binary(int_arr, px_dist=2)
x1, y1, z1 = bbox_crop[0].start, bbox_crop[1].start, bbox_crop[2].start
arr_cropped = int_arr[bbox_crop]

vertices, faces, normals, values = marching_cubes(arr_cropped, gradient_direction="ascent", step_size=1)
self._faces = faces
self._normals = normals
self._values = values
self._x1 = x1
self._y1 = y1
self._z1 = z1
# make vertices
vertices += (x1, y1, z1) # so it has correct relative coordinates (not world coordinates!)
print("column_stack")

self._vertices = vertices
vfaces = np.column_stack((np.ones(len(faces)) * 3, faces)).astype(int)

mesh = pv.PolyData(vertices, vfaces)
mesh = pv.PolyData(self._vertices, vfaces)
mesh["Normals"] = normals
mesh["values"] = values
print("mesh")
self.mesh = mesh

def get_mesh_with_offset(self, offset: tuple[float, float, float]):
vertices = self._vertices + offset
vfaces = np.column_stack((np.ones(len(self._faces)) * 3, self._faces)).astype(int)

mesh = pv.PolyData(vertices, vfaces)
mesh["Normals"] = self._normals
mesh["values"] = self._values
return mesh

@classmethod
def from_segmentation_nii(cls, seg_nii: NII, rescale_to_iso: bool = True):
assert seg_nii.seg, "NII is not a segmentation"
Expand All @@ -110,7 +120,7 @@ def __init__(
rescale_to_iso: bool = True,
regions: list[int] | None = None,
subregions: list[int] | None = None,
size_factor: int = 5,
size_factor: float = 5,
) -> None:
poi.reorient_()
if rescale_to_iso:
Expand All @@ -122,18 +132,28 @@ def __init__(
if subregions is None:
subregions = poi.keys_subregion()

poi_extracted: list[COORDINATE] = []
self.poi_extracted: list[COORDINATE] = []
self.size_factor = size_factor

for r_id, s_id, coord in poi.items():
if r_id in regions and s_id in subregions:
poi_extracted.append(coord)
self.poi_extracted.append(coord)

n = pv.PolyData(poi_extracted)
n["radius"] = np.ones(shape=len(poi_extracted)) * size_factor
assert len(self.poi_extracted) > 0, "no POIs present"
n = pv.PolyData(self.poi_extracted)
n["radius"] = np.ones(shape=len(self.poi_extracted)) * size_factor
geom = pv.Sphere(theta_resolution=8, phi_resolution=8)
glyphed = n.glyph(scale="radius", geom=geom, progress_bar=False, orient=False)
self.mesh = glyphed

def get_mesh_with_offset(self, offset: tuple[float, float, float]):
pois_shifted = [(x + offset[0], y + offset[1], z + offset[2]) for x, y, z in self.poi_extracted]
n = pv.PolyData(pois_shifted)
n["radius"] = np.ones(shape=len(pois_shifted)) * self.size_factor
geom = pv.Sphere(theta_resolution=8, phi_resolution=8)
glyphed = n.glyph(scale="radius", geom=geom, progress_bar=False, orient=False)
return glyphed


if __name__ == "__main__":
p = "/media/hendrik/be5e95dd-27c8-4c31-adc5-7b75f8ebd5c5/data/hendrik/test_samples/bids_mesh/"
Expand Down
Loading

0 comments on commit 5312854

Please sign in to comment.