Skip to content

Commit

Permalink
Merge pull request espnet#4356 from kan-bayashi/fix_mixed_precision_vits
Browse files Browse the repository at this point in the history
fix loss = NaN in VITS with mixed precision
  • Loading branch information
kan-bayashi authored May 11, 2022
2 parents beb3360 + ec7e2b0 commit 2dde773
Showing 1 changed file with 26 additions and 14 deletions.
40 changes: 26 additions & 14 deletions espnet2/gan_tts/vits/vits.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

"""VITS module for GAN-TTS task."""

from contextlib import contextmanager
from distutils.version import LooseVersion
from typing import Any
from typing import Dict
from typing import Optional
Expand Down Expand Up @@ -37,6 +39,14 @@
"hifigan_multi_scale_multi_period_discriminator": HiFiGANMultiScaleMultiPeriodDiscriminator, # NOQA
}

if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
else:
# Nothing to do if torch<1.6.0
@contextmanager
def autocast(enabled=True): # NOQA
yield


class VITS(AbsGANTTS):
"""VITS module (generator + discriminator).
Expand Down Expand Up @@ -398,18 +408,19 @@ def _forward_generator(
p = self.discriminator(speech_)

# calculate losses
mel_loss = self.mel_loss(speech_hat_, speech_)
kl_loss = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask)
dur_loss = torch.sum(dur_nll.float())
adv_loss = self.generator_adv_loss(p_hat)
feat_match_loss = self.feat_match_loss(p_hat, p)

mel_loss = mel_loss * self.lambda_mel
kl_loss = kl_loss * self.lambda_kl
dur_loss = dur_loss * self.lambda_dur
adv_loss = adv_loss * self.lambda_adv
feat_match_loss = feat_match_loss * self.lambda_feat_match
loss = mel_loss + kl_loss + dur_loss + adv_loss + feat_match_loss
with autocast(enabled=False):
mel_loss = self.mel_loss(speech_hat_, speech_)
kl_loss = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask)
dur_loss = torch.sum(dur_nll.float())
adv_loss = self.generator_adv_loss(p_hat)
feat_match_loss = self.feat_match_loss(p_hat, p)

mel_loss = mel_loss * self.lambda_mel
kl_loss = kl_loss * self.lambda_kl
dur_loss = dur_loss * self.lambda_dur
adv_loss = adv_loss * self.lambda_adv
feat_match_loss = feat_match_loss * self.lambda_feat_match
loss = mel_loss + kl_loss + dur_loss + adv_loss + feat_match_loss

stats = dict(
generator_loss=loss.item(),
Expand Down Expand Up @@ -504,8 +515,9 @@ def _forward_discrminator(
p = self.discriminator(speech_)

# calculate losses
real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p)
loss = real_loss + fake_loss
with autocast(enabled=False):
real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p)
loss = real_loss + fake_loss

stats = dict(
discriminator_loss=loss.item(),
Expand Down

0 comments on commit 2dde773

Please sign in to comment.