From e1d99e646a6e1e5d0f3d582ca25cf692ce54d5d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 5 Jun 2023 11:15:13 +0200 Subject: [PATCH] Port Fairseq TTS models (#2628) * Load fairseq models * Add docs and missing files * Managing fairseq models and docs for API * Make style * Use scarf URL * Add tests * Fix URL * Pass cpu * Make lint * Fixup * Make lint * fixup * Fixup * Change tokenization order * Update README * Fixup * Fixup --- README.md | 34 ++++++++-- TTS/api.py | 8 ++- TTS/tts/models/vits.py | 68 +++++++++++++++++++- TTS/tts/utils/fairseq.py | 48 ++++++++++++++ TTS/tts/utils/text/characters.py | 43 +++++++++++-- TTS/tts/utils/text/tokenizer.py | 7 ++- TTS/utils/manage.py | 80 ++++++++++++++++++++---- TTS/utils/synthesizer.py | 23 ++++++- docs/source/inference.md | 35 +++++++++-- tests/inference_tests/test_python_api.py | 6 +- 10 files changed, 314 insertions(+), 38 deletions(-) create mode 100644 TTS/tts/utils/fairseq.py diff --git a/README.md b/README.md index 0fc099acb7..05c846ef39 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,13 @@ ## 🐸Coqui.ai News -- 📣 Coqui Studio API is landed on 🐸TTS. You can use the studio voices in combination with 🐸TTS models. [Example](https://github.com/coqui-ai/TTS/blob/dev/README.md#-python-api) -- 📣 Voice generation with prompts - **Prompt to Voice** - is live on Coqui.ai!! [Blog Post](https://coqui.ai/blog/tts/prompt-to-voice) -- 📣 Clone your voice with a single click on [🐸Coqui.ai](https://app.coqui.ai/auth/signin) -
+- 📣 You can use [~1100 Fairseq models](https://github.com/facebookresearch/fairseq/tree/main/examples/mms) with 🐸TTS. +- 📣 🐸TTS now supports 🐢Tortoise with faster inference. +- 📣 **Coqui Studio API** is landed on 🐸TTS. - [Example](https://github.com/coqui-ai/TTS/blob/dev/README.md#-python-api) +- 📣 [**Coqui Sudio API**](https://docs.coqui.ai/docs) is live. +- 📣 Voice generation with prompts - **Prompt to Voice** - is live on [**Coqui Studio**](https://app.coqui.ai/auth/signin)!! - [Blog Post](https://coqui.ai/blog/tts/prompt-to-voice) +- 📣 Voice generation with fusion - **Voice fusion** - is live on [**Coqui Studio**](https://app.coqui.ai/auth/signin). +- 📣 Voice cloning is live on [**Coqui Studio**](https://app.coqui.ai/auth/signin). ## @@ -185,7 +188,9 @@ from TTS.api import TTS model_name = TTS.list_models()[0] # Init TTS tts = TTS(model_name) + # Run TTS + # ❗ Since this model is multi-speaker and multi-lingual, we must set the target speaker and the language # Text to speech with a numpy output wav = tts.tts("This is a test! This is also a test!!", speaker=tts.speakers[0], language=tts.languages[0]) @@ -199,7 +204,8 @@ tts = TTS(model_name="tts_models/de/thorsten/tacotron2-DDC", progress_bar=False, # Run TTS tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path=OUTPUT_PATH) -# Example voice cloning with YourTTS in English, French and Portuguese: +# Example voice cloning with YourTTS in English, French and Portuguese + tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=True) tts.tts_to_file("This is voice cloning.", speaker_wav="my/cloning/audio.wav", language="en", file_path="output.wav") tts.tts_to_file("C'est le clonage de la voix.", speaker_wav="my/cloning/audio.wav", language="fr-fr", file_path="output.wav") @@ -221,7 +227,9 @@ tts.tts_with_vc_to_file( file_path="ouptut.wav" ) -# Example text to speech using [🐸Coqui Studio](https://coqui.ai) models. You can use all of your available speakers in the studio. +# Example text to speech using [🐸Coqui Studio](https://coqui.ai) models. + +# You can use all of your available speakers in the studio. # [🐸Coqui Studio](https://coqui.ai) API token is required. You can get it from the [account page](https://coqui.ai/account). # You should set the `COQUI_STUDIO_TOKEN` environment variable to use the API token. @@ -234,6 +242,20 @@ tts = TTS(model_name="coqui_studio/en/Torcull Diarmuid/coqui_studio", progress_b tts.tts_to_file(text="This is a test.", file_path=OUTPUT_PATH) # Run TTS with emotion and speed control tts.tts_to_file(text="This is a test.", file_path=OUTPUT_PATH, emotion="Happy", speed=1.5) + + +#Example text to speech using **Fairseq models in ~1100 languages** 🤯. + +#For these models use the following name format: `tts_models//fairseq/vits`. +#You can find the list of language ISO codes [here](https://dl.fbaipublicfiles.com/mms/tts/all-tts-languages.html) and learn about the Fairseq models [here](https://github.com/facebookresearch/fairseq/tree/main/examples/mms). + +# TTS with on the fly voice conversion +api = TTS("tts_models/deu/fairseq/vits") +api.tts_with_vc_to_file( + "Wie sage ich auf Italienisch, dass ich dich liebe?", + speaker_wav="target/speaker.wav", + file_path="ouptut.wav" +) ``` ### Command line `tts` diff --git a/TTS/api.py b/TTS/api.py index 86d3692845..81337da429 100644 --- a/TTS/api.py +++ b/TTS/api.py @@ -130,7 +130,7 @@ def name_to_speaker(self, name): for speaker in self.speakers: if speaker.name == name: return speaker - raise ValueError(f"Speaker {name} not found.") + raise ValueError(f"Speaker {name} not found in {self.speakers}") def id_to_speaker(self, speaker_id): for speaker in self.speakers: @@ -264,6 +264,10 @@ def __init__( >>> tts.tts_to_file("C'est le clonage de la voix.", speaker_wav="my/cloning/audio.wav", language="fr", file_path="thisisit.wav") >>> tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav", language="pt", file_path="thisisit.wav") + Example Fairseq TTS models (uses ISO language codes in https://dl.fbaipublicfiles.com/mms/tts/all-tts-languages.html): + >>> tts = TTS(model_name="tts_models/eng/fairseq/vits", progress_bar=False, gpu=True) + >>> tts.tts_to_file("This is a test.", file_path="output.wav") + Args: model_name (str, optional): Model name to load. You can list models by ```tts.models```. Defaults to None. model_path (str, optional): Path to the model checkpoint. Defaults to None. @@ -342,7 +346,7 @@ def list_models(): def download_model_by_name(self, model_name: str): model_path, config_path, model_item = self.manager.download_model(model_name) - if isinstance(model_item["github_rls_url"], list): + if "fairseq" in model_name or (model_item is not None and isinstance(model_item["github_rls_url"], list)): # return model directory if there are multiple files # we assume that the model knows how to load itself return None, None, None, None, model_path diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 366252e65b..6bccbce3a0 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -25,11 +25,12 @@ from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.utils.fairseq import rehash_fairseq_vits_checkpoint from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask from TTS.tts.utils.languages import LanguageManager from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.synthesis import synthesis -from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations +from TTS.tts.utils.text.characters import BaseCharacters, BaseVocabulary, _characters, _pad, _phonemes, _punctuations from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment from TTS.utils.io import load_fsspec @@ -1724,6 +1725,50 @@ def load_checkpoint( self.eval() assert not self.training + def load_fairseq_checkpoint( + self, config, checkpoint_dir, eval=False + ): # pylint: disable=unused-argument, redefined-builtin + """Load VITS checkpoints released by fairseq here: https://github.com/facebookresearch/fairseq/tree/main/examples/mms + Performs some changes for compatibility. + + Args: + config (Coqpit): 🐸TTS model config. + checkpoint_dir (str): Path to the checkpoint directory. + eval (bool, optional): Set to True for evaluation. Defaults to False. + """ + import json + + from TTS.tts.utils.text.cleaners import basic_cleaners + + self.disc = None + # set paths + config_file = os.path.join(checkpoint_dir, "config.json") + checkpoint_file = os.path.join(checkpoint_dir, "G_100000.pth") + vocab_file = os.path.join(checkpoint_dir, "vocab.txt") + # set config params + with open(config_file, "r", encoding="utf-8") as file: + # Load the JSON data as a dictionary + config_org = json.load(file) + self.config.audio.sample_rate = config_org["data"]["sampling_rate"] + # self.config.add_blank = config['add_blank'] + # set tokenizer + vocab = FairseqVocab(vocab_file) + self.text_encoder.emb = nn.Embedding(vocab.num_chars, config.model_args.hidden_channels) + self.tokenizer = TTSTokenizer( + use_phonemes=False, + text_cleaner=basic_cleaners, + characters=vocab, + phonemizer=None, + add_blank=config_org["data"]["add_blank"], + use_eos_bos=False, + ) + # load fairseq checkpoint + new_chk = rehash_fairseq_vits_checkpoint(checkpoint_file) + self.load_state_dict(new_chk) + if eval: + self.eval() + assert not self.training + @staticmethod def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None, verbose=True): """Initiate model from config @@ -1920,3 +1965,24 @@ def to_config(self) -> "CharactersConfig": is_unique=False, is_sorted=True, ) + + +class FairseqVocab(BaseVocabulary): + def __init__(self, vocab: str): + super(FairseqVocab).__init__() + self.vocab = vocab + + @property + def vocab(self): + """Return the vocabulary dictionary.""" + return self._vocab + + @vocab.setter + def vocab(self, vocab_file): + with open(vocab_file, encoding="utf-8") as f: + self._vocab = [x.replace("\n", "") for x in f.readlines()] + self.blank = self._vocab[0] + print(self._vocab) + self.pad = " " + self._char_to_id = {s: i for i, s in enumerate(self._vocab)} # pylint: disable=unnecessary-comprehension + self._id_to_char = {i: s for i, s in enumerate(self._vocab)} # pylint: disable=unnecessary-comprehension diff --git a/TTS/tts/utils/fairseq.py b/TTS/tts/utils/fairseq.py new file mode 100644 index 0000000000..3d8eec2b4e --- /dev/null +++ b/TTS/tts/utils/fairseq.py @@ -0,0 +1,48 @@ +import torch + + +def rehash_fairseq_vits_checkpoint(checkpoint_file): + chk = torch.load(checkpoint_file, map_location=torch.device("cpu"))["model"] + new_chk = {} + for k, v in chk.items(): + if "enc_p." in k: + new_chk[k.replace("enc_p.", "text_encoder.")] = v + elif "dec." in k: + new_chk[k.replace("dec.", "waveform_decoder.")] = v + elif "enc_q." in k: + new_chk[k.replace("enc_q.", "posterior_encoder.")] = v + elif "flow.flows.2." in k: + new_chk[k.replace("flow.flows.2.", "flow.flows.1.")] = v + elif "flow.flows.4." in k: + new_chk[k.replace("flow.flows.4.", "flow.flows.2.")] = v + elif "flow.flows.6." in k: + new_chk[k.replace("flow.flows.6.", "flow.flows.3.")] = v + elif "dp.flows.0.m" in k: + new_chk[k.replace("dp.flows.0.m", "duration_predictor.flows.0.translation")] = v + elif "dp.flows.0.logs" in k: + new_chk[k.replace("dp.flows.0.logs", "duration_predictor.flows.0.log_scale")] = v + elif "dp.flows.1" in k: + new_chk[k.replace("dp.flows.1", "duration_predictor.flows.1")] = v + elif "dp.flows.3" in k: + new_chk[k.replace("dp.flows.3", "duration_predictor.flows.2")] = v + elif "dp.flows.5" in k: + new_chk[k.replace("dp.flows.5", "duration_predictor.flows.3")] = v + elif "dp.flows.7" in k: + new_chk[k.replace("dp.flows.7", "duration_predictor.flows.4")] = v + elif "dp.post_flows.0.m" in k: + new_chk[k.replace("dp.post_flows.0.m", "duration_predictor.post_flows.0.translation")] = v + elif "dp.post_flows.0.logs" in k: + new_chk[k.replace("dp.post_flows.0.logs", "duration_predictor.post_flows.0.log_scale")] = v + elif "dp.post_flows.1" in k: + new_chk[k.replace("dp.post_flows.1", "duration_predictor.post_flows.1")] = v + elif "dp.post_flows.3" in k: + new_chk[k.replace("dp.post_flows.3", "duration_predictor.post_flows.2")] = v + elif "dp.post_flows.5" in k: + new_chk[k.replace("dp.post_flows.5", "duration_predictor.post_flows.3")] = v + elif "dp.post_flows.7" in k: + new_chk[k.replace("dp.post_flows.7", "duration_predictor.post_flows.4")] = v + elif "dp." in k: + new_chk[k.replace("dp.", "duration_predictor.")] = v + else: + new_chk[k] = v + return new_chk diff --git a/TTS/tts/utils/text/characters.py b/TTS/tts/utils/text/characters.py index 1b375e4fca..8fa45ed84b 100644 --- a/TTS/tts/utils/text/characters.py +++ b/TTS/tts/utils/text/characters.py @@ -63,6 +63,18 @@ def blank_id(self) -> int: the vocabulary.""" return self.char_to_id(self.blank) if self.blank else len(self.vocab) + @property + def bos_id(self) -> int: + """Return the index of the bos character. If the bos character is not specified, return the length of the + vocabulary.""" + return self.char_to_id(self.bos) if self.bos else len(self.vocab) + + @property + def eos_id(self) -> int: + """Return the index of the eos character. If the eos character is not specified, return the length of the + vocabulary.""" + return self.char_to_id(self.eos) if self.eos else len(self.vocab) + @property def vocab(self): """Return the vocabulary dictionary.""" @@ -71,11 +83,13 @@ def vocab(self): @vocab.setter def vocab(self, vocab): """Set the vocabulary dictionary and character mapping dictionaries.""" - self._vocab = vocab - self._char_to_id = {char: idx for idx, char in enumerate(self._vocab)} - self._id_to_char = { - idx: char for idx, char in enumerate(self._vocab) # pylint: disable=unnecessary-comprehension - } + self._vocab, self._char_to_id, self._id_to_char = None, None, None + if vocab is not None: + self._vocab = vocab + self._char_to_id = {char: idx for idx, char in enumerate(self._vocab)} + self._id_to_char = { + idx: char for idx, char in enumerate(self._vocab) # pylint: disable=unnecessary-comprehension + } @staticmethod def init_from_config(config, **kwargs): @@ -93,6 +107,17 @@ def init_from_config(config, **kwargs): ) return BaseVocabulary(**kwargs), config + def to_config(self) -> "CharactersConfig": + return CharactersConfig( + vocab_dict=self._vocab, + pad=self.pad, + eos=self.eos, + bos=self.bos, + blank=self.blank, + is_unique=False, + is_sorted=False, + ) + @property def num_chars(self): """Return number of tokens in the vocabulary.""" @@ -174,6 +199,14 @@ def pad_id(self) -> int: def blank_id(self) -> int: return self.char_to_id(self.blank) if self.blank else len(self.vocab) + @property + def eos_id(self) -> int: + return self.char_to_id(self.eos) if self.eos else len(self.vocab) + + @property + def bos_id(self) -> int: + return self.char_to_id(self.bos) if self.bos else len(self.vocab) + @property def characters(self): return self._characters diff --git a/TTS/tts/utils/text/tokenizer.py b/TTS/tts/utils/text/tokenizer.py index 04cbbd329b..b7faf86e8a 100644 --- a/TTS/tts/utils/text/tokenizer.py +++ b/TTS/tts/utils/text/tokenizer.py @@ -108,11 +108,12 @@ def text_to_ids(self, text: str, language: str = None) -> List[int]: # pylint: text = self.text_cleaner(text) if self.use_phonemes: text = self.phonemizer.phonemize(text, separator="", language=language) + text = self.encode(text) if self.add_blank: text = self.intersperse_blank_char(text, True) if self.use_eos_bos: text = self.pad_with_bos_eos(text) - return self.encode(text) + return text def ids_to_text(self, id_sequence: List[int]) -> str: """Converts a sequence of token IDs to a string of text.""" @@ -120,14 +121,14 @@ def ids_to_text(self, id_sequence: List[int]) -> str: def pad_with_bos_eos(self, char_sequence: List[str]): """Pads a sequence with the special BOS and EOS characters.""" - return [self.characters.bos] + list(char_sequence) + [self.characters.eos] + return [self.characters.bos_id] + list(char_sequence) + [self.characters.eos_id] def intersperse_blank_char(self, char_sequence: List[str], use_blank_char: bool = False): """Intersperses the blank character between characters in a sequence. Use the ```blank``` character if defined else use the ```pad``` character. """ - char_to_use = self.characters.blank if use_blank_char else self.characters.pad + char_to_use = self.characters.blank_id if use_blank_char else self.characters.pad result = [char_to_use] * (len(char_sequence) * 2 + 1) result[1::2] = char_sequence return result diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index 0d0b90648e..98e48a2a12 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -1,5 +1,6 @@ import json import os +import tarfile import zipfile from pathlib import Path from shutil import copyfile, rmtree @@ -245,6 +246,30 @@ def print_model_license(model_item: Dict): else: print(" > Model's license - No license information available") + def download_fairseq_model(self, model_name, output_path): + URI_PREFIX = "https://coqui.gateway.scarf.sh/fairseq/" + _, lang, _, _ = model_name.split("/") + model_download_uri = os.path.join(URI_PREFIX, f"{lang}.tar.gz") + self._download_tar_file(model_download_uri, output_path, self.progress_bar) + + def _set_model_item(self, model_name): + # fetch model info from the dict + model_type, lang, dataset, model = model_name.split("/") + model_full_name = f"{model_type}--{lang}--{dataset}--{model}" + if "fairseq" in model_name: + model_item = { + "model_type": "tts_models", + "license": "CC BY-NC 4.0", + "default_vocoder": None, + "author": "fairseq", + "description": "this model is released by Meta under Fairseq repo. Visit https://github.com/facebookresearch/fairseq/tree/main/examples/mms for more info.", + } + else: + # get model from models.json + model_item = self.models_dict[model_type][lang][dataset][model] + model_item["model_type"] = model_type + return model_item, model_full_name, model + def download_model(self, model_name): """Download model files given the full model name. Model name is in the format @@ -259,11 +284,7 @@ def download_model(self, model_name): Args: model_name (str): model name as explained above. """ - # fetch model info from the dict - model_type, lang, dataset, model = model_name.split("/") - model_full_name = f"{model_type}--{lang}--{dataset}--{model}" - model_item = self.models_dict[model_type][lang][dataset][model] - model_item["model_type"] = model_type + model_item, model_full_name, model = self._set_model_item(model_name) # set the model specific output path output_path = os.path.join(self.output_prefix, model_full_name) if os.path.exists(output_path): @@ -271,16 +292,20 @@ def download_model(self, model_name): else: os.makedirs(output_path, exist_ok=True) print(f" > Downloading model to {output_path}") - # download from github release - if isinstance(model_item["github_rls_url"], list): - self._download_model_files(model_item["github_rls_url"], output_path, self.progress_bar) + # download from fairseq + if "fairseq" in model_name: + self.download_fairseq_model(model_name, output_path) else: - self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar) - self.print_model_license(model_item=model_item) + # download from github release + if isinstance(model_item["github_rls_url"], list): + self._download_model_files(model_item["github_rls_url"], output_path, self.progress_bar) + else: + self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar) + self.print_model_license(model_item=model_item) # find downloaded files output_model_path = output_path output_config_path = None - if model != "tortoise-v2": + if model != "tortoise-v2" and "fairseq" not in model_name: output_model_path, output_config_path = self._find_files(output_path) # update paths in the config.json self._update_paths(output_path, output_config_path) @@ -421,6 +446,39 @@ def _download_zip_file(file_url, output_folder, progress_bar): # remove the extracted folder rmtree(os.path.join(output_folder, z.namelist()[0])) + @staticmethod + def _download_tar_file(file_url, output_folder, progress_bar): + """Download the github releases""" + # download the file + r = requests.get(file_url, stream=True) + # extract the file + try: + total_size_in_bytes = int(r.headers.get("content-length", 0)) + block_size = 1024 # 1 Kibibyte + if progress_bar: + progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) + temp_tar_name = os.path.join(output_folder, file_url.split("/")[-1]) + with open(temp_tar_name, "wb") as file: + for data in r.iter_content(block_size): + if progress_bar: + progress_bar.update(len(data)) + file.write(data) + with tarfile.open(temp_tar_name) as t: + t.extractall(output_folder) + tar_names = t.getnames() + os.remove(temp_tar_name) # delete tar after extract + except tarfile.ReadError: + print(f" > Error: Bad tar file - {file_url}") + raise tarfile.ReadError # pylint: disable=raise-missing-from + # move the files to the outer path + for file_path in os.listdir(os.path.join(output_folder, tar_names[0])): + src_path = os.path.join(output_folder, tar_names[0], file_path) + dst_path = os.path.join(output_folder, os.path.basename(file_path)) + if src_path != dst_path: + copyfile(src_path, dst_path) + # remove the extracted folder + rmtree(os.path.join(output_folder, tar_names[0])) + @staticmethod def _download_model_files(file_urls, output_folder, progress_bar): """Download the github releases""" diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index bdecc82e91..f1dce70f73 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -7,7 +7,9 @@ import torch from TTS.config import load_config +from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.models import setup_model as setup_tts_model +from TTS.tts.models.vits import Vits # pylint: disable=unused-wildcard-import # pylint: disable=wildcard-import @@ -98,8 +100,12 @@ def __init__( self.output_sample_rate = self.vc_config.audio["output_sample_rate"] if model_dir: - self._load_tts_from_dir(model_dir, use_cuda) - self.output_sample_rate = self.tts_config.audio["output_sample_rate"] + if "fairseq" in model_dir: + self._load_fairseq_from_dir(model_dir, use_cuda) + self.output_sample_rate = self.tts_config.audio["sample_rate"] + else: + self._load_tts_from_dir(model_dir, use_cuda) + self.output_sample_rate = self.tts_config.audio["output_sample_rate"] @staticmethod def _get_segmenter(lang: str): @@ -133,12 +139,23 @@ def _load_vc(self, vc_checkpoint: str, vc_config_path: str, use_cuda: bool) -> N if use_cuda: self.vc_model.cuda() + def _load_fairseq_from_dir(self, model_dir: str, use_cuda: bool) -> None: + """Load the fairseq model from a directory. + + We assume it is VITS and the model knows how to load itself from the directory and there is a config.json file in the directory. + """ + self.tts_config = VitsConfig() + self.tts_model = Vits.init_from_config(self.tts_config) + self.tts_model.load_fairseq_checkpoint(self.tts_config, checkpoint_dir=model_dir, eval=True) + self.tts_config = self.tts_model.config + if use_cuda: + self.tts_model.cuda() + def _load_tts_from_dir(self, model_dir: str, use_cuda: bool) -> None: """Load the TTS model from a directory. We assume the model knows how to load itself from the directory and there is a config.json file in the directory. """ - config = load_config(os.path.join(model_dir, "config.json")) self.tts_config = config self.tts_model = setup_tts_model(config) diff --git a/docs/source/inference.md b/docs/source/inference.md index 4abdd3271c..3dd9232e59 100644 --- a/docs/source/inference.md +++ b/docs/source/inference.md @@ -128,7 +128,7 @@ wav = tts.tts("This is a test! This is also a test!!", speaker=tts.speakers[0], tts.tts_to_file(text="Hello world!", speaker=tts.speakers[0], language=tts.languages[0], file_path="output.wav") ``` -Here is an example for a single speaker model. +#### Here is an example for a single speaker model. ```python # Init TTS with the target model name @@ -137,7 +137,7 @@ tts = TTS(model_name="tts_models/de/thorsten/tacotron2-DDC", progress_bar=False, tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path=OUTPUT_PATH) ``` -Example voice cloning with YourTTS in English, French and Portuguese: +#### Example voice cloning with YourTTS in English, French and Portuguese: ```python tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=True) @@ -146,15 +146,16 @@ tts.tts_to_file("C'est le clonage de la voix.", speaker_wav="my/cloning/audio.wa tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav", language="pt", file_path="output.wav") ``` -Example voice conversion converting speaker of the `source_wav` to the speaker of the `target_wav` +#### Example voice conversion converting speaker of the `source_wav` to the speaker of the `target_wav` ```python tts = TTS(model_name="voice_conversion_models/multilingual/vctk/freevc24", progress_bar=False, gpu=True) tts.voice_conversion_to_file(source_wav="my/source.wav", target_wav="my/target.wav", file_path="output.wav") ``` -Example voice cloning by a single speaker TTS model combining with the voice conversion model. This way, you can -clone voices by using any model in 🐸TTS. +#### Example voice cloning by a single speaker TTS model combining with the voice conversion model. + +This way, you can clone voices by using any model in 🐸TTS. ```python tts = TTS("tts_models/de/thorsten/tacotron2-DDC") @@ -163,8 +164,11 @@ tts.tts_with_vc_to_file( speaker_wav="target/speaker.wav", file_path="ouptut.wav" ) +``` -Example text to speech using [🐸Coqui Studio](https://coqui.ai) models. You can use all of your available speakers in the studio. +#### Example text to speech using [🐸Coqui Studio](https://coqui.ai) models. + +You can use all of your available speakers in the studio. [🐸Coqui Studio](https://coqui.ai) API token is required. You can get it from the [account page](https://coqui.ai/account). You should set the `COQUI_STUDIO_TOKEN` environment variable to use the API token. @@ -193,4 +197,23 @@ api.emotions api.list_speakers() api.list_voices() wav, sample_rate = api.tts(text="This is a test.", speaker=api.speakers[0].name, emotion="Happy", speed=1.5) +``` + +#### Example text to speech using **Fairseq models in ~1100 languages** 🤯. +For these models use the following name format: `tts_models//fairseq/vits`. + +You can find the list of language ISO codes [here](https://dl.fbaipublicfiles.com/mms/tts/all-tts-languages.html) and learn about the Fairseq models [here](https://github.com/facebookresearch/fairseq/tree/main/examples/mms). + +```python +from TTS.api import TTS +api = TTS(model_name="tts_models/eng/fairseq/vits", gpu=True) +api.tts_to_file("This is a test.", file_path="output.wav") + +# TTS with on the fly voice conversion +api = TTS("tts_models/deu/fairseq/vits") +api.tts_with_vc_to_file( + "Wie sage ich auf Italienisch, dass ich dich liebe?", + speaker_wav="target/speaker.wav", + file_path="ouptut.wav" +) ``` \ No newline at end of file diff --git a/tests/inference_tests/test_python_api.py b/tests/inference_tests/test_python_api.py index f8ee4505d4..2025fcd9c6 100644 --- a/tests/inference_tests/test_python_api.py +++ b/tests/inference_tests/test_python_api.py @@ -60,7 +60,7 @@ def test_single_speaker_model(self): self.assertIsNone(tts.languages) def test_studio_model(self): - tts = TTS(model_name="coqui_studio/en/Torcull Diarmuid/coqui_studio") + tts = TTS(model_name="coqui_studio/en/Zacharie Aimilios/coqui_studio") tts.tts_to_file(text="This is a test.") # check speed > 2.0 raises error @@ -83,6 +83,10 @@ def test_studio_model(self): wav = tts.tts(text="This is a test.", speed=2.0, emotion="Sad") self.assertGreater(len(wav), 0) + def test_fairseq_model(self): # pylint: disable=no-self-use + tts = TTS(model_name="tts_models/eng/fairseq/vits") + tts.tts_to_file(text="This is a test.") + def test_multi_speaker_multi_lingual_model(self): tts = TTS() tts.load_tts_model_by_name(tts.models[0]) # YourTTS