diff --git a/configs/resnet_ebv/README.md b/configs/resnet_ebv/README.md new file mode 100644 index 000000000..9e3381db7 --- /dev/null +++ b/configs/resnet_ebv/README.md @@ -0,0 +1,85 @@ +# ResNet_EBV +> [Equiangular Basis Vectors](https://arxiv.org/abs/2303.11637) + +## Introduction + +EBVs provide a solution to the problem of classification with a large number of classes in resource-constrained environments. When the number of classes is C (e.g., C > 100,000), the number of trainable parameters in the final linear layer of a traditional ResNet-50 increases to 2048 * C. In contrast, EBVs reduce this by using fixed basis vectors for different classes, where the dimensionality is d (with d << C), and constraining the angles between these basis vectors during initialization. After that, EBVs are fixed, which reduces the number of trainable parameters to 2048 * d. EBVs can also be extended to other architectures.[[1](#references)] + + + +## Results + +Our reproduced model performance on ImageNet-1K is reported as follows. + +
+ +| Model | Context | Top-1 (%) | Top-5 (%) | Params (M) | Recipe | Download | +|------------|----------|-----------|-----------|------------|--------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------| +| resnet50_ebv | D910x8-G | 78.12 | 93.80 | 27.55 | [yaml](./resnet50_ebv_ascend.yaml) | \ | + + +
+ +#### Notes + +- Context: Training context denoted as {device}x{pieces}-{MS mode}, where mindspore mode can be G - graph mode or F - pynative mode with ms function. For example, D910x8-G is for training on 8 pieces of Ascend 910 NPU using graph mode. +- Top-1 and Top-5: Accuracy reported on the validation set of ImageNet-1K. + +## Quick Start + +### Preparation + +#### Installation +Please refer to the [installation instruction](https://github.com/mindspore-ecosystem/mindcv#installation) in MindCV. + +#### Dataset Preparation +Please download the [ImageNet-1K](https://www.image-net.org/challenges/LSVRC/2012/index.php) dataset for model training and validation. + +### Training + +* Distributed Training + +It is easy to reproduce the reported results with the pre-defined training recipe. For distributed training on multiple Ascend 910 devices, please run + +```shell +# distributed training on multiple GPU/Ascend devices +mpirun -n 8 python train.py --config configs/resnet/resnet50_ebv_ascend.yaml --data_dir /path/to/imagenet +``` + +> If the script is executed by the root user, the `--allow-run-as-root` parameter must be added to `mpirun`. + +Similarly, you can train the model on multiple GPU devices with the above `mpirun` command. + +For detailed illustration of all hyper-parameters, please refer to [config.py](https://github.com/mindspore-lab/mindcv/blob/main/config.py). + +**Note:** As the global batch size (batch_size x num_devices) is an important hyper-parameter, it is recommended to keep the global batch size unchanged for reproduction or adjust the learning rate linearly to a new global batch size. + +* Standalone Training + +If you want to train or finetune the model on a smaller dataset without distributed training, please run: + +```shell +# standalone training on a CPU/GPU/Ascend device +python train.py --config configs/resnet/resnet50_ebv_ascend.yaml --data_dir /path/to/dataset --distribute False +``` + +### Validation + +To validate the accuracy of the trained model, you can use `validate.py` and parse the checkpoint path with `--ckpt_path`. + +```shell +python validate.py -c configs/resnet/resnet50_ebv_ascend.yaml --data_dir /path/to/imagenet --ckpt_path /path/to/ckpt +``` + +### Deployment + +Please refer to the [deployment tutorial](https://mindspore-lab.github.io/mindcv/tutorials/deployment/) in MindCV. + +## References + +[1] Shen Y, Sun X, Wei X S. Equiangular basis vectors[C]//Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2023: 11755-11765. diff --git a/configs/resnet_ebv/resnet50_ebv_ascend.yaml b/configs/resnet_ebv/resnet50_ebv_ascend.yaml new file mode 100644 index 000000000..dda95c8dd --- /dev/null +++ b/configs/resnet_ebv/resnet50_ebv_ascend.yaml @@ -0,0 +1,58 @@ +# system +mode: 0 +distribute: True +num_parallel_workers: 8 +val_while_train: True +val_interval: 1 + +# dataset +dataset: "imagenet" +data_dir: "/path/to/imagenet" +shuffle: True +dataset_download: False +batch_size: 128 +drop_remainder: True + +# augmentation +image_resize: 224 +scale: [0.08, 1.0] +ratio: [0.75, 1.333] +hflip: 0.5 +interpolation: "bilinear" +re_prob: 0.1 +mixup: 0.2 +cutmix: 1.0 +cutmix_prob: 1.0 +crop_pct: 0.875 +color_jitter: [0., 0., 0.] +auto_augment: "trivialaugwide" + +# model +model: "resnet50_ebv" +num_classes: 1000 +pretrained: False +keep_checkpoint_max: 5 +ckpt_save_policy: "top_k" +ckpt_save_dir: "./ckpt" +epoch_size: 205 +dataset_sink_mode: True +amp_level: "O2" + +# loss +loss: "CE" +label_smoothing: 0.1 + +# lr scheduler +scheduler: "cosine_decay" +lr: 0.1 +min_lr: 0 +warmup_epochs: 5 +decay_epochs: 200 +lr_epoch_stair: False + +# optimizer +opt: "momentum" +momentum: 0.9 +weight_decay: 0.00002 +loss_scale: 1024 +use_nesterov: False diff --git a/mindcv/models/__init__.py b/mindcv/models/__init__.py index f3395796c..cadb0909c 100644 --- a/mindcv/models/__init__.py +++ b/mindcv/models/__init__.py @@ -40,6 +40,7 @@ res2net, resnest, resnet, + resnet_ebv, resnetv2, rexnet, senet, @@ -97,6 +98,7 @@ from .res2net import * from .resnest import * from .resnet import * +from .resnet_ebv import * from .resnetv2 import * from .rexnet import * from .senet import * @@ -171,3 +173,4 @@ __all__.extend(volo.__all__) __all__.extend(["Xception", "xception"]) __all__.extend(xcit.__all__) +__all__.extend(resnet_ebv.__all__) diff --git a/mindcv/models/layers/__init__.py b/mindcv/models/layers/__init__.py index c3e4de210..702c2a8cf 100644 --- a/mindcv/models/layers/__init__.py +++ b/mindcv/models/layers/__init__.py @@ -3,6 +3,7 @@ activation, conv_norm_act, drop_path, + ebv, format, identity, patch_dropout, @@ -14,6 +15,7 @@ from .activation import * from .conv_norm_act import * from .drop_path import * +from .ebv import * from .format import * from .identity import * from .patch_dropout import * diff --git a/mindcv/models/layers/ebv.py b/mindcv/models/layers/ebv.py new file mode 100644 index 000000000..8bffbd3fc --- /dev/null +++ b/mindcv/models/layers/ebv.py @@ -0,0 +1,88 @@ +"""EBV +Mindspore implementations of Equiangular Basis Vectors layer. +Papers: +Equiangular Basis Vectors (https://arxiv.org/pdf/2303.11637) +""" + +import mindspore as ms +import mindspore.numpy as msnp +from mindspore import Tensor, nn, ops + + +class EBV(nn.Cell): + """ + Equiangular Basis Vectors layer + """ + + def __init__( + self, + num_cls: int = 1000, + dim: int = 1000, + thre: float = 0.002, + slice_size: int = 130, + lr: float = 1e-3, + steps: int = 100000, + tau: float = 0.07 + ) -> None: + """ + Args: + num_cls (int): Number of categories, which can also be interpreted as the + number of basis vectors N that need to be generated, num_cls >= N. Default: 1000. + dim (int): Dimension for basis vectors. Default: 1000. + thre (float): The maximum value of the absolute cosine value + of the angle between any two basis vectors. Default: 0.002. + slice_size (int): Slicing optimization is required due to insufficient memory. Default: 130. + lr (float): Optimize learning rate. Default: 1e-3. + steps (int): Optimize step numbers. Default: 100000. + tau (float): Temperature parameter, less than + -num_cls/((num_cls-1) * log(exp(0.001) -1)/(N-1))). Default: 0.07 + """ + super().__init__() + self.num_cls = num_cls + self.dim = dim + self.thre = thre + self.slice_size = slice_size + self.lr = lr + self.steps = steps + self.tau = tau + self.l2norm = ops.L2Normalize() + self.ebv = self._generate_ebv() + self.ebv.requires_grad = False + + def _generate_ebv(self): + basis_vec = ms.Parameter( + ops.L2Normalize(1)( + ops.standard_normal((self.num_cls, self.dim)) + ), name='basis_vec', requires_grad=True) + optim = nn.SGD([basis_vec], learning_rate=self.lr) + matmul = ops.MatMul(transpose_b=True) + + def forward_fn(a, b, e, thr): + m = matmul(a, b).abs() - e + loss = ops.relu(m - thr).sum() + return loss, m + + grad_fn = ops.value_and_grad(forward_fn, 1, [basis_vec], has_aux=True) + for _ in range(self.steps): + basis_vec.set_data(ops.L2Normalize(1)(basis_vec.data)) + mx = self.thre + grads = msnp.zeros_like(basis_vec) + for i in range((self.num_cls - 1) // self.slice_size + 1): + start = self.slice_size * i + end = min(self.slice_size * (i + 1), self.num_cls) + e = ops.one_hot(msnp.arange(start, end), self.num_cls, Tensor(1.0, ms.float32), Tensor(0.0, ms.float32)) + (loss, m), grads_partial = grad_fn(basis_vec[start:end], basis_vec, e, self.thre) + mx = max(mx, m.max().asnumpy().tolist()) + grads = grads + grads_partial[0] + + if mx <= self.thre + 0.0001: + return self.l2norm(basis_vec.data) + optim((grads,)) + + return self.l2norm(basis_vec.data) + + def construct(self, x: Tensor) -> Tensor: + x = self.l2norm(x) + logits = ops.matmul(x, self.ebv.T / self.tau) + + return logits diff --git a/mindcv/models/resnet_ebv.py b/mindcv/models/resnet_ebv.py new file mode 100644 index 000000000..023d38014 --- /dev/null +++ b/mindcv/models/resnet_ebv.py @@ -0,0 +1,419 @@ +""" +MindSpore implementation of `ResNetEBV`. +Refer to Equiangular Basis Vectors. +""" + +from typing import List, Optional, Type, Union + +import mindspore.common.initializer as init +from mindspore import Tensor, nn + +from .helpers import build_model_with_cfg +from .layers import EBV +from .layers.pooling import GlobalAvgPooling +from .registry import register_model + +__all__ = [ + "ResNetEBV", + "resnet18_ebv", + "resnet34_ebv", + "resnet50_ebv", + "resnet101_ebv", + "resnet152_ebv", + "resnext50_32x4d_ebv", + "resnext101_32x4d_ebv", + "resnext101_64x4d_ebv", + "resnext152_64x4d_ebv", +] + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, + 'first_conv': '', 'classifier': '', + **kwargs + } + + +default_cfgs = { + "resnet18_ebv": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/resnet/resnet18-1e65cd21.ckpt"), + "resnet34_ebv": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/resnet/resnet34-f297d27e.ckpt"), + "resnet50_ebv": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/resnet/resnet50-e0733ab8.ckpt"), + "resnet101_ebv": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/resnet/resnet101-689c5e77.ckpt"), + "resnet152_ebv": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/resnet/resnet152-beb689d8.ckpt"), + "resnext50_32x4d_ebv": _cfg( + url="https://download.mindspore.cn/toolkits/mindcv/resnext/resnext50_32x4d-af8aba16.ckpt"), + "resnext101_32x4d_ebv": _cfg( + url="https://download.mindspore.cn/toolkits/mindcv/resnext/resnext101_32x4d-3c1e9c51.ckpt" + ), + "resnext101_64x4d": _cfg( + url="https://download.mindspore.cn/toolkits/mindcv/resnext/resnext101_64x4d-8929255b.ckpt" + ), + "resnext152_64x4d": _cfg( + url="https://download.mindspore.cn/toolkits/mindcv/resnext/resnext152_64x4d-3aba275c.ckpt" + ), +} + + +class BasicBlock(nn.Cell): + """define the basic block of resnet""" + expansion: int = 1 + + def __init__( + self, + in_channels: int, + channels: int, + stride: int = 1, + groups: int = 1, + base_width: int = 64, + norm: Optional[nn.Cell] = None, + down_sample: Optional[nn.Cell] = None, + ) -> None: + super().__init__() + if norm is None: + norm = nn.BatchNorm2d + assert groups == 1, "BasicBlock only supports groups=1" + assert base_width == 64, "BasicBlock only supports base_width=64" + + self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=3, + stride=stride, padding=1, pad_mode="pad") + self.bn1 = norm(channels) + self.relu = nn.ReLU() + self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, + stride=1, padding=1, pad_mode="pad") + self.bn2 = norm(channels) + self.down_sample = down_sample + + def construct(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.down_sample is not None: + identity = self.down_sample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Cell): + """ + Bottleneck here places the stride for downsampling at 3x3 convolution(self.conv2) as torchvision does, + while original implementation places the stride at the first 1x1 convolution(self.conv1) + """ + expansion: int = 4 + + def __init__( + self, + in_channels: int, + channels: int, + stride: int = 1, + groups: int = 1, + base_width: int = 64, + norm: Optional[nn.Cell] = None, + down_sample: Optional[nn.Cell] = None, + ) -> None: + super().__init__() + if norm is None: + norm = nn.BatchNorm2d + + width = int(channels * (base_width / 64.0)) * groups + + self.conv1 = nn.Conv2d(in_channels, width, kernel_size=1, stride=1) + self.bn1 = norm(width) + self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, + padding=1, pad_mode="pad", group=groups) + self.bn2 = norm(width) + self.conv3 = nn.Conv2d(width, channels * self.expansion, + kernel_size=1, stride=1) + self.bn3 = norm(channels * self.expansion) + self.relu = nn.ReLU() + self.down_sample = down_sample + + def construct(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.down_sample is not None: + identity = self.down_sample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNetEBV(nn.Cell): + r"""ResNet model class with EBV layer, based on + `"Deep Residual Learning for Image Recognition" ` and + `Equiangular Basis Vectors ` + + Args: + block: block of resnet. + layers: number of layers of each stage. + num_classes: number of classification classes. Default: 1000. + in_channels: number the channels of the input. Default: 3. + groups: number of groups for group conv in blocks. Default: 1. + base_width: base width of pre group hidden channel in blocks. Default: 64. + norm: normalization layer in blocks. Default: None. + dim (int): Dimension for basis vectors. Default: 1000. + thre (float): The maximum value of the absolute cosine value + of the angle between any two basis vectors. Default: 0.002. + slice_size (int): Slicing optimization is required due to insufficient memory. Default: 130. + lr (float): Optimize learning rate. Default: 1e-3. + steps (int): Optimize step numbers. Default: 100000. + tau (float): Temperature parameter, less than + -num_cls/((num_cls-1) * log(exp(0.001) -1)/(N-1))). Default: 0.07 + """ + + def __init__( + self, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + num_classes: int = 1000, + in_channels: int = 3, + groups: int = 1, + base_width: int = 64, + norm: Optional[nn.Cell] = None, + dim: int = 1000, + thre: float = 0.002, + slice_size: int = 130, + lr: float = 1e-3, + steps: int = 100000, + tau: float = 0.07 + ) -> None: + super().__init__() + if norm is None: + norm = nn.BatchNorm2d + + self.norm: nn.Cell = norm # add type hints to make pylint happy + self.input_channels = 64 + self.groups = groups + self.base_with = base_width + + self.conv1 = nn.Conv2d(in_channels, self.input_channels, kernel_size=7, + stride=2, pad_mode="pad", padding=3) + self.bn1 = norm(self.input_channels) + self.relu = nn.ReLU() + self.feature_info = [dict(chs=self.input_channels, reduction=2, name="relu")] + self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") + self.layer1 = self._make_layer(block, 64, layers[0]) + self.feature_info.append(dict(chs=block.expansion * 64, reduction=4, name="layer1")) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.feature_info.append(dict(chs=block.expansion * 128, reduction=8, name="layer2")) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.feature_info.append(dict(chs=block.expansion * 256, reduction=16, name="layer3")) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.feature_info.append(dict(chs=block.expansion * 512, reduction=32, name="layer4")) + + self.pool = GlobalAvgPooling() + self.num_features = 512 * block.expansion + self.classifier = nn.Dense(self.num_features, dim) + self.ebv = EBV(num_classes, dim, thre, slice_size, lr, steps, tau) + + self._initialize_weights() + + def _initialize_weights(self) -> None: + """Initialize weights for cells.""" + for _, cell in self.cells_and_names(): + if isinstance(cell, nn.Conv2d): + cell.weight.set_data( + init.initializer(init.HeNormal(mode='fan_out', nonlinearity='relu'), + cell.weight.shape, cell.weight.dtype)) + if cell.bias is not None: + cell.bias.set_data( + init.initializer('zeros', cell.bias.shape, cell.bias.dtype)) + elif isinstance(cell, nn.BatchNorm2d): + cell.gamma.set_data(init.initializer('ones', cell.gamma.shape, cell.gamma.dtype)) + cell.beta.set_data(init.initializer('zeros', cell.beta.shape, cell.beta.dtype)) + elif isinstance(cell, nn.Dense): + cell.weight.set_data( + init.initializer(init.HeUniform(mode='fan_in', nonlinearity='sigmoid'), + cell.weight.shape, cell.weight.dtype)) + if cell.bias is not None: + cell.bias.set_data(init.initializer('zeros', cell.bias.shape, cell.bias.dtype)) + + def _make_layer( + self, + block: Type[Union[BasicBlock, Bottleneck]], + channels: int, + block_nums: int, + stride: int = 1, + ) -> nn.SequentialCell: + """build model depending on cfgs""" + down_sample = None + + if stride != 1 or self.input_channels != channels * block.expansion: + down_sample = nn.SequentialCell([ + nn.Conv2d(self.input_channels, channels * block.expansion, kernel_size=1, stride=stride), + self.norm(channels * block.expansion) + ]) + + layers = [] + layers.append( + block( + self.input_channels, + channels, + stride=stride, + down_sample=down_sample, + groups=self.groups, + base_width=self.base_with, + norm=self.norm, + ) + ) + self.input_channels = channels * block.expansion + + for _ in range(1, block_nums): + layers.append( + block( + self.input_channels, + channels, + groups=self.groups, + base_width=self.base_with, + norm=self.norm + ) + ) + + return nn.SequentialCell(layers) + + def forward_features(self, x: Tensor) -> Tensor: + """Network forward feature extraction.""" + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.max_pool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + return x + + def forward_head(self, x: Tensor) -> Tensor: + x = self.pool(x) + x = self.classifier(x) + x = self.ebv(x) + return x + + def construct(self, x: Tensor) -> Tensor: + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _create_resnet_ebv(pretrained=False, **kwargs): + return build_model_with_cfg(ResNetEBV, pretrained, **kwargs) + + +@register_model +def resnet18_ebv(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs): + """Get 18 layers ResNet model with ebv layer. + Refer to the base class `models.ResNetEBV` for more details. + """ + default_cfg = default_cfgs["resnet18_ebv"] + model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], num_classes=num_classes, in_channels=in_channels, + **kwargs) + return _create_resnet_ebv(pretrained, **dict(default_cfg=default_cfg, **model_args)) + + +@register_model +def resnet34_ebv(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs): + """Get 34 layers ResNet model ebv layer. + Refer to the base class `models.ResNetEBV` for more details. + """ + default_cfg = default_cfgs["resnet34_ebv"] + model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], num_classes=num_classes, in_channels=in_channels, + **kwargs) + return _create_resnet_ebv(pretrained, **dict(default_cfg=default_cfg, **model_args)) + + +@register_model +def resnet50_ebv(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs): + """Get 50 layers ResNet model ebv layer. + Refer to the base class `models.ResNetEBV` for more details. + """ + default_cfg = default_cfgs["resnet50_ebv"] + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], num_classes=num_classes, in_channels=in_channels, + **kwargs) + return _create_resnet_ebv(pretrained, **dict(default_cfg=default_cfg, **model_args)) + + +@register_model +def resnet101_ebv(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs): + """Get 101 layers ResNet model ebv layer. + Refer to the base class `models.ResNetEBV` for more details. + """ + default_cfg = default_cfgs["resnet101_ebv"] + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], num_classes=num_classes, in_channels=in_channels, + **kwargs) + return _create_resnet_ebv(pretrained, **dict(default_cfg=default_cfg, **model_args)) + + +@register_model +def resnet152_ebv(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs): + """Get 152 layers ResNet model ebv layer. + Refer to the base class `models.ResNetEBV` for more details. + """ + default_cfg = default_cfgs["resnet152_ebv"] + model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], num_classes=num_classes, in_channels=in_channels, + **kwargs) + return _create_resnet_ebv(pretrained, **dict(default_cfg=default_cfg, **model_args)) + + +@register_model +def resnext50_32x4d_ebv(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs): + """Get 50 layers ResNeXt model with 32 groups of GPConv and ebv layer. + Refer to the base class `models.ResNetEBV` for more details. + """ + default_cfg = default_cfgs["resnext50_32x4d_ebv"] + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], groups=32, base_width=4, num_classes=num_classes, + in_channels=in_channels, **kwargs) + return _create_resnet_ebv(pretrained, **dict(default_cfg=default_cfg, **model_args)) + + +@register_model +def resnext101_32x4d_ebv(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs): + """Get 101 layers ResNeXt model with 32 groups of GPConv and ebv layer. + Refer to the base class `models.ResNetEBV` for more details. + """ + default_cfg = default_cfgs["resnext101_32x4d_ebv"] + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], groups=32, base_width=4, num_classes=num_classes, + in_channels=in_channels, **kwargs) + return _create_resnet_ebv(pretrained, **dict(default_cfg=default_cfg, **model_args)) + + +@register_model +def resnext101_64x4d_ebv(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs): + """Get 101 layers ResNeXt model with 64 groups of GPConv and ebv layer. + Refer to the base class `models.ResNetEBV` for more details. + """ + default_cfg = default_cfgs["resnext101_64x4d_ebv"] + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], groups=64, base_width=4, num_classes=num_classes, + in_channels=in_channels, **kwargs) + return _create_resnet_ebv(pretrained, **dict(default_cfg=default_cfg, **model_args)) + + +@register_model +def resnext152_64x4d_ebv(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs): + default_cfg = default_cfgs["resnext152_64x4d_ebv"] + model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], groups=64, base_width=4, num_classes=num_classes, + in_channels=in_channels, **kwargs) + return _create_resnet_ebv(pretrained, **dict(default_cfg=default_cfg, **model_args)) diff --git a/tests/modules/test_models.py b/tests/modules/test_models.py index 49607d87d..ee3a03b0b 100644 --- a/tests/modules/test_models.py +++ b/tests/modules/test_models.py @@ -58,6 +58,7 @@ "visformer_tiny", "vit_b_32_224", "xception", + "resnet18_ebv" ] check_loss_decrease = False