Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streaming inference for XTTS 🚀 #3035

Merged
merged 13 commits into from
Oct 6, 2023
8 changes: 4 additions & 4 deletions TTS/.models.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
"xtts_v1": {
"description": "XTTS-v1 by Coqui with 13 languages and cross-language voice cloning.",
"hf_url": [
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/model.pth",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/config.json",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/vocab.json"
"https://huggingface.co/coqui/XTTS-v1/resolve/hifigan/model.pth",
"https://huggingface.co/coqui/XTTS-v1/resolve/hifigan/config.json",
"https://huggingface.co/coqui/XTTS-v1/resolve/hifigan/vocab.json"
],
"default_vocoder": null,
"commit": "e9a1953e",
"commit": "e5140314",
"license": "CPML",
"contact": "info@coqui.ai",
"tos_required": true
Expand Down
24 changes: 23 additions & 1 deletion TTS/tts/layers/xtts/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def get_grad_norm_parameter_groups(self):
"heads": list(self.text_head.parameters()) + list(self.mel_head.parameters()),
}

def init_gpt_for_inference(self, kv_cache=True):
def init_gpt_for_inference(self, kv_cache=True, use_deepspeed=False):
seq_length = self.max_prompt_tokens + self.max_mel_tokens + self.max_text_tokens + 1
gpt_config = GPT2Config(
vocab_size=self.max_mel_tokens,
Expand All @@ -195,6 +195,17 @@ def init_gpt_for_inference(self, kv_cache=True):
)
self.gpt.wte = self.mel_embedding

if use_deepspeed:
import deepspeed
self.ds_engine = deepspeed.init_inference(
model=self.gpt_inference.half(), # Transformers models
mp_size=1, # Number of GPU
dtype=torch.float32, # desired data type of output
replace_method="auto", # Lets DS autmatically identify the layer to replace
replace_with_kernel_inject=True, # replace the model with the kernel injector
)
self.gpt_inference = self.ds_engine.module.eval()

def set_inputs_and_targets(self, input, start_token, stop_token):
inp = F.pad(input, (1, 0), value=start_token)
tar = F.pad(input, (0, 1), value=stop_token)
Expand Down Expand Up @@ -543,3 +554,14 @@ def generate(
if "return_dict_in_generate" in hf_generate_kwargs:
return gen.sequences[:, gpt_inputs.shape[1] :], gen
return gen[:, gpt_inputs.shape[1] :]

def get_generator(self, fake_inputs, **hf_generate_kwargs):
return self.gpt_inference.generate_stream(
fake_inputs,
bos_token_id=self.start_audio_token,
pad_token_id=self.stop_audio_token,
eos_token_id=self.stop_audio_token,
max_length=self.max_mel_tokens * 2 + self.max_prompt_tokens + self.max_text_tokens,
do_stream=True,
**hf_generate_kwargs,
)
Loading
Loading