From 65f13e00dc78abff622ca906803965973fc4582c Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 15 Jul 2024 21:39:23 -0700 Subject: [PATCH] adding segmentation functions CZYX for neuromast and cells --- viscy/analysis/segmentation.py | 118 +++++++++++++++++++++++ viscy/analysis/settings/segmentation.yml | 24 +++++ 2 files changed, 142 insertions(+) create mode 100644 viscy/analysis/segmentation.py create mode 100644 viscy/analysis/settings/segmentation.yml diff --git a/viscy/analysis/segmentation.py b/viscy/analysis/segmentation.py new file mode 100644 index 00000000..b455bcf3 --- /dev/null +++ b/viscy/analysis/segmentation.py @@ -0,0 +1,118 @@ +import torch +import numpy as np +import click +from cellpose import models +from skimage.exposure import rescale_intensity, equalize_adapthist +from skimage.util import invert +from numpy.typing import ArrayLike + + +def nuc_mem_segmentation_cellposemodel_3D( + czyx_data: ArrayLike, zyx_slicing: tuple[slice, slice, slice], **cellpose_kwargs +): + """ + Segment nuclei and membranes using Cellpose 3D model. + + """ + + Z_slice = zyx_slicing[0] + Y_slice = zyx_slicing[1] + X_slice = zyx_slicing[2] + czyx_data = czyx_data[:, Z_slice, Y_slice, X_slice] + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + segmentation_stack = np.zeros_like(czyx_data) + click.echo(f"Segmentation Stack shape {segmentation_stack.shape}") + cellpose_params = cellpose_kwargs["cellpose_kwargs"] + c_idx = 0 + if "nucleus_kwargs" in cellpose_params: + click.echo("Segmenting Nuclei") + nuc_seg_kwargs = cellpose_params["nucleus_kwargs"] + + model_nucleus_3D = models.CellposeModel( + model_type=cellpose_params["nuc_model_path"], + # net_avg=True, #Note removed CP3.0 + gpu=True, + device=torch.device(device), + ) + nuc_segmentation, _, _ = model_nucleus_3D.eval(czyx_data, **nuc_seg_kwargs) + segmentation_stack[c_idx] = nuc_segmentation.astype(np.uint16) + c_idx += 1 + if "membrane_kwargs" in cellpose_params: + click.echo("Segmenting Membrane") + mem_seg_kwargs = cellpose_params["membrane_kwargs"] + + model_membrane_3D = models.CellposeModel( + model_type=cellpose_params["mem_model_path"], + # net_avg=True, + gpu=True, + device=torch.device(device), + ) + c_idx_mem, c_idx_nuc = mem_seg_kwargs["channels"] + mem_segmentation, _, _ = model_membrane_3D.eval(czyx_data, **mem_seg_kwargs) + segmentation_stack[c_idx] = mem_segmentation.astype(np.uint16) + + return segmentation_stack + + +def nuc_mem_cp_segmentation_clahe_3D( + czyx_data: ArrayLike, zyx_slicing: tuple, clahe_kwargs, **cellpose_kwargs +): +""" + Segment nuclei and membranes using Cellpose 3D model with CLAHE applied to the input data. +""" + + Z_slice = zyx_slicing[0] + Y_slice = zyx_slicing[1] + X_slice = zyx_slicing[2] + czyx_data = czyx_data[:, Z_slice, Y_slice, X_slice] + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + segmentation_stack = np.zeros_like(czyx_data, dtype=np.uint16) + click.echo(f"Segmentation Stack shape {segmentation_stack.shape}") + cellpose_params = cellpose_kwargs["cellpose_kwargs"] + # clahe_kwargs = clahe_kwargs['clahe'] + c_idx = 0 + if "nucleus_kwargs" in cellpose_params: + click.echo("Segmenting Nuclei") + nuc_seg_kwargs = cellpose_params["nucleus_kwargs"] + + model_nucleus_3D = models.CellposeModel( + model_type=cellpose_params["nuc_model_path"], + # net_avg=True, #Note removed CP3.0 + gpu=True, + device=torch.device(device), + ) + # Apply CLAHE before cellpose + if "clahe_nuc" in clahe_kwargs: + click.echo("Applying CLAHE to Nuclei") + nuc_clahe = clahe_kwargs["clahe_nuc"] + czyx_data[c_idx] = rescale_intensity(czyx_data[c_idx], out_range=(0.0, 1.0)) + czyx_data[c_idx] = equalize_adapthist(czyx_data[c_idx], **nuc_clahe) + nuc_segmentation, _, _ = model_nucleus_3D.eval(czyx_data, **nuc_seg_kwargs) + segmentation_stack[c_idx] = nuc_segmentation.astype(np.uint16) + c_idx += 1 + if "membrane_kwargs" in cellpose_params: + click.echo("Segmenting Membrane") + mem_seg_kwargs = cellpose_params["membrane_kwargs"] + + if "clahe_mem" in clahe_kwargs: + click.echo("Applying CLAHE to Membrane") + mem_clahe = clahe_kwargs["clahe_mem"] + czyx_data[c_idx] = rescale_intensity( + invert(czyx_data[c_idx]), out_range=(0.0, 1.0) + ) + czyx_data[c_idx] = equalize_adapthist(czyx_data[c_idx], **mem_clahe) + model_membrane_3D = models.CellposeModel( + model_type=cellpose_params["mem_model_path"], + # net_avg=True, + gpu=True, + device=torch.device(device), + ) + c_idx_mem, c_idx_nuc = mem_seg_kwargs["channels"] + mem_segmentation, _, _ = model_membrane_3D.eval(czyx_data, **mem_seg_kwargs) + segmentation_stack[c_idx] = mem_segmentation.astype(np.uint16) + + return segmentation_stack diff --git a/viscy/analysis/settings/segmentation.yml b/viscy/analysis/settings/segmentation.yml new file mode 100644 index 00000000..3bd46872 --- /dev/null +++ b/viscy/analysis/settings/segmentation.yml @@ -0,0 +1,24 @@ +mem_model_path: "/hpc/projects/jacobo_group/Code/timelapse_seg_tracking_pipeline/3_segmentation/membrane/cellpose_2Chan_scratch_2024_04_30_11_12_00" +membrane_kwargs: + diameter: 65 + channels: + - 2 + - 1 + cellprob_threshold: 0.4 + invert: false + do_3D: true + anisotropy: 3.26 + min_size: 8000 + +nuc_model_path: "/hpc/projects/jacobo_group/projects/cellpose/Nuclei/Deconvolved/Fine_Tune/models/cellpose_Slices_decon_nuclei_nuclei_v7_2023_06_28_16_54" +nucleus_kwargs: + diameter: 60 + channels: + - 1 + - 0 + cellprob_threshold: 0.0 + invert: false + do_3D: true + anisotropy: 3.26 + min_size: 8000 +