diff --git a/examples/slurmkit_example/slurmkit_cellpose_segmentation.py b/examples/slurmkit_example/slurmkit_cellpose_segmentation.py new file mode 100644 index 00000000..0ec4168b --- /dev/null +++ b/examples/slurmkit_example/slurmkit_cellpose_segmentation.py @@ -0,0 +1,105 @@ +# %% +from iohub import open_ome_zarr +import numpy as np +from mantis.cli import utils +import glob +from natsort import natsorted +from pathlib import Path +from slurmkit import SlurmParams, slurm_function, submit_function +from mantis.analysis.AnalysisSettings import CellposeSegmentationSettings +import click +import datetime + + +# %% +input_dataset_paths = "/hpc/projects/comp.micro/mantis/2023_09_22_A549_0.52NA_illum/4.1-virutal-staining-v2/A549_MitoViewGreen_LysoTracker_W3_FOV5_1_phase_VS.zarr/0/FOV0/0" +config_file = "/hpc/projects/comp.micro/mantis/2023_09_22_A549_0.52NA_illum/4.2-segmentation/segmentation_config.yml" +output_data_path = ( + "./A549_MitoViewGreen_LysoTracker_W3_FOV5_1_phase_VS_segmentation_2.zarr" +) + +# sbatch and resource parameters +partition = "gpu" +cpus_per_task = 16 +mem_per_cpu = "18G" +time = 300 # minutes +simultaneous_processes_per_node = 12 + +input_paths = [Path(path) for path in natsorted(glob.glob(input_dataset_paths))] +output_data_path = Path(output_data_path) +click.echo(f"in: {input_paths}, out: {output_data_path}") +slurm_out_path = str(output_data_path.parent / f"slurm_output/segment2-%j.out") + +settings = utils.yaml_to_model(config_file, CellposeSegmentationSettings) +kwargs = {"cellpose_kwargs": settings.dict()} +print(f"Using settings: {kwargs}") +# %% +with open_ome_zarr(input_paths[0]) as dataset: + T, C, Z, Y, X = dataset.data.shape + channel_names = dataset.channel_names +chunk_zyx_shape = None +channel_names = ["label_nuc", "label_mem"] + +output_metadata = { + "shape": (T, len(channel_names), Z, Y, X), + "chunks": None, + "scale": dataset.scale, + "channel_names": channel_names, + "dtype": np.float32, +} + +utils.create_empty_hcs_zarr( + store_path=output_data_path, + position_keys=[p.parts[-3:] for p in input_paths], + **output_metadata, +) + +# prepare slurm parameters +params = SlurmParams( + partition=partition, + gpus=1, + cpus_per_task=cpus_per_task, + mem_per_cpu=mem_per_cpu, + time=datetime.timedelta(minutes=time), + output=slurm_out_path, +) + +# wrap our utils.process_single_position() function with slurmkit +slurm_process_single_position = slurm_function(utils.process_single_position_v2) +segmentation_func = slurm_process_single_position( + func=utils.nuc_mem_segmentation, + time_indices=list(range(T)), + input_channel_idx=[0, 1], # chanesl in the input dataset + output_channel_idx=[0, 1], # channels in the output dataset + num_processes=simultaneous_processes_per_node, + **kwargs, +) + +# Making batches of jobs to avoid IO overload +slurmkit_array_chunk = 20 +segment_jobs = [] +for i in range(0, len(input_paths), slurmkit_array_chunk): + chunk_input_paths = input_paths[i : i + slurmkit_array_chunk] + + if i == 0: + segment_jobs = [ + submit_function( + segmentation_func, + slurm_params=params, + input_data_path=in_path, + output_path=output_data_path, + ) + for in_path in chunk_input_paths + ] + + else: + segment_jobs = [ + submit_function( + segmentation_func, + slurm_params=params, + input_data_path=in_path, + output_path=output_data_path, + dependencies=segment_jobs, + ) + for in_path in chunk_input_paths + ] diff --git a/mantis/analysis/AnalysisSettings.py b/mantis/analysis/AnalysisSettings.py index ec469072..800c5bb5 100644 --- a/mantis/analysis/AnalysisSettings.py +++ b/mantis/analysis/AnalysisSettings.py @@ -98,3 +98,18 @@ def check_affine_transform_list(cls, v): raise ValueError("Each element in affine_transform_list must be a 4x4 ndarray") return v + + +class Segmentation(MyBaseModel): + diameter: int = None + flow_threshold: float = None + channels: list[int] + do_3D: bool = None + + +class CellposeSegmentationSettings(MyBaseModel): + z_idx: int + mem_model_path: str + membrane_segmentation: Segmentation + nuc_model_path: str + nucleus_segmentation: Segmentation diff --git a/mantis/analysis/settings/example_cellpose_segmentation_settings.yml b/mantis/analysis/settings/example_cellpose_segmentation_settings.yml new file mode 100644 index 00000000..a513fc0d --- /dev/null +++ b/mantis/analysis/settings/example_cellpose_segmentation_settings.yml @@ -0,0 +1,10 @@ +z_idx: 30 # infocus slice +mem_model_path: "cyto2" # default "ctyto2" model works well for cells +membrane_segmentation: + diameter: 200 + flow_threshold: 0.4 + channels: [1, 0] # [membrane, nucleus] membrane channel has to be first + do_3D: false # 3D segmentations only +nuc_model_path: "/hpc/projects/comp.micro/virtual_staining/models/cellpose_models/CP_20220902_NuclFL" +nucleus_segmentation: + channels: [0, 0] \ No newline at end of file diff --git a/mantis/cli/utils.py b/mantis/cli/utils.py index 6a73f449..533d4da1 100644 --- a/mantis/cli/utils.py +++ b/mantis/cli/utils.py @@ -930,6 +930,47 @@ def _check_nan_n_zeros(input_array): return False +def nuc_mem_segmentation(czyx_data, **cellpose_kwargs) -> np.ndarray: + """Segment nuclei and membranes using cellpose""" + + from cellpose import models + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Get the key/values under this dictionary + # cellpose_params = cellpose_params.get('cellpose_params', {}) + cellpose_params = cellpose_kwargs['cellpose_kwargs'] + Z_center_slice = slice(int(cellpose_params['z_idx']), int(cellpose_params['z_idx']) + 1) + Z_slice = slice(int(cellpose_params['z_idx'])-3, int(cellpose_params['z_idx']) + 3) + C, Z, Y, X = czyx_data.shape + + czyx_data_mip = np.zeros((C, 1, Y, X)) + for c in range(C): + czyx_data_mip[c, 0] = np.max(czyx_data[c , Z_slice], axis=0) + cyx_data = czyx_data_mip[:, 0] + + if "nucleus_segmentation" in cellpose_params: + nuc_seg_kwargs = cellpose_params["nucleus_segmentation"] + if "membrane_segmentation" in cellpose_params: + mem_seg_kwargs = cellpose_params["membrane_segmentation"] + + # Initialize Cellpose models + cyto_model = models.Cellpose(gpu=True, model_type=cellpose_params["mem_model_path"]) + nuc_model = models.CellposeModel( + model_type=cellpose_params["nuc_model_path"], device=torch.device(device) + ) + + nuc_masks = nuc_model.eval(cyx_data[0], **nuc_seg_kwargs)[0] + mem_masks, _, _, _ = cyto_model.eval(cyx_data[1], **mem_seg_kwargs) + + # Save + segmentation_stack = np.zeros_like(czyx_data_mip) + zyx_mask = np.stack((nuc_masks, mem_masks)) + segmentation_stack[:, 0:1] = zyx_mask[:, np.newaxis] + + return segmentation_stack + + ## NOTE WIP def apply_transform_to_zyx_and_save_v2( func,