Skip to content

Commit

Permalink
initial commit weighted loss
Browse files Browse the repository at this point in the history
**.DS_Store

Changed test script for accumulate

test_script check and is working

Update base.py process_segmentation_loss

Update mtlrs

update test_mtlrs.py

update requirements.txt

Update base.py

fixes style base.py
  • Loading branch information
TimPaquaij committed Nov 26, 2024
1 parent d34fcad commit 66f9972
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 30 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
__pycache__/
*.py[cod]
*$py.class

**.DS_Store
# C extensions
*.so

Expand Down
27 changes: 17 additions & 10 deletions atommic/collections/multitask/rs/nn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,9 @@ def compute_reconstruction_loss(t, p, s):

return compute_reconstruction_loss(target, prediction, sensitivity_maps)

def process_segmentation_loss(self, target: torch.Tensor, prediction: torch.Tensor, attrs: Dict) -> Dict:
def process_segmentation_loss(
self, target: torch.Tensor, prediction: torch.Tensor, attrs: Dict, loss_func: torch.nn.Module
) -> Dict:
"""Processes the segmentation loss.
Parameters
Expand All @@ -551,14 +553,11 @@ def process_segmentation_loss(self, target: torch.Tensor, prediction: torch.Tens
"""
if self.unnormalize_loss_inputs:
target, prediction, _ = self.__unnormalize_for_loss_or_log__(target, prediction, None, attrs, attrs["r"])
losses = {}
for name, loss_func in self.segmentation_losses.items():
loss = loss_func(target, prediction)
if isinstance(loss, tuple):
# In case of the dice loss, the loss is a tuple of the form (dice, dice loss)
loss = loss[1]
losses[name] = loss
return self.total_segmentation_loss(**losses) * self.total_segmentation_loss_weight
loss = loss_func(target, prediction)
if isinstance(loss, tuple):
# In case of the dice loss, the loss is a tuple of the form (dice, dice loss)
loss = loss[1]
return loss

def __compute_loss__(
self,
Expand Down Expand Up @@ -605,7 +604,15 @@ def __compute_loss__(
batch_size, slices = target_segmentation.shape[:2]
target_segmentation = target_segmentation.reshape(batch_size * slices, *target_segmentation.shape[2:])

segmentation_loss = self.process_segmentation_loss(target_segmentation, predictions_segmentation, attrs)
losses = {}
for name, loss_func in self.segmentation_losses.items():
losses[name] = self.process_segmentation_loss(
target_segmentation,
predictions_segmentation,
attrs,
loss_func,
)
segmentation_loss = self.total_segmentation_loss(**losses)

if self.use_reconstruction_module:
if predictions_reconstruction_n2r is not None and not attrs["n2r_supervised"]:
Expand Down
63 changes: 55 additions & 8 deletions atommic/collections/multitask/rs/nn/mtlrs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# coding=utf-8
__author__ = "Dimitris Karkalousos"

__editor__ = "Tim Paquaij"
__editor__date__ = "2024-10-09"
from typing import Dict, List, Tuple, Union

import torch
Expand Down Expand Up @@ -48,6 +49,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self.reconstruction_module_accumulate_predictions = cfg_dict.get(
"reconstruction_module_accumulate_predictions"
)
self.segmentation_module_accumulate_predictions = cfg_dict.get("segmentation_module_accumulate_predictions")
conv_dim = cfg_dict.get("reconstruction_module_conv_dim")
reconstruction_module_params = {
"num_cascades": self.reconstruction_module_num_cascades,
Expand Down Expand Up @@ -146,6 +148,7 @@ def forward(
Tuple containing the predicted reconstruction and segmentation.
"""
pred_reconstructions = []
pred_segmentations = []
for cascade in self.rs_module:
pred_reconstruction, pred_segmentation, hx = cascade(
y=y,
Expand All @@ -158,6 +161,7 @@ def forward(
)
pred_reconstructions.append(pred_reconstruction)
init_reconstruction_pred = pred_reconstruction[-1][-1]
pred_segmentations.append(pred_segmentation)

if self.task_adaption_type == "multi_task_learning":
hidden_states = [
Expand Down Expand Up @@ -190,7 +194,7 @@ def forward(

init_reconstruction_pred = torch.view_as_real(init_reconstruction_pred)

return pred_reconstructions, pred_segmentation
return pred_reconstructions, pred_segmentations

def process_reconstruction_loss( # noqa: MC0001
self,
Expand Down Expand Up @@ -284,25 +288,68 @@ def compute_reconstruction_loss(t, p, s):

return loss_func(t, p)

if self.accumulate_predictions:
if self.reconstruction_module_accumulate_predictions:
rs_cascades_weights = torch.logspace(-1, 0, steps=len(prediction)).to(target.device)
rs_cascades_loss = []
for rs_cascade_pred in prediction:
cascades_weights = torch.logspace(-1, 0, steps=len(rs_cascade_pred)).to(target.device)
cascades_loss = []
for cascade_pred in rs_cascade_pred:
time_steps_weights = torch.logspace(-1, 0, steps=self.time_steps).to(target.device)
time_steps_weights = torch.logspace(-1, 0, steps=len(cascade_pred)).to(target.device)
time_steps_loss = [
compute_reconstruction_loss(target, time_step_pred, sensitivity_maps)
for time_step_pred in cascade_pred
]
cascade_loss = sum(x * w for x, w in zip(time_steps_loss, time_steps_weights)) / self.time_steps
cascades_loss.append(cascade_loss)
rs_cascade_loss = sum(x * w for x, w in zip(cascades_loss, cascades_weights)) / len(rs_cascade_pred)

cascade_loss = sum(x * w for x, w in zip(time_steps_loss, time_steps_weights)) / sum(
time_steps_weights
)
cascades_loss.append(cascade_loss)
rs_cascade_loss = sum(x * w for x, w in zip(cascades_loss, cascades_weights)) / sum(cascades_weights)
rs_cascades_loss.append(rs_cascade_loss)
loss = sum(x * w for x, w in zip(rs_cascades_loss, rs_cascades_weights)) / len(prediction)
loss = sum(x * w for x, w in zip(rs_cascades_loss, rs_cascades_weights)) / sum(rs_cascades_weights)
else:
# keep the last prediction of the last cascade of the last rs cascade
prediction = prediction[-1][-1][-1]
loss = compute_reconstruction_loss(target, prediction, sensitivity_maps)
return loss

def process_segmentation_loss(
self, target: torch.Tensor, prediction: torch.Tensor, attrs: Dict, loss_func: torch.nn.Module
) -> Dict:
"""Processes the segmentation loss.
Parameters
----------
target : torch.Tensor
Target data of shape [batch_size, nr_classes, n_x, n_y].
prediction : torch.Tensor
Prediction of shape [batch_size, nr_classes, n_x, n_y].
attrs : Dict
Attributes of the data with pre normalization values.
Returns
-------
Dict
Dictionary containing the (multiple) loss values. For example, if the cross entropy loss and the dice loss
are used, the dictionary will contain the keys ``cross_entropy_loss``, ``dice_loss``, and
(combined) ``segmentation_loss``.
"""
if self.unnormalize_loss_inputs:
target, prediction, _ = self.__unnormalize_for_loss_or_log__(target, prediction, None, attrs, attrs["r"])

if self.segmentation_module_accumulate_predictions:
rs_cascades_weights = torch.logspace(-1, 0, steps=len(prediction)).to(target.device)
rs_cascades_loss = []
for pred in prediction:
loss = loss_func(target, pred)
if isinstance(loss, tuple):
loss = loss[1]
rs_cascades_loss.append(loss)
loss = sum(x * w for x, w in zip(rs_cascades_loss, rs_cascades_weights)) / sum(rs_cascades_weights)
else:
prediction = prediction[-1]
loss = loss_func(target, prediction)
if isinstance(loss, tuple):
loss = loss[1][0]
return loss
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
defusedxml>=0.7.1
einops>=0.5.0
h5py==3.9.0
huggingface_hub
huggingface_hub<=0.20.3
hydra-core>1.3,<=1.3.2
nibabel==5.1.0
numba
Expand Down
95 changes: 85 additions & 10 deletions tests/collections/multitask/rs/models/test_mtlrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,76 @@
"max_steps": -1,
},
),
(
[1, 3, 32, 16, 2],
{
"use_reconstruction_module": True,
"task_adaption_type": "multi_task_learning",
"joint_reconstruction_segmentation_module_cascades": 5,
"reconstruction_module_recurrent_layer": "IndRNN",
"reconstruction_module_conv_filters": [64, 64, 2],
"reconstruction_module_conv_kernels": [5, 3, 3],
"reconstruction_module_conv_dilations": [1, 2, 1],
"reconstruction_module_conv_bias": [True, True, False],
"reconstruction_module_recurrent_filters": [64, 64, 0],
"reconstruction_module_recurrent_kernels": [1, 1, 0],
"reconstruction_module_recurrent_dilations": [1, 1, 0],
"reconstruction_module_recurrent_bias": [True, True, False],
"reconstruction_module_depth": 2,
"reconstruction_module_conv_dim": 2,
"reconstruction_module_time_steps": 8,
"reconstruction_module_num_cascades": 5,
"reconstruction_module_dimensionality": 2,
"reconstruction_module_accumulate_predictions": True,
"reconstruction_module_no_dc": True,
"reconstruction_module_keep_prediction": True,
"reconstruction_loss": {"l1": 1.0},
"segmentation_module": "ConvLayer",
"segmentation_module_input_channels": 1,
"segmentation_module_output_channels": 4,
"segmentation_module_channels": 64,
"segmentation_module_pooling_layers": 4,
"segmentation_module_dropout": 0.0,
"segmentation_module_accumulate_predictions": True,
"segmentation_loss": {"dice": 1.0},
"dice_loss_include_background": False,
"dice_loss_to_onehot_y": False,
"dice_loss_sigmoid": True,
"dice_loss_softmax": False,
"dice_loss_other_act": None,
"dice_loss_squared_pred": False,
"dice_loss_jaccard": False,
"dice_loss_reduction": "mean",
"dice_loss_smooth_nr": 1,
"dice_loss_smooth_dr": 1,
"dice_loss_batch": True,
"consecutive_slices": 1,
"coil_combination_method": "SENSE",
"magnitude_input": True,
"use_sens_net": False,
"fft_centered": False,
"fft_normalization": "backward",
"spatial_dims": [-2, -1],
"coil_dim": 1,
"dimensionality": 2,
},
[0.08],
[4],
2,
4,
{
"strategy": "ddp",
"accelerator": "cpu",
"num_nodes": 1,
"max_epochs": 20,
"precision": 32,
"enable_checkpointing": False,
"logger": False,
"log_every_n_steps": 50,
"check_val_every_n_epoch": -1,
"max_steps": -1,
},
),
],
)
def test_mtlmrirs(shape, cfg, center_fractions, accelerations, dimensionality, segmentation_classes, trainer):
Expand Down Expand Up @@ -281,20 +351,25 @@ def test_mtlmrirs(shape, cfg, center_fractions, accelerations, dimensionality, s
output.sum(coil_dim),
)

if cfg.get("accumulate_predictions"):
try:
pred_reconstruction = next(pred_reconstruction)
except StopIteration:
pass
if cfg.get("reconstruction_module_accumulate_predictions"):
if len(pred_reconstruction) != cfg.get("joint_reconstruction_segmentation_module_cascades"):
raise AssertionError("Number of predictions are not equal to the number of cascades")
if cfg.get("reconstruction_module_keep_prediction"):
if len(pred_reconstruction[0]) != cfg.get("reconstruction_module_num_cascades"):
raise AssertionError("Number of predictions are not equal to the number of cascades")
if cfg.get("reconstruction_module_keep_prediction"):
if len(pred_reconstruction[0][0]) != cfg.get("reconstruction_module_time_steps"):
raise AssertionError("Number of predictions are not equal to the number of intermediate predictions")

if isinstance(pred_reconstruction, list):
pred_reconstruction = pred_reconstruction[-1]
if cfg.get("segmentation_module_accumulate_predictions"):
if len(pred_segmentation) != cfg.get("joint_reconstruction_segmentation_module_cascades"):
raise AssertionError(f"Number of segmentations are not equal to the number of cascades")

if isinstance(pred_reconstruction, list):
while isinstance(pred_reconstruction, list):
pred_reconstruction = pred_reconstruction[-1]

if isinstance(pred_reconstruction, list):
pred_reconstruction = pred_reconstruction[-1]
while isinstance(pred_segmentation, list):
pred_segmentation = pred_segmentation[-1]

if dimensionality == 3 or consecutive_slices > 1:
x = x.reshape([x.shape[0] * x.shape[1], x.shape[2], x.shape[3], x.shape[4], x.shape[5]])
Expand Down

0 comments on commit 66f9972

Please sign in to comment.