Skip to content

Commit

Permalink
Port Fairseq TTS models (#2628)
Browse files Browse the repository at this point in the history
* Load fairseq models

* Add docs and missing files

* Managing fairseq models and docs for API

* Make style

* Use scarf URL

* Add tests

* Fix URL

* Pass cpu

* Make lint

* Fixup

* Make lint

* fixup

* Fixup

* Change tokenization order

* Update README

* Fixup

* Fixup
  • Loading branch information
erogol authored Jun 5, 2023
1 parent 0d5e68a commit e785d10
Show file tree
Hide file tree
Showing 10 changed files with 314 additions and 38 deletions.
34 changes: 28 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@


## 🐸Coqui.ai News
- 📣 Coqui Studio API is landed on 🐸TTS. You can use the studio voices in combination with 🐸TTS models. [Example](https://github.com/coqui-ai/TTS/blob/dev/README.md#-python-api)
- 📣 Voice generation with prompts - **Prompt to Voice** - is live on Coqui.ai!! [Blog Post](https://coqui.ai/blog/tts/prompt-to-voice)
- 📣 Clone your voice with a single click on [🐸Coqui.ai](https://app.coqui.ai/auth/signin)
<br>
- 📣 You can use [~1100 Fairseq models](https://github.com/facebookresearch/fairseq/tree/main/examples/mms) with 🐸TTS.
- 📣 🐸TTS now supports 🐢Tortoise with faster inference.
- 📣 **Coqui Studio API** is landed on 🐸TTS. - [Example](https://github.com/coqui-ai/TTS/blob/dev/README.md#-python-api)
- 📣 [**Coqui Sudio API**](https://docs.coqui.ai/docs) is live.
- 📣 Voice generation with prompts - **Prompt to Voice** - is live on [**Coqui Studio**](https://app.coqui.ai/auth/signin)!! - [Blog Post](https://coqui.ai/blog/tts/prompt-to-voice)
- 📣 Voice generation with fusion - **Voice fusion** - is live on [**Coqui Studio**](https://app.coqui.ai/auth/signin).
- 📣 Voice cloning is live on [**Coqui Studio**](https://app.coqui.ai/auth/signin).

## <img src="https://raw.githubusercontent.com/coqui-ai/TTS/main/images/coqui-log-green-TTS.png" height="56"/>

Expand Down Expand Up @@ -185,7 +188,9 @@ from TTS.api import TTS
model_name = TTS.list_models()[0]
# Init TTS
tts = TTS(model_name)

# Run TTS

# ❗ Since this model is multi-speaker and multi-lingual, we must set the target speaker and the language
# Text to speech with a numpy output
wav = tts.tts("This is a test! This is also a test!!", speaker=tts.speakers[0], language=tts.languages[0])
Expand All @@ -199,7 +204,8 @@ tts = TTS(model_name="tts_models/de/thorsten/tacotron2-DDC", progress_bar=False,
# Run TTS
tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path=OUTPUT_PATH)

# Example voice cloning with YourTTS in English, French and Portuguese:
# Example voice cloning with YourTTS in English, French and Portuguese

tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=True)
tts.tts_to_file("This is voice cloning.", speaker_wav="my/cloning/audio.wav", language="en", file_path="output.wav")
tts.tts_to_file("C'est le clonage de la voix.", speaker_wav="my/cloning/audio.wav", language="fr-fr", file_path="output.wav")
Expand All @@ -221,7 +227,9 @@ tts.tts_with_vc_to_file(
file_path="ouptut.wav"
)

# Example text to speech using [🐸Coqui Studio](https://coqui.ai) models. You can use all of your available speakers in the studio.
# Example text to speech using [🐸Coqui Studio](https://coqui.ai) models.

# You can use all of your available speakers in the studio.
# [🐸Coqui Studio](https://coqui.ai) API token is required. You can get it from the [account page](https://coqui.ai/account).
# You should set the `COQUI_STUDIO_TOKEN` environment variable to use the API token.

Expand All @@ -234,6 +242,20 @@ tts = TTS(model_name="coqui_studio/en/Torcull Diarmuid/coqui_studio", progress_b
tts.tts_to_file(text="This is a test.", file_path=OUTPUT_PATH)
# Run TTS with emotion and speed control
tts.tts_to_file(text="This is a test.", file_path=OUTPUT_PATH, emotion="Happy", speed=1.5)


#Example text to speech using **Fairseq models in ~1100 languages** 🤯.

#For these models use the following name format: `tts_models/<lang-iso_code>/fairseq/vits`.
#You can find the list of language ISO codes [here](https://dl.fbaipublicfiles.com/mms/tts/all-tts-languages.html) and learn about the Fairseq models [here](https://github.com/facebookresearch/fairseq/tree/main/examples/mms).

# TTS with on the fly voice conversion
api = TTS("tts_models/deu/fairseq/vits")
api.tts_with_vc_to_file(
"Wie sage ich auf Italienisch, dass ich dich liebe?",
speaker_wav="target/speaker.wav",
file_path="ouptut.wav"
)
```

### Command line `tts`
Expand Down
8 changes: 6 additions & 2 deletions TTS/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def name_to_speaker(self, name):
for speaker in self.speakers:
if speaker.name == name:
return speaker
raise ValueError(f"Speaker {name} not found.")
raise ValueError(f"Speaker {name} not found in {self.speakers}")

def id_to_speaker(self, speaker_id):
for speaker in self.speakers:
Expand Down Expand Up @@ -264,6 +264,10 @@ def __init__(
>>> tts.tts_to_file("C'est le clonage de la voix.", speaker_wav="my/cloning/audio.wav", language="fr", file_path="thisisit.wav")
>>> tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav", language="pt", file_path="thisisit.wav")
Example Fairseq TTS models (uses ISO language codes in https://dl.fbaipublicfiles.com/mms/tts/all-tts-languages.html):
>>> tts = TTS(model_name="tts_models/eng/fairseq/vits", progress_bar=False, gpu=True)
>>> tts.tts_to_file("This is a test.", file_path="output.wav")
Args:
model_name (str, optional): Model name to load. You can list models by ```tts.models```. Defaults to None.
model_path (str, optional): Path to the model checkpoint. Defaults to None.
Expand Down Expand Up @@ -342,7 +346,7 @@ def list_models():

def download_model_by_name(self, model_name: str):
model_path, config_path, model_item = self.manager.download_model(model_name)
if isinstance(model_item["github_rls_url"], list):
if "fairseq" in model_name or (model_item is not None and isinstance(model_item["github_rls_url"], list)):
# return model directory if there are multiple files
# we assume that the model knows how to load itself
return None, None, None, None, model_path
Expand Down
68 changes: 67 additions & 1 deletion TTS/tts/models/vits.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.fairseq import rehash_fairseq_vits_checkpoint
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations
from TTS.tts.utils.text.characters import BaseCharacters, BaseVocabulary, _characters, _pad, _phonemes, _punctuations
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment
from TTS.utils.io import load_fsspec
Expand Down Expand Up @@ -1723,6 +1724,50 @@ def load_checkpoint(
self.eval()
assert not self.training

def load_fairseq_checkpoint(
self, config, checkpoint_dir, eval=False
): # pylint: disable=unused-argument, redefined-builtin
"""Load VITS checkpoints released by fairseq here: https://github.com/facebookresearch/fairseq/tree/main/examples/mms
Performs some changes for compatibility.
Args:
config (Coqpit): 🐸TTS model config.
checkpoint_dir (str): Path to the checkpoint directory.
eval (bool, optional): Set to True for evaluation. Defaults to False.
"""
import json

from TTS.tts.utils.text.cleaners import basic_cleaners

self.disc = None
# set paths
config_file = os.path.join(checkpoint_dir, "config.json")
checkpoint_file = os.path.join(checkpoint_dir, "G_100000.pth")
vocab_file = os.path.join(checkpoint_dir, "vocab.txt")
# set config params
with open(config_file, "r", encoding="utf-8") as file:
# Load the JSON data as a dictionary
config_org = json.load(file)
self.config.audio.sample_rate = config_org["data"]["sampling_rate"]
# self.config.add_blank = config['add_blank']
# set tokenizer
vocab = FairseqVocab(vocab_file)
self.text_encoder.emb = nn.Embedding(vocab.num_chars, config.model_args.hidden_channels)
self.tokenizer = TTSTokenizer(
use_phonemes=False,
text_cleaner=basic_cleaners,
characters=vocab,
phonemizer=None,
add_blank=config_org["data"]["add_blank"],
use_eos_bos=False,
)
# load fairseq checkpoint
new_chk = rehash_fairseq_vits_checkpoint(checkpoint_file)
self.load_state_dict(new_chk)
if eval:
self.eval()
assert not self.training

@staticmethod
def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None, verbose=True):
"""Initiate model from config
Expand Down Expand Up @@ -1919,3 +1964,24 @@ def to_config(self) -> "CharactersConfig":
is_unique=False,
is_sorted=True,
)


class FairseqVocab(BaseVocabulary):
def __init__(self, vocab: str):
super(FairseqVocab).__init__()
self.vocab = vocab

@property
def vocab(self):
"""Return the vocabulary dictionary."""
return self._vocab

@vocab.setter
def vocab(self, vocab_file):
with open(vocab_file, encoding="utf-8") as f:
self._vocab = [x.replace("\n", "") for x in f.readlines()]
self.blank = self._vocab[0]
print(self._vocab)
self.pad = " "
self._char_to_id = {s: i for i, s in enumerate(self._vocab)} # pylint: disable=unnecessary-comprehension
self._id_to_char = {i: s for i, s in enumerate(self._vocab)} # pylint: disable=unnecessary-comprehension
48 changes: 48 additions & 0 deletions TTS/tts/utils/fairseq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import torch


def rehash_fairseq_vits_checkpoint(checkpoint_file):
chk = torch.load(checkpoint_file, map_location=torch.device("cpu"))["model"]
new_chk = {}
for k, v in chk.items():
if "enc_p." in k:
new_chk[k.replace("enc_p.", "text_encoder.")] = v
elif "dec." in k:
new_chk[k.replace("dec.", "waveform_decoder.")] = v
elif "enc_q." in k:
new_chk[k.replace("enc_q.", "posterior_encoder.")] = v
elif "flow.flows.2." in k:
new_chk[k.replace("flow.flows.2.", "flow.flows.1.")] = v
elif "flow.flows.4." in k:
new_chk[k.replace("flow.flows.4.", "flow.flows.2.")] = v
elif "flow.flows.6." in k:
new_chk[k.replace("flow.flows.6.", "flow.flows.3.")] = v
elif "dp.flows.0.m" in k:
new_chk[k.replace("dp.flows.0.m", "duration_predictor.flows.0.translation")] = v
elif "dp.flows.0.logs" in k:
new_chk[k.replace("dp.flows.0.logs", "duration_predictor.flows.0.log_scale")] = v
elif "dp.flows.1" in k:
new_chk[k.replace("dp.flows.1", "duration_predictor.flows.1")] = v
elif "dp.flows.3" in k:
new_chk[k.replace("dp.flows.3", "duration_predictor.flows.2")] = v
elif "dp.flows.5" in k:
new_chk[k.replace("dp.flows.5", "duration_predictor.flows.3")] = v
elif "dp.flows.7" in k:
new_chk[k.replace("dp.flows.7", "duration_predictor.flows.4")] = v
elif "dp.post_flows.0.m" in k:
new_chk[k.replace("dp.post_flows.0.m", "duration_predictor.post_flows.0.translation")] = v
elif "dp.post_flows.0.logs" in k:
new_chk[k.replace("dp.post_flows.0.logs", "duration_predictor.post_flows.0.log_scale")] = v
elif "dp.post_flows.1" in k:
new_chk[k.replace("dp.post_flows.1", "duration_predictor.post_flows.1")] = v
elif "dp.post_flows.3" in k:
new_chk[k.replace("dp.post_flows.3", "duration_predictor.post_flows.2")] = v
elif "dp.post_flows.5" in k:
new_chk[k.replace("dp.post_flows.5", "duration_predictor.post_flows.3")] = v
elif "dp.post_flows.7" in k:
new_chk[k.replace("dp.post_flows.7", "duration_predictor.post_flows.4")] = v
elif "dp." in k:
new_chk[k.replace("dp.", "duration_predictor.")] = v
else:
new_chk[k] = v
return new_chk
43 changes: 38 additions & 5 deletions TTS/tts/utils/text/characters.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,18 @@ def blank_id(self) -> int:
the vocabulary."""
return self.char_to_id(self.blank) if self.blank else len(self.vocab)

@property
def bos_id(self) -> int:
"""Return the index of the bos character. If the bos character is not specified, return the length of the
vocabulary."""
return self.char_to_id(self.bos) if self.bos else len(self.vocab)

@property
def eos_id(self) -> int:
"""Return the index of the eos character. If the eos character is not specified, return the length of the
vocabulary."""
return self.char_to_id(self.eos) if self.eos else len(self.vocab)

@property
def vocab(self):
"""Return the vocabulary dictionary."""
Expand All @@ -71,11 +83,13 @@ def vocab(self):
@vocab.setter
def vocab(self, vocab):
"""Set the vocabulary dictionary and character mapping dictionaries."""
self._vocab = vocab
self._char_to_id = {char: idx for idx, char in enumerate(self._vocab)}
self._id_to_char = {
idx: char for idx, char in enumerate(self._vocab) # pylint: disable=unnecessary-comprehension
}
self._vocab, self._char_to_id, self._id_to_char = None, None, None
if vocab is not None:
self._vocab = vocab
self._char_to_id = {char: idx for idx, char in enumerate(self._vocab)}
self._id_to_char = {
idx: char for idx, char in enumerate(self._vocab) # pylint: disable=unnecessary-comprehension
}

@staticmethod
def init_from_config(config, **kwargs):
Expand All @@ -93,6 +107,17 @@ def init_from_config(config, **kwargs):
)
return BaseVocabulary(**kwargs), config

def to_config(self) -> "CharactersConfig":
return CharactersConfig(
vocab_dict=self._vocab,
pad=self.pad,
eos=self.eos,
bos=self.bos,
blank=self.blank,
is_unique=False,
is_sorted=False,
)

@property
def num_chars(self):
"""Return number of tokens in the vocabulary."""
Expand Down Expand Up @@ -174,6 +199,14 @@ def pad_id(self) -> int:
def blank_id(self) -> int:
return self.char_to_id(self.blank) if self.blank else len(self.vocab)

@property
def eos_id(self) -> int:
return self.char_to_id(self.eos) if self.eos else len(self.vocab)

@property
def bos_id(self) -> int:
return self.char_to_id(self.bos) if self.bos else len(self.vocab)

@property
def characters(self):
return self._characters
Expand Down
7 changes: 4 additions & 3 deletions TTS/tts/utils/text/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,26 +108,27 @@ def text_to_ids(self, text: str, language: str = None) -> List[int]: # pylint:
text = self.text_cleaner(text)
if self.use_phonemes:
text = self.phonemizer.phonemize(text, separator="", language=language)
text = self.encode(text)
if self.add_blank:
text = self.intersperse_blank_char(text, True)
if self.use_eos_bos:
text = self.pad_with_bos_eos(text)
return self.encode(text)
return text

def ids_to_text(self, id_sequence: List[int]) -> str:
"""Converts a sequence of token IDs to a string of text."""
return self.decode(id_sequence)

def pad_with_bos_eos(self, char_sequence: List[str]):
"""Pads a sequence with the special BOS and EOS characters."""
return [self.characters.bos] + list(char_sequence) + [self.characters.eos]
return [self.characters.bos_id] + list(char_sequence) + [self.characters.eos_id]

def intersperse_blank_char(self, char_sequence: List[str], use_blank_char: bool = False):
"""Intersperses the blank character between characters in a sequence.
Use the ```blank``` character if defined else use the ```pad``` character.
"""
char_to_use = self.characters.blank if use_blank_char else self.characters.pad
char_to_use = self.characters.blank_id if use_blank_char else self.characters.pad
result = [char_to_use] * (len(char_sequence) * 2 + 1)
result[1::2] = char_sequence
return result
Expand Down
Loading

0 comments on commit e785d10

Please sign in to comment.