Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training adjustments #4

Merged
merged 14 commits into from
Apr 20, 2023
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ dependencies = [
"monai",
"pytorch-lightning",
"torch",
"imageio",
"scipy",
"mrcfile"
]

# extras
Expand Down
2 changes: 1 addition & 1 deletion src/membrain_seg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""membrane segmentation in 3D for cryo-ET"""
"""membrane segmentation in 3D for cryo-ET."""
from importlib.metadata import PackageNotFoundError, version

try:
Expand Down
1 change: 1 addition & 0 deletions src/membrain_seg/dataloading/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Empty init."""
93 changes: 93 additions & 0 deletions src/membrain_seg/dataloading/data_utils.py
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corrected an error here: The script was saving score maps instead of segmentation masks if --store_probabilities flag was set to True.

Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import os

import mrcfile
import numpy as np
import SimpleITK as sitk


def load_data_for_inference(data_path, transforms, device):
"""Load tomogram for inference.

This function loads the tomogram, normalizes it, and performs defined
transforms on it (most likely just conversion to Torch.Tensor).
Additionally moves tomogram to GPU if available.
"""
tomogram = load_tomogram(data_path, normalize_data=True)
tomogram = np.expand_dims(tomogram, 0)

new_data = transforms(tomogram)
new_data = new_data.unsqueeze(0) # Add batch dimension
new_data = new_data.to(device)
return new_data


def store_segmented_tomograms(
network_output, out_folder, orig_data_path, ckpt_token, store_probabilities=False
):
"""Helper function for storing output segmentations.

Stores segmentation into
os.path.join(out_folder, os.path.basename(orig_data_path))
If specified, also logits are stored before thresholding.
"""
predictions = network_output[0]
predictions_np = predictions.squeeze(0).squeeze(0).cpu().numpy()
out_folder = out_folder
if store_probabilities:
out_file = os.path.join(
out_folder, os.path.basename(orig_data_path)[:-4] + "_scores.mrc"
)
store_tomogram(out_file, predictions_np)
predictions_np_thres = predictions.squeeze(0).squeeze(0).cpu().numpy() > 0.0
out_file_thres = os.path.join(
out_folder,
os.path.basename(orig_data_path)[:-4] + "_" + ckpt_token + "_segmented.mrc",
)
store_tomogram(out_file_thres, predictions_np_thres)
print("MemBrain has finished segmenting your tomogram.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be nice to wrap print statements like this under some sort of verbose flag. For example, make verbose an input parameters and then

if verbose:
    print....



def read_nifti(nifti_file):
"""Read nifti files. This will be redundant once we move to mrc files I guess?."""
a = np.array(sitk.GetArrayFromImage(sitk.ReadImage(nifti_file)), dtype=float)
return a


def load_tomogram(filename, return_header=False, normalize_data=False):
"""
Loads data and transposes s.t. we have data in the form x,y,z.

If specified, tomogram values are normalized to zero mean and unit std.
"""
with mrcfile.open(filename, permissive=True) as mrc:
data = np.array(mrc.data)
data = np.transpose(data, (2, 1, 0))
cella = mrc.header.cella
cellb = mrc.header.cellb
origin = mrc.header.origin
pixel_spacing = np.array([mrc.voxel_size.x, mrc.voxel_size.y, mrc.voxel_size.z])
header_dict = {
"cella": cella,
"cellb": cellb,
"origin": origin,
"pixel_spacing": pixel_spacing,
}
if normalize_data:
data -= np.mean(data)
data /= np.std(data)
if return_header:
return data, header_dict
return data


def store_tomogram(filename, tomogram, header_dict=None):
"""Store tomogram in specified path."""
if tomogram.dtype != np.int8:
tomogram = np.array(tomogram, dtype=np.float32)
tomogram = np.transpose(tomogram, (2, 1, 0))
with mrcfile.new(filename, overwrite=True) as mrc:
mrc.set_data(tomogram)
if header_dict is not None:
mrc.header.cella = header_dict["cella"]
mrc.header.cellb = header_dict["cellb"]
mrc.header.origin = header_dict["origin"]
257 changes: 257 additions & 0 deletions src/membrain_seg/dataloading/memseg_augmentation.py
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These augmentations are key to the network's performance in my opinion.
I tried to reproduce nnUNet's data augmentations, which are mainly based on the batchgenerators package. I wanted to not rely on either nnunet or batchgenerators package, so I tried implementing all the augmentations with MONAI.
Many of the augmentations were already available, others I translated from batchgenerators to MONAI (see also dataloading.transforms.py) .

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the testing functions for checking the effects of data augmentations

Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
import numpy as np
import torch
from monai.transforms import (
Compose,
OneOf,
RandAxisFlipd,
RandGaussianNoised,
RandGaussianSmoothd,
RandRotate90d,
RandRotated,
RandZoomd,
ToTensor,
ToTensord,
)

from membrain_seg.dataloading.transforms import (
AxesShuffle,
BlankCuboidTransform,
BrightnessGradientAdditiveTransform,
DownsampleSegForDeepSupervisionTransform,
LocalGammaTransform,
MedianFilterd,
RandAdjustContrastWithInversionAndStats,
RandApplyTransform,
RandomBrightnessTransformd,
RandomContrastTransformd,
SharpeningTransformMONAI,
SimulateLowResolutionTransform,
)

### Hard-coded area
pool_op_kernel_sizes = [
[2, 2, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2],
] # hard-coded
net_num_pool_op_kernel_sizes = pool_op_kernel_sizes
deep_supervision_scales = [[1, 1, 1]] + [
list(i) for i in 1 / np.cumprod(np.vstack(net_num_pool_op_kernel_sizes), axis=0)
][:-1]

data_aug_params = {}
data_aug_params["rotation_x"] = (-30.0 / 360 * 2.0 * np.pi, 30.0 / 360 * 2.0 * np.pi)
data_aug_params["rotation_y"] = (-30.0 / 360 * 2.0 * np.pi, 30.0 / 360 * 2.0 * np.pi)
data_aug_params["rotation_z"] = (-30.0 / 360 * 2.0 * np.pi, 30.0 / 360 * 2.0 * np.pi)

mirror_axes = (0, 1, 2)


def get_mirrored_img(img, mirror_idx):
"""Get mirrored images.

There are 8 possible cases, enumerated from 0 to 7.
This is used for test time augmentation.
"""
assert mirror_idx < 8 and mirror_idx >= 0
if mirror_idx == 0:
return img

if mirror_idx == 1 and (2 in mirror_axes):
return torch.flip(img, (4,))

if mirror_idx == 2 and (1 in mirror_axes):
return torch.flip(img, (3,))

if mirror_idx == 3 and (2 in mirror_axes) and (1 in mirror_axes):
return torch.flip(img, (4, 3))

if mirror_idx == 4 and (0 in mirror_axes):
return torch.flip(img, (2,))

if mirror_idx == 5 and (0 in mirror_axes) and (2 in mirror_axes):
return torch.flip(img, (4, 2))

if mirror_idx == 6 and (0 in mirror_axes) and (1 in mirror_axes):
return torch.flip(img, (3, 2))

if (
mirror_idx == 7
and (0 in mirror_axes)
and (1 in mirror_axes)
and (2 in mirror_axes)
):
return torch.flip(img, (4, 3, 2))


def get_training_transforms(prob_to_one=False, return_as_list=False):
"""Returns the data augmentation transforms for training phase."""
aug_sequence = [
RandRotated(
keys=("image", "label"),
range_x=data_aug_params["rotation_x"],
range_y=data_aug_params["rotation_y"],
range_z=data_aug_params["rotation_x"],
prob=(1.0 if prob_to_one else 0.75),
mode=("bilinear", "nearest"),
),
RandZoomd(
keys=("image", "label"),
prob=(1.0 if prob_to_one else 0.3),
min_zoom=0.7,
max_zoom=1.43,
mode=("trilinear", "nearest-exact"),
padding_mode=("constant", "constant"),
), # TODO: Independent scale for each axis?
RandRotate90d(
keys=("image", "label"),
prob=(1.0 if prob_to_one else 0.70),
max_k=3,
spatial_axes=(0, 1),
),
RandRotate90d(
keys=("image", "label"),
prob=(1.0 if prob_to_one else 0.70),
max_k=3,
spatial_axes=(0, 2),
),
RandRotate90d(
keys=("image", "label"),
prob=(1.0 if prob_to_one else 0.70),
max_k=3,
spatial_axes=(1, 2),
),
AxesShuffle,
OneOf(
[
RandApplyTransform(
transform=MedianFilterd(keys=["image"], radius=(2, 8)),
prob=(
1.0 if prob_to_one else 0.25
), # Changed range from 8 to 6 and prob to 15%
# for efficiency reasons
),
RandGaussianSmoothd(
keys=["image"],
sigma_x=(0.3, 1.5),
sigma_y=(0.3, 1.5),
sigma_z=(0.3, 1.5),
prob=(1.0 if prob_to_one else 0.3),
),
]
),
RandGaussianNoised(
keys=["image"], prob=(1.0 if prob_to_one else 0.4), mean=0.0, std=0.7
), # Chaned std from 0.1 to 0.5 --> check visually
RandomBrightnessTransformd(
keys=["image"], mu=0.0, sigma=0.5, prob=(1.0 if prob_to_one else 0.30)
),
OneOf(
[
RandomContrastTransformd(
keys=["image"],
contrast_range=(0.5, 2.0),
preserve_range=True,
prob=(1.0 if prob_to_one else 0.30),
),
RandomContrastTransformd(
keys=["image"],
contrast_range=(0.5, 2.0),
preserve_range=False,
prob=(1.0 if prob_to_one else 0.30),
),
]
),
RandApplyTransform(
SimulateLowResolutionTransform(
keys=["image"],
downscale_factor_range=(0.25, 1.0),
upscale_mode="trilinear",
downscale_mode="nearest-exact",
),
prob=(1.0 if prob_to_one else 0.35),
),
RandApplyTransform(
Compose(
[
RandAdjustContrastWithInversionAndStats(keys=["image"], prob=1.0),
RandAdjustContrastWithInversionAndStats(keys=["image"], prob=1.0),
]
),
prob=(1.0 if prob_to_one else 0.25),
),
RandAxisFlipd(keys=("image", "label"), prob=(1.0 if prob_to_one else 0.5)),
BlankCuboidTransform(
keys=["image"],
prob=(1.0 if prob_to_one else 0.4),
cuboid_area=(160 // 10, 160 // 3),
is_3d=True,
max_cuboids=5,
replace_with="mean",
), # patch size of 160 hard-coded. Should we make it flexible?
RandApplyTransform(
BrightnessGradientAdditiveTransform(
keys=["image"],
scale=lambda x, y: np.exp(
np.random.uniform(np.log(x[y] // 6), np.log(x[y]))
),
loc=(-0.5, 1.5),
max_strength=lambda x, y: np.random.uniform(-5, -1)
if np.random.uniform() < 0.5
else np.random.uniform(1, 5),
mean_centered=False,
),
prob=(1.0 if prob_to_one else 0.3),
),
RandApplyTransform(
LocalGammaTransform(
keys=["image"],
scale=lambda x, y: np.exp(
np.random.uniform(np.log(x[y] // 6), np.log(x[y]))
),
loc=(-0.5, 1.5),
gamma=lambda: np.random.uniform(0.01, 0.8)
if np.random.uniform() < 0.5
else np.random.uniform(1.5, 4),
),
prob=(1.0 if prob_to_one else 0.3),
),
SharpeningTransformMONAI(
keys=["image"],
strength=(0.1, 1),
same_for_each_channel=False,
p_per_channel=(1.0 if prob_to_one else 0.2),
),
DownsampleSegForDeepSupervisionTransform(
keys=["label"], ds_scales=deep_supervision_scales, order="nearest"
),
ToTensord(keys=["image"], dtype=torch.float),
]
if return_as_list:
return aug_sequence
return Compose(aug_sequence)


def get_validation_transforms(return_as_list=False):
"""Returns the data augmentation transforms for training phase."""
aug_sequence = [
DownsampleSegForDeepSupervisionTransform(
keys=["label"], ds_scales=deep_supervision_scales, order="nearest"
),
ToTensord(keys=["image"], dtype=torch.float),
]
if return_as_list:
return aug_sequence
return Compose(aug_sequence)


def get_prediction_transforms():
"""Returns data augmentation transforms for prediction phase."""
transforms = Compose(
[
ToTensor(),
]
)
return transforms
Loading