The project is inspired by @aarcosg's traffic sign detection project. The purpose of this project is to demenstrate how to use Tensorflow Object Detection API to train your own dataset, and use your own model to do inference.
The main steps are:
- Setup the environment.
- Prepare the dataset.
- Train the models.
- Inference.
-
If you have used Tensorflow Object Detection API before, then you are all set.
-
Download Tensorflow Object Detection API from here, then follow the instruction here to finish the install.
-
Download faster_rcnn_resnet50_coco pre-trained model from the model zoo and put it under this project.
-
Then modify configs/faster_rcnn_resnet50_coco_autti.config to update the weights location.
The dataset used in this example is from Autti, you can downloaded it from here.
The dataset's annotation file is a .csv file. You have to fist convert it into a gt.txt. I have write a script to do this.
python scripts/create_gttxt.py
More scripts are provied in scripts folder. Currently supported dataset:
- Autti
- LISA traffic light dataset
- Tsinghua-Tencent 100K traffic sign dataset
Tensorflow Object Detection API use tf record data in order to train the model. TF record is a combination of annotation and images files. The scripts/create_tfrecord.py will do the work. Please modify the setting in this file according to your own setting. After that, you will get train.record and val.record at your dataset folder. Update the configs/faster_rcnn_resnet50_coco_autti.config again to use the new location.
python scripts/create_tfrecord.py --data_dir dataset/autti --output_dir dataset/autti --label_map_path configs/autti_label_map.pbtxt
Pay attention to these settings in create_tfrecord.py.
num_total = 12
num_train = 11
num_classes = 5
And you have everything to train the network. The first step is to use TF API's script to train the model. Modify the train.py's location to your owns.
python ~/Github/models/research/object_detection/legacy/train.py --logtostderr --pipeline_config_path configs/faster_rcnn_resnet50_coco_autti.config --train_dir dataset/autti/ckpt/
After that, you will find checkpoint files in ckpt folder. The last step is to export the checkpoint to useable weights files. Please notice the ckpt index should be correct.
python ~/Github/models/research/object_detection/export_inference_graph.py --input_type image_tensor --pipeline_config_path configs/faster_rcnn_resnet50_coco_autti.config --trained_checkpoint_prefix dataset/autti/ckpt/model.ckpt-1905 --output_directory dataset/autti/exported
I have put everyhing about the inference in ref.py file. Please modify the locations in the script first, then just run it.
python ref.py
The tensorboard log is enabled by default. Just use the following cmd to run the tensorboard.
tensorboard --logdir=tf-log
tensorboard --logdir=dataset/ckpt
Congratulations, you have finished all the work!
First thing first, read the instruction again here
When run create_tfrerocd.py, error ImportError: cannot import name 'string_int_label_map_pb2'
Make sure you have finish protobuf-compiler installation.
When run create_tfrerocd.py, error ImportError: No module named 'object_detection'
Make sure you have config the PYTHONPATH.
The project is inspired by @aarcosg's traffic sign detection project, some code in ref.py and scripts/create_tfrecord.py are directly from their work. Thanks!