Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add a new model called resnet_ebv #780

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions configs/resnet_ebv/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# ResNet_EBV
> [Equiangular Basis Vectors](https://arxiv.org/abs/2303.11637)

## Introduction

EBVs provide a solution to the problem of classification with a large number of classes in resource-constrained environments. When the number of classes is C (e.g., C > 100,000), the number of trainable parameters in the final linear layer of a traditional ResNet-50 increases to 2048 * C. In contrast, EBVs reduce this by using fixed basis vectors for different classes, where the dimensionality is d (with d << C), and constraining the angles between these basis vectors during initialization. After that, EBVs are fixed, which reduces the number of trainable parameters to 2048 * d. EBVs can also be extended to other architectures.[[1](#references)]

<!-- <p align="center">
<img src="https://user-images.githubusercontent.com/53842165/223672204-8ac59c6c-cd8a-45c2-945f-7e556c383056.jpg" width=500 />
</p>
<p align="center">
<em>Figure 1. Comparisons between typical classification paradigms and Equiangular Basis Vectors (EBVs). [<a href="#references">1</a>] </em>
</p> -->

## Results

Our reproduced model performance on ImageNet-1K is reported as follows.

<div align="center">

| Model | Context | Top-1 (%) | Top-5 (%) | Params (M) | Recipe | Download |
|------------|----------|-----------|-----------|------------|--------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------|
| resnet50_ebv | D910x8-G | 78.12 | 93.80 | 27.55 | [yaml](./resnet50_ebv_ascend.yaml) | \ |


</div>

#### Notes

- Context: Training context denoted as {device}x{pieces}-{MS mode}, where mindspore mode can be G - graph mode or F - pynative mode with ms function. For example, D910x8-G is for training on 8 pieces of Ascend 910 NPU using graph mode.
- Top-1 and Top-5: Accuracy reported on the validation set of ImageNet-1K.

## Quick Start

### Preparation

#### Installation
Please refer to the [installation instruction](https://github.com/mindspore-ecosystem/mindcv#installation) in MindCV.

#### Dataset Preparation
Please download the [ImageNet-1K](https://www.image-net.org/challenges/LSVRC/2012/index.php) dataset for model training and validation.

### Training

* Distributed Training

It is easy to reproduce the reported results with the pre-defined training recipe. For distributed training on multiple Ascend 910 devices, please run

```shell
# distributed training on multiple GPU/Ascend devices
mpirun -n 8 python train.py --config configs/resnet/resnet50_ebv_ascend.yaml --data_dir /path/to/imagenet
```

> If the script is executed by the root user, the `--allow-run-as-root` parameter must be added to `mpirun`.

Similarly, you can train the model on multiple GPU devices with the above `mpirun` command.

For detailed illustration of all hyper-parameters, please refer to [config.py](https://github.com/mindspore-lab/mindcv/blob/main/config.py).

**Note:** As the global batch size (batch_size x num_devices) is an important hyper-parameter, it is recommended to keep the global batch size unchanged for reproduction or adjust the learning rate linearly to a new global batch size.

* Standalone Training

If you want to train or finetune the model on a smaller dataset without distributed training, please run:

```shell
# standalone training on a CPU/GPU/Ascend device
python train.py --config configs/resnet/resnet50_ebv_ascend.yaml --data_dir /path/to/dataset --distribute False
```

### Validation

To validate the accuracy of the trained model, you can use `validate.py` and parse the checkpoint path with `--ckpt_path`.

```shell
python validate.py -c configs/resnet/resnet50_ebv_ascend.yaml --data_dir /path/to/imagenet --ckpt_path /path/to/ckpt
```

### Deployment

Please refer to the [deployment tutorial](https://mindspore-lab.github.io/mindcv/tutorials/deployment/) in MindCV.

## References

[1] Shen Y, Sun X, Wei X S. Equiangular basis vectors[C]//Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2023: 11755-11765.
58 changes: 58 additions & 0 deletions configs/resnet_ebv/resnet50_ebv_ascend.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# system
mode: 0
distribute: True
num_parallel_workers: 8
val_while_train: True
val_interval: 1

# dataset
dataset: "imagenet"
data_dir: "/path/to/imagenet"
shuffle: True
dataset_download: False
batch_size: 128
drop_remainder: True

# augmentation
image_resize: 224
scale: [0.08, 1.0]
ratio: [0.75, 1.333]
hflip: 0.5
interpolation: "bilinear"
re_prob: 0.1
mixup: 0.2
cutmix: 1.0
cutmix_prob: 1.0
crop_pct: 0.875
color_jitter: [0., 0., 0.]
auto_augment: "trivialaugwide"

# model
model: "resnet50_ebv"
num_classes: 1000
pretrained: False
keep_checkpoint_max: 5
ckpt_save_policy: "top_k"
ckpt_save_dir: "./ckpt"
epoch_size: 205
dataset_sink_mode: True
amp_level: "O2"

# loss
loss: "CE"
label_smoothing: 0.1

# lr scheduler
scheduler: "cosine_decay"
lr: 0.1
min_lr: 0
warmup_epochs: 5
decay_epochs: 200
lr_epoch_stair: False

# optimizer
opt: "momentum"
momentum: 0.9
weight_decay: 0.00002
loss_scale: 1024
use_nesterov: False
3 changes: 3 additions & 0 deletions mindcv/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
res2net,
resnest,
resnet,
resnet_ebv,
resnetv2,
rexnet,
senet,
Expand Down Expand Up @@ -97,6 +98,7 @@
from .res2net import *
from .resnest import *
from .resnet import *
from .resnet_ebv import *
from .resnetv2 import *
from .rexnet import *
from .senet import *
Expand Down Expand Up @@ -171,3 +173,4 @@
__all__.extend(volo.__all__)
__all__.extend(["Xception", "xception"])
__all__.extend(xcit.__all__)
__all__.extend(resnet_ebv.__all__)
2 changes: 2 additions & 0 deletions mindcv/models/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
activation,
conv_norm_act,
drop_path,
ebv,
format,
identity,
patch_dropout,
Expand All @@ -14,6 +15,7 @@
from .activation import *
from .conv_norm_act import *
from .drop_path import *
from .ebv import *
from .format import *
from .identity import *
from .patch_dropout import *
Expand Down
88 changes: 88 additions & 0 deletions mindcv/models/layers/ebv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""EBV
Mindspore implementations of Equiangular Basis Vectors layer.
Papers:
Equiangular Basis Vectors (https://arxiv.org/pdf/2303.11637)
"""

import mindspore as ms
import mindspore.numpy as msnp
from mindspore import Tensor, nn, ops


class EBV(nn.Cell):
"""
Equiangular Basis Vectors layer
"""

def __init__(
self,
num_cls: int = 1000,
dim: int = 1000,
thre: float = 0.002,
slice_size: int = 130,
lr: float = 1e-3,
steps: int = 100000,
tau: float = 0.07
) -> None:
"""
Args:
num_cls (int): Number of categories, which can also be interpreted as the
number of basis vectors N that need to be generated, num_cls >= N. Default: 1000.
dim (int): Dimension for basis vectors. Default: 1000.
thre (float): The maximum value of the absolute cosine value
of the angle between any two basis vectors. Default: 0.002.
slice_size (int): Slicing optimization is required due to insufficient memory. Default: 130.
lr (float): Optimize learning rate. Default: 1e-3.
steps (int): Optimize step numbers. Default: 100000.
tau (float): Temperature parameter, less than
-num_cls/((num_cls-1) * log(exp(0.001) -1)/(N-1))). Default: 0.07
"""
super().__init__()
self.num_cls = num_cls
self.dim = dim
self.thre = thre
self.slice_size = slice_size
self.lr = lr
self.steps = steps
self.tau = tau
self.l2norm = ops.L2Normalize()
self.ebv = self._generate_ebv()
self.ebv.requires_grad = False

def _generate_ebv(self):
basis_vec = ms.Parameter(
ops.L2Normalize(1)(
ops.standard_normal((self.num_cls, self.dim))
), name='basis_vec', requires_grad=True)
optim = nn.SGD([basis_vec], learning_rate=self.lr)
matmul = ops.MatMul(transpose_b=True)

def forward_fn(a, b, e, thr):
m = matmul(a, b).abs() - e
loss = ops.relu(m - thr).sum()
return loss, m

grad_fn = ops.value_and_grad(forward_fn, 1, [basis_vec], has_aux=True)
for _ in range(self.steps):
basis_vec.set_data(ops.L2Normalize(1)(basis_vec.data))
mx = self.thre
grads = msnp.zeros_like(basis_vec)
for i in range((self.num_cls - 1) // self.slice_size + 1):
start = self.slice_size * i
end = min(self.slice_size * (i + 1), self.num_cls)
e = ops.one_hot(msnp.arange(start, end), self.num_cls, Tensor(1.0, ms.float32), Tensor(0.0, ms.float32))
(loss, m), grads_partial = grad_fn(basis_vec[start:end], basis_vec, e, self.thre)
mx = max(mx, m.max().asnumpy().tolist())
grads = grads + grads_partial[0]

if mx <= self.thre + 0.0001:
return self.l2norm(basis_vec.data)
optim((grads,))

return self.l2norm(basis_vec.data)

def construct(self, x: Tensor) -> Tensor:
x = self.l2norm(x)
logits = ops.matmul(x, self.ebv.T / self.tau)

return logits
Loading