A state-of-the-art Vision Transformer model for detecting pneumonia in chest X-ray images, achieving 91.10% test accuracy and 97.14% AUC.
| Metric | Score |
|---|---|
| Test Accuracy | 91.10% |
| Test Precision | 92.07% |
| Test Recall | 91.10% |
| Test F1-Score | 91.32% |
| Test AUC | 97.14% |
| Cohen's Kappa | 78.71% |
- Predicted: Normal
- Confidence: 63.8%
- Normal Probability: 63.8%
- Pneumonia Probability: 36.2%
- Predicted: Normal
- Confidence: 74.9%
- Normal Probability: 74.9%
- Pneumonia Probability: 25.1%
- Predicted: Normal
- Confidence: 77.5%
- Normal Probability: 77.5%
- Pneumonia Probability: 22.5%
- Predicted: Pneumonia
- Confidence: 84.4%
- Normal Probability: 15.6%
- Pneumonia Probability: 84.4%
- Predicted: Pneumonia
- Confidence: 82.8%
- Normal Probability: 17.2%
- Pneumonia Probability: 82.8%
- Predicted: Pneumonia
- Confidence: 82.6%
- Normal Probability: 17.4%
- Pneumonia Probability: 82.6%
# Clone the repository
git clone https://github.com/nithinyanna10/vision_transformers.git
cd vision_transformers
# Create virtual environment
python3 -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate
# Install dependencies
pip install -r requirements.txtThe dataset is not included in this repository due to size constraints. You need to download it separately:
Option 1: Using Kaggle API (Recommended)
# Install kagglehub if not already installed
pip install kagglehub
# Download the dataset
python download_dataset.pyOption 2: Manual Download
- Go to Chest X-Ray Images (Pneumonia) Dataset
- Download the dataset
- Extract it to
data/chest_xray/directory - Ensure the structure is:
data/chest_xray/ ├── train/ │ ├── NORMAL/ │ └── PNEUMONIA/ ├── val/ │ ├── NORMAL/ │ └── PNEUMONIA/ └── test/ ├── NORMAL/ └── PNEUMONIA/
python train.pyNote: The trained model checkpoint is not included in this repository due to size constraints. You'll need to train the model yourself using the above command. Training typically takes 20-30 minutes on a modern GPU or 2-3 hours on CPU.
streamlit run app.py- Vision Transformer (ViT) with 12 layers and 8 attention heads
- Embedding Dimension: 512
- Patch Size: 16x16
- Input Resolution: 224x224
- Parameters: ~41.7M
- Advanced Data Augmentation: Medical-specific augmentations including elastic transforms, CLAHE, and gamma correction
- Focal Loss: Handles class imbalance effectively
- Mixed Precision Training: Optimized for Apple Silicon (MPS) and CUDA
- Comprehensive Evaluation: ROC curves, confusion matrices, Grad-CAM visualizations
- Interactive Web App: Streamlit-based demo with real-time predictions
- Source: Chest X-Ray Images (Pneumonia) - Kaggle
- Total Images: 5,840
- Classes: Normal (1,575), Pneumonia (4,265)
- Split: 80% Train, 10% Validation, 10% Test
- Epochs: 20
- Batch Size: 16
- Learning Rate: 0.0001 (with cosine scheduling)
- Optimizer: AdamW with weight decay
- Loss Function: Focal Loss (α=0.25, γ=2.0)
Sample Accuracy: 91.7% on 12 test images
The model demonstrates excellent performance on both normal and pneumonia cases, with high confidence predictions and robust generalization.
vision_transformers/
├── src/
│ ├── model.py # Vision Transformer architecture
│ ├── trainer.py # Training loop and optimization
│ ├── evaluator.py # Model evaluation and visualization
│ └── data_loader.py # Data loading and augmentation
├── configs/
│ └── config.yaml # Model and training configuration
├── results/ # Generated test results and visualizations
├── models/ # Trained model checkpoints
├── app.py # Streamlit web application
├── train.py # Main training script
├── download_dataset.py # Dataset download script
└── requirements.txt # Python dependencies
Note: The data/ folder is not included in this repository. Download the dataset using the instructions above.
Our project includes comprehensive explainability features to understand how the AI makes decisions:
Shows which regions of the chest X-ray the model focuses on for its diagnosis:
Reveals how the transformer processes different parts of the image:
Identifies the most important features that influence the prediction:
Combines multiple explanation methods for complete understanding:
- ResNet50 backbone + Vision Transformer with cross-attention
- 63M+ parameters for enhanced feature extraction
- Cross-attention mechanisms between CNN and ViT features
- Parallel processing with different patch sizes (8x8, 16x16, 32x32)
- 60M+ parameters for multi-resolution analysis
- Scale-aware feature fusion for comprehensive understanding
- Weighted voting of multiple architectures
- Robust performance through model diversity
- Improved accuracy and generalization
Try our interactive explainability demo:
# Run the XAI demonstration
python demo_xai.py
# Launch interactive Streamlit app
streamlit run xai_demo.pyCompare different architectures with comprehensive metrics:
# Run model comparison
python compare_models.py
# Train advanced models
python train_advanced.py --model_type hybrid_cnn_vitPre-train the model using Masked Autoencoder approach for better feature learning:
Make the model robust against adversarial attacks:
Understand model confidence and prediction reliability:
# Train with all advanced techniques
python train_advanced_techniques.py --model_type hybrid_cnn_vit --use_mae --use_adversarial --use_uncertainty
# Train with specific techniques
python train_advanced_techniques.py --use_adversarial # Only adversarial training
python train_advanced_techniques.py --use_uncertainty # Only uncertainty quantificationComplete MLOps pipeline with experiment tracking, model versioning, and automated CI/CD:
Comprehensive data validation system ensuring data quality:
# Run data validation
python -c "from mlops.data_validation import DataValidator; DataValidator().validate_dataset()"
# Generate validation report
python mlops/data_validation.pyReal-time model monitoring with drift detection and alerting:
# Start model monitoring
python -c "from mlops.model_monitoring import ModelMonitor; monitor = ModelMonitor(); monitor.log_metrics({'accuracy': 0.92})"
# Generate monitoring dashboard
python mlops/model_monitoring.pyGitHub Actions workflow for automated training, evaluation, and deployment:
# .github/workflows/mlops-pipeline.yml
- Data validation and quality checks
- Automated model training and evaluation
- Model versioning and registry
- Performance monitoring and alerting
- Automated deployment to staging/production- Experiment Tracking: Weights & Biases, MLflow integration
- Model Versioning: DVC for data versioning, model registry
- Data Validation: Comprehensive quality assurance
- Model Monitoring: Real-time performance tracking, drift detection
- Automated CI/CD: GitHub Actions pipeline
- Alerting System: Performance degradation alerts
Advanced model optimization for production deployment:
Production-ready API with comprehensive endpoints:
Complete containerization with multi-stage builds:
# Build and run with Docker
docker build -t medical-vit .
docker run -p 8000:8000 medical-vit
# Or use Docker Compose
docker-compose up -d- POST /predict - Single image prediction
- POST /predict/batch - Batch processing (up to 10 images)
- POST /predict/file - File upload prediction
- GET /health - Health monitoring
- GET /model/info - Model information
- GET /stats - API statistics
- INT8 Quantization: 73% size reduction, 1.9x speedup
- FP16 Conversion: 50% size reduction, GPU memory efficient
- ONNX Conversion: Cross-platform deployment, 2.5x speedup
- Performance Benchmarking: Comprehensive speed and size analysis
# Full deployment pipeline
python deploy.py --full
# Individual components
python deploy.py --optimize-only # Model optimization
python deploy.py --build-only # Docker build
python deploy.py --deploy-only # Docker Compose deploy
python deploy.py --test-only # Performance testing- Multi-stage Docker builds for optimized images
- Health checks and monitoring
- Async processing for high throughput
- Error handling and logging
- API documentation with Swagger UI
- Performance testing suite
This project demonstrates the effectiveness of Vision Transformers in medical image analysis, achieving state-of-the-art performance on pneumonia detection. The model's high accuracy and AUC score make it suitable for clinical decision support systems.
- Advanced ViT architectures for medical imaging
- Comprehensive XAI framework for model interpretability
- Production-ready MLOps pipeline with experiment tracking
- Multi-scale feature extraction for enhanced accuracy
This project is for educational and research purposes. Please ensure compliance with medical data regulations when using in clinical settings.
Built with PyTorch, Vision Transformers, Streamlit, and advanced XAI techniques















