diff --git a/.github/workflows/pypi-release.yml b/.github/workflows/pypi-release.yml index 83797be17b..fc990826a7 100644 --- a/.github/workflows/pypi-release.yml +++ b/.github/workflows/pypi-release.yml @@ -42,16 +42,18 @@ jobs: - uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} - - run: | + - name: Install pip requirements + run: | python -m pip install -U pip setuptools wheel build - - run: | - python -m build - - run: | - python -m pip install dist/*.whl + python -m pip install -r requirements.txt + - name: Setup and install manylinux1_x86_64 wheel + run: | + python setup.py bdist_wheel --plat-name=manylinux1_x86_64 + python -m pip install dist/*-manylinux*.whl - uses: actions/upload-artifact@v2 with: name: wheel-${{ matrix.python-version }} - path: dist/*.whl + path: dist/*-manylinux*.whl publish-artifacts: runs-on: ubuntu-20.04 needs: [build-sdist, build-wheels] diff --git a/MANIFEST.in b/MANIFEST.in index 82ecadcb48..321d3999c1 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -11,4 +11,5 @@ recursive-include TTS *.md recursive-include TTS *.py recursive-include TTS *.pyx recursive-include images *.png - +recursive-exclude tests * +prune tests* diff --git a/README.md b/README.md index 8ed67c308f..1ca585d7ba 100644 --- a/README.md +++ b/README.md @@ -130,7 +130,7 @@ pip install -e .[all,dev,notebooks] # Select the relevant extras If you are on Ubuntu (Debian), you can also run following commands for installation. ```bash -$ make system-deps # intended to be used on Ubuntu (Debian). Let us know if you have a diffent OS. +$ make system-deps # intended to be used on Ubuntu (Debian). Let us know if you have a different OS. $ make install ``` @@ -145,25 +145,61 @@ If you are on Windows, 👑@GuyPaddock wrote installation instructions [here](ht ``` $ tts --list_models ``` - +- Get model info (for both tts_models and vocoder_models): + - Query by type/name: + The model_info_by_name uses the name as it from the --list_models. + ``` + $ tts --model_info_by_name "///" + ``` + For example: + + ``` + $ tts --model_info_by_name tts_models/tr/common-voice/glow-tts + ``` + ``` + $ tts --model_info_by_name vocoder_models/en/ljspeech/hifigan_v2 + ``` + - Query by type/idx: + The model_query_idx uses the corresponding idx from --list_models. + ``` + $ tts --model_info_by_idx "/" + ``` + For example: + + ``` + $ tts --model_info_by_idx tts_models/3 + ``` + - Run TTS with default models: ``` - $ tts --text "Text for TTS" + $ tts --text "Text for TTS" --out_path output/path/speech.wav ``` - Run a TTS model with its default vocoder model: ``` - $ tts --text "Text for TTS" --model_name "// + $ tts --text "Text for TTS" --model_name "///" --out_path output/path/speech.wav + ``` + For example: + + ``` + $ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --out_path output/path/speech.wav ``` - Run with specific TTS and vocoder models from the list: ``` - $ tts --text "Text for TTS" --model_name "//" --vocoder_name "//" --output_path + $ tts --text "Text for TTS" --model_name "///" --vocoder_name "///" --out_path output/path/speech.wav ``` + For example: + + ``` + $ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --vocoder_name "vocoder_models/en/ljspeech/univnet" --out_path output/path/speech.wav + ``` + + - Run your own TTS model (Using Griffin-Lim Vocoder): ``` diff --git a/TTS/.models.json b/TTS/.models.json index 93d9f417be..e12e7a0b31 100644 --- a/TTS/.models.json +++ b/TTS/.models.json @@ -215,6 +215,14 @@ "author": "@thorstenMueller", "license": "apache 2.0", "commit": "unknown" + }, + "tacotron2-DDC": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--de--thorsten--tacotron2-DDC.zip", + "default_vocoder": "vocoder_models/de/thorsten/hifigan_v1", + "description": "Thorsten-Dec2021-22k-DDC", + "author": "@thorstenMueller", + "license": "apache 2.0", + "commit": "unknown" } } }, @@ -460,6 +468,13 @@ "author": "@thorstenMueller", "license": "apache 2.0", "commit": "unknown" + }, + "hifigan_v1": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/vocoder_models--de--thorsten--hifigan_v1.zip", + "description": "HifiGAN vocoder model for Thorsten Neutral Dec2021 22k Samplerate Tacotron2 DDC model", + "author": "@thorstenMueller", + "license": "apache 2.0", + "commit": "unknown" } } }, diff --git a/TTS/VERSION b/TTS/VERSION index 7deb86fee4..8adc70fdd9 100644 --- a/TTS/VERSION +++ b/TTS/VERSION @@ -1 +1 @@ -0.7.1 \ No newline at end of file +0.8.0 \ No newline at end of file diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index 7c6098909b..9a95651ade 100755 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -60,13 +60,13 @@ def main(): - Run a TTS model with its default vocoder model: ``` - $ tts --text "Text for TTS" --model_name "//" + $ tts --text "Text for TTS" --model_name "/// ``` - Run with specific TTS and vocoder models from the list: ``` - $ tts --text "Text for TTS" --model_name "//" --vocoder_name "//" --output_path + $ tts --text "Text for TTS" --model_name "///" --vocoder_name "///" --output_path ``` - Run your own TTS model (Using Griffin-Lim Vocoder): diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index d28f188e75..f2e7779c0c 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -13,13 +13,13 @@ from TTS.encoder.dataset import EncoderDataset from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_encoder_model -from TTS.encoder.utils.samplers import PerfectBatchSampler from TTS.encoder.utils.training import init_training from TTS.encoder.utils.visual import plot_embeddings from TTS.tts.datasets import load_tts_samples from TTS.utils.audio import AudioProcessor from TTS.utils.generic_utils import count_parameters, remove_experiment_folder from TTS.utils.io import copy_model_files +from TTS.utils.samplers import PerfectBatchSampler from TTS.utils.training import check_update torch.backends.cudnn.enabled = True diff --git a/TTS/bin/tune_wavegrad.py b/TTS/bin/tune_wavegrad.py index a31d6c4548..09582cea7c 100644 --- a/TTS/bin/tune_wavegrad.py +++ b/TTS/bin/tune_wavegrad.py @@ -1,4 +1,4 @@ -"""Search a good noise schedule for WaveGrad for a given number of inferece iterations""" +"""Search a good noise schedule for WaveGrad for a given number of inference iterations""" import argparse from itertools import product as cartesian_product @@ -7,94 +7,97 @@ from torch.utils.data import DataLoader from tqdm import tqdm +from TTS.config import load_config from TTS.utils.audio import AudioProcessor -from TTS.utils.io import load_config from TTS.vocoder.datasets.preprocess import load_wav_data from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset -from TTS.vocoder.utils.generic_utils import setup_generator +from TTS.vocoder.models import setup_model -parser = argparse.ArgumentParser() -parser.add_argument("--model_path", type=str, help="Path to model checkpoint.") -parser.add_argument("--config_path", type=str, help="Path to model config file.") -parser.add_argument("--data_path", type=str, help="Path to data directory.") -parser.add_argument("--output_path", type=str, help="path for output file including file name and extension.") -parser.add_argument( - "--num_iter", type=int, help="Number of model inference iterations that you like to optimize noise schedule for." -) -parser.add_argument("--use_cuda", type=bool, help="enable/disable CUDA.") -parser.add_argument("--num_samples", type=int, default=1, help="Number of datasamples used for inference.") -parser.add_argument( - "--search_depth", - type=int, - default=3, - help="Search granularity. Increasing this increases the run-time exponentially.", -) +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model_path", type=str, help="Path to model checkpoint.") + parser.add_argument("--config_path", type=str, help="Path to model config file.") + parser.add_argument("--data_path", type=str, help="Path to data directory.") + parser.add_argument("--output_path", type=str, help="path for output file including file name and extension.") + parser.add_argument( + "--num_iter", + type=int, + help="Number of model inference iterations that you like to optimize noise schedule for.", + ) + parser.add_argument("--use_cuda", action="store_true", help="enable CUDA.") + parser.add_argument("--num_samples", type=int, default=1, help="Number of datasamples used for inference.") + parser.add_argument( + "--search_depth", + type=int, + default=3, + help="Search granularity. Increasing this increases the run-time exponentially.", + ) -# load config -args = parser.parse_args() -config = load_config(args.config_path) + # load config + args = parser.parse_args() + config = load_config(args.config_path) -# setup audio processor -ap = AudioProcessor(**config.audio) + # setup audio processor + ap = AudioProcessor(**config.audio) -# load dataset -_, train_data = load_wav_data(args.data_path, 0) -train_data = train_data[: args.num_samples] -dataset = WaveGradDataset( - ap=ap, - items=train_data, - seq_len=-1, - hop_len=ap.hop_length, - pad_short=config.pad_short, - conv_pad=config.conv_pad, - is_training=True, - return_segments=False, - use_noise_augment=False, - use_cache=False, - verbose=True, -) -loader = DataLoader( - dataset, - batch_size=1, - shuffle=False, - collate_fn=dataset.collate_full_clips, - drop_last=False, - num_workers=config.num_loader_workers, - pin_memory=False, -) + # load dataset + _, train_data = load_wav_data(args.data_path, 0) + train_data = train_data[: args.num_samples] + dataset = WaveGradDataset( + ap=ap, + items=train_data, + seq_len=-1, + hop_len=ap.hop_length, + pad_short=config.pad_short, + conv_pad=config.conv_pad, + is_training=True, + return_segments=False, + use_noise_augment=False, + use_cache=False, + verbose=True, + ) + loader = DataLoader( + dataset, + batch_size=1, + shuffle=False, + collate_fn=dataset.collate_full_clips, + drop_last=False, + num_workers=config.num_loader_workers, + pin_memory=False, + ) -# setup the model -model = setup_generator(config) -if args.use_cuda: - model.cuda() + # setup the model + model = setup_model(config) + if args.use_cuda: + model.cuda() -# setup optimization parameters -base_values = sorted(10 * np.random.uniform(size=args.search_depth)) -print(base_values) -exponents = 10 ** np.linspace(-6, -1, num=args.num_iter) -best_error = float("inf") -best_schedule = None -total_search_iter = len(base_values) ** args.num_iter -for base in tqdm(cartesian_product(base_values, repeat=args.num_iter), total=total_search_iter): - beta = exponents * base - model.compute_noise_level(beta) - for data in loader: - mel, audio = data - y_hat = model.inference(mel.cuda() if args.use_cuda else mel) + # setup optimization parameters + base_values = sorted(10 * np.random.uniform(size=args.search_depth)) + print(f" > base values: {base_values}") + exponents = 10 ** np.linspace(-6, -1, num=args.num_iter) + best_error = float("inf") + best_schedule = None # pylint: disable=C0103 + total_search_iter = len(base_values) ** args.num_iter + for base in tqdm(cartesian_product(base_values, repeat=args.num_iter), total=total_search_iter): + beta = exponents * base + model.compute_noise_level(beta) + for data in loader: + mel, audio = data + y_hat = model.inference(mel.cuda() if args.use_cuda else mel) - if args.use_cuda: - y_hat = y_hat.cpu() - y_hat = y_hat.numpy() + if args.use_cuda: + y_hat = y_hat.cpu() + y_hat = y_hat.numpy() - mel_hat = [] - for i in range(y_hat.shape[0]): - m = ap.melspectrogram(y_hat[i, 0])[:, :-1] - mel_hat.append(torch.from_numpy(m)) + mel_hat = [] + for i in range(y_hat.shape[0]): + m = ap.melspectrogram(y_hat[i, 0])[:, :-1] + mel_hat.append(torch.from_numpy(m)) - mel_hat = torch.stack(mel_hat) - mse = torch.sum((mel - mel_hat) ** 2).mean() - if mse.item() < best_error: - best_error = mse.item() - best_schedule = {"beta": beta} - print(f" > Found a better schedule. - MSE: {mse.item()}") - np.save(args.output_path, best_schedule) + mel_hat = torch.stack(mel_hat) + mse = torch.sum((mel - mel_hat) ** 2).mean() + if mse.item() < best_error: + best_error = mse.item() + best_schedule = {"beta": beta} + print(f" > Found a better schedule. - MSE: {mse.item()}") + np.save(args.output_path, best_schedule) diff --git a/TTS/config/__init__.py b/TTS/config/__init__.py index 6b0778c5a7..067c32d97d 100644 --- a/TTS/config/__init__.py +++ b/TTS/config/__init__.py @@ -62,7 +62,7 @@ def _process_model_name(config_dict: Dict) -> str: return model_name -def load_config(config_path: str) -> None: +def load_config(config_path: str) -> Coqpit: """Import `json` or `yaml` files as TTS configs. First, load the input file as a `dict` and check the model name to find the corresponding Config class. Then initialize the Config. diff --git a/TTS/encoder/models/resnet.py b/TTS/encoder/models/resnet.py index 84e9967f84..e75ab6c463 100644 --- a/TTS/encoder/models/resnet.py +++ b/TTS/encoder/models/resnet.py @@ -1,7 +1,7 @@ import torch from torch import nn -# from TTS.utils.audio import TorchSTFT +# from TTS.utils.audio.torch_transforms import TorchSTFT from TTS.encoder.models.base_encoder import BaseEncoder diff --git a/TTS/tts/configs/shared_configs.py b/TTS/tts/configs/shared_configs.py index 4704687c26..e1ea8be3da 100644 --- a/TTS/tts/configs/shared_configs.py +++ b/TTS/tts/configs/shared_configs.py @@ -200,9 +200,6 @@ class BaseTTSConfig(BaseTrainingConfig): loss_masking (bool): enable / disable masking loss values against padded segments of samples in a batch. - sort_by_audio_len (bool): - If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `False`. - min_text_len (int): Minimum length of input text to be used. All shorter samples will be ignored. Defaults to 0. @@ -303,7 +300,6 @@ class BaseTTSConfig(BaseTrainingConfig): batch_group_size: int = 0 loss_masking: bool = None # dataloading - sort_by_audio_len: bool = False min_audio_len: int = 1 max_audio_len: int = float("inf") min_text_len: int = 1 diff --git a/TTS/tts/configs/tacotron_config.py b/TTS/tts/configs/tacotron_config.py index e25609ffcf..350b5ea996 100644 --- a/TTS/tts/configs/tacotron_config.py +++ b/TTS/tts/configs/tacotron_config.py @@ -53,7 +53,7 @@ class TacotronConfig(BaseTTSConfig): enable /disable the Stopnet that predicts the end of the decoder sequence. Defaults to True. stopnet_pos_weight (float): Weight that is applied to over-weight positive instances in the Stopnet loss. Use larger values with - datasets with longer sentences. Defaults to 10. + datasets with longer sentences. Defaults to 0.2. max_decoder_steps (int): Max number of steps allowed for the decoder. Defaults to 50. encoder_in_features (int): @@ -161,8 +161,8 @@ class TacotronConfig(BaseTTSConfig): prenet_dropout_at_inference: bool = False stopnet: bool = True separate_stopnet: bool = True - stopnet_pos_weight: float = 10.0 - max_decoder_steps: int = 500 + stopnet_pos_weight: float = 0.2 + max_decoder_steps: int = 10000 encoder_in_features: int = 256 decoder_in_features: int = 256 decoder_output_dim: int = 80 diff --git a/TTS/tts/configs/vits_config.py b/TTS/tts/configs/vits_config.py index a8c7f91dcd..3469f701ae 100644 --- a/TTS/tts/configs/vits_config.py +++ b/TTS/tts/configs/vits_config.py @@ -2,7 +2,7 @@ from typing import List from TTS.tts.configs.shared_configs import BaseTTSConfig -from TTS.tts.models.vits import VitsArgs +from TTS.tts.models.vits import VitsArgs, VitsAudioConfig @dataclass @@ -16,6 +16,9 @@ class VitsConfig(BaseTTSConfig): model_args (VitsArgs): Model architecture arguments. Defaults to `VitsArgs()`. + audio (VitsAudioConfig): + Audio processing configuration. Defaults to `VitsAudioConfig()`. + grad_clip (List): Gradient clipping thresholds for each optimizer. Defaults to `[1000.0, 1000.0]`. @@ -67,6 +70,18 @@ class VitsConfig(BaseTTSConfig): compute_linear_spec (bool): If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`. + use_weighted_sampler (bool): + If true, use weighted sampler with bucketing for balancing samples between datasets used in training. Defaults to `False`. + + weighted_sampler_attrs (dict): + Key retuned by the formatter to be used for weighted sampler. For example `{"root_path": 2.0, "speaker_name": 1.0}` sets sample probabilities + by overweighting `root_path` by 2.0. Defaults to `{}`. + + weighted_sampler_multipliers (dict): + Weight each unique value of a key returned by the formatter for weighted sampling. + For example `{"root_path":{"/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-100/":1.0, "/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-360/": 0.5}`. + It will sample instances from `train-clean-100` 2 times more than `train-clean-360`. Defaults to `{}`. + r (int): Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`. @@ -94,6 +109,7 @@ class VitsConfig(BaseTTSConfig): model: str = "vits" # model specific params model_args: VitsArgs = field(default_factory=VitsArgs) + audio: VitsAudioConfig = VitsAudioConfig() # optimizer grad_clip: List[float] = field(default_factory=lambda: [1000, 1000]) @@ -120,6 +136,11 @@ class VitsConfig(BaseTTSConfig): return_wav: bool = True compute_linear_spec: bool = True + # sampler params + use_weighted_sampler: bool = False # TODO: move it to the base config + weighted_sampler_attrs: dict = field(default_factory=lambda: {}) + weighted_sampler_multipliers: dict = field(default_factory=lambda: {}) + # overrides r: int = 1 # DO NOT CHANGE add_blank: bool = True diff --git a/TTS/tts/datasets/formatters.py b/TTS/tts/datasets/formatters.py index ef05ea7c7a..a4be2b33ef 100644 --- a/TTS/tts/datasets/formatters.py +++ b/TTS/tts/datasets/formatters.py @@ -34,6 +34,7 @@ def coqui(root_path, meta_file, ignored_speakers=None): "audio_file": audio_path, "speaker_name": speaker_name if speaker_name is not None else row.speaker_name, "emotion_name": emotion_name if emotion_name is not None else row.emotion_name, + "root_path": root_path, } ) if not_found_counter > 0: @@ -53,7 +54,7 @@ def tweb(root_path, meta_file, **kwargs): # pylint: disable=unused-argument cols = line.split("\t") wav_file = os.path.join(root_path, cols[0] + ".wav") text = cols[1] - items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) return items @@ -68,7 +69,7 @@ def mozilla(root_path, meta_file, **kwargs): # pylint: disable=unused-argument wav_file = cols[1].strip() text = cols[0].strip() wav_file = os.path.join(root_path, "wavs", wav_file) - items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) return items @@ -84,7 +85,7 @@ def mozilla_de(root_path, meta_file, **kwargs): # pylint: disable=unused-argume text = cols[1].strip() folder_name = f"BATCH_{wav_file.split('_')[0]}_FINAL" wav_file = os.path.join(root_path, folder_name, wav_file) - items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) return items @@ -130,7 +131,9 @@ def mailabs(root_path, meta_files=None, ignored_speakers=None): wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), "wavs", cols[0] + ".wav") if os.path.isfile(wav_file): text = cols[1].strip() - items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) + items.append( + {"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path} + ) else: # M-AI-Labs have some missing samples, so just print the warning print("> File %s does not exist!" % (wav_file)) @@ -148,7 +151,7 @@ def ljspeech(root_path, meta_file, **kwargs): # pylint: disable=unused-argument cols = line.split("|") wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") text = cols[2] - items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) return items @@ -166,7 +169,9 @@ def ljspeech_test(root_path, meta_file, **kwargs): # pylint: disable=unused-arg cols = line.split("|") wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") text = cols[2] - items.append({"text": text, "audio_file": wav_file, "speaker_name": f"ljspeech-{speaker_id}"}) + items.append( + {"text": text, "audio_file": wav_file, "speaker_name": f"ljspeech-{speaker_id}", "root_path": root_path} + ) return items @@ -181,7 +186,7 @@ def thorsten(root_path, meta_file, **kwargs): # pylint: disable=unused-argument cols = line.split("|") wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") text = cols[1] - items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) return items @@ -198,7 +203,7 @@ def sam_accenture(root_path, meta_file, **kwargs): # pylint: disable=unused-arg if not os.path.exists(wav_file): print(f" [!] {wav_file} in metafile does not exist. Skipping...") continue - items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) return items @@ -213,7 +218,7 @@ def ruslan(root_path, meta_file, **kwargs): # pylint: disable=unused-argument cols = line.split("|") wav_file = os.path.join(root_path, "RUSLAN", cols[0] + ".wav") text = cols[1] - items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) return items @@ -261,7 +266,9 @@ def common_voice(root_path, meta_file, ignored_speakers=None): if speaker_name in ignored_speakers: continue wav_file = os.path.join(root_path, "clips", cols[1].replace(".mp3", ".wav")) - items.append({"text": text, "audio_file": wav_file, "speaker_name": "MCV_" + speaker_name}) + items.append( + {"text": text, "audio_file": wav_file, "speaker_name": "MCV_" + speaker_name, "root_path": root_path} + ) return items @@ -288,7 +295,14 @@ def libri_tts(root_path, meta_files=None, ignored_speakers=None): if isinstance(ignored_speakers, list): if speaker_name in ignored_speakers: continue - items.append({"text": text, "audio_file": wav_file, "speaker_name": f"LTTS_{speaker_name}"}) + items.append( + { + "text": text, + "audio_file": wav_file, + "speaker_name": f"LTTS_{speaker_name}", + "root_path": root_path, + } + ) for item in items: assert os.path.exists(item["audio_file"]), f" [!] wav files don't exist - {item['audio_file']}" return items @@ -307,7 +321,7 @@ def custom_turkish(root_path, meta_file, **kwargs): # pylint: disable=unused-ar skipped_files.append(wav_file) continue text = cols[1].strip() - items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) print(f" [!] {len(skipped_files)} files skipped. They don't exist...") return items @@ -329,7 +343,7 @@ def brspeech(root_path, meta_file, ignored_speakers=None): if isinstance(ignored_speakers, list): if speaker_id in ignored_speakers: continue - items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_id}) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_id, "root_path": root_path}) return items @@ -372,7 +386,9 @@ def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic else: wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + f"_{mic}.{file_ext}") if os.path.exists(wav_file): - items.append({"text": text, "audio_file": wav_file, "speaker_name": "VCTK_" + speaker_id}) + items.append( + {"text": text, "audio_file": wav_file, "speaker_name": "VCTK_" + speaker_id, "root_path": root_path} + ) else: print(f" [!] wav files don't exist - {wav_file}") return items @@ -392,7 +408,9 @@ def vctk_old(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=Non with open(meta_file, "r", encoding="utf-8") as file_text: text = file_text.readlines()[0] wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav") - items.append({"text": text, "audio_file": wav_file, "speaker_name": "VCTK_old_" + speaker_id}) + items.append( + {"text": text, "audio_file": wav_file, "speaker_name": "VCTK_old_" + speaker_id, "root_path": root_path} + ) return items @@ -411,7 +429,7 @@ def synpaflex(root_path, metafiles=None, **kwargs): # pylint: disable=unused-ar if os.path.exists(txt_file) and os.path.exists(wav_file): with open(txt_file, "r", encoding="utf-8") as file_text: text = file_text.readlines()[0] - items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) return items @@ -433,7 +451,7 @@ def open_bible(root_path, meta_files="train", ignore_digits_sentences=True, igno if ignore_digits_sentences and any(map(str.isdigit, text)): continue wav_file = os.path.join(root_path, split_dir, speaker_id, file_id + ".flac") - items.append({"text": text, "audio_file": wav_file, "speaker_name": "OB_" + speaker_id}) + items.append({"text": text, "audio_file": wav_file, "speaker_name": "OB_" + speaker_id, "root_path": root_path}) return items @@ -450,7 +468,9 @@ def mls(root_path, meta_files=None, ignored_speakers=None): if isinstance(ignored_speakers, list): if speaker in ignored_speakers: continue - items.append({"text": text, "audio_file": wav_file, "speaker_name": "MLS_" + speaker}) + items.append( + {"text": text, "audio_file": wav_file, "speaker_name": "MLS_" + speaker, "root_path": root_path} + ) return items @@ -520,7 +540,9 @@ def emotion(root_path, meta_file, ignored_speakers=None): if isinstance(ignored_speakers, list): if speaker_id in ignored_speakers: continue - items.append({"audio_file": wav_file, "speaker_name": speaker_id, "emotion_name": emotion_id}) + items.append( + {"audio_file": wav_file, "speaker_name": speaker_id, "emotion_name": emotion_id, "root_path": root_path} + ) return items @@ -540,7 +562,7 @@ def baker(root_path: str, meta_file: str, **kwargs) -> List[List[str]]: # pylin for line in ttf: wav_name, text = line.rstrip("\n").split("|") wav_path = os.path.join(root_path, "clips_22", wav_name) - items.append({"text": text, "audio_file": wav_path, "speaker_name": speaker_name}) + items.append({"text": text, "audio_file": wav_path, "speaker_name": speaker_name, "root_path": root_path}) return items @@ -554,5 +576,5 @@ def kokoro(root_path, meta_file, **kwargs): # pylint: disable=unused-argument cols = line.split("|") wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") text = cols[2].replace(" ", "") - items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) return items diff --git a/TTS/tts/layers/generic/wavenet.py b/TTS/tts/layers/generic/wavenet.py index aeb45c7bcd..613ad19d59 100644 --- a/TTS/tts/layers/generic/wavenet.py +++ b/TTS/tts/layers/generic/wavenet.py @@ -67,9 +67,14 @@ def __init__( for i in range(num_layers): dilation = dilation_rate**i padding = int((kernel_size * dilation - dilation) / 2) - in_layer = torch.nn.Conv1d( - hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding - ) + if i == 0: + in_layer = torch.nn.Conv1d( + in_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding + ) + else: + in_layer = torch.nn.Conv1d( + hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding + ) in_layer = torch.nn.utils.weight_norm(in_layer, name="weight") self.in_layers.append(in_layer) diff --git a/TTS/tts/layers/glow_tts/decoder.py b/TTS/tts/layers/glow_tts/decoder.py index f57c37314c..61c5174ac5 100644 --- a/TTS/tts/layers/glow_tts/decoder.py +++ b/TTS/tts/layers/glow_tts/decoder.py @@ -29,11 +29,11 @@ def squeeze(x, x_mask=None, num_sqz=2): def unsqueeze(x, x_mask=None, num_sqz=2): - """GlowTTS unsqueeze operation + """GlowTTS unsqueeze operation (revert the squeeze) Note: each 's' is a n-dimensional vector. - ``[[s1, s3, s5], [s2, s4, s6]] --> [[s1, s3, s5], [s2, s4, s6]]`` + ``[[s1, s3, s5], [s2, s4, s6]] --> [[s1, s3, s5, s2, s4, s6]]`` """ b, c, t = x.size() diff --git a/TTS/tts/layers/glow_tts/glow.py b/TTS/tts/layers/glow_tts/glow.py index ff1b99e8ec..3b745018a2 100644 --- a/TTS/tts/layers/glow_tts/glow.py +++ b/TTS/tts/layers/glow_tts/glow.py @@ -197,7 +197,7 @@ def __init__( end.bias.data.zero_() self.end = end # coupling layers - self.wn = WN(in_channels, hidden_channels, kernel_size, dilation_rate, num_layers, c_in_channels, dropout_p) + self.wn = WN(hidden_channels, hidden_channels, kernel_size, dilation_rate, num_layers, c_in_channels, dropout_p) def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs): # pylint: disable=unused-argument """ diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 1f0961b303..9933df6bea 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -7,8 +7,8 @@ from torch.nn import functional from TTS.tts.utils.helpers import sequence_mask -from TTS.tts.utils.ssim import ssim -from TTS.utils.audio import TorchSTFT +from TTS.tts.utils.ssim import SSIMLoss as _SSIMLoss +from TTS.utils.audio.torch_transforms import TorchSTFT # pylint: disable=abstract-method @@ -91,30 +91,55 @@ class for each corresponding step. return loss +def sample_wise_min_max(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """Min-Max normalize tensor through first dimension + Shapes: + - x: :math:`[B, D1, D2]` + - m: :math:`[B, D1, 1]` + """ + maximum = torch.amax(x.masked_fill(~mask, 0), dim=(1, 2), keepdim=True) + minimum = torch.amin(x.masked_fill(~mask, np.inf), dim=(1, 2), keepdim=True) + return (x - minimum) / (maximum - minimum + 1e-8) + + class SSIMLoss(torch.nn.Module): - """SSIM loss as explained here https://en.wikipedia.org/wiki/Structural_similarity""" + """SSIM loss as (1 - SSIM) + SSIM is explained here https://en.wikipedia.org/wiki/Structural_similarity + """ def __init__(self): super().__init__() - self.loss_func = ssim + self.loss_func = _SSIMLoss() - def forward(self, y_hat, y, length=None): + def forward(self, y_hat, y, length): """ Args: y_hat (tensor): model prediction values. y (tensor): target values. - length (tensor): length of each sample in a batch. + length (tensor): length of each sample in a batch for masking. + Shapes: y_hat: B x T X D y: B x T x D length: B + Returns: loss: An average loss value in range [0, 1] masked by the length. """ - if length is not None: - m = sequence_mask(sequence_length=length, max_len=y.size(1)).unsqueeze(2).float().to(y_hat.device) - y_hat, y = y_hat * m, y * m - return 1 - self.loss_func(y_hat.unsqueeze(1), y.unsqueeze(1)) + mask = sequence_mask(sequence_length=length, max_len=y.size(1)).unsqueeze(2) + y_norm = sample_wise_min_max(y, mask) + y_hat_norm = sample_wise_min_max(y_hat, mask) + ssim_loss = self.loss_func((y_norm * mask).unsqueeze(1), (y_hat_norm * mask).unsqueeze(1)) + + if ssim_loss.item() > 1.0: + print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 1.0") + ssim_loss = torch.tensor(1.0, device=ssim_loss.device) + + if ssim_loss.item() < 0.0: + print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 0.0") + ssim_loss = torch.tensor(0.0, device=ssim_loss.device) + + return ssim_loss class AttentionEntropyLoss(nn.Module): @@ -123,9 +148,6 @@ def forward(self, align): """ Forces attention to be more decisive by penalizing soft attention weights - - TODO: arguments - TODO: unit_test """ entropy = torch.distributions.Categorical(probs=align).entropy() loss = (entropy / np.log(align.shape[1])).mean() @@ -133,9 +155,17 @@ def forward(self, align): class BCELossMasked(nn.Module): - def __init__(self, pos_weight): + """BCE loss with masking. + + Used mainly for stopnet in autoregressive models. + + Args: + pos_weight (float): weight for positive samples. If set < 1, penalize early stopping. Defaults to None. + """ + + def __init__(self, pos_weight: float = None): super().__init__() - self.pos_weight = pos_weight + self.pos_weight = nn.Parameter(torch.tensor([pos_weight]), requires_grad=False) def forward(self, x, target, length): """ @@ -155,16 +185,17 @@ class for each corresponding step. Returns: loss: An average loss value in range [0, 1] masked by the length. """ - # mask: (batch, max_len, 1) target.requires_grad = False if length is not None: - mask = sequence_mask(sequence_length=length, max_len=target.size(1)).float() - x = x * mask - target = target * mask + # mask: (batch, max_len, 1) + mask = sequence_mask(sequence_length=length, max_len=target.size(1)) num_items = mask.sum() + loss = functional.binary_cross_entropy_with_logits( + x.masked_select(mask), target.masked_select(mask), pos_weight=self.pos_weight, reduction="sum" + ) else: + loss = functional.binary_cross_entropy_with_logits(x, target, pos_weight=self.pos_weight, reduction="sum") num_items = torch.numel(x) - loss = functional.binary_cross_entropy_with_logits(x, target, pos_weight=self.pos_weight, reduction="sum") loss = loss / num_items return loss diff --git a/TTS/tts/layers/tacotron/capacitron_layers.py b/TTS/tts/layers/tacotron/capacitron_layers.py index 56fe44bc33..68321358d2 100644 --- a/TTS/tts/layers/tacotron/capacitron_layers.py +++ b/TTS/tts/layers/tacotron/capacitron_layers.py @@ -53,6 +53,7 @@ def forward(self, reference_mel_info=None, text_info=None, speaker_embedding=Non text_summary_out = self.text_summary_net(text_inputs, input_lengths).to(reference_mels.device) enc_out = torch.cat([enc_out, text_summary_out], dim=-1) if speaker_embedding is not None: + speaker_embedding = torch.squeeze(speaker_embedding) enc_out = torch.cat([enc_out, speaker_embedding], dim=-1) # Feed the output of the ref encoder and information about text/speaker into diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index c86bd391b4..df64429d8f 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -137,7 +137,7 @@ def get_aux_input_from_test_setences(self, sentence_info): if hasattr(self, "speaker_manager"): if config.use_d_vector_file: if speaker_name is None: - d_vector = self.speaker_manager.get_random_embeddings() + d_vector = self.speaker_manager.get_random_embedding() else: d_vector = self.speaker_manager.get_d_vector_by_name(speaker_name) elif config.use_speaker_embedding: diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index 7c0f95e151..cc241c43b8 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -514,7 +514,7 @@ def preprocess(self, y, y_lengths, y_max_length, attn=None): y = y[:, :, :y_max_length] if attn is not None: attn = attn[:, :, :, :y_max_length] - y_lengths = (y_lengths // self.num_squeeze) * self.num_squeeze + y_lengths = torch.div(y_lengths, self.num_squeeze, rounding_mode="floor") * self.num_squeeze return y, y_lengths, y_max_length, attn def store_inverse(self): diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index a6b1c74332..15fa297ff0 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -4,6 +4,7 @@ from itertools import chain from typing import Dict, List, Tuple, Union +import numpy as np import torch import torch.distributed as dist import torchaudio @@ -13,6 +14,8 @@ from torch.cuda.amp.autocast_mode import autocast from torch.nn import functional as F from torch.utils.data import DataLoader +from torch.utils.data.sampler import WeightedRandomSampler +from trainer.torch import DistributedSampler, DistributedSamplerWrapper from trainer.trainer_utils import get_optimizer, get_scheduler from TTS.tts.configs.shared_configs import CharactersConfig @@ -29,6 +32,8 @@ from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment +from TTS.utils.io import load_fsspec +from TTS.utils.samplers import BucketBatchSampler from TTS.vocoder.models.hifigan_generator import HifiganGenerator from TTS.vocoder.utils.generic_utils import plot_results @@ -200,11 +205,51 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm return spec +############################# +# CONFIGS +############################# + + +@dataclass +class VitsAudioConfig(Coqpit): + fft_size: int = 1024 + sample_rate: int = 22050 + win_length: int = 1024 + hop_length: int = 256 + num_mels: int = 80 + mel_fmin: int = 0 + mel_fmax: int = None + + ############################## # DATASET ############################## +def get_attribute_balancer_weights(items: list, attr_name: str, multi_dict: dict = None): + """Create inverse frequency weights for balancing the dataset. + Use `multi_dict` to scale relative weights.""" + attr_names_samples = np.array([item[attr_name] for item in items]) + unique_attr_names = np.unique(attr_names_samples).tolist() + attr_idx = [unique_attr_names.index(l) for l in attr_names_samples] + attr_count = np.array([len(np.where(attr_names_samples == l)[0]) for l in unique_attr_names]) + weight_attr = 1.0 / attr_count + dataset_samples_weight = np.array([weight_attr[l] for l in attr_idx]) + dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight) + if multi_dict is not None: + # check if all keys are in the multi_dict + for k in multi_dict: + assert k in unique_attr_names, f"{k} not in {unique_attr_names}" + # scale weights + multiplier_samples = np.array([multi_dict.get(item[attr_name], 1.0) for item in items]) + dataset_samples_weight *= multiplier_samples + return ( + torch.from_numpy(dataset_samples_weight).float(), + unique_attr_names, + np.unique(dataset_samples_weight).tolist(), + ) + + class VitsDataset(TTSDataset): def __init__(self, model_args, *args, **kwargs): super().__init__(*args, **kwargs) @@ -786,7 +831,7 @@ def on_init_end(self, trainer): # pylint: disable=W0613 print(" > Text Encoder was reinit.") def get_aux_input(self, aux_input: Dict): - sid, g, lid = self._set_cond_input(aux_input) + sid, g, lid, _ = self._set_cond_input(aux_input) return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid} def _freeze_layers(self): @@ -817,7 +862,7 @@ def _freeze_layers(self): @staticmethod def _set_cond_input(aux_input: Dict): """Set the speaker conditioning input based on the multi-speaker mode.""" - sid, g, lid = None, None, None + sid, g, lid, durations = None, None, None, None if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None: sid = aux_input["speaker_ids"] if sid.ndim == 0: @@ -832,7 +877,10 @@ def _set_cond_input(aux_input: Dict): if lid.ndim == 0: lid = lid.unsqueeze_(0) - return sid, g, lid + if "durations" in aux_input and aux_input["durations"] is not None: + durations = aux_input["durations"] + + return sid, g, lid, durations def _set_speaker_input(self, aux_input: Dict): d_vectors = aux_input.get("d_vectors", None) @@ -946,7 +994,7 @@ def forward( # pylint: disable=dangerous-default-value - syn_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]` """ outputs = {} - sid, g, lid = self._set_cond_input(aux_input) + sid, g, lid, _ = self._set_cond_input(aux_input) # speaker embedding if self.args.use_speaker_embedding and sid is not None: g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] @@ -1028,7 +1076,9 @@ def _set_x_lengths(x, aux_input): @torch.no_grad() def inference( - self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None} + self, + x, + aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None, "durations": None}, ): # pylint: disable=dangerous-default-value """ Note: @@ -1048,7 +1098,7 @@ def inference( - m_p: :math:`[B, C, T_dec]` - logs_p: :math:`[B, C, T_dec]` """ - sid, g, lid = self._set_cond_input(aux_input) + sid, g, lid, durations = self._set_cond_input(aux_input) x_lengths = self._set_x_lengths(x, aux_input) # speaker embedding @@ -1062,21 +1112,25 @@ def inference( x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb) - if self.args.use_sdp: - logw = self.duration_predictor( - x, - x_mask, - g=g if self.args.condition_dp_on_speaker else None, - reverse=True, - noise_scale=self.inference_noise_scale_dp, - lang_emb=lang_emb, - ) + if durations is None: + if self.args.use_sdp: + logw = self.duration_predictor( + x, + x_mask, + g=g if self.args.condition_dp_on_speaker else None, + reverse=True, + noise_scale=self.inference_noise_scale_dp, + lang_emb=lang_emb, + ) + else: + logw = self.duration_predictor( + x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb + ) + w = torch.exp(logw) * x_mask * self.length_scale else: - logw = self.duration_predictor( - x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb - ) + assert durations.shape[-1] == x.shape[-1] + w = durations.unsqueeze(0) - w = torch.exp(logw) * x_mask * self.length_scale w_ceil = torch.ceil(w) y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype).unsqueeze(1) # [B, 1, T_dec] @@ -1341,7 +1395,7 @@ def get_aux_input_from_test_sentences(self, sentence_info): if hasattr(self, "speaker_manager"): if config.use_d_vector_file: if speaker_name is None: - d_vector = self.speaker_manager.get_random_embeddings() + d_vector = self.speaker_manager.get_random_embedding() else: d_vector = self.speaker_manager.get_mean_embedding(speaker_name, num_samples=None, randomize=False) elif config.use_speaker_embedding: @@ -1485,6 +1539,42 @@ def format_batch_on_device(self, batch): batch["mel"] = batch["mel"] * sequence_mask(batch["mel_lens"]).unsqueeze(1) return batch + def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1, is_eval=False): + weights = None + data_items = dataset.samples + if getattr(config, "use_weighted_sampler", False): + for attr_name, alpha in config.weighted_sampler_attrs.items(): + print(f" > Using weighted sampler for attribute '{attr_name}' with alpha '{alpha}'") + multi_dict = config.weighted_sampler_multipliers.get(attr_name, None) + print(multi_dict) + weights, attr_names, attr_weights = get_attribute_balancer_weights( + attr_name=attr_name, items=data_items, multi_dict=multi_dict + ) + weights = weights * alpha + print(f" > Attribute weights for '{attr_names}' \n | > {attr_weights}") + + # input_audio_lenghts = [os.path.getsize(x["audio_file"]) for x in data_items] + + if weights is not None: + w_sampler = WeightedRandomSampler(weights, len(weights)) + batch_sampler = BucketBatchSampler( + w_sampler, + data=data_items, + batch_size=config.eval_batch_size if is_eval else config.batch_size, + sort_key=lambda x: os.path.getsize(x["audio_file"]), + drop_last=True, + ) + else: + batch_sampler = None + # sampler for DDP + if batch_sampler is None: + batch_sampler = DistributedSampler(dataset) if num_gpus > 1 else None + else: # If a sampler is already defined use this sampler and DDP sampler together + batch_sampler = ( + DistributedSamplerWrapper(batch_sampler) if num_gpus > 1 else batch_sampler + ) # TODO: check batch_sampler with multi-gpu + return batch_sampler + def get_data_loader( self, config: Coqpit, @@ -1523,17 +1613,24 @@ def get_data_loader( # get samplers sampler = self.get_sampler(config, dataset, num_gpus) - - loader = DataLoader( - dataset, - batch_size=config.eval_batch_size if is_eval else config.batch_size, - shuffle=False, # shuffle is done in the dataset. - drop_last=False, # setting this False might cause issues in AMP training. - sampler=sampler, - collate_fn=dataset.collate_fn, - num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, - pin_memory=False, - ) + if sampler is None: + loader = DataLoader( + dataset, + batch_size=config.eval_batch_size if is_eval else config.batch_size, + shuffle=False, # shuffle is done in the dataset. + collate_fn=dataset.collate_fn, + drop_last=False, # setting this False might cause issues in AMP training. + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) + else: + loader = DataLoader( + dataset, + batch_sampler=sampler, + collate_fn=dataset.collate_fn, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) return loader def get_optimizer(self) -> List: @@ -1590,7 +1687,7 @@ def load_checkpoint( strict=True, ): # pylint: disable=unused-argument, redefined-builtin """Load the model checkpoint and setup for training or inference""" - state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) # compat band-aid for the pre-trained models to not use the encoder baked into the model # TODO: consider baking the speaker encoder into the model and call it from there. # as it is probably easier for model distribution. diff --git a/TTS/tts/utils/helpers.py b/TTS/tts/utils/helpers.py index c2e7f56146..b62004c8ed 100644 --- a/TTS/tts/utils/helpers.py +++ b/TTS/tts/utils/helpers.py @@ -76,7 +76,7 @@ def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4, pad_ index_start = segment_indices[i] index_end = index_start + segment_size x_i = x[i] - if pad_short and index_end > x.size(2): + if pad_short and index_end >= x.size(2): # pad the sample if it is shorter than the segment size x_i = torch.nn.functional.pad(x_i, (0, (index_end + 1) - x.size(2))) segments[i] = x_i[:, index_start:index_end] @@ -107,16 +107,16 @@ def rand_segments( T = segment_size if _x_lenghts is None: _x_lenghts = T - len_diff = _x_lenghts - segment_size + 1 + len_diff = _x_lenghts - segment_size if let_short_samples: _x_lenghts[len_diff < 0] = segment_size - len_diff = _x_lenghts - segment_size + 1 + len_diff = _x_lenghts - segment_size else: assert all( len_diff > 0 ), f" [!] At least one sample is shorter than the segment size ({segment_size}). \n {_x_lenghts}" - segment_indices = (torch.rand([B]).type_as(x) * len_diff).long() - ret = segment(x, segment_indices, segment_size) + segment_indices = (torch.rand([B]).type_as(x) * (len_diff + 1)).long() + ret = segment(x, segment_indices, segment_size, pad_short=pad_short) return ret, segment_indices diff --git a/TTS/tts/utils/ssim.py b/TTS/tts/utils/ssim.py index ab2c69914e..4bc3befc5b 100644 --- a/TTS/tts/utils/ssim.py +++ b/TTS/tts/utils/ssim.py @@ -1,73 +1,383 @@ -# taken from https://github.com/Po-Hsun-Su/pytorch-ssim +# Adopted from https://github.com/photosynthesis-team/piq -from math import exp +from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F -from torch.autograd import Variable +from torch.nn.modules.loss import _Loss -def gaussian(window_size, sigma): - gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) for x in range(window_size)]) - return gauss / gauss.sum() +def _reduce(x: torch.Tensor, reduction: str = "mean") -> torch.Tensor: + r"""Reduce input in batch dimension if needed. + Args: + x: Tensor with shape (N, *). + reduction: Specifies the reduction type: + ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'`` + """ + if reduction == "none": + return x + if reduction == "mean": + return x.mean(dim=0) + if reduction == "sum": + return x.sum(dim=0) + raise ValueError("Unknown reduction. Expected one of {'none', 'mean', 'sum'}") -def create_window(window_size, channel): - _1D_window = gaussian(window_size, 1.5).unsqueeze(1) - _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) - window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) - return window +def _validate_input( + tensors: List[torch.Tensor], + dim_range: Tuple[int, int] = (0, -1), + data_range: Tuple[float, float] = (0.0, -1.0), + # size_dim_range: Tuple[float, float] = (0., -1.), + size_range: Optional[Tuple[int, int]] = None, +) -> None: + r"""Check that input(-s) satisfies the requirements + Args: + tensors: Tensors to check + dim_range: Allowed number of dimensions. (min, max) + data_range: Allowed range of values in tensors. (min, max) + size_range: Dimensions to include in size comparison. (start_dim, end_dim + 1) + """ + if not __debug__: + return -def _ssim(img1, img2, window, window_size, channel, size_average=True): - mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) - mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) + x = tensors[0] - # TODO: check if you need AMP disabled - # with torch.cuda.amp.autocast(enabled=False): - mu1_sq = mu1.float().pow(2) - mu2_sq = mu2.float().pow(2) - mu1_mu2 = mu1 * mu2 + for t in tensors: + assert torch.is_tensor(t), f"Expected torch.Tensor, got {type(t)}" + assert t.device == x.device, f"Expected tensors to be on {x.device}, got {t.device}" - sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq - sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq - sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 + if size_range is None: + assert t.size() == x.size(), f"Expected tensors with same size, got {t.size()} and {x.size()}" + else: + assert ( + t.size()[size_range[0] : size_range[1]] == x.size()[size_range[0] : size_range[1]] + ), f"Expected tensors with same size at given dimensions, got {t.size()} and {x.size()}" + + if dim_range[0] == dim_range[1]: + assert t.dim() == dim_range[0], f"Expected number of dimensions to be {dim_range[0]}, got {t.dim()}" + elif dim_range[0] < dim_range[1]: + assert ( + dim_range[0] <= t.dim() <= dim_range[1] + ), f"Expected number of dimensions to be between {dim_range[0]} and {dim_range[1]}, got {t.dim()}" + + if data_range[0] < data_range[1]: + assert data_range[0] <= t.min(), f"Expected values to be greater or equal to {data_range[0]}, got {t.min()}" + assert t.max() <= data_range[1], f"Expected values to be lower or equal to {data_range[1]}, got {t.max()}" + + +def gaussian_filter(kernel_size: int, sigma: float) -> torch.Tensor: + r"""Returns 2D Gaussian kernel N(0,`sigma`^2) + Args: + size: Size of the kernel + sigma: Std of the distribution + Returns: + gaussian_kernel: Tensor with shape (1, kernel_size, kernel_size) + """ + coords = torch.arange(kernel_size, dtype=torch.float32) + coords -= (kernel_size - 1) / 2.0 + + g = coords**2 + g = (-(g.unsqueeze(0) + g.unsqueeze(1)) / (2 * sigma**2)).exp() + + g /= g.sum() + return g.unsqueeze(0) + + +def ssim( + x: torch.Tensor, + y: torch.Tensor, + kernel_size: int = 11, + kernel_sigma: float = 1.5, + data_range: Union[int, float] = 1.0, + reduction: str = "mean", + full: bool = False, + downsample: bool = True, + k1: float = 0.01, + k2: float = 0.03, +) -> List[torch.Tensor]: + r"""Interface of Structural Similarity (SSIM) index. + Inputs supposed to be in range ``[0, data_range]``. + To match performance with skimage and tensorflow set ``'downsample' = True``. + + Args: + x: An input tensor. Shape :math:`(N, C, H, W)` or :math:`(N, C, H, W, 2)`. + y: A target tensor. Shape :math:`(N, C, H, W)` or :math:`(N, C, H, W, 2)`. + kernel_size: The side-length of the sliding window used in comparison. Must be an odd value. + kernel_sigma: Sigma of normal distribution. + data_range: Maximum value range of images (usually 1.0 or 255). + reduction: Specifies the reduction type: + ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'`` + full: Return cs map or not. + downsample: Perform average pool before SSIM computation. Default: True + k1: Algorithm parameter, K1 (small constant). + k2: Algorithm parameter, K2 (small constant). + Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. + + Returns: + Value of Structural Similarity (SSIM) index. In case of 5D input tensors, complex value is returned + as a tensor of size 2. + + References: + Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004). + Image quality assessment: From error visibility to structural similarity. + IEEE Transactions on Image Processing, 13, 600-612. + https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf, + DOI: `10.1109/TIP.2003.819861` + """ + assert kernel_size % 2 == 1, f"Kernel size must be odd, got [{kernel_size}]" + _validate_input([x, y], dim_range=(4, 5), data_range=(0, data_range)) + + x = x / float(data_range) + y = y / float(data_range) + + # Averagepool image if the size is large enough + f = max(1, round(min(x.size()[-2:]) / 256)) + if (f > 1) and downsample: + x = F.avg_pool2d(x, kernel_size=f) + y = F.avg_pool2d(y, kernel_size=f) + + kernel = gaussian_filter(kernel_size, kernel_sigma).repeat(x.size(1), 1, 1, 1).to(y) + _compute_ssim_per_channel = _ssim_per_channel_complex if x.dim() == 5 else _ssim_per_channel + ssim_map, cs_map = _compute_ssim_per_channel(x=x, y=y, kernel=kernel, k1=k1, k2=k2) + ssim_val = ssim_map.mean(1) + cs = cs_map.mean(1) + + ssim_val = _reduce(ssim_val, reduction) + cs = _reduce(cs, reduction) + + if full: + return [ssim_val, cs] + + return ssim_val + + +class SSIMLoss(_Loss): + r"""Creates a criterion that measures the structural similarity index error between + each element in the input :math:`x` and target :math:`y`. + + To match performance with skimage and tensorflow set ``'downsample' = True``. + + The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: + + .. math:: + SSIM = \{ssim_1,\dots,ssim_{N \times C}\}\\ + ssim_{l}(x, y) = \frac{(2 \mu_x \mu_y + c_1) (2 \sigma_{xy} + c_2)} + {(\mu_x^2 +\mu_y^2 + c_1)(\sigma_x^2 +\sigma_y^2 + c_2)}, + + where :math:`N` is the batch size, `C` is the channel size. If :attr:`reduction` is not ``'none'`` + (default ``'mean'``), then: + + .. math:: + SSIMLoss(x, y) = + \begin{cases} + \operatorname{mean}(1 - SSIM), & \text{if reduction} = \text{'mean';}\\ + \operatorname{sum}(1 - SSIM), & \text{if reduction} = \text{'sum'.} + \end{cases} + + :math:`x` and :math:`y` are tensors of arbitrary shapes with a total + of :math:`n` elements each. - C1 = 0.01**2 - C2 = 0.03**2 + The sum operation still operates over all the elements, and divides by :math:`n`. + The division by :math:`n` can be avoided if one sets ``reduction = 'sum'``. + In case of 5D input tensors, complex value is returned as a tensor of size 2. - ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + Args: + kernel_size: By default, the mean and covariance of a pixel is obtained + by convolution with given filter_size. + kernel_sigma: Standard deviation for Gaussian kernel. + k1: Coefficient related to c1 in the above equation. + k2: Coefficient related to c2 in the above equation. + downsample: Perform average pool before SSIM computation. Default: True + reduction: Specifies the reduction type: + ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'`` + data_range: Maximum value range of images (usually 1.0 or 255). - if size_average: - return ssim_map.mean() - return ssim_map.mean(1).mean(1).mean(1) + Examples: + >>> loss = SSIMLoss() + >>> x = torch.rand(3, 3, 256, 256, requires_grad=True) + >>> y = torch.rand(3, 3, 256, 256) + >>> output = loss(x, y) + >>> output.backward() + References: + Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004). + Image quality assessment: From error visibility to structural similarity. + IEEE Transactions on Image Processing, 13, 600-612. + https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf, + DOI:`10.1109/TIP.2003.819861` + """ + __constants__ = ["kernel_size", "k1", "k2", "sigma", "kernel", "reduction"] -class SSIM(torch.nn.Module): - def __init__(self, window_size=11, size_average=True): + def __init__( + self, + kernel_size: int = 11, + kernel_sigma: float = 1.5, + k1: float = 0.01, + k2: float = 0.03, + downsample: bool = True, + reduction: str = "mean", + data_range: Union[int, float] = 1.0, + ) -> None: super().__init__() - self.window_size = window_size - self.size_average = size_average - self.channel = 1 - self.window = create_window(window_size, self.channel) - def forward(self, img1, img2): - (_, channel, _, _) = img1.size() + # Generic loss parameters. + self.reduction = reduction - if channel == self.channel and self.window.data.type() == img1.data.type(): - window = self.window - else: - window = create_window(self.window_size, channel) - window = window.type_as(img1) + # Loss-specific parameters. + self.kernel_size = kernel_size + + # This check might look redundant because kernel size is checked within the ssim function anyway. + # However, this check allows to fail fast when the loss is being initialised and training has not been started. + assert kernel_size % 2 == 1, f"Kernel size must be odd, got [{kernel_size}]" + self.kernel_sigma = kernel_sigma + self.k1 = k1 + self.k2 = k2 + self.downsample = downsample + self.data_range = data_range + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + r"""Computation of Structural Similarity (SSIM) index as a loss function. + + Args: + x: An input tensor. Shape :math:`(N, C, H, W)` or :math:`(N, C, H, W, 2)`. + y: A target tensor. Shape :math:`(N, C, H, W)` or :math:`(N, C, H, W, 2)`. + + Returns: + Value of SSIM loss to be minimized, i.e ``1 - ssim`` in [0, 1] range. In case of 5D input tensors, + complex value is returned as a tensor of size 2. + """ + + score = ssim( + x=x, + y=y, + kernel_size=self.kernel_size, + kernel_sigma=self.kernel_sigma, + downsample=self.downsample, + data_range=self.data_range, + reduction=self.reduction, + full=False, + k1=self.k1, + k2=self.k2, + ) + return torch.ones_like(score) - score + + +def _ssim_per_channel( + x: torch.Tensor, + y: torch.Tensor, + kernel: torch.Tensor, + k1: float = 0.01, + k2: float = 0.03, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + r"""Calculate Structural Similarity (SSIM) index for X and Y per channel. + + Args: + x: An input tensor. Shape :math:`(N, C, H, W)`. + y: A target tensor. Shape :math:`(N, C, H, W)`. + kernel: 2D Gaussian kernel. + k1: Algorithm parameter, K1 (small constant, see [1]). + k2: Algorithm parameter, K2 (small constant, see [1]). + Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. + + Returns: + Full Value of Structural Similarity (SSIM) index. + """ + if x.size(-1) < kernel.size(-1) or x.size(-2) < kernel.size(-2): + raise ValueError( + f"Kernel size can't be greater than actual input size. Input size: {x.size()}. " + f"Kernel size: {kernel.size()}" + ) + + c1 = k1**2 + c2 = k2**2 + n_channels = x.size(1) + mu_x = F.conv2d(x, weight=kernel, stride=1, padding=0, groups=n_channels) + mu_y = F.conv2d(y, weight=kernel, stride=1, padding=0, groups=n_channels) + + mu_xx = mu_x**2 + mu_yy = mu_y**2 + mu_xy = mu_x * mu_y + + sigma_xx = F.conv2d(x**2, weight=kernel, stride=1, padding=0, groups=n_channels) - mu_xx + sigma_yy = F.conv2d(y**2, weight=kernel, stride=1, padding=0, groups=n_channels) - mu_yy + sigma_xy = F.conv2d(x * y, weight=kernel, stride=1, padding=0, groups=n_channels) - mu_xy + + # Contrast sensitivity (CS) with alpha = beta = gamma = 1. + cs = (2.0 * sigma_xy + c2) / (sigma_xx + sigma_yy + c2) + + # Structural similarity (SSIM) + ss = (2.0 * mu_xy + c1) / (mu_xx + mu_yy + c1) * cs + + ssim_val = ss.mean(dim=(-1, -2)) + cs = cs.mean(dim=(-1, -2)) + return ssim_val, cs + + +def _ssim_per_channel_complex( + x: torch.Tensor, + y: torch.Tensor, + kernel: torch.Tensor, + k1: float = 0.01, + k2: float = 0.03, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + r"""Calculate Structural Similarity (SSIM) index for Complex X and Y per channel. + + Args: + x: An input tensor. Shape :math:`(N, C, H, W, 2)`. + y: A target tensor. Shape :math:`(N, C, H, W, 2)`. + kernel: 2-D gauss kernel. + k1: Algorithm parameter, K1 (small constant, see [1]). + k2: Algorithm parameter, K2 (small constant, see [1]). + Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. + + Returns: + Full Value of Complex Structural Similarity (SSIM) index. + """ + n_channels = x.size(1) + if x.size(-2) < kernel.size(-1) or x.size(-3) < kernel.size(-2): + raise ValueError( + f"Kernel size can't be greater than actual input size. Input size: {x.size()}. " + f"Kernel size: {kernel.size()}" + ) + + c1 = k1**2 + c2 = k2**2 + + x_real = x[..., 0] + x_imag = x[..., 1] + y_real = y[..., 0] + y_imag = y[..., 1] + + mu1_real = F.conv2d(x_real, weight=kernel, stride=1, padding=0, groups=n_channels) + mu1_imag = F.conv2d(x_imag, weight=kernel, stride=1, padding=0, groups=n_channels) + mu2_real = F.conv2d(y_real, weight=kernel, stride=1, padding=0, groups=n_channels) + mu2_imag = F.conv2d(y_imag, weight=kernel, stride=1, padding=0, groups=n_channels) + + mu1_sq = mu1_real.pow(2) + mu1_imag.pow(2) + mu2_sq = mu2_real.pow(2) + mu2_imag.pow(2) + mu1_mu2_real = mu1_real * mu2_real - mu1_imag * mu2_imag + mu1_mu2_imag = mu1_real * mu2_imag + mu1_imag * mu2_real + + compensation = 1.0 - self.window = window - self.channel = channel + x_sq = x_real.pow(2) + x_imag.pow(2) + y_sq = y_real.pow(2) + y_imag.pow(2) + x_y_real = x_real * y_real - x_imag * y_imag + x_y_imag = x_real * y_imag + x_imag * y_real - return _ssim(img1, img2, window, self.window_size, channel, self.size_average) + sigma1_sq = F.conv2d(x_sq, weight=kernel, stride=1, padding=0, groups=n_channels) - mu1_sq + sigma2_sq = F.conv2d(y_sq, weight=kernel, stride=1, padding=0, groups=n_channels) - mu2_sq + sigma12_real = F.conv2d(x_y_real, weight=kernel, stride=1, padding=0, groups=n_channels) - mu1_mu2_real + sigma12_imag = F.conv2d(x_y_imag, weight=kernel, stride=1, padding=0, groups=n_channels) - mu1_mu2_imag + sigma12 = torch.stack((sigma12_imag, sigma12_real), dim=-1) + mu1_mu2 = torch.stack((mu1_mu2_real, mu1_mu2_imag), dim=-1) + # Set alpha = beta = gamma = 1. + cs_map = (sigma12 * 2 + c2 * compensation) / (sigma1_sq.unsqueeze(-1) + sigma2_sq.unsqueeze(-1) + c2 * compensation) + ssim_map = (mu1_mu2 * 2 + c1 * compensation) / (mu1_sq.unsqueeze(-1) + mu2_sq.unsqueeze(-1) + c1 * compensation) + ssim_map = ssim_map * cs_map + ssim_val = ssim_map.mean(dim=(-2, -3)) + cs = cs_map.mean(dim=(-2, -3)) -def ssim(img1, img2, window_size=11, size_average=True): - (_, channel, _, _) = img1.size() - window = create_window(window_size, channel).type_as(img1) - window = window.type_as(img1) - return _ssim(img1, img2, window, window_size, channel, size_average) + return ssim_val, cs diff --git a/TTS/tts/utils/text/phonemizers/espeak_wrapper.py b/TTS/tts/utils/text/phonemizers/espeak_wrapper.py index 024f79c6be..281da22162 100644 --- a/TTS/tts/utils/text/phonemizers/espeak_wrapper.py +++ b/TTS/tts/utils/text/phonemizers/espeak_wrapper.py @@ -1,4 +1,5 @@ import logging +import re import subprocess from typing import Dict, List @@ -163,6 +164,13 @@ def phonemize_espeak(self, text: str, separator: str = "|", tie=False) -> str: # dealing with the conditions descrived above ph_decoded = ph_decoded[:1].replace("_", "") + ph_decoded[1:] + + # espeak-ng backend can add language flags that need to be removed: + # "sɛʁtˈɛ̃ mˈo kɔm (en)fˈʊtbɔːl(fr) ʒenˈɛʁ de- flˈaɡ də- lˈɑ̃ɡ." + # phonemize needs to remove the language flags of the returned text: + # "sɛʁtˈɛ̃ mˈo kɔm fˈʊtbɔːl ʒenˈɛʁ de- flˈaɡ də- lˈɑ̃ɡ." + ph_decoded = re.sub(r"\(.+?\)", "", ph_decoded) + phonemes += ph_decoded.strip() return phonemes.replace("_", separator) diff --git a/TTS/tts/utils/text/punctuation.py b/TTS/tts/utils/text/punctuation.py index b2a058bb07..8d199cc545 100644 --- a/TTS/tts/utils/text/punctuation.py +++ b/TTS/tts/utils/text/punctuation.py @@ -137,7 +137,7 @@ def _restore(cls, text, puncs, num): # pylint: disable=too-many-return-statemen # nothing have been phonemized, returns the puncs alone if not text: - return ["".join(m.mark for m in puncs)] + return ["".join(m.punc for m in puncs)] current = puncs[0] diff --git a/TTS/utils/audio/__init__.py b/TTS/utils/audio/__init__.py new file mode 100644 index 0000000000..f18f221999 --- /dev/null +++ b/TTS/utils/audio/__init__.py @@ -0,0 +1 @@ +from TTS.utils.audio.processor import AudioProcessor diff --git a/TTS/utils/audio/numpy_transforms.py b/TTS/utils/audio/numpy_transforms.py new file mode 100644 index 0000000000..f6f0385542 --- /dev/null +++ b/TTS/utils/audio/numpy_transforms.py @@ -0,0 +1,425 @@ +from typing import Tuple + +import librosa +import numpy as np +import pyworld as pw +import scipy +import soundfile as sf + +# For using kwargs +# pylint: disable=unused-argument + + +def build_mel_basis( + *, + sample_rate: int = None, + fft_size: int = None, + num_mels: int = None, + mel_fmax: int = None, + mel_fmin: int = None, + **kwargs, +) -> np.ndarray: + """Build melspectrogram basis. + + Returns: + np.ndarray: melspectrogram basis. + """ + if mel_fmax is not None: + assert mel_fmax <= sample_rate // 2 + assert mel_fmax - mel_fmin > 0 + return librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=num_mels, fmin=mel_fmin, fmax=mel_fmax) + + +def millisec_to_length( + *, frame_length_ms: int = None, frame_shift_ms: int = None, sample_rate: int = None, **kwargs +) -> Tuple[int, int]: + """Compute hop and window length from milliseconds. + + Returns: + Tuple[int, int]: hop length and window length for STFT. + """ + factor = frame_length_ms / frame_shift_ms + assert (factor).is_integer(), " [!] frame_shift_ms should divide frame_length_ms" + win_length = int(frame_length_ms / 1000.0 * sample_rate) + hop_length = int(win_length / float(factor)) + return win_length, hop_length + + +def _log(x, base): + if base == 10: + return np.log10(x) + return np.log(x) + + +def _exp(x, base): + if base == 10: + return np.power(10, x) + return np.exp(x) + + +def amp_to_db(*, x: np.ndarray = None, gain: float = 1, base: int = 10, **kwargs) -> np.ndarray: + """Convert amplitude values to decibels. + + Args: + x (np.ndarray): Amplitude spectrogram. + gain (float): Gain factor. Defaults to 1. + base (int): Logarithm base. Defaults to 10. + + Returns: + np.ndarray: Decibels spectrogram. + """ + assert (x < 0).sum() == 0, " [!] Input values must be non-negative." + return gain * _log(np.maximum(1e-8, x), base) + + +# pylint: disable=no-self-use +def db_to_amp(*, x: np.ndarray = None, gain: float = 1, base: int = 10, **kwargs) -> np.ndarray: + """Convert decibels spectrogram to amplitude spectrogram. + + Args: + x (np.ndarray): Decibels spectrogram. + gain (float): Gain factor. Defaults to 1. + base (int): Logarithm base. Defaults to 10. + + Returns: + np.ndarray: Amplitude spectrogram. + """ + return _exp(x / gain, base) + + +def preemphasis(*, x: np.ndarray, coef: float = 0.97, **kwargs) -> np.ndarray: + """Apply pre-emphasis to the audio signal. Useful to reduce the correlation between neighbouring signal values. + + Args: + x (np.ndarray): Audio signal. + + Raises: + RuntimeError: Preemphasis coeff is set to 0. + + Returns: + np.ndarray: Decorrelated audio signal. + """ + if coef == 0: + raise RuntimeError(" [!] Preemphasis is set 0.0.") + return scipy.signal.lfilter([1, -coef], [1], x) + + +def deemphasis(*, x: np.ndarray = None, coef: float = 0.97, **kwargs) -> np.ndarray: + """Reverse pre-emphasis.""" + if coef == 0: + raise RuntimeError(" [!] Preemphasis is set 0.0.") + return scipy.signal.lfilter([1], [1, -coef], x) + + +def spec_to_mel(*, spec: np.ndarray, mel_basis: np.ndarray = None, **kwargs) -> np.ndarray: + """Convert a full scale linear spectrogram output of a network to a melspectrogram. + + Args: + spec (np.ndarray): Normalized full scale linear spectrogram. + + Shapes: + - spec: :math:`[C, T]` + + Returns: + np.ndarray: Normalized melspectrogram. + """ + return np.dot(mel_basis, spec) + + +def mel_to_spec(*, mel: np.ndarray = None, mel_basis: np.ndarray = None, **kwargs) -> np.ndarray: + """Convert a melspectrogram to full scale spectrogram.""" + assert (mel < 0).sum() == 0, " [!] Input values must be non-negative." + inv_mel_basis = np.linalg.pinv(mel_basis) + return np.maximum(1e-10, np.dot(inv_mel_basis, mel)) + + +def wav_to_spec(*, wav: np.ndarray = None, **kwargs) -> np.ndarray: + """Compute a spectrogram from a waveform. + + Args: + wav (np.ndarray): Waveform. Shape :math:`[T_wav,]` + + Returns: + np.ndarray: Spectrogram. Shape :math:`[C, T_spec]`. :math:`T_spec == T_wav / hop_length` + """ + D = stft(y=wav, **kwargs) + S = np.abs(D) + return S.astype(np.float32) + + +def wav_to_mel(*, wav: np.ndarray = None, mel_basis=None, **kwargs) -> np.ndarray: + """Compute a melspectrogram from a waveform.""" + D = stft(y=wav, **kwargs) + S = spec_to_mel(spec=np.abs(D), mel_basis=mel_basis, **kwargs) + return S.astype(np.float32) + + +def spec_to_wav(*, spec: np.ndarray, power: float = 1.5, **kwargs) -> np.ndarray: + """Convert a spectrogram to a waveform using Griffi-Lim vocoder.""" + S = spec.copy() + return griffin_lim(spec=S**power, **kwargs) + + +def mel_to_wav(*, mel: np.ndarray = None, power: float = 1.5, **kwargs) -> np.ndarray: + """Convert a melspectrogram to a waveform using Griffi-Lim vocoder.""" + S = mel.copy() + S = mel_to_spec(mel=S, mel_basis=kwargs["mel_basis"]) # Convert back to linear + return griffin_lim(spec=S**power, **kwargs) + + +### STFT and ISTFT ### +def stft( + *, + y: np.ndarray = None, + fft_size: int = None, + hop_length: int = None, + win_length: int = None, + pad_mode: str = "reflect", + window: str = "hann", + center: bool = True, + **kwargs, +) -> np.ndarray: + """Librosa STFT wrapper. + + Check http://librosa.org/doc/main/generated/librosa.stft.html argument details. + + Returns: + np.ndarray: Complex number array. + """ + return librosa.stft( + y=y, + n_fft=fft_size, + hop_length=hop_length, + win_length=win_length, + pad_mode=pad_mode, + window=window, + center=center, + ) + + +def istft( + *, + y: np.ndarray = None, + fft_size: int = None, + hop_length: int = None, + win_length: int = None, + window: str = "hann", + center: bool = True, + **kwargs, +) -> np.ndarray: + """Librosa iSTFT wrapper. + + Check http://librosa.org/doc/main/generated/librosa.istft.html argument details. + + Returns: + np.ndarray: Complex number array. + """ + return librosa.istft(y, hop_length=hop_length, win_length=win_length, center=center, window=window) + + +def griffin_lim(*, spec: np.ndarray = None, num_iter=60, **kwargs) -> np.ndarray: + angles = np.exp(2j * np.pi * np.random.rand(*spec.shape)) + S_complex = np.abs(spec).astype(np.complex) + y = istft(y=S_complex * angles, **kwargs) + if not np.isfinite(y).all(): + print(" [!] Waveform is not finite everywhere. Skipping the GL.") + return np.array([0.0]) + for _ in range(num_iter): + angles = np.exp(1j * np.angle(stft(y=y, **kwargs))) + y = istft(y=S_complex * angles, **kwargs) + return y + + +def compute_stft_paddings( + *, x: np.ndarray = None, hop_length: int = None, pad_two_sides: bool = False, **kwargs +) -> Tuple[int, int]: + """Compute paddings used by Librosa's STFT. Compute right padding (final frame) or both sides padding + (first and final frames)""" + pad = (x.shape[0] // hop_length + 1) * hop_length - x.shape[0] + if not pad_two_sides: + return 0, pad + return pad // 2, pad // 2 + pad % 2 + + +def compute_f0( + *, x: np.ndarray = None, pitch_fmax: float = None, hop_length: int = None, sample_rate: int = None, **kwargs +) -> np.ndarray: + """Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram. + + Args: + x (np.ndarray): Waveform. Shape :math:`[T_wav,]` + + Returns: + np.ndarray: Pitch. Shape :math:`[T_pitch,]`. :math:`T_pitch == T_wav / hop_length` + + Examples: + >>> WAV_FILE = filename = librosa.util.example_audio_file() + >>> from TTS.config import BaseAudioConfig + >>> from TTS.utils.audio.processor import AudioProcessor >>> conf = BaseAudioConfig(pitch_fmax=8000) + >>> ap = AudioProcessor(**conf) + >>> wav = ap.load_wav(WAV_FILE, sr=22050)[:5 * 22050] + >>> pitch = ap.compute_f0(wav) + """ + assert pitch_fmax is not None, " [!] Set `pitch_fmax` before caling `compute_f0`." + + f0, t = pw.dio( + x.astype(np.double), + fs=sample_rate, + f0_ceil=pitch_fmax, + frame_period=1000 * hop_length / sample_rate, + ) + f0 = pw.stonemask(x.astype(np.double), f0, t, sample_rate) + return f0 + + +### Audio Processing ### +def find_endpoint( + *, + wav: np.ndarray = None, + trim_db: float = -40, + sample_rate: int = None, + min_silence_sec=0.8, + gain: float = None, + base: int = None, + **kwargs, +) -> int: + """Find the last point without silence at the end of a audio signal. + + Args: + wav (np.ndarray): Audio signal. + threshold_db (int, optional): Silence threshold in decibels. Defaults to -40. + min_silence_sec (float, optional): Ignore silences that are shorter then this in secs. Defaults to 0.8. + gian (float, optional): Gain to be used to convert trim_db to trim_amp. Defaults to None. + base (int, optional): Base of the logarithm used to convert trim_db to trim_amp. Defaults to 10. + + Returns: + int: Last point without silence. + """ + window_length = int(sample_rate * min_silence_sec) + hop_length = int(window_length / 4) + threshold = db_to_amp(x=-trim_db, gain=gain, base=base) + for x in range(hop_length, len(wav) - window_length, hop_length): + if np.max(wav[x : x + window_length]) < threshold: + return x + hop_length + return len(wav) + + +def trim_silence( + *, + wav: np.ndarray = None, + sample_rate: int = None, + trim_db: float = None, + win_length: int = None, + hop_length: int = None, + **kwargs, +) -> np.ndarray: + """Trim silent parts with a threshold and 0.01 sec margin""" + margin = int(sample_rate * 0.01) + wav = wav[margin:-margin] + return librosa.effects.trim(wav, top_db=trim_db, frame_length=win_length, hop_length=hop_length)[0] + + +def volume_norm(*, x: np.ndarray = None, coef: float = 0.95, **kwargs) -> np.ndarray: + """Normalize the volume of an audio signal. + + Args: + x (np.ndarray): Raw waveform. + coef (float): Coefficient to rescale the maximum value. Defaults to 0.95. + + Returns: + np.ndarray: Volume normalized waveform. + """ + return x / abs(x).max() * coef + + +def rms_norm(*, wav: np.ndarray = None, db_level: float = -27.0, **kwargs) -> np.ndarray: + r = 10 ** (db_level / 20) + a = np.sqrt((len(wav) * (r**2)) / np.sum(wav**2)) + return wav * a + + +def rms_volume_norm(*, x: np.ndarray, db_level: float = -27.0, **kwargs) -> np.ndarray: + """Normalize the volume based on RMS of the signal. + + Args: + x (np.ndarray): Raw waveform. + db_level (float): Target dB level in RMS. Defaults to -27.0. + + Returns: + np.ndarray: RMS normalized waveform. + """ + assert -99 <= db_level <= 0, " [!] db_level should be between -99 and 0" + wav = rms_norm(wav=x, db_level=db_level) + return wav + + +def load_wav(*, filename: str, sample_rate: int = None, resample: bool = False, **kwargs) -> np.ndarray: + """Read a wav file using Librosa and optionally resample, silence trim, volume normalize. + + Resampling slows down loading the file significantly. Therefore it is recommended to resample the file before. + + Args: + filename (str): Path to the wav file. + sr (int, optional): Sampling rate for resampling. Defaults to None. + resample (bool, optional): Resample the audio file when loading. Slows down the I/O time. Defaults to False. + + Returns: + np.ndarray: Loaded waveform. + """ + if resample: + # loading with resampling. It is significantly slower. + x, _ = librosa.load(filename, sr=sample_rate) + else: + # SF is faster than librosa for loading files + x, _ = sf.read(filename) + return x + + +def save_wav(*, wav: np.ndarray, path: str, sample_rate: int = None, **kwargs) -> None: + """Save float waveform to a file using Scipy. + + Args: + wav (np.ndarray): Waveform with float values in range [-1, 1] to save. + path (str): Path to a output file. + sr (int, optional): Sampling rate used for saving to the file. Defaults to None. + """ + wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav)))) + scipy.io.wavfile.write(path, sample_rate, wav_norm.astype(np.int16)) + + +def mulaw_encode(*, wav: np.ndarray, mulaw_qc: int, **kwargs) -> np.ndarray: + mu = 2**mulaw_qc - 1 + signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu) + signal = (signal + 1) / 2 * mu + 0.5 + return np.floor( + signal, + ) + + +def mulaw_decode(*, wav, mulaw_qc: int, **kwargs) -> np.ndarray: + """Recovers waveform from quantized values.""" + mu = 2**mulaw_qc - 1 + x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1) + return x + + +def encode_16bits(*, x: np.ndarray, **kwargs) -> np.ndarray: + return np.clip(x * 2**15, -(2**15), 2**15 - 1).astype(np.int16) + + +def quantize(*, x: np.ndarray, quantize_bits: int, **kwargs) -> np.ndarray: + """Quantize a waveform to a given number of bits. + + Args: + x (np.ndarray): Waveform to quantize. Must be normalized into the range `[-1, 1]`. + quantize_bits (int): Number of quantization bits. + + Returns: + np.ndarray: Quantized waveform. + """ + return (x + 1.0) * (2**quantize_bits - 1) / 2 + + +def dequantize(*, x, quantize_bits, **kwargs) -> np.ndarray: + """Dequantize a waveform from the given number of bits.""" + return 2 * x / (2**quantize_bits - 1) - 1 diff --git a/TTS/utils/audio.py b/TTS/utils/audio/processor.py similarity index 84% rename from TTS/utils/audio.py rename to TTS/utils/audio/processor.py index fc9d194201..5a63b44455 100644 --- a/TTS/utils/audio.py +++ b/TTS/utils/audio/processor.py @@ -6,179 +6,14 @@ import scipy.io.wavfile import scipy.signal import soundfile as sf -import torch -from torch import nn from TTS.tts.utils.helpers import StandardScaler - -class TorchSTFT(nn.Module): # pylint: disable=abstract-method - """Some of the audio processing funtions using Torch for faster batch processing. - - TODO: Merge this with audio.py - - Args: - - n_fft (int): - FFT window size for STFT. - - hop_length (int): - number of frames between STFT columns. - - win_length (int, optional): - STFT window length. - - pad_wav (bool, optional): - If True pad the audio with (n_fft - hop_length) / 2). Defaults to False. - - window (str, optional): - The name of a function to create a window tensor that is applied/multiplied to each frame/window. Defaults to "hann_window" - - sample_rate (int, optional): - target audio sampling rate. Defaults to None. - - mel_fmin (int, optional): - minimum filter frequency for computing melspectrograms. Defaults to None. - - mel_fmax (int, optional): - maximum filter frequency for computing melspectrograms. Defaults to None. - - n_mels (int, optional): - number of melspectrogram dimensions. Defaults to None. - - use_mel (bool, optional): - If True compute the melspectrograms otherwise. Defaults to False. - - do_amp_to_db_linear (bool, optional): - enable/disable amplitude to dB conversion of linear spectrograms. Defaults to False. - - spec_gain (float, optional): - gain applied when converting amplitude to DB. Defaults to 1.0. - - power (float, optional): - Exponent for the magnitude spectrogram, e.g., 1 for energy, 2 for power, etc. Defaults to None. - - use_htk (bool, optional): - Use HTK formula in mel filter instead of Slaney. - - mel_norm (None, 'slaney', or number, optional): - If 'slaney', divide the triangular mel weights by the width of the mel band - (area normalization). - - If numeric, use `librosa.util.normalize` to normalize each filter by to unit l_p norm. - See `librosa.util.normalize` for a full description of supported norm values - (including `+-np.inf`). - - Otherwise, leave all the triangles aiming for a peak value of 1.0. Defaults to "slaney". - """ - - def __init__( - self, - n_fft, - hop_length, - win_length, - pad_wav=False, - window="hann_window", - sample_rate=None, - mel_fmin=0, - mel_fmax=None, - n_mels=80, - use_mel=False, - do_amp_to_db=False, - spec_gain=1.0, - power=None, - use_htk=False, - mel_norm="slaney", - ): - super().__init__() - self.n_fft = n_fft - self.hop_length = hop_length - self.win_length = win_length - self.pad_wav = pad_wav - self.sample_rate = sample_rate - self.mel_fmin = mel_fmin - self.mel_fmax = mel_fmax - self.n_mels = n_mels - self.use_mel = use_mel - self.do_amp_to_db = do_amp_to_db - self.spec_gain = spec_gain - self.power = power - self.use_htk = use_htk - self.mel_norm = mel_norm - self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False) - self.mel_basis = None - if use_mel: - self._build_mel_basis() - - def __call__(self, x): - """Compute spectrogram frames by torch based stft. - - Args: - x (Tensor): input waveform - - Returns: - Tensor: spectrogram frames. - - Shapes: - x: [B x T] or [:math:`[B, 1, T]`] - """ - if x.ndim == 2: - x = x.unsqueeze(1) - if self.pad_wav: - padding = int((self.n_fft - self.hop_length) / 2) - x = torch.nn.functional.pad(x, (padding, padding), mode="reflect") - # B x D x T x 2 - o = torch.stft( - x.squeeze(1), - self.n_fft, - self.hop_length, - self.win_length, - self.window, - center=True, - pad_mode="reflect", # compatible with audio.py - normalized=False, - onesided=True, - return_complex=False, - ) - M = o[:, :, :, 0] - P = o[:, :, :, 1] - S = torch.sqrt(torch.clamp(M**2 + P**2, min=1e-8)) - - if self.power is not None: - S = S**self.power - - if self.use_mel: - S = torch.matmul(self.mel_basis.to(x), S) - if self.do_amp_to_db: - S = self._amp_to_db(S, spec_gain=self.spec_gain) - return S - - def _build_mel_basis(self): - mel_basis = librosa.filters.mel( - self.sample_rate, - self.n_fft, - n_mels=self.n_mels, - fmin=self.mel_fmin, - fmax=self.mel_fmax, - htk=self.use_htk, - norm=self.mel_norm, - ) - self.mel_basis = torch.from_numpy(mel_basis).float() - - @staticmethod - def _amp_to_db(x, spec_gain=1.0): - return torch.log(torch.clamp(x, min=1e-5) * spec_gain) - - @staticmethod - def _db_to_amp(x, spec_gain=1.0): - return torch.exp(x) / spec_gain +# pylint: disable=too-many-public-methods -# pylint: disable=too-many-public-methods class AudioProcessor(object): - """Audio Processor for TTS used by all the data pipelines. - - TODO: Make this a dataclass to replace `BaseAudioConfig`. + """Audio Processor for TTS. Note: All the class arguments are set to default values to enable a flexible initialization diff --git a/TTS/utils/audio/torch_transforms.py b/TTS/utils/audio/torch_transforms.py new file mode 100644 index 0000000000..d4523ad08e --- /dev/null +++ b/TTS/utils/audio/torch_transforms.py @@ -0,0 +1,163 @@ +import librosa +import torch +from torch import nn + + +class TorchSTFT(nn.Module): # pylint: disable=abstract-method + """Some of the audio processing funtions using Torch for faster batch processing. + + Args: + + n_fft (int): + FFT window size for STFT. + + hop_length (int): + number of frames between STFT columns. + + win_length (int, optional): + STFT window length. + + pad_wav (bool, optional): + If True pad the audio with (n_fft - hop_length) / 2). Defaults to False. + + window (str, optional): + The name of a function to create a window tensor that is applied/multiplied to each frame/window. Defaults to "hann_window" + + sample_rate (int, optional): + target audio sampling rate. Defaults to None. + + mel_fmin (int, optional): + minimum filter frequency for computing melspectrograms. Defaults to None. + + mel_fmax (int, optional): + maximum filter frequency for computing melspectrograms. Defaults to None. + + n_mels (int, optional): + number of melspectrogram dimensions. Defaults to None. + + use_mel (bool, optional): + If True compute the melspectrograms otherwise. Defaults to False. + + do_amp_to_db_linear (bool, optional): + enable/disable amplitude to dB conversion of linear spectrograms. Defaults to False. + + spec_gain (float, optional): + gain applied when converting amplitude to DB. Defaults to 1.0. + + power (float, optional): + Exponent for the magnitude spectrogram, e.g., 1 for energy, 2 for power, etc. Defaults to None. + + use_htk (bool, optional): + Use HTK formula in mel filter instead of Slaney. + + mel_norm (None, 'slaney', or number, optional): + If 'slaney', divide the triangular mel weights by the width of the mel band + (area normalization). + + If numeric, use `librosa.util.normalize` to normalize each filter by to unit l_p norm. + See `librosa.util.normalize` for a full description of supported norm values + (including `+-np.inf`). + + Otherwise, leave all the triangles aiming for a peak value of 1.0. Defaults to "slaney". + """ + + def __init__( + self, + n_fft, + hop_length, + win_length, + pad_wav=False, + window="hann_window", + sample_rate=None, + mel_fmin=0, + mel_fmax=None, + n_mels=80, + use_mel=False, + do_amp_to_db=False, + spec_gain=1.0, + power=None, + use_htk=False, + mel_norm="slaney", + ): + super().__init__() + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.pad_wav = pad_wav + self.sample_rate = sample_rate + self.mel_fmin = mel_fmin + self.mel_fmax = mel_fmax + self.n_mels = n_mels + self.use_mel = use_mel + self.do_amp_to_db = do_amp_to_db + self.spec_gain = spec_gain + self.power = power + self.use_htk = use_htk + self.mel_norm = mel_norm + self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False) + self.mel_basis = None + if use_mel: + self._build_mel_basis() + + def __call__(self, x): + """Compute spectrogram frames by torch based stft. + + Args: + x (Tensor): input waveform + + Returns: + Tensor: spectrogram frames. + + Shapes: + x: [B x T] or [:math:`[B, 1, T]`] + """ + if x.ndim == 2: + x = x.unsqueeze(1) + if self.pad_wav: + padding = int((self.n_fft - self.hop_length) / 2) + x = torch.nn.functional.pad(x, (padding, padding), mode="reflect") + # B x D x T x 2 + o = torch.stft( + x.squeeze(1), + self.n_fft, + self.hop_length, + self.win_length, + self.window, + center=True, + pad_mode="reflect", # compatible with audio.py + normalized=False, + onesided=True, + return_complex=False, + ) + M = o[:, :, :, 0] + P = o[:, :, :, 1] + S = torch.sqrt(torch.clamp(M**2 + P**2, min=1e-8)) + + if self.power is not None: + S = S**self.power + + if self.use_mel: + S = torch.matmul(self.mel_basis.to(x), S) + if self.do_amp_to_db: + S = self._amp_to_db(S, spec_gain=self.spec_gain) + return S + + def _build_mel_basis(self): + mel_basis = librosa.filters.mel( + self.sample_rate, + self.n_fft, + n_mels=self.n_mels, + fmin=self.mel_fmin, + fmax=self.mel_fmax, + htk=self.use_htk, + norm=self.mel_norm, + ) + self.mel_basis = torch.from_numpy(mel_basis).float() + + @staticmethod + def _amp_to_db(x, spec_gain=1.0): + return torch.log(torch.clamp(x, min=1e-5) * spec_gain) + + @staticmethod + def _db_to_amp(x, spec_gain=1.0): + return torch.exp(x) / spec_gain diff --git a/TTS/utils/capacitron_optimizer.py b/TTS/utils/capacitron_optimizer.py index c9f075afac..fac7d8a06d 100644 --- a/TTS/utils/capacitron_optimizer.py +++ b/TTS/utils/capacitron_optimizer.py @@ -34,6 +34,8 @@ def first_step(self): self.primary_optimizer.zero_grad() def step(self): + # Update param groups to display the correct learning rate + self.param_groups = self.primary_optimizer.param_groups self.primary_optimizer.step() def zero_grad(self): diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index 281e5af02a..5eed6683e6 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -1,4 +1,3 @@ -import io import json import os import zipfile @@ -7,6 +6,7 @@ from typing import Dict, Tuple import requests +from tqdm import tqdm from TTS.config import load_config from TTS.utils.generic_utils import get_user_data_dir @@ -337,11 +337,20 @@ def _update_path(field_name, new_path, config_path): def _download_zip_file(file_url, output_folder): """Download the github releases""" # download the file - r = requests.get(file_url) + r = requests.get(file_url, stream=True) # extract the file try: - with zipfile.ZipFile(io.BytesIO(r.content)) as z: + total_size_in_bytes = int(r.headers.get("content-length", 0)) + block_size = 1024 # 1 Kibibyte + progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) + temp_zip_name = os.path.join(output_folder, file_url.split("/")[-1]) + with open(temp_zip_name, "wb") as file: + for data in r.iter_content(block_size): + progress_bar.update(len(data)) + file.write(data) + with zipfile.ZipFile(temp_zip_name) as z: z.extractall(output_folder) + os.remove(temp_zip_name) # delete zip after extract except zipfile.BadZipFile: print(f" > Error: Bad zip file - {file_url}") raise zipfile.BadZipFile # pylint: disable=raise-missing-from diff --git a/TTS/encoder/utils/samplers.py b/TTS/utils/samplers.py similarity index 55% rename from TTS/encoder/utils/samplers.py rename to TTS/utils/samplers.py index 08256b347d..df5d4185eb 100644 --- a/TTS/encoder/utils/samplers.py +++ b/TTS/utils/samplers.py @@ -1,6 +1,8 @@ +import math import random +from typing import Callable, List, Union -from torch.utils.data.sampler import Sampler, SubsetRandomSampler +from torch.utils.data.sampler import BatchSampler, Sampler, SubsetRandomSampler class SubsetSampler(Sampler): @@ -112,3 +114,89 @@ def __iter__(self): def __len__(self): class_batch_size = self._batch_size // self._num_classes_in_batch return min(((len(s) + class_batch_size - 1) // class_batch_size) for s in self._samplers) + + +def identity(x): + return x + + +class SortedSampler(Sampler): + """Samples elements sequentially, always in the same order. + + Taken from https://github.com/PetrochukM/PyTorch-NLP + + Args: + data (iterable): Iterable data. + sort_key (callable): Specifies a function of one argument that is used to extract a + numerical comparison key from each list element. + + Example: + >>> list(SortedSampler(range(10), sort_key=lambda i: -i)) + [9, 8, 7, 6, 5, 4, 3, 2, 1, 0] + + """ + + def __init__(self, data, sort_key: Callable = identity): + super().__init__(data) + self.data = data + self.sort_key = sort_key + zip_ = [(i, self.sort_key(row)) for i, row in enumerate(self.data)] + zip_ = sorted(zip_, key=lambda r: r[1]) + self.sorted_indexes = [item[0] for item in zip_] + + def __iter__(self): + return iter(self.sorted_indexes) + + def __len__(self): + return len(self.data) + + +class BucketBatchSampler(BatchSampler): + """Bucket batch sampler + + Adapted from https://github.com/PetrochukM/PyTorch-NLP + + Args: + sampler (torch.data.utils.sampler.Sampler): + batch_size (int): Size of mini-batch. + drop_last (bool): If `True` the sampler will drop the last batch if its size would be less + than `batch_size`. + data (list): List of data samples. + sort_key (callable, optional): Callable to specify a comparison key for sorting. + bucket_size_multiplier (int, optional): Buckets are of size + `batch_size * bucket_size_multiplier`. + + Example: + >>> sampler = WeightedRandomSampler(weights, len(weights)) + >>> sampler = BucketBatchSampler(sampler, data=data_items, batch_size=32, drop_last=True) + """ + + def __init__( + self, + sampler, + data, + batch_size, + drop_last, + sort_key: Union[Callable, List] = identity, + bucket_size_multiplier=100, + ): + super().__init__(sampler, batch_size, drop_last) + self.data = data + self.sort_key = sort_key + _bucket_size = batch_size * bucket_size_multiplier + if hasattr(sampler, "__len__"): + _bucket_size = min(_bucket_size, len(sampler)) + self.bucket_sampler = BatchSampler(sampler, _bucket_size, False) + + def __iter__(self): + for idxs in self.bucket_sampler: + bucket_data = [self.data[idx] for idx in idxs] + sorted_sampler = SortedSampler(bucket_data, self.sort_key) + for batch_idx in SubsetRandomSampler(list(BatchSampler(sorted_sampler, self.batch_size, self.drop_last))): + sorted_idxs = [idxs[i] for i in batch_idx] + yield sorted_idxs + + def __len__(self): + if self.drop_last: + return len(self.sampler) // self.batch_size + return math.ceil(len(self.sampler) / self.batch_size) diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 2f31980972..170bb22355 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -307,7 +307,7 @@ def tts( waveform = waveform.squeeze() # trim silence - if self.tts_config.audio["do_trim_silence"] is True: + if "do_trim_silence" in self.tts_config.audio and self.tts_config.audio["do_trim_silence"]: waveform = trim_silence(waveform, self.tts_model.ap) wavs += list(waveform) diff --git a/TTS/vocoder/datasets/wavegrad_dataset.py b/TTS/vocoder/datasets/wavegrad_dataset.py index 05e0fae887..d941eab33e 100644 --- a/TTS/vocoder/datasets/wavegrad_dataset.py +++ b/TTS/vocoder/datasets/wavegrad_dataset.py @@ -149,4 +149,4 @@ def collate_full_clips(batch): mels[idx, :, : mel.shape[1]] = mel audios[idx, : audio.shape[0]] = audio - return audios, mels + return mels, audios diff --git a/TTS/vocoder/layers/losses.py b/TTS/vocoder/layers/losses.py index 848e292b83..befc43cca6 100644 --- a/TTS/vocoder/layers/losses.py +++ b/TTS/vocoder/layers/losses.py @@ -4,7 +4,7 @@ from torch import nn from torch.nn import functional as F -from TTS.utils.audio import TorchSTFT +from TTS.utils.audio.torch_transforms import TorchSTFT from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss ################################# diff --git a/TTS/vocoder/models/gan.py b/TTS/vocoder/models/gan.py index ed5b26dd93..a3803f7714 100644 --- a/TTS/vocoder/models/gan.py +++ b/TTS/vocoder/models/gan.py @@ -185,8 +185,7 @@ def train_step(self, batch: Dict, criterion: Dict, optimizer_idx: int) -> Tuple[ outputs = {"model_outputs": self.y_hat_g} return outputs, loss_dict - @staticmethod - def _log(name: str, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, Dict]: + def _log(self, name: str, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, Dict]: """Logging shared by the training and evaluation. Args: @@ -198,7 +197,7 @@ def _log(name: str, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dic Returns: Tuple[Dict, Dict]: log figures and audio samples. """ - y_hat = outputs[0]["model_outputs"] + y_hat = outputs[0]["model_outputs"] if self.train_disc else outputs[1]["model_outputs"] y = batch["waveform"] figures = plot_results(y_hat, y, ap, name) sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy() diff --git a/TTS/vocoder/models/univnet_discriminator.py b/TTS/vocoder/models/univnet_discriminator.py index d6b0e5d52c..34e2d1c276 100644 --- a/TTS/vocoder/models/univnet_discriminator.py +++ b/TTS/vocoder/models/univnet_discriminator.py @@ -3,7 +3,7 @@ from torch import nn from torch.nn.utils import spectral_norm, weight_norm -from TTS.utils.audio import TorchSTFT +from TTS.utils.audio.torch_transforms import TorchSTFT from TTS.vocoder.models.hifigan_discriminator import MultiPeriodDiscriminator LRELU_SLOPE = 0.1 diff --git a/TTS/vocoder/models/wavernn.py b/TTS/vocoder/models/wavernn.py index 6686db45dd..e0a25e324b 100644 --- a/TTS/vocoder/models/wavernn.py +++ b/TTS/vocoder/models/wavernn.py @@ -233,6 +233,7 @@ class of models has however remained an elusive problem. With a focus on text-to else: raise RuntimeError("Unknown model mode value - ", self.args.mode) + self.ap = AudioProcessor(**config.audio.to_dict()) self.aux_dims = self.args.res_out_dims // 4 if self.args.use_upsample_net: @@ -571,7 +572,7 @@ def eval_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]: def test( self, assets: Dict, test_loader: "DataLoader", output: Dict # pylint: disable=unused-argument ) -> Tuple[Dict, Dict]: - ap = assets["audio_processor"] + ap = self.ap figures = {} audios = {} samples = test_loader.dataset.load_test_samples(1) @@ -587,8 +588,16 @@ def test( } ) audios.update({f"test_{idx}/audio": y_hat}) + # audios.update({f"real_{idx}/audio": y_hat}) return figures, audios + def test_log( + self, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument + ) -> Tuple[Dict, np.ndarray]: + figures, audios = outputs + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + @staticmethod def format_batch(batch: Dict) -> Dict: waveform = batch[0] @@ -605,7 +614,7 @@ def get_data_loader( # pylint: disable=no-self-use verbose: bool, num_gpus: int, ): - ap = assets["audio_processor"] + ap = self.ap dataset = WaveRNNDataset( ap=ap, items=samples, diff --git a/notebooks/dataset_analysis/AnalyzeDataset.ipynb b/notebooks/dataset_analysis/AnalyzeDataset.ipynb index 519638472c..7fc51a3adf 100644 --- a/notebooks/dataset_analysis/AnalyzeDataset.ipynb +++ b/notebooks/dataset_analysis/AnalyzeDataset.ipynb @@ -45,7 +45,7 @@ "source": [ "NUM_PROC = 8\n", "DATASET_CONFIG = BaseDatasetConfig(\n", - " name=\"ljspeech\", meta_file_train=\"metadata.csv\", path=\"/home/ubuntu/TTS/depot/data/male_dataset1_44k/\"\n", + " name=\"ljspeech\", meta_file_train=\"metadata.csv\", path=\"/absolute/path/to/your/dataset/\"\n", ")" ] }, @@ -58,13 +58,13 @@ "def formatter(root_path, meta_file, **kwargs): # pylint: disable=unused-argument\n", " txt_file = os.path.join(root_path, meta_file)\n", " items = []\n", - " speaker_name = \"maledataset1\"\n", + " speaker_name = \"myspeaker\"\n", " with open(txt_file, \"r\", encoding=\"utf-8\") as ttf:\n", " for line in ttf:\n", " cols = line.split(\"|\")\n", - " wav_file = os.path.join(root_path, \"wavs\", cols[0])\n", + " wav_file = os.path.join(root_path, \"wavs\", cols[0] + \".wav\") \n", " text = cols[1]\n", - " items.append([text, wav_file, speaker_name])\n", + " items.append({\"text\": text, \"audio_file\": wav_file, \"speaker_name\": speaker_name})\n", " return items" ] }, @@ -78,7 +78,10 @@ "source": [ "# use your own preprocessor at this stage - TTS/datasets/proprocess.py\n", "train_samples, eval_samples = load_tts_samples(DATASET_CONFIG, eval_split=True, formatter=formatter)\n", - "items = train_samples + eval_samples\n", + "if eval_samples is not None:\n", + " items = train_samples + eval_samples\n", + "else:\n", + " items = train_samples\n", "print(\" > Number of audio files: {}\".format(len(items)))\n", "print(items[1])" ] @@ -94,7 +97,7 @@ "# check wavs if exist\n", "wav_files = []\n", "for item in items:\n", - " wav_file = item[1].strip()\n", + " wav_file = item[\"audio_file\"].strip()\n", " wav_files.append(wav_file)\n", " if not os.path.exists(wav_file):\n", " print(waf_path)" @@ -131,8 +134,8 @@ "outputs": [], "source": [ "def load_item(item):\n", - " text = item[0].strip()\n", - " file_name = item[1].strip()\n", + " text = item[\"text\"].strip()\n", + " file_name = item[\"audio_file\"].strip()\n", " audio, sr = librosa.load(file_name, sr=None)\n", " audio_len = len(audio) / sr\n", " text_len = len(text)\n", @@ -416,7 +419,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.5" + "version": "3.9.12" } }, "nbformat": 4, diff --git a/notebooks/dataset_analysis/PhonemeCoverage.ipynb b/notebooks/dataset_analysis/PhonemeCoverage.ipynb index 2b7f5d670f..d481ed2910 100644 --- a/notebooks/dataset_analysis/PhonemeCoverage.ipynb +++ b/notebooks/dataset_analysis/PhonemeCoverage.ipynb @@ -37,7 +37,7 @@ "# set some vars\n", "# TTS_PATH = \"/home/thorsten/___dev/tts/mozilla/TTS\"\n", "CONFIG_FILE = \"/path/to/config/config.json\"\n", - "CHARS_TO_REMOVE = \".,:!?'\"" + "CHARS_TO_REMOVE = \".,:!?'\"\n" ] }, { @@ -59,7 +59,8 @@ "# extra imports that might not be included in requirements.txt\n", "import collections\n", "import operator\n", - "\n" + "\n", + "%matplotlib inline" ] }, { @@ -75,7 +76,7 @@ "CONFIG = load_config(CONFIG_FILE)\n", "\n", "# Load some properties from config.json\n", - "CONFIG_METADATA = sorted(load_tts_samples(CONFIG.datasets)[0])\n", + "CONFIG_METADATA = load_tts_samples(CONFIG.datasets)[0]\n", "CONFIG_METADATA = CONFIG_METADATA\n", "CONFIG_DATASET = CONFIG.datasets[0]\n", "CONFIG_PHONEME_LANGUAGE = CONFIG.phoneme_language\n", @@ -84,7 +85,10 @@ "\n", "# Will be printed on generated output graph\n", "CONFIG_RUN_NAME = CONFIG.run_name\n", - "CONFIG_RUN_DESC = CONFIG.run_description" + "CONFIG_RUN_DESC = CONFIG.run_description\n", + "\n", + "# Needed to convert text to phonemes and phonemes to ids\n", + "tokenizer, config = TTSTokenizer.init_from_config(CONFIG)" ] }, { @@ -112,12 +116,13 @@ "source": [ "def get_phoneme_from_sequence(text):\n", " temp_list = []\n", - " if len(text[0]) > 0:\n", - " temp_text = text[0].rstrip('\\n')\n", + " if len(text[\"text\"]) > 0:\n", + " #temp_text = text[0].rstrip('\\n')\n", + " temp_text = text[\"text\"].rstrip('\\n')\n", " for rm_bad_chars in CHARS_TO_REMOVE:\n", " temp_text = temp_text.replace(rm_bad_chars,\"\")\n", - " seq = phoneme_to_sequence(temp_text, [CONFIG_TEXT_CLEANER], CONFIG_PHONEME_LANGUAGE, CONFIG_ENABLE_EOS_BOS_CHARS)\n", - " text = sequence_to_phoneme(seq)\n", + " seq = tokenizer.text_to_ids(temp_text)\n", + " text = tokenizer.ids_to_text(seq)\n", " text = text.replace(\" \",\"\")\n", " temp_list.append(text)\n", " return temp_list" @@ -229,7 +234,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -243,7 +248,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.5" + "version": "3.9.12" } }, "nbformat": 4, diff --git a/recipes/blizzard2013/tacotron1-Capacitron/train_capacitron_t1.py b/recipes/blizzard2013/tacotron1-Capacitron/train_capacitron_t1.py index 52c6098fa2..060cb9d4df 100644 --- a/recipes/blizzard2013/tacotron1-Capacitron/train_capacitron_t1.py +++ b/recipes/blizzard2013/tacotron1-Capacitron/train_capacitron_t1.py @@ -48,7 +48,6 @@ precompute_num_workers=24, run_eval=True, test_delay_epochs=5, - ga_alpha=0.0, r=2, optimizer="CapacitronOptimizer", optimizer_params={"RAdam": {"betas": [0.9, 0.998], "weight_decay": 1e-6}, "SGD": {"lr": 1e-5, "momentum": 0.9}}, @@ -68,16 +67,15 @@ datasets=[dataset_config], lr=1e-3, lr_scheduler="StepwiseGradualLR", - lr_scheduler_params={"gradual_learning_rates": [[0, 1e-3], [2e4, 5e-4], [4e5, 3e-4], [6e4, 1e-4], [8e4, 5e-5]]}, + lr_scheduler_params={"gradual_learning_rates": [[0, 1e-3], [2e4, 5e-4], [4e4, 3e-4], [6e4, 1e-4], [8e4, 5e-5]]}, scheduler_after_epoch=False, # scheduler doesn't work without this flag - # Need to experiment with these below for capacitron loss_masking=False, decoder_loss_alpha=1.0, postnet_loss_alpha=1.0, - postnet_diff_spec_alpha=0.0, - decoder_diff_spec_alpha=0.0, - decoder_ssim_alpha=0.0, - postnet_ssim_alpha=0.0, + postnet_diff_spec_alpha=1.0, + decoder_diff_spec_alpha=1.0, + decoder_ssim_alpha=1.0, + postnet_ssim_alpha=1.0, ) ap = AudioProcessor(**config.audio.to_dict()) diff --git a/recipes/blizzard2013/tacotron2-Capacitron/train_capacitron_t2.py b/recipes/blizzard2013/tacotron2-Capacitron/train_capacitron_t2.py index cf27b9dfd1..1bd2a03630 100644 --- a/recipes/blizzard2013/tacotron2-Capacitron/train_capacitron_t2.py +++ b/recipes/blizzard2013/tacotron2-Capacitron/train_capacitron_t2.py @@ -52,7 +52,6 @@ precompute_num_workers=24, run_eval=True, test_delay_epochs=5, - ga_alpha=0.0, r=2, optimizer="CapacitronOptimizer", optimizer_params={"RAdam": {"betas": [0.9, 0.998], "weight_decay": 1e-6}, "SGD": {"lr": 1e-5, "momentum": 0.9}}, @@ -77,23 +76,20 @@ "gradual_learning_rates": [ [0, 1e-3], [2e4, 5e-4], - [4e5, 3e-4], + [4e4, 3e-4], [6e4, 1e-4], [8e4, 5e-5], ] }, scheduler_after_epoch=False, # scheduler doesn't work without this flag - # dashboard_logger='wandb', - # sort_by_audio_len=True, seq_len_norm=True, - # Need to experiment with these below for capacitron loss_masking=False, decoder_loss_alpha=1.0, postnet_loss_alpha=1.0, - postnet_diff_spec_alpha=0.0, - decoder_diff_spec_alpha=0.0, - decoder_ssim_alpha=0.0, - postnet_ssim_alpha=0.0, + postnet_diff_spec_alpha=1.0, + decoder_diff_spec_alpha=1.0, + decoder_ssim_alpha=1.0, + postnet_ssim_alpha=1.0, ) ap = AudioProcessor(**config.audio.to_dict()) diff --git a/recipes/ljspeech/fast_pitch/train_fast_pitch.py b/recipes/ljspeech/fast_pitch/train_fast_pitch.py index a84658f35f..1c0e470226 100644 --- a/recipes/ljspeech/fast_pitch/train_fast_pitch.py +++ b/recipes/ljspeech/fast_pitch/train_fast_pitch.py @@ -54,7 +54,6 @@ print_step=50, print_eval=False, mixed_precision=False, - sort_by_audio_len=True, max_seq_len=500000, output_path=output_path, datasets=[dataset_config], diff --git a/recipes/ljspeech/fast_speech/train_fast_speech.py b/recipes/ljspeech/fast_speech/train_fast_speech.py index 0245dd938b..ab7e884104 100644 --- a/recipes/ljspeech/fast_speech/train_fast_speech.py +++ b/recipes/ljspeech/fast_speech/train_fast_speech.py @@ -53,7 +53,6 @@ print_step=50, print_eval=False, mixed_precision=False, - sort_by_audio_len=True, max_seq_len=500000, output_path=output_path, datasets=[dataset_config], diff --git a/recipes/ljspeech/speedy_speech/train_speedy_speech.py b/recipes/ljspeech/speedy_speech/train_speedy_speech.py index 1ab3db1c2e..fd3c8679d9 100644 --- a/recipes/ljspeech/speedy_speech/train_speedy_speech.py +++ b/recipes/ljspeech/speedy_speech/train_speedy_speech.py @@ -46,7 +46,6 @@ print_step=50, print_eval=False, mixed_precision=False, - sort_by_audio_len=True, max_seq_len=500000, output_path=output_path, datasets=[dataset_config], diff --git a/recipes/ljspeech/tacotron2-Capacitron/train_capacitron_t2.py b/recipes/ljspeech/tacotron2-Capacitron/train_capacitron_t2.py index 6bb0aed782..a1882451db 100644 --- a/recipes/ljspeech/tacotron2-Capacitron/train_capacitron_t2.py +++ b/recipes/ljspeech/tacotron2-Capacitron/train_capacitron_t2.py @@ -68,7 +68,6 @@ print_step=25, print_eval=True, mixed_precision=False, - sort_by_audio_len=True, seq_len_norm=True, output_path=output_path, datasets=[dataset_config], diff --git a/recipes/ljspeech/vits_tts/train_vits.py b/recipes/ljspeech/vits_tts/train_vits.py index c070b3f1ce..94e230a1da 100644 --- a/recipes/ljspeech/vits_tts/train_vits.py +++ b/recipes/ljspeech/vits_tts/train_vits.py @@ -2,11 +2,10 @@ from trainer import Trainer, TrainerArgs -from TTS.config.shared_configs import BaseAudioConfig from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.datasets import load_tts_samples -from TTS.tts.models.vits import Vits +from TTS.tts.models.vits import Vits, VitsAudioConfig from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor @@ -14,21 +13,8 @@ dataset_config = BaseDatasetConfig( name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/") ) -audio_config = BaseAudioConfig( - sample_rate=22050, - win_length=1024, - hop_length=256, - num_mels=80, - preemphasis=0.0, - ref_level_db=20, - log_func="np.log", - do_trim_silence=True, - trim_db=45, - mel_fmin=0, - mel_fmax=None, - spec_gain=1.0, - signal_norm=False, - do_amp_to_db_linear=False, +audio_config = VitsAudioConfig( + sample_rate=22050, win_length=1024, hop_length=256, num_mels=80, mel_fmin=0, mel_fmax=None ) config = VitsConfig( @@ -37,7 +23,7 @@ batch_size=32, eval_batch_size=16, batch_group_size=5, - num_loader_workers=0, + num_loader_workers=8, num_eval_loader_workers=4, run_eval=True, test_delay_epochs=-1, @@ -52,6 +38,7 @@ mixed_precision=True, output_path=output_path, datasets=[dataset_config], + cudnn_benchmark=False, ) # INITIALIZE THE AUDIO PROCESSOR diff --git a/recipes/multilingual/vits_tts/train_vits_tts.py b/recipes/multilingual/vits_tts/train_vits_tts.py index 0e650ade8e..0a9cced418 100644 --- a/recipes/multilingual/vits_tts/train_vits_tts.py +++ b/recipes/multilingual/vits_tts/train_vits_tts.py @@ -3,11 +3,10 @@ from trainer import Trainer, TrainerArgs -from TTS.config.shared_configs import BaseAudioConfig from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.datasets import load_tts_samples -from TTS.tts.models.vits import CharactersConfig, Vits, VitsArgs +from TTS.tts.models.vits import CharactersConfig, Vits, VitsArgs, VitsAudioConfig from TTS.tts.utils.languages import LanguageManager from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.text.tokenizer import TTSTokenizer @@ -22,22 +21,13 @@ for path in dataset_paths ] -audio_config = BaseAudioConfig( +audio_config = VitsAudioConfig( sample_rate=16000, win_length=1024, hop_length=256, num_mels=80, - preemphasis=0.0, - ref_level_db=20, - log_func="np.log", - do_trim_silence=False, - trim_db=23.0, mel_fmin=0, mel_fmax=None, - spec_gain=1.0, - signal_norm=True, - do_amp_to_db_linear=False, - resample=False, ) vitsArgs = VitsArgs( @@ -69,7 +59,6 @@ use_language_weighted_sampler=True, print_eval=False, mixed_precision=False, - sort_by_audio_len=True, min_audio_len=32 * 256 * 4, max_audio_len=160000, output_path=output_path, diff --git a/recipes/thorsten_DE/speedy_speech/train_speedy_speech.py b/recipes/thorsten_DE/speedy_speech/train_speedy_speech.py index 1a4c8ec86d..8f2413066f 100644 --- a/recipes/thorsten_DE/speedy_speech/train_speedy_speech.py +++ b/recipes/thorsten_DE/speedy_speech/train_speedy_speech.py @@ -60,7 +60,6 @@ "Dieser Kuchen ist großartig. Er ist so lecker und feucht.", "Vor dem 22. November 1963.", ], - sort_by_audio_len=True, max_seq_len=500000, output_path=output_path, datasets=[dataset_config], diff --git a/recipes/thorsten_DE/vits_tts/train_vits.py b/recipes/thorsten_DE/vits_tts/train_vits.py index 86a7dfe68a..25c57b64fc 100644 --- a/recipes/thorsten_DE/vits_tts/train_vits.py +++ b/recipes/thorsten_DE/vits_tts/train_vits.py @@ -2,11 +2,10 @@ from trainer import Trainer, TrainerArgs -from TTS.config.shared_configs import BaseAudioConfig from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.datasets import load_tts_samples -from TTS.tts.models.vits import Vits +from TTS.tts.models.vits import Vits, VitsAudioConfig from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor from TTS.utils.downloaders import download_thorsten_de @@ -21,21 +20,13 @@ print("Downloading dataset") download_thorsten_de(os.path.split(os.path.abspath(dataset_config.path))[0]) -audio_config = BaseAudioConfig( +audio_config = VitsAudioConfig( sample_rate=22050, win_length=1024, hop_length=256, num_mels=80, - preemphasis=0.0, - ref_level_db=20, - log_func="np.log", - do_trim_silence=True, - trim_db=45, mel_fmin=0, mel_fmax=None, - spec_gain=1.0, - signal_norm=False, - do_amp_to_db_linear=False, ) config = VitsConfig( diff --git a/recipes/vctk/download_vctk.sh b/recipes/vctk/download_vctk.sh index c0cea74382..d08a53c61c 100644 --- a/recipes/vctk/download_vctk.sh +++ b/recipes/vctk/download_vctk.sh @@ -2,7 +2,7 @@ # take the scripts's parent's directory to prefix all the output paths. RUN_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" echo $RUN_DIR -# download LJSpeech dataset +# download VCTK dataset wget https://datashare.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip -O VCTK-Corpus-0.92.zip # extract mkdir VCTK diff --git a/recipes/vctk/vits/train_vits.py b/recipes/vctk/vits/train_vits.py index 88fd7de9a1..814d0989e6 100644 --- a/recipes/vctk/vits/train_vits.py +++ b/recipes/vctk/vits/train_vits.py @@ -2,11 +2,10 @@ from trainer import Trainer, TrainerArgs -from TTS.config.shared_configs import BaseAudioConfig from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.datasets import load_tts_samples -from TTS.tts.models.vits import Vits, VitsArgs +from TTS.tts.models.vits import Vits, VitsArgs, VitsAudioConfig from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor @@ -17,22 +16,8 @@ ) -audio_config = BaseAudioConfig( - sample_rate=22050, - win_length=1024, - hop_length=256, - num_mels=80, - preemphasis=0.0, - ref_level_db=20, - log_func="np.log", - do_trim_silence=True, - trim_db=23.0, - mel_fmin=0, - mel_fmax=None, - spec_gain=1.0, - signal_norm=False, - do_amp_to_db_linear=False, - resample=True, +audio_config = VitsAudioConfig( + sample_rate=22050, win_length=1024, hop_length=256, num_mels=80, mel_fmin=0, mel_fmax=None ) vitsArgs = VitsArgs( @@ -62,6 +47,7 @@ max_text_len=325, # change this if you have a larger VRAM than 16GB output_path=output_path, datasets=[dataset_config], + cudnn_benchmark=False, ) # INITIALIZE THE AUDIO PROCESSOR diff --git a/requirements.txt b/requirements.txt index b3acfeca4e..a9416112e0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,15 @@ # core deps -numpy==1.21.6 +numpy==1.21.6;python_version<"3.10" +numpy==1.22.4;python_version=="3.10" cython==0.29.28 scipy>=1.4.0 torch>=1.7 torchaudio soundfile librosa==0.8.0 -numba==0.55.1 -inflect +numba==0.55.1;python_version<"3.10" +numba==0.55.2;python_version=="3.10" +inflect==5.6.0 tqdm anyascii pyyaml diff --git a/run_bash_tests.sh b/run_bash_tests.sh index feb9082bd3..2f5ba88934 100755 --- a/run_bash_tests.sh +++ b/run_bash_tests.sh @@ -4,5 +4,4 @@ TF_CPP_MIN_LOG_LEVEL=3 # runtime bash based tests # TODO: move these to python ./tests/bash_tests/test_demo_server.sh && \ -./tests/bash_tests/test_resample.sh && \ ./tests/bash_tests/test_compute_statistics.sh diff --git a/setup.py b/setup.py index 3c8609499d..f95d79f14c 100644 --- a/setup.py +++ b/setup.py @@ -90,7 +90,7 @@ def pip_install(package_name): # ext_modules=find_cython_extensions(), # package include_package_data=True, - packages=find_packages(include=["TTS*"]), + packages=find_packages(include=["TTS"], exclude=["*.tests", "*tests.*", "tests.*", "*tests", "tests"]), package_data={ "TTS": [ "VERSION", diff --git a/tests/aux_tests/test_audio_processor.py b/tests/aux_tests/test_audio_processor.py index 566116923b..d01aeffa3f 100644 --- a/tests/aux_tests/test_audio_processor.py +++ b/tests/aux_tests/test_audio_processor.py @@ -3,7 +3,7 @@ from tests import get_tests_input_path, get_tests_output_path, get_tests_path from TTS.config import BaseAudioConfig -from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.processor import AudioProcessor TESTS_PATH = get_tests_path() OUT_PATH = os.path.join(get_tests_output_path(), "audio_tests") diff --git a/tests/aux_tests/test_numpy_transforms.py b/tests/aux_tests/test_numpy_transforms.py new file mode 100644 index 0000000000..0c1836b9b1 --- /dev/null +++ b/tests/aux_tests/test_numpy_transforms.py @@ -0,0 +1,105 @@ +import math +import os +import unittest +from dataclasses import dataclass + +import librosa +import numpy as np +from coqpit import Coqpit + +from tests import get_tests_input_path, get_tests_output_path, get_tests_path +from TTS.utils.audio import numpy_transforms as np_transforms + +TESTS_PATH = get_tests_path() +OUT_PATH = os.path.join(get_tests_output_path(), "audio_tests") +WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav") + +os.makedirs(OUT_PATH, exist_ok=True) + + +# pylint: disable=no-self-use + + +class TestNumpyTransforms(unittest.TestCase): + def setUp(self) -> None: + @dataclass + class AudioConfig(Coqpit): + sample_rate: int = 22050 + fft_size: int = 1024 + num_mels: int = 256 + mel_fmax: int = 1800 + mel_fmin: int = 0 + hop_length: int = 256 + win_length: int = 1024 + pitch_fmax: int = 450 + trim_db: int = -1 + min_silence_sec: float = 0.01 + gain: float = 1.0 + base: float = 10.0 + + self.config = AudioConfig() + self.sample_wav, _ = librosa.load(WAV_FILE, sr=self.config.sample_rate) + + def test_build_mel_basis(self): + """Check if the mel basis is correctly built""" + print(" > Testing mel basis building.") + mel_basis = np_transforms.build_mel_basis(**self.config) + self.assertEqual(mel_basis.shape, (self.config.num_mels, self.config.fft_size // 2 + 1)) + + def test_millisec_to_length(self): + """Check if the conversion from milliseconds to length is correct""" + print(" > Testing millisec to length conversion.") + win_len, hop_len = np_transforms.millisec_to_length( + frame_length_ms=1000, frame_shift_ms=12.5, sample_rate=self.config.sample_rate + ) + self.assertEqual(hop_len, int(12.5 / 1000.0 * self.config.sample_rate)) + self.assertEqual(win_len, self.config.sample_rate) + + def test_amplitude_db_conversion(self): + di = np.random.rand(11) + o1 = np_transforms.amp_to_db(x=di, gain=1.0, base=10) + o2 = np_transforms.db_to_amp(x=o1, gain=1.0, base=10) + np.testing.assert_almost_equal(di, o2, decimal=5) + + def test_preemphasis_deemphasis(self): + di = np.random.rand(11) + o1 = np_transforms.preemphasis(x=di, coeff=0.95) + o2 = np_transforms.deemphasis(x=o1, coeff=0.95) + np.testing.assert_almost_equal(di, o2, decimal=5) + + def test_spec_to_mel(self): + mel_basis = np_transforms.build_mel_basis(**self.config) + spec = np.random.rand(self.config.fft_size // 2 + 1, 20) # [C, T] + mel = np_transforms.spec_to_mel(spec=spec, mel_basis=mel_basis) + self.assertEqual(mel.shape, (self.config.num_mels, 20)) + + def mel_to_spec(self): + mel_basis = np_transforms.build_mel_basis(**self.config) + mel = np.random.rand(self.config.num_mels, 20) # [C, T] + spec = np_transforms.mel_to_spec(mel=mel, mel_basis=mel_basis) + self.assertEqual(spec.shape, (self.config.fft_size // 2 + 1, 20)) + + def test_wav_to_spec(self): + spec = np_transforms.wav_to_spec(wav=self.sample_wav, **self.config) + self.assertEqual( + spec.shape, (self.config.fft_size // 2 + 1, math.ceil(self.sample_wav.shape[0] / self.config.hop_length)) + ) + + def test_wav_to_mel(self): + mel_basis = np_transforms.build_mel_basis(**self.config) + mel = np_transforms.wav_to_mel(wav=self.sample_wav, mel_basis=mel_basis, **self.config) + self.assertEqual( + mel.shape, (self.config.num_mels, math.ceil(self.sample_wav.shape[0] / self.config.hop_length)) + ) + + def test_compute_f0(self): + pitch = np_transforms.compute_f0(x=self.sample_wav, **self.config) + mel_basis = np_transforms.build_mel_basis(**self.config) + mel = np_transforms.wav_to_mel(wav=self.sample_wav, mel_basis=mel_basis, **self.config) + assert pitch.shape[0] == mel.shape[1] + + def test_load_wav(self): + wav = np_transforms.load_wav(filename=WAV_FILE, resample=False, sample_rate=22050) + wav_resample = np_transforms.load_wav(filename=WAV_FILE, resample=True, sample_rate=16000) + self.assertEqual(wav.shape, (self.sample_wav.shape[0],)) + self.assertNotEqual(wav_resample.shape, (self.sample_wav.shape[0],)) diff --git a/tests/aux_tests/test_remove_silence_vad_script.py b/tests/aux_tests/test_remove_silence_vad_script.py deleted file mode 100644 index c934e06551..0000000000 --- a/tests/aux_tests/test_remove_silence_vad_script.py +++ /dev/null @@ -1,29 +0,0 @@ -import os -import unittest - -import torch - -from tests import get_tests_input_path, get_tests_output_path, run_cli - -torch.manual_seed(1) - -# pylint: disable=protected-access -class TestRemoveSilenceVAD(unittest.TestCase): - @staticmethod - def test(): - # set paths - wav_path = os.path.join(get_tests_input_path(), "../data/ljspeech/wavs") - output_path = os.path.join(get_tests_output_path(), "output_wavs_removed_silence/") - output_resample_path = os.path.join(get_tests_output_path(), "output_ljspeech_16khz/") - - # resample audios - run_cli( - f'CUDA_VISIBLE_DEVICES="" python TTS/bin/resample.py --input_dir "{wav_path}" --output_dir "{output_resample_path}" --output_sr 16000' - ) - - # run test - run_cli( - f'CUDA_VISIBLE_DEVICES="" python TTS/bin/remove_silence_using_vad.py --input_dir "{output_resample_path}" --output_dir "{output_path}"' - ) - run_cli(f'rm -rf "{output_resample_path}"') - run_cli(f'rm -rf "{output_path}"') diff --git a/tests/bash_tests/test_resample.sh b/tests/bash_tests/test_resample.sh deleted file mode 100755 index ba87127298..0000000000 --- a/tests/bash_tests/test_resample.sh +++ /dev/null @@ -1,16 +0,0 @@ -#!/usr/bin/env bash -set -xe -BASEDIR=$(dirname "$0") -TARGET_SR=16000 -echo "$BASEDIR" -#run the resample script -python TTS/bin/resample.py --input_dir $BASEDIR/../data/ljspeech --output_dir $BASEDIR/outputs/resample_tests --output_sr $TARGET_SR -#check samplerate of output -OUT_SR=$( (echo "import librosa" ; echo "y, sr = librosa.load('"$BASEDIR"/outputs/resample_tests/wavs/LJ001-0012.wav', sr=None)" ; echo "print(sr)") | python ) -OUT_SR=$(($OUT_SR + 0)) -if [[ $OUT_SR -ne $TARGET_SR ]]; then - echo "Missmatch between target and output sample rates" - exit 1 -fi -#cleaning up -rm -rf $BASEDIR/outputs/resample_tests diff --git a/tests/data_tests/test_samplers.py b/tests/data_tests/test_samplers.py index b85e0ec4b3..730d0d8bd1 100644 --- a/tests/data_tests/test_samplers.py +++ b/tests/data_tests/test_samplers.py @@ -5,11 +5,11 @@ import torch from TTS.config.shared_configs import BaseDatasetConfig -from TTS.encoder.utils.samplers import PerfectBatchSampler from TTS.tts.datasets import load_tts_samples from TTS.tts.utils.data import get_length_balancer_weights from TTS.tts.utils.languages import get_language_balancer_weights from TTS.tts.utils.speakers import get_speaker_balancer_weights +from TTS.utils.samplers import BucketBatchSampler, PerfectBatchSampler # Fixing random state to avoid random fails torch.manual_seed(0) @@ -163,3 +163,31 @@ def test_length_weighted_random_sampler(self): # pylint: disable=no-self-use else: len2 += 1 assert is_balanced(len1, len2), "Length Weighted sampler is supposed to be balanced" + + def test_bucket_batch_sampler(self): + bucket_size_multiplier = 2 + sampler = range(len(train_samples)) + sampler = BucketBatchSampler( + sampler, + data=train_samples, + batch_size=7, + drop_last=True, + sort_key=lambda x: len(x["text"]), + bucket_size_multiplier=bucket_size_multiplier, + ) + + # check if the samples are sorted by text lenght whuile bucketing + min_text_len_in_bucket = 0 + bucket_items = [] + for batch_idx, batch in enumerate(list(sampler)): + if (batch_idx + 1) % bucket_size_multiplier == 0: + for bucket_item in bucket_items: + self.assertLessEqual(min_text_len_in_bucket, len(train_samples[bucket_item]["text"])) + min_text_len_in_bucket = len(train_samples[bucket_item]["text"]) + min_text_len_in_bucket = 0 + bucket_items = [] + else: + bucket_items += batch + + # check sampler length + self.assertEqual(len(sampler), len(train_samples) // 7) diff --git a/tests/text_tests/test_tokenizer.py b/tests/text_tests/test_tokenizer.py index 908952ea20..6e95c0ad81 100644 --- a/tests/text_tests/test_tokenizer.py +++ b/tests/text_tests/test_tokenizer.py @@ -30,6 +30,13 @@ def test_text_to_ids_phonemes(self): test_hat = self.tokenizer_ph.ids_to_text(ids) self.assertEqual(text_ph, test_hat) + def test_text_to_ids_phonemes_punctuation(self): + text = "..." + text_ph = self.ph.phonemize(text, separator="") + ids = self.tokenizer_ph.text_to_ids(text) + test_hat = self.tokenizer_ph.ids_to_text(ids) + self.assertEqual(text_ph, test_hat) + def test_text_to_ids_phonemes_with_eos_bos(self): text = "Bu bir Örnek." self.tokenizer_ph.use_eos_bos = True diff --git a/tests/tts_tests/test_losses.py b/tests/tts_tests/test_losses.py new file mode 100644 index 0000000000..522b7bb17c --- /dev/null +++ b/tests/tts_tests/test_losses.py @@ -0,0 +1,239 @@ +import unittest + +import torch as T + +from TTS.tts.layers.losses import BCELossMasked, L1LossMasked, MSELossMasked, SSIMLoss +from TTS.tts.utils.helpers import sequence_mask + + +class L1LossMaskedTests(unittest.TestCase): + def test_in_out(self): # pylint: disable=no-self-use + # test input == target + layer = L1LossMasked(seq_len_norm=False) + dummy_input = T.ones(4, 8, 128).float() + dummy_target = T.ones(4, 8, 128).float() + dummy_length = (T.ones(4) * 8).long() + output = layer(dummy_input, dummy_target, dummy_length) + assert output.item() == 0.0 + + # test input != target + dummy_input = T.ones(4, 8, 128).float() + dummy_target = T.zeros(4, 8, 128).float() + dummy_length = (T.ones(4) * 8).long() + output = layer(dummy_input, dummy_target, dummy_length) + assert output.item() == 1.0, "1.0 vs {}".format(output.item()) + + # test if padded values of input makes any difference + dummy_input = T.ones(4, 8, 128).float() + dummy_target = T.zeros(4, 8, 128).float() + dummy_length = (T.arange(5, 9)).long() + mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) + output = layer(dummy_input + mask, dummy_target, dummy_length) + assert output.item() == 1.0, "1.0 vs {}".format(output.item()) + + dummy_input = T.rand(4, 8, 128).float() + dummy_target = dummy_input.detach() + dummy_length = (T.arange(5, 9)).long() + mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) + output = layer(dummy_input + mask, dummy_target, dummy_length) + assert output.item() == 0, "0 vs {}".format(output.item()) + + # seq_len_norm = True + # test input == target + layer = L1LossMasked(seq_len_norm=True) + dummy_input = T.ones(4, 8, 128).float() + dummy_target = T.ones(4, 8, 128).float() + dummy_length = (T.ones(4) * 8).long() + output = layer(dummy_input, dummy_target, dummy_length) + assert output.item() == 0.0 + + # test input != target + dummy_input = T.ones(4, 8, 128).float() + dummy_target = T.zeros(4, 8, 128).float() + dummy_length = (T.ones(4) * 8).long() + output = layer(dummy_input, dummy_target, dummy_length) + assert output.item() == 1.0, "1.0 vs {}".format(output.item()) + + # test if padded values of input makes any difference + dummy_input = T.ones(4, 8, 128).float() + dummy_target = T.zeros(4, 8, 128).float() + dummy_length = (T.arange(5, 9)).long() + mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) + output = layer(dummy_input + mask, dummy_target, dummy_length) + assert abs(output.item() - 1.0) < 1e-5, "1.0 vs {}".format(output.item()) + + dummy_input = T.rand(4, 8, 128).float() + dummy_target = dummy_input.detach() + dummy_length = (T.arange(5, 9)).long() + mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) + output = layer(dummy_input + mask, dummy_target, dummy_length) + assert output.item() == 0, "0 vs {}".format(output.item()) + + +class MSELossMaskedTests(unittest.TestCase): + def test_in_out(self): # pylint: disable=no-self-use + # test input == target + layer = MSELossMasked(seq_len_norm=False) + dummy_input = T.ones(4, 8, 128).float() + dummy_target = T.ones(4, 8, 128).float() + dummy_length = (T.ones(4) * 8).long() + output = layer(dummy_input, dummy_target, dummy_length) + assert output.item() == 0.0 + + # test input != target + dummy_input = T.ones(4, 8, 128).float() + dummy_target = T.zeros(4, 8, 128).float() + dummy_length = (T.ones(4) * 8).long() + output = layer(dummy_input, dummy_target, dummy_length) + assert output.item() == 1.0, "1.0 vs {}".format(output.item()) + + # test if padded values of input makes any difference + dummy_input = T.ones(4, 8, 128).float() + dummy_target = T.zeros(4, 8, 128).float() + dummy_length = (T.arange(5, 9)).long() + mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) + output = layer(dummy_input + mask, dummy_target, dummy_length) + assert output.item() == 1.0, "1.0 vs {}".format(output.item()) + + dummy_input = T.rand(4, 8, 128).float() + dummy_target = dummy_input.detach() + dummy_length = (T.arange(5, 9)).long() + mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) + output = layer(dummy_input + mask, dummy_target, dummy_length) + assert output.item() == 0, "0 vs {}".format(output.item()) + + # seq_len_norm = True + # test input == target + layer = MSELossMasked(seq_len_norm=True) + dummy_input = T.ones(4, 8, 128).float() + dummy_target = T.ones(4, 8, 128).float() + dummy_length = (T.ones(4) * 8).long() + output = layer(dummy_input, dummy_target, dummy_length) + assert output.item() == 0.0 + + # test input != target + dummy_input = T.ones(4, 8, 128).float() + dummy_target = T.zeros(4, 8, 128).float() + dummy_length = (T.ones(4) * 8).long() + output = layer(dummy_input, dummy_target, dummy_length) + assert output.item() == 1.0, "1.0 vs {}".format(output.item()) + + # test if padded values of input makes any difference + dummy_input = T.ones(4, 8, 128).float() + dummy_target = T.zeros(4, 8, 128).float() + dummy_length = (T.arange(5, 9)).long() + mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) + output = layer(dummy_input + mask, dummy_target, dummy_length) + assert abs(output.item() - 1.0) < 1e-5, "1.0 vs {}".format(output.item()) + + dummy_input = T.rand(4, 8, 128).float() + dummy_target = dummy_input.detach() + dummy_length = (T.arange(5, 9)).long() + mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) + output = layer(dummy_input + mask, dummy_target, dummy_length) + assert output.item() == 0, "0 vs {}".format(output.item()) + + +class SSIMLossTests(unittest.TestCase): + def test_in_out(self): # pylint: disable=no-self-use + # test input == target + layer = SSIMLoss() + dummy_input = T.ones(4, 57, 128).float() + dummy_target = T.ones(4, 57, 128).float() + dummy_length = (T.ones(4) * 8).long() + output = layer(dummy_input, dummy_target, dummy_length) + assert output.item() == 0.0 + + # test input != target + dummy_input = T.arange(0, 4 * 57 * 128) + dummy_input = dummy_input.reshape(4, 57, 128).float() + dummy_target = T.arange(-4 * 57 * 128, 0) + dummy_target = dummy_target.reshape(4, 57, 128).float() + dummy_target = -dummy_target + + dummy_length = (T.ones(4) * 58).long() + output = layer(dummy_input, dummy_target, dummy_length) + assert output.item() >= 1.0, "0 vs {}".format(output.item()) + + # test if padded values of input makes any difference + dummy_input = T.ones(4, 57, 128).float() + dummy_target = T.zeros(4, 57, 128).float() + dummy_length = (T.arange(54, 58)).long() + mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) + output = layer(dummy_input + mask, dummy_target, dummy_length) + assert output.item() == 0.0 + + dummy_input = T.rand(4, 57, 128).float() + dummy_target = dummy_input.detach() + dummy_length = (T.arange(54, 58)).long() + mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) + output = layer(dummy_input + mask, dummy_target, dummy_length) + assert output.item() == 0, "0 vs {}".format(output.item()) + + # seq_len_norm = True + # test input == target + layer = L1LossMasked(seq_len_norm=True) + dummy_input = T.ones(4, 57, 128).float() + dummy_target = T.ones(4, 57, 128).float() + dummy_length = (T.ones(4) * 8).long() + output = layer(dummy_input, dummy_target, dummy_length) + assert output.item() == 0.0 + + # test input != target + dummy_input = T.ones(4, 57, 128).float() + dummy_target = T.zeros(4, 57, 128).float() + dummy_length = (T.ones(4) * 8).long() + output = layer(dummy_input, dummy_target, dummy_length) + assert output.item() == 1.0, "1.0 vs {}".format(output.item()) + + # test if padded values of input makes any difference + dummy_input = T.ones(4, 57, 128).float() + dummy_target = T.zeros(4, 57, 128).float() + dummy_length = (T.arange(54, 58)).long() + mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) + output = layer(dummy_input + mask, dummy_target, dummy_length) + assert abs(output.item() - 1.0) < 1e-5, "1.0 vs {}".format(output.item()) + + dummy_input = T.rand(4, 57, 128).float() + dummy_target = dummy_input.detach() + dummy_length = (T.arange(54, 58)).long() + mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) + output = layer(dummy_input + mask, dummy_target, dummy_length) + assert output.item() == 0, "0 vs {}".format(output.item()) + + +class BCELossTest(unittest.TestCase): + def test_in_out(self): # pylint: disable=no-self-use + layer = BCELossMasked(pos_weight=5.0) + + length = T.tensor([95]) + target = ( + 1.0 - sequence_mask(length - 1, 100).float() + ) # [0, 0, .... 1, 1] where the first 1 is the last mel frame + true_x = target * 200 - 100 # creates logits of [-100, -100, ... 100, 100] corresponding to target + zero_x = T.zeros(target.shape) - 100.0 # simulate logits if it never stops decoding + early_x = -200.0 * sequence_mask(length - 3, 100).float() + 100.0 # simulate logits on early stopping + late_x = -200.0 * sequence_mask(length + 1, 100).float() + 100.0 # simulate logits on late stopping + + loss = layer(true_x, target, length) + self.assertEqual(loss.item(), 0.0) + + loss = layer(early_x, target, length) + self.assertAlmostEqual(loss.item(), 2.1053, places=4) + + loss = layer(late_x, target, length) + self.assertAlmostEqual(loss.item(), 5.2632, places=4) + + loss = layer(zero_x, target, length) + self.assertAlmostEqual(loss.item(), 5.2632, places=4) + + # pos_weight should be < 1 to penalize early stopping + layer = BCELossMasked(pos_weight=0.2) + loss = layer(true_x, target, length) + self.assertEqual(loss.item(), 0.0) + + # when pos_weight < 1 overweight the early stopping loss + + loss_early = layer(early_x, target, length) + loss_late = layer(late_x, target, length) + self.assertGreater(loss_early.item(), loss_late.item()) diff --git a/tests/tts_tests/test_tacotron_layers.py b/tests/tts_tests/test_tacotron_layers.py index fdce75ddc7..43e72417c2 100644 --- a/tests/tts_tests/test_tacotron_layers.py +++ b/tests/tts_tests/test_tacotron_layers.py @@ -2,9 +2,7 @@ import torch as T -from TTS.tts.layers.losses import L1LossMasked, SSIMLoss from TTS.tts.layers.tacotron.tacotron import CBHG, Decoder, Encoder, Prenet -from TTS.tts.utils.helpers import sequence_mask # pylint: disable=unused-variable @@ -85,131 +83,3 @@ def test_in_out(self): # pylint: disable=no-self-use assert output.shape[0] == 4 assert output.shape[1] == 8 assert output.shape[2] == 256 # 128 * 2 BiRNN - - -class L1LossMaskedTests(unittest.TestCase): - def test_in_out(self): # pylint: disable=no-self-use - # test input == target - layer = L1LossMasked(seq_len_norm=False) - dummy_input = T.ones(4, 8, 128).float() - dummy_target = T.ones(4, 8, 128).float() - dummy_length = (T.ones(4) * 8).long() - output = layer(dummy_input, dummy_target, dummy_length) - assert output.item() == 0.0 - - # test input != target - dummy_input = T.ones(4, 8, 128).float() - dummy_target = T.zeros(4, 8, 128).float() - dummy_length = (T.ones(4) * 8).long() - output = layer(dummy_input, dummy_target, dummy_length) - assert output.item() == 1.0, "1.0 vs {}".format(output.item()) - - # test if padded values of input makes any difference - dummy_input = T.ones(4, 8, 128).float() - dummy_target = T.zeros(4, 8, 128).float() - dummy_length = (T.arange(5, 9)).long() - mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) - output = layer(dummy_input + mask, dummy_target, dummy_length) - assert output.item() == 1.0, "1.0 vs {}".format(output.item()) - - dummy_input = T.rand(4, 8, 128).float() - dummy_target = dummy_input.detach() - dummy_length = (T.arange(5, 9)).long() - mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) - output = layer(dummy_input + mask, dummy_target, dummy_length) - assert output.item() == 0, "0 vs {}".format(output.item()) - - # seq_len_norm = True - # test input == target - layer = L1LossMasked(seq_len_norm=True) - dummy_input = T.ones(4, 8, 128).float() - dummy_target = T.ones(4, 8, 128).float() - dummy_length = (T.ones(4) * 8).long() - output = layer(dummy_input, dummy_target, dummy_length) - assert output.item() == 0.0 - - # test input != target - dummy_input = T.ones(4, 8, 128).float() - dummy_target = T.zeros(4, 8, 128).float() - dummy_length = (T.ones(4) * 8).long() - output = layer(dummy_input, dummy_target, dummy_length) - assert output.item() == 1.0, "1.0 vs {}".format(output.item()) - - # test if padded values of input makes any difference - dummy_input = T.ones(4, 8, 128).float() - dummy_target = T.zeros(4, 8, 128).float() - dummy_length = (T.arange(5, 9)).long() - mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) - output = layer(dummy_input + mask, dummy_target, dummy_length) - assert abs(output.item() - 1.0) < 1e-5, "1.0 vs {}".format(output.item()) - - dummy_input = T.rand(4, 8, 128).float() - dummy_target = dummy_input.detach() - dummy_length = (T.arange(5, 9)).long() - mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) - output = layer(dummy_input + mask, dummy_target, dummy_length) - assert output.item() == 0, "0 vs {}".format(output.item()) - - -class SSIMLossTests(unittest.TestCase): - def test_in_out(self): # pylint: disable=no-self-use - # test input == target - layer = SSIMLoss() - dummy_input = T.ones(4, 8, 128).float() - dummy_target = T.ones(4, 8, 128).float() - dummy_length = (T.ones(4) * 8).long() - output = layer(dummy_input, dummy_target, dummy_length) - assert output.item() == 0.0 - - # test input != target - dummy_input = T.ones(4, 8, 128).float() - dummy_target = T.zeros(4, 8, 128).float() - dummy_length = (T.ones(4) * 8).long() - output = layer(dummy_input, dummy_target, dummy_length) - assert abs(output.item() - 1.0) < 1e-4, "1.0 vs {}".format(output.item()) - - # test if padded values of input makes any difference - dummy_input = T.ones(4, 8, 128).float() - dummy_target = T.zeros(4, 8, 128).float() - dummy_length = (T.arange(5, 9)).long() - mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) - output = layer(dummy_input + mask, dummy_target, dummy_length) - assert abs(output.item() - 1.0) < 1e-4, "1.0 vs {}".format(output.item()) - - dummy_input = T.rand(4, 8, 128).float() - dummy_target = dummy_input.detach() - dummy_length = (T.arange(5, 9)).long() - mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) - output = layer(dummy_input + mask, dummy_target, dummy_length) - assert output.item() == 0, "0 vs {}".format(output.item()) - - # seq_len_norm = True - # test input == target - layer = L1LossMasked(seq_len_norm=True) - dummy_input = T.ones(4, 8, 128).float() - dummy_target = T.ones(4, 8, 128).float() - dummy_length = (T.ones(4) * 8).long() - output = layer(dummy_input, dummy_target, dummy_length) - assert output.item() == 0.0 - - # test input != target - dummy_input = T.ones(4, 8, 128).float() - dummy_target = T.zeros(4, 8, 128).float() - dummy_length = (T.ones(4) * 8).long() - output = layer(dummy_input, dummy_target, dummy_length) - assert output.item() == 1.0, "1.0 vs {}".format(output.item()) - - # test if padded values of input makes any difference - dummy_input = T.ones(4, 8, 128).float() - dummy_target = T.zeros(4, 8, 128).float() - dummy_length = (T.arange(5, 9)).long() - mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) - output = layer(dummy_input + mask, dummy_target, dummy_length) - assert abs(output.item() - 1.0) < 1e-5, "1.0 vs {}".format(output.item()) - - dummy_input = T.rand(4, 8, 128).float() - dummy_target = dummy_input.detach() - dummy_length = (T.arange(5, 9)).long() - mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) - output = layer(dummy_input + mask, dummy_target, dummy_length) - assert output.item() == 0, "0 vs {}".format(output.item()) diff --git a/tests/tts_tests/test_vits.py b/tests/tts_tests/test_vits.py index b9cebb5a65..7d474c2092 100644 --- a/tests/tts_tests/test_vits.py +++ b/tests/tts_tests/test_vits.py @@ -9,7 +9,17 @@ from TTS.config import load_config from TTS.encoder.utils.generic_utils import setup_encoder_model from TTS.tts.configs.vits_config import VitsConfig -from TTS.tts.models.vits import Vits, VitsArgs, amp_to_db, db_to_amp, load_audio, spec_to_mel, wav_to_mel, wav_to_spec +from TTS.tts.models.vits import ( + Vits, + VitsArgs, + VitsAudioConfig, + amp_to_db, + db_to_amp, + load_audio, + spec_to_mel, + wav_to_mel, + wav_to_spec, +) from TTS.tts.utils.speakers import SpeakerManager LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json") @@ -421,8 +431,10 @@ def test_train_step(self): self._check_parameter_changes(model, model_ref) def test_train_step_upsampling(self): + """Upsampling by the decoder upsampling layers""" # setup the model with torch.autograd.set_detect_anomaly(True): + audio_config = VitsAudioConfig(sample_rate=22050) model_args = VitsArgs( num_chars=32, spec_segment_size=10, @@ -430,7 +442,7 @@ def test_train_step_upsampling(self): interpolate_z=False, upsample_rates_decoder=[8, 8, 4, 2], ) - config = VitsConfig(model_args=model_args) + config = VitsConfig(model_args=model_args, audio=audio_config) model = Vits(config).to(device) model.train() # model to train @@ -459,10 +471,18 @@ def test_train_step_upsampling(self): self._check_parameter_changes(model, model_ref) def test_train_step_upsampling_interpolation(self): + """Upsampling by interpolation""" # setup the model with torch.autograd.set_detect_anomaly(True): - model_args = VitsArgs(num_chars=32, spec_segment_size=10, encoder_sample_rate=11025, interpolate_z=True) - config = VitsConfig(model_args=model_args) + audio_config = VitsAudioConfig(sample_rate=22050) + model_args = VitsArgs( + num_chars=32, + spec_segment_size=10, + encoder_sample_rate=11025, + interpolate_z=True, + upsample_rates_decoder=[8, 8, 2, 2], + ) + config = VitsConfig(model_args=model_args, audio=audio_config) model = Vits(config).to(device) model.train() # model to train