Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement VitsAudioConfig #1556

Merged
merged 13 commits into from
Jul 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ recursive-include TTS *.md
recursive-include TTS *.py
recursive-include TTS *.pyx
recursive-include images *.png

recursive-exclude tests *
prune tests*
6 changes: 5 additions & 1 deletion TTS/tts/configs/vits_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import List

from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.models.vits import VitsArgs
from TTS.tts.models.vits import VitsArgs, VitsAudioConfig


@dataclass
Expand All @@ -16,6 +16,9 @@ class VitsConfig(BaseTTSConfig):
model_args (VitsArgs):
Model architecture arguments. Defaults to `VitsArgs()`.

audio (VitsAudioConfig):
Audio processing configuration. Defaults to `VitsAudioConfig()`.

grad_clip (List):
Gradient clipping thresholds for each optimizer. Defaults to `[1000.0, 1000.0]`.

Expand Down Expand Up @@ -94,6 +97,7 @@ class VitsConfig(BaseTTSConfig):
model: str = "vits"
# model specific params
model_args: VitsArgs = field(default_factory=VitsArgs)
audio: VitsAudioConfig = VitsAudioConfig()

# optimizer
grad_clip: List[float] = field(default_factory=lambda: [1000, 1000])
Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/layers/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def forward(self, y_hat, y, length):

if ssim_loss.item() < 0.0:
print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 0.0")
ssim_loss = torch.tensor([0.0])
ssim_loss = torch.tensor([0.0])

return ssim_loss

Expand Down
16 changes: 16 additions & 0 deletions TTS/tts/models/vits.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,22 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm
return spec


#############################
# CONFIGS
#############################


@dataclass
class VitsAudioConfig(Coqpit):
fft_size: int = 1024
sample_rate: int = 22050
win_length: int = 1024
hop_length: int = 256
num_mels: int = 80
mel_fmin: int = 0
mel_fmax: int = None


##############################
# DATASET
##############################
Expand Down
4 changes: 2 additions & 2 deletions TTS/tts/utils/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ def _reduce(x: torch.Tensor, reduction: str = "mean") -> torch.Tensor:
"""
if reduction == "none":
return x
elif reduction == "mean":
if reduction == "mean":
return x.mean(dim=0)
elif reduction == "sum":
if reduction == "sum":
return x.sum(dim=0)
raise ValueError("Unknown reduction. Expected one of {'none', 'mean', 'sum'}")

Expand Down
2 changes: 1 addition & 1 deletion TTS/utils/synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def tts(
waveform = waveform.squeeze()

# trim silence
if self.tts_config.audio["do_trim_silence"] is True:
if "do_trim_silence" in self.tts_config.audio and self.tts_config.audio["do_trim_silence"]:
waveform = trim_silence(waveform, self.tts_model.ap)

wavs += list(waveform)
Expand Down
1 change: 0 additions & 1 deletion recipes/ljspeech/fast_pitch/train_fast_pitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
print_step=50,
print_eval=False,
mixed_precision=False,
sort_by_audio_len=True,
max_seq_len=500000,
output_path=output_path,
datasets=[dataset_config],
Expand Down
1 change: 0 additions & 1 deletion recipes/ljspeech/fast_speech/train_fast_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
print_step=50,
print_eval=False,
mixed_precision=False,
sort_by_audio_len=True,
max_seq_len=500000,
output_path=output_path,
datasets=[dataset_config],
Expand Down
1 change: 0 additions & 1 deletion recipes/ljspeech/speedy_speech/train_speedy_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
print_step=50,
print_eval=False,
mixed_precision=False,
sort_by_audio_len=True,
max_seq_len=500000,
output_path=output_path,
datasets=[dataset_config],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@
print_step=25,
print_eval=True,
mixed_precision=False,
sort_by_audio_len=True,
seq_len_norm=True,
output_path=output_path,
datasets=[dataset_config],
Expand Down
23 changes: 5 additions & 18 deletions recipes/ljspeech/vits_tts/train_vits.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,19 @@

from trainer import Trainer, TrainerArgs

from TTS.config.shared_configs import BaseAudioConfig
from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.vits import Vits
from TTS.tts.models.vits import Vits, VitsAudioConfig
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor

output_path = os.path.dirname(os.path.abspath(__file__))
dataset_config = BaseDatasetConfig(
name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")
)
audio_config = BaseAudioConfig(
sample_rate=22050,
win_length=1024,
hop_length=256,
num_mels=80,
preemphasis=0.0,
ref_level_db=20,
log_func="np.log",
do_trim_silence=True,
trim_db=45,
mel_fmin=0,
mel_fmax=None,
spec_gain=1.0,
signal_norm=False,
do_amp_to_db_linear=False,
audio_config = VitsAudioConfig(
sample_rate=22050, win_length=1024, hop_length=256, num_mels=80, mel_fmin=0, mel_fmax=None
)

config = VitsConfig(
Expand All @@ -37,7 +23,7 @@
batch_size=32,
eval_batch_size=16,
batch_group_size=5,
num_loader_workers=0,
num_loader_workers=8,
num_eval_loader_workers=4,
run_eval=True,
test_delay_epochs=-1,
Expand All @@ -52,6 +38,7 @@
mixed_precision=True,
output_path=output_path,
datasets=[dataset_config],
cudnn_benchmark=False,
)

# INITIALIZE THE AUDIO PROCESSOR
Expand Down
15 changes: 2 additions & 13 deletions recipes/multilingual/vits_tts/train_vits_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@

from trainer import Trainer, TrainerArgs

from TTS.config.shared_configs import BaseAudioConfig
from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.vits import CharactersConfig, Vits, VitsArgs
from TTS.tts.models.vits import CharactersConfig, Vits, VitsArgs, VitsAudioConfig
from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
Expand All @@ -22,22 +21,13 @@
for path in dataset_paths
]

audio_config = BaseAudioConfig(
audio_config = VitsAudioConfig(
sample_rate=16000,
win_length=1024,
hop_length=256,
num_mels=80,
preemphasis=0.0,
ref_level_db=20,
log_func="np.log",
do_trim_silence=False,
trim_db=23.0,
mel_fmin=0,
mel_fmax=None,
spec_gain=1.0,
signal_norm=True,
do_amp_to_db_linear=False,
resample=False,
)

vitsArgs = VitsArgs(
Expand Down Expand Up @@ -69,7 +59,6 @@
use_language_weighted_sampler=True,
print_eval=False,
mixed_precision=False,
sort_by_audio_len=True,
min_audio_len=32 * 256 * 4,
max_audio_len=160000,
output_path=output_path,
Expand Down
1 change: 0 additions & 1 deletion recipes/thorsten_DE/speedy_speech/train_speedy_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
"Dieser Kuchen ist großartig. Er ist so lecker und feucht.",
"Vor dem 22. November 1963.",
],
sort_by_audio_len=True,
max_seq_len=500000,
output_path=output_path,
datasets=[dataset_config],
Expand Down
13 changes: 2 additions & 11 deletions recipes/thorsten_DE/vits_tts/train_vits.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

from trainer import Trainer, TrainerArgs

from TTS.config.shared_configs import BaseAudioConfig
from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.vits import Vits
from TTS.tts.models.vits import Vits, VitsAudioConfig
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor
from TTS.utils.downloaders import download_thorsten_de
Expand All @@ -21,21 +20,13 @@
print("Downloading dataset")
download_thorsten_de(os.path.split(os.path.abspath(dataset_config.path))[0])

audio_config = BaseAudioConfig(
audio_config = VitsAudioConfig(
sample_rate=22050,
win_length=1024,
hop_length=256,
num_mels=80,
preemphasis=0.0,
ref_level_db=20,
log_func="np.log",
do_trim_silence=True,
trim_db=45,
mel_fmin=0,
mel_fmax=None,
spec_gain=1.0,
signal_norm=False,
do_amp_to_db_linear=False,
)

config = VitsConfig(
Expand Down
22 changes: 4 additions & 18 deletions recipes/vctk/vits/train_vits.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

from trainer import Trainer, TrainerArgs

from TTS.config.shared_configs import BaseAudioConfig
from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.vits import Vits, VitsArgs
from TTS.tts.models.vits import Vits, VitsArgs, VitsAudioConfig
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor
Expand All @@ -17,22 +16,8 @@
)


audio_config = BaseAudioConfig(
sample_rate=22050,
win_length=1024,
hop_length=256,
num_mels=80,
preemphasis=0.0,
ref_level_db=20,
log_func="np.log",
do_trim_silence=True,
trim_db=23.0,
mel_fmin=0,
mel_fmax=None,
spec_gain=1.0,
signal_norm=False,
do_amp_to_db_linear=False,
resample=True,
audio_config = VitsAudioConfig(
sample_rate=22050, win_length=1024, hop_length=256, num_mels=80, mel_fmin=0, mel_fmax=None
)

vitsArgs = VitsArgs(
Expand Down Expand Up @@ -62,6 +47,7 @@
max_text_len=325, # change this if you have a larger VRAM than 16GB
output_path=output_path,
datasets=[dataset_config],
cudnn_benchmark=False,
)

# INITIALIZE THE AUDIO PROCESSOR
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def pip_install(package_name):
# ext_modules=find_cython_extensions(),
# package
include_package_data=True,
packages=find_packages(include=["TTS*"]),
packages=find_packages(include=["TTS"], exclude=["*.tests", "*tests.*", "tests.*", "*tests", "tests"]),
package_data={
"TTS": [
"VERSION",
Expand Down
28 changes: 24 additions & 4 deletions tests/tts_tests/test_vits.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,17 @@
from TTS.config import load_config
from TTS.encoder.utils.generic_utils import setup_encoder_model
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.models.vits import Vits, VitsArgs, amp_to_db, db_to_amp, load_audio, spec_to_mel, wav_to_mel, wav_to_spec
from TTS.tts.models.vits import (
Vits,
VitsArgs,
VitsAudioConfig,
amp_to_db,
db_to_amp,
load_audio,
spec_to_mel,
wav_to_mel,
wav_to_spec,
)
from TTS.tts.utils.speakers import SpeakerManager

LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json")
Expand Down Expand Up @@ -421,16 +431,18 @@ def test_train_step(self):
self._check_parameter_changes(model, model_ref)

def test_train_step_upsampling(self):
"""Upsampling by the decoder upsampling layers"""
# setup the model
with torch.autograd.set_detect_anomaly(True):
audio_config = VitsAudioConfig(sample_rate=22050)
model_args = VitsArgs(
num_chars=32,
spec_segment_size=10,
encoder_sample_rate=11025,
interpolate_z=False,
upsample_rates_decoder=[8, 8, 4, 2],
)
config = VitsConfig(model_args=model_args)
config = VitsConfig(model_args=model_args, audio=audio_config)
model = Vits(config).to(device)
model.train()
# model to train
Expand Down Expand Up @@ -459,10 +471,18 @@ def test_train_step_upsampling(self):
self._check_parameter_changes(model, model_ref)

def test_train_step_upsampling_interpolation(self):
"""Upsampling by interpolation"""
# setup the model
with torch.autograd.set_detect_anomaly(True):
model_args = VitsArgs(num_chars=32, spec_segment_size=10, encoder_sample_rate=11025, interpolate_z=True)
config = VitsConfig(model_args=model_args)
audio_config = VitsAudioConfig(sample_rate=22050)
model_args = VitsArgs(
num_chars=32,
spec_segment_size=10,
encoder_sample_rate=11025,
interpolate_z=True,
upsample_rates_decoder=[8, 8, 2, 2],
)
config = VitsConfig(model_args=model_args, audio=audio_config)
model = Vits(config).to(device)
model.train()
# model to train
Expand Down