From adbbc02a84c0cff6dd3f877d89d947d604cdab2d Mon Sep 17 00:00:00 2001 From: Austin Bowen Date: Tue, 28 May 2024 18:55:09 -0400 Subject: [PATCH] Add ParlerTTS (#7) --- README.md | 8 ++ docs/voicebox.tts.rst | 8 ++ src/voicebox/tts/__init__.py | 5 + src/voicebox/tts/parlertts.py | 113 +++++++++++++++++++++++ tests/integration/test_all_ttss_basic.py | 4 + 5 files changed, 138 insertions(+) create mode 100644 src/voicebox/tts/parlertts.py diff --git a/README.md b/README.md index 6421c79..c5f1845 100644 --- a/README.md +++ b/README.md @@ -86,6 +86,14 @@ Online TTS engine used by Google Translate. 1. `pip install "voicebox-tts[gtts]"` 2. Install ffmpeg or libav for `pydub` ([docs](https://github.com/jiaaro/pydub#getting-ffmpeg-set-up)) +### 🤗 Parler TTS [🌐](https://github.com/huggingface/parler-tts) + +Offline TTS engine released by Hugging Face that uses a promptable +deep learning model to generate speech. + +- Class: [`voicebox.tts.ParlerTTS`](voicebox.tts.parlertts.ParlerTTS) +- Setup: `pip install git+https://github.com/huggingface/parler-tts.git` + ### Pico TTS Very basic offline TTS engine. diff --git a/docs/voicebox.tts.rst b/docs/voicebox.tts.rst index 1889a0b..79e1b23 100644 --- a/docs/voicebox.tts.rst +++ b/docs/voicebox.tts.rst @@ -52,6 +52,14 @@ voicebox.tts.gtts module :undoc-members: :show-inheritance: +voicebox.tts.parlertts module +----------------------------- + +.. automodule:: voicebox.tts.parlertts + :members: + :undoc-members: + :show-inheritance: + voicebox.tts.picotts module --------------------------- diff --git a/src/voicebox/tts/__init__.py b/src/voicebox/tts/__init__.py index bf85ab8..4d73d87 100644 --- a/src/voicebox/tts/__init__.py +++ b/src/voicebox/tts/__init__.py @@ -34,6 +34,11 @@ except ImportError: pass +try: + from voicebox.tts.parlertts import ParlerTTS +except ImportError: + pass + from voicebox.tts.picotts import PicoTTS try: diff --git a/src/voicebox/tts/parlertts.py b/src/voicebox/tts/parlertts.py new file mode 100644 index 0000000..0d9491c --- /dev/null +++ b/src/voicebox/tts/parlertts.py @@ -0,0 +1,113 @@ +from typing import Union + +import torch +from parler_tts import ParlerTTSForConditionalGeneration +from transformers import AutoTokenizer + +from voicebox.audio import Audio +from voicebox.tts import TTS +from voicebox.types import StrOrSSML + + +class ParlerTTS(TTS): + """ + Offline TTS engine released by Hugging Face 🤗 that uses a promptable + deep learning model to generate speech. + + Use ``ParlerTTS.build()`` to create a new instance of this class + instead of instantiating it directly. + + See repo for details: https://github.com/huggingface/parler-tts + """ + + model: torch.nn.Module + tokenizer: AutoTokenizer + device: torch.device + + def __init__( + self, + model: torch.nn.Module, + tokenizer: AutoTokenizer, + device: torch.device, + description: str, + ): + self.model = model + self.tokenizer = tokenizer + self.device = device + + self._tokenized_description = None + self.description = description + + @property + def description(self) -> str: + return self._description + + @description.setter + def description(self, value: str): + self._description = value + self._tokenized_description = self._tokenize(value) + + @classmethod + def build( + cls, + description: str = '', + model_name: str = 'parler-tts/parler_tts_mini_v0.1', + device: Union[str, torch.device] = None, + torch_dtype: torch.dtype = None, + ): + """ + Build a new instance of ``ParlerTTS``. + + Args: + description: + (Optional) The description of the voice to use. + This can be changed later by setting the ``description`` + property on the returned instance. + model_name: + (Optional) The name of the Parler model to use. + device: + (Optional) The device to use for inference. + If not given, an appropriate default is chosen based on the + available hardware. + torch_dtype: + (Optional) The data type to use for inference. + If not given, an appropriate default is chosen based on the + device. + """ + + if device is None: + if torch.cuda.is_available(): + device = 'cuda:0' + elif torch.backends.mps.is_available(): + device = 'mps' + elif torch.xpu.is_available(): + device = 'xpu' + else: + device = 'cpu' + device = torch.device(device) + + if torch_dtype is None: + torch_dtype = torch.float16 if device.type != 'cpu' else torch.float32 + + model = ParlerTTSForConditionalGeneration.from_pretrained(model_name) + model.to(device, dtype=torch_dtype) + + tokenizer = AutoTokenizer.from_pretrained(model_name) + + return cls(model, tokenizer, device, description) + + def get_speech(self, text: StrOrSSML) -> Audio: + tokenized_text = self._tokenize(text) + + signal = self.model.generate( + input_ids=self._tokenized_description, + prompt_input_ids=tokenized_text, + ) + signal = signal.cpu().to(torch.float32).numpy().squeeze() + + sample_rate = self.model.config.sampling_rate + + return Audio(signal, sample_rate) + + def _tokenize(self, text: str) -> torch.Tensor: + return self.tokenizer(text, return_tensors='pt').input_ids.to(self.device) diff --git a/tests/integration/test_all_ttss_basic.py b/tests/integration/test_all_ttss_basic.py index 2ac42d9..006edad 100644 --- a/tests/integration/test_all_ttss_basic.py +++ b/tests/integration/test_all_ttss_basic.py @@ -12,6 +12,7 @@ ESpeakNG, GoogleCloudTTS, gTTS, + ParlerTTS, PicoTTS, Pyttsx3TTS, ) @@ -36,6 +37,9 @@ def test_get_speech(tts_class: Type[TTS]): voice_params=VoiceSelectionParams(language_code='en-US'), ) + elif tts_class is ParlerTTS: + tts = ParlerTTS.build() + else: tts = tts_class()