Skip to content

Commit

Permalink
[TTS] Add period discriminator and feature matching loss to codec rec…
Browse files Browse the repository at this point in the history
…ipe (#7884)

* [TTS] Add period discriminator and feature matching loss to codec recipe

Signed-off-by: Ryan <rlangman@nvidia.com>

* [TTS] Update docs for period discriminator

Signed-off-by: Ryan <rlangman@nvidia.com>

---------

Signed-off-by: Ryan <rlangman@nvidia.com>
  • Loading branch information
rlangman authored Jan 19, 2024
1 parent e329575 commit 46f6465
Show file tree
Hide file tree
Showing 3 changed files with 216 additions and 3 deletions.
51 changes: 51 additions & 0 deletions nemo/collections/tts/losses/audio_codec_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,58 @@ def forward(self, audio_real, audio_gen, audio_len):
return loss


class FeatureMatchingLoss(Loss):
"""
Standard feature matching loss measuring the difference in the internal discriminator layer outputs
(usually leaky relu activations) between real and generated audio, scaled down by the total number of
discriminators and layers.
"""

def __init__(self):
super(FeatureMatchingLoss, self).__init__()

@property
def input_types(self):
return {
"fmaps_real": [[NeuralType(elements_type=VoidType())]],
"fmaps_gen": [[NeuralType(elements_type=VoidType())]],
}

@property
def output_types(self):
return {
"loss": NeuralType(elements_type=LossType()),
}

@typecheck()
def forward(self, fmaps_real, fmaps_gen):
loss = 0.0
for fmap_real, fmap_gen in zip(fmaps_real, fmaps_gen):
# [B, ..., time]
for feat_real, feat_gen in zip(fmap_real, fmap_gen):
# [B, ...]
diff = torch.abs(feat_real - feat_gen)
feat_loss = torch.mean(diff) / len(fmap_real)
loss += feat_loss

loss /= len(fmaps_real)

return loss


class RelativeFeatureMatchingLoss(Loss):
"""
Relative feature matching loss as described in https://arxiv.org/pdf/2210.13438.pdf.
This is similar to standard feature matching loss, but it scales the loss by the absolute value of
each feature averaged across time. This might be slightly different from the paper which says the
"mean is computed over all dimensions", which could imply taking the average across both time and
features.
Args:
div_guard: Value to add when dividing by mean to avoid large/NaN values.
"""

def __init__(self, div_guard=1e-3):
super(RelativeFeatureMatchingLoss, self).__init__()
self.div_guard = div_guard
Expand Down
10 changes: 9 additions & 1 deletion nemo/collections/tts/models/audio_codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from pytorch_lightning import Trainer

from nemo.collections.tts.losses.audio_codec_loss import (
FeatureMatchingLoss,
MultiResolutionMelLoss,
MultiResolutionSTFTLoss,
RelativeFeatureMatchingLoss,
Expand Down Expand Up @@ -130,7 +131,14 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self.feature_loss_scale = cfg.get("feature_loss_scale", 1.0)
self.gen_loss_fn = instantiate(cfg.generator_loss)
self.disc_loss_fn = instantiate(cfg.discriminator_loss)
self.feature_loss_fn = RelativeFeatureMatchingLoss()

feature_loss_type = cfg.get("feature_loss_type", "relative")
if feature_loss_type == "relative":
self.feature_loss_fn = RelativeFeatureMatchingLoss()
elif feature_loss_type == "absolute":
self.feature_loss_fn = FeatureMatchingLoss()
else:
raise ValueError(f'Unknown feature loss type {feature_loss_type}.')

# Codebook loss setup
if self.vector_quantizer:
Expand Down
158 changes: 156 additions & 2 deletions nemo/collections/tts/modules/audio_codec_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional, Tuple
from typing import Iterable, List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

from nemo.collections.asr.parts.utils.activations import Snake
from nemo.collections.tts.parts.utils.helpers import mask_sequence_tensor
from nemo.core.classes.common import typecheck
from nemo.core.classes.module import NeuralModule
from nemo.core.neural_types.elements import EncodedRepresentation, Index, LengthsType, VoidType
from nemo.core.neural_types.elements import AudioSignal, EncodedRepresentation, Index, LengthsType, VoidType
from nemo.core.neural_types.neural_type import NeuralType
from nemo.utils import logging

Expand Down Expand Up @@ -186,6 +187,159 @@ def forward(self, inputs):
return self.conv(inputs)


class PeriodDiscriminator(NeuralModule):
"""
Period discriminator introduced in HiFi-GAN https://arxiv.org/abs/2010.05646 which attempts to
discriminate phase information by looking at equally spaced audio samples.
Args:
period: Spacing between audio sample inputs.
lrelu_slope: Slope to use for activation. Leaky relu with slope of 0.1 or 0.2 is recommended for the
stability of the feature matching loss.
"""

def __init__(self, period, lrelu_slope=0.1):
super().__init__()
self.period = period
self.activation = nn.LeakyReLU(lrelu_slope)
self.conv_layers = nn.ModuleList(
[
Conv2dNorm(1, 32, kernel_size=(5, 1), stride=(3, 1)),
Conv2dNorm(32, 128, kernel_size=(5, 1), stride=(3, 1)),
Conv2dNorm(128, 512, kernel_size=(5, 1), stride=(3, 1)),
Conv2dNorm(512, 1024, kernel_size=(5, 1), stride=(3, 1)),
Conv2dNorm(1024, 1024, kernel_size=(5, 1), stride=(1, 1)),
]
)
self.conv_post = Conv2dNorm(1024, 1, kernel_size=(3, 1))

@property
def input_types(self):
return {
"audio": NeuralType(('B', 'T_audio'), AudioSignal()),
}

@property
def output_types(self):
return {
"score": NeuralType(('B', 'C', 'T_out'), VoidType()),
"fmap": [NeuralType(('B', 'D', 'T_layer', 'C'), VoidType())],
}

@typecheck()
def forward(self, audio):

batch_size, time = audio.shape
out = rearrange(audio, 'B T -> B 1 T')
# Pad audio so that it is divisible by the period
if time % self.period != 0:
n_pad = self.period - (time % self.period)
out = F.pad(out, (0, n_pad), "reflect")
time = time + n_pad
# [batch, 1, (time / period), period]
out = out.view(batch_size, 1, time // self.period, self.period)

fmap = []
for conv in self.conv_layers:
# [batch, filters, (time / period / stride), period]
out = conv(inputs=out)
out = self.activation(out)
fmap.append(out)
# [batch, 1, (time / period / strides), period]
score = self.conv_post(inputs=out)
fmap.append(score)
score = rearrange(score, "B 1 T C -> B C T")

return score, fmap


class MultiPeriodDiscriminator(NeuralModule):
"""
Wrapper class to aggregate results of multiple period discriminators.
The periods are expected to be increasing prime numbers in order to maximize coverage and minimize overlap
"""

def __init__(self, periods: Iterable[int] = (2, 3, 5, 7, 11), lrelu_slope=0.1):
super().__init__()
self.discriminators = nn.ModuleList(
[PeriodDiscriminator(period=period, lrelu_slope=lrelu_slope) for period in periods]
)

@property
def input_types(self):
return {
"audio_real": NeuralType(('B', 'T_audio'), AudioSignal()),
"audio_gen": NeuralType(('B', 'T_audio'), AudioSignal()),
}

@property
def output_types(self):
return {
"scores_real": [NeuralType(('B', 'C', 'T_out'), VoidType())],
"scores_gen": [NeuralType(('B', 'C', 'T_out'), VoidType())],
"fmaps_real": [[NeuralType(('B', 'D', 'T_layer', 'C'), VoidType())]],
"fmaps_gen": [[NeuralType(('B', 'D', 'T_layer', 'C'), VoidType())]],
}

@typecheck()
def forward(self, audio_real, audio_gen):
scores_real = []
scores_gen = []
fmaps_real = []
fmaps_gen = []
for discriminator in self.discriminators:
score_real, fmap_real = discriminator(audio=audio_real)
score_gen, fmap_gen = discriminator(audio=audio_gen)
scores_real.append(score_real)
fmaps_real.append(fmap_real)
scores_gen.append(score_gen)
fmaps_gen.append(fmap_gen)

return scores_real, scores_gen, fmaps_real, fmaps_gen


class Discriminator(NeuralModule):
"""
Wrapper class which takes a list of discriminators and aggregates the results across them.
"""

def __init__(self, discriminators: Iterable[NeuralModule]):
super().__init__()
self.discriminators = nn.ModuleList(discriminators)

@property
def input_types(self):
return {
"audio_real": NeuralType(('B', 'T_audio'), AudioSignal()),
"audio_gen": NeuralType(('B', 'T_audio'), AudioSignal()),
}

@property
def output_types(self):
return {
"scores_real": [NeuralType(('B', 'C', 'T_out'), VoidType())],
"scores_gen": [NeuralType(('B', 'C', 'T_out'), VoidType())],
"fmaps_real": [[NeuralType(('B', 'D', 'T_layer', 'C'), VoidType())]],
"fmaps_gen": [[NeuralType(('B', 'D', 'T_layer', 'C'), VoidType())]],
}

@typecheck()
def forward(self, audio_real, audio_gen):
scores_real = []
scores_gen = []
fmaps_real = []
fmaps_gen = []
for discriminator in self.discriminators:
score_real, score_gen, fmap_real, fmap_gen = discriminator(audio_real=audio_real, audio_gen=audio_gen)
scores_real += score_real
fmaps_real += fmap_real
scores_gen += score_gen
fmaps_gen += fmap_gen

return scores_real, scores_gen, fmaps_real, fmaps_gen


class FiniteScalarQuantizer(NeuralModule):
"""This quantizer is based on the Finite Scalar Quantization (FSQ) method.
It quantizes each element of the input vector independently into a number of levels.
Expand Down

0 comments on commit 46f6465

Please sign in to comment.