From 66f9972cf2aa80d86a9f5d2d260f215753051ab1 Mon Sep 17 00:00:00 2001 From: Tim Paquaij <71269471+TimPaquaij@users.noreply.github.com> Date: Wed, 9 Oct 2024 10:35:25 +0200 Subject: [PATCH] initial commit weighted loss **.DS_Store Changed test script for accumulate test_script check and is working Update base.py process_segmentation_loss Update mtlrs update test_mtlrs.py update requirements.txt Update base.py fixes style base.py --- .gitignore | 2 +- atommic/collections/multitask/rs/nn/base.py | 27 ++++-- atommic/collections/multitask/rs/nn/mtlrs.py | 63 ++++++++++-- requirements/requirements.txt | 2 +- .../multitask/rs/models/test_mtlrs.py | 95 +++++++++++++++++-- 5 files changed, 159 insertions(+), 30 deletions(-) diff --git a/.gitignore b/.gitignore index 6500cd24..d13c4505 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,7 @@ __pycache__/ *.py[cod] *$py.class - +**.DS_Store # C extensions *.so diff --git a/atommic/collections/multitask/rs/nn/base.py b/atommic/collections/multitask/rs/nn/base.py index 5a0b506a..e12a100f 100644 --- a/atommic/collections/multitask/rs/nn/base.py +++ b/atommic/collections/multitask/rs/nn/base.py @@ -530,7 +530,9 @@ def compute_reconstruction_loss(t, p, s): return compute_reconstruction_loss(target, prediction, sensitivity_maps) - def process_segmentation_loss(self, target: torch.Tensor, prediction: torch.Tensor, attrs: Dict) -> Dict: + def process_segmentation_loss( + self, target: torch.Tensor, prediction: torch.Tensor, attrs: Dict, loss_func: torch.nn.Module + ) -> Dict: """Processes the segmentation loss. Parameters @@ -551,14 +553,11 @@ def process_segmentation_loss(self, target: torch.Tensor, prediction: torch.Tens """ if self.unnormalize_loss_inputs: target, prediction, _ = self.__unnormalize_for_loss_or_log__(target, prediction, None, attrs, attrs["r"]) - losses = {} - for name, loss_func in self.segmentation_losses.items(): - loss = loss_func(target, prediction) - if isinstance(loss, tuple): - # In case of the dice loss, the loss is a tuple of the form (dice, dice loss) - loss = loss[1] - losses[name] = loss - return self.total_segmentation_loss(**losses) * self.total_segmentation_loss_weight + loss = loss_func(target, prediction) + if isinstance(loss, tuple): + # In case of the dice loss, the loss is a tuple of the form (dice, dice loss) + loss = loss[1] + return loss def __compute_loss__( self, @@ -605,7 +604,15 @@ def __compute_loss__( batch_size, slices = target_segmentation.shape[:2] target_segmentation = target_segmentation.reshape(batch_size * slices, *target_segmentation.shape[2:]) - segmentation_loss = self.process_segmentation_loss(target_segmentation, predictions_segmentation, attrs) + losses = {} + for name, loss_func in self.segmentation_losses.items(): + losses[name] = self.process_segmentation_loss( + target_segmentation, + predictions_segmentation, + attrs, + loss_func, + ) + segmentation_loss = self.total_segmentation_loss(**losses) if self.use_reconstruction_module: if predictions_reconstruction_n2r is not None and not attrs["n2r_supervised"]: diff --git a/atommic/collections/multitask/rs/nn/mtlrs.py b/atommic/collections/multitask/rs/nn/mtlrs.py index 81daa2d4..8bde5d4c 100644 --- a/atommic/collections/multitask/rs/nn/mtlrs.py +++ b/atommic/collections/multitask/rs/nn/mtlrs.py @@ -1,6 +1,7 @@ # coding=utf-8 __author__ = "Dimitris Karkalousos" - +__editor__ = "Tim Paquaij" +__editor__date__ = "2024-10-09" from typing import Dict, List, Tuple, Union import torch @@ -48,6 +49,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.reconstruction_module_accumulate_predictions = cfg_dict.get( "reconstruction_module_accumulate_predictions" ) + self.segmentation_module_accumulate_predictions = cfg_dict.get("segmentation_module_accumulate_predictions") conv_dim = cfg_dict.get("reconstruction_module_conv_dim") reconstruction_module_params = { "num_cascades": self.reconstruction_module_num_cascades, @@ -146,6 +148,7 @@ def forward( Tuple containing the predicted reconstruction and segmentation. """ pred_reconstructions = [] + pred_segmentations = [] for cascade in self.rs_module: pred_reconstruction, pred_segmentation, hx = cascade( y=y, @@ -158,6 +161,7 @@ def forward( ) pred_reconstructions.append(pred_reconstruction) init_reconstruction_pred = pred_reconstruction[-1][-1] + pred_segmentations.append(pred_segmentation) if self.task_adaption_type == "multi_task_learning": hidden_states = [ @@ -190,7 +194,7 @@ def forward( init_reconstruction_pred = torch.view_as_real(init_reconstruction_pred) - return pred_reconstructions, pred_segmentation + return pred_reconstructions, pred_segmentations def process_reconstruction_loss( # noqa: MC0001 self, @@ -284,25 +288,68 @@ def compute_reconstruction_loss(t, p, s): return loss_func(t, p) - if self.accumulate_predictions: + if self.reconstruction_module_accumulate_predictions: rs_cascades_weights = torch.logspace(-1, 0, steps=len(prediction)).to(target.device) rs_cascades_loss = [] for rs_cascade_pred in prediction: cascades_weights = torch.logspace(-1, 0, steps=len(rs_cascade_pred)).to(target.device) cascades_loss = [] for cascade_pred in rs_cascade_pred: - time_steps_weights = torch.logspace(-1, 0, steps=self.time_steps).to(target.device) + time_steps_weights = torch.logspace(-1, 0, steps=len(cascade_pred)).to(target.device) time_steps_loss = [ compute_reconstruction_loss(target, time_step_pred, sensitivity_maps) for time_step_pred in cascade_pred ] - cascade_loss = sum(x * w for x, w in zip(time_steps_loss, time_steps_weights)) / self.time_steps - cascades_loss.append(cascade_loss) - rs_cascade_loss = sum(x * w for x, w in zip(cascades_loss, cascades_weights)) / len(rs_cascade_pred) + + cascade_loss = sum(x * w for x, w in zip(time_steps_loss, time_steps_weights)) / sum( + time_steps_weights + ) + cascades_loss.append(cascade_loss) + rs_cascade_loss = sum(x * w for x, w in zip(cascades_loss, cascades_weights)) / sum(cascades_weights) rs_cascades_loss.append(rs_cascade_loss) - loss = sum(x * w for x, w in zip(rs_cascades_loss, rs_cascades_weights)) / len(prediction) + loss = sum(x * w for x, w in zip(rs_cascades_loss, rs_cascades_weights)) / sum(rs_cascades_weights) else: # keep the last prediction of the last cascade of the last rs cascade prediction = prediction[-1][-1][-1] loss = compute_reconstruction_loss(target, prediction, sensitivity_maps) return loss + + def process_segmentation_loss( + self, target: torch.Tensor, prediction: torch.Tensor, attrs: Dict, loss_func: torch.nn.Module + ) -> Dict: + """Processes the segmentation loss. + + Parameters + ---------- + target : torch.Tensor + Target data of shape [batch_size, nr_classes, n_x, n_y]. + prediction : torch.Tensor + Prediction of shape [batch_size, nr_classes, n_x, n_y]. + attrs : Dict + Attributes of the data with pre normalization values. + + Returns + ------- + Dict + Dictionary containing the (multiple) loss values. For example, if the cross entropy loss and the dice loss + are used, the dictionary will contain the keys ``cross_entropy_loss``, ``dice_loss``, and + (combined) ``segmentation_loss``. + """ + if self.unnormalize_loss_inputs: + target, prediction, _ = self.__unnormalize_for_loss_or_log__(target, prediction, None, attrs, attrs["r"]) + + if self.segmentation_module_accumulate_predictions: + rs_cascades_weights = torch.logspace(-1, 0, steps=len(prediction)).to(target.device) + rs_cascades_loss = [] + for pred in prediction: + loss = loss_func(target, pred) + if isinstance(loss, tuple): + loss = loss[1] + rs_cascades_loss.append(loss) + loss = sum(x * w for x, w in zip(rs_cascades_loss, rs_cascades_weights)) / sum(rs_cascades_weights) + else: + prediction = prediction[-1] + loss = loss_func(target, prediction) + if isinstance(loss, tuple): + loss = loss[1][0] + return loss diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 95865e4d..ba9f35a4 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,7 +1,7 @@ defusedxml>=0.7.1 einops>=0.5.0 h5py==3.9.0 -huggingface_hub +huggingface_hub<=0.20.3 hydra-core>1.3,<=1.3.2 nibabel==5.1.0 numba diff --git a/tests/collections/multitask/rs/models/test_mtlrs.py b/tests/collections/multitask/rs/models/test_mtlrs.py index aa88845d..af719d6d 100644 --- a/tests/collections/multitask/rs/models/test_mtlrs.py +++ b/tests/collections/multitask/rs/models/test_mtlrs.py @@ -222,6 +222,76 @@ "max_steps": -1, }, ), + ( + [1, 3, 32, 16, 2], + { + "use_reconstruction_module": True, + "task_adaption_type": "multi_task_learning", + "joint_reconstruction_segmentation_module_cascades": 5, + "reconstruction_module_recurrent_layer": "IndRNN", + "reconstruction_module_conv_filters": [64, 64, 2], + "reconstruction_module_conv_kernels": [5, 3, 3], + "reconstruction_module_conv_dilations": [1, 2, 1], + "reconstruction_module_conv_bias": [True, True, False], + "reconstruction_module_recurrent_filters": [64, 64, 0], + "reconstruction_module_recurrent_kernels": [1, 1, 0], + "reconstruction_module_recurrent_dilations": [1, 1, 0], + "reconstruction_module_recurrent_bias": [True, True, False], + "reconstruction_module_depth": 2, + "reconstruction_module_conv_dim": 2, + "reconstruction_module_time_steps": 8, + "reconstruction_module_num_cascades": 5, + "reconstruction_module_dimensionality": 2, + "reconstruction_module_accumulate_predictions": True, + "reconstruction_module_no_dc": True, + "reconstruction_module_keep_prediction": True, + "reconstruction_loss": {"l1": 1.0}, + "segmentation_module": "ConvLayer", + "segmentation_module_input_channels": 1, + "segmentation_module_output_channels": 4, + "segmentation_module_channels": 64, + "segmentation_module_pooling_layers": 4, + "segmentation_module_dropout": 0.0, + "segmentation_module_accumulate_predictions": True, + "segmentation_loss": {"dice": 1.0}, + "dice_loss_include_background": False, + "dice_loss_to_onehot_y": False, + "dice_loss_sigmoid": True, + "dice_loss_softmax": False, + "dice_loss_other_act": None, + "dice_loss_squared_pred": False, + "dice_loss_jaccard": False, + "dice_loss_reduction": "mean", + "dice_loss_smooth_nr": 1, + "dice_loss_smooth_dr": 1, + "dice_loss_batch": True, + "consecutive_slices": 1, + "coil_combination_method": "SENSE", + "magnitude_input": True, + "use_sens_net": False, + "fft_centered": False, + "fft_normalization": "backward", + "spatial_dims": [-2, -1], + "coil_dim": 1, + "dimensionality": 2, + }, + [0.08], + [4], + 2, + 4, + { + "strategy": "ddp", + "accelerator": "cpu", + "num_nodes": 1, + "max_epochs": 20, + "precision": 32, + "enable_checkpointing": False, + "logger": False, + "log_every_n_steps": 50, + "check_val_every_n_epoch": -1, + "max_steps": -1, + }, + ), ], ) def test_mtlmrirs(shape, cfg, center_fractions, accelerations, dimensionality, segmentation_classes, trainer): @@ -281,20 +351,25 @@ def test_mtlmrirs(shape, cfg, center_fractions, accelerations, dimensionality, s output.sum(coil_dim), ) - if cfg.get("accumulate_predictions"): - try: - pred_reconstruction = next(pred_reconstruction) - except StopIteration: - pass + if cfg.get("reconstruction_module_accumulate_predictions"): + if len(pred_reconstruction) != cfg.get("joint_reconstruction_segmentation_module_cascades"): + raise AssertionError("Number of predictions are not equal to the number of cascades") + if cfg.get("reconstruction_module_keep_prediction"): + if len(pred_reconstruction[0]) != cfg.get("reconstruction_module_num_cascades"): + raise AssertionError("Number of predictions are not equal to the number of cascades") + if cfg.get("reconstruction_module_keep_prediction"): + if len(pred_reconstruction[0][0]) != cfg.get("reconstruction_module_time_steps"): + raise AssertionError("Number of predictions are not equal to the number of intermediate predictions") - if isinstance(pred_reconstruction, list): - pred_reconstruction = pred_reconstruction[-1] + if cfg.get("segmentation_module_accumulate_predictions"): + if len(pred_segmentation) != cfg.get("joint_reconstruction_segmentation_module_cascades"): + raise AssertionError(f"Number of segmentations are not equal to the number of cascades") - if isinstance(pred_reconstruction, list): + while isinstance(pred_reconstruction, list): pred_reconstruction = pred_reconstruction[-1] - if isinstance(pred_reconstruction, list): - pred_reconstruction = pred_reconstruction[-1] + while isinstance(pred_segmentation, list): + pred_segmentation = pred_segmentation[-1] if dimensionality == 3 or consecutive_slices > 1: x = x.reshape([x.shape[0] * x.shape[1], x.shape[2], x.shape[3], x.shape[4], x.shape[5]])