diff --git a/python/paddle/tests/test_vision_models.py b/python/paddle/tests/test_vision_models.py index 29e00e73e29b28..89dc73b7b83d25 100644 --- a/python/paddle/tests/test_vision_models.py +++ b/python/paddle/tests/test_vision_models.py @@ -11,6 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +unit tests of all models in paddle.vision.models +""" import unittest import numpy as np @@ -20,18 +23,22 @@ class TestVisonModels(unittest.TestCase): + """ + unit tests of all models in paddle.vision.models + """ def models_infer(self, arch, pretrained=False, batch_norm=False): - + """ + arch: the name of the model to be tested + """ x = np.array(np.random.random((2, 3, 224, 224)), dtype=np.float32) if batch_norm: net = models.__dict__[arch](pretrained=pretrained, batch_norm=True) else: net = models.__dict__[arch](pretrained=pretrained) - + input = InputSpec([None, 3, 224, 224], 'float32', 'image') model = paddle.Model(net, input) model.prepare() - model.predict_batch(x) def test_mobilenetv2_pretrained(self): @@ -70,6 +77,12 @@ def test_resnet101(self): def test_resnet152(self): self.models_infer('resnet152') + def test_wide_resnet50_2(self): + self.models_infer('wide_resnet50_2') + + def test_wide_resnet101_2(self): + self.models_infer('wide_resnet101_2') + def test_densenet121(self): self.models_infer('densenet121') @@ -88,9 +101,6 @@ def test_densenet264(self): def test_alexnet(self): self.models_infer('alexnet') - def test_shufflenetv2_swish(self): - self.models_infer('shufflenet_v2_swish') - def test_resnext50_32x4d(self): self.models_infer('resnext50_32x4d') @@ -112,27 +122,6 @@ def test_resnext152_64x4d(self): def test_inception_v3(self): self.models_infer('inception_v3') - def test_googlenet(self): - self.models_infer('googlenet') - - def test_shufflenetv2_x0_25(self): - self.models_infer('shufflenet_v2_x0_25') - - def test_shufflenetv2_x0_33(self): - self.models_infer('shufflenet_v2_x0_33') - - def test_shufflenetv2_x0_5(self): - self.models_infer('shufflenet_v2_x0_5') - - def test_shufflenetv2_x1_0(self): - self.models_infer('shufflenet_v2_x1_0') - - def test_shufflenetv2_x1_5(self): - self.models_infer('shufflenet_v2_x1_5') - - def test_shufflenetv2_x2_0(self): - self.models_infer('shufflenet_v2_x2_0') - def test_vgg16_num_classes(self): vgg16 = models.__dict__['vgg16'](pretrained=False, num_classes=10) diff --git a/python/paddle/vision/__init__.py b/python/paddle/vision/__init__.py index 22a42b6d312770..27f0814521e106 100644 --- a/python/paddle/vision/__init__.py +++ b/python/paddle/vision/__init__.py @@ -34,6 +34,8 @@ from .models import resnet50 # noqa: F401 from .models import resnet101 # noqa: F401 from .models import resnet152 # noqa: F401 +from .models import wide_resnet50_2 +from .models import wide_resnet101_2 from .models import MobileNetV1 # noqa: F401 from .models import mobilenet_v1 # noqa: F401 from .models import MobileNetV2 # noqa: F401 diff --git a/python/paddle/vision/models/__init__.py b/python/paddle/vision/models/__init__.py index a66d77fc888681..203473b97ad419 100644 --- a/python/paddle/vision/models/__init__.py +++ b/python/paddle/vision/models/__init__.py @@ -11,13 +11,14 @@ #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #See the License for the specific language governing permissions and #limitations under the License. - from .resnet import ResNet # noqa: F401 from .resnet import resnet18 # noqa: F401 from .resnet import resnet34 # noqa: F401 from .resnet import resnet50 # noqa: F401 from .resnet import resnet101 # noqa: F401 from .resnet import resnet152 # noqa: F401 +from .resnet import wide_resnet50_2 +from .resnet import wide_resnet101_2 from .mobilenetv1 import MobileNetV1 # noqa: F401 from .mobilenetv1 import mobilenet_v1 # noqa: F401 from .mobilenetv2 import MobileNetV2 # noqa: F401 @@ -63,6 +64,8 @@ 'resnet50', 'resnet101', 'resnet152', + 'wide_resnet50_2', + 'wide_resnet101_2', 'VGG', 'vgg11', 'vgg13', diff --git a/python/paddle/vision/models/resnet.py b/python/paddle/vision/models/resnet.py index 5be69c93e8b5f0..18844db28a09ef 100644 --- a/python/paddle/vision/models/resnet.py +++ b/python/paddle/vision/models/resnet.py @@ -1,367 +1,485 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import division -from __future__ import print_function - -import paddle -import paddle.nn as nn - -from paddle.utils.download import get_weights_path_from_url - -__all__ = [] - -model_urls = { - 'resnet18': ('https://paddle-hapi.bj.bcebos.com/models/resnet18.pdparams', - 'cf548f46534aa3560945be4b95cd11c4'), - 'resnet34': ('https://paddle-hapi.bj.bcebos.com/models/resnet34.pdparams', - '8d2275cf8706028345f78ac0e1d31969'), - 'resnet50': ('https://paddle-hapi.bj.bcebos.com/models/resnet50.pdparams', - 'ca6f485ee1ab0492d38f323885b0ad80'), - 'resnet101': ('https://paddle-hapi.bj.bcebos.com/models/resnet101.pdparams', - '02f35f034ca3858e1e54d4036443c92d'), - 'resnet152': ('https://paddle-hapi.bj.bcebos.com/models/resnet152.pdparams', - '7ad16a2f1e7333859ff986138630fd7a'), -} - - -class BasicBlock(nn.Layer): - expansion = 1 - - def __init__(self, - inplanes, - planes, - stride=1, - downsample=None, - groups=1, - base_width=64, - dilation=1, - norm_layer=None): - super(BasicBlock, self).__init__() - if norm_layer is None: - norm_layer = nn.BatchNorm2D - - if dilation > 1: - raise NotImplementedError( - "Dilation > 1 not supported in BasicBlock") - - self.conv1 = nn.Conv2D( - inplanes, planes, 3, padding=1, stride=stride, bias_attr=False) - self.bn1 = norm_layer(planes) - self.relu = nn.ReLU() - self.conv2 = nn.Conv2D(planes, planes, 3, padding=1, bias_attr=False) - self.bn2 = norm_layer(planes) - self.downsample = downsample - self.stride = stride - - def forward(self, x): - identity = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - - if self.downsample is not None: - identity = self.downsample(x) - - out += identity - out = self.relu(out) - - return out - - -class BottleneckBlock(nn.Layer): - - expansion = 4 - - def __init__(self, - inplanes, - planes, - stride=1, - downsample=None, - groups=1, - base_width=64, - dilation=1, - norm_layer=None): - super(BottleneckBlock, self).__init__() - if norm_layer is None: - norm_layer = nn.BatchNorm2D - width = int(planes * (base_width / 64.)) * groups - - self.conv1 = nn.Conv2D(inplanes, width, 1, bias_attr=False) - self.bn1 = norm_layer(width) - - self.conv2 = nn.Conv2D( - width, - width, - 3, - padding=dilation, - stride=stride, - groups=groups, - dilation=dilation, - bias_attr=False) - self.bn2 = norm_layer(width) - - self.conv3 = nn.Conv2D( - width, planes * self.expansion, 1, bias_attr=False) - self.bn3 = norm_layer(planes * self.expansion) - self.relu = nn.ReLU() - self.downsample = downsample - self.stride = stride - - def forward(self, x): - identity = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - out = self.relu(out) - - out = self.conv3(out) - out = self.bn3(out) - - if self.downsample is not None: - identity = self.downsample(x) - - out += identity - out = self.relu(out) - - return out - - -class ResNet(nn.Layer): - """ResNet model from - `"Deep Residual Learning for Image Recognition" `_ - - Args: - Block (BasicBlock|BottleneckBlock): block module of model. - depth (int): layers of resnet, default: 50. - num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer - will not be defined. Default: 1000. - with_pool (bool): use pool before the last fc layer or not. Default: True. - - Examples: - .. code-block:: python - - from paddle.vision.models import ResNet - from paddle.vision.models.resnet import BottleneckBlock, BasicBlock - - resnet50 = ResNet(BottleneckBlock, 50) - - resnet18 = ResNet(BasicBlock, 18) - - """ - - def __init__(self, block, depth, num_classes=1000, with_pool=True): - super(ResNet, self).__init__() - layer_cfg = { - 18: [2, 2, 2, 2], - 34: [3, 4, 6, 3], - 50: [3, 4, 6, 3], - 101: [3, 4, 23, 3], - 152: [3, 8, 36, 3] - } - layers = layer_cfg[depth] - self.num_classes = num_classes - self.with_pool = with_pool - self._norm_layer = nn.BatchNorm2D - - self.inplanes = 64 - self.dilation = 1 - - self.conv1 = nn.Conv2D( - 3, - self.inplanes, - kernel_size=7, - stride=2, - padding=3, - bias_attr=False) - self.bn1 = self._norm_layer(self.inplanes) - self.relu = nn.ReLU() - self.maxpool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1) - self.layer1 = self._make_layer(block, 64, layers[0]) - self.layer2 = self._make_layer(block, 128, layers[1], stride=2) - self.layer3 = self._make_layer(block, 256, layers[2], stride=2) - self.layer4 = self._make_layer(block, 512, layers[3], stride=2) - if with_pool: - self.avgpool = nn.AdaptiveAvgPool2D((1, 1)) - - if num_classes > 0: - self.fc = nn.Linear(512 * block.expansion, num_classes) - - def _make_layer(self, block, planes, blocks, stride=1, dilate=False): - norm_layer = self._norm_layer - downsample = None - previous_dilation = self.dilation - if dilate: - self.dilation *= stride - stride = 1 - if stride != 1 or self.inplanes != planes * block.expansion: - downsample = nn.Sequential( - nn.Conv2D( - self.inplanes, - planes * block.expansion, - 1, - stride=stride, - bias_attr=False), - norm_layer(planes * block.expansion), ) - - layers = [] - layers.append( - block(self.inplanes, planes, stride, downsample, 1, 64, - previous_dilation, norm_layer)) - self.inplanes = planes * block.expansion - for _ in range(1, blocks): - layers.append(block(self.inplanes, planes, norm_layer=norm_layer)) - - return nn.Sequential(*layers) - - def forward(self, x): - x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) - x = self.maxpool(x) - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) - - if self.with_pool: - x = self.avgpool(x) - - if self.num_classes > 0: - x = paddle.flatten(x, 1) - x = self.fc(x) - - return x - - -def _resnet(arch, Block, depth, pretrained, **kwargs): - model = ResNet(Block, depth, **kwargs) - if pretrained: - assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format( - arch) - weight_path = get_weights_path_from_url(model_urls[arch][0], - model_urls[arch][1]) - - param = paddle.load(weight_path) - model.set_dict(param) - - return model - - -def resnet18(pretrained=False, **kwargs): - """ResNet 18-layer model - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - - Examples: - .. code-block:: python - - from paddle.vision.models import resnet18 - - # build model - model = resnet18() - - # build model and load imagenet pretrained weight - # model = resnet18(pretrained=True) - """ - return _resnet('resnet18', BasicBlock, 18, pretrained, **kwargs) - - -def resnet34(pretrained=False, **kwargs): - """ResNet 34-layer model - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - - Examples: - .. code-block:: python - - from paddle.vision.models import resnet34 - - # build model - model = resnet34() - - # build model and load imagenet pretrained weight - # model = resnet34(pretrained=True) - """ - return _resnet('resnet34', BasicBlock, 34, pretrained, **kwargs) - - -def resnet50(pretrained=False, **kwargs): - """ResNet 50-layer model - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - - Examples: - .. code-block:: python - - from paddle.vision.models import resnet50 - - # build model - model = resnet50() - - # build model and load imagenet pretrained weight - # model = resnet50(pretrained=True) - """ - return _resnet('resnet50', BottleneckBlock, 50, pretrained, **kwargs) - - -def resnet101(pretrained=False, **kwargs): - """ResNet 101-layer model - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - - Examples: - .. code-block:: python - - from paddle.vision.models import resnet101 - - # build model - model = resnet101() - - # build model and load imagenet pretrained weight - # model = resnet101(pretrained=True) - """ - return _resnet('resnet101', BottleneckBlock, 101, pretrained, **kwargs) - - -def resnet152(pretrained=False, **kwargs): - """ResNet 152-layer model - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - - Examples: - .. code-block:: python - - from paddle.vision.models import resnet152 - - # build model - model = resnet152() - - # build model and load imagenet pretrained weight - # model = resnet152(pretrained=True) - """ - return _resnet('resnet152', BottleneckBlock, 152, pretrained, **kwargs) +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is adapted from torchvision.models.resnet +from __future__ import division +from __future__ import print_function + +from typing import Type, Any, Callable, Union, List, Optional +import paddle +import paddle.nn as nn +from paddle import Tensor +from paddle.utils.download import get_weights_path_from_url + +__all__ = [] + +model_urls = { + 'resnet18': ('https://paddle-hapi.bj.bcebos.com/models/resnet18.pdparams', + 'cf548f46534aa3560945be4b95cd11c4'), + 'resnet34': ('https://paddle-hapi.bj.bcebos.com/models/resnet34.pdparams', + '8d2275cf8706028345f78ac0e1d31969'), + 'resnet50': ('https://paddle-hapi.bj.bcebos.com/models/resnet50.pdparams', + 'ca6f485ee1ab0492d38f323885b0ad80'), + 'resnet101': ('https://paddle-hapi.bj.bcebos.com/models/resnet101.pdparams', + '02f35f034ca3858e1e54d4036443c92d'), + 'resnet152': ('https://paddle-hapi.bj.bcebos.com/models/resnet152.pdparams', + '7ad16a2f1e7333859ff986138630fd7a'), +} + + +def conv3x3(in_planes: int, + out_planes: int, + stride: int=1, + groups: int=1, + dilation: int=1) -> nn.Conv2D: + """3x3 convolution with padding.""" + return nn.Conv2D( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + dilation=dilation, + bias_attr=False) + + +def conv1x1(in_planes: int, out_planes: int, stride: int=1) -> nn.Conv2D: + """1x1 convolution""" + return nn.Conv2D( + in_planes, out_planes, kernel_size=1, stride=stride, bias_attr=False) + + +class BasicBlock(nn.Layer): + expansion: int = 1 + + def __init__(self, + inplanes: int, + planes: int, + stride: int=1, + downsample: Optional[nn.Layer]=None, + groups: int=1, + base_width: int=64, + dilation: int=1, + norm_layer: Optional[Callable[..., nn.Layer]]=None) -> None: + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2D + if groups != 1 or base_width != 64: + raise ValueError( + 'BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError( + "Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU() + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class BottleneckBlock(nn.Layer): + """ + BottleneckBlock places the stride for downsampling at 3x3 convolution(self.conv2) + while original implementation places the stride at the first 1x1 convolution(self.conv1) + according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + This variant is also known as ResNet V1.5 and improves accuracy according to + https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + """ + expansion: int = 4 + + def __init__(self, + inplanes: int, + planes: int, + stride: int=1, + downsample: Optional[nn.Layer]=None, + groups: int=1, + base_width: int=64, + dilation: int=1, + norm_layer: Optional[Callable[..., nn.Layer]]=None) -> None: + super(BottleneckBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2D + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU() + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Layer): + """ResNet model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + Block (BasicBlock|BottleneckBlock): block module of model. + depth (int): layers of resnet, default: 50. + num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer + will not be defined. Default: 1000. + with_pool (bool): use pool before the last fc layer or not. Default: True. + + Examples: + .. code-block:: python + + from paddle.vision.models import ResNet + from paddle.vision.models.resnet import BottleneckBlock, BasicBlock + + resnet50 = ResNet(BottleneckBlock, 50) + + resnet18 = ResNet(BasicBlock, 18) + + """ + + def __init__(self, + block: Type[Union[BasicBlock, BottleneckBlock]], + depth: int, + num_classes: int=1000, + groups: int=1, + width_per_group: int=64, + replace_stride_with_dilation: Optional[List[bool]]=None, + norm_layer: Optional[Callable[..., nn.Layer]]=None, + with_pool: bool=True) -> None: + super(ResNet, self).__init__() + layer_cfg = { + 18: [2, 2, 2, 2], + 34: [3, 4, 6, 3], + 50: [3, 4, 6, 3], + 101: [3, 4, 23, 3], + 152: [3, 8, 36, 3] + } + layers = layer_cfg[depth] + self.num_classes = num_classes + self.with_pool = with_pool + if norm_layer is None: + norm_layer = nn.BatchNorm2D + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if not isinstance(replace_stride_with_dilation, ( + tuple, list)) or len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple/list, got {}".format( + replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2D( + 3, + self.inplanes, + kernel_size=7, + stride=2, + padding=3, + bias_attr=False) + self.bn1 = self._norm_layer(self.inplanes) + self.relu = nn.ReLU() + self.maxpool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer( + block, + 128, + layers[1], + stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer( + block, + 256, + layers[2], + stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer( + block, + 512, + layers[3], + stride=2, + dilate=replace_stride_with_dilation[2]) + if with_pool: + self.avgpool = nn.AdaptiveAvgPool2D((1, 1)) + if num_classes > 0: + self.fc = nn.Linear(512 * block.expansion, num_classes) + + def _make_layer(self, + block: Type[Union[BasicBlock, BottleneckBlock]], + planes: int, + blocks: int, + stride: int=1, + dilate: bool=False) -> nn.Sequential: + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), ) + + layers = [] + layers.append( + block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def forward(self, x: Tensor) -> Tensor: + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + if self.with_pool: + x = self.avgpool(x) + + if self.num_classes > 0: + x = paddle.flatten(x, 1) + x = self.fc(x) + + return x + + +def _resnet(arch: str, + block: Type[Union[BasicBlock, BottleneckBlock]], + depth: int, + pretrained: bool, + **kwargs: Any) -> ResNet: + model = ResNet(block, depth, **kwargs) + if pretrained: + assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format( + arch) + weights_path = get_weights_path_from_url(model_urls[arch][0], + model_urls[arch][1]) + param = paddle.load(weights_path) + model.set_dict(param) + + return model + + +def resnet18(pretrained: bool=False, **kwargs: Any) -> ResNet: + r"""ResNet-18 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + + Examples: + .. code-block:: python + + from paddle.vision.models import resnet18 + + # build model + model = resnet18() + + # build model and load imagenet pretrained weight + # model = resnet18(pretrained=True) + """ + return _resnet('resnet18', BasicBlock, 18, pretrained, **kwargs) + + +def resnet34(pretrained: bool=False, **kwargs: Any) -> ResNet: + r"""ResNet-34 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + + Examples: + .. code-block:: python + + from paddle.vision.models import resnet34 + + # build model + model = resnet34() + + # build model and load imagenet pretrained weight + # model = resnet34(pretrained=True) + """ + return _resnet('resnet34', BasicBlock, 34, pretrained, **kwargs) + + +def resnet50(pretrained: bool=False, **kwargs: Any) -> ResNet: + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + + Examples: + .. code-block:: python + + from paddle.vision.models import resnet50 + + # build model + model = resnet50() + + # build model and load imagenet pretrained weight + # model = resnet50(pretrained=True) + """ + return _resnet('resnet50', BottleneckBlock, 50, pretrained, **kwargs) + + +def resnet101(pretrained: bool=False, **kwargs: Any) -> ResNet: + r"""ResNet-101 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + + Examples: + .. code-block:: python + + from paddle.vision.models import resnet101 + + # build model + model = resnet101() + + # build model and load imagenet pretrained weight + model = resnet101(pretrained=True) + """ + return _resnet('resnet101', BottleneckBlock, 101, pretrained, **kwargs) + + +def resnet152(pretrained: bool=False, **kwargs: Any) -> ResNet: + r"""ResNet-152 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + + Examples: + .. code-block:: python + + from paddle.vision.models import resnet152 + + # build model + model = resnet152() + + # build model and load imagenet pretrained weight + model = resnet152(pretrained=True) + """ + return _resnet('resnet152', BottleneckBlock, 152, pretrained, **kwargs) + + +def wide_resnet50_2(pretrained: bool=False, **kwargs: Any) -> ResNet: + r"""Wide ResNet-50-2 model from + `"Wide Residual Networks" `_. + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + + Args: + pretrained (bool): If True, return a model pre-trained on ImageNet + + Examples: + .. code-block:: python + + from paddle.vision.models import wide_resnet50_2 + + # build model + model = wide_resnet50_2() + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet50_2', BottleneckBlock, 50, pretrained, **kwargs) + + +def wide_resnet101_2(pretrained: bool=False, **kwargs: Any) -> ResNet: + r"""Wide ResNet-101-2 model from + `"Wide Residual Networks" `_. + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + + Examples: + .. code-block:: python + + from paddle.vision.models import wide_resnet101_2 + + # build model + model = wide_resnet101_2() + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet101_2', BottleneckBlock, 101, pretrained, + **kwargs)