Given images of blood samples predict whether a patient has malaria or not using Tensorflow Object Detection API and Transfer Learning
Blood Sample Image |
After Detection Image |
Data came from three different labs’ ex vivo samples of P. vivax infected patients in Manaus, Brazil, and Thailand. The Manaus and Thailand data were used for training and validation while the Brazil data were left out as our test set. Blood smears were stained with Giemsa reagent, which attaches to DNA and allow experts to inspect infected cells and determine their stage.
It can be downloaded from Kaggle
As the data is highly imbalanced and almost 98% cells are Red Blood cells, I made this problem a two stage classification task.
Label all the cells other than RBC as Non RBC cells and then create a object detector that only detects either cell is RBC or NON RBC. For this task I used Tensorflow Object Detection API to fine-tune and train a Faster RCNN model pre-trained on coco dataset.
For all the NON RBC cells create a CNN Classifier to classify given a NON RBC cell which category it belongs (trophozoite, schizont, ring, leukocyte, gametocyte). I fine-tuned VGG-16 Model trained on imagenet data.
So by combining the first stage and second stage we build a complete two stage classification model in which we will supply a blood sample image, Faster RCNN detector will detect and gives bouding boxes of RBC and NON-RBC in the image and then for all the NON-RBC we will crop the NON-RBC and pass that cell image to our VGG-16 classifier and it will predict the categories of NON-RBC.
Grount Truth Count | Model Truth Count | |
red blood cell | 77420 | 71040 |
trophozoite | 1473 | 1513 |
ring | 353 | 311 |
schizont | 179 | 178 |
gametocyte | 144 | 157 |
leukocyte | 103 | 87 |
Grount Truth Count | Model Truth Count | |
red blood cell | 5614 | 4551 |
trophozoite | 111 | 146 |
ring | 169 | 95 |
schizont | 11 | 6 |
gametocyte | 12 | 7 |
leukocyte | 0 | 0 |
Clone the repository and run the following commands from the terminal.
Go through this blog to setup TFOD on your system.
pip install -r requirements.txt
python TFODRecordCreator.py
In the experiment folder paste the Faster RCNN resnet101 coco model in training subfolder and then run the following command from TFOD folder
python object_detection/model_main.py --pipeline_config_path faster_rcnn_malaria.config --model_dir experiment/training --num_train_steps 20000 --sample_1_of_n_eval_examples 10 --alsologtostderr
Run the following command from TFOD folder
python object_detection/export_inference_graph.py --input_type image_tensor --pipeline_config_path faster_rcnn_malaria.config --trained_checkpoint_prefix experiments/training/model.ckpt-20000 --output_directory output/models
python CellExtractor.py
python VGG16Trainer.py
python inference.py --imagePath "test.png"
preprocessing.py
As data is present in json file this file creates annotated csv files which will be required by our Tensorflow Object Detection API.
TFAnnotation.py
This class defines the format of data required by Tensorflow Object Detection API.
TFODRecordCreator.py
This script creates the .record and classes files using TFAnnotation class which is required by Tensorflow Object Detection API to train Faster RCNN model.
plotting.py
Script for plotting of bouding boxes of RBC, NON RBC on images.
CellExtractor.py
Script for cropping and extracting NON-RBC images using bounding boxes from main image for the training of VGG-16 model.
VGG16Trainer.py
Script for training of cropped NON-RBC images using VGG-16 pretrained model.
trainDataInference.py, testDataInference
Script for evaluation of training and testing images using two-staged classification.
inference.py
Function to predict the result of blood sample image and also shows the image of input blood sample with labels of each cell and bouding box.
Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change.