Skip to content

Commit

Permalink
Support different tts model types. (#1541)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Mar 12, 2024
1 parent 959906e commit 81f518e
Show file tree
Hide file tree
Showing 12 changed files with 265 additions and 49 deletions.
17 changes: 13 additions & 4 deletions docs/source/recipes/TTS/ljspeech/vits.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,20 @@ Training
--start-epoch 1 \
--use-fp16 1 \
--exp-dir vits/exp \
--tokens data/tokens.txt
--tokens data/tokens.txt \
--model-type high \
--max-duration 500
.. note::

You can adjust the hyper-parameters to control the size of the VITS model and
the training configurations. For more details, please run ``./vits/train.py --help``.

.. warning::

If you want a model that runs faster on CPU, please use ``--model-type low``
or ``--model-type medium``.

.. note::

The training can take a long time (usually a couple of days).
Expand Down Expand Up @@ -95,8 +101,8 @@ training part first. It will save the ground-truth and generated wavs to the dir
Export models
-------------

Currently we only support ONNX model exporting. It will generate two files in the given ``exp-dir``:
``vits-epoch-*.onnx`` and ``vits-epoch-*.int8.onnx``.
Currently we only support ONNX model exporting. It will generate one file in the given ``exp-dir``:
``vits-epoch-*.onnx``.

.. code-block:: bash
Expand All @@ -120,4 +126,7 @@ Download pretrained models
If you don't want to train from scratch, you can download the pretrained models
by visiting the following link:

- `<https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2024-02-28>`_
- ``--model-type=high``: `<https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2024-02-28>`_
- ``--model-type=medium``: `<https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-medium-2024-03-12>`_
- ``--model-type=low``: `<https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-low-2024-03-12>`_

73 changes: 69 additions & 4 deletions egs/ljspeech/TTS/README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Introduction

This is a public domain speech dataset consisting of 13,100 short audio clips of a single speaker reading passages from 7 non-fiction books.
A transcription is provided for each clip.
This is a public domain speech dataset consisting of 13,100 short audio clips of a single speaker reading passages from 7 non-fiction books.
A transcription is provided for each clip.
Clips vary in length from 1 to 10 seconds and have a total length of approximately 24 hours.

The texts were published between 1884 and 1964, and are in the public domain.
The texts were published between 1884 and 1964, and are in the public domain.
The audio was recorded in 2016-17 by the [LibriVox](https://librivox.org/) project and is also in the public domain.

The above information is from the [LJSpeech website](https://keithito.com/LJ-Speech-Dataset/).
Expand Down Expand Up @@ -35,4 +35,69 @@ To inference, use:
--exp-dir vits/exp \
--epoch 1000 \
--tokens data/tokens.txt
```
```

## Quality vs speed

If you feel that the trained model is slow at runtime, you can specify the
argument `--model-type` during training. Possible values are:

- `low`, means **low** quality. The resulting model is very small in file size
and runs very fast. The following is a wave file generatd by a `low` quality model

https://github.com/k2-fsa/icefall/assets/5284924/d5758c24-470d-40ee-b089-e57fcba81633

The text is `Ask not what your country can do for you; ask what you can do for your country.`

The exported onnx model has a file size of ``26.8 MB`` (float32).

- `medium`, means **medium** quality.
The following is a wave file generatd by a `medium` quality model

https://github.com/k2-fsa/icefall/assets/5284924/b199d960-3665-4d0d-9ae9-a1bb69cbc8ac

The text is `Ask not what your country can do for you; ask what you can do for your country.`

The exported onnx model has a file size of ``70.9 MB`` (float32).

- `high`, means **high** quality. This is the default value.

The following is a wave file generatd by a `high` quality model

https://github.com/k2-fsa/icefall/assets/5284924/b39f3048-73a6-4267-bf95-df5abfdb28fc

The text is `Ask not what your country can do for you; ask what you can do for your country.`

The exported onnx model has a file size of ``113 MB`` (float32).


A pre-trained `low` model trained using 4xV100 32GB GPU with the following command can be found at
<https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-low-2024-03-12>

```bash
export CUDA_VISIBLE_DEVICES=0,1,2,3
./vits/train.py \
--world-size 4 \
--num-epochs 1601 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir vits/exp \
--model-type low \
--max-duration 800
```

A pre-trained `medium` model trained using 4xV100 32GB GPU with the following command can be found at
<https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-medium-2024-03-12>
```bash
export CUDA_VISIBLE_DEVICES=4,5,6,7
./vits/train.py \
--world-size 4 \
--num-epochs 1000 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir vits/exp-medium \
--model-type medium \
--max-duration 500

# (Note it is killed after `epoch-820.pt`)
```
2 changes: 1 addition & 1 deletion egs/ljspeech/TTS/prepare.sh
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare LJSpeech manifest"
# We assume that you have downloaded the LJSpeech corpus
# to $dl_dir/LJSpeech
# to $dl_dir/LJSpeech-1.1
mkdir -p data/manifests
if [ ! -e data/manifests/.ljspeech.done ]; then
lhotse prepare ljspeech $dl_dir/LJSpeech-1.1 data/manifests
Expand Down
40 changes: 20 additions & 20 deletions egs/ljspeech/TTS/vits/export-onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@
--exp-dir vits/exp \
--tokens data/tokens.txt
It will generate two files inside vits/exp:
It will generate one file inside vits/exp:
- vits-epoch-1000.onnx
- vits-epoch-1000.int8.onnx (quantizated model)
See ./test_onnx.py for how to use the exported ONNX models.
"""
Expand All @@ -40,7 +39,6 @@
import onnx
import torch
import torch.nn as nn
from onnxruntime.quantization import QuantType, quantize_dynamic
from tokenizer import Tokenizer
from train import get_model, get_params

Expand Down Expand Up @@ -75,6 +73,16 @@ def get_parser():
help="""Path to vocabulary.""",
)

parser.add_argument(
"--model-type",
type=str,
default="high",
choices=["low", "medium", "high"],
help="""If not empty, valid values are: low, medium, high.
It controls the model size. low -> runs faster.
""",
)

return parser


Expand Down Expand Up @@ -136,7 +144,7 @@ def forward(
Return a tuple containing:
- audio, generated wavform tensor, (B, T_wav)
"""
audio, _, _ = self.model.inference(
audio, _, _ = self.model.generator.inference(
text=tokens,
text_lengths=tokens_lens,
noise_scale=noise_scale,
Expand Down Expand Up @@ -198,6 +206,11 @@ def export_model_onnx(
},
)

if model.model.spks is None:
num_speakers = 1
else:
num_speakers = model.model.spks

meta_data = {
"model_type": "vits",
"version": "1",
Expand All @@ -206,8 +219,8 @@ def export_model_onnx(
"language": "English",
"voice": "en-us", # Choose your language appropriately
"has_espeak": 1,
"n_speakers": 1,
"sample_rate": 22050, # Must match the real sample rate
"n_speakers": num_speakers,
"sample_rate": model.model.sampling_rate, # Must match the real sample rate
}
logging.info(f"meta_data: {meta_data}")

Expand All @@ -233,14 +246,13 @@ def main():

load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)

model = model.generator
model.to("cpu")
model.eval()

model = OnnxModel(model=model)

num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"generator parameters: {num_param}")
logging.info(f"generator parameters: {num_param}, or {num_param/1000/1000} M")

suffix = f"epoch-{params.epoch}"

Expand All @@ -256,18 +268,6 @@ def main():
)
logging.info(f"Exported generator to {model_filename}")

# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection

logging.info("Generate int8 quantization models")

model_filename_int8 = params.exp_dir / f"vits-{suffix}.int8.onnx"
quantize_dynamic(
model_input=model_filename,
model_output=model_filename_int8,
weight_type=QuantType.QUInt8,
)


if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
Expand Down
2 changes: 1 addition & 1 deletion egs/ljspeech/TTS/vits/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def __init__(
self.upsample_factor = int(np.prod(decoder_upsample_scales))
self.spks = None
if spks is not None and spks > 1:
assert global_channels > 0
assert global_channels > 0, global_channels
self.spks = spks
self.global_emb = torch.nn.Embedding(spks, global_channels)
self.spk_embed_dim = None
Expand Down
11 changes: 11 additions & 0 deletions egs/ljspeech/TTS/vits/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,16 @@ def get_parser():
help="""Path to vocabulary.""",
)

parser.add_argument(
"--model-type",
type=str,
default="high",
choices=["low", "medium", "high"],
help="""If not empty, valid values are: low, medium, high.
It controls the model size. low -> runs faster.
""",
)

return parser


Expand All @@ -94,6 +104,7 @@ def infer_dataset(
tokenizer:
Used to convert text to phonemes.
"""

# Background worker save audios to disk.
def _save_worker(
batch_size: int,
Expand Down
50 changes: 50 additions & 0 deletions egs/ljspeech/TTS/vits/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from tokenizer import Tokenizer
from train import get_model, get_params
from vits import VITS


def test_model_type(model_type):
tokens = "./data/tokens.txt"

params = get_params()

tokenizer = Tokenizer(tokens)
params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
params.model_type = model_type

model = get_model(params)
generator = model.generator

num_param = sum([p.numel() for p in generator.parameters()])
print(
f"{model_type}: generator parameters: {num_param}, or {num_param/1000/1000} M"
)


def main():
test_model_type("high") # 35.63 M
test_model_type("low") # 7.55 M
test_model_type("medium") # 23.61 M


if __name__ == "__main__":
main()
27 changes: 23 additions & 4 deletions egs/ljspeech/TTS/vits/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,28 @@ def get_parser():
help="""Path to vocabulary.""",
)

parser.add_argument(
"--text",
type=str,
default="Ask not what your country can do for you; ask what you can do for your country.",
help="Text to generate speech for",
)

parser.add_argument(
"--output-filename",
type=str,
default="test_onnx.wav",
help="Filename to save the generated wave file.",
)

return parser


class OnnxModel:
def __init__(self, model_filename: str):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 4
session_opts.intra_op_num_threads = 1

self.session_opts = session_opts

Expand All @@ -72,6 +86,9 @@ def __init__(self, model_filename: str):
)
logging.info(f"{self.model.get_modelmeta().custom_metadata_map}")

metadata = self.model.get_modelmeta().custom_metadata_map
self.sample_rate = int(metadata["sample_rate"])

def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Tensor:
"""
Args:
Expand Down Expand Up @@ -101,22 +118,24 @@ def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Ten

def main():
args = get_parser().parse_args()
logging.info(vars(args))

tokenizer = Tokenizer(args.tokens)

logging.info("About to create onnx model")
model = OnnxModel(args.model_filename)

text = "I went there to see the land, the people and how their system works, end quote."
text = args.text
tokens = tokenizer.texts_to_token_ids(
[text], intersperse_blank=True, add_sos=True, add_eos=True
)
tokens = torch.tensor(tokens) # (1, T)
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T)
audio = model(tokens, tokens_lens) # (1, T')

torchaudio.save(str("test_onnx.wav"), audio, sample_rate=22050)
logging.info("Saved to test_onnx.wav")
output_filename = args.output_filename
torchaudio.save(output_filename, audio, sample_rate=model.sample_rate)
logging.info(f"Saved to {output_filename}")


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 81f518e

Please sign in to comment.