Skip to content

Commit

Permalink
Add FP16 Support.
Browse files Browse the repository at this point in the history
Merge PartialFC pytorch into arcface_torch.
Pytorch1.6+ is all you need
  • Loading branch information
anxiangsir committed Mar 11, 2021
1 parent 43ec930 commit d873824
Show file tree
Hide file tree
Showing 30 changed files with 153 additions and 2,987 deletions.
25 changes: 23 additions & 2 deletions recognition/arcface_torch/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,32 @@ More details see [eval.md](docs/eval.md) in docs.
| MS1MV3-Arcface | r18 | 92.08 | 94.68 |97.65 |97.63 |99.73|
| MS1MV3-Arcface | r34 | | | | | |
| MS1MV3-Arcface | r50 | 94.79 | 96.43 |98.28 |98.89 |99.85|
| MS1MV3-Arcface | r50-amp | 94.72 | 96.41 |98.30 |99.06 |99.85|
| MS1MV3-Arcface | r100 | 95.22 | 96.87 |98.45 |99.19 |99.85|

### Glint360k
| Datasets | backbone | IJBC(1e-05) | IJBC(1e-04) |agedb30|cfp_fp|lfw |
| :---: | :--- | :--- | :--- |:--- |:--- |:--- |
| Glint360k-Cosface | r100 | - | - |- |- |- |
| Glint360k-Cosface | r100 | 96.19 | 97.39 |98.52 |99.26 |99.83|

More details see [eval.md](docs/modelzoo.md) in docs.
More details see [eval.md](docs/modelzoo.md) in docs.



## Citation
```
@inproceedings{deng2019arcface,
title={Arcface: Additive angular margin loss for deep face recognition},
author={Deng, Jiankang and Guo, Jia and Xue, Niannan and Zafeiriou, Stefanos},
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
pages={4690--4699},
year={2019}
}
@inproceedings{an2020partical_fc,
title={Partial FC: Training 10 Million Identities on a Single Machine},
author={An, Xiang and Zhu, Xuhan and Xiao, Yang and Wu, Lan and Zhang, Ming and Gao, Yuan and Qin, Bin and
Zhang, Debing and Fu Ying},
booktitle={Arxiv 2010.05222},
year={2020}
}
```
53 changes: 18 additions & 35 deletions recognition/arcface_torch/backbones/iresnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,50 +35,36 @@ def __init__(self, inplanes, planes, stride=1, downsample=None,
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")

self.bn1 = nn.BatchNorm2d(
inplanes,
eps=1e-05,
)
self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
self.conv1 = conv3x3(inplanes, planes)
self.bn2 = nn.BatchNorm2d(
planes,
eps=1e-05,
)
self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
self.prelu = nn.PReLU(planes)
self.conv2 = conv3x3(planes, planes, stride)
self.bn3 = nn.BatchNorm2d(
planes,
eps=1e-05,
)
self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
self.downsample = downsample
self.stride = stride

def forward(self, x):
identity = x

out = self.bn1(x)
out = self.conv1(out)
out = self.bn2(out)
out = self.prelu(out)
out = self.conv2(out)
out = self.bn3(out)

if self.downsample is not None:
identity = self.downsample(x)

out += identity

return out


class IResNet(nn.Module):
fc_scale = 7 * 7

def __init__(self,
block, layers, dropout=0, num_features=512, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None):
groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
super(IResNet, self).__init__()
self.fp16 = fp16
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
Expand Down Expand Up @@ -109,8 +95,7 @@ def __init__(self,
dilate=replace_stride_with_dilation[2])
self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
self.dropout = nn.Dropout(p=dropout, inplace=True)
self.fc = nn.Linear(512 * block.expansion * self.fc_scale,
num_features)
self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
self.features = nn.BatchNorm1d(num_features, eps=1e-05)
nn.init.constant_(self.features.weight, 1.0)
self.features.weight.requires_grad = False
Expand Down Expand Up @@ -154,21 +139,19 @@ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
return nn.Sequential(*layers)

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.prelu(x)

x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)

x = self.bn2(x)
x = torch.flatten(x, 1)
x = self.dropout(x)
x = self.fc(x)
with torch.cuda.amp.autocast(self.fp16):
x = self.conv1(x)
x = self.bn1(x)
x = self.prelu(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.bn2(x)
x = torch.flatten(x, 1)
x = self.dropout(x)
x = self.fc(x.float() if self.fp16 else x)
x = self.features(x)

return x


Expand Down
2 changes: 1 addition & 1 deletion recognition/arcface_torch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
config.weight_decay = 5e-4
config.batch_size = 64
config.lr = 0.1 # batch size is 512
config.output = "ms1mv3_r50_arcface"
config.output = "ms1mv3_arcface_r50"

if config.dataset == "emore":
config.rec = "/train_tmp/faces_emore"
Expand Down
2 changes: 1 addition & 1 deletion recognition/arcface_torch/partial_fc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

class PartialFC(Module):
"""
Author: {Yang Xiao, Xiang An, XuHan Zhu} in DeepGlint,
Author: {Xiang An, Yang Xiao, XuHan Zhu} in DeepGlint,
Partial FC: Training 10 Million Identities on a Single Machine
See the original paper:
https://arxiv.org/abs/2010.05222
Expand Down
24 changes: 17 additions & 7 deletions recognition/arcface_torch/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from partial_fc import PartialFC
from utils.utils_callbacks import CallBackVerification, CallBackLogging, CallBackModelCheckpoint
from utils.utils_logging import AverageMeter, init_logging
from utils.utils_amp import MaxClipGradScaler

torch.backends.cudnn.benchmark = True

Expand All @@ -42,7 +43,7 @@ def main(args):
sampler=train_sampler, num_workers=0, pin_memory=True, drop_last=True)

dropout = 0.4 if cfg.dataset is "webface" else 0
backbone = eval("backbones.{}".format(args.network))(False, dropout=dropout).to(local_rank)
backbone = eval("backbones.{}".format(args.network))(False, dropout=dropout, fp16=cfg.fp16).to(local_rank)

if args.resume:
try:
Expand Down Expand Up @@ -81,30 +82,39 @@ def main(args):

start_epoch = 0
total_step = int(len(trainset) / cfg.batch_size / world_size * cfg.num_epoch)
if rank is 0:
logging.info("Total Step is: %d" % total_step)
if rank is 0: logging.info("Total Step is: %d" % total_step)

callback_verification = CallBackVerification(2000, rank, cfg.val_targets, cfg.rec)
callback_logging = CallBackLogging(50, rank, total_step, cfg.batch_size, world_size, None)
callback_checkpoint = CallBackModelCheckpoint(rank, cfg.output)

loss = AverageMeter()
global_step = 0
grad_scaler = MaxClipGradScaler(cfg.batch_size, 128 * cfg.batch_size, growth_interval=100) if cfg.fp16 else None
for epoch in range(start_epoch, cfg.num_epoch):
train_sampler.set_epoch(epoch)
for step, (img, label) in enumerate(train_loader):
global_step += 1
features = F.normalize(backbone(img))
x_grad, loss_v = module_partial_fc.forward_backward(label, features, opt_pfc)
features.backward(x_grad)
clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2)
opt_backbone.step()

if cfg.fp16:
features.backward(grad_scaler.scale(x_grad))
grad_scaler.unscale_(opt_backbone)
clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2)
grad_scaler.step(opt_backbone)
grad_scaler.update()
else:
features.backward(x_grad)
clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2)
opt_backbone.step()

opt_pfc.step()
module_partial_fc.update()
opt_backbone.zero_grad()
opt_pfc.zero_grad()
loss.update(loss_v, 1)
callback_logging(global_step, loss, epoch)
callback_logging(global_step, loss, epoch, cfg.fp16, grad_scaler)
callback_verification(global_step, backbone)
callback_checkpoint(global_step, backbone, module_partial_fc)
scheduler_backbone.step()
Expand Down
81 changes: 81 additions & 0 deletions recognition/arcface_torch/utils/utils_amp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from typing import Dict, List

import torch
from torch._six import container_abcs
from torch.cuda.amp import GradScaler


class _MultiDeviceReplicator(object):
"""
Lazily serves copies of a tensor to requested devices. Copies are cached per-device.
"""

def __init__(self, master_tensor: torch.Tensor) -> None:
assert master_tensor.is_cuda
self.master = master_tensor
self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}

def get(self, device) -> torch.Tensor:
retval = self._per_device_tensors.get(device, None)
if retval is None:
retval = self.master.to(device=device, non_blocking=True, copy=True)
self._per_device_tensors[device] = retval
return retval


class MaxClipGradScaler(GradScaler):
def __init__(self, init_scale, max_scale: float, growth_interval=100):
GradScaler.__init__(self, init_scale=init_scale, growth_interval=growth_interval)
self.max_scale = max_scale

def scale_clip(self):
if self.get_scale() == self.max_scale:
self.set_growth_factor(1)
elif self.get_scale() < self.max_scale:
self.set_growth_factor(2)
elif self.get_scale() > self.max_scale:
self._scale.fill_(self.max_scale)
self.set_growth_factor(1)

def scale(self, outputs):
"""
Multiplies ('scales') a tensor or list of tensors by the scale factor.
Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned
unmodified.
Arguments:
outputs (Tensor or iterable of Tensors): Outputs to scale.
"""
if not self._enabled:
return outputs
self.scale_clip()
# Short-circuit for the common case.
if isinstance(outputs, torch.Tensor):
assert outputs.is_cuda
if self._scale is None:
self._lazy_init_scale_growth_tracker(outputs.device)
assert self._scale is not None
return outputs * self._scale.to(device=outputs.device, non_blocking=True)

# Invoke the more complex machinery only if we're treating multiple outputs.
stash: List[_MultiDeviceReplicator] = [] # holds a reference that can be overwritten by apply_scale

def apply_scale(val):
if isinstance(val, torch.Tensor):
assert val.is_cuda
if len(stash) == 0:
if self._scale is None:
self._lazy_init_scale_growth_tracker(val.device)
assert self._scale is not None
stash.append(_MultiDeviceReplicator(self._scale))
return val * stash[0].get(val.device)
elif isinstance(val, container_abcs.Iterable):
iterable = map(apply_scale, val)
if isinstance(val, list) or isinstance(val, tuple):
return type(val)(iterable)
else:
return iterable
else:
raise ValueError("outputs must be a Tensor or an iterable of Tensors")
return apply_scale(outputs)
16 changes: 10 additions & 6 deletions recognition/arcface_torch/utils/utils_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,13 @@ def __init__(self, frequent, rank, total_step, batch_size, world_size, writer=No
self.init = False
self.tic = 0

def __call__(self, global_step, loss: AverageMeter, epoch: int):
def __call__(self, global_step, loss: AverageMeter, epoch: int, fp16: bool, grad_scaler: torch.cuda.amp.GradScaler):
if self.rank is 0 and global_step > 0 and global_step % self.frequent == 0:
if self.init:
try:
speed: float = self.frequent * self.batch_size / (time.time() - self.tic)
speed_total = speed * self.world_size
except ZeroDivisionError:
speed = float('inf')
speed_total = float('inf')

time_now = (time.time() - self.time_start) / 3600
Expand All @@ -78,10 +77,15 @@ def __call__(self, global_step, loss: AverageMeter, epoch: int):
if self.writer is not None:
self.writer.add_scalar('time_for_end', time_for_end, global_step)
self.writer.add_scalar('loss', loss.avg, global_step)

msg = "Speed %.2f samples/sec Loss %.4f Epoch: %d Global Step: %d Required: %1.f hours" % (
speed_total, loss.avg, epoch, global_step, time_for_end
)
if fp16:
msg = "Speed %.2f samples/sec Loss %.4f Epoch: %d Global Step: %d "\
"Fp16 Grad Scale: %2.f Required: %1.f hours" % (
speed_total, loss.avg, epoch, global_step, grad_scaler.get_scale(), time_for_end
)
else:
msg = "Speed %.2f samples/sec Loss %.4f Epoch: %d Global Step: %d Required: %1.f hours" % (
speed_total, loss.avg, epoch, global_step, time_for_end
)
logging.info(msg)
loss.reset()
self.tic = time.time()
Expand Down
32 changes: 0 additions & 32 deletions recognition/partial_fc/pytorch/IJB/IJBC_img2array.py

This file was deleted.

Loading

0 comments on commit d873824

Please sign in to comment.