Skip to content

TillmannRheude/SegHiResGrad_CAM

Repository files navigation


Seg-HiRes-Grad CAM

Convolutional neural networks (CNNs) achieve prevailing results in segmentation tasks nowadays and represent the state-of-the-art for image-based analysis. However, the understanding of the accurate decision-making process of a CNN is rather unknown. The research area of explainable artificial intelligence (xAI) primarily revolves around understanding and interpreting this black-box behavior. One way of interpreting a CNN is the use of class activation maps (CAMs) that represent heatmaps to indicate the importance of image areas for the prediction of the CNN. For classification tasks, a variety of CAM algorithms exist. But for segmentation tasks, only one CAM algorithm for the interpretation of the output of a CNN exist. We propose a transfer between existing classification- and segmentation-based methods for more detailed, explainable, and consistent results which show salient pixels in semantic segmentation tasks. The resulting Seg-HiRes-Grad CAM is an extension of the segmentation-based Seg-Grad CAM with the transfer to the classification-based HiRes CAM. Our method improves the previously-mentioned existing segmentation-based method by adjusting it to recently published classification-based methods. Especially for medical image segmentation, this transfer solves existing explainability disadvantages.

Demo

Requirements

Python 3.9.12, PyTorch 1.11.0, NumPy 1.21.5, Image PIL, torchvision 0.12.0

Data and Model Preparation

Please structure the data and model as the folders in this repository, e.g. as follows:

├──data 
    ├──dataset_1 (e.g. cityscapes)
          ├──test
              ├──images
                  ├──files.png
              ├──masks
                  ├──files.png
    ├──dataset_2
          ├──...
├──results 
    ├──modelname (e.g. vanilla_unert)
          ├──models
              ├──dataset_1 (e.g. cityscapes)
                  ├──modelfile
    ├──modelname
          ├──...

The convolutional neural network model has to be adjusted (if it is not the U-Net which is in this repository) such that it can be used out of the box with the main.py in this repository, e.g. as the following minimal example:

class Vanilla_UNet_2d(nn.Module):
    def __init__(self, grad_cam=False):
        super(Vanilla_UNet_2d, self).__init__()
        
        self.grad_cam = grad_cam # activate this for .eval() the model 
        
    def activations_hook(self, grad): 
        # for GradCam
        self.gradients = grad
    
    def forward(self, x):
        # do something
        
        # encoder
        
        # i want the explanation of this layer
        # for GradCam
        if self.grad_cam: self.h = x.register_hook(self.activations_hook)
        
        # decoder
    
        return x
    
    def get_act_grad(self):
        # for GradCam
        return self.gradients
    
    def get_act(self, x):
        # for GradCam
        for block in self.encoder:
            # send the tensor to the encoder block and get the encoded tensor before and after the max pooling operation
            x_pre_pool, x_post_pool = block(x)
            # save the encoded tensor before the max pooling operation for the skip connection part later
            x = x_post_pool
        x = x_pre_pool

        return x 

To run Seg(-HiRes)-Grad CAM, the main.py file provides the relevant arguments. Alternatively, one could implement (Seg-HiRes-)Grad CAM by using PyTorch's hooks.

References

Kira Vinogradova, Alexandr Dibrov, et al. “Towards Interpretable Semantic Segmentation via Gradient-Weighted Class Activation Mapping (Student Abstract)”. In: vol. 34.10. Apr. 2020, pp. 13943–13944. DOI: 10.1609/aaai.v34i10.7244.

Marius Cordts, Mohamed Omran, et al. “The Cityscapes Dataset for Semantic Urban Scene Understanding”. In: 2016 IEEE Conference on Computer Vision and Pattern Recognition, CVPR 2016, Las Vegas, NV, USA, June 27-30, 2016. IEEE Computer Society, 2016, pp. 3213–3223. DOI: 10.1109/CVPR.2016.350.

N. Heller et al., “The state of the art in kidney and kidney tumor segmentation in contrast-enhanced CT imaging: Results of the KiTS19 Challenge,” Medical Image Analysis, p. 101821, 2020.

Citation

Please cite the work with the following information (bibtex format):

@article{melba:2024:023:rheude,
    title = "Leveraging CAM Algorithms for Explaining Medical Semantic Segmentation",
    author = "Rheude, Tillmann and Wirtz, Andreas and Kuijper, Arjan and Wesarg, Stefan",
    journal = "Machine Learning for Biomedical Imaging",
    volume = "2",
    issue = "iMIMIC 2023 special issue",
    year = "2024",
    pages = "2089--2102",
    issn = "2766-905X",
    doi = "https://doi.org/10.59275/j.melba.2024-ebd3",
    url = "https://melba-journal.org/2024:023"
}

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages