diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index 2f9b6f9be2..8e1fd410e6 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -7,9 +7,10 @@ from TTS.tts.layers.glow_tts.decoder import Decoder from TTS.tts.utils.generic_utils import sequence_mask from TTS.tts.layers.glow_tts.monotonic_align import maximum_path, generate_path +from TTS.tts.models.tts_abstract import TTSAbstract -class GlowTts(nn.Module): +class GlowTts(TTSAbstract): """Glow TTS models from https://arxiv.org/abs/2005.11129 Args: @@ -179,7 +180,8 @@ def forward(self, x, x_lengths, y=None, y_lengths=None, attn=None, g=None): return z, logdet, y_mean, y_log_scale, attn, o_dur_log, o_attn_dur @torch.no_grad() - def inference(self, x, x_lengths, g=None): + def inference(self, x, x_lengths, g=None, *args, **kwargs): # pylint: disable=unused-argument,keyword-arg-before-vararg + if g is not None: if self.external_speaker_embedding_dim: g = F.normalize(g).unsqueeze(-1) diff --git a/TTS/tts/models/speedy_speech.py b/TTS/tts/models/speedy_speech.py index 93496d59a6..e722c5741d 100644 --- a/TTS/tts/models/speedy_speech.py +++ b/TTS/tts/models/speedy_speech.py @@ -5,9 +5,10 @@ from TTS.tts.layers.speedy_speech.encoder import Encoder, PositionalEncoding from TTS.tts.utils.generic_utils import sequence_mask from TTS.tts.layers.glow_tts.monotonic_align import generate_path +from TTS.tts.models.tts_abstract import TTSAbstract -class SpeedySpeech(nn.Module): +class SpeedySpeech(TTSAbstract): """Speedy Speech model https://arxiv.org/abs/2008.03802 @@ -36,29 +37,29 @@ class SpeedySpeech(nn.Module): # pylint: disable=dangerous-default-value def __init__( - self, - num_chars, - out_channels, - hidden_channels, - positional_encoding=True, - length_scale=1, - encoder_type='residual_conv_bn', - encoder_params={ - "kernel_size": 4, - "dilations": 4 * [1, 2, 4] + [1], - "num_conv_blocks": 2, - "num_res_blocks": 13 - }, - decoder_type='residual_conv_bn', - decoder_params={ - "kernel_size": 4, - "dilations": 4 * [1, 2, 4, 8] + [1], - "num_conv_blocks": 2, - "num_res_blocks": 17 - }, - num_speakers=0, - external_c=False, - c_in_channels=0): + self, + num_chars, + out_channels, + hidden_channels, + positional_encoding=True, + length_scale=1, + encoder_type='residual_conv_bn', + encoder_params={ + "kernel_size": 4, + "dilations": 4 * [1, 2, 4] + [1], + "num_conv_blocks": 2, + "num_res_blocks": 13 + }, + decoder_type='residual_conv_bn', + decoder_params={ + "kernel_size": 4, + "dilations": 4 * [1, 2, 4, 8] + [1], + "num_conv_blocks": 2, + "num_res_blocks": 17 + }, + num_speakers=0, + external_c=False, + c_in_channels=0): super().__init__() self.length_scale = float(length_scale) if isinstance(length_scale, int) else length_scale @@ -174,7 +175,7 @@ def forward(self, x, x_lengths, y_lengths, dr, g=None): # pylint: disable=unuse o_de, attn= self._forward_decoder(o_en, o_en_dp, dr, x_mask, y_lengths, g=g) return o_de, o_dr_log.squeeze(1), attn - def inference(self, x, x_lengths, g=None): # pylint: disable=unused-argument + def inference(self, x, x_lengths, g=None, *args, **kwargs): # pylint: disable=unused-argument,keyword-arg-before-vararg """ Shapes: x: [B, T_max] diff --git a/TTS/tts/models/tacotron2.py b/TTS/tts/models/tacotron2.py index e56e4ca069..26240cb86f 100644 --- a/TTS/tts/models/tacotron2.py +++ b/TTS/tts/models/tacotron2.py @@ -186,7 +186,7 @@ def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_ return decoder_outputs, postnet_outputs, alignments, stop_tokens @torch.no_grad() - def inference(self, text, speaker_ids=None, style_mel=None, speaker_embeddings=None): + def inference(self, text, *args, speaker_ids=None, style_mel=None, speaker_embeddings=None, **kwargs): embedded_inputs = self.embedding(text).transpose(1, 2) encoder_outputs = self.encoder.inference(embedded_inputs) diff --git a/TTS/tts/models/tacotron_abstract.py b/TTS/tts/models/tacotron_abstract.py index 1095326966..fb1ab764b1 100644 --- a/TTS/tts/models/tacotron_abstract.py +++ b/TTS/tts/models/tacotron_abstract.py @@ -1,13 +1,12 @@ import copy -from abc import ABC, abstractmethod import torch -from torch import nn from TTS.tts.utils.generic_utils import sequence_mask +from TTS.utils.io import AttrDict +from TTS.tts.models.tts_abstract import TTSAbstract - -class TacotronAbstract(ABC, nn.Module): +class TacotronAbstract(TTSAbstract): def __init__(self, num_chars, num_speakers, @@ -71,6 +70,7 @@ def __init__(self, self.encoder = None self.decoder = None self.postnet = None + # multispeaker if self.speaker_embedding_dim is None: @@ -113,15 +113,8 @@ def _init_coarse_decoder(self): # CORE FUNCTIONS ############################# - @abstractmethod - def forward(self): - pass - - @abstractmethod - def inference(self): - pass - def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin + def load_checkpoint(self, config: AttrDict, checkpoint_path: str, eval: bool = False): # pylint: disable=unused-argument, redefined-builtin state = torch.load(checkpoint_path, map_location=torch.device('cpu')) self.load_state_dict(state['model']) self.decoder.set_r(state['r']) diff --git a/TTS/tts/models/tts_abstract.py b/TTS/tts/models/tts_abstract.py new file mode 100644 index 0000000000..65f971f34f --- /dev/null +++ b/TTS/tts/models/tts_abstract.py @@ -0,0 +1,26 @@ +from TTS.utils.io import AttrDict +from torch import nn +from abc import ABC, abstractmethod + + +class TTSAbstract(ABC, nn.Module): + """Abstract for tts model (tacotron, speedy_speech, glow_tts ...) + + Heritance: + ABC: Abstract Base Class + nn.Module: pytorch nn.Module + """ + + @abstractmethod + def forward(self): + pass + + @abstractmethod + def inference(self, text, speaker_ids=None, style_mel=None, speaker_embeddings=None): + pass + + @abstractmethod + def load_checkpoint(self, config: AttrDict, checkpoint_path: str, eval: bool = False): + pass + + diff --git a/TTS/tts/utils/generic_utils.py b/TTS/tts/utils/generic_utils.py index d898aebd72..787cd3dbf6 100644 --- a/TTS/tts/utils/generic_utils.py +++ b/TTS/tts/utils/generic_utils.py @@ -5,6 +5,7 @@ from collections import Counter from TTS.utils.generic_utils import check_argument +from TTS.tts.models.tts_abstract import TTSAbstract def split_dataset(items): @@ -44,12 +45,12 @@ def to_camel(text): return re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text) -def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None): +def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None) -> TTSAbstract: print(" > Using model: {}".format(c.model)) MyModel = importlib.import_module('TTS.tts.models.' + c.model.lower()) MyModel = getattr(MyModel, to_camel(c.model)) if c.model.lower() in "tacotron": - model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False), + model: TTSAbstract = MyModel(num_chars=num_chars + getattr(c, "add_blank", False), num_speakers=num_speakers, r=c.r, postnet_output_dim=int(c.audio['fft_size'] / 2 + 1), @@ -76,7 +77,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None): ddc_r=c.ddc_r, speaker_embedding_dim=speaker_embedding_dim) elif c.model.lower() == "tacotron2": - model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False), + model: TTSAbstract = MyModel(num_chars=num_chars + getattr(c, "add_blank", False), num_speakers=num_speakers, r=c.r, postnet_output_dim=c.audio['num_mels'], @@ -102,7 +103,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None): ddc_r=c.ddc_r, speaker_embedding_dim=speaker_embedding_dim) elif c.model.lower() == "glow_tts": - model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False), + model: TTSAbstract = MyModel(num_chars=num_chars + getattr(c, "add_blank", False), hidden_channels_enc=c['hidden_channels_encoder'], hidden_channels_dec=c['hidden_channels_decoder'], hidden_channels_dp=c['hidden_channels_duration_predictor'], @@ -123,7 +124,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None): mean_only=True, external_speaker_embedding_dim=speaker_embedding_dim) elif c.model.lower() == "speedy_speech": - model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False), + model: TTSAbstract = MyModel(num_chars=num_chars + getattr(c, "add_blank", False), out_channels=c.audio['num_mels'], hidden_channels=c['hidden_channels'], positional_encoding=c['positional_encoding'], @@ -132,6 +133,8 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None): decoder_type=c['decoder_type'], decoder_params=c['decoder_params'], c_in_channels=0) + else: + return BaseException("Model type is not allowed : ", c.model.lower()) return model def is_tacotron(c): diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index 43bb1f6a0d..788c4af9ad 100644 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -2,12 +2,12 @@ import json -def make_speakers_json_path(out_path): +def make_speakers_json_path(out_path: str) -> str: """Returns conventional speakers.json location.""" return os.path.join(out_path, "speakers.json") -def load_speaker_mapping(out_path): +def load_speaker_mapping(out_path: str) -> dict: """Loads speaker mapping if already present.""" try: if os.path.splitext(out_path)[1] == '.json': diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index adbd0d209b..58ea1a7daf 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -8,6 +8,8 @@ import numpy as np from .text import text_to_sequence, phoneme_to_sequence +from TTS.utils.io import AttrDict +from TTS.tts.models.tts_abstract import TTSAbstract def text_to_seqvec(text, CONFIG): text_cleaner = [CONFIG.text_cleaner] @@ -50,7 +52,7 @@ def compute_style_mel(style_wav, ap, cuda=False): return style_mel -def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None, speaker_embeddings=None): +def run_model_torch(model: TTSAbstract, inputs, CONFIG: AttrDict, truncated: bool, speaker_id=None, style_mel=None, speaker_embeddings=None): if 'tacotron' in CONFIG.model.lower(): if CONFIG.use_gst: decoder_output, postnet_output, alignments, stop_tokens = model.inference( @@ -196,10 +198,10 @@ def apply_griffin_lim(inputs, input_lens, CONFIG, ap): return wavs -def synthesis(model, - text, - CONFIG, - use_cuda, +def synthesis(model: TTSAbstract, + text: str, + CONFIG: AttrDict, + use_cuda: bool, ap, speaker_id=None, style_wav=None, diff --git a/TTS/utils/io.py b/TTS/utils/io.py index 30b7b7e27f..17acf5ad9c 100644 --- a/TTS/utils/io.py +++ b/TTS/utils/io.py @@ -20,7 +20,7 @@ def __init__(self, *args, **kwargs): self.__dict__ = self -def read_json_with_comments(json_path): +def read_json_with_comments(json_path: str) -> AttrDict: # fallback to json with open(json_path, "r", encoding = "utf-8") as f: input_str = f.read() diff --git a/TTS/vocoder/models/melgan_generator.py b/TTS/vocoder/models/melgan_generator.py index 3070eac779..c783f4063a 100644 --- a/TTS/vocoder/models/melgan_generator.py +++ b/TTS/vocoder/models/melgan_generator.py @@ -2,10 +2,11 @@ from torch import nn from torch.nn.utils import weight_norm +from TTS.vocoder.models.vocoder_abstract import VocoderAbstract from TTS.vocoder.layers.melgan import ResidualStack -class MelganGenerator(nn.Module): +class MelganGenerator(VocoderAbstract): def __init__(self, in_channels=80, out_channels=1, diff --git a/TTS/vocoder/models/parallel_wavegan_generator.py b/TTS/vocoder/models/parallel_wavegan_generator.py index 1d1bcdcbf8..6b236fe1ae 100644 --- a/TTS/vocoder/models/parallel_wavegan_generator.py +++ b/TTS/vocoder/models/parallel_wavegan_generator.py @@ -4,9 +4,9 @@ from TTS.vocoder.layers.parallel_wavegan import ResidualBlock from TTS.vocoder.layers.upsample import ConvUpsample +from TTS.vocoder.models.vocoder_abstract import VocoderAbstract - -class ParallelWaveganGenerator(torch.nn.Module): +class ParallelWaveganGenerator(VocoderAbstract): """PWGAN generator as in https://arxiv.org/pdf/1910.11480.pdf. It is similar to WaveNet with no causal convolution. It is conditioned on an aux feature (spectrogram) to generate diff --git a/TTS/vocoder/models/vocoder_abstract.py b/TTS/vocoder/models/vocoder_abstract.py new file mode 100644 index 0000000000..d065cbaefe --- /dev/null +++ b/TTS/vocoder/models/vocoder_abstract.py @@ -0,0 +1,25 @@ +from TTS.utils.io import AttrDict +from torch import nn +from abc import ABC, abstractmethod + + +class VocoderAbstract(ABC, nn.Module): + """Abstract for vocoder model (melgan, wavernn, etc ...) + + Heritance: + ABC: Abstract Base Class + nn.Module: pytorch nn.Module + """ + @abstractmethod + def forward(self): + pass + + @abstractmethod + def inference(self, text, speaker_ids=None, style_mel=None, speaker_embeddings=None): + pass + + @abstractmethod + def load_checkpoint(self, config: AttrDict, checkpoint_path: str, eval: bool = False): + pass + + diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py index f4a5faa3a0..5a4821bbcb 100644 --- a/TTS/vocoder/models/wavegrad.py +++ b/TTS/vocoder/models/wavegrad.py @@ -4,9 +4,10 @@ from torch.nn.utils import weight_norm from ..layers.wavegrad import DBlock, FiLM, UBlock, Conv1d +from TTS.vocoder.models.vocoder_abstract import VocoderAbstract -class Wavegrad(nn.Module): +class Wavegrad(VocoderAbstract): # pylint: disable=dangerous-default-value def __init__(self, in_channels=80, diff --git a/TTS/vocoder/utils/generic_utils.py b/TTS/vocoder/utils/generic_utils.py index 0d532063ff..11146eef37 100644 --- a/TTS/vocoder/utils/generic_utils.py +++ b/TTS/vocoder/utils/generic_utils.py @@ -5,6 +5,8 @@ from matplotlib import pyplot as plt from TTS.tts.utils.visual import plot_spectrogram +from TTS.utils.io import AttrDict +from TTS.vocoder.models.vocoder_abstract import VocoderAbstract def interpolate_vocoder_input(scale_factor, spec): @@ -61,12 +63,25 @@ def plot_results(y_hat, y, ap, global_step, name_prefix): return figures -def to_camel(text): +def to_camel(text: str) -> str: text = text.capitalize() return re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text) -def setup_generator(c): +def setup_generator(vocoder_config: AttrDict) -> VocoderAbstract: + """Generate the vocoder model + + Args: + vocoder_config (AttrDict): vocoder config as a AttrDict object + + Raises: + ValueError: if requested model is not available anymore + NotImplementedError: if requested model is not implemented yet + + Returns: + VocoderAbstract: Return a vocoder that heritate from VocoderAbstract + """ + c = vocoder_config print(" > Generator Model: {}".format(c.generator_model)) MyModel = importlib.import_module('TTS.vocoder.models.' + c.generator_model.lower()) @@ -76,7 +91,7 @@ def setup_generator(c): else: MyModel = getattr(MyModel, to_camel(c.generator_model)) if c.generator_model.lower() in 'wavernn': - model = MyModel( + model: VocoderAbstract = MyModel( rnn_dims=c.wavernn_model_params['rnn_dims'], fc_dims=c.wavernn_model_params['fc_dims'], mode=c.mode, @@ -92,7 +107,7 @@ def setup_generator(c): hop_length=c.audio["hop_length"], sample_rate=c.audio["sample_rate"],) elif c.generator_model.lower() in 'melgan_generator': - model = MyModel( + model: VocoderAbstract = MyModel( in_channels=c.audio['num_mels'], out_channels=1, proj_kernel=7, @@ -104,7 +119,7 @@ def setup_generator(c): raise ValueError( 'melgan_fb_generator is now fullband_melgan_generator') elif c.generator_model.lower() in 'multiband_melgan_generator': - model = MyModel( + model: VocoderAbstract = MyModel( in_channels=c.audio['num_mels'], out_channels=4, proj_kernel=7, @@ -113,7 +128,7 @@ def setup_generator(c): res_kernel=3, num_res_blocks=c.generator_model_params['num_res_blocks']) elif c.generator_model.lower() in 'fullband_melgan_generator': - model = MyModel( + model: VocoderAbstract = MyModel( in_channels=c.audio['num_mels'], out_channels=1, proj_kernel=7, @@ -122,7 +137,7 @@ def setup_generator(c): res_kernel=3, num_res_blocks=c.generator_model_params['num_res_blocks']) elif c.generator_model.lower() in 'parallel_wavegan_generator': - model = MyModel( + model: VocoderAbstract = MyModel( in_channels=1, out_channels=1, kernel_size=3, @@ -137,7 +152,7 @@ def setup_generator(c): use_weight_norm=True, upsample_factors=c.generator_model_params['upsample_factors']) elif c.generator_model.lower() in 'wavegrad': - model = MyModel( + model: VocoderAbstract = MyModel( in_channels=c['audio']['num_mels'], out_channels=1, use_weight_norm=c['model_params']['use_weight_norm'], @@ -210,8 +225,3 @@ def setup_discriminator(c): bias=True ) return model - - -# def check_config(c): -# c = None -# pass