-
Notifications
You must be signed in to change notification settings - Fork 148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for Coqui TTS #59
base: main
Are you sure you want to change the base?
Changes from all commits
28dd92f
06da573
14e6762
384cc4e
0711cb0
d829ecf
4f8bedc
b098208
c7178c9
5ab8749
854230e
4e88170
c4f1b74
ad1edec
180671e
e4ceee2
a2fc0c1
c38951a
521bfec
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ venv/ | |
.idea | ||
.history/ | ||
.run/ | ||
.python-version | ||
|
||
# Temporary files | ||
*.tmp | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
TTS_AZURE = "azure" | ||
TTS_OPENAI = "openai" | ||
TTS_EDGE = "edge" | ||
TTS_COQUI = "coqui" | ||
|
||
|
||
class BaseTTSProvider: # Base interface for TTS providers | ||
|
@@ -34,18 +35,29 @@ def get_output_file_extension(self): | |
|
||
# Common support methods for all TTS providers | ||
def get_supported_tts_providers() -> List[str]: | ||
return [TTS_AZURE, TTS_OPENAI, TTS_EDGE] | ||
return [TTS_AZURE, TTS_OPENAI, TTS_EDGE, TTS_COQUI] | ||
|
||
|
||
def get_tts_provider(config) -> BaseTTSProvider: | ||
if config.tts == TTS_AZURE: | ||
from audiobook_generator.tts_providers.azure_tts_provider import AzureTTSProvider | ||
from audiobook_generator.tts_providers.azure_tts_provider import \ | ||
AzureTTSProvider | ||
|
||
return AzureTTSProvider(config) | ||
elif config.tts == TTS_OPENAI: | ||
from audiobook_generator.tts_providers.openai_tts_provider import OpenAITTSProvider | ||
from audiobook_generator.tts_providers.openai_tts_provider import \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no functional change, just cosmetics, not needed |
||
OpenAITTSProvider | ||
|
||
return OpenAITTSProvider(config) | ||
elif config.tts == TTS_EDGE: | ||
from audiobook_generator.tts_providers.edge_tts_provider import EdgeTTSProvider | ||
from audiobook_generator.tts_providers.edge_tts_provider import \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no functional change, just cosmetics, not needed |
||
EdgeTTSProvider | ||
|
||
return EdgeTTSProvider(config) | ||
elif config.tts == TTS_COQUI: | ||
from audiobook_generator.tts_providers.coqui_tts_provider import \ | ||
CoquiTTSProvider | ||
|
||
return CoquiTTSProvider(config) | ||
else: | ||
raise ValueError(f"Invalid TTS provider: {config.tts}") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import logging | ||
import math | ||
import tempfile | ||
|
||
import torch | ||
from pydub import AudioSegment | ||
from TTS.api import TTS | ||
|
||
from audiobook_generator.config.general_config import GeneralConfig | ||
from audiobook_generator.core.audio_tags import AudioTags | ||
from audiobook_generator.core.utils import set_audio_tags | ||
from audiobook_generator.tts_providers.base_tts_provider import BaseTTSProvider | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class CoquiTTSProvider(BaseTTSProvider): | ||
def __init__(self, config: GeneralConfig): | ||
# Init TTS with the target model name | ||
|
||
device = 'cuda' if torch.cuda.is_available() else 'cpu' | ||
|
||
logger.setLevel(config.log) | ||
|
||
# TTS provider specific config | ||
config.output_format = config.output_format or 'mp3' | ||
config.model_name = config.model_name or 'tts_models/en/ljspeech/tacotron2-DDC' | ||
config.language_coqui = config.language_coqui or 'en' | ||
config.voice_sample_wav_path = config.voice_sample_wav_path or '' | ||
|
||
self.tts = TTS( | ||
model_name=config.model_name, | ||
progress_bar=True, | ||
).to(device) | ||
|
||
self.price = 0.000 | ||
super().__init__(config) | ||
|
||
def __str__(self) -> str: | ||
return f'{self.config}' | ||
|
||
def validate_config(self): | ||
pass | ||
|
||
def text_to_speech( | ||
self, | ||
text: str, | ||
output_file: str, | ||
audio_tags: AudioTags, | ||
): | ||
|
||
with tempfile.TemporaryDirectory() as tmpdirname: | ||
print('created temporary directory', tmpdirname) | ||
|
||
tmpfilename = tmpdirname + '/file.wav' | ||
|
||
if self.tts.is_multi_lingual: | ||
print(len(text)) | ||
self.tts.tts_to_file( | ||
text, | ||
speaker_wav=self.config.voice_sample_wav_path, | ||
language=self.config.language_coqui, | ||
file_path=tmpfilename, | ||
split_sentences=True, | ||
) | ||
else: | ||
self.tts.tts_to_file( | ||
text, | ||
file_path=tmpfilename, | ||
split_sentences=True, | ||
) | ||
|
||
# Convert the wav file to the desired format | ||
AudioSegment.from_wav(tmpfilename).export(output_file, format=self.config.output_format) | ||
|
||
set_audio_tags(output_file, audio_tags) | ||
|
||
def estimate_cost(self, total_chars): | ||
return math.ceil(total_chars / 1000) * self.price | ||
|
||
def get_break_string(self): | ||
return ' ' | ||
|
||
def get_output_file_extension(self): | ||
if self.config.output_format.startswith('amr'): | ||
return 'amr' | ||
elif self.config.output_format.startswith('ogg'): | ||
return 'ogg' | ||
elif self.config.output_format.endswith('truesilk'): | ||
return 'silk' | ||
elif self.config.output_format.endswith('pcm'): | ||
return 'pcm' | ||
elif self.config.output_format.startswith('raw'): | ||
return 'wav' | ||
elif self.config.output_format.startswith('webm'): | ||
return 'webm' | ||
elif self.config.output_format.endswith('opus'): | ||
return 'opus' | ||
elif self.config.output_format.endswith('mp3'): | ||
return 'mp3' | ||
else: | ||
raise NotImplementedError( | ||
f'Unknown file extension for output format: {self.config.output_format}' | ||
) | ||
|
||
def get_supported_models(self): | ||
print(self.tts.list_models()) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -94,23 +94,23 @@ def handle_args(): | |
help=''' | ||
Speaking rate of the text. Valid relative values range from -50%%(--xxx='-50%%') to +100%%. | ||
For negative value use format --arg=value, | ||
''' | ||
''', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it is last argument in function, comma not required |
||
) | ||
|
||
edge_tts_group.add_argument( | ||
"--voice_volume", | ||
help=''' | ||
Volume level of the speaking voice. Valid relative values floor to -100%%. | ||
For negative value use format --arg=value, | ||
''' | ||
''', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it is last argument in function, comma not required |
||
) | ||
|
||
edge_tts_group.add_argument( | ||
"--voice_pitch", | ||
help=''' | ||
Baseline pitch for the text.Valid relative values like -80Hz,+50Hz, pitch changes should be within 0.5 to 1.5 times the original audio. | ||
For negative value use format --arg=value, | ||
''' | ||
''', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it is last argument in function, comma not required |
||
) | ||
|
||
edge_tts_group.add_argument( | ||
|
@@ -125,6 +125,19 @@ def handle_args(): | |
help="Break duration in milliseconds for the different paragraphs or sections (default: 1250). Valid values range from 0 to 5000 milliseconds.", | ||
) | ||
|
||
coqui_tts_group = parser.add_argument_group(title="coqui specific") | ||
coqui_tts_group.add_argument( | ||
"--voice_sample_wav_path", | ||
default="sample_voices/samples_en_man_1.wav", | ||
help="Path to the sample wav file to be used for the voice of the TTS provider", | ||
) | ||
|
||
coqui_tts_group.add_argument( | ||
"--language_coqui", | ||
default="en", | ||
help="Language for the text-to-speech service using Coqui provider(default: en). Possible values are ['en', 'es', 'fr', 'de', 'it', 'pt', 'pl', 'tr', 'ru', 'nl', 'cs', 'ar', 'zh-cn', 'hu', 'ko', 'ja', 'hi']", | ||
) | ||
|
||
args = parser.parse_args() | ||
return GeneralConfig(args) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,4 +5,5 @@ openai==1.2.2 | |
requests==2.31.0 | ||
socksio==1.0.0 | ||
edge-tts==6.1.10 | ||
pydub==0.25.1 | ||
pydub==0.25.1 | ||
TTS==0.22.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no functional change, just cosmetics, not needed