diff --git a/docs/source/recipes/TTS/ljspeech/vits.rst b/docs/source/recipes/TTS/ljspeech/vits.rst index 323d0adfc8..d31bf63022 100644 --- a/docs/source/recipes/TTS/ljspeech/vits.rst +++ b/docs/source/recipes/TTS/ljspeech/vits.rst @@ -56,7 +56,8 @@ Training --start-epoch 1 \ --use-fp16 1 \ --exp-dir vits/exp \ - --tokens data/tokens.txt + --tokens data/tokens.txt \ + --model-type high \ --max-duration 500 .. note:: @@ -64,6 +65,11 @@ Training You can adjust the hyper-parameters to control the size of the VITS model and the training configurations. For more details, please run ``./vits/train.py --help``. +.. warning:: + + If you want a model that runs faster on CPU, please use ``--model-type low`` + or ``--model-type medium``. + .. note:: The training can take a long time (usually a couple of days). @@ -95,8 +101,8 @@ training part first. It will save the ground-truth and generated wavs to the dir Export models ------------- -Currently we only support ONNX model exporting. It will generate two files in the given ``exp-dir``: -``vits-epoch-*.onnx`` and ``vits-epoch-*.int8.onnx``. +Currently we only support ONNX model exporting. It will generate one file in the given ``exp-dir``: +``vits-epoch-*.onnx``. .. code-block:: bash @@ -120,4 +126,7 @@ Download pretrained models If you don't want to train from scratch, you can download the pretrained models by visiting the following link: - - ``_ + - ``--model-type=high``: ``_ + - ``--model-type=medium``: ``_ + - ``--model-type=low``: ``_ + diff --git a/egs/ljspeech/TTS/README.md b/egs/ljspeech/TTS/README.md index 80be5a3155..7b112c12c8 100644 --- a/egs/ljspeech/TTS/README.md +++ b/egs/ljspeech/TTS/README.md @@ -1,10 +1,10 @@ # Introduction -This is a public domain speech dataset consisting of 13,100 short audio clips of a single speaker reading passages from 7 non-fiction books. -A transcription is provided for each clip. +This is a public domain speech dataset consisting of 13,100 short audio clips of a single speaker reading passages from 7 non-fiction books. +A transcription is provided for each clip. Clips vary in length from 1 to 10 seconds and have a total length of approximately 24 hours. -The texts were published between 1884 and 1964, and are in the public domain. +The texts were published between 1884 and 1964, and are in the public domain. The audio was recorded in 2016-17 by the [LibriVox](https://librivox.org/) project and is also in the public domain. The above information is from the [LJSpeech website](https://keithito.com/LJ-Speech-Dataset/). @@ -35,4 +35,69 @@ To inference, use: --exp-dir vits/exp \ --epoch 1000 \ --tokens data/tokens.txt -``` \ No newline at end of file +``` + +## Quality vs speed + +If you feel that the trained model is slow at runtime, you can specify the +argument `--model-type` during training. Possible values are: + + - `low`, means **low** quality. The resulting model is very small in file size + and runs very fast. The following is a wave file generatd by a `low` quality model + + https://github.com/k2-fsa/icefall/assets/5284924/d5758c24-470d-40ee-b089-e57fcba81633 + + The text is `Ask not what your country can do for you; ask what you can do for your country.` + + The exported onnx model has a file size of ``26.8 MB`` (float32). + + - `medium`, means **medium** quality. + The following is a wave file generatd by a `medium` quality model + + https://github.com/k2-fsa/icefall/assets/5284924/b199d960-3665-4d0d-9ae9-a1bb69cbc8ac + + The text is `Ask not what your country can do for you; ask what you can do for your country.` + + The exported onnx model has a file size of ``70.9 MB`` (float32). + + - `high`, means **high** quality. This is the default value. + + The following is a wave file generatd by a `high` quality model + + https://github.com/k2-fsa/icefall/assets/5284924/b39f3048-73a6-4267-bf95-df5abfdb28fc + + The text is `Ask not what your country can do for you; ask what you can do for your country.` + + The exported onnx model has a file size of ``113 MB`` (float32). + + +A pre-trained `low` model trained using 4xV100 32GB GPU with the following command can be found at + + +```bash +export CUDA_VISIBLE_DEVICES=0,1,2,3 +./vits/train.py \ + --world-size 4 \ + --num-epochs 1601 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir vits/exp \ + --model-type low \ + --max-duration 800 +``` + +A pre-trained `medium` model trained using 4xV100 32GB GPU with the following command can be found at + +```bash +export CUDA_VISIBLE_DEVICES=4,5,6,7 +./vits/train.py \ + --world-size 4 \ + --num-epochs 1000 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir vits/exp-medium \ + --model-type medium \ + --max-duration 500 + +# (Note it is killed after `epoch-820.pt`) +``` diff --git a/egs/ljspeech/TTS/prepare.sh b/egs/ljspeech/TTS/prepare.sh index cbf27bd423..bded423ac9 100755 --- a/egs/ljspeech/TTS/prepare.sh +++ b/egs/ljspeech/TTS/prepare.sh @@ -54,7 +54,7 @@ fi if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then log "Stage 1: Prepare LJSpeech manifest" # We assume that you have downloaded the LJSpeech corpus - # to $dl_dir/LJSpeech + # to $dl_dir/LJSpeech-1.1 mkdir -p data/manifests if [ ! -e data/manifests/.ljspeech.done ]; then lhotse prepare ljspeech $dl_dir/LJSpeech-1.1 data/manifests diff --git a/egs/ljspeech/TTS/vits/export-onnx.py b/egs/ljspeech/TTS/vits/export-onnx.py index 58b1663684..0740757c06 100755 --- a/egs/ljspeech/TTS/vits/export-onnx.py +++ b/egs/ljspeech/TTS/vits/export-onnx.py @@ -25,9 +25,8 @@ --exp-dir vits/exp \ --tokens data/tokens.txt -It will generate two files inside vits/exp: +It will generate one file inside vits/exp: - vits-epoch-1000.onnx - - vits-epoch-1000.int8.onnx (quantizated model) See ./test_onnx.py for how to use the exported ONNX models. """ @@ -40,7 +39,6 @@ import onnx import torch import torch.nn as nn -from onnxruntime.quantization import QuantType, quantize_dynamic from tokenizer import Tokenizer from train import get_model, get_params @@ -75,6 +73,16 @@ def get_parser(): help="""Path to vocabulary.""", ) + parser.add_argument( + "--model-type", + type=str, + default="high", + choices=["low", "medium", "high"], + help="""If not empty, valid values are: low, medium, high. + It controls the model size. low -> runs faster. + """, + ) + return parser @@ -136,7 +144,7 @@ def forward( Return a tuple containing: - audio, generated wavform tensor, (B, T_wav) """ - audio, _, _ = self.model.inference( + audio, _, _ = self.model.generator.inference( text=tokens, text_lengths=tokens_lens, noise_scale=noise_scale, @@ -198,6 +206,11 @@ def export_model_onnx( }, ) + if model.model.spks is None: + num_speakers = 1 + else: + num_speakers = model.model.spks + meta_data = { "model_type": "vits", "version": "1", @@ -206,8 +219,8 @@ def export_model_onnx( "language": "English", "voice": "en-us", # Choose your language appropriately "has_espeak": 1, - "n_speakers": 1, - "sample_rate": 22050, # Must match the real sample rate + "n_speakers": num_speakers, + "sample_rate": model.model.sampling_rate, # Must match the real sample rate } logging.info(f"meta_data: {meta_data}") @@ -233,14 +246,13 @@ def main(): load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - model = model.generator model.to("cpu") model.eval() model = OnnxModel(model=model) num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"generator parameters: {num_param}") + logging.info(f"generator parameters: {num_param}, or {num_param/1000/1000} M") suffix = f"epoch-{params.epoch}" @@ -256,18 +268,6 @@ def main(): ) logging.info(f"Exported generator to {model_filename}") - # Generate int8 quantization models - # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection - - logging.info("Generate int8 quantization models") - - model_filename_int8 = params.exp_dir / f"vits-{suffix}.int8.onnx" - quantize_dynamic( - model_input=model_filename, - model_output=model_filename_int8, - weight_type=QuantType.QUInt8, - ) - if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" diff --git a/egs/ljspeech/TTS/vits/generator.py b/egs/ljspeech/TTS/vits/generator.py index 66c8cedb19..b9add9e828 100644 --- a/egs/ljspeech/TTS/vits/generator.py +++ b/egs/ljspeech/TTS/vits/generator.py @@ -189,7 +189,7 @@ def __init__( self.upsample_factor = int(np.prod(decoder_upsample_scales)) self.spks = None if spks is not None and spks > 1: - assert global_channels > 0 + assert global_channels > 0, global_channels self.spks = spks self.global_emb = torch.nn.Embedding(spks, global_channels) self.spk_embed_dim = None diff --git a/egs/ljspeech/TTS/vits/infer.py b/egs/ljspeech/TTS/vits/infer.py index 9e7c71c6dc..7be76e3151 100755 --- a/egs/ljspeech/TTS/vits/infer.py +++ b/egs/ljspeech/TTS/vits/infer.py @@ -72,6 +72,16 @@ def get_parser(): help="""Path to vocabulary.""", ) + parser.add_argument( + "--model-type", + type=str, + default="high", + choices=["low", "medium", "high"], + help="""If not empty, valid values are: low, medium, high. + It controls the model size. low -> runs faster. + """, + ) + return parser @@ -94,6 +104,7 @@ def infer_dataset( tokenizer: Used to convert text to phonemes. """ + # Background worker save audios to disk. def _save_worker( batch_size: int, diff --git a/egs/ljspeech/TTS/vits/test_model.py b/egs/ljspeech/TTS/vits/test_model.py new file mode 100755 index 0000000000..1de10f012b --- /dev/null +++ b/egs/ljspeech/TTS/vits/test_model.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from tokenizer import Tokenizer +from train import get_model, get_params +from vits import VITS + + +def test_model_type(model_type): + tokens = "./data/tokens.txt" + + params = get_params() + + tokenizer = Tokenizer(tokens) + params.blank_id = tokenizer.pad_id + params.vocab_size = tokenizer.vocab_size + params.model_type = model_type + + model = get_model(params) + generator = model.generator + + num_param = sum([p.numel() for p in generator.parameters()]) + print( + f"{model_type}: generator parameters: {num_param}, or {num_param/1000/1000} M" + ) + + +def main(): + test_model_type("high") # 35.63 M + test_model_type("low") # 7.55 M + test_model_type("medium") # 23.61 M + + +if __name__ == "__main__": + main() diff --git a/egs/ljspeech/TTS/vits/test_onnx.py b/egs/ljspeech/TTS/vits/test_onnx.py index 4f46e8e6c5..b3805fadb3 100755 --- a/egs/ljspeech/TTS/vits/test_onnx.py +++ b/egs/ljspeech/TTS/vits/test_onnx.py @@ -54,6 +54,20 @@ def get_parser(): help="""Path to vocabulary.""", ) + parser.add_argument( + "--text", + type=str, + default="Ask not what your country can do for you; ask what you can do for your country.", + help="Text to generate speech for", + ) + + parser.add_argument( + "--output-filename", + type=str, + default="test_onnx.wav", + help="Filename to save the generated wave file.", + ) + return parser @@ -61,7 +75,7 @@ class OnnxModel: def __init__(self, model_filename: str): session_opts = ort.SessionOptions() session_opts.inter_op_num_threads = 1 - session_opts.intra_op_num_threads = 4 + session_opts.intra_op_num_threads = 1 self.session_opts = session_opts @@ -72,6 +86,9 @@ def __init__(self, model_filename: str): ) logging.info(f"{self.model.get_modelmeta().custom_metadata_map}") + metadata = self.model.get_modelmeta().custom_metadata_map + self.sample_rate = int(metadata["sample_rate"]) + def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Tensor: """ Args: @@ -101,13 +118,14 @@ def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Ten def main(): args = get_parser().parse_args() + logging.info(vars(args)) tokenizer = Tokenizer(args.tokens) logging.info("About to create onnx model") model = OnnxModel(args.model_filename) - text = "I went there to see the land, the people and how their system works, end quote." + text = args.text tokens = tokenizer.texts_to_token_ids( [text], intersperse_blank=True, add_sos=True, add_eos=True ) @@ -115,8 +133,9 @@ def main(): tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T) audio = model(tokens, tokens_lens) # (1, T') - torchaudio.save(str("test_onnx.wav"), audio, sample_rate=22050) - logging.info("Saved to test_onnx.wav") + output_filename = args.output_filename + torchaudio.save(output_filename, audio, sample_rate=model.sample_rate) + logging.info(f"Saved to {output_filename}") if __name__ == "__main__": diff --git a/egs/ljspeech/TTS/vits/text_encoder.py b/egs/ljspeech/TTS/vits/text_encoder.py index fcbae7103f..9b21ed9cb0 100644 --- a/egs/ljspeech/TTS/vits/text_encoder.py +++ b/egs/ljspeech/TTS/vits/text_encoder.py @@ -92,9 +92,9 @@ def forward( x_lengths (Tensor): Length tensor (B,). Returns: - Tensor: Encoded hidden representation (B, attention_dim, T_text). - Tensor: Projected mean tensor (B, attention_dim, T_text). - Tensor: Projected scale tensor (B, attention_dim, T_text). + Tensor: Encoded hidden representation (B, embed_dim, T_text). + Tensor: Projected mean tensor (B, embed_dim, T_text). + Tensor: Projected scale tensor (B, embed_dim, T_text). Tensor: Mask tensor for input tensor (B, 1, T_text). """ @@ -108,6 +108,7 @@ def forward( # encoder assume the channel last (B, T_text, embed_dim) x = self.encoder(x, key_padding_mask=pad_mask) + # Note: attention_dim == embed_dim # convert the channel first (B, embed_dim, T_text) x = x.transpose(1, 2) diff --git a/egs/ljspeech/TTS/vits/tokenizer.py b/egs/ljspeech/TTS/vits/tokenizer.py index 9a5a9090ec..8144ffe1eb 100644 --- a/egs/ljspeech/TTS/vits/tokenizer.py +++ b/egs/ljspeech/TTS/vits/tokenizer.py @@ -18,7 +18,15 @@ from typing import Dict, List import tacotron_cleaner.cleaners -from piper_phonemize import phonemize_espeak + +try: + from piper_phonemize import phonemize_espeak +except Exception as ex: + raise RuntimeError( + f"{ex}\nPlease follow instructions in " + "../prepare.sh to install piper-phonemize" + ) + from utils import intersperse diff --git a/egs/ljspeech/TTS/vits/train.py b/egs/ljspeech/TTS/vits/train.py index 6589b75ff6..34b943765a 100755 --- a/egs/ljspeech/TTS/vits/train.py +++ b/egs/ljspeech/TTS/vits/train.py @@ -153,6 +153,16 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--model-type", + type=str, + default="high", + choices=["low", "medium", "high"], + help="""If not empty, valid values are: low, medium, high. + It controls the model size. low -> runs faster. + """, + ) + return parser @@ -189,15 +199,6 @@ def get_params() -> AttributeDict: - feature_dim: The model input dim. It has to match the one used in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. """ params = AttributeDict( { @@ -278,6 +279,7 @@ def get_model(params: AttributeDict) -> nn.Module: vocab_size=params.vocab_size, feature_dim=params.feature_dim, sampling_rate=params.sampling_rate, + model_type=params.model_type, mel_loss_params=mel_loss_params, lambda_adv=params.lambda_adv, lambda_mel=params.lambda_mel, @@ -363,7 +365,7 @@ def train_one_epoch( model.train() device = model.device if isinstance(model, DDP) else next(model.parameters()).device - # used to summary the stats over iterations in one epoch + # used to track the stats over iterations in one epoch tot_loss = MetricsTracker() saved_bad_model = False diff --git a/egs/ljspeech/TTS/vits/vits.py b/egs/ljspeech/TTS/vits/vits.py index b4f0c21e6d..0b9575cbde 100644 --- a/egs/ljspeech/TTS/vits/vits.py +++ b/egs/ljspeech/TTS/vits/vits.py @@ -5,6 +5,7 @@ """VITS module for GAN-TTS task.""" +import copy from typing import Any, Dict, Optional, Tuple import torch @@ -38,6 +39,36 @@ "hifigan_multi_scale_multi_period_discriminator": HiFiGANMultiScaleMultiPeriodDiscriminator, # NOQA } +LOW_CONFIG = { + "hidden_channels": 96, + "decoder_upsample_scales": (8, 8, 4), + "decoder_channels": 256, + "decoder_upsample_kernel_sizes": (16, 16, 8), + "decoder_resblock_kernel_sizes": (3, 5, 7), + "decoder_resblock_dilations": ((1, 2), (2, 6), (3, 12)), + "text_encoder_cnn_module_kernel": 3, +} + +MEDIUM_CONFIG = { + "hidden_channels": 192, + "decoder_upsample_scales": (8, 8, 4), + "decoder_channels": 256, + "decoder_upsample_kernel_sizes": (16, 16, 8), + "decoder_resblock_kernel_sizes": (3, 5, 7), + "decoder_resblock_dilations": ((1, 2), (2, 6), (3, 12)), + "text_encoder_cnn_module_kernel": 3, +} + +HIGH_CONFIG = { + "hidden_channels": 192, + "decoder_upsample_scales": (8, 8, 2, 2), + "decoder_channels": 512, + "decoder_upsample_kernel_sizes": (16, 16, 4, 4), + "decoder_resblock_kernel_sizes": (3, 7, 11), + "decoder_resblock_dilations": ((1, 3, 5), (1, 3, 5), (1, 3, 5)), + "text_encoder_cnn_module_kernel": 5, +} + class VITS(nn.Module): """Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`""" @@ -49,6 +80,7 @@ def __init__( feature_dim: int = 513, sampling_rate: int = 22050, generator_type: str = "vits_generator", + model_type: str = "", generator_params: Dict[str, Any] = { "hidden_channels": 192, "spks": None, @@ -155,12 +187,13 @@ def __init__( """Initialize VITS module. Args: - idim (int): Input vocabrary size. + idim (int): Input vocabulary size. odim (int): Acoustic feature dimension. The actual output channels will be 1 since VITS is the end-to-end text-to-wave model but for the compatibility odim is used to indicate the acoustic feature dimension. sampling_rate (int): Sampling rate, not used for the training but it will be referred in saving waveform during the inference. + model_type (str): If not empty, must be one of: low, medium, high generator_type (str): Generator type. generator_params (Dict[str, Any]): Parameter dict for generator. discriminator_type (str): Discriminator type. @@ -181,6 +214,24 @@ def __init__( """ super().__init__() + generator_params = copy.deepcopy(generator_params) + discriminator_params = copy.deepcopy(discriminator_params) + generator_adv_loss_params = copy.deepcopy(generator_adv_loss_params) + discriminator_adv_loss_params = copy.deepcopy(discriminator_adv_loss_params) + feat_match_loss_params = copy.deepcopy(feat_match_loss_params) + mel_loss_params = copy.deepcopy(mel_loss_params) + + if model_type != "": + assert model_type in ("low", "medium", "high"), model_type + if model_type == "low": + generator_params.update(LOW_CONFIG) + elif model_type == "medium": + generator_params.update(MEDIUM_CONFIG) + elif model_type == "high": + generator_params.update(HIGH_CONFIG) + else: + raise ValueError(f"Unknown model_type: ${model_type}") + # define modules generator_class = AVAILABLE_GENERATERS[generator_type] if generator_type == "vits_generator":