Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/new segmenter #13

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions segmenter_model_zoo/quilt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


class QuiltModelZoo:
""" download model from quilt """
"""download model from quilt"""

def __init__(self):
"""connect to model zoo on quilt3"""
Expand Down Expand Up @@ -59,7 +59,7 @@ def download_model(


def validate_model(model_name, save_path):
""" check if model exists, otherwise download it """
"""check if model exists, otherwise download it"""
model_path = save_path + os.sep + model_name + ".pth"
if not os.path.exists(model_path):
zoo_client = QuiltModelZoo()
Expand Down
168 changes: 98 additions & 70 deletions segmenter_model_zoo/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
import importlib

import torch
from torch.autograd import Variable
from aicsmlsegment.utils import input_normalization
from scipy.ndimage import zoom
from aicsimageio import AICSImage

from segmenter_model_zoo.quilt_utils import validate_model
from aicsmlsegment.multichannel_sliding_window import sliding_window_inference
from aicsmlsegment.fnet_prediction_torch import predict_piecewise

###############################################################################

Expand Down Expand Up @@ -43,6 +44,13 @@
"nchannel": 1,
"OutputCh": [0, 1],
},
"unet_xy_zoom_0pad": {
"size_in": [32, 360, 360],
"size_out": [32, 360, 360],
"nclass": [2, 2, 2],
"nchannel": 1,
"OutputCh": 1,
},
}

# a record of current basic models
Expand Down Expand Up @@ -83,6 +91,12 @@
"path": "quilt",
"default_cutoff": 0.5,
},
"LMNB1_fill_production_v2": {
"model_type": "unet_xy_zoom_0pad",
"norm": 15,
"path": "//allen/aics/assay-dev/users/Benji/problem3/late_tp_addition/run_7/checkpoint_epoch=259.ckpt", # noqa E501
"default_cutoff": 0.4,
},
"LMNB1_seed_production": {
"model_type": "unet_xy_zoom",
"norm": 15,
Expand Down Expand Up @@ -214,6 +228,7 @@ def load_train(

else:
model_type = CHECKPOINT_PATH_MAPPING[checkpoint_name]["model_type"]
self.model_name = model_type

# load default model parameters or from model_param
if "size_in" in model_param:
Expand Down Expand Up @@ -243,12 +258,15 @@ def load_train(

# define the model
if model_type == "unet_xy":
from aicsmlsegment.Net3D.unet_xy import UNet3D as DNN
from aicsmlsegment.NetworkArchitecture.unet_xy import UNet3D as DNN

model = DNN(self.nchannel, self.nclass)
self.softmax = model.final_activation
elif model_type == "unet_xy_zoom":
from aicsmlsegment.Net3D.unet_xy_enlarge import UNet3D as DNN
elif model_type in ["unet_xy_zoom", "unet_xy_zoom_0pad"]:
module = importlib.import_module(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would importlib also work for unet_xy?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch! I need to fix this

"aicsmlsegment.NetworkArchitecture." + model_type
)
DNN = getattr(module, "UNet3D")

if "zoom_ratio" in model_param:
zoom_ratio = model_param["zoom_ratio"]
Expand All @@ -272,8 +290,20 @@ def load_train(
model_path = CHECKPOINT_PATH_MAPPING[checkpoint_name]["path"]

state = torch.load(model_path, map_location=torch.device("cpu"))
print(state.keys())
if "model_state_dict" in state:
self.model.load_state_dict(state["model_state_dict"])
elif "state_dict" in state:
try:
self.model.load_state_dict(state["state_dict"])
except Exception:
from collections import OrderedDict

model_state = state["state_dict"]
model_state_adjusted = OrderedDict()
for key, value in model_state.items():
model_state_adjusted[key[6:]] = value
self.model.load_state_dict(model_state_adjusted)
else:
self.model.load_state_dict(state)

Expand Down Expand Up @@ -304,6 +334,8 @@ def apply_on_single_zstack(
already_normalized: bool = False,
cutoff: float = None,
inference_param: Dict = {},
size_in: List = None,
size_out: List = None,
) -> np.ndarray:
"""
Apply a trained model on an image
Expand Down Expand Up @@ -334,6 +366,10 @@ def apply_on_single_zstack(
only one parameter is allowed: "ResizeRatio" (a list of three
float numbers to indicate the ResizeRatio to apply on ZYX axis).
More parameters may be added in the future.
size_in: List
the input patch size, to overwrite default
size_out: List
the output patch size, to overwrite default

Return:
-------------
Expand Down Expand Up @@ -383,73 +419,65 @@ def apply_on_single_zstack(
model = self.model
model.eval()

# do padding on input
padding = [(x - y) // 2 for x, y in zip(self.size_in, self.size_out)]
img_pad0 = np.pad(
input_img,
((0, 0), (0, 0), (padding[1], padding[1]), (padding[2], padding[2])),
"symmetric",
)
img_pad = np.pad(
img_pad0, ((0, 0), (padding[0], padding[0]), (0, 0), (0, 0)), "constant"
)

# we only support single output image in model zoo
# other outputs are only supported in full segmenter prediction so far
assert len(self.OutputCh) == 2
output_img = np.zeros(input_img.shape)

# loop through the image patch by patch
num_step_z = int(np.ceil(input_img.shape[1] / self.size_out[0]))
num_step_y = int(np.ceil(input_img.shape[2] / self.size_out[1]))
num_step_x = int(np.ceil(input_img.shape[3] / self.size_out[2]))
with torch.no_grad():
for ix in range(num_step_x):
if ix < num_step_x - 1:
xa = ix * self.size_out[2]
# check if need to use default size_in and size_out
if size_in is None:
size_in = self.size_in
if size_out is None:
size_out = self.size_out

if size_in == size_out:
dims_max = [0] + size_in
overlaps = [int(0.1 * dim) for dim in dims_max]
input_tensor = torch.from_numpy(input_img).to("cuda:0")
with torch.no_grad():
output_tensor = predict_piecewise(
model,
input_tensor,
dims_max=dims_max,
overlaps=overlaps,
)
else:
# do padding on input
padding = [(x - y) // 2 for x, y in zip(size_in, size_out)]
img_pad0 = np.pad(
input_img,
((0, 0), (0, 0), (padding[1], padding[1]), (padding[2], padding[2])),
"symmetric",
)
img_pad = np.pad(
img_pad0, ((0, 0), (padding[0], padding[0]), (0, 0), (0, 0)), "constant"
)

# pad the extra batch dimension
img_pad = np.expand_dims(img_pad, axis=0)

# run sliding window inference
with torch.no_grad():
output_tensor, _ = sliding_window_inference(
inputs=torch.from_numpy(img_pad).float().cuda(),
roi_size=size_in,
out_size=size_out,
original_image_size=input_img.shape[-3:],
sw_batch_size=1,
predictor=model.forward,
overlap=0.25,
mode="gaussian",
model_name=self.model_name,
)

output_tensor = torch.nn.Softmax(dim=1)(output_tensor)
output_img = output_tensor.cpu().data.numpy()
if self.OutputCh:
# old models, only take the output from the highest resolution
if type(self.OutputCh) == list:
# if it is [v1, v2], the second value is which channel to take from
# the highest resolution output
if len(self.OutputCh) >= 2:
self.OutputCh = self.OutputCh[1]
else:
xa = input_img.shape[3] - self.size_out[2]

for iy in range(num_step_y):
if iy < num_step_y - 1:
ya = iy * self.size_out[1]
else:
ya = input_img.shape[2] - self.size_out[1]

for iz in range(num_step_z):
if iz < num_step_z - 1:
za = iz * self.size_out[0]
else:
za = input_img.shape[1] - self.size_out[0]

input_patch = img_pad[
:,
za : (za + self.size_in[0]),
ya : (ya + self.size_in[1]),
xa : (xa + self.size_in[2]),
]
input_img_tensor = torch.from_numpy(input_patch)
tmp_out = model(Variable(input_img_tensor.cuda()).unsqueeze(0))
assert len(self.OutputCh) // 2 <= len(
tmp_out
), "the parameter OutputCh not compatible with output tensors"

label = tmp_out[self.OutputCh[0]]
prob = self.softmax(label)
out_flat_tensor = prob.cpu().data
out_tensor = out_flat_tensor.view(
self.size_out[0],
self.size_out[1],
self.size_out[2],
self.nclass[0],
)
out_nda = out_tensor.numpy()
output_img[
0,
za : (za + self.size_out[0]),
ya : (ya + self.size_out[1]),
xa : (xa + self.size_out[2]),
] = out_nda[:, :, :, self.OutputCh[1]]
# just convert list to integer
self.OutputCh = self.OutputCh[0]
output_img = output_img[:, self.OutputCh, :, :, :]

torch.cuda.empty_cache()

Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@
requirements = [
'PyYAML',
'aicsimageio>3.3.0',
'aicsmlsegment>0.0.5'
'aicsmlsegment>0.0.5',
'scikit-image',
"quilt3",
'quilt3',
]

extra_requirements = {
Expand Down