diff --git a/.gitignore b/.gitignore index 6500cd24..d13c4505 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,7 @@ __pycache__/ *.py[cod] *$py.class - +**.DS_Store # C extensions *.so diff --git a/atommic/collections/common/data/mri_loader.py b/atommic/collections/common/data/mri_loader.py index fac98ee1..b6b80ab0 100644 --- a/atommic/collections/common/data/mri_loader.py +++ b/atommic/collections/common/data/mri_loader.py @@ -226,7 +226,9 @@ def __init__( # noqa: MC0001 self.examples = [ex for ex in self.examples if ex[2]["encoding_size"][1] in num_cols] self.indices_to_log = np.random.choice( - len(self.examples), int(log_images_rate * len(self.examples)), replace=False # type: ignore + [example[1] for example in self.examples], + int(log_images_rate * len(self.examples)), # type: ignore + replace=False, ) def _retrieve_metadata(self, fname: Union[str, Path]) -> Tuple[Dict, int]: diff --git a/atommic/collections/common/losses/__init__.py b/atommic/collections/common/losses/__init__.py index bdf483a2..e1057c33 100644 --- a/atommic/collections/common/losses/__init__.py +++ b/atommic/collections/common/losses/__init__.py @@ -5,4 +5,10 @@ from atommic.collections.common.losses.wasserstein import SinkhornDistance # noqa: F401 VALID_RECONSTRUCTION_LOSSES = ["l1", "mse", "ssim", "noise_aware", "wasserstein"] -VALID_SEGMENTATION_LOSSES = ["cross_entropy", "dice"] +VALID_SEGMENTATION_LOSSES = [ + "categorical_cross_entropy", + "dice", + "binary_cross_entropy", + "generalized_dice", + "focal_loss", +] diff --git a/atommic/collections/multitask/rs/nn/base.py b/atommic/collections/multitask/rs/nn/base.py index 5a0b506a..56400924 100644 --- a/atommic/collections/multitask/rs/nn/base.py +++ b/atommic/collections/multitask/rs/nn/base.py @@ -37,8 +37,10 @@ from atommic.collections.reconstruction.losses.na import NoiseAwareLoss from atommic.collections.reconstruction.losses.ssim import SSIMLoss from atommic.collections.reconstruction.metrics import mse, nmse, psnr, ssim -from atommic.collections.segmentation.losses.cross_entropy import CrossEntropyLoss -from atommic.collections.segmentation.losses.dice import Dice +from atommic.collections.segmentation.losses.cross_entropy import BinaryCrossEntropyLoss, CategoricalCrossEntropyLoss +from atommic.collections.segmentation.losses.dice import Dice, GeneralisedDice +from atommic.collections.segmentation.losses.focal import FocalLoss +from atommic.collections.segmentation.losses.utils import one_hot __all__ = ["BaseMRIReconstructionSegmentationModel"] @@ -155,8 +157,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # noqa: MC0001 self.log_multiple_modalities = cfg_dict.get("log_multiple_modalities", False) # Set threshold for segmentation classes. If None, no thresholding is applied. + self.segmentation_type = cfg_dict.get("segmentation_type", "MLS") self.segmentation_classes_thresholds = cfg_dict.get("segmentation_classes_thresholds", None) self.segmentation_activation = cfg_dict.get("segmentation_activation", None) + self.segmentation_output_mode = cfg_dict.get("segmentation_output_mode", "binary") # Initialize loss related parameters. self.segmentation_losses = {} @@ -182,20 +186,37 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # noqa: MC0001 segmentation_losses_ = {k: v / total_weight for k, v in segmentation_losses_.items()} for name in VALID_SEGMENTATION_LOSSES: if name in segmentation_losses_: - if name == "cross_entropy": - cross_entropy_loss_classes_weight = torch.tensor( - cfg_dict.get("cross_entropy_loss_classes_weight", 0.0) - ) - self.segmentation_losses[name] = CrossEntropyLoss( + if name == "categorical_cross_entropy": + self.segmentation_losses[name] = CategoricalCrossEntropyLoss( num_samples=cfg_dict.get("cross_entropy_loss_num_samples", 50), ignore_index=cfg_dict.get("cross_entropy_loss_ignore_index", -100), - reduction=cfg_dict.get("cross_entropy_loss_reduction", "none"), + reduction=cfg_dict.get("cross_entropy_loss_reduction", "mean"), label_smoothing=cfg_dict.get("cross_entropy_loss_label_smoothing", 0.0), - weight=cross_entropy_loss_classes_weight, + weight=cfg_dict.get("cross_entropy_loss_classes_weight", None), + to_onehot_y=cfg_dict.get("cross_entropy_loss_to_onehot_y", True), + include_background=cfg_dict.get("cross_entropy_loss_include_background", True), + ) + elif name == "binary_cross_entropy": + self.segmentation_losses[name] = BinaryCrossEntropyLoss( + num_samples=cfg_dict.get("cross_entropy_loss_num_samples", 50), + include_background=cfg_dict.get("cross_entropy_loss_include_background", True), + reduction=cfg_dict.get("cross_entropy_loss_reduction", "mean"), + weight=cfg_dict.get("cross_entropy_loss_classes_weight", None), + to_onehot_y=cfg_dict.get("cross_entropy_loss_to_onehot_y", True), + ) + elif name == "focal_loss": + self.segmentation_losses[name] = FocalLoss( + reduction=cfg_dict.get("focal_loss_reduction", "mean"), + weight=cfg_dict.get("focal_loss_classes_weight", None), + alpha=cfg_dict.get("focal_loss_alpha", None), + gamma=cfg_dict.get("focal_loss_gamma", 2.0), + use_softmax=cfg_dict.get("focal_loss_use_softmax", True), + to_onehot_y=cfg_dict.get("focal_loss_to_onehot_y", False), + include_background=cfg_dict.get("focal_loss_include_background", True), ) elif name == "dice": self.segmentation_losses[name] = Dice( - include_background=cfg_dict.get("dice_loss_include_background", False), + include_background=cfg_dict.get("dice_loss_include_background", True), to_onehot_y=cfg_dict.get("dice_loss_to_onehot_y", False), sigmoid=cfg_dict.get("dice_loss_sigmoid", True), softmax=cfg_dict.get("dice_loss_softmax", False), @@ -206,8 +227,22 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # noqa: MC0001 reduction=cfg_dict.get("dice_loss_reduction", "mean"), smooth_nr=cfg_dict.get("dice_loss_smooth_nr", 1e-5), smooth_dr=cfg_dict.get("dice_loss_smooth_dr", 1e-5), - batch=cfg_dict.get("dice_loss_batch", False), + batch=cfg_dict.get("dice_loss_batch", True), ) + elif name == "generalized_dice": + self.segmentation_losses[name] = GeneralisedDice( + include_background=cfg_dict.get("dice_loss_include_background", True), + to_onehot_y=cfg_dict.get("dice_loss_to_onehot_y", False), + sigmoid=cfg_dict.get("dice_loss_sigmoid", True), + softmax=cfg_dict.get("dice_loss_softmax", False), + other_act=cfg_dict.get("dice_loss_other_act", None), + reduction=cfg_dict.get("dice_loss_reduction", "mean"), + w_type=cfg_dict.get("dice_loss_w_type", "square"), + smooth_nr=cfg_dict.get("dice_loss_smooth_nr", 1e-5), + smooth_dr=cfg_dict.get("dice_loss_smooth_dr", 1e-5), + batch=cfg_dict.get("dice_loss_batch", True), + ) + self.segmentation_losses = {f"loss_{i+1}": v for i, v in enumerate(self.segmentation_losses.values())} self.total_segmentation_losses = len(self.segmentation_losses) self.total_segmentation_loss_weight = cfg_dict.get("total_segmentation_loss_weight", 1.0) @@ -218,6 +253,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # noqa: MC0001 cross_entropy_metric_reduction = cfg_dict.get("cross_entropy_metric_reduction", "none") cross_entropy_metric_label_smoothing = cfg_dict.get("cross_entropy_metric_label_smoothing", 0.0) cross_entropy_metric_classes_weight = torch.tensor(cfg_dict.get("cross_entropy_metric_classes_weight", 0.0)) + cross_entropy_metric_to_onehot_y = cfg_dict.get("cross_entropy_loss_to_onehot_y", True) + cross_entropy_metric_include_background = cfg_dict.get("cross_entropy_loss_include_background", True) + dice_metric_include_background = cfg_dict.get("dice_metric_include_background", False) dice_metric_to_onehot_y = cfg_dict.get("dice_metric_to_onehot_y", False) dice_metric_sigmoid = cfg_dict.get("dice_metric_sigmoid", True) @@ -264,16 +302,21 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # noqa: MC0001 self.ssim_vals_reconstruction: Dict = defaultdict(dict) self.psnr_vals_reconstruction: Dict = defaultdict(dict) - if not is_none(cross_entropy_metric_classes_weight) and cross_entropy_metric_classes_weight != 0.0: - self.cross_entropy_metric = CrossEntropyLoss( + if ( + not is_none(cross_entropy_metric_classes_weight) and cross_entropy_metric_classes_weight != 0.0 + ): # TODO: Cross-entropy is not really a metric used in papers I would remove it + self.cross_entropy_metric = CategoricalCrossEntropyLoss( num_samples=cross_entropy_metric_num_samples, ignore_index=cross_entropy_metric_ignore_index, reduction=cross_entropy_metric_reduction, label_smoothing=cross_entropy_metric_label_smoothing, weight=cross_entropy_metric_classes_weight, + to_onehot_y=cross_entropy_metric_to_onehot_y, + include_background=cross_entropy_metric_include_background, ) else: self.cross_entropy_metric = None # type: ignore + self.dice_metric = Dice( include_background=dice_metric_include_background, to_onehot_y=dice_metric_to_onehot_y, @@ -354,30 +397,42 @@ def __unnormalize_for_loss_or_log__( target = unnormalize( target, { - "min": attrs["prediction_min"][batch_idx] - if "prediction_min" in attrs - else attrs[f"prediction_min_{r}"][batch_idx], - "max": attrs["prediction_max"][batch_idx] - if "prediction_max" in attrs - else attrs[f"prediction_max_{r}"][batch_idx], - "mean": attrs["prediction_mean"][batch_idx] - if "prediction_mean" in attrs - else attrs[f"prediction_mean_{r}"][batch_idx], - "std": attrs["prediction_std"][batch_idx] - if "prediction_std" in attrs - else attrs[f"prediction_std_{r}"][batch_idx], + "min": ( + attrs["prediction_min"][batch_idx] + if "prediction_min" in attrs + else attrs[f"prediction_min_{r}"][batch_idx] + ), + "max": ( + attrs["prediction_max"][batch_idx] + if "prediction_max" in attrs + else attrs[f"prediction_max_{r}"][batch_idx] + ), + "mean": ( + attrs["prediction_mean"][batch_idx] + if "prediction_mean" in attrs + else attrs[f"prediction_mean_{r}"][batch_idx] + ), + "std": ( + attrs["prediction_std"][batch_idx] + if "prediction_std" in attrs + else attrs[f"prediction_std_{r}"][batch_idx] + ), }, self.normalization_type, ) prediction = unnormalize( prediction, { - "min": attrs["noise_prediction_min"][batch_idx] - if "noise_prediction_min" in attrs - else attrs[f"noise_prediction_min_{r}"][batch_idx], - "max": attrs["noise_prediction_max"][batch_idx] - if "noise_prediction_max" in attrs - else attrs[f"noise_prediction_max_{r}"][batch_idx], + "min": ( + attrs["noise_prediction_min"][batch_idx] + if "noise_prediction_min" in attrs + else attrs[f"noise_prediction_min_{r}"][batch_idx] + ), + "max": ( + attrs["noise_prediction_max"][batch_idx] + if "noise_prediction_max" in attrs + else attrs[f"noise_prediction_max_{r}"][batch_idx] + ), attrs["noise_prediction_mean"][batch_idx] if "noise_prediction_mean" in attrs else "mean": attrs[f"noise_prediction_mean_{r}"][batch_idx], @@ -391,36 +446,52 @@ def __unnormalize_for_loss_or_log__( target = unnormalize( target, { - "min": attrs["target_min"][batch_idx] - if "target_min" in attrs - else attrs[f"target_min_{r}"][batch_idx], - "max": attrs["target_max"][batch_idx] - if "target_max" in attrs - else attrs[f"target_max_{r}"][batch_idx], - "mean": attrs["target_mean"][batch_idx] - if "target_mean" in attrs - else attrs[f"target_mean_{r}"][batch_idx], - "std": attrs["target_std"][batch_idx] - if "target_std" in attrs - else attrs[f"target_std_{r}"][batch_idx], + "min": ( + attrs["target_min"][batch_idx] + if "target_min" in attrs + else attrs[f"target_min_{r}"][batch_idx] + ), + "max": ( + attrs["target_max"][batch_idx] + if "target_max" in attrs + else attrs[f"target_max_{r}"][batch_idx] + ), + "mean": ( + attrs["target_mean"][batch_idx] + if "target_mean" in attrs + else attrs[f"target_mean_{r}"][batch_idx] + ), + "std": ( + attrs["target_std"][batch_idx] + if "target_std" in attrs + else attrs[f"target_std_{r}"][batch_idx] + ), }, self.normalization_type, ) prediction = unnormalize( prediction, { - "min": attrs["prediction_min"][batch_idx] - if "prediction_min" in attrs - else attrs[f"prediction_min_{r}"][batch_idx], - "max": attrs["prediction_max"][batch_idx] - if "prediction_max" in attrs - else attrs[f"prediction_max_{r}"][batch_idx], - "mean": attrs["prediction_mean"][batch_idx] - if "prediction_mean" in attrs - else attrs[f"prediction_mean_{r}"][batch_idx], - "std": attrs["prediction_std"][batch_idx] - if "prediction_std" in attrs - else attrs[f"prediction_std_{r}"][batch_idx], + "min": ( + attrs["prediction_min"][batch_idx] + if "prediction_min" in attrs + else attrs[f"prediction_min_{r}"][batch_idx] + ), + "max": ( + attrs["prediction_max"][batch_idx] + if "prediction_max" in attrs + else attrs[f"prediction_max_{r}"][batch_idx] + ), + "mean": ( + attrs["prediction_mean"][batch_idx] + if "prediction_mean" in attrs + else attrs[f"prediction_mean_{r}"][batch_idx] + ), + "std": ( + attrs["prediction_std"][batch_idx] + if "prediction_std" in attrs + else attrs[f"prediction_std_{r}"][batch_idx] + ), }, self.normalization_type, ) @@ -556,7 +627,7 @@ def process_segmentation_loss(self, target: torch.Tensor, prediction: torch.Tens loss = loss_func(target, prediction) if isinstance(loss, tuple): # In case of the dice loss, the loss is a tuple of the form (dice, dice loss) - loss = loss[1] + loss = loss[1][0] losses[name] = loss return self.total_segmentation_loss(**losses) * self.total_segmentation_loss_weight @@ -734,11 +805,6 @@ def __compute_and_log_metrics_and_outputs__( # noqa: MC0001 batch_idx=_batch_idx_, ) - output_predictions_reconstruction = output_predictions_reconstruction.detach().cpu() - output_target_reconstruction = output_target_reconstruction.detach().cpu() - output_target_segmentation = output_target_segmentation.detach().cpu() - output_predictions_segmentation = output_predictions_segmentation.detach().cpu() - # Normalize target and predictions to [0, 1] for logging. if torch.is_complex(output_target_reconstruction) and output_target_reconstruction.shape[-1] != 2: output_target_reconstruction = torch.view_as_real(output_target_reconstruction) @@ -747,7 +813,6 @@ def __compute_and_log_metrics_and_outputs__( # noqa: MC0001 output_target_reconstruction = output_target_reconstruction / torch.max( torch.abs(output_target_reconstruction) ) - output_target_reconstruction = output_target_reconstruction.detach().cpu() if ( torch.is_complex(output_predictions_reconstruction) @@ -759,7 +824,24 @@ def __compute_and_log_metrics_and_outputs__( # noqa: MC0001 output_predictions_reconstruction = output_predictions_reconstruction / torch.max( torch.abs(output_predictions_reconstruction) ) - output_predictions_reconstruction = output_predictions_reconstruction.detach().cpu() + output_predictions_reconstruction = output_predictions_reconstruction.detach().cpu().float() + output_target_reconstruction = output_target_reconstruction.detach().cpu().float() + output_target_segmentation = output_target_segmentation.detach().cpu().float() + output_predictions_segmentation = output_predictions_segmentation.detach().cpu().float() + + if self.segmentation_type == 'MCS': + output_predictions_segmentation = torch.softmax(output_predictions_segmentation, dim=0).float() + if self.segmentation_output_mode == "binary": + output_predictions_segmentation = output_predictions_segmentation.argmax(dim=0, keepdim=True) + output_predictions_segmentation = one_hot( + output_predictions_segmentation, num_classes=self.segmentation_module_output_channels, dim=0 + ) + else: + # When using wandb plots needs to be between [0,1]. When using "MLS" with/without thresholding the + # outputs are logits and exceed this range. + output_predictions_segmentation = output_predictions_segmentation.clamp(0, 1).float() + if self.segmentation_output_mode == "binary": + output_predictions_segmentation = torch.where(output_predictions_segmentation > 0.5, 1, 0).float() # Log target and predictions, if log_image is True for this slice. if attrs["log_image"][_batch_idx_]: @@ -828,7 +910,6 @@ def __compute_and_log_metrics_and_outputs__( # noqa: MC0001 output_target_segmentation.to(self.device), output_predictions_segmentation.to(self.device), ) - dice_score, _ = self.dice_metric(output_target_segmentation, output_predictions_segmentation) self.dice_vals[fname[_batch_idx_]][str(slice_idx[_batch_idx_].item())] = dice_score @@ -1131,7 +1212,7 @@ def inference_step( # noqa: MC0001 attrs["noise"], ) - if not is_none(self.segmentation_classes_thresholds): + if not is_none(self.segmentation_classes_thresholds) and self.segmentation_type == 'MLS': for class_idx, thres in enumerate(self.segmentation_classes_thresholds): if self.segmentation_activation == "sigmoid": if isinstance(predictions_segmentation, list): @@ -1475,6 +1556,16 @@ def test_step(self, batch: Dict[float, torch.Tensor], batch_idx: int): while isinstance(predictions_segmentation, list): predictions_segmentation = predictions_segmentation[-1] + if self.segmentation_type == 'MCS': + predictions_segmentation = torch.softmax(predictions_segmentation, dim=1).float() + if self.segmentation_output_mode == "binary": + predictions_segmentation = predictions_segmentation.argmax(dim=1, keepdim=True) + predictions_segmentation = one_hot( + predictions_segmentation, num_classes=self.segmentation_module_output_channels, dim=1 + ) + elif self.segmentation_output_mode == "binary": + predictions_segmentation = torch.where(predictions_segmentation > 0.5, 1, 0).float() + predictions_segmentation = predictions_segmentation.detach().cpu().numpy() if self.use_reconstruction_module: @@ -1835,6 +1926,7 @@ def _setup_dataloader_from_config(cfg: DictConfig) -> DataLoader: coil_dim=cfg.get("coil_dim", 1), consecutive_slices=cfg.get("consecutive_slices", 1), use_seed=cfg.get("use_seed", True), + include_background_label=cfg.get("include_background_label", False), ), segmentations_root=cfg.get("segmentations_path"), segmentation_classes=cfg.get("segmentation_classes", 2), diff --git a/atommic/collections/multitask/rs/nn/idslr.py b/atommic/collections/multitask/rs/nn/idslr.py index e6a988c4..5c8a2488 100644 --- a/atommic/collections/multitask/rs/nn/idslr.py +++ b/atommic/collections/multitask/rs/nn/idslr.py @@ -43,7 +43,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): if self.input_channels == 0: raise ValueError("Segmentation module input channels cannot be 0.") reconstruction_out_chans = cfg_dict.get("reconstruction_module_output_channels", 2) - self.segmentation_out_chans = cfg_dict.get("segmentation_module_output_channels", 1) + self.segmentation_module_output_channels = cfg_dict.get("segmentation_module_output_channels", 1) chans = cfg_dict.get("channels", 32) num_pools = cfg_dict.get("num_pools", 4) drop_prob = cfg_dict.get("drop_prob", 0.0) @@ -76,7 +76,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.segmentation_decoder = UnetDecoder( chans=chans, num_pools=num_pools, - out_chans=self.segmentation_out_chans, + out_chans=self.segmentation_module_output_channels, drop_prob=drop_prob, normalize=normalize, padding=padding, @@ -238,9 +238,13 @@ def process_final_segmentation(self, prediction: torch.Tensor) -> torch.Tensor: """ if prediction.shape[-1] == 2: prediction = torch.view_as_complex(prediction) - if prediction.shape[1] != self.segmentation_out_chans and prediction.shape[1] != 2 and prediction.dim() == 5: + if ( + prediction.shape[1] != self.segmentation_module_output_channels + and prediction.shape[1] != 2 + and prediction.dim() == 5 + ): prediction = prediction.squeeze(1) - if prediction.shape[1] != self.segmentation_out_chans: + if prediction.shape[1] != self.segmentation_module_output_channels: prediction = prediction.permute(0, 3, 1, 2) prediction = torch.abs(prediction) if self.normalize_segmentation_output: diff --git a/atommic/collections/multitask/rs/nn/idslr_unet.py b/atommic/collections/multitask/rs/nn/idslr_unet.py index 9dfcf375..7f61a881 100644 --- a/atommic/collections/multitask/rs/nn/idslr_unet.py +++ b/atommic/collections/multitask/rs/nn/idslr_unet.py @@ -45,7 +45,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): if self.input_channels == 0: raise ValueError("Segmentation module input channels cannot be 0.") reconstruction_out_chans = cfg_dict.get("reconstruction_module_output_channels", 2) - segmentation_out_chans = cfg_dict.get("segmentation_module_output_channels", 1) + self.segmentation_module_output_channels = cfg_dict.get("segmentation_module_output_channels", 1) chans = cfg_dict.get("channels", 32) num_pools = cfg_dict.get("num_pools", 4) drop_prob = cfg_dict.get("drop_prob", 0.0) @@ -78,7 +78,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.segmentation_module = Unet( in_chans=reconstruction_out_chans, - out_chans=segmentation_out_chans, + out_chans=self.segmentation_module_output_channels, chans=chans, num_pool_layers=num_pools, drop_prob=drop_prob, diff --git a/atommic/collections/multitask/rs/nn/mtlrs.py b/atommic/collections/multitask/rs/nn/mtlrs.py index 81daa2d4..2b3107fe 100644 --- a/atommic/collections/multitask/rs/nn/mtlrs.py +++ b/atommic/collections/multitask/rs/nn/mtlrs.py @@ -284,19 +284,24 @@ def compute_reconstruction_loss(t, p, s): return loss_func(t, p) - if self.accumulate_predictions: + if self.reconstruction_module_accumulate_predictions: rs_cascades_weights = torch.logspace(-1, 0, steps=len(prediction)).to(target.device) rs_cascades_loss = [] for rs_cascade_pred in prediction: cascades_weights = torch.logspace(-1, 0, steps=len(rs_cascade_pred)).to(target.device) cascades_loss = [] for cascade_pred in rs_cascade_pred: - time_steps_weights = torch.logspace(-1, 0, steps=self.time_steps).to(target.device) + time_steps_weights = torch.logspace(-1, 0, steps=self.reconstruction_module_time_steps).to( + target.device + ) time_steps_loss = [ compute_reconstruction_loss(target, time_step_pred, sensitivity_maps) for time_step_pred in cascade_pred ] - cascade_loss = sum(x * w for x, w in zip(time_steps_loss, time_steps_weights)) / self.time_steps + cascade_loss = ( + sum(x * w for x, w in zip(time_steps_loss, time_steps_weights)) + / self.reconstruction_module_time_steps + ) cascades_loss.append(cascade_loss) rs_cascade_loss = sum(x * w for x, w in zip(cascades_loss, cascades_weights)) / len(rs_cascade_pred) rs_cascades_loss.append(rs_cascade_loss) diff --git a/atommic/collections/multitask/rs/nn/recseg_unet.py b/atommic/collections/multitask/rs/nn/recseg_unet.py index dc6b3fc9..f9425809 100644 --- a/atommic/collections/multitask/rs/nn/recseg_unet.py +++ b/atommic/collections/multitask/rs/nn/recseg_unet.py @@ -55,9 +55,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): drop_prob=cfg_dict.get("reconstruction_module_dropout", 0.0), ) + self.segmentation_module_output_channels = cfg_dict.get("segmentation_module_output_channels", 1) self.segmentation_module = Unet( in_chans=reconstruction_module_output_channels, - out_chans=cfg_dict.get("segmentation_module_output_channels", 1), + out_chans=self.segmentation_module_output_channels, chans=cfg_dict.get("segmentation_module_channels", 64), num_pool_layers=cfg_dict.get("segmentation_module_pooling_layers", 2), drop_prob=cfg_dict.get("segmentation_module_dropout", 0.0), diff --git a/atommic/collections/multitask/rs/nn/segnet.py b/atommic/collections/multitask/rs/nn/segnet.py index c6532336..4a1b5666 100644 --- a/atommic/collections/multitask/rs/nn/segnet.py +++ b/atommic/collections/multitask/rs/nn/segnet.py @@ -52,7 +52,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.input_channels = cfg_dict.get("input_channels", 2) reconstruction_out_chans = cfg_dict.get("reconstruction_module_output_channels", 2) - segmentation_out_chans = cfg_dict.get("segmentation_module_output_channels", 1) + self.segmentation_module_output_channels = cfg_dict.get("segmentation_module_output_channels", 1) chans = cfg_dict.get("channels", 32) num_pools = cfg_dict.get("num_pools", 4) drop_prob = cfg_dict.get("drop_prob", 0.0) @@ -97,7 +97,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): UnetDecoder( chans=chans, num_pools=num_pools, - out_chans=segmentation_out_chans, + out_chans=self.segmentation_module_output_channels, drop_prob=drop_prob, normalize=normalize, padding=padding, @@ -110,8 +110,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.segmentation_final_layer = torch.nn.Sequential( ConvNonlinear( - segmentation_out_chans * num_cascades, - segmentation_out_chans, + self.segmentation_module_output_channels * num_cascades, + self.segmentation_module_output_channels, conv_dim=cfg_dict.get("segmentation_final_layer_conv_dim", 2), kernel_size=cfg_dict.get("segmentation_final_layer_kernel_size", 3), dilation=cfg_dict.get("segmentation_final_layer_dilation", 1), diff --git a/atommic/collections/multitask/rs/nn/seranet.py b/atommic/collections/multitask/rs/nn/seranet.py index 86792a20..c2f2381b 100644 --- a/atommic/collections/multitask/rs/nn/seranet.py +++ b/atommic/collections/multitask/rs/nn/seranet.py @@ -97,10 +97,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): coil_combination_method=self.coil_combination_method, ) self.segmentation_module_input_channels = cfg_dict.get("segmentation_module_input_channels", 2) - segmentation_module_output_channels = cfg_dict.get("segmentation_module_output_channels", 1) + self.segmentation_module_output_channels = cfg_dict.get("segmentation_module_output_channels", 1) self.segmentation_module = ConvLSTMNormUnet( in_chans=self.segmentation_module_input_channels, - out_chans=segmentation_module_output_channels, + out_chans=self.segmentation_module_output_channels, chans=cfg_dict.get("segmentation_module_channels", 64), num_pools=cfg_dict.get("segmentation_module_pooling_layers", 2), drop_prob=cfg_dict.get("segmentation_module_dropout", 0.0), @@ -109,12 +109,12 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): num_iterations=cfg_dict.get("recurrent_module_iterations", 3), attention_model=AttentionGate( in_chans_x=self.segmentation_module_input_channels * 2, - in_chans_g=segmentation_module_output_channels, - out_chans=segmentation_module_output_channels, + in_chans_g=self.segmentation_module_output_channels, + out_chans=self.segmentation_module_output_channels, ), unet_model=ConvLSTMNormUnet( in_chans=self.segmentation_module_input_channels * 2, - out_chans=segmentation_module_output_channels, + out_chans=self.segmentation_module_output_channels, chans=cfg_dict.get("recurrent_module_attention_channels", 64), num_pools=cfg_dict.get("recurrent_module_attention_pooling_layers", 2), drop_prob=cfg_dict.get("recurrent_module_attention_dropout", 0.0), diff --git a/atommic/collections/multitask/rs/nn/seranet_base/seranet_block.py b/atommic/collections/multitask/rs/nn/seranet_base/seranet_block.py index 0cecb5af..508613f1 100644 --- a/atommic/collections/multitask/rs/nn/seranet_base/seranet_block.py +++ b/atommic/collections/multitask/rs/nn/seranet_base/seranet_block.py @@ -123,7 +123,9 @@ def __init__( ) self.model_name = self.reconstruction_module[0].__class__.__name__.lower() if self.model_name == "modulelist": - self.model_name = self.reconstruction_module[0][0].__class__.__name__.lower() + self.model_name = self.reconstruction_module[0][ # pylint: disable=unsubscriptable-object + 0 + ].__class__.__name__.lower() self.fft_centered = fft_centered self.fft_normalization = fft_normalization diff --git a/atommic/collections/multitask/rs/parts/transforms.py b/atommic/collections/multitask/rs/parts/transforms.py index d10abda9..da5a949a 100644 --- a/atommic/collections/multitask/rs/parts/transforms.py +++ b/atommic/collections/multitask/rs/parts/transforms.py @@ -97,8 +97,9 @@ def __init__( fft_normalization: str = "backward", spatial_dims: Sequence[int] = None, coil_dim: int = 0, - consecutive_slices: int = 1, # pylint: disable=unused-argument + consecutive_slices: int = 1, use_seed: bool = True, + include_background_label: bool = False, ): """Inits :class:`RSMRIDataTransforms`. @@ -247,6 +248,8 @@ def __init__( Consecutive slices. Default is ``1``. use_seed : bool, optional Whether to use seed. Default is ``True``. + include_background_label: bool optional + Add an extra class to define the background label. Default is ``False``. """ self.complex_data = complex_data @@ -436,6 +439,8 @@ def __init__( self.normalization, # type: ignore ] ) + self.consecutive_slices = consecutive_slices + self.include_background_label = include_background_label self.cropping = Composer([self.cropping]) # type: ignore self.normalization = Composer([self.normalization]) # type: ignore @@ -574,6 +579,32 @@ def __call__( segmentation_labels = segmentation_labels.float() segmentation_labels = torch.abs(segmentation_labels) + if self.include_background_label: + if self.consecutive_slices > 1: + segmentation_labels_bg = torch.zeros( + (segmentation_labels.shape[0], segmentation_labels.shape[2], segmentation_labels.shape[3]) + ) + segmentation_labels_new = torch.zeros( + ( + segmentation_labels.shape[0], + segmentation_labels.shape[1] + 1, + segmentation_labels.shape[2], + segmentation_labels.shape[3], + ) + ) + for i in range(target_reconstruction.shape[0]): + idx_background = torch.where(torch.sum(segmentation_labels[i], dim=0) == 0) + segmentation_labels_bg[i][idx_background] = 1 + segmentation_labels_new[i] = torch.concat( + (segmentation_labels_bg[i].unsqueeze(0), segmentation_labels[i]), dim=0 + ) + segmentation_labels = segmentation_labels_new + else: + segmentation_labels_bg = torch.zeros((segmentation_labels.shape[-2], segmentation_labels.shape[-1])) + idx_background = torch.where(torch.sum(segmentation_labels, dim=0) == 0) + segmentation_labels_bg[idx_background] = 1 + segmentation_labels = torch.concat((segmentation_labels_bg.unsqueeze(0), segmentation_labels), dim=0) + attrs.update( self.__parse_normalization_vars__( kspace_pre_normalization_vars, diff --git a/atommic/collections/segmentation/losses/__init__.py b/atommic/collections/segmentation/losses/__init__.py index f780f27e..8ca5500c 100644 --- a/atommic/collections/segmentation/losses/__init__.py +++ b/atommic/collections/segmentation/losses/__init__.py @@ -1,5 +1,8 @@ # coding=utf-8 __author__ = "Dimitris Karkalousos" -from atommic.collections.segmentation.losses.cross_entropy import CrossEntropyLoss # noqa: F401 +from atommic.collections.segmentation.losses.cross_entropy import BinaryCrossEntropyLoss # noqa F401 +from atommic.collections.segmentation.losses.cross_entropy import CategoricalCrossEntropyLoss # noqa: F401 from atommic.collections.segmentation.losses.dice import Dice # noqa: F401 +from atommic.collections.segmentation.losses.dice import GeneralisedDice # noqa: F401 +from atommic.collections.segmentation.losses.focal import FocalLoss # noqa: F401 diff --git a/atommic/collections/segmentation/losses/cross_entropy.py b/atommic/collections/segmentation/losses/cross_entropy.py index dd7ef2f3..bf82e19f 100644 --- a/atommic/collections/segmentation/losses/cross_entropy.py +++ b/atommic/collections/segmentation/losses/cross_entropy.py @@ -1,11 +1,16 @@ # coding=utf-8 __author__ = "Dimitris Karkalousos" +import warnings + import torch -from torch import nn + +from atommic.collections.common.parts.utils import is_none +from atommic.collections.segmentation.losses.utils import one_hot +from atommic.core.classes.loss import Loss -class CrossEntropyLoss(nn.Module): +class CategoricalCrossEntropyLoss(Loss): """Wrapper around PyTorch's CrossEntropyLoss to support 2D and 3D inputs.""" def __init__( @@ -15,39 +20,146 @@ def __init__( reduction: str = "none", label_smoothing: float = 0.0, weight: torch.Tensor = None, + to_onehot_y: bool = False, + include_background: bool = True, ): - """Inits :class:`CrossEntropyLoss`. + """Inits :class:`CategoricalCrossEntropyLoss`. Parameters ---------- num_samples : int, optional - Number of Monte Carlo samples, by default 50 + Number of Monte Carlo samples. Default is ``50``. ignore_index : int, optional - Index to ignore, by default -100 + Index to ignore. Default is ``-100``. reduction : str, optional - Reduction method, by default "none" + Reduction method. Default is ``None``. label_smoothing : float, optional - Label smoothing, by default 0.0 + Label smoothing. Default is ``0.0``. weight : torch.Tensor, optional - Weight for each class, by default None + Weight for each class. Default is ``None``. + include_background : bool + whether to include the computation on the first channel of the predicted output. Default is ``True``. + to_onehot_y : bool + Whether to convert `y` into the one-hot format. Default is ``False``. """ super().__init__() self.mc_samples = num_samples + self.include_background = include_background + self.to_onehot_y = to_onehot_y + self.weight = weight + self.ignore_index = ignore_index + self.reduction = reduction + self.label_smoothing = label_smoothing + + def forward( + self, target: torch.Tensor, input: torch.Tensor, pred_log_var: torch.Tensor = None # noqa: MC0001 + ) -> torch.Tensor: + """Forward pass of :class:`CategoricalCrossEntropyLoss`. + + Parameters + ---------- + target : torch.Tensor + Target tensor. Shape: (batch_size, num_classes, *spatial_dims) + input : torch.Tensor + Prediction tensor. Shape: (batch_size, num_classes, *spatial_dims) + pred_log_var : torch.Tensor, optional + Prediction log variance tensor. Shape: (batch_size, num_classes, *spatial_dims). Default is ``None``. + + Returns + ------- + torch.Tensor + CategoricalCrossEntropy Loss + """ + if input.dim() == 3: + input = input.unsqueeze(0) + if target.dim() == 3: + target = target.unsqueeze(0) + + if not is_none(self.weight): + self.weight = torch.tensor(self.weight).clone().to(input) + else: + self.weight = None + self.cross_entropy = torch.nn.CrossEntropyLoss( - weight=weight, - ignore_index=ignore_index, - reduction=reduction, - label_smoothing=label_smoothing, + weight=self.weight, + ignore_index=self.ignore_index, + reduction=self.reduction, + label_smoothing=self.label_smoothing, ) - def forward(self, target: torch.Tensor, _input: torch.Tensor, pred_log_var: torch.Tensor = None) -> torch.Tensor: - """Forward pass of :class:`CrossEntropyLoss`. + n_pred_ch = input.shape[1] + + if self.to_onehot_y: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `to_onehot_y = True` ignored.") + else: + target = one_hot(target, num_classes=n_pred_ch) + + if not self.include_background: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `include_background = False` ignored.") + else: + # if skipping background, removing first channel + target = target[:, 1:] + input = input[:, 1:] + + if self.mc_samples == 1 or pred_log_var is None: + return self.cross_entropy(input.float(), target) + + pred_shape = [self.mc_samples, *input.shape] + noise = torch.randn(pred_shape, device=input.device) + noisy_pred = input.unsqueeze(0) + torch.sqrt(torch.exp(pred_log_var)).unsqueeze(0) * noise + noisy_pred = noisy_pred.view(-1, *input.shape[1:]) + tiled_target = target.unsqueeze(0).tile((self.mc_samples,)).view(-1, *target.shape[1:]) + loss = self.cross_entropy(noisy_pred, tiled_target).view(self.mc_samples, -1, *input.shape[-2:]) + return loss + + +class BinaryCrossEntropyLoss(Loss): + """Wrapper around PyTorch's BinaryCrossEntropyLoss to support 2D and 3D inputs.""" + + def __init__( + self, + num_samples: int = 50, + weight: torch.Tensor = None, + reduction: str = 'none', + include_background: bool = True, + to_onehot_y: bool = False, + ): + """Inits :class:`BinaryCrossEntropyLoss`. + + Parameters + ---------- + num_samples : int, optional + Number of Monte Carlo samples. Default is ``50``. + ignore_index : int, optional + Index to ignore. Default is ``-100``. + reduction : str, optional + Reduction method. Default is ``None``. + weight : torch.Tensor, optional + Weight for each class. Default is ``None``. + include_background : bool + whether to include the computation on the first channel of the predicted output. Default is ``True``. + to_onehot_y : bool + Whether to convert `y` into the one-hot format. Default is ``False``. + """ + super().__init__() + self.mc_samples = num_samples + self.weight = weight + self.reduction = reduction + self.to_onehot_y = to_onehot_y + self.include_background = include_background + + def forward( + self, target: torch.Tensor, input: torch.Tensor, pred_log_var: torch.Tensor = None # noqa: MC0001 + ) -> torch.Tensor: + """Forward pass of :class:`BinaryCrossEntropyLoss`. Parameters ---------- target : torch.Tensor Target tensor. Shape: (batch_size, num_classes, *spatial_dims) - _input : torch.Tensor + input : torch.Tensor Prediction tensor. Shape: (batch_size, num_classes, *spatial_dims) pred_log_var : torch.Tensor, optional Prediction log variance tensor. Shape: (batch_size, num_classes, *spatial_dims). Default is ``None``. @@ -55,23 +167,38 @@ def forward(self, target: torch.Tensor, _input: torch.Tensor, pred_log_var: torc Returns ------- torch.Tensor - Loss tensor. Shape: (batch_size, *spatial_dims) + BinaryCrossEntropy Loss """ # In case we do not have a batch dimension, add it - if _input.dim() == 3: - _input = _input.unsqueeze(0) + if input.dim() == 3: + input = input.unsqueeze(0) if target.dim() == 3: target = target.unsqueeze(0) + self.weight = self.weight.clone().to(input.device) if self.weight is not None else None + n_pred_ch = input.shape[1] + + if self.to_onehot_y: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `to_onehot_y = True` ignored.") + else: + target = one_hot(target, num_classes=n_pred_ch) - self.cross_entropy.weight = self.cross_entropy.weight.clone().to(_input.device) + if not self.include_background: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `include_background = False` ignored.") + else: + # if skipping background, removing first channel + target = target[:, 1:] + input = input[:, 1:] + self.binary_cross_entropy = torch.nn.BCEWithLogitsLoss(weight=self.weight, reduction=self.reduction) if self.mc_samples == 1 or pred_log_var is None: - return self.cross_entropy(_input.float(), target).mean() + return self.binary_cross_entropy(input.float(), target) - pred_shape = [self.mc_samples, *_input.shape] - noise = torch.randn(pred_shape, device=_input.device) - noisy_pred = _input.unsqueeze(0) + torch.sqrt(torch.exp(pred_log_var)).unsqueeze(0) * noise - noisy_pred = noisy_pred.view(-1, *_input.shape[1:]) + pred_shape = [self.mc_samples, *input.shape] + noise = torch.randn(pred_shape, device=input.device) + noisy_pred = input.unsqueeze(0) + torch.sqrt(torch.exp(pred_log_var)).unsqueeze(0) * noise + noisy_pred = noisy_pred.view(-1, *input.shape[1:]) tiled_target = target.unsqueeze(0).tile((self.mc_samples,)).view(-1, *target.shape[1:]) - loss = self.cross_entropy(noisy_pred, tiled_target).view(self.mc_samples, -1, *_input.shape[-2:]).mean(0) - return loss.mean() + loss = self.binary_cross_entropy(noisy_pred, tiled_target).view(self.mc_samples, -1, *input.shape[-2:]).mean(0) + return loss diff --git a/atommic/collections/segmentation/losses/dice.py b/atommic/collections/segmentation/losses/dice.py index f318e67c..a2dbd87a 100644 --- a/atommic/collections/segmentation/losses/dice.py +++ b/atommic/collections/segmentation/losses/dice.py @@ -11,7 +11,7 @@ from torch import Tensor from atommic.collections.common.parts.utils import is_none -from atommic.collections.segmentation.losses.utils import do_metric_reduction +from atommic.collections.segmentation.losses.utils import do_metric_reduction, one_hot from atommic.core.classes.loss import Loss @@ -70,7 +70,7 @@ def __init__( squared_pred: bool = False, jaccard: bool = False, flatten: bool = False, - reduction: str = "mean", + reduction: str = "none", smooth_nr: float = 1e-5, smooth_dr: float = 1e-5, batch: bool = True, @@ -100,7 +100,7 @@ def __init__( 'none': no reduction will be applied. 'mean': the sum of the output will be divided by the number of elements in the output. 'sum': the output will be summed. - Default is ``mean``. + Default is ``none``. smooth_nr : float A small constant added to the numerator to avoid `nan` when all items are 0. Default is ``1e-5``. smooth_dr : float @@ -131,23 +131,30 @@ def __init__( self.smooth_dr = float(smooth_dr) self.batch = batch - def forward(self, target: torch.Tensor, _input: torch.Tensor) -> Tuple[Union[Tensor, Any], Tensor]: # noqa: MC0001 + def forward(self, target: torch.Tensor, input: torch.Tensor) -> Tuple[Union[Tensor, Any], Tensor]: # noqa: MC0001 """Forward pass of :class:`Dice`. Parameters ---------- - _input: torch.Tensor - Prediction of shape [BNHW[D]]. - target: torch.Tensor - Ground truth of shape [BNHW[D]]. + target : torch.Tensor + Target tensor. Shape: (batch_size, num_classes, *spatial_dims) or (batch_size, 1, *spatial_dims) + input : torch.Tensor + Prediction tensor. Shape: (batch_size, num_classes, *spatial_dims) + pred_log_var : torch.Tensor, optional + Prediction log variance tensor. Shape: (batch_size, num_classes, *spatial_dims). Default is ``None``. Returns ------- torch.Tensor - Dice loss. + Dice Loss """ - if isinstance(_input, np.ndarray): - _input = torch.from_numpy(_input) + if input.dim() == 3: + input = input.unsqueeze(0) + if target.dim() == 3: + target = target.unsqueeze(0) + + if isinstance(input, np.ndarray): + input = torch.from_numpy(input) if isinstance(target, np.ndarray): target = torch.from_numpy(target) @@ -157,20 +164,20 @@ def forward(self, target: torch.Tensor, _input: torch.Tensor) -> Tuple[Union[Ten else: segmentation_classes_dim = 0 target = target.reshape(target.shape[segmentation_classes_dim], 1, -1) - _input = _input.reshape(_input.shape[segmentation_classes_dim], 1, -1) + input = input.reshape(input.shape[segmentation_classes_dim], 1, -1) if self.sigmoid: - _input = torch.sigmoid(_input.float()) + input = torch.sigmoid(input.float()) - n_pred_ch = _input.shape[1] + n_pred_ch = input.shape[1] if self.softmax: if n_pred_ch == 1: warnings.warn("single channel prediction, `softmax=True` ignored.") else: - _input = torch.softmax(_input.float(), 1).to(_input) + input = torch.softmax(input.float(), 1).to(input) if self.other_act is not None: - _input = self.other_act(_input) + input = self.other_act(input) if self.to_onehot_y: if n_pred_ch == 1: @@ -184,22 +191,22 @@ def forward(self, target: torch.Tensor, _input: torch.Tensor) -> Tuple[Union[Ten else: # if skipping background, removing first channel target = target[:, 1:] - _input = _input[:, 1:] + input = input[:, 1:] - if target.shape != _input.shape: - raise AssertionError(f"ground truth has different shape ({target.shape}) from _input ({_input.shape})") + if target.shape != input.shape: + raise AssertionError(f"ground truth has different shape ({target.shape}) from input ({input.shape})") # reducing only spatial dimensions (not batch nor channels) - reduce_axis: List[int] = torch.arange(2, len(_input.shape)).tolist() - if self.batch: + reduce_axis: List[int] = torch.arange(2, len(input.shape)).tolist() + if not self.batch: # This reduces the batch dimension so to # reducing spatial dimensions and batch reduce_axis = [0] + reduce_axis - intersection = torch.sum(target * _input, dim=reduce_axis) + intersection = torch.sum(target * input, dim=reduce_axis) if self.squared_pred: target = torch.pow(target, 2) - _input = torch.pow(_input, 2) + input = torch.pow(input, 2) ground_o = torch.sum(target, dim=reduce_axis) - pred_o = torch.sum(_input, dim=reduce_axis) + pred_o = torch.sum(input, dim=reduce_axis) denominator = ground_o + pred_o if self.jaccard: denominator = 2.0 * (denominator - intersection) @@ -210,39 +217,171 @@ def forward(self, target: torch.Tensor, _input: torch.Tensor) -> Tuple[Union[Ten return dice_score, f -def one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.float, dim: int = 1) -> torch.Tensor: - """Convert labels to one-hot representation. - - Parameters - ---------- - labels: torch.Tensor - the labels of shape [BNHW[D]]. - num_classes: int - number of classes. - dtype: torch.dtype - the data type of the returned tensor. - dim: int - the dimension to expand the one-hot tensor. +class GeneralisedDice(Loss): + """ + Compute the Generalised Dice loss defined in: - Returns - ------- - torch.Tensor - The one-hot representation of the labels. + Sudre, C. et. al. (2017) Generalised Dice overlap as a deep learning + loss function for highly unbalanced segmentations. DLMIA 2017. - Examples - -------- - >>> labels = torch.tensor([[[[0, 1, 2]]]]) - >>> one_hot(labels, num_classes=3) - tensor([[[[1., 0., 0.], - [0., 1., 0.], - [0., 0., 1.]]]]) + Adapted from: + https://github.com/NifTK/NiftyNet/blob/v0.6.0/niftynet/layer/loss_segmentation.py#L279 """ - # if `dim` is bigger, add singleton dim at the end - if labels.ndim < dim + 1: - shape = list(labels.shape) + [1] * (dim + 1 - len(labels.shape)) - labels = torch.reshape(labels, shape) - sh = list(labels.shape) - sh[dim] = num_classes - o = torch.zeros(size=sh, dtype=dtype, device=labels.device) - labels = o.scatter_(dim=dim, index=labels.long(), value=1) - return labels + + def __init__( + self, + include_background: bool = True, + to_onehot_y: bool = False, + sigmoid: bool = True, + softmax: bool = False, + other_act: Callable | None = None, + w_type: str = "square", + reduction: str = 'none', + smooth_nr: float = 1e-5, + smooth_dr: float = 1e-5, + batch: bool = True, + ) -> None: + """ + Inits :class:`GeneralisedDiceLoss`. + + Parameters + ---------- + include_background : bool + whether to skip Dice computation on the first channel of the predicted output. Default is ``True``. + to_onehot_y : bool + Whether to convert `y` into the one-hot format. Default is ``False``. + sigmoid : bool + Whether to add sigmoid function to the input data. Default is ``True``. + softmax : bool + Whether to add softmax function to the input data. Default is ``False``. + other_act : Callable + Use this parameter if you want to apply another type of activation layer. Default is ``None``. + squared_pred : bool + Whether to square the prediction before calculating Dice. Default is ``False``. + w_type: {``"square"``, ``"simple"``, ``"uniform"``} + Type of function to transform ground truth volume to a weight factor. Defaults to ``"square"``. + flatten : bool + Whether to flatten input data. Default is ``False``. + reduction : str + Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. + 'none': no reduction will be applied. + 'mean': the sum of the output will be divided by the number of elements in the output. + 'sum': the output will be summed. + Default is ``none``. + smooth_nr : float + A small constant added to the numerator to avoid `nan` when all items are 0. Default is ``1e-5``. + smooth_dr : float + A small constant added to the denominator to avoid `nan` when all items are 0. Default is ``1e-5``. + batch : bool + If True, compute Dice loss for each batch and return a tensor with shape (batch_size,). + If False, compute Dice loss for the whole batch and return a tensor with shape (1,). + Default is ``True``. + """ + super().__init__() + other_act = None if is_none(other_act) else other_act + if other_act is not None and not callable(other_act): + raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.") + if int(sigmoid) + int(softmax) + int(other_act is not None) > 1: + raise ValueError( + "Incompatible values: more than 1 of [sigmoid = True, softmax = True, other_act is not None]." + ) + + self.include_background = include_background + self.to_onehot_y = to_onehot_y + self.sigmoid = sigmoid + self.softmax = softmax + self.other_act = other_act + self.reduction = reduction + self.w_type = w_type + + self.smooth_nr = float(smooth_nr) + self.smooth_dr = float(smooth_dr) + self.batch = batch + + def w_func(self, grnd): + if self.w_type == "simple": + return torch.reciprocal(grnd) + if self.w_type == "square": + return torch.reciprocal(grnd * grnd) + return torch.ones_like(grnd) + + def forward(self, target: torch.Tensor, input: torch.Tensor) -> torch.Tensor: # noqa: MC0001 + """ + Forward pass of :class:`GeneralisedDice`. + + Parameters + ---------- + target : torch.Tensor + Target tensor. Shape: (batch_size, num_classes, *spatial_dims) or (batch_size, 1, *spatial_dims) + input : torch.Tensor + Prediction tensor. Shape: (batch_size, num_classes, *spatial_dims) + pred_log_var : torch.Tensor, optional + Prediction log variance tensor. Shape: (batch_size, num_classes, *spatial_dims). Default is ``None``. + + Returns + ------- + torch.Tensor + GeneralizedDice Loss + """ + if input.dim() == 3: + input = input.unsqueeze(0) + if target.dim() == 3: + target = target.unsqueeze(0) + if self.sigmoid: + input = torch.sigmoid(input) + n_pred_ch = input.shape[1] + if self.softmax: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `softmax=True` ignored.") + else: + input = torch.softmax(input, 1) + + if self.other_act is not None: + input = self.other_act(input) + + if self.to_onehot_y: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + else: + target = one_hot(target, num_classes=n_pred_ch) + + if not self.include_background: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `include_background=False` ignored.") + else: + # if skipping background, removing first channel + target = target[:, 1:] + input = input[:, 1:] + + if target.shape != input.shape: + raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})") + + # reducing only spatial dimensions (not batch nor channels) + reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist() + if not self.batch: + reduce_axis = [0] + reduce_axis + intersection = torch.sum(target * input, reduce_axis) + + ground_o = torch.sum(target, reduce_axis) + pred_o = torch.sum(input, reduce_axis) + + denominator = ground_o + pred_o + + w = self.w_func(ground_o.float()) + infs = torch.isinf(w) + if not self.batch: + w[infs] = 0.0 + w = w + infs * torch.max(w) + else: + w[infs] = 0.0 + max_values = torch.max(w, dim=1)[0].unsqueeze(dim=1) + w = w + infs * max_values + + final_reduce_dim = 0 if not self.batch else 1 + numer = 2.0 * (intersection * w).sum(final_reduce_dim, keepdim=True) + self.smooth_nr + denom = (denominator * w).sum(final_reduce_dim, keepdim=True) + self.smooth_dr + gendice_score = numer / denom + gendice_score = torch.where(denominator > 0, gendice_score, torch.tensor(1.0).to(pred_o.device)) + gendice_score, _ = do_metric_reduction(gendice_score, reduction=self.reduction) + f: torch.Tensor = 1.0 - gendice_score + return gendice_score, f diff --git a/atommic/collections/segmentation/losses/focal.py b/atommic/collections/segmentation/losses/focal.py new file mode 100644 index 00000000..8b4c8bd4 --- /dev/null +++ b/atommic/collections/segmentation/losses/focal.py @@ -0,0 +1,236 @@ +# coding=utf-8 +# coding = utf-8 +__author__ = "Tim Paquaij" + +# Taken and adapted from: +# https://github.com/Project-MONAI/MONAI/blob/46a5272196a6c2590ca2589029eed8e4d56ff008/monai/losses/focal_loss.py + +import warnings +from typing import Optional, Sequence +import torch +import torch.nn.functional as F + +from atommic.collections.segmentation.losses.utils import one_hot +from atommic.core.classes.loss import Loss + + +class FocalLoss(Loss): + """ + FocalLoss is an extension of BCEWithLogitsLoss that down-weights loss from + high confidence correct predictions. + + Reimplementation of the Focal Loss described in: + + - ["Focal Loss for Dense Object Detection"](https://arxiv.org/abs/1708.02002), T. Lin et al., ICCV 2017 + - "AnatomyNet: Deep learning for fast and fully automated whole-volume segmentation of head and neck anatomy", + Zhu et al., Medical Physics 2018 + + """ + + def __init__( + self, + gamma: float = 2.0, + alpha: float | None = None, + weight: Sequence[float] | float | int | torch.Tensor | None = None, + reduction: str = "none", + use_softmax: bool = False, + include_background: bool = True, + to_onehot_y: bool = False, + ) -> None: + """Inits :class:`FocalLoss` + + Parameters + ---------- + gamma : float, optional + Value of the exponent gamma in the definition of the Focal loss. Default is 2 + alpha : float, optional + Value of the alpha: [0,1] in the definition of the alpha-balanced Focal loss. Default is None + weight : torch.Tensor, optional + Weight for each class. Default is None + reduction : str, optional + Reduction method. Default is "none" + use_softmax : bool, optional + option to compute the focal loss as a categorical cross-entropy. Default is ``False`` + include_background : bool + whether to include the computation on the first channel of the predicted output. Default is ``True``. + to_onehot_y : bool + Whether to convert `y` into the one-hot format. Default is ``False``. + """ + super().__init__() + self.reduction = reduction + self.include_background = include_background + self.to_onehot_y = to_onehot_y + self.gamma = gamma + self.alpha = alpha + self.use_softmax = use_softmax + self.weight = torch.as_tensor(weight) if weight is not None else None + self.register_buffer("class_weight", self.weight) + + def forward(self, target: torch.Tensor, input: torch.Tensor) -> torch.Tensor: # noqa: MC0001 + """Forward pass of :class:`FocalLoss`. + + Parameters + ---------- + target : torch.Tensor + Target tensor. Shape: (batch_size, num_classes, *spatial_dims) + input : torch.Tensor + Prediction tensor. Shape: (batch_size, num_classes, *spatial_dims) + + Returns + ------- + torch.Tensor + Focal Loss + """ + if input.dim() == 3: + input = input.unsqueeze(0) + if target.dim() == 3: + target = target.unsqueeze(0) + + n_pred_ch = input.shape[1] + + if self.to_onehot_y: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `to_onehot_y = True` ignored.") + else: + target = one_hot(target, num_classes=n_pred_ch) + + if not self.include_background: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `include_background = False` ignored.") + else: + # if skipping background, removing first channel + target = target[:, 1:] + input = input[:, 1:] + + if target.shape != input.shape: + raise ValueError(f"ground truth has different shape ({target.shape}) from input ({input.shape})") + + loss: Optional[torch.Tensor] = None + input = input.float() + target = target.float() + if self.use_softmax: + if not self.include_background and self.alpha is not None: + self.alpha = None + warnings.warn("`include_background = False`, `alpha` ignored when using softmax.") + loss = softmax_focal_loss(input, target, self.gamma, self.alpha) + else: + loss = sigmoid_focal_loss(input, target, self.gamma, self.alpha) + + num_of_classes = target.shape[1] + if (self.class_weight is not None) and ( # type: ignore # pylint: disable=access-member-before-definition + num_of_classes != 1 + ): + # make sure the lengths of weights are equal to the number of classes + if self.class_weight.ndim == 0: # type: ignore # pylint: disable=access-member-before-definition + self.class_weight = torch.as_tensor([self.class_weight] * num_of_classes) # type: ignore + else: + if self.class_weight.shape[0] != num_of_classes: + raise ValueError( + """the length of the `weight` sequence should be the same as the number of classes. + If `include_background = False`, the weight should not include + the background category class 0.""" + ) + if self.class_weight.min() < 0: + raise ValueError("the value/values of the `weight` should be no less than 0.") + # apply class_weight to loss + self.class_weight = self.class_weight.to(loss) + broadcast_dims = [-1] + [1] * len(target.shape[2:]) + self.class_weight = self.class_weight.view(broadcast_dims) + loss = self.class_weight * loss + + if self.reduction == "sum": + # Previously there was a mean over the last dimension, which did not + # return a compatible BCE loss. To maintain backwards compatible + # behavior we have a flag that performs this extra step, disable or + # parameterize if necessary. (Or justify why the mean should be there) + average_spatial_dims = True + if average_spatial_dims: + loss = loss.mean(dim=list(range(2, len(target.shape)))) + loss = loss.sum() + elif self.reduction == "mean": + loss = loss.mean() + elif self.reduction == "none": + pass + else: + raise ValueError( + f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].' + ) + return loss + + +def softmax_focal_loss( + input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: Optional[float] = None # noqa: MC0001 +) -> torch.Tensor: + """Softmax operation for focal loss + + Parameters + ---------- + input : torch.Tensor + Prediction tensor. Shape: (batch_size, num_classes, *spatial_dims) + target : torch.Tensor + Target tensor. Shape: (batch_size, num_classes, *spatial_dims) + gamma : float + Value of the exponent gamma in the definition of the Focal loss. Default is 2 + alpha : float, optional + Value of the alpha: [0,1] in the definition of the alpha-balanced Focal loss. Default is None + + Returns + ------- + torch.Tensor + Focal Loss + """ + input_ls = input.log_softmax(1) + loss: torch.Tensor = -(1 - input_ls.exp()).pow(gamma) * input_ls * target + + if alpha is not None: + # (1-alpha) for the background class and alpha for the other classes + alpha_fac = torch.tensor([1 - alpha] + [alpha] * (target.shape[1] - 1)).to(loss) + broadcast_dims = [-1] + [1] * len(target.shape[2:]) + alpha_fac = alpha_fac.view(broadcast_dims) + loss = alpha_fac * loss + + return loss + + +def sigmoid_focal_loss( + input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: Optional[float] = None # noqa: MC0001 +) -> torch.Tensor: + """Sigmoid operation for focal loss + + Parameters + ---------- + input : torch.Tensor + Prediction tensor. Shape: (batch_size, num_classes, *spatial_dims) + target : torch.Tensor + Target tensor. Shape: (batch_size, num_classes, *spatial_dims) + gamma : float + Value of the exponent gamma in the definition of the Focal loss, defualt 2 + alpha : float, optional + Value of the alpha: [0,1] in the definition of the alpha-balanced Focal loss. Default is None + + Returns + ------- + torch.Tensor + Focal Loss + """ + # computing binary cross entropy with logits + # equivalent to F.binary_cross_entropy_with_logits(input, target, reduction = 'none') + # see also https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Loss.cpp#L363 + loss: torch.Tensor = input - input * target - F.logsigmoid(input) + + # sigmoid(-i) if t == 1; sigmoid(i) if t == 0 <=> + # 1-sigmoid(i) if t == 1; sigmoid(i) if t == 0 <=> + # 1-p if t == 1; p if t == 0 <=> + # pfac, that is, the term (1 - pt) + invprobs = F.logsigmoid(-input * (target * 2 - 1)) # reduced chance of overflow + # (pfac.log() * gamma).exp() <=> + # pfac.log().exp() ^ gamma <=> + # pfac ^ gamma + loss = (invprobs * gamma).exp() * loss + + if alpha is not None: + # alpha if t == 1; (1-alpha) if t == 0 + alpha_factor = target * alpha + (1 - target) * (1 - alpha) + loss = alpha_factor * loss + + return loss diff --git a/atommic/collections/segmentation/losses/utils.py b/atommic/collections/segmentation/losses/utils.py index d4ff6037..bca136c3 100644 --- a/atommic/collections/segmentation/losses/utils.py +++ b/atommic/collections/segmentation/losses/utils.py @@ -72,3 +72,41 @@ def do_metric_reduction(f: torch.Tensor, reduction: str = "mean") -> Tuple[Tenso '["mean", "sum", "mean_batch", "sum_batch", "mean_channel", "sum_channel" "none"].' ) return f, not_nans + + +def one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.float, dim: int = 1) -> torch.Tensor: + """Convert labels to one-hot representation. + + Parameters + ---------- + labels: torch.Tensor + the labels of shape [BNHW[D]]. + num_classes: int + number of classes. + dtype: torch.dtype + the data type of the returned tensor. + dim: int + the dimension to expand the one-hot tensor. + + Returns + ------- + torch.Tensor + The one-hot representation of the labels. + + Examples + -------- + >>> labels = torch.tensor([[[[0, 1, 2]]]]) + >>> one_hot(labels, num_classes=3) + tensor([[[[1., 0., 0.], + [0., 1., 0.], + [0., 0., 1.]]]]) + """ + # if `dim` is bigger, add singleton dim at the end + if labels.ndim < dim + 1: + shape = list(labels.shape) + [1] * (dim + 1 - len(labels.shape)) + labels = torch.reshape(labels, shape) + sh = list(labels.shape) + sh[dim] = num_classes + o = torch.zeros(size=sh, dtype=dtype, device=labels.device) + labels = o.scatter_(dim=dim, index=labels.long(), value=1) + return labels diff --git a/atommic/collections/segmentation/metrics/segmentation_metrics.py b/atommic/collections/segmentation/metrics/segmentation_metrics.py index a5b520a4..e88bfc72 100644 --- a/atommic/collections/segmentation/metrics/segmentation_metrics.py +++ b/atommic/collections/segmentation/metrics/segmentation_metrics.py @@ -12,8 +12,7 @@ from torchmetrics import functional as F from atommic.collections.segmentation.losses import Dice -from atommic.collections.segmentation.losses.dice import one_hot -from atommic.collections.segmentation.losses.utils import do_metric_reduction +from atommic.collections.segmentation.losses.utils import do_metric_reduction, one_hot def asd(x, y, voxelspacing=None, connectivity=1): diff --git a/atommic/collections/segmentation/nn/base.py b/atommic/collections/segmentation/nn/base.py index f81c5d4e..043a7389 100644 --- a/atommic/collections/segmentation/nn/base.py +++ b/atommic/collections/segmentation/nn/base.py @@ -26,8 +26,10 @@ SegmentationMRIDataset, SKMTEASegmentationMRIDataset, ) -from atommic.collections.segmentation.losses.cross_entropy import CrossEntropyLoss -from atommic.collections.segmentation.losses.dice import Dice +from atommic.collections.segmentation.losses.cross_entropy import BinaryCrossEntropyLoss, CategoricalCrossEntropyLoss +from atommic.collections.segmentation.losses.dice import Dice, GeneralisedDice +from atommic.collections.segmentation.losses.focal import FocalLoss +from atommic.collections.segmentation.losses.utils import one_hot from atommic.collections.segmentation.parts.transforms import SegmentationMRIDataTransforms __all__ = ["BaseMRISegmentationModel"] @@ -59,6 +61,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): if self.input_channels == 0: raise ValueError("Segmentation module input channels cannot be 0.") + # Set the output channels of the segmentation module. Necessary for multi-class segmentation. + self.segmentation_module_output_channels = cfg_dict.get("segmentation_module_output_channels", 2) + # Set type of data, i.e., magnitude only or complex valued. self.magnitude_input = cfg_dict.get("magnitude_input", True) # Refers to the type of the complex-valued data. It can be either "stacked" or "complex_abs" or @@ -77,6 +82,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # Set threshold for segmentation classes. If None, no thresholding is applied. self.segmentation_classes_thresholds = cfg_dict.get("segmentation_classes_thresholds", None) self.segmentation_activation = cfg_dict.get("segmentation_activation", None) + self.segmentation_type = cfg_dict.get("segmentation_type", "MLS") + self.segmentation_output_mode = cfg_dict.get("segmentation_output_mode", "binary") # Initialize loss related parameters. self.segmentation_losses = {} @@ -102,20 +109,37 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): segmentation_losses_ = {k: v / total_weight for k, v in segmentation_losses_.items()} for name in VALID_SEGMENTATION_LOSSES: if name in segmentation_losses_: - if name == "cross_entropy": - cross_entropy_loss_classes_weight = torch.tensor( - cfg_dict.get("cross_entropy_loss_classes_weight", 0.5) - ) - self.segmentation_losses[name] = CrossEntropyLoss( + if name == "categorical_cross_entropy": + self.segmentation_losses[name] = CategoricalCrossEntropyLoss( num_samples=cfg_dict.get("cross_entropy_loss_num_samples", 50), ignore_index=cfg_dict.get("cross_entropy_loss_ignore_index", -100), - reduction=cfg_dict.get("cross_entropy_loss_reduction", "none"), + reduction=cfg_dict.get("cross_entropy_loss_reduction", "mean"), label_smoothing=cfg_dict.get("cross_entropy_loss_label_smoothing", 0.0), - weight=cross_entropy_loss_classes_weight, + weight=cfg_dict.get("cross_entropy_loss_classes_weight", None), + to_onehot_y=cfg_dict.get("cross_entropy_loss_to_onehot_y", True), + include_background=cfg_dict.get("cross_entropy_loss_include_background", True), + ) + elif name == "binary_cross_entropy": + self.segmentation_losses[name] = BinaryCrossEntropyLoss( + num_samples=cfg_dict.get("cross_entropy_loss_num_samples", 50), + include_background=cfg_dict.get("cross_entropy_loss_include_background", True), + reduction=cfg_dict.get("cross_entropy_loss_reduction", "mean"), + weight=cfg_dict.get("cross_entropy_loss_classes_weight", None), + to_onehot_y=cfg_dict.get("cross_entropy_loss_to_onehot_y", True), + ) + elif name == "focal_loss": + self.segmentation_losses[name] = FocalLoss( + reduction=cfg_dict.get("focal_loss_reduction", "mean"), + weight=cfg_dict.get("focal_loss_classes_weight", None), + alpha=cfg_dict.get("focal_loss_alpha", None), + gamma=cfg_dict.get("focal_loss_gamma", 2.0), + use_softmax=cfg_dict.get("focal_loss_use_softmax", True), + to_onehot_y=cfg_dict.get("focal_loss_to_onehot_y", False), + include_background=cfg_dict.get("focal_loss_include_background", True), ) elif name == "dice": self.segmentation_losses[name] = Dice( - include_background=cfg_dict.get("dice_loss_include_background", False), + include_background=cfg_dict.get("dice_loss_include_background", True), to_onehot_y=cfg_dict.get("dice_loss_to_onehot_y", False), sigmoid=cfg_dict.get("dice_loss_sigmoid", True), softmax=cfg_dict.get("dice_loss_softmax", False), @@ -126,8 +150,22 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): reduction=cfg_dict.get("dice_loss_reduction", "mean"), smooth_nr=cfg_dict.get("dice_loss_smooth_nr", 1e-5), smooth_dr=cfg_dict.get("dice_loss_smooth_dr", 1e-5), - batch=cfg_dict.get("dice_loss_batch", False), + batch=cfg_dict.get("dice_loss_batch", True), + ) + elif name == "generalized_dice": + self.segmentation_losses[name] = GeneralisedDice( + include_background=cfg_dict.get("dice_loss_include_background", True), + to_onehot_y=cfg_dict.get("dice_loss_to_onehot_y", False), + sigmoid=cfg_dict.get("dice_loss_sigmoid", True), + softmax=cfg_dict.get("dice_loss_softmax", False), + other_act=cfg_dict.get("dice_loss_other_act", None), + reduction=cfg_dict.get("dice_loss_reduction", "mean"), + w_type=cfg_dict.get("dice_loss_w_type", "square"), + smooth_nr=cfg_dict.get("dice_loss_smooth_nr", 1e-5), + smooth_dr=cfg_dict.get("dice_loss_smooth_dr", 1e-5), + batch=cfg_dict.get("dice_loss_batch", True), ) + self.segmentation_losses = {f"loss_{i+1}": v for i, v in enumerate(self.segmentation_losses.values())} self.total_segmentation_losses = len(self.segmentation_losses) self.total_segmentation_loss_weight = cfg_dict.get("total_segmentation_loss_weight", 1.0) @@ -138,6 +176,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): cross_entropy_metric_reduction = cfg_dict.get("cross_entropy_metric_reduction", "none") cross_entropy_metric_label_smoothing = cfg_dict.get("cross_entropy_metric_label_smoothing", 0.0) cross_entropy_metric_classes_weight = cfg_dict.get("cross_entropy_metric_classes_weight", None) + cross_entropy_metric_to_onehot_y = cfg_dict.get("cross_entropy_loss_to_onehot_y", False) + cross_entropy_metric_include_background = cfg_dict.get("cross_entropy_loss_include_background", True) dice_metric_include_background = cfg_dict.get("dice_metric_include_background", False) dice_metric_to_onehot_y = cfg_dict.get("dice_metric_to_onehot_y", False) dice_metric_sigmoid = cfg_dict.get("dice_metric_sigmoid", True) @@ -154,17 +194,21 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # Initialize the module super().__init__(cfg=cfg, trainer=trainer) - if not is_none(cross_entropy_metric_classes_weight): - cross_entropy_metric_classes_weight = torch.tensor(cross_entropy_metric_classes_weight) - self.cross_entropy_metric = CrossEntropyLoss( + if ( + not is_none(cross_entropy_metric_classes_weight) and cross_entropy_metric_classes_weight != 0.0 + ): # TODO: Cross-entropy is not really a metric used in papers I would remove it + self.cross_entropy_metric = CategoricalCrossEntropyLoss( num_samples=cross_entropy_metric_num_samples, ignore_index=cross_entropy_metric_ignore_index, reduction=cross_entropy_metric_reduction, label_smoothing=cross_entropy_metric_label_smoothing, weight=cross_entropy_metric_classes_weight, + to_onehot_y=cross_entropy_metric_to_onehot_y, + include_background=cross_entropy_metric_include_background, ) else: self.cross_entropy_metric = None # type: ignore + self.dice_metric = Dice( include_background=dice_metric_include_background, to_onehot_y=dice_metric_to_onehot_y, @@ -282,7 +326,7 @@ def process_segmentation_loss(self, target: torch.Tensor, prediction: torch.Tens loss = loss_func(target, prediction) if isinstance(loss, tuple): # In case of the dice loss, the loss is a tuple of the form (dice, dice loss) - loss = loss[1] + loss = loss[1][0] losses[name] = loss return self.total_segmentation_loss(**losses) * self.total_segmentation_loss_weight @@ -338,6 +382,20 @@ def __compute_and_log_metrics_and_outputs__( output_target_segmentation = target_segmentation[_batch_idx_] output_predictions = predictions[_batch_idx_] + if self.segmentation_type == 'MCS': + output_predictions = torch.softmax(output_predictions, dim=0).float() + if self.segmentation_output_mode == "binary": + output_predictions = output_predictions.argmax(dim=0, keepdim=True) + output_predictions = one_hot( + output_predictions, num_classes=self.segmentation_module_output_channels, dim=0 + ) + else: + # When using wandb plots needs to be between [0,1]. + # When using "MLS" with/without thresholding the outputs are logits and exceed this range. + output_predictions = output_predictions.clamp(0, 1).float() + if self.segmentation_output_mode == "binary": + output_predictions = torch.where(output_predictions > 0.5, 1, 0).float() + if self.unnormalize_log_outputs: # Unnormalize target and predictions with pre normalization values. This is only for logging purposes. # For the loss computation, the self.unnormalize_loss_inputs flag is used. @@ -433,7 +491,7 @@ def inference_step( if prediction.dim() == 5: prediction = prediction.reshape(batch_size * slices, *prediction.shape[2:]) - if not is_none(self.segmentation_classes_thresholds): + if not is_none(self.segmentation_classes_thresholds) and self.segmentation_type == 'MLS': for class_idx, thres in enumerate(self.segmentation_classes_thresholds): if self.segmentation_activation == "sigmoid": cond = torch.sigmoid(prediction[:, class_idx]) @@ -599,7 +657,6 @@ def validation_step(self, batch: Dict[float, torch.Tensor], batch_idx: int): target_segmentation = outputs["target"] predictions = outputs["predictions"] - # print memory usage for debugging val_loss = self.process_segmentation_loss(target_segmentation, predictions, attrs) # type: ignore # Compute metrics and log them and log outputs. @@ -685,6 +742,14 @@ def test_step(self, batch: Dict[float, torch.Tensor], batch_idx: int): fname, # type: ignore slice_idx, # type: ignore ) + if self.segmentation_type == 'MCS': + predictions = torch.softmax(predictions, dim=1).float() + if self.segmentation_output_mode == "binary": + predictions = predictions.argmax(dim=1, keepdim=True) + predictions = one_hot(predictions, num_classes=self.segmentation_module_output_channels, dim=1) + else: + if self.segmentation_output_mode == "binary": + predictions = torch.where(predictions > 0.5, 1, 0).float() # Get the file name. fname = attrs['fname'][0] # type: ignore diff --git a/projects/MTL/rs/SKMTEA/conf/train/mtlrs.yaml b/projects/MTL/rs/SKMTEA/conf/train/mtlrs.yaml index 3d3dd6c3..71ad2a85 100644 --- a/projects/MTL/rs/SKMTEA/conf/train/mtlrs.yaml +++ b/projects/MTL/rs/SKMTEA/conf/train/mtlrs.yaml @@ -49,6 +49,8 @@ model: reconstruction_module_keep_prediction: true reconstruction_module_accumulate_predictions: true segmentation_module: AttentionUNet + segmentation_type: MLS + segmentation_output_mode: binary segmentation_module_input_channels: 1 segmentation_module_output_channels: 4 segmentation_module_channels: 64 @@ -64,7 +66,7 @@ model: dice_loss_squared_pred: false dice_loss_jaccard: false dice_loss_flatten: false - dice_loss_reduction: mean_batch + dice_loss_reduction: mean dice_loss_smooth_nr: 1e-5 dice_loss_smooth_dr: 1e-5 dice_loss_batch: true @@ -76,7 +78,7 @@ model: dice_metric_squared_pred: false dice_metric_jaccard: false dice_metric_flatten: false - dice_metric_reduction: mean_batch + dice_metric_reduction: mean dice_metric_smooth_nr: 1e-5 dice_metric_smooth_dr: 1e-5 dice_metric_batch: true @@ -152,6 +154,7 @@ model: coil_dim: 1 use_seed: false segmentations_path: data_parent_dir/skm-tea/v1-release/segmentation_masks/raw-data-track + include_background_label: false segmentation_classes: 4 complex_data: true batch_size: 1 @@ -206,6 +209,7 @@ model: coil_dim: 1 use_seed: true segmentations_path: data_parent_dir/skm-tea/v1-release/segmentation_masks/raw-data-track + include_background_label: false segmentation_classes: 4 complex_data: true batch_size: 1 diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 95865e4d..ba9f35a4 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,7 +1,7 @@ defusedxml>=0.7.1 einops>=0.5.0 h5py==3.9.0 -huggingface_hub +huggingface_hub<=0.20.3 hydra-core>1.3,<=1.3.2 nibabel==5.1.0 numba diff --git a/tests/collections/multitask/rs/models/test_mtlrs.py b/tests/collections/multitask/rs/models/test_mtlrs.py index aa88845d..93907424 100644 --- a/tests/collections/multitask/rs/models/test_mtlrs.py +++ b/tests/collections/multitask/rs/models/test_mtlrs.py @@ -222,6 +222,82 @@ "max_steps": -1, }, ), + ( + [1, 3, 32, 16, 2], + { + "use_reconstruction_module": True, + "task_adaption_type": "multi_task_learning", + "joint_reconstruction_segmentation_module_cascades": 5, + "reconstruction_module_recurrent_layer": "IndRNN", + "reconstruction_module_conv_filters": [64, 64, 2], + "reconstruction_module_conv_kernels": [5, 3, 3], + "reconstruction_module_conv_dilations": [1, 2, 1], + "reconstruction_module_conv_bias": [True, True, False], + "reconstruction_module_recurrent_filters": [64, 64, 0], + "reconstruction_module_recurrent_kernels": [1, 1, 0], + "reconstruction_module_recurrent_dilations": [1, 1, 0], + "reconstruction_module_recurrent_bias": [True, True, False], + "reconstruction_module_depth": 2, + "reconstruction_module_conv_dim": 2, + "reconstruction_module_time_steps": 8, + "reconstruction_module_num_cascades": 5, + "reconstruction_module_dimensionality": 2, + "reconstruction_module_accumulate_predictions": True, + "reconstruction_module_no_dc": True, + "reconstruction_module_keep_prediction": True, + "reconstruction_loss": {"l1": 1.0}, + "segmentation_module": "UNet", + "segmentation_module_input_channels": 2, + "segmentation_module_output_channels": 4, + "segmentation_module_channels": 64, + "segmentation_module_pooling_layers": 4, + "segmentation_module_dropout": 0.0, + "segmentation_loss": { + "focal_loss": 1.0, + "dice": 1.0, + "categorical_cross_entropy": 1.0, + "generalized_dice": 1.0, + }, + "segmentation_type": "MLS", # MCS (Multiclass Segmentation) MLS (Multilabel Segmentation) + "segmentation_output_mode": "Probability", # Probability + "dice_loss_include_background": False, + "dice_loss_to_onehot_y": False, + "dice_loss_sigmoid": True, + "dice_loss_softmax": False, + "dice_loss_other_act": None, + "dice_loss_squared_pred": False, + "dice_loss_jaccard": False, + "dice_loss_reduction": "mean", + "dice_loss_smooth_nr": 1, + "dice_loss_smooth_dr": 1, + "dice_loss_batch": True, + "consecutive_slices": 1, + "coil_combination_method": "SENSE", + "magnitude_input": False, + "use_sens_net": False, + "fft_centered": False, + "fft_normalization": "backward", + "spatial_dims": [-2, -1], + "coil_dim": 1, + "dimensionality": 2, + }, + [0.08], + [4], + 2, + 4, + { + "strategy": "ddp", + "accelerator": "cpu", + "num_nodes": 1, + "max_epochs": 20, + "precision": 32, + "enable_checkpointing": False, + "logger": False, + "log_every_n_steps": 50, + "check_val_every_n_epoch": -1, + "max_steps": -1, + }, + ), ], ) def test_mtlmrirs(shape, cfg, center_fractions, accelerations, dimensionality, segmentation_classes, trainer): diff --git a/tests/collections/segmentation/losses/__init__.py b/tests/collections/segmentation/losses/__init__.py new file mode 100644 index 00000000..3a328280 --- /dev/null +++ b/tests/collections/segmentation/losses/__init__.py @@ -0,0 +1,2 @@ +# coding=utf-8 +__author__ = "Tim Paquaij" diff --git a/tests/collections/segmentation/losses/test_ce.py b/tests/collections/segmentation/losses/test_ce.py new file mode 100644 index 00000000..ebe51a63 --- /dev/null +++ b/tests/collections/segmentation/losses/test_ce.py @@ -0,0 +1,183 @@ +# coding=utf-8 +__author__ = "Tim Paquaij" +import pytest +import torch + +from atommic.collections.segmentation.losses.cross_entropy import BinaryCrossEntropyLoss, CategoricalCrossEntropyLoss +from atommic.collections.segmentation.losses.focal import FocalLoss +from tests.collections.reconstruction.mri_data.conftest import create_input + + +@pytest.mark.parametrize( + "shape, cfg", + [ + ( + [2, 4, 10, 10], + { + "to_onehot_y": True, + "include_background": True, + "reduction": 'none', + }, + ), + ( + [2, 4, 10, 10], + { + "to_onehot_y": True, + "include_background": False, + "reduction": 'none', + }, + ), + ( + [2, 4, 10, 10], + { + "to_onehot_y": True, + "include_background": False, + "reduction": 'mean', + }, + ), + ], +) +def test_CCE_loss(shape, cfg): + """ + Test Categorical Cross-Entropy Loss + + Parameters + ---------- + shape : list of int + Shape of the input data + cfg : dict + Dictionary with the parameters of the loss function + """ + x = create_input(shape).requires_grad_() + y = create_input(shape) + y = torch.softmax(y, dim=1) + y = torch.argmax(y, dim=1, keepdim=True) + ce_loss = CategoricalCrossEntropyLoss( + to_onehot_y=cfg.get('to_onehot_y'), + include_background=cfg.get('include_background'), + reduction=cfg.get('reduction'), + weight=cfg.get('weight'), + ) + + result = ce_loss(y, x) + if cfg.get('include_background') and cfg.get('reduction') == 'none': + assert result.shape == torch.Size([2, 10, 10]) + if not cfg.get('include_background') and cfg.get('reduction') == 'none': + assert result.shape == torch.Size([2, 10, 10]) + if cfg.get('reduction') == 'mean': + assert torch.is_floating_point(result) + + +@pytest.mark.parametrize( + "shape, cfg", + [ + ( + [2, 4, 10, 10], + { + "to_onehot_y": True, + "include_background": True, + "reduction": 'none', + }, + ), + ( + [2, 4, 10, 10], + { + "to_onehot_y": True, + "include_background": False, + "reduction": 'none', + }, + ), + ( + [2, 4, 10, 10], + { + "to_onehot_y": True, + "include_background": False, + "reduction": 'mean', + }, + ), + ], +) +def test_BCE_loss(shape, cfg): + """ + Test Binary Cross-Entropy Loss + + Parameters + ---------- + shape : list of int + Shape of the input data + cfg : dict + Dictionary with the parameters of the loss function + """ + x = create_input(shape).requires_grad_() + y = create_input(shape) + y = torch.softmax(y, dim=1) + y = torch.argmax(y, dim=1, keepdim=True) + ce_loss = BinaryCrossEntropyLoss(reduction=cfg.get('reduction'), to_onehot_y=cfg.get('to_onehot_y')) + result = ce_loss(y.float(), x) + if cfg.get('include_background') and cfg.get('reduction') == 'none': + assert result.shape == torch.Size([2, 4, 10, 10]) + if not cfg.get('include_background') and cfg.get('reduction') == 'none': + assert result.shape == torch.Size([2, 4, 10, 10]) + if cfg.get('reduction') == 'mean': + assert torch.is_floating_point(result) + + +@pytest.mark.parametrize( + "shape, cfg", + [ + ( + [2, 4, 10, 10], + { + "to_onehot_y": True, + "include_background": True, + "reduction": 'none', + "softmax": False, + }, + ), + ( + [2, 4, 10, 10], + { + "to_onehot_y": True, + "include_background": False, + "reduction": 'none', + "Softmax": False, + }, + ), + ( + [2, 4, 10, 10], + { + "to_onehot_y": True, + "include_background": False, + "reduction": 'mean', + "Softmax": True, + }, + ), + ], +) +def test_focal_loss(shape, cfg): + """ + Test focal loss + + Parameters + ---------- + shape : list of int + Shape of the input data + cfg : dict + Dictionary with the parameters of the loss function + """ + x = create_input(shape).requires_grad_() + y = create_input(shape) + y = torch.softmax(y, dim=1) + y = torch.argmax(y, dim=1, keepdim=True) + focal_loss = FocalLoss( + to_onehot_y=cfg.get('to_onehot_y'), + include_background=cfg.get('include_background'), + reduction=cfg.get('reduction'), + ) + result = focal_loss(y, x) + if cfg.get('include_background') and cfg.get('reduction') == 'none': + assert result.shape == torch.Size([2, 4, 10, 10]) + if not cfg.get('include_background') and cfg.get('reduction') == 'none': + assert result.shape == torch.Size([2, 3, 10, 10]) + if cfg.get('reduction') == 'mean': + assert torch.is_floating_point(result) diff --git a/tests/collections/segmentation/losses/test_dice.py b/tests/collections/segmentation/losses/test_dice.py new file mode 100644 index 00000000..326bea28 --- /dev/null +++ b/tests/collections/segmentation/losses/test_dice.py @@ -0,0 +1,172 @@ +# coding=utf-8 +__author__ = "Tim Paquaij" +import pytest +import torch +from atommic.collections.segmentation.losses.dice import Dice, GeneralisedDice +from tests.collections.reconstruction.mri_data.conftest import create_input + + +@pytest.mark.parametrize( + "shape, cfg", + [ + ( + [2, 4, 10, 10], + { + "to_onehot_y": True, + "include_background": True, + "reduction": 'none', + "softmax": True, + "sigmoid": False, + "batch": True, + }, + ), + ( + [2, 4, 10, 10], + { + "to_onehot_y": True, + "include_background": True, + "reduction": 'mean_channel', + "softmax": True, + "sigmoid": False, + "batch": True, + }, + ), + ( + [2, 4, 10, 10], + { + "to_onehot_y": True, + "include_background": False, + "reduction": 'none', + "softmax": True, + "sigmoid": False, + "batch": True, + }, + ), + ( + [2, 4, 10, 10], + { + "to_onehot_y": True, + "include_background": True, + "reduction": 'none', + "softmax": True, + "sigmoid": False, + "batch": False, + }, + ), + ], +) +def test_dice_loss(shape, cfg): + """ + Test Dice Loss + + Parameters + ---------- + shape : list of int + Shape of the input data + cfg : dict + Dictionary with the parameters of the loss function + """ + x = create_input(shape) + y = create_input(shape) + y = torch.softmax(y, dim=1) + y = torch.argmax(y, dim=1, keepdim=True) + dice_loss = Dice( + include_background=cfg.get('include_background'), + to_onehot_y=cfg.get('to_onehot_y'), + batch=cfg.get('batch'), + reduction=cfg.get('reduction'), + softmax=cfg.get('softmax'), + sigmoid=cfg.get('sigmoid'), + ) + _, result = dice_loss(y, x) + if cfg.get('include_background') and cfg.get('batch') and cfg.get('reduction') == 'none': + assert result.shape[0] == 2 and result.shape[1] == x.shape[1] and result.dim() == 2 + elif cfg.get('include_background') and cfg.get('reduction') == 'mean_channel' and cfg.get('batch'): + assert result.shape[0] == 2 and result.dim() == 1 + elif not cfg.get('include_background') and cfg.get('reduction') == 'none' and cfg.get('batch'): + assert result.shape[0] == 2 and result.shape[1] == x.shape[1] - 1 and result.dim() == 2 + elif not cfg.get('batch') and cfg.get('reduction') == 'none': + assert result.shape[0] == 4 and result.dim() == 1 + + +@pytest.mark.parametrize( + "shape, cfg", + [ + ( + [2, 4, 10, 10], + { + "to_onehot_y": True, + "include_background": True, + "reduction": 'none', + "softmax": True, + "sigmoid": False, + "batch": True, + }, + ), + ( + [2, 4, 10, 10], + { + "to_onehot_y": True, + "include_background": True, + "reduction": 'mean_channel', + "softmax": True, + "sigmoid": False, + "batch": True, + }, + ), + ( + [2, 4, 10, 10], + { + "to_onehot_y": True, + "include_background": False, + "reduction": 'none', + "softmax": True, + "sigmoid": False, + "batch": True, + }, + ), + ( + [2, 4, 10, 10], + { + "to_onehot_y": True, + "include_background": True, + "reduction": 'none', + "softmax": True, + "sigmoid": False, + "batch": False, + }, + ), + ], +) +def test_gendice_loss(shape, cfg): + """ + Test Genralised Dice Loss + + Parameters + ---------- + shape : list of int + Shape of the input data + cfg : dict + Dictionary with the parameters of the loss function + """ + x = create_input(shape) + y = create_input(shape) + y = torch.softmax(y, dim=1) + y = torch.argmax(y, dim=1, keepdim=True) + gendice_loss = GeneralisedDice( + include_background=cfg.get('include_background'), + to_onehot_y=cfg.get('to_onehot_y'), + batch=cfg.get('batch'), + reduction=cfg.get('reduction'), + softmax=cfg.get('softmax'), + sigmoid=cfg.get('sigmoid'), + ) + _, result = gendice_loss(y, x) + if cfg.get('include_background') and cfg.get('batch') and cfg.get('reduction') == 'none': + assert result.shape[0] == 2 and result.shape[1] == x.shape[1] and result.dim() == 2 + elif cfg.get('include_background') and cfg.get('reduction') == 'mean_channel' and cfg.get('batch'): + assert result.shape[0] == 2 and result.dim() == 1 + elif not cfg.get('include_background') and cfg.get('reduction') == 'none' and cfg.get('batch'): + assert result.shape[0] == 2 and result.shape[1] == x.shape[1] - 1 and result.dim() == 2 + elif not cfg.get('batch') and cfg.get('reduction') == 'none': + assert result.shape[0] == 4 and result.dim() == 1