Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PaddlePaddle Hackathon] add GoogLeNet #36034

Merged
merged 1 commit into from
Nov 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/paddle/tests/test_pretrained_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def infer(self, arch):
def test_models(self):
arches = [
'mobilenet_v1', 'mobilenet_v2', 'resnet18', 'vgg16', 'alexnet',
'resnext50_32x4d', 'inception_v3', 'densenet121'
'resnext50_32x4d', 'inception_v3', 'densenet121', 'googlenet'
]
for arch in arches:
self.infer(arch)
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/tests/test_vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ def test_resnext152_64x4d(self):
def test_inception_v3(self):
self.models_infer('inception_v3')

def test_googlenet(self):
self.models_infer('googlenet')

def test_vgg16_num_classes(self):
vgg16 = models.__dict__['vgg16'](pretrained=False, num_classes=10)

Expand Down
2 changes: 2 additions & 0 deletions python/paddle/vision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@
from .models import resnext152_64x4d # noqa: F401
from .models import InceptionV3 # noqa: F401
from .models import inception_v3 # noqa: F401
from .models import GoogLeNet # noqa: F401
from .models import googlenet # noqa: F401
from .transforms import BaseTransform # noqa: F401
from .transforms import Compose # noqa: F401
from .transforms import Resize # noqa: F401
Expand Down
6 changes: 5 additions & 1 deletion python/paddle/vision/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
from .resnext import resnext152_64x4d # noqa: F401
from .inceptionv3 import InceptionV3 # noqa: F401
from .inceptionv3 import inception_v3 # noqa: F401
from .googlenet import GoogLeNet # noqa: F401
from .googlenet import googlenet # noqa: F401

__all__ = [ #noqa
'ResNet',
Expand Down Expand Up @@ -79,5 +81,7 @@
'resnext152_32x4d',
'resnext152_64x4d',
'InceptionV3',
'inception_v3'
'inception_v3',
'GoogLeNet',
'googlenet',
]
254 changes: 254 additions & 0 deletions python/paddle/vision/models/googlenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import division
from __future__ import print_function

import paddle
import paddle.nn as nn
import paddle.nn.functional as F

from paddle.nn import Conv2D, Linear, Dropout
from paddle.nn import MaxPool2D, AvgPool2D, AdaptiveAvgPool2D
from paddle.nn.initializer import Uniform
from paddle.fluid.param_attr import ParamAttr
from paddle.utils.download import get_weights_path_from_url

__all__ = []

model_urls = {
"googlenet":
("https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/GoogLeNet_pretrained.pdparams",
"80c06f038e905c53ab32c40eca6e26ae")
}


def xavier(channels, filter_size):
stdv = (3.0 / (filter_size**2 * channels))**0.5
param_attr = ParamAttr(initializer=Uniform(-stdv, stdv))
return param_attr


class ConvLayer(nn.Layer):
def __init__(self,
num_channels,
num_filters,
filter_size,
stride=1,
groups=1):
super(ConvLayer, self).__init__()

self._conv = Conv2D(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
bias_attr=False)

def forward(self, inputs):
y = self._conv(inputs)
return y


class Inception(nn.Layer):
def __init__(self, input_channels, output_channels, filter1, filter3R,
filter3, filter5R, filter5, proj):
super(Inception, self).__init__()

self._conv1 = ConvLayer(input_channels, filter1, 1)
self._conv3r = ConvLayer(input_channels, filter3R, 1)
self._conv3 = ConvLayer(filter3R, filter3, 3)
self._conv5r = ConvLayer(input_channels, filter5R, 1)
self._conv5 = ConvLayer(filter5R, filter5, 5)
self._pool = MaxPool2D(kernel_size=3, stride=1, padding=1)

self._convprj = ConvLayer(input_channels, proj, 1)

def forward(self, inputs):
conv1 = self._conv1(inputs)

conv3r = self._conv3r(inputs)
conv3 = self._conv3(conv3r)

conv5r = self._conv5r(inputs)
conv5 = self._conv5(conv5r)

pool = self._pool(inputs)
convprj = self._convprj(pool)

cat = paddle.concat([conv1, conv3, conv5, convprj], axis=1)
cat = F.relu(cat)
return cat


class GoogLeNet(nn.Layer):
"""GoogLeNet (Inception v1) model architecture from
`"Going Deeper with Convolutions" <https://arxiv.org/pdf/1409.4842.pdf>`_

Args:
num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer
will not be defined. Default: 1000.
with_pool (bool, optional): use pool before the last fc layer or not. Default: True.

Examples:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

英文文档示例应与中文示例保持一致~

.. code-block:: python

import paddle
from paddle.vision.models import GoogLeNet

# build model
model = GoogLeNet()

x = paddle.rand([1, 3, 224, 224])
out, out1, out2 = model(x)

print(out.shape)
"""

def __init__(self, num_classes=1000, with_pool=True):
super(GoogLeNet, self).__init__()
self.num_classes = num_classes
self.with_pool = with_pool

self._conv = ConvLayer(3, 64, 7, 2)
self._pool = MaxPool2D(kernel_size=3, stride=2)
self._conv_1 = ConvLayer(64, 64, 1)
self._conv_2 = ConvLayer(64, 192, 3)

self._ince3a = Inception(192, 192, 64, 96, 128, 16, 32, 32)
self._ince3b = Inception(256, 256, 128, 128, 192, 32, 96, 64)

self._ince4a = Inception(480, 480, 192, 96, 208, 16, 48, 64)
self._ince4b = Inception(512, 512, 160, 112, 224, 24, 64, 64)
self._ince4c = Inception(512, 512, 128, 128, 256, 24, 64, 64)
self._ince4d = Inception(512, 512, 112, 144, 288, 32, 64, 64)
self._ince4e = Inception(528, 528, 256, 160, 320, 32, 128, 128)

self._ince5a = Inception(832, 832, 256, 160, 320, 32, 128, 128)
self._ince5b = Inception(832, 832, 384, 192, 384, 48, 128, 128)

if with_pool:
# out
self._pool_5 = AdaptiveAvgPool2D(1)
# out1
self._pool_o1 = AvgPool2D(kernel_size=5, stride=3)
# out2
self._pool_o2 = AvgPool2D(kernel_size=5, stride=3)

if num_classes > 0:
# out
self._drop = Dropout(p=0.4, mode="downscale_in_infer")
self._fc_out = Linear(
1024, num_classes, weight_attr=xavier(1024, 1))

# out1
self._conv_o1 = ConvLayer(512, 128, 1)
self._fc_o1 = Linear(1152, 1024, weight_attr=xavier(2048, 1))
self._drop_o1 = Dropout(p=0.7, mode="downscale_in_infer")
self._out1 = Linear(1024, num_classes, weight_attr=xavier(1024, 1))

# out2
self._conv_o2 = ConvLayer(528, 128, 1)
self._fc_o2 = Linear(1152, 1024, weight_attr=xavier(2048, 1))
self._drop_o2 = Dropout(p=0.7, mode="downscale_in_infer")
self._out2 = Linear(1024, num_classes, weight_attr=xavier(1024, 1))

def forward(self, inputs):
x = self._conv(inputs)
x = self._pool(x)
x = self._conv_1(x)
x = self._conv_2(x)
x = self._pool(x)

x = self._ince3a(x)
x = self._ince3b(x)
x = self._pool(x)

ince4a = self._ince4a(x)
x = self._ince4b(ince4a)
x = self._ince4c(x)
ince4d = self._ince4d(x)
x = self._ince4e(ince4d)
x = self._pool(x)

x = self._ince5a(x)
ince5b = self._ince5b(x)

out, out1, out2 = ince5b, ince4a, ince4d

if self.with_pool:
out = self._pool_5(out)
out1 = self._pool_o1(out1)
out2 = self._pool_o2(out2)

if self.num_classes > 0:
out = self._drop(out)
out = paddle.squeeze(out, axis=[2, 3])
out = self._fc_out(out)

out1 = self._conv_o1(out1)
out1 = paddle.flatten(out1, start_axis=1, stop_axis=-1)
out1 = self._fc_o1(out1)
out1 = F.relu(out1)
out1 = self._drop_o1(out1)
out1 = self._out1(out1)

out2 = self._conv_o2(out2)
out2 = paddle.flatten(out2, start_axis=1, stop_axis=-1)
out2 = self._fc_o2(out2)
out2 = self._drop_o2(out2)
out2 = self._out2(out2)

return [out, out1, out2]


def googlenet(pretrained=False, **kwargs):
"""GoogLeNet (Inception v1) model architecture from
`"Going Deeper with Convolutions" <https://arxiv.org/pdf/1409.4842.pdf>`_

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet

Examples:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

英文文档示例应与中文示例保持一致~

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改~

.. code-block:: python

import paddle
from paddle.vision.models import googlenet

# build model
model = googlenet()

# build model and load imagenet pretrained weight
# model = googlenet(pretrained=True)

x = paddle.rand([1, 3, 224, 224])
out, out1, out2 = model(x)

print(out.shape)
"""
model = GoogLeNet(**kwargs)
arch = "googlenet"
if pretrained:
assert (
arch in model_urls
), "{} model do not have a pretrained model now, you should set pretrained=False".format(
arch)
weight_path = get_weights_path_from_url(model_urls[arch][0],
model_urls[arch][1])

param = paddle.load(weight_path)
model.set_dict(param)
return model