diff --git a/espnet2/gan_tts/vits/vits.py b/espnet2/gan_tts/vits/vits.py index 3f906b96374..e08c486d3fe 100644 --- a/espnet2/gan_tts/vits/vits.py +++ b/espnet2/gan_tts/vits/vits.py @@ -3,6 +3,8 @@ """VITS module for GAN-TTS task.""" +from contextlib import contextmanager +from distutils.version import LooseVersion from typing import Any from typing import Dict from typing import Optional @@ -37,6 +39,14 @@ "hifigan_multi_scale_multi_period_discriminator": HiFiGANMultiScaleMultiPeriodDiscriminator, # NOQA } +if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): + from torch.cuda.amp import autocast +else: + # Nothing to do if torch<1.6.0 + @contextmanager + def autocast(enabled=True): # NOQA + yield + class VITS(AbsGANTTS): """VITS module (generator + discriminator). @@ -398,18 +408,19 @@ def _forward_generator( p = self.discriminator(speech_) # calculate losses - mel_loss = self.mel_loss(speech_hat_, speech_) - kl_loss = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask) - dur_loss = torch.sum(dur_nll.float()) - adv_loss = self.generator_adv_loss(p_hat) - feat_match_loss = self.feat_match_loss(p_hat, p) - - mel_loss = mel_loss * self.lambda_mel - kl_loss = kl_loss * self.lambda_kl - dur_loss = dur_loss * self.lambda_dur - adv_loss = adv_loss * self.lambda_adv - feat_match_loss = feat_match_loss * self.lambda_feat_match - loss = mel_loss + kl_loss + dur_loss + adv_loss + feat_match_loss + with autocast(enabled=False): + mel_loss = self.mel_loss(speech_hat_, speech_) + kl_loss = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask) + dur_loss = torch.sum(dur_nll.float()) + adv_loss = self.generator_adv_loss(p_hat) + feat_match_loss = self.feat_match_loss(p_hat, p) + + mel_loss = mel_loss * self.lambda_mel + kl_loss = kl_loss * self.lambda_kl + dur_loss = dur_loss * self.lambda_dur + adv_loss = adv_loss * self.lambda_adv + feat_match_loss = feat_match_loss * self.lambda_feat_match + loss = mel_loss + kl_loss + dur_loss + adv_loss + feat_match_loss stats = dict( generator_loss=loss.item(), @@ -504,8 +515,9 @@ def _forward_discrminator( p = self.discriminator(speech_) # calculate losses - real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p) - loss = real_loss + fake_loss + with autocast(enabled=False): + real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p) + loss = real_loss + fake_loss stats = dict( discriminator_loss=loss.item(),