diff --git a/README.md b/README.md index a83aab2b..3b2e475e 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,7 @@ pip3 install wespeakerruntime ``` ## 🔥 News +* 2023.06.30: Support the [SphereFace2](https://ieeexplore.ieee.org/abstract/document/10094954) loss function, with better performance and noisy robust in comparison with the ArcMargin Softmax, see [#173](https://github.com/wenet-e2e/wespeaker/pull/173). * 2023.04.27: Support the [CAM++](https://arxiv.org/abs/2303.00332) model, with better performance and single-thread inference rtf in comparison with the ResNet34 model, see [#153](https://github.com/wenet-e2e/wespeaker/pull/153). @@ -75,6 +76,7 @@ pip3 install wespeakerruntime - [x] [Add_Margin (AM-Softmax)](https://arxiv.org/pdf/1801.05599.pdf) - [x] [Arc_Margin (AAM-Softmax)](https://arxiv.org/pdf/1801.07698v1.pdf) - [x] [Arc_Margin+Inter-topk+Sub-center](https://arxiv.org/pdf/2110.05042.pdf) + - [x] [SphereFace2](https://ieeexplore.ieee.org/abstract/document/10094954) * Scoring - [x] Cosine - [x] PLDA diff --git a/examples/voxceleb/v2/README.md b/examples/voxceleb/v2/README.md index 79b98367..49b78a01 100644 --- a/examples/voxceleb/v2/README.md +++ b/examples/voxceleb/v2/README.md @@ -13,6 +13,7 @@ | ResNet34-TSTP-emb256 | 6.63M | × | 0.941 | 1.114 | 2.026 | | | | √ | 0.899 | 1.064 | 1.856 | +* 🔥 UPDATE 2023.6.30: We support SphereFace2 loss function and obtain better and robust performance, see [#173](https://github.com/wenet-e2e/wespeaker/pull/173). * 🔥 UPDATE 2022.07.19: We apply the same setups as the winning system of CNSRC 2022 (see [cnceleb](https://github.com/wenet-e2e/wespeaker/tree/master/examples/cnceleb/v2) recipe for details), and obtain significant performance improvement compared with our previous implementation. * LR scheduler warmup from 0 @@ -50,7 +51,7 @@ * 🔥 UPDATE 2022.11.30: We support arc_margin_intertopk_subcenter loss function and Multi-query Multi-head Attentive Statistics Pooling, and obtain better performance especially on hard trials [VoxSRC](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/competition2021.html). - * See [#103](https://github.com/wenet-e2e/wespeaker/pull/103). + * See [#115](https://github.com/wenet-e2e/wespeaker/pull/115). ## PLDA results diff --git a/examples/voxceleb/v2/conf/resnet.yaml b/examples/voxceleb/v2/conf/resnet.yaml index 0b727c2f..d0e774e3 100644 --- a/examples/voxceleb/v2/conf/resnet.yaml +++ b/examples/voxceleb/v2/conf/resnet.yaml @@ -46,7 +46,7 @@ model_args: pooling_func: "TSTP" # TSTP, ASTP, MQMHASTP two_emb_layer: False projection_args: - project_type: "arc_margin" # add_margin, arc_margin, sphere, softmax, arc_margin_intertopk_subcenter + project_type: "arc_margin" # add_margin, arc_margin, sphere, sphereface2, softmax, arc_margin_intertopk_subcenter scale: 32.0 easy_margin: False diff --git a/wespeaker/models/projections.py b/wespeaker/models/projections.py index 80aeaa2e..5045ea8b 100644 --- a/wespeaker/models/projections.py +++ b/wespeaker/models/projections.py @@ -1,6 +1,7 @@ # Copyright (c) 2021 Shuai Wang (wsstriving@gmail.com) # 2021 Zhengyang Chen (chenzhengyang117@gmail.com) # 2022 Hongji Wang (jijijiang77@gmail.com) +# 2023 Bing Han (hanbing97@sjtu.edu.cn) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -48,12 +49,123 @@ def get_projection(conf): projection = SphereProduct(conf['embed_dim'], conf['num_class'], margin=4) + elif conf['project_type'] == 'sphereface2': + projection = SphereFace2(conf['embed_dim'], + conf['num_class'], + scale=conf['scale'], + margin=0.0, + t=conf.get('t', 3), + lanbuda=conf.get('lanbuda', 0.7), + margin_type=conf.get('margin_type', 'C')) else: projection = Linear(conf['embed_dim'], conf['num_class']) return projection +class SphereFace2(nn.Module): + r"""Implement of sphereface2 for speaker verification: + Reference: + [1] Exploring Binary Classification Loss for Speaker Verification + https://ieeexplore.ieee.org/abstract/document/10094954 + [2] Sphereface2: Binary classification is all you need + for deep face recognition + https://arxiv.org/pdf/2108.01513 + Args: + in_features: size of each input sample + out_features: size of each output sample + scale: norm of input feature + margin: margin + lanbuda: weight of positive and negative pairs + t: parameter for adjust score distribution + margin_type: A:cos(theta+margin) or C:cos(theta)-margin + Recommend margin: + training: 0.2 for C and 0.15 for A + LMF: 0.3 for C and 0.25 for A + """ + + def __init__(self, + in_features, + out_features, + scale=32.0, + margin=0.2, + lanbuda=0.7, + t=3, + margin_type='C'): + super(SphereFace2, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.scale = scale + self.weight = nn.Parameter(torch.FloatTensor(out_features, + in_features)) + nn.init.xavier_uniform_(self.weight) + self.bias = nn.Parameter(torch.zeros(1, 1)) + self.t = t + self.lanbuda = lanbuda + self.margin_type = margin_type + + ######## + self.margin = margin + self.cos_m = math.cos(margin) + self.sin_m = math.sin(margin) + self.th = math.cos(math.pi - margin) + self.mm = math.sin(math.pi - margin) + self.mmm = 1.0 + math.cos(math.pi - margin) + ######## + + def update(self, margin=0.2): + self.margin = margin + self.cos_m = math.cos(margin) + self.sin_m = math.sin(margin) + self.th = math.cos(math.pi - margin) + self.mm = math.sin(math.pi - margin) + self.mmm = 1.0 + math.cos(math.pi - margin) + + def fun_g(self, z, t: int): + gz = 2 * torch.pow((z + 1) / 2, t) - 1 + return gz + + def forward(self, input, label): + # compute similarity + cos = F.linear(F.normalize(input), F.normalize(self.weight)) + + if self.margin_type == 'A': # arcface type + sin = torch.sqrt(1.0 - torch.pow(cos, 2)) + cos_m_theta_p = self.scale * self.fun_g( + torch.where(cos > self.th, cos * self.cos_m - sin * self.sin_m, + cos - self.mmm), self.t) + self.bias[0][0] + cos_m_theta_n = self.scale * self.fun_g( + cos * self.cos_m + sin * self.sin_m, self.t) + self.bias[0][0] + cos_p_theta = self.lanbuda * torch.log( + 1 + torch.exp(-1.0 * cos_m_theta_p)) + cos_n_theta = ( + 1 - self.lanbuda) * torch.log(1 + torch.exp(cos_m_theta_n)) + else: # cosface type + cos_m_theta_p = self.scale * (self.fun_g(cos, self.t) - + self.margin) + self.bias[0][0] + cos_m_theta_n = self.scale * (self.fun_g(cos, self.t) + + self.margin) + self.bias[0][0] + cos_p_theta = self.lanbuda * torch.log( + 1 + torch.exp(-1.0 * cos_m_theta_p)) + cos_n_theta = ( + 1 - self.lanbuda) * torch.log(1 + torch.exp(cos_m_theta_n)) + + target_mask = input.new_zeros(cos.size()) + target_mask.scatter_(1, label.view(-1, 1).long(), 1.0) + nontarget_mask = 1 - target_mask + cos1 = (cos - self.margin) * target_mask + cos * nontarget_mask + output = self.scale * cos1 # for computing the accuracy + loss = (target_mask * cos_p_theta + + nontarget_mask * cos_n_theta).sum(1).mean() + return output, loss + + def extra_repr(self): + return '''in_features={}, out_features={}, scale={}, lanbuda={}, + margin={}, t={}, margin_type={}'''.format( + self.in_features, self.out_features, self.scale, self.lanbuda, + self.margin, self.t, self.margin_type) + + class ArcMarginProduct(nn.Module): r"""Implement of large margin arc distance: : Args: diff --git a/wespeaker/utils/executor.py b/wespeaker/utils/executor.py index 6ef23e5f..61eb8137 100644 --- a/wespeaker/utils/executor.py +++ b/wespeaker/utils/executor.py @@ -61,8 +61,10 @@ def run_epoch(dataloader, outputs = model(features) # (embed_a,embed_b) in most cases embeds = outputs[-1] if isinstance(outputs, tuple) else outputs outputs = model.module.projection(embeds, targets) - - loss = criterion(outputs, targets) + if isinstance(outputs, tuple): + outputs, loss = outputs + else: + loss = criterion(outputs, targets) # loss, acc loss_meter.add(loss.item()) diff --git a/wespeaker/utils/executor_deprecated.py b/wespeaker/utils/executor_deprecated.py index 5641f0e1..41770160 100644 --- a/wespeaker/utils/executor_deprecated.py +++ b/wespeaker/utils/executor_deprecated.py @@ -52,8 +52,10 @@ def run_epoch(dataloader, outputs = model(features) # (embed_a,embed_b) in most cases embeds = outputs[-1] if isinstance(outputs, tuple) else outputs outputs = model.module.projection(embeds, targets) - - loss = criterion(outputs, targets) + if isinstance(outputs, tuple): + outputs, loss = outputs + else: + loss = criterion(outputs, targets) # loss, acc loss_meter.add(loss.item()) acc_meter.add(outputs.cpu().detach().numpy(),