Application of Convolutional Vision Transformers for the Classification of Infectious Diseases in Chest Radiological images.
Background:
- Global Threats: COVID-19, TB, pneumonia.
- Importance of X-rays: Essential non-invasive tools; however, manual analysis is inconsistent and often inaccurate.
Deep Learning Evolution:
- Last decade: Rise of CNNs for chest X-rays. Key studies: Rajpurkar et al., Wang et al.
- Emergence of Transformers in imaging, comparable to CNNs (Dosovitskiy et al.)
Personal Drive:
- Initial goal: Thesis on deep learning in computer-aided diagnosis.
- Discovery of the chest X-ray dataset during research.
- Influenced by "Do Preprocessing and Class Imbalance Matter to the Deep Image Classifiers for COVID-19 Detection?".
- Decision to tackle this significant issue for my master's thesis solidified.
Project Objective:
- Venture beyond CNNs; explore transformers for chest X-rays.
- Use pre-trained models like vgg19, resnet50 along side a Vision Transformer for disease classification: COVID-19, TB, lung opacity, etc.
- Goal: High accuracy and comprehensive understanding of the model's diagnostic skills.
- Collection: Sourced a comprehensive chest X-ray dataset from open-source platforms.
- Image Preprocessing:
- Zooming: Emphasis on regions of clinical interest.
- CLAHE: Adaptive contrast enhancement preserving essential details.
- Sharpening: Enhancing edges and fine details.
- Scaling: Uniformity in image size and resolution.
- Zero-Centering: Neutralize mean pixel value for faster model convergence.
- Shuffling & Batching: Stratified sampling for balanced class exposure during training.
- Gblobal Class Based Weights: Counter dataset imbalance and prevent model bias.
- Architectures:
- Established architectures: VGG19, ResNet50.
- Introduce Custom CNN and Vision Transformer.
- Evaluation Matrix: Analyze using accuracy, precision, recall, F-score, and ROC-AUC.
- Interpretability:
- Visualization methods: Convolutional visualization & attention map.
- Goal: Understand model decision-making and identify regions of importance.
- Dataset link: https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/WNQ3GI
- A collection of chest radiological images, primarily X-rays publicly released on
25th September, 2021
. - Creators: Researchers from Qatar University, Doha, and Dhaka University, with contributions from Pakistan and Malaysia.
- Expertise: Continuous consultations with medical experts to ensure precision.
- Sources: Mainly from the COVID-19 Radiography Database on Kaggle, augmented with additional Pneumonia and COVID-19 images from other platforms.
- Categories: Originally four - COVID-19, Lung Opacity, Normal, and Viral Pneumonia. Tuberculosis was added to enhance scope and relevance.
- Image Type: Mainly frontal view X-rays. Lateral views excluded for consistency.
Disease Name | Samples | Label |
---|---|---|
COVID-19 | 4,189 | 0 |
Lung Opacity | 6,012 | 1 |
Normal | 10,192 | 2 |
Viral Pneumonia | 7,397 | 3 |
Tuberculosis | 4,897 | 4 |
Total | 32,687 |
-
The dataset serves as a cornerstone for evaluating the potential of deep learning models such as (pre-trained conv-nets, vision transformers or custom convolutional neural network architectures in chest X-ray classification.
-
It is a fairly recent, comprehensive dataset for disease diagnosis or classification from chest radiology images (chest X-rays).
-
The authors of this dataset, split the dataset into
Training
,Testing
andValidation
set.
-
Black-box Nature of Deep Learning Models: Deep models like VGG19 and ResNet50 often lack transparency in decision-making, impacting medical practitioners' trust.
-
Overfitting of Models: Indications from loss-accuracy curves suggest prevalent overfitting in some studies, hindering real-world reliability.
-
Inadequate Performance Metrics: Some studies might have used unsuitable metrics given the class imbalances in chest X-ray datasets.
-
Binary Classification Emphasis: Many studies, including the dataset's authors, narrowly focused on binary classification, possibly overlooking key diagnostic nuances.
-
Local Feature Concentration: Traditional CNNs might not capture the broader context in medical images effectively.
-
Limited Disease Diversity: Some focused primarily on COVID-19, neglecting other respiratory diseases with similar X-ray features.
-
Class Imbalance: The datasets used in most studies were imbalanced. However, no steps were taken to address class imabalance during training. The skewed representation in datasets might induce biased model predictions.
-
Lack of Interpretability Tools: There's a marked absence of tools to illustrate model decisions, essential for clinician trust.
-
Holistic Model Comparison: An exhaustive evaluation of models ranging from VGG19 and ResNet50 to Vision Transformers and custom CNNs will be undertaken.
-
Mitigating Overfitting: This work will use techniques like regularization and data augmentation to combat overfitting.
-
Comprehensive Performance Metrics: Given the dataset imbalances, appropriate metrics will be employed for a genuine model assessment. Metrics such as precision, recall, f1 score and Area Under Curve of Receiving Operating Characteristics.
-
Embracing Multi-class Classification: Shifting focus from binary to multi-class classification to capture the intricacies of respiratory diseases.
-
Incorporating Global Context with Transformers: Vision Transformers could address the limitation of traditional CNNs by recognizing broader image contexts.
-
Enhanced Interpretability with Custom CNN: The custom CNN will emphasize diagnostic regions, balancing performance with transparency due to the use of spatial convolutional attention.
-
Efficiency and Performance Equilibrium: The custom CNN, with features like depth-wise separable convolution, aims for a balance between efficiency and accuracy.
-
Addressing Black-box Models via model interpretability: This research will focus on illustrating the decision-making of deep classifiers by using gradient or activation visualization tools such as GRAD-CAM or Guided GRAD-CAM as well as Attention map visualization.
-
Effective Handling of Class Imbalance: Class Imbalance will be handled during training by using global class based weighting.
1. How does the performance of Convolutional Vision Transformers compare with traditional Convolutional Neural Networks and other deep learning models such as VGG19, ResNet50, in the classification of chest radiological images of infectious respiratory diseases?
2. What insights can be derived from activation visualization or attention map visualization about the decision-making process of the pre-trained vgg19, resnet50 models or custom CNN and vision transformer models, respectively, in predicting respiratory diseases from chest radiological images?
-
The vision transformer model and the custom CNN model was trained from scratch.
-
In case of Pre-trained Models, a 2-step training protocol is adopted:
-
Transfer Learning:
- This involves harnessing the learned features from a previously trained model to accelerate and enhance the training for a new, related task.
- Initially, only the top layers (often the fully connected layers) are retrained to adapt to the new task while the base layers are kept frozen.
- At this stage, the pre-trained model effectively acts as a feature extractor, with its learned patterns benefiting the new task.
-
Fine-Tuning:
- This phase dives deeper, adjusting not just the top layers but also some or all of the deeper layers of the pre-trained model.
- By doing so, the model is further specialized, tailoring its learned features more closely to the nuances of the specific task at hand.
-
Class Weights:
0 | 1.5613057324840764 |
1 | 1.0873301912947047 |
2 | 0.6413736713000817 |
3 | 0.8837314105452907 |
4 | 1.334921715452689 |
-
Variation Across Classes:
- Different diseases exhibit unique pixel intensity distributions, reflecting their distinct radiological manifestations.
- These variations are indicative of the specific pathophysiological changes associated with each condition.
-
Significance of Outliers:
- Outliers in intensity distributions can stem from imaging artifacts, data inconsistencies, or rare case presentations.
- Their presence demands careful preprocessing, considering their potential to either introduce noise or represent crucial data points.
-
Potential for Model Discrimination:
- The disparities in intensity histograms among most classes bolster the claim that deep learning models can effectively differentiate between these classes.
-
Challenges with 'Lung Opacity' and 'Normal' Classes:
- These two classes display nearly identical intensity distributions, highlighting possible challenges in model differentiation.
- Such overlap underscores the importance of advanced feature extraction or specific modeling techniques.
-
Class-specific Observations:
- COVID-19: Consistent imaging or disease presentation is hinted at by its near-normal distribution.
- Lung Opacity: Multiple radiological patterns might be present, as suggested by the secondary peak post-mean.
- Normal: Despite its resemblance to 'Lung Opacity', a notable secondary peak implies potential overlap with pathological conditions.
- Viral Pneumonia: Its histogram suggests a mix of typical and atypical disease presentations in the dataset.
- Tuberculosis: Two primary radiological patterns seem to dominate, as indicated by the twin peaks centered around the mean.
- Zoom
- CLAHE
- Image Sharpening
- Rescaling
- 0 centering
Basic Sharpening Kernel:
0 | -1 | 0 |
-1 | 5 | -1 |
0 | -1 | 0 |
Laplacian Sharpening Kernel:
0 | 1 | 0 |
1 | -4 | 1 |
0 | 1 | 0 |
Diagonal Edge Sharpening:
-1 | 0 | -1 |
0 | 8 | 0 |
-1 | 0 | -1 |
Exaggerated Sharpening:
-1 | -1 | -1 |
-1 | 9 | -1 |
-1 | -1 | -1 |
Another variant of Laplacian:
1 | 1 | 1 |
1 | -7 | 1 |
1 | 1 | 1 |
- some
- some
- some
Combining multiple techniques offers a synergistic effect for X-ray image optimization. A stepwise methodology was employed to highlight the significance of each stage:
-
Zoom (Factor of 0.83):
- Purpose: Crop out potential artifacts and noise along the image borders.
- Benefits: Standardizes the primary region of interest due to diverse radiological image presentations from various sources.
-
CLAHE (Clip Limit of 2 and Grid Size of 15):
- Purpose: Enhance image contrast.
- Benefits: Especially effective for X-rays with dark backgrounds, it provides optimized contrast while preserving genuine image features.
-
Basic Sharpening Kernel:
- Purpose: Accentuate the edges and details within the image.
- Benefits: Sharpens the image without adding unnecessary noise or artifacts, making the structures within the chest clearly visible.
-
Rescaling:
- Purpose: Normalize pixel values to a specific range, commonly between 0 and 1.
- Benefits: Ensures consistent input scale for deep learning models, which can aid in faster and more stable convergence during training.
-
Zero-Centering:
- Purpose: Shift pixel values so that the mean is zero.
- Benefits: Reduces the dominance of particular pixel intensity ranges, allowing models to learn features more effectively.
Key Observations from Visualization:
- Zooming: Effectively eliminates extraneous details, offering a clearer view of the primary chest region.
- CLAHE: Achieves uniform and enhanced contrast, which makes minute details more discernible.
- Sharpening: Further emphasizes the edges and contours, clearly outlining structures within the chest.
Conclusion: The comprehensive approach of Zoom, CLAHE, Image Sharpening, Rescaling, and Zero-Centering offers a robust and efficient pre-processing pipeline for chest X-rays. Each step complements the others, ensuring the transformed image is medically relevant, standardized, and optimized for deep learning applications.
- Shuffled, Batched, and One-Hot Encoded.
- Type:
tf.data.Dataset
(MapDataset, PrefetchDataset)
- Efficiently handles large datasets
- Seamless data feeding for GPU/TPU
- On-the-fly transformations with
map()
- Overlaps preprocessing & execution with
prefetch()
- Efficient batching and shuffling
- Memory-efficient & integrated with TensorFlow ecosystem
- Seed: 123
- Batch Size: 128
- Image Size: 224
- Classes: 5
- Set random seeds for both numpy and TensorFlow
- Histograms for Training, Testing, and Validation sets.
- Maintained class ratio from original to final pre-processed batches.
- Preliminary Simple Convnet Model to ensure the final datasets (Train,Validation and Testing) were transformed and compiled correctly.
- VGG19.
- Resnet50.
- Vision Transformer.
- Custom Convnet Architecture based on skip connection, depthwise separable conv layers, spatial convolutional attention, dubbed, CustomCNN.
-
Comprehensive Performance Analysis: Using a mix of traditional architectures like VGG19 and Resnet50, alongside modern models such as Vision Transformers and a tailored architecture, provides a holistic evaluation of their efficacy in chest X-ray classification. This diversity allows for a rigorous cross-architecture comparison.
-
Building Upon Previous Work: Traditional CNNs like VGG19 and Resnet50 have established successes in image classification. Assessing their fit for chest X-rays helps understand their specific advantages and challenges in this domain.
-
Exploring New Paradigms: Vision Transformers, diverging from the traditional CNN approach, raise a question: can segmenting images into patches and viewing them as sequences enhance medically-relevant feature extraction?
-
Customized Approach: CustomCNN, emphasizing skip connections, depthwise separable conv layers, and spatial convolutional attention, is crafted to better discern the nuances in chest X-rays. It tackles specific challenges like overlapping structures and subtle anomalies, which may be missed by general architectures.
-
Addressing Interpretability: Deep learning models can often be "black-boxes", especially concerning in medical contexts. The custom architecture aims to bolster model transparency, spotlighting diagnostically significant regions.
-
Flexibility and Adaptability: Evaluating a variety of models offers flexibility. Should one model underperform or face unexpected challenges, alternatives are at the ready to ensure project goals are met.
In summary, the chosen models represent a balance of reliable methods and innovative paradigms, all geared towards optimal chest X-ray classification.
- Trained a basic ConvNet on pre-processed chest x-ray dataset.
- Objective: Verify correct image pre-processing & achieve good performance metrics.
- This step preceded transfer learning and fine-tuning with VGG19 and ResNet50.
Layer Type | Output Shape | Param # |
---|---|---|
InputLayer | (None, 224, 224, 3) | 0 |
Conv2D | (None, 224, 224, 64) | 1,792 |
Activation | (None, 224, 224, 64) | 0 |
Conv2D | (None, 224, 224, 64) | 36,928 |
Activation | (None, 224, 224, 64) | 0 |
MaxPooling2D | (None, 112, 112, 64) | 0 |
Dropout | (None, 112, 112, 64) | 0 |
... | ... | ... |
Activation | (None, 256) | 0 |
Dropout | (None, 256) | 0 |
Dense | (None, 5) | 1,285 |
Total params: 104,038,981
Trainable params: 104,038,981
Non-trainable params: 0
Label | Precision | Recall | F1-Score | Support |
---|---|---|---|---|
0 | 0.88 | 0.95 | 0.91 | 838 |
1 | 0.88 | 0.80 | 0.84 | 1203 |
2 | 0.89 | 0.92 | 0.90 | 2039 |
3 | 0.96 | 0.96 | 0.96 | 1480 |
4 | 0.98 | 0.96 | 0.97 | 980 |
- Accuracy: 0.92
- Macro Avg: Precision: 0.92, Recall: 0.92, F1-Score: 0.92, Support: 6540
- Weighted Avg: Precision: 0.92, Recall: 0.92, F1-Score: 0.92, Support: 6540
Layer Type | Output Shape | Param # |
---|---|---|
InputLayer | (None, 224, 224, 3) | 0 |
block1_conv1 (Conv2D) | (None, 224, 224, 64) | 1,792 |
block1_conv2 (Conv2D) | (None, 224, 224, 64) | 36,928 |
block1_pool (MaxPooling2D) | (None, 112, 112, 64) | 0 |
block2_conv1 (Conv2D) | (None, 112, 112, 128) | 73,856 |
block2_conv2 (Conv2D) | (None, 112, 112, 128) | 147,584 |
block2_pool (MaxPooling2D) | (None, 56, 56, 128) | 0 |
block3_conv1 (Conv2D) | (None, 56, 56, 256) | 295,168 |
block3_conv2 (Conv2D) | (None, 56, 56, 256) | 590,080 |
block3_conv3 (Conv2D) | (None, 56, 56, 256) | 590,080 |
block3_conv4 (Conv2D) | (None, 56, 56, 256) | 590,080 |
block3_pool (MaxPooling2D) | (None, 28, 28, 256) | 0 |
block4_conv1 (Conv2D) | (None, 28, 28, 512) | 1,180,160 |
block4_conv2 (Conv2D) | (None, 28, 28, 512) | 2,359,808 |
block4_conv3 (Conv2D) | (None, 28, 28, 512) | 2,359,808 |
block4_conv4 (Conv2D) | (None, 28, 28, 512) | 2,359,808 |
block4_pool (MaxPooling2D) | (None, 14, 14, 512) | 0 |
block5_conv1 (Conv2D) | (None, 14, 14, 512) | 2,359,808 |
block5_conv2 (Conv2D) | (None, 14, 14, 512) | 2,359,808 |
block5_conv3 (Conv2D) | (None, 14, 14, 512) | 2,359,808 |
block5_conv4 (Conv2D) | (None, 14, 14, 512) | 2,359,808 |
block5_pool (MaxPooling2D) | (None, 7, 7, 512) | 0 |
global_average_pooling2d (GlobalAveragePooling2D) | (None, 512) | 0 |
dense (Dense) | (None, 4096) | 2,101,248 |
dropout (Dropout) | (None, 4096) | 0 |
dense_1 (Dense) | (None, 4096) | 16,781,312 |
dense_2 (Dense) | (None, 5) | 20,485 |
- Total params: 38,927,429
- Trainable params: 38,927,429
- Non-trainable params: 0
precision | recall | f1-score | support | |
---|---|---|---|---|
0-Covid 19 | 0.94 | 0.95 | 0.95 | 838 |
1-Lung Opacity | 0.89 | 0.92 | 0.90 | 1203 |
2-Normal | 0.94 | 0.93 | 0.94 | 2039 |
3-Viral Pneumonia | 0.98 | 0.96 | 0.97 | 1480 |
4-Tuberculosis | 0.97 | 0.98 | 0.98 | 980 |
accuracy | 0.95 | 6540 | ||
macro avg | 0.95 | 0.95 | 0.95 | 6540 |
weighted avg | 0.95 | 0.95 | 0.95 | 6540 |
Layer (type) | Output Shape | Param # | Connected to |
---|---|---|---|
input_1 (InputLayer) | [(None, 224, 224, 3)] | 0 | |
conv1_pad (ZeroPadding2D) | (None, 230, 230, 3) | 0 | input_1[0][0] |
conv1_conv (Conv2D) | (None, 112, 112, 64) | 9472 | conv1_pad[0][0] |
conv1_bn (BatchNormalization) | (None, 112, 112, 64) | 256 | conv1_conv[0][0] |
conv1_relu (Activation) | (None, 112, 112, 64) | 0 | conv1_bn[0][0] |
pool1_pad (ZeroPadding2D) | (None, 114, 114, 64) | 0 | conv1_relu[0][0] |
pool1_pool (MaxPooling2D) | (None, 56, 56, 64) | 0 | pool1_pad[0][0] |
conv2_block1_1_conv (Conv2D) | (None, 56, 56, 64) | 4160 | pool1_pool[0][0] |
... | ... | ... | ... |
... | ... | ... | ... |
... | ... | ... | ... |
conv5_block3_1_conv (Conv2D) | (None, 7, 7, 512) | 1049088 | conv5_block2_out[0][0] |
conv5_block3_1_bn (BatchNormalization) | (None, 7, 7, 512) | 2048 | conv5_block3_1_conv[0][0] |
conv5_block3_1_relu (Activation) | (None, 7, 7, 512) | 0 | conv5_block3_1_bn[0][0] |
conv5_block3_2_conv (Conv2D) | (None, 7, 7, 512) | 2359808 | conv5_block3_1_relu[0][0] |
conv5_block3_2_bn (BatchNormalization) | (None, 7, 7, 512) | 2048 | conv5_block3_2_conv[0][0] |
conv5_block3_2_relu (Activation) | (None, 7, 7, 512) | 0 | conv5_block3_2_bn[0][0] |
conv5_block3_3_conv (Conv2D) | (None, 7, 7, 2048) | 1050624 | conv5_block3_2_relu[0][0] |
conv5_block3_3_bn (BatchNormalization) | (None, 7, 7, 2048) | 8192 | conv5_block3_3_conv[0][0] |
conv5_block3_add (Add) | (None, 7, 7, 2048) | 0 | conv5_block2_out[0][0], conv5_block3_3_bn[0][0] |
conv5_block3_out (Activation) | (None, 7, 7, 2048) | 0 | conv5_block3_add[0][0] |
global_average_pooling2d (GlobalAveragePooling2D) | (None, 2048) | 0 | conv5_block3_out[0][0] |
dense (Dense) | (None, 5) | 10245 | global_average_pooling2d[0][0] |
Total params: 23,597,957
Trainable params: 23,544,837
Non-trainable params: 53,120
precision | recall | f1-score | support | |
---|---|---|---|---|
0 | 0.97 | 0.91 | 0.94 | 838 |
1 | 0.94 | 0.87 | 0.90 | 1203 |
2 | 0.91 | 0.97 | 0.94 | 2039 |
3 | 0.93 | 0.99 | 0.95 | 1480 |
4 | 0.99 | 0.89 | 0.94 | 980 |
accuracy | 0.94 | 6540 | ||
macro avg | 0.95 | 0.93 | 0.94 | 6540 |
weighted avg | 0.94 | 0.94 | 0.94 | 6540 |
Layer (type) | Output Shape | Param # | Connected to |
---|---|---|---|
input_layer (InputLayer) | (None, 224, 224, 3) | 0 | [] |
spatialAttention_conv_preVGG_Spatial_Attention (Conv2D) | (None, 224, 224, 1) | 28 | ['input_layer[0][0]'] |
spatialAttention_multiply_preVGG_Spatial_Attention (Multiply) | (None, 224, 224, 3) | 0 | ['input_layer[0][0]', 'spatialAttention_conv_preVGG_Spatial_Attention[0][0]'] |
vggBlock_1_conv_1 (Conv2D) | (None, 224, 224, 32) | 896 | ['spatialAttention_multiply_preVGG_Spatial_Attention[0][0]'] |
vggBlock_1_bn_1 (BatchNormalization) | (None, 224, 224, 32) | 128 | ['vggBlock_1_conv_1[0][0]'] |
vggBlock_1_act_1 (Activation) | (None, 224, 224, 32) | 0 | ['vggBlock_1_bn_1[0][0]'] |
vggBlock_1_conv_2 (Conv2D) | (None, 224, 224, 32) | 9248 | ['vggBlock_1_act_1[0][0]'] |
vggBlock_1_bn_2 (BatchNormalization) | (None, 224, 224, 32) | 128 | ['vggBlock_1_conv_2[0][0]'] |
vggBlock_1_act_2 (Activation) | (None, 224, 224, 32) | 0 | ['vggBlock_1_bn_2[0][0]'] |
spatialAttention_conv_prePoolVGG_1 (Conv2D) | (None, 224, 224, 1) | 289 | ['vggBlock_1_act_2[0][0]'] |
resBlock_1_adjust_conv (Conv2D) | (None, 224, 224, 32) | 128 | ['spatialAttention_multiply_preVGG_Spatial_Attention[0][0]'] |
spatialAttention_multiply_prePoolVGG_1 (Multiply) | (None, 224, 224, 32) | 0 | ['vggBlock_1_act_2[0][0]', 'spatialAttention_conv_prePoolVGG_1[0][0]'] |
resBlock_1_adjust_bn (BatchNormalization) | (None, 224, 224, 32) | 128 | ['resBlock_1_adjust_conv[0][0]'] |
vggBlock_1_pool (MaxPooling2D) | (None, 112, 112, 32) | 0 | ['spatialAttention_multiply_prePoolVGG_1[0][0]'] |
resBlock_1_adjust_pool (MaxPooling2D) | (None, 112, 112, 32) | 0 | ['resBlock_1_adjust_bn[0][0]'] |
resBlock_1_add (Add) | (None, 112, 112, 32) | 0 | ['vggBlock_1_pool[0][0]', 'resBlock_1_adjust_pool[0][0]'] |
... | ... | ... | ... |
... | ... | ... | ... |
... | ... | ... | ... |
vggBlock_4_conv_1 (Conv2D) | (None, 28, 28, 256) | 295168 | ['resBlock_3_add[0][0]'] |
vggBlock_4_bn_1 (BatchNormalization) | (None, 28, 28, 256) | 1024 | ['vggBlock_4_conv_1[0][0]'] |
vggBlock_4_act_1 (Activation) | (None, 28, 28, 256) | 0 | ['vggBlock_4_bn_1[0][0]'] |
vggBlock_4_conv_2 (Conv2D) | (None, 28, 28, 256) | 590080 | ['vggBlock_4_act_1[0][0]'] |
vggBlock_4_bn_2 (BatchNormalization) | (None, 28, 28, 256) | 1024 | ['vggBlock_4_conv_2[0][0]'] |
vggBlock_4_act_2 (Activation) | (None, 28, 28, 256) | 0 | ['vggBlock_4_bn_2[0][0]'] |
spatialAttention_conv_prePoolVGG_4 (Conv2D) | (None, 28, 28, 1) | 2305 | ['vggBlock_4_act_2[0][0]'] |
resBlock_4_adjust_conv (Conv2D) | (None, 28, 28, 256) | 33024 | ['resBlock_3_add[0][0]'] |
spatialAttention_multiply_prePoolVGG_4 (Multiply) | (None, 28, 28, 256) | 0 | ['vggBlock_4_act_2[0][0]', 'spatialAttention_conv_prePoolVGG_4[0][0]'] |
resBlock_4_adjust_bn (BatchNormalization) | (None, 28, 28, 256) | 1024 | ['resBlock_4_adjust_conv[0][0]'] |
vggBlock_4_pool (MaxPooling2D) | (None, 14, 14, 256) | 0 | ['spatialAttention_multiply_prePoolVGG_4[0][0]'] |
resBlock_4_adjust_pool (MaxPooling2D) | (None, 14, 14, 256) | 0 | ['resBlock_4_adjust_bn[0][0]'] |
resBlock_4_add (Add) | (None, 14, 14, 256) | 0 | ['vggBlock_4_pool[0][0]', 'resBlock_4_adjust_pool[0][0]'] |
global_avg_pool (GlobalAveragePooling2D) | (None, 256) | 0 | ['resBlock_4_add[0][0]'] |
dense_1 (Dense) | (None, 1024) | 263168 | ['global_avg_pool[0][0]'] |
dropout_1 (Dropout) | (None, 1024) | 0 | ['dense_1[0][0]'] |
dense_2 (Dense) | (None, 1024) | 1049600 | ['dropout_1[0][0]'] |
output_layer (Dense) | (None, 5) | 5125 | ['dense_2[0][0]'] |
Total params: 2,543,845
Trainable params: 2,540,965
Non-trainable params: 2
Precision | Recall | F1-Score | Support | |
---|---|---|---|---|
0 | 0.91 | 0.89 | 0.90 | 838 |
1 | 0.87 | 0.86 | 0.86 | 1203 |
2 | 0.92 | 0.92 | 0.92 | 2039 |
3 | 0.96 | 0.97 | 0.97 | 1480 |
4 | 0.97 | 0.97 | 0.97 | 980 |
----------- | ----------- | -------- | ---------- | --------- |
Accuracy | 0.92 | 6540 | ||
Macro Avg | 0.92 | 0.92 | 0.92 | 6540 |
Weighted Avg | 0.92 | 0.92 | 0.92 | 6540 |
Layer (type) | Output Shape | Param # | Connected to |
---|---|---|---|
input_1 (InputLayer) | (None, 224, 224, 3) | 0 | - |
tf.image.extract_patches (TFOpLambda) | (None, 11, 11, 1200) | 0 | input_1[0][0] |
reshape (Reshape) | (None, 121, 1200) | 0 | tf.image.extract_patches[0][0] |
dense (Dense) | (None, 121, 768) | 922,368 | reshape[0][0] |
token_and_position_embedding | (None, 122, 768) | 94,464 | dense[0][0] |
layer_normalization (LayerNormalization) | (None, 122, 768) | 1,536 | token_and_position_embedding[0][0] |
multi_head_attention (MultiHeadAttention) | (None, 122, 768) | 11,808,768 | layer_normalization[0][0], layer_normalization[0][0] |
add (Add) | (None, 122, 768) | 0 | multi_head_attention[0][0], token_and_position_embedding[0][0] |
layer_normalization_1 (LayerNormalization) | (None, 122, 768) | 1,536 | add[0][0] |
dense_1 (Dense) | (None, 122, 3072) | 2,362,368 | layer_normalization_1[0][0] |
dropout (Dropout) | (None, 122, 3072) | 0 | dense_1[0][0] |
dense_2 (Dense) | (None, 122, 768) | 2,360,064 | dropout[0][0] |
add_1 (Add) | (None, 122, 768) | 0 | dense_2[0][0], layer_normalization_1[0][0] |
layer_normalization_2 (LayerNormalization) | (None, 122, 768) | 1,536 | add_1[0][0] |
multi_head_attention_1 (MultiHeadAttention) | (None, 122, 768) | 11,808,768 | layer_normalization_2[0][0], layer_normalization_2[0][0] |
add_2 (Add) | (None, 122, 768) | 0 | multi_head_attention_1[0][0], add_1[0][0] |
layer_normalization_3 (LayerNormalization) | (None, 122, 768) | 1,536 | add_2[0][0] |
... | ... | ... | ... |
... | ... | ... | ... |
... | ... | ... | ... |
dense_9 (Dense) | (None, 122, 3072) | 2,362,368 | layer_normalization_9[0][0] |
dropout_4 (Dropout) | (None, 122, 3072) | 0 | dense_9[0][0] |
dense_10 (Dense) | (None, 122, 768) | 2,360,064 | dropout_4[0][0] |
add_9 (Add) | (None, 122, 768) | 0 | dense_10[0][0], layer_normalization_9[0][0] |
dense_11 (Dense) | (None, 122, 1) | 769 | add_9[0][0] |
tf.math.multiply (TFOpLambda) | (None, 122, 768) | 0 | add_9[0][0], dense_11[0][0] |
tf.math.reduce_sum (TFOpLambda) | (None, 768) | 0 | tf.math.multiply[0][0] |
dense_12 (Dense) | (None, 2096) | 1,611,824 | tf.math.reduce_sum[0][0] |
dropout_5 (Dropout) | (None, 2096) | 0 | dense_12[0][0] |
dense_13 (Dense) | (None, 2096) | 4,395,312 | dropout_5[0][0] |
global_average_pooling1d (GlobalAveragePooling1D) | (None, 768) | 0 | add_9[0][0] |
concatenate (Concatenate) | (None, 2864) | 0 | dense_13[0][0], global_average_pooling1d[0][0] |
dense_14 (Dense) | (None, 5) | 14,325 | concatenate[0][0] |
Total params: 89,710,422
Trainable params: 89,710,422
Non-trainable params: 0
Precision | Recall | F1-Score | Support | |
---|---|---|---|---|
0 | 0.59 | 0.67 | 0.63 | 838 |
1 | 0.74 | 0.73 | 0.73 | 1203 |
2 | 0.82 | 0.80 | 0.81 | 2039 |
3 | 0.90 | 0.86 | 0.88 | 1480 |
4 | 0.81 | 0.81 | 0.81 | 980 |
------------ | --------- | ------ | ------- | ------- |
accuracy | 0.79 | 6540 | ||
macro avg | 0.77 | 0.77 | 0.77 | 6540 |
weighted avg | 0.79 | 0.79 | 0.79 | 6540 |
Research Question 1 : How does the performance of Convolutional Vision Transformers compare with traditional Convolutional Neural Networks and other deep learning models such as VGG19, ResNet50, in the classification of chest radiological images of infectious respiratory diseases?
Comparing the performance of Convolutional Vision Transformers (ViT) with traditional Convolutional Neural Networks (CNNs) like VGG19, ResNet50, and the custom CNN architecture for chest radiological images provides several insights.
Model Name | Loss | Test Accuracy | Precision | Recall | AUC | F1 Score | Top-2-Accuracy |
---|---|---|---|---|---|---|---|
VGG19 | 0.159846 | 0.945260 | 0.946319 | 0.943425 | 0.994903 | 0.944870 | Not Available |
CustomCNN | 0.208691 | 0.970000 | 0.928472 | 0.920948 | 0.992963 | 0.924695 | Not Available |
ResNet50 | 0.222941 | 0.936697 | 0.939757 | 0.935015 | 0.990741 | 0.937380 | Not Available |
VIT | 0.662379 | 0.786544 | 0.811604 | 0.761468 | 0.951327 | 0.785737 | 0.936544 |
ViTs, especially those adapted from "ViT b16", shine when trained on vast datasets. The study's dataset might not be expansive enough to fully leverage ViTs, and their computational demands could hinder optimal training, as indicated by the elevated loss and accuracy metrics during training.
CNNs like VGG19 and ResNet50, optimized for image tasks over the years, display superior performance. The custom CNN, with its specialized layers, further underlines the effectiveness of task-tailored CNNs.
While Vision Transformers have made significant advancements in computer vision, when classifying chest radiological images—with current dataset and computational resources—they may be outperformed by established CNNs like VGG19, ResNet50, and the custom CNN model.
-
Research Question 2 : What insights can be derived from activation visualization or attention map visualization about the decision-making process of the pre-trained vgg19, resnet50 models or custom CNN and vision transformer models, respectively, in predicting respiratory diseases from chest radiological images?
-
Insights from Activation and Attention Map Visualizations in Chest Radiological Images Classification
-
Activation and attention map visualizations offer transparency into neural networks, highlighting areas deemed significant during predictions. Such insights are crucial for medical imaging applications, where the understanding of a model's focus can aid in validation and trust.
-
Using gradient-based visualization, ResNet50 pinpoints distinct chest X-ray features indicative of respiratory diseases, like opacities. However, its attention sometimes extends to noise and image edges. This could be influenced by the zoom preprocessing, emphasizing both critical regions and noise.
Gradient visualizations for VGG19 revealed its attention predominantly on image noise and artifacts, potentially due to the absence of a zoom preprocessing step. This highlights the importance of diligent preprocessing in medical imaging.
Spatial attention maps from the custom CNN delineate its focus from broad chest X-ray regions to specific areas like the lungs and then potential pathological zones. This layered attention mirrors a radiologist's diagnostic process. The integration of spatial attention mechanisms boosts the model's ability to focus on vital image areas, essential for medical imaging tasks.
ViTs, with their self-attention mechanisms, process images as sequences of non-overlapping patches. Though the attention maps from our ViT experiments exhibited inconsistencies, ideally, ViTs would emphasize diagnostic-relevant patches, providing a comprehensive image interpretation. However, their efficiency is deeply tied to training data quality and size.
While visualizations grant insights into model decision-making, they should be interpreted with care. Sole reliance on network focus isn't conclusive proof of correct predictions. Clinical validation remains vital. Nevertheless, these visualization tools undoubtedly foster understanding and collaboration between neural networks and clinicians.
-
Deep Learning in Medical Imaging
- Application: Both opportunities and challenges observed.
- Focus on:
- CNNs: Traditional architectures like VGG19 and ResNet50.
- ViTs: Emerging model in the landscape.
- Task: Classify chest radiological images of infectious respiratory diseases.
- Additional Investigation:
- Custom convolution network architecture with features:
- Spatial attention
- Depth-wise convolutions
- VGG and ResNet inspired blocks.
- Custom convolution network architecture with features:
-
Outcomes
- Vision Transformers
- Advancements recognized in general computer vision domain.
- Limitations: Medical imaging, especially with dataset size and computational constraints.
- Traditional CNNs
- Demonstrated superior performance.
- Benefit: Spatial hierarchies tailored for image data.
- Custom CNN: Highlighted advantages of task-specific models.
- Model Interpretability
- Emphasis using:
- Gradient-based visualization
- Attention map methodologies.
- Purpose: Transparent insight into model decisions.
- Opportunity: Enhanced collaboration between ML models and medical professionals.
- Emphasis using:
- Vision Transformers
- Dataset Expansion : ViTs typically perform better on larger datasets.
- Incorporation of Transfer Learning : Incorporate other pre-trained models like inception, xception etc for a more robust and comprehensive comparison.
- Model Fusion and Ensembling : Combine the strengths of CNNs and ViTs.
- Clinical Integration and Validation : Collaborative efforts with radiologists.
- Enhanced Model Interpretability : Beyond current methods, explore advanced interpretability techniques.
- Real-time Application : Integrate models into real-time diagnostic platforms for radiologists.