Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev pr1 : add model and vocoder abstraction #4

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions TTS/tts/models/glow_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from TTS.tts.layers.glow_tts.decoder import Decoder
from TTS.tts.utils.generic_utils import sequence_mask
from TTS.tts.layers.glow_tts.monotonic_align import maximum_path, generate_path
from TTS.tts.models.tts_abstract import TTSAbstract


class GlowTts(nn.Module):
class GlowTts(TTSAbstract):
"""Glow TTS models from https://arxiv.org/abs/2005.11129

Args:
Expand Down Expand Up @@ -179,7 +180,8 @@ def forward(self, x, x_lengths, y=None, y_lengths=None, attn=None, g=None):
return z, logdet, y_mean, y_log_scale, attn, o_dur_log, o_attn_dur

@torch.no_grad()
def inference(self, x, x_lengths, g=None):
def inference(self, x, x_lengths, g=None, *args, **kwargs): # pylint: disable=unused-argument,keyword-arg-before-vararg

if g is not None:
if self.external_speaker_embedding_dim:
g = F.normalize(g).unsqueeze(-1)
Expand Down
51 changes: 26 additions & 25 deletions TTS/tts/models/speedy_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from TTS.tts.layers.speedy_speech.encoder import Encoder, PositionalEncoding
from TTS.tts.utils.generic_utils import sequence_mask
from TTS.tts.layers.glow_tts.monotonic_align import generate_path
from TTS.tts.models.tts_abstract import TTSAbstract


class SpeedySpeech(nn.Module):
class SpeedySpeech(TTSAbstract):
"""Speedy Speech model
https://arxiv.org/abs/2008.03802

Expand Down Expand Up @@ -36,29 +37,29 @@ class SpeedySpeech(nn.Module):
# pylint: disable=dangerous-default-value

def __init__(
self,
num_chars,
out_channels,
hidden_channels,
positional_encoding=True,
length_scale=1,
encoder_type='residual_conv_bn',
encoder_params={
"kernel_size": 4,
"dilations": 4 * [1, 2, 4] + [1],
"num_conv_blocks": 2,
"num_res_blocks": 13
},
decoder_type='residual_conv_bn',
decoder_params={
"kernel_size": 4,
"dilations": 4 * [1, 2, 4, 8] + [1],
"num_conv_blocks": 2,
"num_res_blocks": 17
},
num_speakers=0,
external_c=False,
c_in_channels=0):
self,
num_chars,
out_channels,
hidden_channels,
positional_encoding=True,
length_scale=1,
encoder_type='residual_conv_bn',
encoder_params={
"kernel_size": 4,
"dilations": 4 * [1, 2, 4] + [1],
"num_conv_blocks": 2,
"num_res_blocks": 13
},
decoder_type='residual_conv_bn',
decoder_params={
"kernel_size": 4,
"dilations": 4 * [1, 2, 4, 8] + [1],
"num_conv_blocks": 2,
"num_res_blocks": 17
},
num_speakers=0,
external_c=False,
c_in_channels=0):

super().__init__()
self.length_scale = float(length_scale) if isinstance(length_scale, int) else length_scale
Expand Down Expand Up @@ -174,7 +175,7 @@ def forward(self, x, x_lengths, y_lengths, dr, g=None): # pylint: disable=unuse
o_de, attn= self._forward_decoder(o_en, o_en_dp, dr, x_mask, y_lengths, g=g)
return o_de, o_dr_log.squeeze(1), attn

def inference(self, x, x_lengths, g=None): # pylint: disable=unused-argument
def inference(self, x, x_lengths, g=None, *args, **kwargs): # pylint: disable=unused-argument,keyword-arg-before-vararg
"""
Shapes:
x: [B, T_max]
Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/models/tacotron2.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_
return decoder_outputs, postnet_outputs, alignments, stop_tokens

@torch.no_grad()
def inference(self, text, speaker_ids=None, style_mel=None, speaker_embeddings=None):
def inference(self, text, *args, speaker_ids=None, style_mel=None, speaker_embeddings=None, **kwargs):
embedded_inputs = self.embedding(text).transpose(1, 2)
encoder_outputs = self.encoder.inference(embedded_inputs)

Expand Down
17 changes: 5 additions & 12 deletions TTS/tts/models/tacotron_abstract.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import copy
from abc import ABC, abstractmethod

import torch
from torch import nn

from TTS.tts.utils.generic_utils import sequence_mask
from TTS.utils.io import AttrDict
from TTS.tts.models.tts_abstract import TTSAbstract


class TacotronAbstract(ABC, nn.Module):
class TacotronAbstract(TTSAbstract):
def __init__(self,
num_chars,
num_speakers,
Expand Down Expand Up @@ -71,6 +70,7 @@ def __init__(self,
self.encoder = None
self.decoder = None
self.postnet = None


# multispeaker
if self.speaker_embedding_dim is None:
Expand Down Expand Up @@ -113,15 +113,8 @@ def _init_coarse_decoder(self):
# CORE FUNCTIONS
#############################

@abstractmethod
def forward(self):
pass

@abstractmethod
def inference(self):
pass

def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
def load_checkpoint(self, config: AttrDict, checkpoint_path: str, eval: bool = False): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
self.load_state_dict(state['model'])
self.decoder.set_r(state['r'])
Expand Down
26 changes: 26 additions & 0 deletions TTS/tts/models/tts_abstract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from TTS.utils.io import AttrDict
from torch import nn
from abc import ABC, abstractmethod


class TTSAbstract(ABC, nn.Module):
"""Abstract for tts model (tacotron, speedy_speech, glow_tts ...)

Heritance:
ABC: Abstract Base Class
nn.Module: pytorch nn.Module
"""

@abstractmethod
def forward(self):
pass

@abstractmethod
def inference(self, text, speaker_ids=None, style_mel=None, speaker_embeddings=None):
pass

@abstractmethod
def load_checkpoint(self, config: AttrDict, checkpoint_path: str, eval: bool = False):
pass


13 changes: 8 additions & 5 deletions TTS/tts/utils/generic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections import Counter

from TTS.utils.generic_utils import check_argument
from TTS.tts.models.tts_abstract import TTSAbstract


def split_dataset(items):
Expand Down Expand Up @@ -44,12 +45,12 @@ def to_camel(text):
return re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text)


def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None) -> TTSAbstract:
print(" > Using model: {}".format(c.model))
MyModel = importlib.import_module('TTS.tts.models.' + c.model.lower())
MyModel = getattr(MyModel, to_camel(c.model))
if c.model.lower() in "tacotron":
model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
model: TTSAbstract = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
num_speakers=num_speakers,
r=c.r,
postnet_output_dim=int(c.audio['fft_size'] / 2 + 1),
Expand All @@ -76,7 +77,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
ddc_r=c.ddc_r,
speaker_embedding_dim=speaker_embedding_dim)
elif c.model.lower() == "tacotron2":
model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
model: TTSAbstract = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
num_speakers=num_speakers,
r=c.r,
postnet_output_dim=c.audio['num_mels'],
Expand All @@ -102,7 +103,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
ddc_r=c.ddc_r,
speaker_embedding_dim=speaker_embedding_dim)
elif c.model.lower() == "glow_tts":
model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
model: TTSAbstract = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
hidden_channels_enc=c['hidden_channels_encoder'],
hidden_channels_dec=c['hidden_channels_decoder'],
hidden_channels_dp=c['hidden_channels_duration_predictor'],
Expand All @@ -123,7 +124,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
mean_only=True,
external_speaker_embedding_dim=speaker_embedding_dim)
elif c.model.lower() == "speedy_speech":
model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
model: TTSAbstract = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
out_channels=c.audio['num_mels'],
hidden_channels=c['hidden_channels'],
positional_encoding=c['positional_encoding'],
Expand All @@ -132,6 +133,8 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
decoder_type=c['decoder_type'],
decoder_params=c['decoder_params'],
c_in_channels=0)
else:
return BaseException("Model type is not allowed : ", c.model.lower())
return model

def is_tacotron(c):
Expand Down
4 changes: 2 additions & 2 deletions TTS/tts/utils/speakers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
import json


def make_speakers_json_path(out_path):
def make_speakers_json_path(out_path: str) -> str:
"""Returns conventional speakers.json location."""
return os.path.join(out_path, "speakers.json")


def load_speaker_mapping(out_path):
def load_speaker_mapping(out_path: str) -> dict:
"""Loads speaker mapping if already present."""
try:
if os.path.splitext(out_path)[1] == '.json':
Expand Down
12 changes: 7 additions & 5 deletions TTS/tts/utils/synthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import numpy as np
from .text import text_to_sequence, phoneme_to_sequence

from TTS.utils.io import AttrDict
from TTS.tts.models.tts_abstract import TTSAbstract

def text_to_seqvec(text, CONFIG):
text_cleaner = [CONFIG.text_cleaner]
Expand Down Expand Up @@ -50,7 +52,7 @@ def compute_style_mel(style_wav, ap, cuda=False):
return style_mel


def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None, speaker_embeddings=None):
def run_model_torch(model: TTSAbstract, inputs, CONFIG: AttrDict, truncated: bool, speaker_id=None, style_mel=None, speaker_embeddings=None):
if 'tacotron' in CONFIG.model.lower():
if CONFIG.use_gst:
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
Expand Down Expand Up @@ -196,10 +198,10 @@ def apply_griffin_lim(inputs, input_lens, CONFIG, ap):
return wavs


def synthesis(model,
text,
CONFIG,
use_cuda,
def synthesis(model: TTSAbstract,
text: str,
CONFIG: AttrDict,
use_cuda: bool,
ap,
speaker_id=None,
style_wav=None,
Expand Down
2 changes: 1 addition & 1 deletion TTS/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, *args, **kwargs):
self.__dict__ = self


def read_json_with_comments(json_path):
def read_json_with_comments(json_path: str) -> AttrDict:
# fallback to json
with open(json_path, "r", encoding = "utf-8") as f:
input_str = f.read()
Expand Down
3 changes: 2 additions & 1 deletion TTS/vocoder/models/melgan_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
from torch import nn
from torch.nn.utils import weight_norm

from TTS.vocoder.models.vocoder_abstract import VocoderAbstract
from TTS.vocoder.layers.melgan import ResidualStack


class MelganGenerator(nn.Module):
class MelganGenerator(VocoderAbstract):
def __init__(self,
in_channels=80,
out_channels=1,
Expand Down
4 changes: 2 additions & 2 deletions TTS/vocoder/models/parallel_wavegan_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
from TTS.vocoder.layers.upsample import ConvUpsample
from TTS.vocoder.models.vocoder_abstract import VocoderAbstract


class ParallelWaveganGenerator(torch.nn.Module):
class ParallelWaveganGenerator(VocoderAbstract):
"""PWGAN generator as in https://arxiv.org/pdf/1910.11480.pdf.
It is similar to WaveNet with no causal convolution.
It is conditioned on an aux feature (spectrogram) to generate
Expand Down
25 changes: 25 additions & 0 deletions TTS/vocoder/models/vocoder_abstract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from TTS.utils.io import AttrDict
from torch import nn
from abc import ABC, abstractmethod


class VocoderAbstract(ABC, nn.Module):
"""Abstract for vocoder model (melgan, wavernn, etc ...)

Heritance:
ABC: Abstract Base Class
nn.Module: pytorch nn.Module
"""
@abstractmethod
def forward(self):
pass

@abstractmethod
def inference(self, text, speaker_ids=None, style_mel=None, speaker_embeddings=None):
pass

@abstractmethod
def load_checkpoint(self, config: AttrDict, checkpoint_path: str, eval: bool = False):
pass


3 changes: 2 additions & 1 deletion TTS/vocoder/models/wavegrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from torch.nn.utils import weight_norm

from ..layers.wavegrad import DBlock, FiLM, UBlock, Conv1d
from TTS.vocoder.models.vocoder_abstract import VocoderAbstract


class Wavegrad(nn.Module):
class Wavegrad(VocoderAbstract):
# pylint: disable=dangerous-default-value
def __init__(self,
in_channels=80,
Expand Down
Loading