Skip to content

Commit

Permalink
Adding neural HMM TTS Model (#2272)
Browse files Browse the repository at this point in the history
* Adding neural HMM TTS

* Adding tests

* Adding neural hmm on readme

* renaming training recipe

* Removing overflow\s decoder parameters from the config

* Update the Trainer requirement version for a compatible one (#2276)

* Bump up to v0.10.2

* Adding neural HMM TTS

* Adding tests

* Adding neural hmm on readme

* renaming training recipe

* Removing overflow\s decoder parameters from the config

* fixing documentation

Co-authored-by: Edresson Casanova <edresson1@gmail.com>
Co-authored-by: Eren Gölge <erogol@hotmail.com>
  • Loading branch information
3 people authored Jan 23, 2023
1 parent 497f22b commit d83ee8f
Show file tree
Hide file tree
Showing 7 changed files with 753 additions and 8 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ Underlined "TTS*" and "Judy*" are 🐸TTS models
- SC-GlowTTS: [paper](https://arxiv.org/abs/2104.05557)
- Capacitron: [paper](https://arxiv.org/abs/1906.03402)
- OverFlow: [paper](https://arxiv.org/abs/2211.06892)
- Neural HMM TTS: [paper](https://arxiv.org/abs/2108.13320)

### End-to-End Models
- VITS: [paper](https://arxiv.org/pdf/2106.06103)
Expand Down
170 changes: 170 additions & 0 deletions TTS/tts/configs/neuralhmm_tts_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from dataclasses import dataclass, field
from typing import List

from TTS.tts.configs.shared_configs import BaseTTSConfig


@dataclass
class NeuralhmmTTSConfig(BaseTTSConfig):
"""
Define parameters for Neural HMM TTS model.
Example:
>>> from TTS.tts.configs.overflow_config import OverflowConfig
>>> config = OverflowConfig()
Args:
model (str):
Model name used to select the right model class to initilize. Defaults to `Overflow`.
run_eval_steps (int):
Run evalulation epoch after N steps. If None, waits until training epoch is completed. Defaults to None.
save_step (int):
Save local checkpoint every save_step steps. Defaults to 500.
plot_step (int):
Plot training stats on the logger every plot_step steps. Defaults to 1.
model_param_stats (bool):
Log model parameters stats on the logger dashboard. Defaults to False.
force_generate_statistics (bool):
Force generate mel normalization statistics. Defaults to False.
mel_statistics_parameter_path (str):
Path to the mel normalization statistics.If the model doesn't finds a file there it will generate statistics.
Defaults to None.
num_chars (int):
Number of characters used by the model. It must be defined before initializing the model. Defaults to None.
state_per_phone (int):
Generates N states per phone. Similar, to `add_blank` parameter in GlowTTS but in Overflow it is upsampled by model's encoder. Defaults to 2.
encoder_in_out_features (int):
Channels of encoder input and character embedding tensors. Defaults to 512.
encoder_n_convolutions (int):
Number of convolution layers in the encoder. Defaults to 3.
out_channels (int):
Channels of the final model output. It must match the spectragram size. Defaults to 80.
ar_order (int):
Autoregressive order of the model. Defaults to 1. In ablations of Neural HMM it was found that more autoregression while giving more variation hurts naturalness of the synthesised audio.
sampling_temp (float):
Variation added to the sample from the latent space of neural HMM. Defaults to 0.334.
deterministic_transition (bool):
deterministic duration generation based on duration quantiles as defiend in "S. Ronanki, O. Watts, S. King, and G. E. Henter, “Medianbased generation of synthetic speech durations using a nonparametric approach,” in Proc. SLT, 2016.". Defaults to True.
duration_threshold (float):
Threshold for duration quantiles. Defaults to 0.55. Tune this to change the speaking rate of the synthesis, where lower values defines a slower speaking rate and higher values defines a faster speaking rate.
use_grad_checkpointing (bool):
Use gradient checkpointing to save memory. In a multi-GPU setting currently pytorch does not supports gradient checkpoint inside a loop so we will have to turn it off then.Adjust depending on whatever get more batch size either by using a single GPU or multi-GPU. Defaults to True.
max_sampling_time (int):
Maximum sampling time while synthesising latents from neural HMM. Defaults to 1000.
prenet_type (str):
`original` or `bn`. `original` sets the default Prenet and `bn` uses Batch Normalization version of the
Prenet. Defaults to `original`.
prenet_dim (int):
Dimension of the Prenet. Defaults to 256.
prenet_n_layers (int):
Number of layers in the Prenet. Defaults to 2.
prenet_dropout (float):
Dropout rate of the Prenet. Defaults to 0.5.
prenet_dropout_at_inference (bool):
Use dropout at inference time. Defaults to False.
memory_rnn_dim (int):
Dimension of the memory LSTM to process the prenet output. Defaults to 1024.
outputnet_size (list[int]):
Size of the output network inside the neural HMM. Defaults to [1024].
flat_start_params (dict):
Parameters for the flat start initialization of the neural HMM. Defaults to `{"mean": 0.0, "std": 1.0, "transition_p": 0.14}`.
It will be recomputed when you pass the dataset.
std_floor (float):
Floor value for the standard deviation of the neural HMM. Prevents model cheating by putting point mass and getting infinite likelihood at any datapoint. Defaults to 0.01.
It is called `variance flooring` in standard HMM literature.
optimizer (str):
Optimizer to use for training. Defaults to `adam`.
optimizer_params (dict):
Parameters for the optimizer. Defaults to `{"weight_decay": 1e-6}`.
grad_clip (float):
Gradient clipping threshold. Defaults to 40_000.
lr (float):
Learning rate. Defaults to 1e-3.
lr_scheduler (str):
Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or
`TTS.utils.training`. Defaults to `None`.
min_seq_len (int):
Minimum input sequence length to be used at training.
max_seq_len (int):
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
"""

model: str = "NeuralHMM_TTS"

# Training and Checkpoint configs
run_eval_steps: int = 100
save_step: int = 500
plot_step: int = 1
model_param_stats: bool = False

# data parameters
force_generate_statistics: bool = False
mel_statistics_parameter_path: str = None

# Encoder parameters
num_chars: int = None
state_per_phone: int = 2
encoder_in_out_features: int = 512
encoder_n_convolutions: int = 3

# HMM parameters
out_channels: int = 80
ar_order: int = 1
sampling_temp: float = 0
deterministic_transition: bool = True
duration_threshold: float = 0.43
use_grad_checkpointing: bool = True
max_sampling_time: int = 1000

## Prenet parameters
prenet_type: str = "original"
prenet_dim: int = 256
prenet_n_layers: int = 2
prenet_dropout: float = 0.5
prenet_dropout_at_inference: bool = True
memory_rnn_dim: int = 1024

## Outputnet parameters
outputnet_size: List[int] = field(default_factory=lambda: [1024])
flat_start_params: dict = field(default_factory=lambda: {"mean": 0.0, "std": 1.0, "transition_p": 0.14})
std_floor: float = 0.001

# optimizer parameters
optimizer: str = "Adam"
optimizer_params: dict = field(default_factory=lambda: {"weight_decay": 1e-6})
grad_clip: float = 40000.0
lr: float = 1e-3
lr_scheduler: str = None

# overrides
min_text_len: int = 10
max_text_len: int = 500
min_audio_len: int = 512

# testing
test_sentences: List[str] = field(
default_factory=lambda: [
"Be a voice, not an echo.",
]
)

# Extra needed config
r: int = 1
use_d_vector_file: bool = False
use_speaker_embedding: bool = False

def check_values(self):
"""Validate the hyperparameters.
Raises:
AssertionError: when the parameters network is not defined
AssertionError: transition probability is not between 0 and 1
"""
assert self.ar_order > 0, "AR order must be greater than 0 it is an autoregressive model."
assert (
len(self.outputnet_size) >= 1
), f"Parameter Network must have atleast one layer check the config file for parameter network. Provided: {self.parameternetwork}"
assert (
0 < self.flat_start_params["transition_p"] < 1
), f"Transition probability must be between 0 and 1. Provided: {self.flat_start_params['transition_p']}"
15 changes: 10 additions & 5 deletions TTS/tts/layers/overflow/plotting_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,29 @@ def validate_numpy_array(value: Any):
return value


def get_spec_from_most_probable_state(log_alpha_scaled, means, decoder):
def get_spec_from_most_probable_state(log_alpha_scaled, means, decoder=None):
"""Get the most probable state means from the log_alpha_scaled.
Args:
log_alpha_scaled (torch.Tensor): Log alpha scaled values.
- Shape: :math:`(T, N)`
means (torch.Tensor): Means of the states.
- Shape: :math:`(N, T, D_out)`
decoder (torch.nn.Module): Decoder module to decode the latent to melspectrogram
decoder (torch.nn.Module): Decoder module to decode the latent to melspectrogram. Defaults to None.
"""
max_state_numbers = torch.max(log_alpha_scaled, dim=1)[1]
max_len = means.shape[0]
n_mel_channels = means.shape[2]
max_state_numbers = max_state_numbers.unsqueeze(1).unsqueeze(1).expand(max_len, 1, n_mel_channels)
means = torch.gather(means, 1, max_state_numbers).squeeze(1).to(log_alpha_scaled.dtype)
mel = (
decoder(means.T.unsqueeze(0), torch.tensor([means.shape[0]], device=means.device), reverse=True)[0].squeeze(0).T
)
if decoder is not None:
mel = (
decoder(means.T.unsqueeze(0), torch.tensor([means.shape[0]], device=means.device), reverse=True)[0]
.squeeze(0)
.T
)
else:
mel = means
return mel


Expand Down
Loading

0 comments on commit d83ee8f

Please sign in to comment.