diff --git a/configs/resnet_ebv/README.md b/configs/resnet_ebv/README.md
new file mode 100644
index 000000000..9e3381db7
--- /dev/null
+++ b/configs/resnet_ebv/README.md
@@ -0,0 +1,85 @@
+# ResNet_EBV
+> [Equiangular Basis Vectors](https://arxiv.org/abs/2303.11637)
+
+## Introduction
+
+EBVs provide a solution to the problem of classification with a large number of classes in resource-constrained environments. When the number of classes is C (e.g., C > 100,000), the number of trainable parameters in the final linear layer of a traditional ResNet-50 increases to 2048 * C. In contrast, EBVs reduce this by using fixed basis vectors for different classes, where the dimensionality is d (with d << C), and constraining the angles between these basis vectors during initialization. After that, EBVs are fixed, which reduces the number of trainable parameters to 2048 * d. EBVs can also be extended to other architectures.[[1](#references)]
+
+
+
+## Results
+
+Our reproduced model performance on ImageNet-1K is reported as follows.
+
+
+
+| Model | Context | Top-1 (%) | Top-5 (%) | Params (M) | Recipe | Download |
+|------------|----------|-----------|-----------|------------|--------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------|
+| resnet50_ebv | D910x8-G | 78.12 | 93.80 | 27.55 | [yaml](./resnet50_ebv_ascend.yaml) | \ |
+
+
+
+
+#### Notes
+
+- Context: Training context denoted as {device}x{pieces}-{MS mode}, where mindspore mode can be G - graph mode or F - pynative mode with ms function. For example, D910x8-G is for training on 8 pieces of Ascend 910 NPU using graph mode.
+- Top-1 and Top-5: Accuracy reported on the validation set of ImageNet-1K.
+
+## Quick Start
+
+### Preparation
+
+#### Installation
+Please refer to the [installation instruction](https://github.com/mindspore-ecosystem/mindcv#installation) in MindCV.
+
+#### Dataset Preparation
+Please download the [ImageNet-1K](https://www.image-net.org/challenges/LSVRC/2012/index.php) dataset for model training and validation.
+
+### Training
+
+* Distributed Training
+
+It is easy to reproduce the reported results with the pre-defined training recipe. For distributed training on multiple Ascend 910 devices, please run
+
+```shell
+# distributed training on multiple GPU/Ascend devices
+mpirun -n 8 python train.py --config configs/resnet/resnet50_ebv_ascend.yaml --data_dir /path/to/imagenet
+```
+
+> If the script is executed by the root user, the `--allow-run-as-root` parameter must be added to `mpirun`.
+
+Similarly, you can train the model on multiple GPU devices with the above `mpirun` command.
+
+For detailed illustration of all hyper-parameters, please refer to [config.py](https://github.com/mindspore-lab/mindcv/blob/main/config.py).
+
+**Note:** As the global batch size (batch_size x num_devices) is an important hyper-parameter, it is recommended to keep the global batch size unchanged for reproduction or adjust the learning rate linearly to a new global batch size.
+
+* Standalone Training
+
+If you want to train or finetune the model on a smaller dataset without distributed training, please run:
+
+```shell
+# standalone training on a CPU/GPU/Ascend device
+python train.py --config configs/resnet/resnet50_ebv_ascend.yaml --data_dir /path/to/dataset --distribute False
+```
+
+### Validation
+
+To validate the accuracy of the trained model, you can use `validate.py` and parse the checkpoint path with `--ckpt_path`.
+
+```shell
+python validate.py -c configs/resnet/resnet50_ebv_ascend.yaml --data_dir /path/to/imagenet --ckpt_path /path/to/ckpt
+```
+
+### Deployment
+
+Please refer to the [deployment tutorial](https://mindspore-lab.github.io/mindcv/tutorials/deployment/) in MindCV.
+
+## References
+
+[1] Shen Y, Sun X, Wei X S. Equiangular basis vectors[C]//Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2023: 11755-11765.
diff --git a/configs/resnet_ebv/resnet50_ebv_ascend.yaml b/configs/resnet_ebv/resnet50_ebv_ascend.yaml
new file mode 100644
index 000000000..dda95c8dd
--- /dev/null
+++ b/configs/resnet_ebv/resnet50_ebv_ascend.yaml
@@ -0,0 +1,58 @@
+# system
+mode: 0
+distribute: True
+num_parallel_workers: 8
+val_while_train: True
+val_interval: 1
+
+# dataset
+dataset: "imagenet"
+data_dir: "/path/to/imagenet"
+shuffle: True
+dataset_download: False
+batch_size: 128
+drop_remainder: True
+
+# augmentation
+image_resize: 224
+scale: [0.08, 1.0]
+ratio: [0.75, 1.333]
+hflip: 0.5
+interpolation: "bilinear"
+re_prob: 0.1
+mixup: 0.2
+cutmix: 1.0
+cutmix_prob: 1.0
+crop_pct: 0.875
+color_jitter: [0., 0., 0.]
+auto_augment: "trivialaugwide"
+
+# model
+model: "resnet50_ebv"
+num_classes: 1000
+pretrained: False
+keep_checkpoint_max: 5
+ckpt_save_policy: "top_k"
+ckpt_save_dir: "./ckpt"
+epoch_size: 205
+dataset_sink_mode: True
+amp_level: "O2"
+
+# loss
+loss: "CE"
+label_smoothing: 0.1
+
+# lr scheduler
+scheduler: "cosine_decay"
+lr: 0.1
+min_lr: 0
+warmup_epochs: 5
+decay_epochs: 200
+lr_epoch_stair: False
+
+# optimizer
+opt: "momentum"
+momentum: 0.9
+weight_decay: 0.00002
+loss_scale: 1024
+use_nesterov: False
diff --git a/mindcv/models/__init__.py b/mindcv/models/__init__.py
index f3395796c..cadb0909c 100644
--- a/mindcv/models/__init__.py
+++ b/mindcv/models/__init__.py
@@ -40,6 +40,7 @@
res2net,
resnest,
resnet,
+ resnet_ebv,
resnetv2,
rexnet,
senet,
@@ -97,6 +98,7 @@
from .res2net import *
from .resnest import *
from .resnet import *
+from .resnet_ebv import *
from .resnetv2 import *
from .rexnet import *
from .senet import *
@@ -171,3 +173,4 @@
__all__.extend(volo.__all__)
__all__.extend(["Xception", "xception"])
__all__.extend(xcit.__all__)
+__all__.extend(resnet_ebv.__all__)
diff --git a/mindcv/models/layers/__init__.py b/mindcv/models/layers/__init__.py
index c3e4de210..702c2a8cf 100644
--- a/mindcv/models/layers/__init__.py
+++ b/mindcv/models/layers/__init__.py
@@ -3,6 +3,7 @@
activation,
conv_norm_act,
drop_path,
+ ebv,
format,
identity,
patch_dropout,
@@ -14,6 +15,7 @@
from .activation import *
from .conv_norm_act import *
from .drop_path import *
+from .ebv import *
from .format import *
from .identity import *
from .patch_dropout import *
diff --git a/mindcv/models/layers/ebv.py b/mindcv/models/layers/ebv.py
new file mode 100644
index 000000000..8bffbd3fc
--- /dev/null
+++ b/mindcv/models/layers/ebv.py
@@ -0,0 +1,88 @@
+"""EBV
+Mindspore implementations of Equiangular Basis Vectors layer.
+Papers:
+Equiangular Basis Vectors (https://arxiv.org/pdf/2303.11637)
+"""
+
+import mindspore as ms
+import mindspore.numpy as msnp
+from mindspore import Tensor, nn, ops
+
+
+class EBV(nn.Cell):
+ """
+ Equiangular Basis Vectors layer
+ """
+
+ def __init__(
+ self,
+ num_cls: int = 1000,
+ dim: int = 1000,
+ thre: float = 0.002,
+ slice_size: int = 130,
+ lr: float = 1e-3,
+ steps: int = 100000,
+ tau: float = 0.07
+ ) -> None:
+ """
+ Args:
+ num_cls (int): Number of categories, which can also be interpreted as the
+ number of basis vectors N that need to be generated, num_cls >= N. Default: 1000.
+ dim (int): Dimension for basis vectors. Default: 1000.
+ thre (float): The maximum value of the absolute cosine value
+ of the angle between any two basis vectors. Default: 0.002.
+ slice_size (int): Slicing optimization is required due to insufficient memory. Default: 130.
+ lr (float): Optimize learning rate. Default: 1e-3.
+ steps (int): Optimize step numbers. Default: 100000.
+ tau (float): Temperature parameter, less than
+ -num_cls/((num_cls-1) * log(exp(0.001) -1)/(N-1))). Default: 0.07
+ """
+ super().__init__()
+ self.num_cls = num_cls
+ self.dim = dim
+ self.thre = thre
+ self.slice_size = slice_size
+ self.lr = lr
+ self.steps = steps
+ self.tau = tau
+ self.l2norm = ops.L2Normalize()
+ self.ebv = self._generate_ebv()
+ self.ebv.requires_grad = False
+
+ def _generate_ebv(self):
+ basis_vec = ms.Parameter(
+ ops.L2Normalize(1)(
+ ops.standard_normal((self.num_cls, self.dim))
+ ), name='basis_vec', requires_grad=True)
+ optim = nn.SGD([basis_vec], learning_rate=self.lr)
+ matmul = ops.MatMul(transpose_b=True)
+
+ def forward_fn(a, b, e, thr):
+ m = matmul(a, b).abs() - e
+ loss = ops.relu(m - thr).sum()
+ return loss, m
+
+ grad_fn = ops.value_and_grad(forward_fn, 1, [basis_vec], has_aux=True)
+ for _ in range(self.steps):
+ basis_vec.set_data(ops.L2Normalize(1)(basis_vec.data))
+ mx = self.thre
+ grads = msnp.zeros_like(basis_vec)
+ for i in range((self.num_cls - 1) // self.slice_size + 1):
+ start = self.slice_size * i
+ end = min(self.slice_size * (i + 1), self.num_cls)
+ e = ops.one_hot(msnp.arange(start, end), self.num_cls, Tensor(1.0, ms.float32), Tensor(0.0, ms.float32))
+ (loss, m), grads_partial = grad_fn(basis_vec[start:end], basis_vec, e, self.thre)
+ mx = max(mx, m.max().asnumpy().tolist())
+ grads = grads + grads_partial[0]
+
+ if mx <= self.thre + 0.0001:
+ return self.l2norm(basis_vec.data)
+ optim((grads,))
+
+ return self.l2norm(basis_vec.data)
+
+ def construct(self, x: Tensor) -> Tensor:
+ x = self.l2norm(x)
+ logits = ops.matmul(x, self.ebv.T / self.tau)
+
+ return logits
diff --git a/mindcv/models/resnet_ebv.py b/mindcv/models/resnet_ebv.py
new file mode 100644
index 000000000..023d38014
--- /dev/null
+++ b/mindcv/models/resnet_ebv.py
@@ -0,0 +1,419 @@
+"""
+MindSpore implementation of `ResNetEBV`.
+Refer to Equiangular Basis Vectors.
+"""
+
+from typing import List, Optional, Type, Union
+
+import mindspore.common.initializer as init
+from mindspore import Tensor, nn
+
+from .helpers import build_model_with_cfg
+from .layers import EBV
+from .layers.pooling import GlobalAvgPooling
+from .registry import register_model
+
+__all__ = [
+ "ResNetEBV",
+ "resnet18_ebv",
+ "resnet34_ebv",
+ "resnet50_ebv",
+ "resnet101_ebv",
+ "resnet152_ebv",
+ "resnext50_32x4d_ebv",
+ "resnext101_32x4d_ebv",
+ "resnext101_64x4d_ebv",
+ "resnext152_64x4d_ebv",
+]
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000,
+ 'first_conv': '', 'classifier': '',
+ **kwargs
+ }
+
+
+default_cfgs = {
+ "resnet18_ebv": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/resnet/resnet18-1e65cd21.ckpt"),
+ "resnet34_ebv": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/resnet/resnet34-f297d27e.ckpt"),
+ "resnet50_ebv": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/resnet/resnet50-e0733ab8.ckpt"),
+ "resnet101_ebv": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/resnet/resnet101-689c5e77.ckpt"),
+ "resnet152_ebv": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/resnet/resnet152-beb689d8.ckpt"),
+ "resnext50_32x4d_ebv": _cfg(
+ url="https://download.mindspore.cn/toolkits/mindcv/resnext/resnext50_32x4d-af8aba16.ckpt"),
+ "resnext101_32x4d_ebv": _cfg(
+ url="https://download.mindspore.cn/toolkits/mindcv/resnext/resnext101_32x4d-3c1e9c51.ckpt"
+ ),
+ "resnext101_64x4d": _cfg(
+ url="https://download.mindspore.cn/toolkits/mindcv/resnext/resnext101_64x4d-8929255b.ckpt"
+ ),
+ "resnext152_64x4d": _cfg(
+ url="https://download.mindspore.cn/toolkits/mindcv/resnext/resnext152_64x4d-3aba275c.ckpt"
+ ),
+}
+
+
+class BasicBlock(nn.Cell):
+ """define the basic block of resnet"""
+ expansion: int = 1
+
+ def __init__(
+ self,
+ in_channels: int,
+ channels: int,
+ stride: int = 1,
+ groups: int = 1,
+ base_width: int = 64,
+ norm: Optional[nn.Cell] = None,
+ down_sample: Optional[nn.Cell] = None,
+ ) -> None:
+ super().__init__()
+ if norm is None:
+ norm = nn.BatchNorm2d
+ assert groups == 1, "BasicBlock only supports groups=1"
+ assert base_width == 64, "BasicBlock only supports base_width=64"
+
+ self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=3,
+ stride=stride, padding=1, pad_mode="pad")
+ self.bn1 = norm(channels)
+ self.relu = nn.ReLU()
+ self.conv2 = nn.Conv2d(channels, channels, kernel_size=3,
+ stride=1, padding=1, pad_mode="pad")
+ self.bn2 = norm(channels)
+ self.down_sample = down_sample
+
+ def construct(self, x: Tensor) -> Tensor:
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.down_sample is not None:
+ identity = self.down_sample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Cell):
+ """
+ Bottleneck here places the stride for downsampling at 3x3 convolution(self.conv2) as torchvision does,
+ while original implementation places the stride at the first 1x1 convolution(self.conv1)
+ """
+ expansion: int = 4
+
+ def __init__(
+ self,
+ in_channels: int,
+ channels: int,
+ stride: int = 1,
+ groups: int = 1,
+ base_width: int = 64,
+ norm: Optional[nn.Cell] = None,
+ down_sample: Optional[nn.Cell] = None,
+ ) -> None:
+ super().__init__()
+ if norm is None:
+ norm = nn.BatchNorm2d
+
+ width = int(channels * (base_width / 64.0)) * groups
+
+ self.conv1 = nn.Conv2d(in_channels, width, kernel_size=1, stride=1)
+ self.bn1 = norm(width)
+ self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride,
+ padding=1, pad_mode="pad", group=groups)
+ self.bn2 = norm(width)
+ self.conv3 = nn.Conv2d(width, channels * self.expansion,
+ kernel_size=1, stride=1)
+ self.bn3 = norm(channels * self.expansion)
+ self.relu = nn.ReLU()
+ self.down_sample = down_sample
+
+ def construct(self, x: Tensor) -> Tensor:
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.down_sample is not None:
+ identity = self.down_sample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class ResNetEBV(nn.Cell):
+ r"""ResNet model class with EBV layer, based on
+ `"Deep Residual Learning for Image Recognition" ` and
+ `Equiangular Basis Vectors `
+
+ Args:
+ block: block of resnet.
+ layers: number of layers of each stage.
+ num_classes: number of classification classes. Default: 1000.
+ in_channels: number the channels of the input. Default: 3.
+ groups: number of groups for group conv in blocks. Default: 1.
+ base_width: base width of pre group hidden channel in blocks. Default: 64.
+ norm: normalization layer in blocks. Default: None.
+ dim (int): Dimension for basis vectors. Default: 1000.
+ thre (float): The maximum value of the absolute cosine value
+ of the angle between any two basis vectors. Default: 0.002.
+ slice_size (int): Slicing optimization is required due to insufficient memory. Default: 130.
+ lr (float): Optimize learning rate. Default: 1e-3.
+ steps (int): Optimize step numbers. Default: 100000.
+ tau (float): Temperature parameter, less than
+ -num_cls/((num_cls-1) * log(exp(0.001) -1)/(N-1))). Default: 0.07
+ """
+
+ def __init__(
+ self,
+ block: Type[Union[BasicBlock, Bottleneck]],
+ layers: List[int],
+ num_classes: int = 1000,
+ in_channels: int = 3,
+ groups: int = 1,
+ base_width: int = 64,
+ norm: Optional[nn.Cell] = None,
+ dim: int = 1000,
+ thre: float = 0.002,
+ slice_size: int = 130,
+ lr: float = 1e-3,
+ steps: int = 100000,
+ tau: float = 0.07
+ ) -> None:
+ super().__init__()
+ if norm is None:
+ norm = nn.BatchNorm2d
+
+ self.norm: nn.Cell = norm # add type hints to make pylint happy
+ self.input_channels = 64
+ self.groups = groups
+ self.base_with = base_width
+
+ self.conv1 = nn.Conv2d(in_channels, self.input_channels, kernel_size=7,
+ stride=2, pad_mode="pad", padding=3)
+ self.bn1 = norm(self.input_channels)
+ self.relu = nn.ReLU()
+ self.feature_info = [dict(chs=self.input_channels, reduction=2, name="relu")]
+ self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.feature_info.append(dict(chs=block.expansion * 64, reduction=4, name="layer1"))
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+ self.feature_info.append(dict(chs=block.expansion * 128, reduction=8, name="layer2"))
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+ self.feature_info.append(dict(chs=block.expansion * 256, reduction=16, name="layer3"))
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+ self.feature_info.append(dict(chs=block.expansion * 512, reduction=32, name="layer4"))
+
+ self.pool = GlobalAvgPooling()
+ self.num_features = 512 * block.expansion
+ self.classifier = nn.Dense(self.num_features, dim)
+ self.ebv = EBV(num_classes, dim, thre, slice_size, lr, steps, tau)
+
+ self._initialize_weights()
+
+ def _initialize_weights(self) -> None:
+ """Initialize weights for cells."""
+ for _, cell in self.cells_and_names():
+ if isinstance(cell, nn.Conv2d):
+ cell.weight.set_data(
+ init.initializer(init.HeNormal(mode='fan_out', nonlinearity='relu'),
+ cell.weight.shape, cell.weight.dtype))
+ if cell.bias is not None:
+ cell.bias.set_data(
+ init.initializer('zeros', cell.bias.shape, cell.bias.dtype))
+ elif isinstance(cell, nn.BatchNorm2d):
+ cell.gamma.set_data(init.initializer('ones', cell.gamma.shape, cell.gamma.dtype))
+ cell.beta.set_data(init.initializer('zeros', cell.beta.shape, cell.beta.dtype))
+ elif isinstance(cell, nn.Dense):
+ cell.weight.set_data(
+ init.initializer(init.HeUniform(mode='fan_in', nonlinearity='sigmoid'),
+ cell.weight.shape, cell.weight.dtype))
+ if cell.bias is not None:
+ cell.bias.set_data(init.initializer('zeros', cell.bias.shape, cell.bias.dtype))
+
+ def _make_layer(
+ self,
+ block: Type[Union[BasicBlock, Bottleneck]],
+ channels: int,
+ block_nums: int,
+ stride: int = 1,
+ ) -> nn.SequentialCell:
+ """build model depending on cfgs"""
+ down_sample = None
+
+ if stride != 1 or self.input_channels != channels * block.expansion:
+ down_sample = nn.SequentialCell([
+ nn.Conv2d(self.input_channels, channels * block.expansion, kernel_size=1, stride=stride),
+ self.norm(channels * block.expansion)
+ ])
+
+ layers = []
+ layers.append(
+ block(
+ self.input_channels,
+ channels,
+ stride=stride,
+ down_sample=down_sample,
+ groups=self.groups,
+ base_width=self.base_with,
+ norm=self.norm,
+ )
+ )
+ self.input_channels = channels * block.expansion
+
+ for _ in range(1, block_nums):
+ layers.append(
+ block(
+ self.input_channels,
+ channels,
+ groups=self.groups,
+ base_width=self.base_with,
+ norm=self.norm
+ )
+ )
+
+ return nn.SequentialCell(layers)
+
+ def forward_features(self, x: Tensor) -> Tensor:
+ """Network forward feature extraction."""
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.max_pool(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ return x
+
+ def forward_head(self, x: Tensor) -> Tensor:
+ x = self.pool(x)
+ x = self.classifier(x)
+ x = self.ebv(x)
+ return x
+
+ def construct(self, x: Tensor) -> Tensor:
+ x = self.forward_features(x)
+ x = self.forward_head(x)
+ return x
+
+
+def _create_resnet_ebv(pretrained=False, **kwargs):
+ return build_model_with_cfg(ResNetEBV, pretrained, **kwargs)
+
+
+@register_model
+def resnet18_ebv(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs):
+ """Get 18 layers ResNet model with ebv layer.
+ Refer to the base class `models.ResNetEBV` for more details.
+ """
+ default_cfg = default_cfgs["resnet18_ebv"]
+ model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], num_classes=num_classes, in_channels=in_channels,
+ **kwargs)
+ return _create_resnet_ebv(pretrained, **dict(default_cfg=default_cfg, **model_args))
+
+
+@register_model
+def resnet34_ebv(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs):
+ """Get 34 layers ResNet model ebv layer.
+ Refer to the base class `models.ResNetEBV` for more details.
+ """
+ default_cfg = default_cfgs["resnet34_ebv"]
+ model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], num_classes=num_classes, in_channels=in_channels,
+ **kwargs)
+ return _create_resnet_ebv(pretrained, **dict(default_cfg=default_cfg, **model_args))
+
+
+@register_model
+def resnet50_ebv(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs):
+ """Get 50 layers ResNet model ebv layer.
+ Refer to the base class `models.ResNetEBV` for more details.
+ """
+ default_cfg = default_cfgs["resnet50_ebv"]
+ model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], num_classes=num_classes, in_channels=in_channels,
+ **kwargs)
+ return _create_resnet_ebv(pretrained, **dict(default_cfg=default_cfg, **model_args))
+
+
+@register_model
+def resnet101_ebv(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs):
+ """Get 101 layers ResNet model ebv layer.
+ Refer to the base class `models.ResNetEBV` for more details.
+ """
+ default_cfg = default_cfgs["resnet101_ebv"]
+ model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], num_classes=num_classes, in_channels=in_channels,
+ **kwargs)
+ return _create_resnet_ebv(pretrained, **dict(default_cfg=default_cfg, **model_args))
+
+
+@register_model
+def resnet152_ebv(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs):
+ """Get 152 layers ResNet model ebv layer.
+ Refer to the base class `models.ResNetEBV` for more details.
+ """
+ default_cfg = default_cfgs["resnet152_ebv"]
+ model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], num_classes=num_classes, in_channels=in_channels,
+ **kwargs)
+ return _create_resnet_ebv(pretrained, **dict(default_cfg=default_cfg, **model_args))
+
+
+@register_model
+def resnext50_32x4d_ebv(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs):
+ """Get 50 layers ResNeXt model with 32 groups of GPConv and ebv layer.
+ Refer to the base class `models.ResNetEBV` for more details.
+ """
+ default_cfg = default_cfgs["resnext50_32x4d_ebv"]
+ model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], groups=32, base_width=4, num_classes=num_classes,
+ in_channels=in_channels, **kwargs)
+ return _create_resnet_ebv(pretrained, **dict(default_cfg=default_cfg, **model_args))
+
+
+@register_model
+def resnext101_32x4d_ebv(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs):
+ """Get 101 layers ResNeXt model with 32 groups of GPConv and ebv layer.
+ Refer to the base class `models.ResNetEBV` for more details.
+ """
+ default_cfg = default_cfgs["resnext101_32x4d_ebv"]
+ model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], groups=32, base_width=4, num_classes=num_classes,
+ in_channels=in_channels, **kwargs)
+ return _create_resnet_ebv(pretrained, **dict(default_cfg=default_cfg, **model_args))
+
+
+@register_model
+def resnext101_64x4d_ebv(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs):
+ """Get 101 layers ResNeXt model with 64 groups of GPConv and ebv layer.
+ Refer to the base class `models.ResNetEBV` for more details.
+ """
+ default_cfg = default_cfgs["resnext101_64x4d_ebv"]
+ model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], groups=64, base_width=4, num_classes=num_classes,
+ in_channels=in_channels, **kwargs)
+ return _create_resnet_ebv(pretrained, **dict(default_cfg=default_cfg, **model_args))
+
+
+@register_model
+def resnext152_64x4d_ebv(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs):
+ default_cfg = default_cfgs["resnext152_64x4d_ebv"]
+ model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], groups=64, base_width=4, num_classes=num_classes,
+ in_channels=in_channels, **kwargs)
+ return _create_resnet_ebv(pretrained, **dict(default_cfg=default_cfg, **model_args))
diff --git a/tests/modules/test_models.py b/tests/modules/test_models.py
index 49607d87d..ee3a03b0b 100644
--- a/tests/modules/test_models.py
+++ b/tests/modules/test_models.py
@@ -58,6 +58,7 @@
"visformer_tiny",
"vit_b_32_224",
"xception",
+ "resnet18_ebv"
]
check_loss_decrease = False