Skip to content

Commit

Permalink
Initial commit multi-task attention modules
Browse files Browse the repository at this point in the history
**.DS_Store

Initial commit multi-task attention modules

Added attention modules within MTLRS script

Updated test file for multi-task (Passed)

added accumulate_segmentation and refine test file

added kernel_size and padding to cfg

Error fixed during vaidation step

Update after fixing error's in training loop

Training and validation works for Task Attention

Deleted propagation_module not used in own reasearch

added option to remove or include background for attetnion meganism

Update docstrings and spacing

Update docstrings

update requirements

update doc

update mtlrs

Update branch to prevent issues when merging

update mtlrs.py

style update

style base.py
  • Loading branch information
TimPaquaij committed Nov 26, 2024
1 parent d34fcad commit c99000e
Show file tree
Hide file tree
Showing 7 changed files with 595 additions and 24 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
__pycache__/
*.py[cod]
*$py.class

**.DS_Store
# C extensions
*.so

Expand Down
114 changes: 103 additions & 11 deletions atommic/collections/multitask/rs/nn/mtlrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
__author__ = "Dimitris Karkalousos"

from typing import Dict, List, Tuple, Union

import warnings
import torch
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Trainer
Expand All @@ -12,6 +12,8 @@
from atommic.collections.multitask.rs.nn.base import BaseMRIReconstructionSegmentationModel
from atommic.collections.multitask.rs.nn.mtlrs_base.mtlrs_block import MTLRSBlock
from atommic.core.classes.common import typecheck
from atommic.collections.multitask.rs.nn.mtlrs_base.task_attention_module import TaskAttentionalModule
from atommic.collections.multitask.rs.nn.mtlrs_base.spatially_adaptive_semantic_guidance_module import SASG

__all__ = ["MTLRS"]

Expand Down Expand Up @@ -108,8 +110,51 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
)

self.task_adaption_type = cfg_dict.get("task_adaption_type", "multi_task_learning")
self.attention_module = cfg_dict.get("attention_module", False)
self.attention_module_kernel_size = cfg_dict.get("attention_module_kernel_size", 3)
self.attention_module_padding = cfg_dict.get("attention_module_padding", 1)
self.attention_module_remove_background = cfg_dict.get("attention_module_remove_background", True)
self.attention_module_segmentation_input = (
self.segmentation_module_output_channels - 1
if self.attention_module_remove_background
else self.segmentation_module_output_channels
)
if self.attention_module == "SemanticGuidanceModule" and self.task_adaption_type == "multi_task_learning":
self.task_adaption_type = "multi_task_learning_softmax"
warnings.warn(
"Logits can not be used with Spatially Semantic Guidance attention modules."
" Logits will be transfomed into probabilities with softmax."
f"task_adaption_type: {self.task_adaption_type}"
)
if self.attention_module == "TaskAttentionModule":
self.attention_module_block = torch.nn.ModuleList(
[
TaskAttentionalModule(
channels_in=cfg_dict.get("reconstruction_module_conv_filters")[0],
kernel_size=self.attention_module_kernel_size,
padding=self.attention_module_padding,
)
for _ in range(
self.rs_cascades - 1
) # TODO: range is set to (rs_cascades-1) since the last cascade does not need a attention modules since the output will not be used.
]
)
if self.attention_module == "SemanticGuidanceModule":
self.attention_module_block = torch.nn.ModuleList(
[
SASG(
channels_rec=cfg_dict.get("reconstruction_module_conv_filters")[0],
channels_seg=self.attention_module_segmentation_input,
kernel_size=self.attention_module_kernel_size,
padding=self.attention_module_padding,
)
for _ in range(
self.rs_cascades - 1
) # TODO: range is set to (rs_cascades-1) since the last cascade does not need a attention modules since the output will not be used.
]
)

# pylint: disable=arguments-differ
# pylint: disable = arguments-differ
@typecheck()
def forward(
self,
Expand Down Expand Up @@ -146,7 +191,7 @@ def forward(
Tuple containing the predicted reconstruction and segmentation.
"""
pred_reconstructions = []
for cascade in self.rs_module:
for c, cascade in enumerate(self.rs_module):
pred_reconstruction, pred_segmentation, hx = cascade(
y=y,
sensitivity_maps=sensitivity_maps,
Expand All @@ -158,8 +203,7 @@ def forward(
)
pred_reconstructions.append(pred_reconstruction)
init_reconstruction_pred = pred_reconstruction[-1][-1]

if self.task_adaption_type == "multi_task_learning":
if self.task_adaption_type == "multi_task_learning" and c != self.rs_cascades - 1:
hidden_states = [
torch.cat(
[torch.abs(init_reconstruction_pred.unsqueeze(self.coil_dim) * pred_segmentation)]
Expand All @@ -173,8 +217,49 @@ def forward(
if self.consecutive_slices > 1:
hx = [x.unsqueeze(1) for x in hx]

if self.task_adaption_type == "multi_task_learning_softmax" and c != self.rs_cascades - 1:
if self.consecutive_slices > 1:
pred_segmentation_soft = torch.softmax(pred_segmentation, dim=2)
else:
pred_segmentation_soft = torch.softmax(pred_segmentation, dim=1)
if self.attention_module == "SemanticGuidanceModule" and self.attention_module_remove_background:
hidden_states = [
pred_segmentation_soft[:, 1:] for _ in self.reconstruction_module_recurrent_filters
]
elif self.attention_module == "SemanticGuidanceModule" and not self.attention_module_remove_background:
hidden_states = [pred_segmentation_soft for _ in self.reconstruction_module_recurrent_filters]
elif self.attention_module_remove_background:
hidden_states = [
torch.cat(
[
torch.abs(init_reconstruction_pred.unsqueeze(self.coil_dim))
* torch.sum(pred_segmentation_soft[..., 1:, :, :], dim=1, keepdim=True)
]
* f,
dim=self.coil_dim,
)
for f in self.reconstruction_module_recurrent_filters
if f != 0
]
else:
hidden_states = [
torch.cat(
[
torch.abs(init_reconstruction_pred.unsqueeze(self.coil_dim))
* torch.sum(pred_segmentation_soft, dim=1, keepdim=True)
]
* f,
dim=self.coil_dim,
)
for f in self.reconstruction_module_recurrent_filters
if f != 0
]

# Check if the concatenated hidden states are the same size as the hidden state of the RNN
if hidden_states[0].shape[self.coil_dim] != hx[0].shape[self.coil_dim]:
if (
hidden_states[0].shape[self.coil_dim] != hx[0].shape[self.coil_dim]
and not self.attention_module == "SemanticGuidanceModule"
):
prev_hidden_states = hidden_states
hidden_states = []
for hs in prev_hidden_states:
Expand All @@ -185,8 +270,10 @@ def forward(
dim=self.coil_dim,
)
hidden_states.append(new_hidden_state)

hx = [hx[i] + hidden_states[i] for i in range(len(hx))]
if self.attention_module:
hx = [self.attention_module_block[c](hx[i], hidden_states[i]) for i in range(len(hx))]
else:
hx = [hx[i] + hidden_states[i] for i in range(len(hx))]

init_reconstruction_pred = torch.view_as_real(init_reconstruction_pred)

Expand Down Expand Up @@ -284,19 +371,24 @@ def compute_reconstruction_loss(t, p, s):

return loss_func(t, p)

if self.accumulate_predictions:
if self.reconstruction_module_accumulate_predictions:
rs_cascades_weights = torch.logspace(-1, 0, steps=len(prediction)).to(target.device)
rs_cascades_loss = []
for rs_cascade_pred in prediction:
cascades_weights = torch.logspace(-1, 0, steps=len(rs_cascade_pred)).to(target.device)
cascades_loss = []
for cascade_pred in rs_cascade_pred:
time_steps_weights = torch.logspace(-1, 0, steps=self.time_steps).to(target.device)
time_steps_weights = torch.logspace(-1, 0, steps=self.reconstruction_module_time_steps).to(
target.device
)
time_steps_loss = [
compute_reconstruction_loss(target, time_step_pred, sensitivity_maps)
for time_step_pred in cascade_pred
]
cascade_loss = sum(x * w for x, w in zip(time_steps_loss, time_steps_weights)) / self.time_steps
cascade_loss = (
sum(x * w for x, w in zip(time_steps_loss, time_steps_weights))
/ self.reconstruction_module_time_steps
)
cascades_loss.append(cascade_loss)
rs_cascade_loss = sum(x * w for x, w in zip(cascades_loss, cascades_weights)) / len(rs_cascade_pred)
rs_cascades_loss.append(rs_cascade_loss)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# coding=utf-8
__author__ = "Tim Paquaij"
import torch
import torch.nn as nn


class SASG(nn.Module):
"""Spatial-Adapted Semantic Guidance
An attention model based on segmentation probabilities to enhance reconstruction features
Built based on the SASG module descibed in:
https://www.sciencedirect.com/science/article/abs/pii/S0169260724000415?via%3Dihub
"""

def __init__(
self,
channels_rec: int,
channels_seg: int,
kernel_size: int | tuple = (1, 1),
padding: int | tuple = 0,
):
"""Inits :class:`SASG`.
Parameters
----------
channels_rec : int
Number of reconstruction feature channels.
channels_seg : int
Number of segmentation classes.
kernel_size : int | tuple, optional
Size of the convolutional kernel. Default is 1
padding : int | tuple, optional
Padding around all four sizes of the input. Default is 0.
"""
super().__init__()
self.conv = nn.Conv2d(channels_rec, channels_rec, kernel_size=kernel_size, padding=padding)
self.spade = SPADE(channels_rec, channels_seg, kernel_size, padding)
self.act = nn.LeakyReLU(negative_slope=0.2)

def forward(self, rec_features: torch.Tensor, seg_features: torch.Tensor) -> torch.Tensor:
"""Forward :class:`SASG`.
Parameters
----------
rec_features : torch.Tensor
Tensor of the reconstruction features with size [batch_size, feature_channels, height, width]
seg_features : torch.Tensor
Tensor of the segmentation features with size [batch_size, nr_classes, height, width]
Returns
-------
new_rec_features : torch.Tensor
Tensor of the optimised reconstruction features with size [batch_size, feature_channels, height, width]
"""

hidden_layers_features_s = self.spade(rec_features, seg_features)
hidden_layers_features_s = self.conv(self.act(hidden_layers_features_s))
hidden_layers_features_s = self.spade(hidden_layers_features_s, seg_features)
hidden_layers_features_s = self.conv(self.act(hidden_layers_features_s))
new_rec_features = hidden_layers_features_s + rec_features
return new_rec_features


class SPADE(nn.Module):

def __init__(
self,
channels_rec: int,
channels_seg: int,
kernel_size: int | tuple = (1, 1),
padding: int | tuple = 0,
):
"""Inits :class:`SPADE`.
Parameters
----------
channels_rec : int
Number of reconstruction feature channels.
channels_seg : int
Number of segmentation classes.
kernel_size : int | tuple, optional
Size of the convolutional kernel. Default is 1
padding : int | tuple, optional
Padding around all four sizes of the input. Default is 0.
"""
super().__init__()
self.conv_1 = nn.Conv2d(channels_seg, channels_seg, kernel_size=kernel_size, padding=padding)
self.conv_2 = nn.Conv2d(channels_seg, channels_rec, kernel_size=kernel_size, padding=padding)
self.instance = nn.InstanceNorm2d(channels_rec)
self.act = nn.LeakyReLU(negative_slope=0.2)

def forward(self, rec_features, seg_features) -> torch.Tensor:
"""Forward :class:`SPADE`.
Parameters
----------
rec_features : torch.Tensor
Tensor of the reconstruction features with size [batch_size, feature_channels, height, width]
seg_features : torch.Tensor
Tensor of the segmentation features with size [batch_size, nr_classes, height, width]
Returns
-------
new_rec_features : torch.Tensor
Tensor of the optimised reconstruction features with size [batch_size, feature_channels, height, width]
"""
hidden_layers_features = self.instance(rec_features)
segmentation_prob = self.act(self.conv_1(seg_features))
new_rec_features = torch.mul(self.conv_2(segmentation_prob), hidden_layers_features) + self.conv_2(
segmentation_prob
)
return new_rec_features
Loading

0 comments on commit c99000e

Please sign in to comment.