Skip to content

Commit

Permalink
Add ParlerTTS (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
austin-bowen authored May 28, 2024
1 parent 61222e7 commit adbbc02
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 0 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,14 @@ Online TTS engine used by Google Translate.
1. `pip install "voicebox-tts[gtts]"`
2. Install ffmpeg or libav for `pydub` ([docs](https://github.com/jiaaro/pydub#getting-ffmpeg-set-up))

### 🤗 Parler TTS [🌐](https://github.com/huggingface/parler-tts)

Offline TTS engine released by Hugging Face that uses a promptable
deep learning model to generate speech.

- Class: [`voicebox.tts.ParlerTTS`](voicebox.tts.parlertts.ParlerTTS)
- Setup: `pip install git+https://github.com/huggingface/parler-tts.git`

### Pico TTS

Very basic offline TTS engine.
Expand Down
8 changes: 8 additions & 0 deletions docs/voicebox.tts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ voicebox.tts.gtts module
:undoc-members:
:show-inheritance:

voicebox.tts.parlertts module
-----------------------------

.. automodule:: voicebox.tts.parlertts
:members:
:undoc-members:
:show-inheritance:

voicebox.tts.picotts module
---------------------------

Expand Down
5 changes: 5 additions & 0 deletions src/voicebox/tts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@
except ImportError:
pass

try:
from voicebox.tts.parlertts import ParlerTTS
except ImportError:
pass

from voicebox.tts.picotts import PicoTTS

try:
Expand Down
113 changes: 113 additions & 0 deletions src/voicebox/tts/parlertts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from typing import Union

import torch
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer

from voicebox.audio import Audio
from voicebox.tts import TTS
from voicebox.types import StrOrSSML


class ParlerTTS(TTS):
"""
Offline TTS engine released by Hugging Face 🤗 that uses a promptable
deep learning model to generate speech.
Use ``ParlerTTS.build()`` to create a new instance of this class
instead of instantiating it directly.
See repo for details: https://github.com/huggingface/parler-tts
"""

model: torch.nn.Module
tokenizer: AutoTokenizer
device: torch.device

def __init__(
self,
model: torch.nn.Module,
tokenizer: AutoTokenizer,
device: torch.device,
description: str,
):
self.model = model
self.tokenizer = tokenizer
self.device = device

self._tokenized_description = None
self.description = description

@property
def description(self) -> str:
return self._description

@description.setter
def description(self, value: str):
self._description = value
self._tokenized_description = self._tokenize(value)

@classmethod
def build(
cls,
description: str = '',
model_name: str = 'parler-tts/parler_tts_mini_v0.1',
device: Union[str, torch.device] = None,
torch_dtype: torch.dtype = None,
):
"""
Build a new instance of ``ParlerTTS``.
Args:
description:
(Optional) The description of the voice to use.
This can be changed later by setting the ``description``
property on the returned instance.
model_name:
(Optional) The name of the Parler model to use.
device:
(Optional) The device to use for inference.
If not given, an appropriate default is chosen based on the
available hardware.
torch_dtype:
(Optional) The data type to use for inference.
If not given, an appropriate default is chosen based on the
device.
"""

if device is None:
if torch.cuda.is_available():
device = 'cuda:0'
elif torch.backends.mps.is_available():
device = 'mps'
elif torch.xpu.is_available():
device = 'xpu'
else:
device = 'cpu'
device = torch.device(device)

if torch_dtype is None:
torch_dtype = torch.float16 if device.type != 'cpu' else torch.float32

model = ParlerTTSForConditionalGeneration.from_pretrained(model_name)
model.to(device, dtype=torch_dtype)

tokenizer = AutoTokenizer.from_pretrained(model_name)

return cls(model, tokenizer, device, description)

def get_speech(self, text: StrOrSSML) -> Audio:
tokenized_text = self._tokenize(text)

signal = self.model.generate(
input_ids=self._tokenized_description,
prompt_input_ids=tokenized_text,
)
signal = signal.cpu().to(torch.float32).numpy().squeeze()

sample_rate = self.model.config.sampling_rate

return Audio(signal, sample_rate)

def _tokenize(self, text: str) -> torch.Tensor:
return self.tokenizer(text, return_tensors='pt').input_ids.to(self.device)
4 changes: 4 additions & 0 deletions tests/integration/test_all_ttss_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
ESpeakNG,
GoogleCloudTTS,
gTTS,
ParlerTTS,
PicoTTS,
Pyttsx3TTS,
)
Expand All @@ -36,6 +37,9 @@ def test_get_speech(tts_class: Type[TTS]):
voice_params=VoiceSelectionParams(language_code='en-US'),
)

elif tts_class is ParlerTTS:
tts = ParlerTTS.build()

else:
tts = tts_class()

Expand Down

0 comments on commit adbbc02

Please sign in to comment.