Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Segmentation (no CLI) #106

Closed
wants to merge 8 commits into from
105 changes: 105 additions & 0 deletions examples/slurmkit_example/slurmkit_cellpose_segmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# %%
from iohub import open_ome_zarr
import numpy as np
from mantis.cli import utils
import glob
from natsort import natsorted
from pathlib import Path
from slurmkit import SlurmParams, slurm_function, submit_function
from mantis.analysis.AnalysisSettings import CellposeSegmentationSettings
import click
import datetime


# %%
input_dataset_paths = "/hpc/projects/comp.micro/mantis/2023_09_22_A549_0.52NA_illum/4.1-virutal-staining-v2/A549_MitoViewGreen_LysoTracker_W3_FOV5_1_phase_VS.zarr/0/FOV0/0"
config_file = "/hpc/projects/comp.micro/mantis/2023_09_22_A549_0.52NA_illum/4.2-segmentation/segmentation_config.yml"
output_data_path = (
"./A549_MitoViewGreen_LysoTracker_W3_FOV5_1_phase_VS_segmentation_2.zarr"
)

# sbatch and resource parameters
partition = "gpu"
cpus_per_task = 16
mem_per_cpu = "18G"
time = 300 # minutes
simultaneous_processes_per_node = 12

input_paths = [Path(path) for path in natsorted(glob.glob(input_dataset_paths))]
output_data_path = Path(output_data_path)
click.echo(f"in: {input_paths}, out: {output_data_path}")
slurm_out_path = str(output_data_path.parent / f"slurm_output/segment2-%j.out")

settings = utils.yaml_to_model(config_file, CellposeSegmentationSettings)
kwargs = {"cellpose_kwargs": settings.dict()}
print(f"Using settings: {kwargs}")
# %%
with open_ome_zarr(input_paths[0]) as dataset:
T, C, Z, Y, X = dataset.data.shape
channel_names = dataset.channel_names
chunk_zyx_shape = None
channel_names = ["label_nuc", "label_mem"]

output_metadata = {
"shape": (T, len(channel_names), Z, Y, X),
"chunks": None,
"scale": dataset.scale,
"channel_names": channel_names,
"dtype": np.float32,
}

utils.create_empty_hcs_zarr(
store_path=output_data_path,
position_keys=[p.parts[-3:] for p in input_paths],
**output_metadata,
)

# prepare slurm parameters
params = SlurmParams(
partition=partition,
gpus=1,
cpus_per_task=cpus_per_task,
mem_per_cpu=mem_per_cpu,
time=datetime.timedelta(minutes=time),
output=slurm_out_path,
)

# wrap our utils.process_single_position() function with slurmkit
slurm_process_single_position = slurm_function(utils.process_single_position_v2)
segmentation_func = slurm_process_single_position(
func=utils.nuc_mem_segmentation,
time_indices=list(range(T)),
input_channel_idx=[0, 1], # chanesl in the input dataset
output_channel_idx=[0, 1], # channels in the output dataset
num_processes=simultaneous_processes_per_node,
**kwargs,
)

# Making batches of jobs to avoid IO overload
slurmkit_array_chunk = 20
segment_jobs = []
for i in range(0, len(input_paths), slurmkit_array_chunk):
chunk_input_paths = input_paths[i : i + slurmkit_array_chunk]

if i == 0:
segment_jobs = [
submit_function(
segmentation_func,
slurm_params=params,
input_data_path=in_path,
output_path=output_data_path,
)
for in_path in chunk_input_paths
]

else:
segment_jobs = [
submit_function(
segmentation_func,
slurm_params=params,
input_data_path=in_path,
output_path=output_data_path,
dependencies=segment_jobs,
)
for in_path in chunk_input_paths
]
15 changes: 15 additions & 0 deletions mantis/analysis/AnalysisSettings.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,18 @@ def check_affine_transform_list(cls, v):
raise ValueError("Each element in affine_transform_list must be a 4x4 ndarray")

return v


class Segmentation(MyBaseModel):
diameter: int = None
flow_threshold: float = None
channels: list[int]
do_3D: bool = None


class CellposeSegmentationSettings(MyBaseModel):
z_idx: int
mem_model_path: str
membrane_segmentation: Segmentation
nuc_model_path: str
nucleus_segmentation: Segmentation
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
z_idx: 30 # infocus slice
mem_model_path: "cyto2" # default "ctyto2" model works well for cells
membrane_segmentation:
diameter: 200
flow_threshold: 0.4
channels: [1, 0] # [membrane, nucleus] membrane channel has to be first
do_3D: false # 3D segmentations only
nuc_model_path: "/hpc/projects/comp.micro/virtual_staining/models/cellpose_models/CP_20220902_NuclFL"
nucleus_segmentation:
channels: [0, 0]
41 changes: 41 additions & 0 deletions mantis/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,47 @@ def _check_nan_n_zeros(input_array):
return False


def nuc_mem_segmentation(czyx_data, **cellpose_kwargs) -> np.ndarray:
"""Segment nuclei and membranes using cellpose"""

from cellpose import models

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Get the key/values under this dictionary
# cellpose_params = cellpose_params.get('cellpose_params', {})
cellpose_params = cellpose_kwargs['cellpose_kwargs']
Z_center_slice = slice(int(cellpose_params['z_idx']), int(cellpose_params['z_idx']) + 1)
Z_slice = slice(int(cellpose_params['z_idx'])-3, int(cellpose_params['z_idx']) + 3)
C, Z, Y, X = czyx_data.shape

czyx_data_mip = np.zeros((C, 1, Y, X))
for c in range(C):
czyx_data_mip[c, 0] = np.max(czyx_data[c , Z_slice], axis=0)
cyx_data = czyx_data_mip[:, 0]

if "nucleus_segmentation" in cellpose_params:
nuc_seg_kwargs = cellpose_params["nucleus_segmentation"]
if "membrane_segmentation" in cellpose_params:
mem_seg_kwargs = cellpose_params["membrane_segmentation"]

# Initialize Cellpose models
cyto_model = models.Cellpose(gpu=True, model_type=cellpose_params["mem_model_path"])
nuc_model = models.CellposeModel(
model_type=cellpose_params["nuc_model_path"], device=torch.device(device)
)

nuc_masks = nuc_model.eval(cyx_data[0], **nuc_seg_kwargs)[0]
mem_masks, _, _, _ = cyto_model.eval(cyx_data[1], **mem_seg_kwargs)

# Save
segmentation_stack = np.zeros_like(czyx_data_mip)
zyx_mask = np.stack((nuc_masks, mem_masks))
segmentation_stack[:, 0:1] = zyx_mask[:, np.newaxis]

return segmentation_stack


## NOTE WIP
def apply_transform_to_zyx_and_save_v2(
func,
Expand Down