Skip to content

Commit

Permalink
Streaming inference for XTTS 🚀 (#3035)
Browse files Browse the repository at this point in the history
  • Loading branch information
WeberJulian authored Oct 6, 2023
1 parent 2150136 commit e5e0cbf
Show file tree
Hide file tree
Showing 8 changed files with 2,192 additions and 129 deletions.
8 changes: 4 additions & 4 deletions TTS/.models.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
"xtts_v1": {
"description": "XTTS-v1 by Coqui with 13 languages and cross-language voice cloning.",
"hf_url": [
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/model.pth",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/config.json",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/vocab.json"
"https://huggingface.co/coqui/XTTS-v1/resolve/hifigan/model.pth",
"https://huggingface.co/coqui/XTTS-v1/resolve/hifigan/config.json",
"https://huggingface.co/coqui/XTTS-v1/resolve/hifigan/vocab.json"
],
"default_vocoder": null,
"commit": "e9a1953e",
"commit": "e5140314",
"license": "CPML",
"contact": "info@coqui.ai",
"tos_required": true
Expand Down
24 changes: 23 additions & 1 deletion TTS/tts/layers/xtts/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def get_grad_norm_parameter_groups(self):
"heads": list(self.text_head.parameters()) + list(self.mel_head.parameters()),
}

def init_gpt_for_inference(self, kv_cache=True):
def init_gpt_for_inference(self, kv_cache=True, use_deepspeed=False):
seq_length = self.max_prompt_tokens + self.max_mel_tokens + self.max_text_tokens + 1
gpt_config = GPT2Config(
vocab_size=self.max_mel_tokens,
Expand All @@ -195,6 +195,17 @@ def init_gpt_for_inference(self, kv_cache=True):
)
self.gpt.wte = self.mel_embedding

if use_deepspeed:
import deepspeed
self.ds_engine = deepspeed.init_inference(
model=self.gpt_inference.half(), # Transformers models
mp_size=1, # Number of GPU
dtype=torch.float32, # desired data type of output
replace_method="auto", # Lets DS autmatically identify the layer to replace
replace_with_kernel_inject=True, # replace the model with the kernel injector
)
self.gpt_inference = self.ds_engine.module.eval()

def set_inputs_and_targets(self, input, start_token, stop_token):
inp = F.pad(input, (1, 0), value=start_token)
tar = F.pad(input, (0, 1), value=stop_token)
Expand Down Expand Up @@ -543,3 +554,14 @@ def generate(
if "return_dict_in_generate" in hf_generate_kwargs:
return gen.sequences[:, gpt_inputs.shape[1] :], gen
return gen[:, gpt_inputs.shape[1] :]

def get_generator(self, fake_inputs, **hf_generate_kwargs):
return self.gpt_inference.generate_stream(
fake_inputs,
bos_token_id=self.start_audio_token,
pad_token_id=self.stop_audio_token,
eos_token_id=self.stop_audio_token,
max_length=self.max_mel_tokens * 2 + self.max_prompt_tokens + self.max_text_tokens,
do_stream=True,
**hf_generate_kwargs,
)
Loading

0 comments on commit e5e0cbf

Please sign in to comment.