Skip to content

Commit

Permalink
COCO dataset for SSD and update README.md (#844)
Browse files Browse the repository at this point in the history
* ready to coco_reader

* complete coco_reader.py & coco_train.py

* complete coco reader

* rename file

* use argparse instead of explicit assignment

* fix

* fix reader bug for some gray image in coco data

* ready to train coco

* fix bug in test()

* fix bug in test()

* change coco dataset to coco2017 dataset

* change dataset from coco to coco2017

* change learning rate

* fix bug in gt label (category id 2 label)

* fix bug in background label

* save model when train finished

* use coco map

* adding coco year version args: 2014 or 2017

* add coco dataset download, and README.md

* fix

* fix image truncted IOError, map version error

* add test config

* add eval.py for evaluate trained model 

* fix

* fix bug when cocoMAP

* updata READEME.md

* fix cocoMAP bug

* find strange with test_program = fluid.default_main_program().clone(for_test=True)

* add inference and visualize, awa, README.md

* upload infer&visual example image

* refine image

* refine

* fix bug after merge

* follow yapf

* follow comments

* fix bug after separate eval and eval_cocoMAP

* follow yapf

* follow comments

* follow yapf

* follow yapf
  • Loading branch information
sefira authored and qingqing01 committed Apr 27, 2018
1 parent 806cff7 commit 496ff37
Show file tree
Hide file tree
Showing 13 changed files with 462 additions and 141 deletions.
43 changes: 36 additions & 7 deletions fluid/object_detection/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ The minimum PaddlePaddle version needed for the code sample in this directory is

### Introduction

[Single Shot MultiBox Detector (SSD)](https://arxiv.org/abs/1512.02325) framework for object detection is based on a feed-forward convolutional network. The early network is a standard convolutional architecture for image classification, such as VGG, ResNet, or MobileNet, which is als called base network. In this tutorial we used [MobileNet](https://arxiv.org/abs/1704.04861).
[Single Shot MultiBox Detector (SSD)](https://arxiv.org/abs/1512.02325) framework for object detection is based on a feed-forward convolutional network. The early network is a standard convolutional architecture for image classification, such as VGG, ResNet, or MobileNet, which is also called base network. In this tutorial we used [MobileNet](https://arxiv.org/abs/1704.04861).

### Data Preparation

Expand Down Expand Up @@ -52,39 +52,68 @@ Declaration: the MobileNet-v1 SSD model is converted by [TensorFlow model](https
#### Train on PASCAL VOC
- Train on one device (/GPU).
```python
env CUDA_VISIABLE_DEVICES=0 python -u train.py --parallel=False --data='pascalvoc' --pretrained_model='pretrained/ssd_mobilenet_v1_coco/'
env CUDA_VISIABLE_DEVICES=0 python -u train.py --parallel=False --dataset='pascalvoc' --pretrained_model='pretrained/ssd_mobilenet_v1_coco/'
```
- Train on multi devices (/GPUs).

```python
env CUDA_VISIABLE_DEVICES=0,1 python -u train.py --batch_size=64 --data='pascalvoc' --pretrained_model='pretrained/ssd_mobilenet_v1_coco/'
env CUDA_VISIABLE_DEVICES=0,1 python -u train.py --batch_size=64 --dataset='pascalvoc' --pretrained_model='pretrained/ssd_mobilenet_v1_coco/'
```

#### Train on MS-COCO
- Train on one device (/GPU).
```python
env CUDA_VISIABLE_DEVICES=0 python -u train.py --parallel=False --data='coco' --pretrained_model='pretrained/mobilenet_imagenet/'
env CUDA_VISIABLE_DEVICES=0 python -u train.py --parallel=False --dataset='coco2014' --pretrained_model='pretrained/mobilenet_imagenet/'
```
- Train on multi devices (/GPUs).
```python
env CUDA_VISIABLE_DEVICES=0,1 python -u train.py --batch_size=64 --data='coco' --pretrained_model='pretrained/mobilenet_imagenet/'
env CUDA_VISIABLE_DEVICES=0,1 python -u train.py --batch_size=64 --dataset='coco2014' --pretrained_model='pretrained/mobilenet_imagenet/'
```

TBD

### Evaluate

You can evaluate your trained model in different metric like 11point, integral on both PASCAL VOC and COCO dataset. Moreover, we provide eval_coco_map.py which uses a COCO-specific mAP metric defined by [COCO committee](http://cocodataset.org/#detections-eval). To use this eval_coco_map.py, [cocoapi](https://github.com/cocodataset/cocoapi) is needed.
Install the cocoapi:
```
# COCOAPI=/path/to/clone/cocoapi
git clone https://github.com/cocodataset/cocoapi.git $COCOAPI
cd $COCOAPI/PythonAPI
# Install into global site-packages
make install
# Alternatively, if you do not have permissions or prefer
# not to install the COCO API into global site-packages
python2 setup.py install --user
```
Note we set the defualt test list to the dataset's test/val list, you can use your own test list by setting test_list args.

#### Evaluate on PASCAL VOC
```python
env CUDA_VISIABLE_DEVICES=0 python eval.py --dataset='pascalvoc' --model_dir='train_pascal_model/90' --data_dir='data/pascalvoc' --test_list='test.txt' --ap_version='11point'
```

#### Evaluate on MS-COCO
```python
env CUDA_VISIABLE_DEVICES=0 python eval.py --model='model/90' --test_list=''
env CUDA_VISIABLE_DEVICES=0 python eval.py --dataset='coco2014' --nms_threshold=0.5 --model_dir='train_coco_model/40' --test_list='annotations/instances_minival2014.json' --ap_version='integral'
env CUDA_VISIABLE_DEVICES=0 python eval_coco_map.py --dataset='coco2017' --nms_threshold=0.5 --model_dir='train_coco_model/40' --test_list='annotations/instances_minival2017.json'
```

TBD

### Infer and Visualize

```python
env CUDA_VISIABLE_DEVICES=0 python infer.py --batch_size=2 --model='model/90' --test_list=''
env CUDA_VISIABLE_DEVICES=0 python infer.py --model_dir='train_coco_model/20' --image_path='./data/coco/val2014/COCO_val2014_000000000139.jpg'
```
Below is the examples after running python infer.py to inference and visualize the model result.
<p align="center">
<img src="images/COCO_val2014_000000000139.jpg" height=300 width=400 hspace='10'/>
<img src="images/COCO_val2014_000000000785.jpg" height=300 width=400 hspace='10'/>
<img src="images/COCO_val2014_000000142324.jpg" height=300 width=400 hspace='10'/>
<img src="images/COCO_val2014_000000144003.jpg" height=300 width=400 hspace='10'/> <br />
MobileNet-SSD300x300 Visualization Examples
</p>

TBD

Expand Down
20 changes: 20 additions & 0 deletions fluid/object_detection/data/coco/download.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
DIR="$( cd "$(dirname "$0")" ; pwd -P )"
cd "$DIR"

# Download the data.
echo "Downloading..."
wget http://images.cocodataset.org/zips/train2014.zip
wget http://images.cocodataset.org/zips/val2014.zip
wget http://images.cocodataset.org/zips/train2017.zip
wget http://images.cocodataset.org/zips/val2017.zip
wget http://images.cocodataset.org/annotations/annotations_trainval2014.zip
wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip
# Extract the data.
echo "Extractint..."
unzip train2014.tar
unzip val2014.tar
unzip train2017.tar
unzip val2017.tar
unzip annotations_trainval2014.tar
unzip annotations_trainval2017.tar

102 changes: 59 additions & 43 deletions fluid/object_detection/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,27 @@
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('dataset', str, 'pascalvoc', "coco or pascalvoc.")
add_arg('dataset', str, 'pascalvoc', "coco2014, coco2017, and pascalvoc.")
add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('use_gpu', bool, True, "Whether use GPU.")
add_arg('data_dir', str, '', "The data root path.")
add_arg('test_list', str, '', "The testing data lists.")
add_arg('label_file', str, '', "The label file, which save the real name and is only used for Pascal VOC.")
add_arg('model_dir', str, '', "The model path.")
add_arg('ap_version', str, '11point', "11point or integral")
add_arg('resize_h', int, 300, "The resized image height.")
add_arg('resize_w', int, 300, "The resized image width.")
add_arg('mean_value_B', float, 127.5, "mean value for B channel which will be subtracted") #123.68
add_arg('mean_value_G', float, 127.5, "mean value for G channel which will be subtracted") #116.78
add_arg('mean_value_R', float, 127.5, "mean value for R channel which will be subtracted") #103.94
add_arg('model_dir', str, '', "The model path.")
add_arg('nms_threshold', float, 0.45, "NMS threshold.")
add_arg('ap_version', str, '11point', "integral, 11point.")
add_arg('resize_h', int, 300, "The resized image height.")
add_arg('resize_w', int, 300, "The resized image height.")
add_arg('mean_value_B', float, 127.5, "Mean value for B channel which will be subtracted.") #123.68
add_arg('mean_value_G', float, 127.5, "Mean value for G channel which will be subtracted.") #116.78
add_arg('mean_value_R', float, 127.5, "Mean value for R channel which will be subtracted.") #103.94
# yapf: enable


def eval(args, data_args, test_list, batch_size, model_dir=None):
image_shape = [3, data_args.resize_h, data_args.resize_w]
if data_args.dataset == 'coco':
num_classes = 81
elif data_args.dataset == 'pascalvoc':
if 'coco' in data_args.dataset:
num_classes = 91
elif 'pascalvoc' in data_args.dataset:
num_classes = 21

image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
Expand All @@ -46,61 +46,77 @@ def eval(args, data_args, test_list, batch_size, model_dir=None):

locs, confs, box, box_var = mobile_net(num_classes, image, image_shape)
nmsed_out = fluid.layers.detection_output(
locs, confs, box, box_var, nms_threshold=0.45)
locs, confs, box, box_var, nms_threshold=args.nms_threshold)
loss = fluid.layers.ssd_loss(locs, confs, gt_box, gt_label, box, box_var)
loss = fluid.layers.reduce_sum(loss)

test_program = fluid.default_main_program().clone(for_test=True)
with fluid.program_guard(test_program):
map_eval = fluid.evaluator.DetectionMAP(
nmsed_out,
gt_label,
gt_box,
difficult,
num_classes,
overlap_threshold=0.5,
evaluate_difficult=False,
ap_version=args.ap_version)

place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)

# yapf: disable
if model_dir:

def if_exist(var):
return os.path.exists(os.path.join(model_dir, var.name))

fluid.io.load_vars(exe, model_dir, predicate=if_exist)

# yapf: enable
test_reader = paddle.batch(
reader.test(data_args, test_list), batch_size=batch_size)
feeder = fluid.DataFeeder(
place=place, feed_list=[image, gt_box, gt_label, difficult])

_, accum_map = map_eval.get_map_var()
map_eval.reset(exe)
for idx, data in enumerate(test_reader()):
test_map = exe.run(test_program,
feed=feeder.feed(data),
fetch_list=[accum_map])
if idx % 50 == 0:
print("Batch {0}, map {1}".format(idx, test_map[0]))
print("Test model {0}, map {1}".format(model_dir, test_map[0]))
def test():
test_program = fluid.default_main_program().clone(for_test=True)
with fluid.program_guard(test_program):
map_eval = fluid.evaluator.DetectionMAP(
nmsed_out,
gt_label,
gt_box,
difficult,
num_classes,
overlap_threshold=0.5,
evaluate_difficult=False,
ap_version=args.ap_version)

_, accum_map = map_eval.get_map_var()
map_eval.reset(exe)
for batch_id, data in enumerate(test_reader()):
test_map = exe.run(test_program,
feed=feeder.feed(data),
fetch_list=[accum_map])
if batch_id % 20 == 0:
print("Batch {0}, map {1}".format(batch_id, test_map[0]))
print("Test model {0}, map {1}".format(model_dir, test_map[0]))

test()


if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)

data_dir = 'data/pascalvoc'
test_list = 'test.txt'
label_file = 'label_list'
if 'coco' in args.dataset:
data_dir = './data/coco'
if '2014' in args.dataset:
test_list = 'annotations/instances_minival2014.json'
elif '2017' in args.dataset:
test_list = 'annotations/instances_val2017.json'

data_args = reader.Settings(
dataset=args.dataset,
data_dir=args.data_dir,
label_file=args.label_file,
data_dir=args.data_dir if len(args.data_dir) > 0 else data_dir,
label_file=label_file,
resize_h=args.resize_h,
resize_w=args.resize_w,
mean_value=[args.mean_value_B, args.mean_value_G, args.mean_value_R])
mean_value=[args.mean_value_B, args.mean_value_G, args.mean_value_R],
apply_distort=False,
apply_expand=False,
ap_version=args.ap_version,
toy=0)
eval(
args,
test_list=args.test_list,
data_args=data_args,
test_list=args.test_list if len(args.test_list) > 0 else test_list,
batch_size=args.batch_size,
model_dir=args.model_dir)
Loading

0 comments on commit 496ff37

Please sign in to comment.