From 08b441af62c281d52d5e84b4056a453a830d7c9b Mon Sep 17 00:00:00 2001 From: wangmeng28 Date: Wed, 20 Dec 2017 15:51:47 +0800 Subject: [PATCH 1/2] Add xception model for image classification --- image_classification/infer.py | 7 +- image_classification/train.py | 7 +- image_classification/xception.py | 193 +++++++++++++++++++++++++++++++ 3 files changed, 205 insertions(+), 2 deletions(-) create mode 100644 image_classification/xception.py diff --git a/image_classification/infer.py b/image_classification/infer.py index 659c4f2a8e..1ae5da2c86 100644 --- a/image_classification/infer.py +++ b/image_classification/infer.py @@ -26,7 +26,10 @@ def main(): parser.add_argument( 'model', help='The model for image classification', - choices=['alexnet', 'vgg13', 'vgg16', 'vgg19', 'resnet', 'googlenet']) + choices=[ + 'alexnet', 'vgg13', 'vgg16', 'vgg19', 'resnet', 'googlenet', + 'xception' + ]) parser.add_argument( 'params_path', help='The file which stores the parameters') args = parser.parse_args() @@ -49,6 +52,8 @@ def main(): out = resnet.resnet_imagenet(image, class_dim=CLASS_DIM) elif args.model == 'googlenet': out, _, _ = googlenet.googlenet(image, class_dim=CLASS_DIM) + elif args.model == 'xception': + out = xception.xception(image, class_dim=CLASS_DIM) # load parameters with gzip.open(args.params_path, 'r') as f: diff --git a/image_classification/train.py b/image_classification/train.py index 12a582db3a..c45eed7727 100644 --- a/image_classification/train.py +++ b/image_classification/train.py @@ -19,7 +19,10 @@ def main(): parser.add_argument( 'model', help='The model for image classification', - choices=['alexnet', 'vgg13', 'vgg16', 'vgg19', 'resnet', 'googlenet']) + choices=[ + 'alexnet', 'vgg13', 'vgg16', 'vgg19', 'resnet', 'googlenet', + 'xception' + ]) args = parser.parse_args() # PaddlePaddle init @@ -52,6 +55,8 @@ def main(): input=out2, label=lbl, coeff=0.3) paddle.evaluator.classification_error(input=out2, label=lbl) extra_layers = [loss1, loss2] + elif args.model == 'xception': + out = xception.xception(image, class_dim=CLASS_DIM) cost = paddle.layer.classification_cost(input=out, label=lbl) diff --git a/image_classification/xception.py b/image_classification/xception.py new file mode 100644 index 0000000000..41c11b8353 --- /dev/null +++ b/image_classification/xception.py @@ -0,0 +1,193 @@ +import paddle.v2 as paddle + +__all__ = ['xception'] + + +def img_separable_conv_bn(name, input, num_channels, num_out_channels, + filter_size, stride, padding, act): + conv = paddle.networks.img_separable_conv( + name=name, + input=input, + num_channels=num_channels, + num_out_channels=num_out_channels, + filter_size=filter_size, + stride=stride, + padding=padding, + act=paddle.activation.Linear()) + norm = paddle.layer.batch_norm(name=name + '_norm', input=conv, act=act) + return norm + + +def img_conv_bn(name, input, num_channels, num_filters, filter_size, stride, + padding, act): + conv = paddle.layer.img_conv( + name=name, + input=input, + num_channels=num_channels, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=padding, + act=paddle.activation.Linear()) + norm = paddle.layer.batch_norm(name=name + '_norm', input=conv, act=act) + return norm + + +def conv_block0(input, + group, + num_channels, + num_filters, + num_filters2=None, + filter_size=3, + pool_padding=0, + entry_relu=True): + if num_filters2 is None: + num_filters2 = num_filters + + if entry_relu: + act_input = paddle.layer.mixed( + input=paddle.layer.identity_projection(input=input), + act=paddle.activation.Relu()) + else: + act_input = input + conv0 = img_separable_conv_bn( + name='xception_block{0}_conv0'.format(group), + input=act_input, + num_channels=num_channels, + num_out_channels=num_filters, + filter_size=filter_size, + stride=1, + padding=(filter_size - 1) / 2, + act=paddle.activation.Relu()) + conv1 = img_separable_conv_bn( + name='xception_block{0}_conv1'.format(group), + input=conv0, + num_channels=num_filters, + num_out_channels=num_filters2, + filter_size=filter_size, + stride=1, + padding=(filter_size - 1) / 2, + act=paddle.activation.Linear()) + pool0 = paddle.layer.img_pool( + name='xception_block{0}_pool'.format(group), + input=conv1, + pool_size=3, + stride=2, + padding=pool_padding, + num_channels=num_filters2, + pool_type=paddle.pooling.CudnnMax()) + + shortcut = img_conv_bn( + name='xception_block{0}_shortcut'.format(group), + input=input, + num_channels=num_channels, + num_filters=num_filters2, + filter_size=1, + stride=2, + padding=0, + act=paddle.activation.Linear()) + + return paddle.layer.addto( + input=[pool0, shortcut], act=paddle.activation.Linear()) + + +def conv_block1(input, group, num_channels, num_filters, filter_size=3): + act_input = paddle.layer.mixed( + input=paddle.layer.identity_projection(input=input), + act=paddle.activation.Relu()) + conv0 = img_separable_conv_bn( + name='xception_block{0}_conv0'.format(group), + input=act_input, + num_channels=num_channels, + num_out_channels=num_filters, + filter_size=filter_size, + stride=1, + padding=(filter_size - 1) / 2, + act=paddle.activation.Relu()) + conv1 = img_separable_conv_bn( + name='xception_block{0}_conv1'.format(group), + input=conv0, + num_channels=num_filters, + num_out_channels=num_filters, + filter_size=filter_size, + stride=1, + padding=(filter_size - 1) / 2, + act=paddle.activation.Relu()) + conv2 = img_separable_conv_bn( + name='xception_block{0}_conv2'.format(group), + input=conv1, + num_channels=num_filters, + num_out_channels=num_filters, + filter_size=filter_size, + stride=1, + padding=(filter_size - 1) / 2, + act=paddle.activation.Linear()) + + shortcut = input + return paddle.layer.addto( + input=[conv2, shortcut], act=paddle.activation.Linear()) + + +def xception(input, class_dim): + conv = img_conv_bn( + name='xception_conv0', + input=input, + num_channels=3, + num_filters=32, + filter_size=3, + stride=2, + padding=1, + act=paddle.activation.Relu()) + conv = img_conv_bn( + name='xception_conv1', + input=conv, + num_channels=32, + num_filters=64, + filter_size=3, + stride=1, + padding=1, + act=paddle.activation.Relu()) + conv = conv_block0( + input=conv, group=2, num_channels=64, num_filters=128, entry_relu=False) + conv = conv_block0(input=conv, group=3, num_channels=128, num_filters=256) + conv = conv_block0(input=conv, group=4, num_channels=256, num_filters=728) + for group in range(5, 13): + conv = conv_block1( + input=conv, group=group, num_channels=728, num_filters=728) + conv = conv_block0( + input=conv, + group=13, + num_channels=728, + num_filters=728, + num_filters2=1024) + conv = img_separable_conv_bn( + name='xception_conv14', + input=conv, + num_channels=1024, + num_out_channels=1536, + filter_size=3, + stride=1, + padding=1, + act=paddle.activation.Relu()) + conv = img_separable_conv_bn( + name='xception_conv15', + input=conv, + num_channels=1536, + num_out_channels=2048, + filter_size=3, + stride=1, + padding=1, + act=paddle.activation.Relu()) + pool = paddle.layer.img_pool( + name='xception_global_pool', + input=conv, + pool_size=7, + stride=1, + num_channels=2048, + pool_type=paddle.pooling.CudnnAvg()) + out = paddle.layer.fc( + name='xception_fc', + input=pool, + size=class_dim, + act=paddle.activation.Softmax()) + return out From 882ec2f94c21a8fe92a5de917299b195881b0c00 Mon Sep 17 00:00:00 2001 From: wangmeng28 Date: Wed, 20 Dec 2017 15:52:10 +0800 Subject: [PATCH 2/2] Add docs for xception model --- README.cn.md | 3 ++- README.md | 3 ++- image_classification/README.md | 10 +++++++++- image_classification/infer.py | 1 + image_classification/train.py | 1 + 5 files changed, 15 insertions(+), 3 deletions(-) diff --git a/README.cn.md b/README.cn.md index 9491690e3d..87bc575c95 100644 --- a/README.cn.md +++ b/README.cn.md @@ -98,12 +98,13 @@ PaddlePaddle提供了丰富的运算单元,帮助大家以模块化的方式 图像相比文字能够提供更加生动、容易理解及更具艺术感的信息,是人们转递与交换信息的重要来源。图像分类是根据图像的语义信息对不同类别图像进行区分,是计算机视觉中重要的基础问题,也是图像检测、图像分割、物体跟踪、行为分析等其他高层视觉任务的基础,在许多领域都有着广泛的应用。如:安防领域的人脸识别和智能视频分析等,交通领域的交通场景识别,互联网领域基于内容的图像检索和相册自动归类,医学领域的图像识别等。 -在图像分类任务中,我们向大家介绍如何训练AlexNet、VGG、GoogLeNet和ResNet模型。同时提供了一个够将Caffe训练好的模型文件转换为PaddlePaddle模型文件的模型转换工具。 +在图像分类任务中,我们向大家介绍如何训练AlexNet、VGG、GoogLeNet、ResNet和Xception模型。同时提供了一个够将Caffe训练好的模型文件转换为PaddlePaddle模型文件的模型转换工具。 - 11.1 [将Caffe模型文件转换为PaddlePaddle模型文件](https://github.com/PaddlePaddle/models/tree/develop/image_classification/caffe2paddle) - 11.2 [AlexNet](https://github.com/PaddlePaddle/models/tree/develop/image_classification) - 11.3 [VGG](https://github.com/PaddlePaddle/models/tree/develop/image_classification) - 11.4 [Residual Network](https://github.com/PaddlePaddle/models/tree/develop/image_classification) +- 11.5 [Xception](https://github.com/PaddlePaddle/models/tree/develop/image_classification) ## 12. 目标检测 diff --git a/README.md b/README.md index 8b938a30dc..9fe7121a55 100644 --- a/README.md +++ b/README.md @@ -72,11 +72,12 @@ As an example for sequence-to-sequence learning, we take the machine translation ## 9. Image classification -For the example of image classification, we show you how to train AlexNet, VGG, GoogLeNet and ResNet models in PaddlePaddle. It also provides a model conversion tool that converts Caffe trained model files into PaddlePaddle model files. +For the example of image classification, we show you how to train AlexNet, VGG, GoogLeNet, ResNet and Xception models in PaddlePaddle. It also provides a model conversion tool that converts Caffe trained model files into PaddlePaddle model files. - 9.1 [convert Caffe model file to PaddlePaddle model file](https://github.com/PaddlePaddle/models/tree/develop/image_classification/caffe2paddle) - 9.2 [AlexNet](https://github.com/PaddlePaddle/models/tree/develop/image_classification) - 9.3 [VGG](https://github.com/PaddlePaddle/models/tree/develop/image_classification) - 9.4 [Residual Network](https://github.com/PaddlePaddle/models/tree/develop/image_classification) +- 9.5 [Xception](https://github.com/PaddlePaddle/models/tree/develop/image_classification) This tutorial is contributed by [PaddlePaddle](https://github.com/PaddlePaddle/Paddle) and licensed under the [Apache-2.0 license](LICENSE). diff --git a/image_classification/README.md b/image_classification/README.md index 843d683c00..caf77ca2cb 100644 --- a/image_classification/README.md +++ b/image_classification/README.md @@ -1,7 +1,7 @@ 图像分类 ======================= -这里将介绍如何在PaddlePaddle下使用AlexNet、VGG、GoogLeNet和ResNet模型进行图像分类。图像分类问题的描述和这四种模型的介绍可以参考[PaddlePaddle book](https://github.com/PaddlePaddle/book/tree/develop/03.image_classification)。 +这里将介绍如何在PaddlePaddle下使用AlexNet、VGG、GoogLeNet、ResNet和Xception模型进行图像分类。图像分类问题的描述和这些模型的介绍可以参考[PaddlePaddle book](https://github.com/PaddlePaddle/book/tree/develop/03.image_classification)。 ## 训练模型 @@ -86,6 +86,14 @@ ResNet模型可以通过下面的代码获取: out = resnet.resnet_imagenet(image, class_dim=CLASS_DIM) ``` +5. 使用Xception模型 + +Xception模型可以通过下面的代码获取: + +```python +out = xception.xception(image, class_dim=CLASS_DIM) +``` + ### 定义损失函数 ```python diff --git a/image_classification/infer.py b/image_classification/infer.py index 1ae5da2c86..de072e3ae0 100644 --- a/image_classification/infer.py +++ b/image_classification/infer.py @@ -5,6 +5,7 @@ import resnet import alexnet import googlenet +import xception import argparse import os from PIL import Image diff --git a/image_classification/train.py b/image_classification/train.py index c45eed7727..dc374e9a54 100644 --- a/image_classification/train.py +++ b/image_classification/train.py @@ -6,6 +6,7 @@ import resnet import alexnet import googlenet +import xception import argparse DATA_DIM = 3 * 224 * 224