diff --git a/.gitignore b/.gitignore index 6500cd24..d13c4505 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,7 @@ __pycache__/ *.py[cod] *$py.class - +**.DS_Store # C extensions *.so diff --git a/atommic/collections/multitask/rs/nn/mtlrs.py b/atommic/collections/multitask/rs/nn/mtlrs.py index 81daa2d4..56f92e51 100644 --- a/atommic/collections/multitask/rs/nn/mtlrs.py +++ b/atommic/collections/multitask/rs/nn/mtlrs.py @@ -2,7 +2,7 @@ __author__ = "Dimitris Karkalousos" from typing import Dict, List, Tuple, Union - +import warnings import torch from omegaconf import DictConfig, OmegaConf from pytorch_lightning import Trainer @@ -12,6 +12,8 @@ from atommic.collections.multitask.rs.nn.base import BaseMRIReconstructionSegmentationModel from atommic.collections.multitask.rs.nn.mtlrs_base.mtlrs_block import MTLRSBlock from atommic.core.classes.common import typecheck +from atommic.collections.multitask.rs.nn.mtlrs_base.task_attention_module import TaskAttentionalModule +from atommic.collections.multitask.rs.nn.mtlrs_base.spatially_adaptive_semantic_guidance_module import SASG __all__ = ["MTLRS"] @@ -108,8 +110,51 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): ) self.task_adaption_type = cfg_dict.get("task_adaption_type", "multi_task_learning") + self.attention_module = cfg_dict.get("attention_module", False) + self.attention_module_kernel_size = cfg_dict.get("attention_module_kernel_size", 3) + self.attention_module_padding = cfg_dict.get("attention_module_padding", 1) + self.attention_module_remove_background = cfg_dict.get("attention_module_remove_background", True) + self.attention_module_segmentation_input = ( + self.segmentation_module_output_channels - 1 + if self.attention_module_remove_background + else self.segmentation_module_output_channels + ) + if self.attention_module == "SemanticGuidanceModule" and self.task_adaption_type == "multi_task_learning": + self.task_adaption_type = "multi_task_learning_softmax" + warnings.warn( + "Logits can not be used with Spatially Semantic Guidance attention modules." + " Logits will be transfomed into probabilities with softmax." + f"task_adaption_type: {self.task_adaption_type}" + ) + if self.attention_module == "TaskAttentionModule": + self.attention_module_block = torch.nn.ModuleList( + [ + TaskAttentionalModule( + channels_in=cfg_dict.get("reconstruction_module_conv_filters")[0], + kernel_size=self.attention_module_kernel_size, + padding=self.attention_module_padding, + ) + for _ in range( + self.rs_cascades - 1 + ) # TODO: range is set to (rs_cascades-1) since the last cascade does not need a attention modules since the output will not be used. + ] + ) + if self.attention_module == "SemanticGuidanceModule": + self.attention_module_block = torch.nn.ModuleList( + [ + SASG( + channels_rec=cfg_dict.get("reconstruction_module_conv_filters")[0], + channels_seg=self.attention_module_segmentation_input, + kernel_size=self.attention_module_kernel_size, + padding=self.attention_module_padding, + ) + for _ in range( + self.rs_cascades - 1 + ) # TODO: range is set to (rs_cascades-1) since the last cascade does not need a attention modules since the output will not be used. + ] + ) - # pylint: disable=arguments-differ + # pylint: disable = arguments-differ @typecheck() def forward( self, @@ -146,7 +191,7 @@ def forward( Tuple containing the predicted reconstruction and segmentation. """ pred_reconstructions = [] - for cascade in self.rs_module: + for c, cascade in enumerate(self.rs_module): pred_reconstruction, pred_segmentation, hx = cascade( y=y, sensitivity_maps=sensitivity_maps, @@ -158,8 +203,7 @@ def forward( ) pred_reconstructions.append(pred_reconstruction) init_reconstruction_pred = pred_reconstruction[-1][-1] - - if self.task_adaption_type == "multi_task_learning": + if self.task_adaption_type == "multi_task_learning" and c != self.rs_cascades - 1: hidden_states = [ torch.cat( [torch.abs(init_reconstruction_pred.unsqueeze(self.coil_dim) * pred_segmentation)] @@ -173,8 +217,49 @@ def forward( if self.consecutive_slices > 1: hx = [x.unsqueeze(1) for x in hx] + if self.task_adaption_type == "multi_task_learning_softmax" and c != self.rs_cascades - 1: + if self.consecutive_slices > 1: + pred_segmentation_soft = torch.softmax(pred_segmentation, dim=2) + else: + pred_segmentation_soft = torch.softmax(pred_segmentation, dim=1) + if self.attention_module == "SemanticGuidanceModule" and self.attention_module_remove_background: + hidden_states = [ + pred_segmentation_soft[:, 1:] for _ in self.reconstruction_module_recurrent_filters + ] + elif self.attention_module == "SemanticGuidanceModule" and not self.attention_module_remove_background: + hidden_states = [pred_segmentation_soft for _ in self.reconstruction_module_recurrent_filters] + elif self.attention_module_remove_background: + hidden_states = [ + torch.cat( + [ + torch.abs(init_reconstruction_pred.unsqueeze(self.coil_dim)) + * torch.sum(pred_segmentation_soft[..., 1:, :, :], dim=1, keepdim=True) + ] + * f, + dim=self.coil_dim, + ) + for f in self.reconstruction_module_recurrent_filters + if f != 0 + ] + else: + hidden_states = [ + torch.cat( + [ + torch.abs(init_reconstruction_pred.unsqueeze(self.coil_dim)) + * torch.sum(pred_segmentation_soft, dim=1, keepdim=True) + ] + * f, + dim=self.coil_dim, + ) + for f in self.reconstruction_module_recurrent_filters + if f != 0 + ] + # Check if the concatenated hidden states are the same size as the hidden state of the RNN - if hidden_states[0].shape[self.coil_dim] != hx[0].shape[self.coil_dim]: + if ( + hidden_states[0].shape[self.coil_dim] != hx[0].shape[self.coil_dim] + and not self.attention_module == "SemanticGuidanceModule" + ): prev_hidden_states = hidden_states hidden_states = [] for hs in prev_hidden_states: @@ -185,8 +270,10 @@ def forward( dim=self.coil_dim, ) hidden_states.append(new_hidden_state) - - hx = [hx[i] + hidden_states[i] for i in range(len(hx))] + if self.attention_module: + hx = [self.attention_module_block[c](hx[i], hidden_states[i]) for i in range(len(hx))] + else: + hx = [hx[i] + hidden_states[i] for i in range(len(hx))] init_reconstruction_pred = torch.view_as_real(init_reconstruction_pred) @@ -284,19 +371,24 @@ def compute_reconstruction_loss(t, p, s): return loss_func(t, p) - if self.accumulate_predictions: + if self.reconstruction_module_accumulate_predictions: rs_cascades_weights = torch.logspace(-1, 0, steps=len(prediction)).to(target.device) rs_cascades_loss = [] for rs_cascade_pred in prediction: cascades_weights = torch.logspace(-1, 0, steps=len(rs_cascade_pred)).to(target.device) cascades_loss = [] for cascade_pred in rs_cascade_pred: - time_steps_weights = torch.logspace(-1, 0, steps=self.time_steps).to(target.device) + time_steps_weights = torch.logspace(-1, 0, steps=self.reconstruction_module_time_steps).to( + target.device + ) time_steps_loss = [ compute_reconstruction_loss(target, time_step_pred, sensitivity_maps) for time_step_pred in cascade_pred ] - cascade_loss = sum(x * w for x, w in zip(time_steps_loss, time_steps_weights)) / self.time_steps + cascade_loss = ( + sum(x * w for x, w in zip(time_steps_loss, time_steps_weights)) + / self.reconstruction_module_time_steps + ) cascades_loss.append(cascade_loss) rs_cascade_loss = sum(x * w for x, w in zip(cascades_loss, cascades_weights)) / len(rs_cascade_pred) rs_cascades_loss.append(rs_cascade_loss) diff --git a/atommic/collections/multitask/rs/nn/mtlrs_base/spatially_adaptive_semantic_guidance_module.py b/atommic/collections/multitask/rs/nn/mtlrs_base/spatially_adaptive_semantic_guidance_module.py new file mode 100644 index 00000000..b2089c05 --- /dev/null +++ b/atommic/collections/multitask/rs/nn/mtlrs_base/spatially_adaptive_semantic_guidance_module.py @@ -0,0 +1,111 @@ +# coding=utf-8 +__author__ = "Tim Paquaij" +import torch +import torch.nn as nn + + +class SASG(nn.Module): + """Spatial-Adapted Semantic Guidance + An attention model based on segmentation probabilities to enhance reconstruction features + Built based on the SASG module descibed in: + https://www.sciencedirect.com/science/article/abs/pii/S0169260724000415?via%3Dihub + """ + + def __init__( + self, + channels_rec: int, + channels_seg: int, + kernel_size: int | tuple = (1, 1), + padding: int | tuple = 0, + ): + """Inits :class:`SASG`. + + Parameters + ---------- + channels_rec : int + Number of reconstruction feature channels. + channels_seg : int + Number of segmentation classes. + kernel_size : int | tuple, optional + Size of the convolutional kernel. Default is 1 + padding : int | tuple, optional + Padding around all four sizes of the input. Default is 0. + """ + super().__init__() + self.conv = nn.Conv2d(channels_rec, channels_rec, kernel_size=kernel_size, padding=padding) + self.spade = SPADE(channels_rec, channels_seg, kernel_size, padding) + self.act = nn.LeakyReLU(negative_slope=0.2) + + def forward(self, rec_features: torch.Tensor, seg_features: torch.Tensor) -> torch.Tensor: + """Forward :class:`SASG`. + + Parameters + ---------- + rec_features : torch.Tensor + Tensor of the reconstruction features with size [batch_size, feature_channels, height, width] + seg_features : torch.Tensor + Tensor of the segmentation features with size [batch_size, nr_classes, height, width] + + Returns + ------- + new_rec_features : torch.Tensor + Tensor of the optimised reconstruction features with size [batch_size, feature_channels, height, width] + """ + + hidden_layers_features_s = self.spade(rec_features, seg_features) + hidden_layers_features_s = self.conv(self.act(hidden_layers_features_s)) + hidden_layers_features_s = self.spade(hidden_layers_features_s, seg_features) + hidden_layers_features_s = self.conv(self.act(hidden_layers_features_s)) + new_rec_features = hidden_layers_features_s + rec_features + return new_rec_features + + +class SPADE(nn.Module): + + def __init__( + self, + channels_rec: int, + channels_seg: int, + kernel_size: int | tuple = (1, 1), + padding: int | tuple = 0, + ): + """Inits :class:`SPADE`. + + Parameters + ---------- + channels_rec : int + Number of reconstruction feature channels. + channels_seg : int + Number of segmentation classes. + kernel_size : int | tuple, optional + Size of the convolutional kernel. Default is 1 + padding : int | tuple, optional + Padding around all four sizes of the input. Default is 0. + """ + super().__init__() + self.conv_1 = nn.Conv2d(channels_seg, channels_seg, kernel_size=kernel_size, padding=padding) + self.conv_2 = nn.Conv2d(channels_seg, channels_rec, kernel_size=kernel_size, padding=padding) + self.instance = nn.InstanceNorm2d(channels_rec) + self.act = nn.LeakyReLU(negative_slope=0.2) + + def forward(self, rec_features, seg_features) -> torch.Tensor: + """Forward :class:`SPADE`. + + Parameters + ---------- + rec_features : torch.Tensor + Tensor of the reconstruction features with size [batch_size, feature_channels, height, width] + seg_features : torch.Tensor + Tensor of the segmentation features with size [batch_size, nr_classes, height, width] + + Returns + ------- + new_rec_features : torch.Tensor + Tensor of the optimised reconstruction features with size [batch_size, feature_channels, height, width] + """ + hidden_layers_features = self.instance(rec_features) + segmentation_prob = self.act(self.conv_1(seg_features)) + new_rec_features = torch.mul(self.conv_2(segmentation_prob), hidden_layers_features) + self.conv_2( + segmentation_prob + ) + return new_rec_features diff --git a/atommic/collections/multitask/rs/nn/mtlrs_base/task_attention_module.py b/atommic/collections/multitask/rs/nn/mtlrs_base/task_attention_module.py new file mode 100644 index 00000000..6875b85d --- /dev/null +++ b/atommic/collections/multitask/rs/nn/mtlrs_base/task_attention_module.py @@ -0,0 +1,130 @@ +# coding=utf-8 +__author__ = "Tim Paquaij" +import torch +import torch.nn as nn + + +class TaskAttentionalModule(nn.Module): + """TaskAttentionalModule + An attention model that utilises two tensor with identical number of feature channels to enhance the common features. + + Built based on the TAM module described in: + https://openaccess.thecvf.com/content_ECCV_2018/papers/Zhenyu_Zhang_Joint_Task-Recursive_Learning_ECCV_2018_paper.pdf + """ + + def __init__( + self, + channels_in: int, + kernel_size: int | tuple = (1, 1), + padding: int | tuple = 0, + ): + """Inits :class:`TaskAttentionalModule`. + + Parameters + ---------- + channels_in : int + Number of feature channels. + kernel_size : int | tuple, optional + Size of the convolutional kernel. Default is 1 + padding : int | tuple, optional + Padding around all four sizes of the input. Default is 0. + """ + super().__init__() + self.balance_conv1 = nn.Conv2d( + int(channels_in * 2), int(channels_in), kernel_size=kernel_size, padding=padding + ) + self.balance_conv2 = nn.Conv2d(int(channels_in), int(channels_in), kernel_size=kernel_size, padding=padding) + self.residual_block = ResidualBlock(int(channels_in), int(channels_in)) + self.fc = nn.Conv2d(int(channels_in * 2), channels_in, kernel_size=kernel_size, padding=padding) + + def forward(self, rec_features: torch.Tensor, seg_features: torch.Tensor) -> torch.Tensor: + """Forward :class:`TaskAttentionModule`. + + Parameters + ---------- + rec_features : torch.Tensor + Tensor of the reconstruction features with size [batch_size, feature_channels, height, width] + seg_features : torch.Tensor + Tensor of the segmentation features with size [batch_size, feature_channels, height, width] + + Returns + ------- + new_rec_features : torch.Tensor + Tensor of the optimised reconstruction features with size [batch_size, feature_channels, height, width] + """ + # Balance unit + concat_features = torch.cat((rec_features, seg_features), dim=1) + balance_tensor = torch.sigmoid(self.balance_conv1(concat_features)) + balanced_output = self.balance_conv2(balance_tensor * rec_features + (1 - balance_tensor) * seg_features) + # Conv-deconvolution layers for spatial attention + res_block = torch.sigmoid(self.residual_block(balanced_output)) + # Generate gated features + gated_rec_features = (1 + res_block) * rec_features + gated_segmentation_features = (1 + res_block) * seg_features + # Concatenate and apply convolutional layer + concatenated_features = torch.cat((gated_rec_features, gated_segmentation_features), dim=1) + output = self.fc(concatenated_features) + return output + + +class ResidualBlock(nn.Module): + """ResidualBlock + A residual block with batch normalization and ReLU activation functions. + Copied and adapted from: + https://github.com/tengshaofeng/ResidualAttentionNetwork-pytorch/blob/master/Residual-Attention-Network/model/basic_layers.py + """ + + def __init__(self, input_channels: int, output_channels: int, stride: int | tuple = 1): + """Inits :class:`ResidualBlock`. + + Parameters + ---------- + input_channels : int + Input number of feature channels + output_channels : int + Output number of feature channels + stride : int | tuple, optional + Stide of convolution. Default is 1 + """ + super().__init__() + self.input_channels = input_channels + self.output_channels = output_channels + self.stride = stride + self.bn1 = nn.BatchNorm2d(input_channels) + self.relu = nn.ReLU(inplace=True) + self.conv1 = nn.Conv2d(input_channels, int(output_channels / 4), 1, 1, bias=False) + self.bn2 = nn.BatchNorm2d(int(output_channels / 4)) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(int(output_channels / 4), int(output_channels / 4), 3, stride, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(int(output_channels / 4)) + self.relu = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(int(output_channels / 4), output_channels, 1, 1, bias=False) + self.conv4 = nn.Conv2d(input_channels, output_channels, 1, stride, bias=False) + + def forward(self, input_features) -> torch.Tensor: + """Forward :class:`ResidualBlock`. + + Parameters + ---------- + input_features : torch.Tensor + Tensor of the combined features with size [batch_size, feature_channels, height, width] + + Returns + ------- + output_features : torch.Tensor + Tensor of the optimised combined features with size [batch_size, feature_channels, height, width] + """ + residual = input_features + out = self.bn1(input_features) + out1 = self.relu(out) + out = self.conv1(out1) + out = self.bn2(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn3(out) + out = self.relu(out) + out = self.conv3(out) + if (self.input_channels != self.output_channels) or (self.stride != 1): + residual = self.conv4(out1) + output_features += residual + return output_features diff --git a/docs/source/mri/collections.rst b/docs/source/mri/collections.rst index 1dc4641e..a551d699 100644 --- a/docs/source/mri/collections.rst +++ b/docs/source/mri/collections.rst @@ -251,6 +251,90 @@ Example configuration: normalize_segmentation_output: true unnormalize_loss_inputs: false unnormalize_log_outputs: false + +Multi-Task Learning for MRI Reconstruction and Segmentation with attentionmodule +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Multi-Task Learning for MRI Reconstruction and Segmentation with attentionmodule +(:class:`~atommic.collections.multitask.rs.nn.mtlrs.MTLRS`) + +Example configuration: + +.. code-block:: bash + + model: + model_name: MTLRS + joint_reconstruction_segmentation_module_cascades: 5 + task_adaption_type: multi_task_learning_softmax + use_reconstruction_module: true + reconstruction_module_recurrent_layer: IndRNN + reconstruction_module_conv_filters: + - 64 + - 64 + - 2 + reconstruction_module_conv_kernels: + - 5 + - 3 + - 3 + reconstruction_module_conv_dilations: + - 1 + - 2 + - 1 + reconstruction_module_conv_bias: + - true + - true + - false + reconstruction_module_recurrent_filters: + - 64 + - 64 + - 0 + reconstruction_module_recurrent_kernels: + - 1 + - 1 + - 0 + reconstruction_module_recurrent_dilations: + - 1 + - 1 + - 0 + reconstruction_module_recurrent_bias: + - true + - true + - false + reconstruction_module_depth: 2 + reconstruction_module_time_steps: 8 + reconstruction_module_conv_dim: 2 + reconstruction_module_num_cascades: 1 + reconstruction_module_dimensionality: 2 + reconstruction_module_no_dc: true + reconstruction_module_keep_prediction: true + reconstruction_module_accumulate_predictions: true + segmentation_module: AttentionUNet + segmentation_module_input_channels: 1 + segmentation_module_output_channels: 2 + segmentation_module_channels: 64 + segmentation_module_pooling_layers: 2 + segmentation_module_dropout: 0.0 + attention_module: SemanticGuidanceModule, + attention_module_kernel_size: 3, + attention_module_padding: 1, + attention_module_remove_background: true, + # task & dataset related parameters + coil_combination_method: SENSE + coil_dim: 1 + complex_valued_type: stacked # stacked, complex_abs, complex_sqrt_abs + complex_data: true + consecutive_slices: 1 + dimensionality: 2 + estimate_coil_sensitivity_maps_with_nn: false + fft_centered: false + fft_normalization: backward + spatial_dims: + - -2 + - -1 + magnitude_input: true + normalization_type: minmax + normalize_segmentation_output: true + unnormalize_loss_inputs: false + unnormalize_log_outputs: false Reconstruction Segmentation method using UNet ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 95865e4d..ba9f35a4 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,7 +1,7 @@ defusedxml>=0.7.1 einops>=0.5.0 h5py==3.9.0 -huggingface_hub +huggingface_hub<=0.20.3 hydra-core>1.3,<=1.3.2 nibabel==5.1.0 numba diff --git a/tests/collections/multitask/rs/models/test_mtlrs.py b/tests/collections/multitask/rs/models/test_mtlrs.py index aa88845d..311fc2aa 100644 --- a/tests/collections/multitask/rs/models/test_mtlrs.py +++ b/tests/collections/multitask/rs/models/test_mtlrs.py @@ -222,8 +222,157 @@ "max_steps": -1, }, ), + ( + [1, 3, 32, 16, 2], + { + "use_reconstruction_module": True, + "task_adaption_type": "multi_task_learning_softmax", + "joint_reconstruction_segmentation_module_cascades": 5, + "reconstruction_module_recurrent_layer": "IndRNN", + "reconstruction_module_conv_filters": [64, 64, 2], + "reconstruction_module_conv_kernels": [5, 3, 3], + "reconstruction_module_conv_dilations": [1, 2, 1], + "reconstruction_module_conv_bias": [True, True, False], + "reconstruction_module_recurrent_filters": [64, 64, 0], + "reconstruction_module_recurrent_kernels": [1, 1, 0], + "reconstruction_module_recurrent_dilations": [1, 1, 0], + "reconstruction_module_recurrent_bias": [True, True, False], + "reconstruction_module_depth": 2, + "reconstruction_module_conv_dim": 2, + "reconstruction_module_time_steps": 8, + "reconstruction_module_num_cascades": 1, + "reconstruction_module_dimensionality": 2, + "reconstruction_module_accumulate_predictions": True, + "reconstruction_module_no_dc": True, + "reconstruction_module_keep_prediction": True, + "reconstruction_loss": {"l1": 1.0}, + "segmentation_module": "ConvLayer", + "segmentation_module_input_channels": 1, + "segmentation_module_output_channels": 4, + "segmentation_module_channels": 64, + "segmentation_module_pooling_layers": 4, + "segmentation_module_dropout": 0.0, + "segmentation_loss": {"dice": 1.0}, + "reconstruction_module_accumulate_predictions": True, + "attention_module": "SemanticGuidanceModule", + "attention_module_kernel_size": 3, + "attention_module_padding": 1, + "attention_module_remove_background": True, + "dice_loss_include_background": False, + "dice_loss_to_onehot_y": False, + "dice_loss_sigmoid": True, + "dice_loss_softmax": False, + "dice_loss_other_act": None, + "dice_loss_squared_pred": False, + "dice_loss_jaccard": False, + "dice_loss_reduction": "mean", + "dice_loss_smooth_nr": 1, + "dice_loss_smooth_dr": 1, + "dice_loss_batch": True, + "consecutive_slices": 1, + "coil_combination_method": "SENSE", + "magnitude_input": True, + "use_sens_net": False, + "fft_centered": False, + "fft_normalization": "backward", + "spatial_dims": [-2, -1], + "coil_dim": 1, + "dimensionality": 2, + }, + [0.08], + [4], + 2, + 4, + { + "strategy": "ddp", + "accelerator": "cpu", + "num_nodes": 1, + "max_epochs": 20, + "precision": 32, + "enable_checkpointing": False, + "logger": False, + "log_every_n_steps": 50, + "check_val_every_n_epoch": -1, + "max_steps": -1, + }, + ), + ( + [1, 3, 32, 16, 2], + { + "use_reconstruction_module": True, + "task_adaption_type": "multi_task_learning", + "joint_reconstruction_segmentation_module_cascades": 5, + "reconstruction_module_recurrent_layer": "IndRNN", + "reconstruction_module_conv_filters": [64, 64, 2], + "reconstruction_module_conv_kernels": [5, 3, 3], + "reconstruction_module_conv_dilations": [1, 2, 1], + "reconstruction_module_conv_bias": [True, True, False], + "reconstruction_module_recurrent_filters": [64, 64, 0], + "reconstruction_module_recurrent_kernels": [1, 1, 0], + "reconstruction_module_recurrent_dilations": [1, 1, 0], + "reconstruction_module_recurrent_bias": [True, True, False], + "reconstruction_module_depth": 2, + "reconstruction_module_conv_dim": 2, + "reconstruction_module_time_steps": 8, + "reconstruction_module_num_cascades": 1, + "reconstruction_module_dimensionality": 2, + "reconstruction_module_accumulate_predictions": True, + "reconstruction_module_no_dc": True, + "reconstruction_module_keep_prediction": True, + "reconstruction_loss": {"l1": 1.0}, + "segmentation_module": "ConvLayer", + "segmentation_module_input_channels": 1, + "segmentation_module_output_channels": 4, + "segmentation_module_channels": 64, + "segmentation_module_pooling_layers": 4, + "segmentation_module_dropout": 0.0, + "segmentation_loss": {"dice": 1.0}, + "reconstruction_module_accumulate_predictions": True, + "attention_module": "TaskAttentionModule", + "attention_module_kernel_size": 3, + "attention_module_padding": 1, + "attention_module_remove_background": True, + "dice_loss_include_background": False, + "dice_loss_to_onehot_y": False, + "dice_loss_sigmoid": True, + "dice_loss_softmax": False, + "dice_loss_other_act": None, + "dice_loss_squared_pred": False, + "dice_loss_jaccard": False, + "dice_loss_reduction": "mean", + "dice_loss_smooth_nr": 1, + "dice_loss_smooth_dr": 1, + "dice_loss_batch": True, + "consecutive_slices": 1, + "coil_combination_method": "SENSE", + "magnitude_input": True, + "use_sens_net": False, + "fft_centered": False, + "fft_normalization": "backward", + "spatial_dims": [-2, -1], + "coil_dim": 1, + "dimensionality": 2, + }, + [0.08], + [4], + 2, + 4, + { + "strategy": "ddp", + "accelerator": "cpu", + "num_nodes": 1, + "max_epochs": 20, + "precision": 32, + "enable_checkpointing": False, + "logger": False, + "log_every_n_steps": 50, + "check_val_every_n_epoch": -1, + "max_steps": -1, + }, + ), ], ) + def test_mtlmrirs(shape, cfg, center_fractions, accelerations, dimensionality, segmentation_classes, trainer): """ Test MultiTask Learning for accelerated-MRI Reconstruction & Segmentation with different parameters. @@ -281,20 +430,25 @@ def test_mtlmrirs(shape, cfg, center_fractions, accelerations, dimensionality, s output.sum(coil_dim), ) - if cfg.get("accumulate_predictions"): - try: - pred_reconstruction = next(pred_reconstruction) - except StopIteration: - pass - - if isinstance(pred_reconstruction, list): - pred_reconstruction = pred_reconstruction[-1] + if cfg.get("reconstruction_module_accumulate_predictions"): + if len(pred_reconstruction) !=cfg.get("joint_reconstruction_segmentation_module_cascades"): + raise AssertionError("Number of predictions are not equal to the number of cascades") + if cfg.get("reconstruction_module_keep_prediction"): + if len(pred_reconstruction[0]) !=cfg.get("reconstruction_module_num_cascades"): + raise AssertionError("Number of predictions are not equal to the number of cascades") + if cfg.get("reconstruction_module_keep_prediction"): + if len(pred_reconstruction[0][0]) !=cfg.get("reconstruction_module_time_steps"): + raise AssertionError("Number of predictions are not equal to the number of intermediate predictions") + + if cfg.get("segmentation_module_accumulate_predictions"): + if len(pred_segmentation) !=cfg.get("joint_reconstruction_segmentation_module_cascades"): + raise AssertionError(f"Number of segmentations are not equal to the number of cascades") - if isinstance(pred_reconstruction, list): + while isinstance(pred_reconstruction, list): pred_reconstruction = pred_reconstruction[-1] - if isinstance(pred_reconstruction, list): - pred_reconstruction = pred_reconstruction[-1] + while isinstance(pred_segmentation, list): + pred_segmentation = pred_segmentation[-1] if dimensionality == 3 or consecutive_slices > 1: x = x.reshape([x.shape[0] * x.shape[1], x.shape[2], x.shape[3], x.shape[4], x.shape[5]])