Skip to content

Commit

Permalink
* adds new segmentation approaches
Browse files Browse the repository at this point in the history
* minor stylish & pylint changes
  • Loading branch information
TimPaquaij authored and wdika committed Nov 13, 2024
1 parent d34fcad commit bf96578
Show file tree
Hide file tree
Showing 25 changed files with 1,382 additions and 195 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
__pycache__/
*.py[cod]
*$py.class

**.DS_Store
# C extensions
*.so

Expand Down
4 changes: 3 additions & 1 deletion atommic/collections/common/data/mri_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,9 @@ def __init__( # noqa: MC0001
self.examples = [ex for ex in self.examples if ex[2]["encoding_size"][1] in num_cols]

self.indices_to_log = np.random.choice(
len(self.examples), int(log_images_rate * len(self.examples)), replace=False # type: ignore
[example[1] for example in self.examples],
int(log_images_rate * len(self.examples)), # type: ignore
replace=False,
)

def _retrieve_metadata(self, fname: Union[str, Path]) -> Tuple[Dict, int]:
Expand Down
8 changes: 7 additions & 1 deletion atommic/collections/common/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,10 @@
from atommic.collections.common.losses.wasserstein import SinkhornDistance # noqa: F401

VALID_RECONSTRUCTION_LOSSES = ["l1", "mse", "ssim", "noise_aware", "wasserstein"]
VALID_SEGMENTATION_LOSSES = ["cross_entropy", "dice"]
VALID_SEGMENTATION_LOSSES = [
"categorical_cross_entropy",
"dice",
"binary_cross_entropy",
"generalized_dice",
"focal_loss",
]
222 changes: 157 additions & 65 deletions atommic/collections/multitask/rs/nn/base.py

Large diffs are not rendered by default.

12 changes: 8 additions & 4 deletions atommic/collections/multitask/rs/nn/idslr.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
if self.input_channels == 0:
raise ValueError("Segmentation module input channels cannot be 0.")
reconstruction_out_chans = cfg_dict.get("reconstruction_module_output_channels", 2)
self.segmentation_out_chans = cfg_dict.get("segmentation_module_output_channels", 1)
self.segmentation_module_output_channels = cfg_dict.get("segmentation_module_output_channels", 1)
chans = cfg_dict.get("channels", 32)
num_pools = cfg_dict.get("num_pools", 4)
drop_prob = cfg_dict.get("drop_prob", 0.0)
Expand Down Expand Up @@ -76,7 +76,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self.segmentation_decoder = UnetDecoder(
chans=chans,
num_pools=num_pools,
out_chans=self.segmentation_out_chans,
out_chans=self.segmentation_module_output_channels,
drop_prob=drop_prob,
normalize=normalize,
padding=padding,
Expand Down Expand Up @@ -238,9 +238,13 @@ def process_final_segmentation(self, prediction: torch.Tensor) -> torch.Tensor:
"""
if prediction.shape[-1] == 2:
prediction = torch.view_as_complex(prediction)
if prediction.shape[1] != self.segmentation_out_chans and prediction.shape[1] != 2 and prediction.dim() == 5:
if (
prediction.shape[1] != self.segmentation_module_output_channels
and prediction.shape[1] != 2
and prediction.dim() == 5
):
prediction = prediction.squeeze(1)
if prediction.shape[1] != self.segmentation_out_chans:
if prediction.shape[1] != self.segmentation_module_output_channels:
prediction = prediction.permute(0, 3, 1, 2)
prediction = torch.abs(prediction)
if self.normalize_segmentation_output:
Expand Down
4 changes: 2 additions & 2 deletions atommic/collections/multitask/rs/nn/idslr_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
if self.input_channels == 0:
raise ValueError("Segmentation module input channels cannot be 0.")
reconstruction_out_chans = cfg_dict.get("reconstruction_module_output_channels", 2)
segmentation_out_chans = cfg_dict.get("segmentation_module_output_channels", 1)
self.segmentation_module_output_channels = cfg_dict.get("segmentation_module_output_channels", 1)
chans = cfg_dict.get("channels", 32)
num_pools = cfg_dict.get("num_pools", 4)
drop_prob = cfg_dict.get("drop_prob", 0.0)
Expand Down Expand Up @@ -78,7 +78,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):

self.segmentation_module = Unet(
in_chans=reconstruction_out_chans,
out_chans=segmentation_out_chans,
out_chans=self.segmentation_module_output_channels,
chans=chans,
num_pool_layers=num_pools,
drop_prob=drop_prob,
Expand Down
11 changes: 8 additions & 3 deletions atommic/collections/multitask/rs/nn/mtlrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,19 +284,24 @@ def compute_reconstruction_loss(t, p, s):

return loss_func(t, p)

if self.accumulate_predictions:
if self.reconstruction_module_accumulate_predictions:
rs_cascades_weights = torch.logspace(-1, 0, steps=len(prediction)).to(target.device)
rs_cascades_loss = []
for rs_cascade_pred in prediction:
cascades_weights = torch.logspace(-1, 0, steps=len(rs_cascade_pred)).to(target.device)
cascades_loss = []
for cascade_pred in rs_cascade_pred:
time_steps_weights = torch.logspace(-1, 0, steps=self.time_steps).to(target.device)
time_steps_weights = torch.logspace(-1, 0, steps=self.reconstruction_module_time_steps).to(
target.device
)
time_steps_loss = [
compute_reconstruction_loss(target, time_step_pred, sensitivity_maps)
for time_step_pred in cascade_pred
]
cascade_loss = sum(x * w for x, w in zip(time_steps_loss, time_steps_weights)) / self.time_steps
cascade_loss = (
sum(x * w for x, w in zip(time_steps_loss, time_steps_weights))
/ self.reconstruction_module_time_steps
)
cascades_loss.append(cascade_loss)
rs_cascade_loss = sum(x * w for x, w in zip(cascades_loss, cascades_weights)) / len(rs_cascade_pred)
rs_cascades_loss.append(rs_cascade_loss)
Expand Down
3 changes: 2 additions & 1 deletion atommic/collections/multitask/rs/nn/recseg_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
drop_prob=cfg_dict.get("reconstruction_module_dropout", 0.0),
)

self.segmentation_module_output_channels = cfg_dict.get("segmentation_module_output_channels", 1)
self.segmentation_module = Unet(
in_chans=reconstruction_module_output_channels,
out_chans=cfg_dict.get("segmentation_module_output_channels", 1),
out_chans=self.segmentation_module_output_channels,
chans=cfg_dict.get("segmentation_module_channels", 64),
num_pool_layers=cfg_dict.get("segmentation_module_pooling_layers", 2),
drop_prob=cfg_dict.get("segmentation_module_dropout", 0.0),
Expand Down
8 changes: 4 additions & 4 deletions atommic/collections/multitask/rs/nn/segnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):

self.input_channels = cfg_dict.get("input_channels", 2)
reconstruction_out_chans = cfg_dict.get("reconstruction_module_output_channels", 2)
segmentation_out_chans = cfg_dict.get("segmentation_module_output_channels", 1)
self.segmentation_module_output_channels = cfg_dict.get("segmentation_module_output_channels", 1)
chans = cfg_dict.get("channels", 32)
num_pools = cfg_dict.get("num_pools", 4)
drop_prob = cfg_dict.get("drop_prob", 0.0)
Expand Down Expand Up @@ -97,7 +97,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
UnetDecoder(
chans=chans,
num_pools=num_pools,
out_chans=segmentation_out_chans,
out_chans=self.segmentation_module_output_channels,
drop_prob=drop_prob,
normalize=normalize,
padding=padding,
Expand All @@ -110,8 +110,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):

self.segmentation_final_layer = torch.nn.Sequential(
ConvNonlinear(
segmentation_out_chans * num_cascades,
segmentation_out_chans,
self.segmentation_module_output_channels * num_cascades,
self.segmentation_module_output_channels,
conv_dim=cfg_dict.get("segmentation_final_layer_conv_dim", 2),
kernel_size=cfg_dict.get("segmentation_final_layer_kernel_size", 3),
dilation=cfg_dict.get("segmentation_final_layer_dilation", 1),
Expand Down
10 changes: 5 additions & 5 deletions atommic/collections/multitask/rs/nn/seranet.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
coil_combination_method=self.coil_combination_method,
)
self.segmentation_module_input_channels = cfg_dict.get("segmentation_module_input_channels", 2)
segmentation_module_output_channels = cfg_dict.get("segmentation_module_output_channels", 1)
self.segmentation_module_output_channels = cfg_dict.get("segmentation_module_output_channels", 1)
self.segmentation_module = ConvLSTMNormUnet(
in_chans=self.segmentation_module_input_channels,
out_chans=segmentation_module_output_channels,
out_chans=self.segmentation_module_output_channels,
chans=cfg_dict.get("segmentation_module_channels", 64),
num_pools=cfg_dict.get("segmentation_module_pooling_layers", 2),
drop_prob=cfg_dict.get("segmentation_module_dropout", 0.0),
Expand All @@ -109,12 +109,12 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
num_iterations=cfg_dict.get("recurrent_module_iterations", 3),
attention_model=AttentionGate(
in_chans_x=self.segmentation_module_input_channels * 2,
in_chans_g=segmentation_module_output_channels,
out_chans=segmentation_module_output_channels,
in_chans_g=self.segmentation_module_output_channels,
out_chans=self.segmentation_module_output_channels,
),
unet_model=ConvLSTMNormUnet(
in_chans=self.segmentation_module_input_channels * 2,
out_chans=segmentation_module_output_channels,
out_chans=self.segmentation_module_output_channels,
chans=cfg_dict.get("recurrent_module_attention_channels", 64),
num_pools=cfg_dict.get("recurrent_module_attention_pooling_layers", 2),
drop_prob=cfg_dict.get("recurrent_module_attention_dropout", 0.0),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@ def __init__(
)
self.model_name = self.reconstruction_module[0].__class__.__name__.lower()
if self.model_name == "modulelist":
self.model_name = self.reconstruction_module[0][0].__class__.__name__.lower()
self.model_name = self.reconstruction_module[0][ # pylint: disable=unsubscriptable-object
0
].__class__.__name__.lower()

self.fft_centered = fft_centered
self.fft_normalization = fft_normalization
Expand Down
33 changes: 32 additions & 1 deletion atommic/collections/multitask/rs/parts/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,9 @@ def __init__(
fft_normalization: str = "backward",
spatial_dims: Sequence[int] = None,
coil_dim: int = 0,
consecutive_slices: int = 1, # pylint: disable=unused-argument
consecutive_slices: int = 1,
use_seed: bool = True,
include_background_label: bool = False,
):
"""Inits :class:`RSMRIDataTransforms`.
Expand Down Expand Up @@ -247,6 +248,8 @@ def __init__(
Consecutive slices. Default is ``1``.
use_seed : bool, optional
Whether to use seed. Default is ``True``.
include_background_label: bool optional
Add an extra class to define the background label. Default is ``False``.
"""
self.complex_data = complex_data

Expand Down Expand Up @@ -436,6 +439,8 @@ def __init__(
self.normalization, # type: ignore
]
)
self.consecutive_slices = consecutive_slices
self.include_background_label = include_background_label

self.cropping = Composer([self.cropping]) # type: ignore
self.normalization = Composer([self.normalization]) # type: ignore
Expand Down Expand Up @@ -574,6 +579,32 @@ def __call__(
segmentation_labels = segmentation_labels.float()
segmentation_labels = torch.abs(segmentation_labels)

if self.include_background_label:
if self.consecutive_slices > 1:
segmentation_labels_bg = torch.zeros(
(segmentation_labels.shape[0], segmentation_labels.shape[2], segmentation_labels.shape[3])
)
segmentation_labels_new = torch.zeros(
(
segmentation_labels.shape[0],
segmentation_labels.shape[1] + 1,
segmentation_labels.shape[2],
segmentation_labels.shape[3],
)
)
for i in range(target_reconstruction.shape[0]):
idx_background = torch.where(torch.sum(segmentation_labels[i], dim=0) == 0)
segmentation_labels_bg[i][idx_background] = 1
segmentation_labels_new[i] = torch.concat(
(segmentation_labels_bg[i].unsqueeze(0), segmentation_labels[i]), dim=0
)
segmentation_labels = segmentation_labels_new
else:
segmentation_labels_bg = torch.zeros((segmentation_labels.shape[-2], segmentation_labels.shape[-1]))
idx_background = torch.where(torch.sum(segmentation_labels, dim=0) == 0)
segmentation_labels_bg[idx_background] = 1
segmentation_labels = torch.concat((segmentation_labels_bg.unsqueeze(0), segmentation_labels), dim=0)

attrs.update(
self.__parse_normalization_vars__(
kspace_pre_normalization_vars,
Expand Down
5 changes: 4 additions & 1 deletion atommic/collections/segmentation/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# coding=utf-8
__author__ = "Dimitris Karkalousos"

from atommic.collections.segmentation.losses.cross_entropy import CrossEntropyLoss # noqa: F401
from atommic.collections.segmentation.losses.cross_entropy import BinaryCrossEntropyLoss # noqa F401
from atommic.collections.segmentation.losses.cross_entropy import CategoricalCrossEntropyLoss # noqa: F401
from atommic.collections.segmentation.losses.dice import Dice # noqa: F401
from atommic.collections.segmentation.losses.dice import GeneralisedDice # noqa: F401
from atommic.collections.segmentation.losses.focal import FocalLoss # noqa: F401
Loading

0 comments on commit bf96578

Please sign in to comment.