Skip to content

Commit

Permalink
Add FT inference example on XTTS docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Edresson committed Oct 23, 2023
1 parent 67ca70a commit 0f96abb
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 4 deletions.
2 changes: 2 additions & 0 deletions TTS/tts/models/xtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,8 @@ def get_compatible_checkpoint_state_dict(self, model_path):
ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan or self.args.use_ne_hifigan else []
ignore_keys += [] if self.args.use_hifigan else ["hifigan_decoder"]
ignore_keys += [] if self.args.use_ne_hifigan else ["ne_hifigan_decoder"]
# remove xtts gpt trainer extra keys
ignore_keys += ["torch_mel_spectrogram_style_encoder", "torch_mel_spectrogram_dvae", "dvae"]
for key in list(checkpoint.keys()):
# check if it is from the coqui Trainer if so convert it
if key.startswith("xtts."):
Expand Down
50 changes: 46 additions & 4 deletions docs/source/models/xtts.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ a few tricks to make it faster and support streaming inference.
Current implementation only supports inference.

### Languages
As of now, XTTS-v1 supports 13 languages: English, Spanish, French, German, Italian, Portuguese,
Polish, Turkish, Russian, Dutch, Czech, Arabic, and Chinese (Simplified).
As of now, XTTS-v1.1 supports 14 languages: English, Spanish, French, German, Italian, Portuguese,
Polish, Turkish, Russian, Dutch, Czech, Arabic, Chinese (Simplified) and Japanese.

Stay tuned as we continue to add support for more languages. If you have any language requests, please feel free to reach out.

Expand All @@ -33,7 +33,7 @@ You can also mail us at info@coqui.ai.

```python
from TTS.api import TTS
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v1", gpu=True)
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v1.1", gpu=True)

# generate speech by cloning a voice using default settings
tts.tts_to_file(text="It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
Expand All @@ -45,7 +45,7 @@ tts.tts_to_file(text="It took me quite a long time to develop a voice, and now t
#### 🐸TTS Command line

```console
tts --model_name tts_models/multilingual/multi-dataset/xtts_v1 \
tts --model_name tts_models/multilingual/multi-dataset/xtts_v1.1 \
--text "Bugün okula gitmek istemiyorum." \
--speaker_wav /path/to/target/speaker.wav \
--language_idx tr \
Expand Down Expand Up @@ -142,6 +142,48 @@ A recipe for `XTTS_v1.1` GPT encoder training using `LJSpeech` dataset is availa

You need to change the fields of the `BaseDatasetConfig` to match your dataset and then update `GPTArgs` and `GPTTrainerConfig` fields as you need. By default, it will use the same parameters that XTTS v1.1 model was trained with. To speed up the model convergence, as default, it will also download the XTTS v1.1 checkpoint and load it.

After training you can do inference following the code bellow.

```python
import os
import torch
import torchaudio
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts

# Add here the xtts_config path
CONFIG_PATH = "recipes/ljspeech/xtts_v1/run/training/GPT_XTTS_LJSpeech_FT-October-23-2023_10+36AM-653f2e75/config.json"
# Add here the vocab file that you have used to train the model
TOKENIZER_PATH = "recipes/ljspeech/xtts_v1/run/training/XTTS_v1.1_original_model_files/vocab.json"
# Add here the checkpoint that you want to do inference with
XTTS_CHECKPOINT = "recipes/ljspeech/xtts_v1/run/training/GPT_XTTS_LJSpeech_FT/best_model.pth"
# Add here the speaker reference
SPEAKER_REFERENCE = "LjSpeech_reference.wav"

# output wav path
OUTPUT_WAV_PATH = "xtts-ft.wav"

print("Loading model...")
config = XttsConfig()
config.load_json(CONFIG_PATH)
model = Xtts.init_from_config(config)
model.load_checkpoint(config, checkpoint_path=XTTS_CHECKPOINT, vocab_path=TOKENIZER_PATH, use_deepspeed=False)
model.cuda()

print("Computing speaker latents...")
gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path=SPEAKER_REFERENCE)

print("Inference...")
out = model.inference(
"It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
"en",
gpt_cond_latent,
speaker_embedding,
diffusion_conditioning,
temperature=0.7, # Add custom parameters here
)
torchaudio.save(OUTPUT_WAV_PATH, torch.tensor(out["wav"]).unsqueeze(0), 24000)
```


## Important resources & papers
Expand Down

0 comments on commit 0f96abb

Please sign in to comment.