Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[loss] add sphereface2 loss function #173

Merged
merged 2 commits into from
Jun 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ pip3 install wespeakerruntime
```

## 🔥 News
* 2023.06.30: Support the [SphereFace2](https://ieeexplore.ieee.org/abstract/document/10094954) loss function, with better performance and noisy robust in comparison with the ArcMargin Softmax, see [#173](https://github.com/wenet-e2e/wespeaker/pull/173).

* 2023.04.27: Support the [CAM++](https://arxiv.org/abs/2303.00332) model, with better performance and single-thread inference rtf in comparison with the ResNet34 model, see [#153](https://github.com/wenet-e2e/wespeaker/pull/153).

Expand Down Expand Up @@ -75,6 +76,7 @@ pip3 install wespeakerruntime
- [x] [Add_Margin (AM-Softmax)](https://arxiv.org/pdf/1801.05599.pdf)
- [x] [Arc_Margin (AAM-Softmax)](https://arxiv.org/pdf/1801.07698v1.pdf)
- [x] [Arc_Margin+Inter-topk+Sub-center](https://arxiv.org/pdf/2110.05042.pdf)
- [x] [SphereFace2](https://ieeexplore.ieee.org/abstract/document/10094954)
* Scoring
- [x] Cosine
- [x] PLDA
Expand Down
3 changes: 2 additions & 1 deletion examples/voxceleb/v2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
| ResNet34-TSTP-emb256 | 6.63M | × | 0.941 | 1.114 | 2.026 |
| | | √ | 0.899 | 1.064 | 1.856 |

* 🔥 UPDATE 2023.6.30: We support SphereFace2 loss function and obtain better and robust performance, see [#173](https://github.com/wenet-e2e/wespeaker/pull/173).

* 🔥 UPDATE 2022.07.19: We apply the same setups as the winning system of CNSRC 2022 (see [cnceleb](https://github.com/wenet-e2e/wespeaker/tree/master/examples/cnceleb/v2) recipe for details), and obtain significant performance improvement compared with our previous implementation.
* LR scheduler warmup from 0
Expand Down Expand Up @@ -50,7 +51,7 @@


* 🔥 UPDATE 2022.11.30: We support arc_margin_intertopk_subcenter loss function and Multi-query Multi-head Attentive Statistics Pooling, and obtain better performance especially on hard trials [VoxSRC](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/competition2021.html).
* See [#103](https://github.com/wenet-e2e/wespeaker/pull/103).
* See [#115](https://github.com/wenet-e2e/wespeaker/pull/115).


## PLDA results
Expand Down
2 changes: 1 addition & 1 deletion examples/voxceleb/v2/conf/resnet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ model_args:
pooling_func: "TSTP" # TSTP, ASTP, MQMHASTP
two_emb_layer: False
projection_args:
project_type: "arc_margin" # add_margin, arc_margin, sphere, softmax, arc_margin_intertopk_subcenter
project_type: "arc_margin" # add_margin, arc_margin, sphere, sphereface2, softmax, arc_margin_intertopk_subcenter
scale: 32.0
easy_margin: False

Expand Down
112 changes: 112 additions & 0 deletions wespeaker/models/projections.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2021 Shuai Wang (wsstriving@gmail.com)
# 2021 Zhengyang Chen (chenzhengyang117@gmail.com)
# 2022 Hongji Wang (jijijiang77@gmail.com)
# 2023 Bing Han (hanbing97@sjtu.edu.cn)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -48,12 +49,123 @@ def get_projection(conf):
projection = SphereProduct(conf['embed_dim'],
conf['num_class'],
margin=4)
elif conf['project_type'] == 'sphereface2':
projection = SphereFace2(conf['embed_dim'],
conf['num_class'],
scale=conf['scale'],
margin=0.0,
t=conf.get('t', 3),
lanbuda=conf.get('lanbuda', 0.7),
margin_type=conf.get('margin_type', 'C'))
else:
projection = Linear(conf['embed_dim'], conf['num_class'])

return projection


class SphereFace2(nn.Module):
r"""Implement of sphereface2 for speaker verification:
Reference:
[1] Exploring Binary Classification Loss for Speaker Verification
https://ieeexplore.ieee.org/abstract/document/10094954
[2] Sphereface2: Binary classification is all you need
for deep face recognition
https://arxiv.org/pdf/2108.01513
Args:
in_features: size of each input sample
out_features: size of each output sample
scale: norm of input feature
margin: margin
lanbuda: weight of positive and negative pairs
t: parameter for adjust score distribution
margin_type: A:cos(theta+margin) or C:cos(theta)-margin
Recommend margin:
training: 0.2 for C and 0.15 for A
LMF: 0.3 for C and 0.25 for A
"""

def __init__(self,
in_features,
out_features,
scale=32.0,
margin=0.2,
lanbuda=0.7,
t=3,
margin_type='C'):
super(SphereFace2, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.scale = scale
self.weight = nn.Parameter(torch.FloatTensor(out_features,
in_features))
nn.init.xavier_uniform_(self.weight)
self.bias = nn.Parameter(torch.zeros(1, 1))
self.t = t
self.lanbuda = lanbuda
self.margin_type = margin_type

########
self.margin = margin
self.cos_m = math.cos(margin)
self.sin_m = math.sin(margin)
self.th = math.cos(math.pi - margin)
self.mm = math.sin(math.pi - margin)
self.mmm = 1.0 + math.cos(math.pi - margin)
########

def update(self, margin=0.2):
self.margin = margin
self.cos_m = math.cos(margin)
self.sin_m = math.sin(margin)
self.th = math.cos(math.pi - margin)
self.mm = math.sin(math.pi - margin)
self.mmm = 1.0 + math.cos(math.pi - margin)

def fun_g(self, z, t: int):
gz = 2 * torch.pow((z + 1) / 2, t) - 1
return gz

def forward(self, input, label):
# compute similarity
cos = F.linear(F.normalize(input), F.normalize(self.weight))

if self.margin_type == 'A': # arcface type
sin = torch.sqrt(1.0 - torch.pow(cos, 2))
cos_m_theta_p = self.scale * self.fun_g(
torch.where(cos > self.th, cos * self.cos_m - sin * self.sin_m,
cos - self.mmm), self.t) + self.bias[0][0]
cos_m_theta_n = self.scale * self.fun_g(
cos * self.cos_m + sin * self.sin_m, self.t) + self.bias[0][0]
cos_p_theta = self.lanbuda * torch.log(
1 + torch.exp(-1.0 * cos_m_theta_p))
cos_n_theta = (
1 - self.lanbuda) * torch.log(1 + torch.exp(cos_m_theta_n))
else: # cosface type
cos_m_theta_p = self.scale * (self.fun_g(cos, self.t) -
self.margin) + self.bias[0][0]
cos_m_theta_n = self.scale * (self.fun_g(cos, self.t) +
self.margin) + self.bias[0][0]
cos_p_theta = self.lanbuda * torch.log(
1 + torch.exp(-1.0 * cos_m_theta_p))
cos_n_theta = (
1 - self.lanbuda) * torch.log(1 + torch.exp(cos_m_theta_n))

target_mask = input.new_zeros(cos.size())
target_mask.scatter_(1, label.view(-1, 1).long(), 1.0)
nontarget_mask = 1 - target_mask
cos1 = (cos - self.margin) * target_mask + cos * nontarget_mask
output = self.scale * cos1 # for computing the accuracy
loss = (target_mask * cos_p_theta +
nontarget_mask * cos_n_theta).sum(1).mean()
return output, loss

def extra_repr(self):
return '''in_features={}, out_features={}, scale={}, lanbuda={},
margin={}, t={}, margin_type={}'''.format(
self.in_features, self.out_features, self.scale, self.lanbuda,
self.margin, self.t, self.margin_type)


class ArcMarginProduct(nn.Module):
r"""Implement of large margin arc distance: :
Args:
Expand Down
6 changes: 4 additions & 2 deletions wespeaker/utils/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,10 @@ def run_epoch(dataloader,
outputs = model(features) # (embed_a,embed_b) in most cases
embeds = outputs[-1] if isinstance(outputs, tuple) else outputs
outputs = model.module.projection(embeds, targets)

loss = criterion(outputs, targets)
if isinstance(outputs, tuple):
outputs, loss = outputs
else:
loss = criterion(outputs, targets)

# loss, acc
loss_meter.add(loss.item())
Expand Down
6 changes: 4 additions & 2 deletions wespeaker/utils/executor_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,10 @@ def run_epoch(dataloader,
outputs = model(features) # (embed_a,embed_b) in most cases
embeds = outputs[-1] if isinstance(outputs, tuple) else outputs
outputs = model.module.projection(embeds, targets)

loss = criterion(outputs, targets)
if isinstance(outputs, tuple):
outputs, loss = outputs
else:
loss = criterion(outputs, targets)
# loss, acc
loss_meter.add(loss.item())
acc_meter.add(outputs.cpu().detach().numpy(),
Expand Down