Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fastspeech2 #2073

Merged
merged 41 commits into from
Jan 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
94ae04c
added EnergyDataset
manmay-nakhashi Oct 12, 2022
c760048
add energy to Dataset
manmay-nakhashi Oct 12, 2022
8788ef1
add comupte_energy
manmay-nakhashi Oct 12, 2022
450d92f
added energy params
manmay-nakhashi Oct 12, 2022
cebd03b
added energy to forward_tts
manmay-nakhashi Oct 12, 2022
83cc206
added plot_avg_energy for visualisation
manmay-nakhashi Oct 12, 2022
17bb411
Update forward_tts.py
manmay-nakhashi Oct 12, 2022
8d8d0b5
create file
manmay-nakhashi Oct 12, 2022
6d70846
added fastspeech2 recipe
manmay-nakhashi Oct 12, 2022
61b332d
add fastspeech2 config
manmay-nakhashi Oct 12, 2022
c5c38ac
removed energy from fast pitch
manmay-nakhashi Oct 12, 2022
1ad399b
add energy loss to forward tts
manmay-nakhashi Oct 12, 2022
593b49e
Update fastspeech2_config.py
manmay-nakhashi Oct 12, 2022
ff5eaec
change run_name
manmay-nakhashi Oct 12, 2022
9da6bfd
Update numpy_transforms.py
manmay-nakhashi Oct 12, 2022
bac63cb
fix typo
manmay-nakhashi Oct 12, 2022
873f769
fix typo
manmay-nakhashi Oct 12, 2022
ade6ee3
fix typo
manmay-nakhashi Oct 12, 2022
af0e97c
linting issues
manmay-nakhashi Oct 12, 2022
c54556f
use_energy default value --> False
manmay-nakhashi Oct 12, 2022
30ebf13
Update numpy_transforms.py
manmay-nakhashi Oct 12, 2022
7efb1ad
linting fixes
manmay-nakhashi Oct 12, 2022
7f6a17c
fix typo
manmay-nakhashi Oct 12, 2022
55ecdb3
liniting_fix
manmay-nakhashi Oct 13, 2022
33b5ab0
liniting_fix
manmay-nakhashi Oct 13, 2022
e6ba384
fix
manmay-nakhashi Oct 13, 2022
cf2f1a2
fixes
manmay-nakhashi Oct 13, 2022
24dee75
fixes
manmay-nakhashi Oct 13, 2022
a061536
resolve conflict
manmay-nakhashi Dec 9, 2022
2bf9a72
lint fix
manmay-nakhashi Dec 9, 2022
1129665
lint fixws
manmay-nakhashi Dec 25, 2022
e4af27b
added training test
manmay-nakhashi Dec 31, 2022
39edcd0
wrong import
manmay-nakhashi Jan 3, 2023
edc6d49
wrong import
manmay-nakhashi Jan 3, 2023
f365ea5
trailing whitespace
manmay-nakhashi Jan 3, 2023
dd818e3
style fix
manmay-nakhashi Jan 3, 2023
d4fb7cc
changed class name because of error
manmay-nakhashi Jan 3, 2023
6694b34
class name change
manmay-nakhashi Jan 3, 2023
ccaee51
class name change
manmay-nakhashi Jan 3, 2023
d586469
change class name
manmay-nakhashi Jan 3, 2023
731a9d6
fixed styles
manmay-nakhashi Jan 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions TTS/tts/configs/fast_pitch_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,13 @@ class FastPitchConfig(BaseTTSConfig):

max_seq_len (int):
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.

# dataset configs
compute_f0(bool):
Compute pitch. defaults to True

f0_cache_path(str):
pith cache path. defaults to None
"""

model: str = "fast_pitch"
Expand Down
198 changes: 198 additions & 0 deletions TTS/tts/configs/fastspeech2_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
from dataclasses import dataclass, field
from typing import List

from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.models.forward_tts import ForwardTTSArgs


@dataclass
class Fastspeech2Config(BaseTTSConfig):
"""Configure `ForwardTTS` as FastPitch model.

Example:

>>> from TTS.tts.configs.fastspeech2_config import FastSpeech2Config
>>> config = FastSpeech2Config()

Args:
model (str):
Model name used for selecting the right model at initialization. Defaults to `fast_pitch`.

base_model (str):
Name of the base model being configured as this model so that 🐸 TTS knows it needs to initiate
the base model rather than searching for the `model` implementation. Defaults to `forward_tts`.

model_args (Coqpit):
Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`.

data_dep_init_steps (int):
Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses
Activation Normalization that pre-computes normalization stats at the beginning and use the same values
for the rest. Defaults to 10.

speakers_file (str):
Path to the file containing the list of speakers. Needed at inference for loading matching speaker ids to
speaker names. Defaults to `None`.

use_speaker_embedding (bool):
enable / disable using speaker embeddings for multi-speaker models. If set True, the model is
in the multi-speaker mode. Defaults to False.

use_d_vector_file (bool):
enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False.

d_vector_file (str):
Path to the file including pre-computed speaker embeddings. Defaults to None.

d_vector_dim (int):
Dimension of the external speaker embeddings. Defaults to 0.

optimizer (str):
Name of the model optimizer. Defaults to `Adam`.

optimizer_params (dict):
Arguments of the model optimizer. Defaults to `{"betas": [0.9, 0.998], "weight_decay": 1e-6}`.

lr_scheduler (str):
Name of the learning rate scheduler. Defaults to `Noam`.

lr_scheduler_params (dict):
Arguments of the learning rate scheduler. Defaults to `{"warmup_steps": 4000}`.

lr (float):
Initial learning rate. Defaults to `1e-3`.

grad_clip (float):
Gradient norm clipping value. Defaults to `5.0`.

spec_loss_type (str):
Type of the spectrogram loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`.

duration_loss_type (str):
Type of the duration loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`.

use_ssim_loss (bool):
Enable/disable the use of SSIM (Structural Similarity) loss. Defaults to True.

wd (float):
Weight decay coefficient. Defaults to `1e-7`.

ssim_loss_alpha (float):
Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0.

dur_loss_alpha (float):
Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0.

spec_loss_alpha (float):
Weight for the L1 spectrogram loss. If set 0, disables the L1 loss. Defaults to 1.0.

pitch_loss_alpha (float):
Weight for the pitch predictor's loss. If set 0, disables the pitch predictor. Defaults to 1.0.

energy_loss_alpha (float):
Weight for the energy predictor's loss. If set 0, disables the energy predictor. Defaults to 1.0.

binary_align_loss_alpha (float):
Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0.

binary_loss_warmup_epochs (float):
Number of epochs to gradually increase the binary loss impact. Defaults to 150.

min_seq_len (int):
Minimum input sequence length to be used at training.

max_seq_len (int):
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.

# dataset configs
compute_f0(bool):
Compute pitch. defaults to True

f0_cache_path(str):
pith cache path. defaults to None

# dataset configs
compute_energy(bool):
Compute energy. defaults to True

energy_cache_path(str):
energy cache path. defaults to None
"""

model: str = "fastspeech2"
base_model: str = "forward_tts"

# model specific params
model_args: ForwardTTSArgs = ForwardTTSArgs()

# multi-speaker settings
num_speakers: int = 0
speakers_file: str = None
use_speaker_embedding: bool = False
use_d_vector_file: bool = False
d_vector_file: str = False
d_vector_dim: int = 0

# optimizer parameters
optimizer: str = "Adam"
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6})
lr_scheduler: str = "NoamLR"
lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000})
lr: float = 1e-4
grad_clip: float = 5.0

# loss params
spec_loss_type: str = "mse"
duration_loss_type: str = "mse"
use_ssim_loss: bool = True
ssim_loss_alpha: float = 1.0
spec_loss_alpha: float = 1.0
aligner_loss_alpha: float = 1.0
pitch_loss_alpha: float = 0.1
energy_loss_alpha: float = 0.1
dur_loss_alpha: float = 0.1
binary_align_loss_alpha: float = 0.1
binary_loss_warmup_epochs: int = 150

# overrides
min_seq_len: int = 13
max_seq_len: int = 200
r: int = 1 # DO NOT CHANGE

# dataset configs
compute_f0: bool = True
f0_cache_path: str = None

# dataset configs
compute_energy: bool = True
energy_cache_path: str = None

# testing
test_sentences: List[str] = field(
default_factory=lambda: [
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"Be a voice, not an echo.",
"I'm sorry Dave. I'm afraid I can't do that.",
"This cake is great. It's so delicious and moist.",
"Prior to November 22, 1963.",
]
)

def __post_init__(self):
# Pass multi-speaker parameters to the model args as `model.init_multispeaker()` looks for it there.
if self.num_speakers > 0:
self.model_args.num_speakers = self.num_speakers

# speaker embedding settings
if self.use_speaker_embedding:
self.model_args.use_speaker_embedding = True
if self.speakers_file:
self.model_args.speakers_file = self.speakers_file

# d-vector settings
if self.use_d_vector_file:
self.model_args.use_d_vector_file = True
if self.d_vector_dim is not None and self.d_vector_dim > 0:
self.model_args.d_vector_dim = self.d_vector_dim
if self.d_vector_file:
self.model_args.d_vector_file = self.d_vector_file
4 changes: 4 additions & 0 deletions TTS/tts/configs/shared_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,9 @@ class BaseTTSConfig(BaseTrainingConfig):
compute_f0 (int):
(Not in use yet).

compute_energy (int):
(Not in use yet).

compute_linear_spec (bool):
If True data loader computes and returns linear spectrograms alongside the other data.

Expand Down Expand Up @@ -305,6 +308,7 @@ class BaseTTSConfig(BaseTrainingConfig):
min_text_len: int = 1
max_text_len: int = float("inf")
compute_f0: bool = False
compute_energy: bool = False
compute_linear_spec: bool = False
precompute_num_workers: int = 0
use_noise_augment: bool = False
Expand Down
Loading