-
Notifications
You must be signed in to change notification settings - Fork 193
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Tortoise support, refactor code (#1)
* simple tortoise generator * basic refactor, splitting tabs and utils * group all history tab functions in one file
- Loading branch information
Showing
15 changed files
with
757 additions
and
504 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from models.bark.bark.generation import preload_models | ||
|
||
|
||
class BarkModelManager: | ||
def __init__(self, config): | ||
self.models_loaded = False | ||
if config["load_models_on_startup"]: | ||
self.reload_models(config) | ||
|
||
def reload_models(self, config): | ||
print(f"{'Rel' if self.models_loaded else 'L'}oading Bark models") | ||
self.models_loaded = True | ||
model_config = config["model"] | ||
text_use_gpu = model_config["text_use_gpu"] | ||
text_use_small = model_config["text_use_small"] | ||
coarse_use_gpu = model_config["coarse_use_gpu"] | ||
coarse_use_small = model_config["coarse_use_small"] | ||
fine_use_gpu = model_config["fine_use_gpu"] | ||
fine_use_small = model_config["fine_use_small"] | ||
codec_use_gpu = model_config["codec_use_gpu"] | ||
|
||
print(f'''\t- Text Generation:\t\t GPU: {"Yes" if text_use_gpu else "No"}, Small Model: {"Yes" if text_use_small else "No"} | ||
\t- Coarse-to-Fine Inference:\t GPU: {"Yes" if coarse_use_gpu else "No"}, Small Model: {"Yes" if coarse_use_small else "No"} | ||
\t- Fine-tuning:\t\t\t GPU: {"Yes" if fine_use_gpu else "No"}, Small Model: {"Yes" if fine_use_small else "No"} | ||
\t- Codec:\t\t\t GPU: {"Yes" if codec_use_gpu else "No"}''') | ||
|
||
preload_models( | ||
text_use_gpu=text_use_gpu, | ||
text_use_small=text_use_small, | ||
coarse_use_gpu=coarse_use_gpu, | ||
coarse_use_small=coarse_use_small, | ||
fine_use_gpu=fine_use_gpu, | ||
fine_use_small=fine_use_small, | ||
codec_use_gpu=codec_use_gpu, | ||
force_reload=True, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from load_config import load_config | ||
|
||
|
||
config = load_config() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
import os | ||
|
||
|
||
def create_base_filename(title, output_path, model, date): | ||
return os.path.join(output_path, f"audio__{model}__{title}__{date}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import datetime | ||
import os | ||
from create_base_filename import create_base_filename | ||
|
||
from get_date import get_date | ||
from save_waveform_plot import save_waveform_plot | ||
|
||
import torchaudio | ||
|
||
from models.tortoise.tortoise.api import TextToSpeech, MODELS_DIR | ||
from models.tortoise.tortoise.utils.audio import load_voices | ||
|
||
SAMPLE_RATE = 24_000 | ||
|
||
|
||
def generate_tortoise(text="The expressiveness of autoregressive transformers is literally nuts! I absolutely adore them.", | ||
voice='random', | ||
preset='fast', | ||
output_path='results/', | ||
model_dir=MODELS_DIR, | ||
candidates=3, | ||
seed=None, | ||
cvvp_amount=.0): | ||
|
||
os.makedirs(output_path, exist_ok=True) | ||
|
||
tts = TextToSpeech(models_dir=model_dir) | ||
|
||
filenames = [] | ||
voice_sel = voice.split('&') if '&' in voice else [voice] | ||
voice_samples, conditioning_latents = load_voices(voice_sel) | ||
|
||
gen, state = tts.tts_with_preset(text, | ||
k=candidates, | ||
voice_samples=voice_samples, | ||
conditioning_latents=conditioning_latents, | ||
preset=preset, | ||
use_deterministic_seed=seed, | ||
return_deterministic_state=True, | ||
cvvp_amount=cvvp_amount) | ||
|
||
seed, _, _, _ = state | ||
|
||
if isinstance(gen, list): | ||
for j, g in enumerate(gen): | ||
filename = os.path.join(output_path, f'{voice}_{j}.wav') | ||
torchaudio.save(filename, g.squeeze(0).cpu(), SAMPLE_RATE) | ||
filenames.append(filename) | ||
else: | ||
audio_tensor = gen.squeeze(0).cpu() | ||
|
||
model = "tortoise" | ||
date = get_date() | ||
|
||
base_filename = create_base_filename(voice, output_path, model, date) | ||
filename = f'{base_filename}.wav' | ||
torchaudio.save(filename, audio_tensor, SAMPLE_RATE) | ||
audio_array = audio_tensor.t().numpy() | ||
# Plot the waveform using matplotlib | ||
filename_png = f'{base_filename}.png' | ||
save_waveform_plot(audio_array, filename_png) | ||
|
||
filename_json = f'{base_filename}.json' | ||
|
||
metadata = { | ||
"text": text, | ||
"voice": voice, | ||
"preset": preset, | ||
"candidates": candidates, | ||
"seed": seed, | ||
"cvvp_amount": cvvp_amount, | ||
"filename": filename, | ||
"filename_png": filename_png, | ||
"filename_json": filename_json, | ||
} | ||
import json | ||
with open(filename_json, 'w') as f: | ||
json.dump(metadata, f) | ||
|
||
filenames.extend((filename, filename_png)) | ||
return filenames | ||
|
||
def generate_tortoise_(prompt): | ||
return generate_tortoise(text=prompt, | ||
voice="random", | ||
output_path="outputs/", | ||
preset="ultra_fast", | ||
candidates=1, | ||
cvvp_amount=.0) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,227 @@ | ||
from create_base_filename import create_base_filename | ||
from gen_tortoise import generate_tortoise_ | ||
from get_date import get_date | ||
from models.bark.bark import SAMPLE_RATE, generate_audio | ||
from scipy.io.wavfile import write as write_wav | ||
import json | ||
from models.bark.bark.generation import SUPPORTED_LANGS | ||
import gradio as gr | ||
from save_waveform_plot import save_waveform_plot | ||
from model_manager import model_manager | ||
from config import config | ||
|
||
def generate(prompt, useHistory, language=None, speaker_id=0, text_temp=0.7, waveform_temp=0.7): | ||
if not model_manager.models_loaded: | ||
model_manager.reload_models(config) | ||
|
||
# generate audio from text | ||
history_prompt = f"{SUPPORTED_LANGS[language][1]}_speaker_{speaker_id}" if useHistory else None | ||
|
||
print("Generating:", prompt, "history_prompt:", history_prompt, | ||
"text_temp:", text_temp, "waveform_temp:", waveform_temp) | ||
audio_array = generate_audio( | ||
prompt, history_prompt=history_prompt, text_temp=text_temp, waveform_temp=waveform_temp) | ||
|
||
model = "bark" | ||
date = get_date() | ||
base_filename = create_base_filename( | ||
history_prompt, "outputs", model, date) | ||
filename = f"{base_filename}.wav" | ||
write_wav(filename, SAMPLE_RATE, audio_array) | ||
filename_png = f"{base_filename}.png" | ||
save_waveform_plot(audio_array, filename_png) | ||
|
||
filename_json = f"{base_filename}.json" | ||
# Generate metadata for the audio file | ||
metadata = { | ||
"prompt": prompt, | ||
"language": SUPPORTED_LANGS[language][0] if useHistory else None, | ||
"speaker_id": speaker_id if useHistory else None, | ||
"history_prompt": history_prompt, | ||
"text_temp": text_temp, | ||
"waveform_temp": waveform_temp, | ||
"date": date, | ||
"filename": filename, | ||
"filename_png": filename_png, | ||
"filename_json": filename_json, | ||
} | ||
with open(filename_json, "w") as outfile: | ||
json.dump(metadata, outfile, indent=2) | ||
|
||
return [filename, filename_png] | ||
|
||
|
||
def generate_multi(count=1): | ||
def gen(prompt, useHistory, language=None, speaker_id=0, text_temp=0.7, waveform_temp=0.7): | ||
filenames = [] | ||
for i in range(count): | ||
filename, filename_png = generate( | ||
prompt, useHistory, language, speaker_id, text_temp=text_temp, waveform_temp=waveform_temp) | ||
filenames.extend((filename, filename_png)) | ||
return filenames | ||
return gen | ||
|
||
|
||
def toggleHistory(choice): | ||
if choice == True: | ||
return [gr.Radio.update(visible=True), gr.Radio.update(visible=True)] | ||
else: | ||
return [gr.Radio.update(visible=False), gr.Radio.update(visible=False)] | ||
|
||
|
||
def generation_tab_bark(): | ||
with gr.Tab("Generation (Bark)"): | ||
useHistory = gr.Checkbox( | ||
label="Use a voice (History Prompt):", value=False) | ||
|
||
languages = [lang[0] for lang in SUPPORTED_LANGS] | ||
languageRadio = gr.Radio(languages, type="index", show_label=False, | ||
value="English", visible=False) | ||
|
||
speaker_ids = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"] | ||
speakerIdRadio = gr.Radio(speaker_ids, type="value", | ||
label="Speaker ID", value="0", visible=False) | ||
|
||
# Show the language and speakerId radios only when useHistory is checked | ||
useHistory.change(fn=toggleHistory, inputs=[useHistory], outputs=[ | ||
languageRadio, speakerIdRadio]) | ||
|
||
with gr.Row(): | ||
text_temp = gr.Slider(label="Text temperature", | ||
value=0.7, minimum=0.0, maximum=1.0, step=0.1) | ||
waveform_temp = gr.Slider( | ||
label="Waveform temperature", value=0.7, minimum=0.0, maximum=1.0, step=0.1) | ||
|
||
prompt = gr.Textbox(label="Prompt", lines=3, | ||
placeholder="Enter text here...") | ||
|
||
inputs = [ | ||
prompt, | ||
useHistory, | ||
languageRadio, | ||
speakerIdRadio, | ||
text_temp, | ||
waveform_temp | ||
] | ||
|
||
with gr.Row(): | ||
audio_1 = gr.Audio(type="filepath", label="Generated audio") | ||
audio_2 = gr.Audio( | ||
type="filepath", label="Generated audio", visible=False) | ||
audio_3 = gr.Audio( | ||
type="filepath", label="Generated audio", visible=False) | ||
|
||
with gr.Row(): | ||
image_1 = gr.Image(label="Waveform") | ||
image_2 = gr.Image(label="Waveform", visible=False) | ||
image_3 = gr.Image(label="Waveform", visible=False) | ||
|
||
outputs = [audio_1, image_1] | ||
outputs2 = [audio_2, image_2] | ||
outputs3 = [audio_3, image_3] | ||
# examples = [ | ||
# ["The quick brown fox jumps over the lazy dog."], | ||
# ["To be or not to be, that is the question."], | ||
# ["In a hole in the ground there lived a hobbit."], | ||
# ["This text uses a history prompt, resulting in a more predictable voice.", | ||
# True, "English", "0"], | ||
# ] | ||
|
||
with gr.Row(): | ||
generate3_button = gr.Button("Generate 3") | ||
generate2_button = gr.Button("Generate 2") | ||
generate1_button = gr.Button("Generate", variant="primary") | ||
|
||
prompt.submit(fn=generate, inputs=inputs, outputs=outputs) | ||
generate1_button.click(fn=generate, inputs=inputs, outputs=outputs) | ||
generate2_button.click(fn=generate_multi(2), inputs=inputs, | ||
outputs=outputs + outputs2) | ||
generate3_button.click(fn=generate_multi(3), inputs=inputs, | ||
outputs=outputs + outputs2 + outputs3) | ||
|
||
def show(count): return [ | ||
gr.Audio.update(visible=True), | ||
gr.Image.update(visible=True), | ||
gr.Audio.update(visible=count > 1), | ||
gr.Image.update(visible=count > 1), | ||
gr.Audio.update(visible=count > 2), | ||
gr.Image.update(visible=count > 2), | ||
] | ||
|
||
generate1_button.click(fn=lambda: show( | ||
1), outputs=outputs + outputs2 + outputs3) | ||
generate2_button.click(fn=lambda: show( | ||
2), outputs=outputs + outputs2 + outputs3) | ||
generate3_button.click(fn=lambda: show( | ||
3), outputs=outputs + outputs2 + outputs3) | ||
|
||
|
||
def test(): | ||
text_prompt = """ | ||
Hello, my name is Suno. And, uh — and I like pizza. [laughs] | ||
But I also have other interests such as playing tic tac toe. | ||
""" | ||
|
||
history_prompt = "en_speaker_0" | ||
generate(text_prompt, True, history_prompt) | ||
generate(text_prompt, False, history_prompt) | ||
|
||
|
||
def generation_tab_tortoise(): | ||
with gr.Tab("Generation (Tortoise)"): | ||
prompt_tortoise = gr.Textbox(label="Prompt", lines=3, | ||
placeholder="Enter text here...") | ||
|
||
inputs = [ | ||
prompt_tortoise | ||
] | ||
|
||
with gr.Row(): | ||
audio_1 = gr.Audio(type="filepath", label="Generated audio") | ||
audio_2 = gr.Audio( | ||
type="filepath", label="Generated audio", visible=False) | ||
audio_3 = gr.Audio( | ||
type="filepath", label="Generated audio", visible=False) | ||
|
||
with gr.Row(): | ||
image_1 = gr.Image(label="Waveform") | ||
image_2 = gr.Image(label="Waveform", visible=False) | ||
image_3 = gr.Image(label="Waveform", visible=False) | ||
|
||
outputs = [audio_1, image_1] | ||
outputs2 = [audio_2, image_2] | ||
outputs3 = [audio_3, image_3] | ||
|
||
with gr.Row(): | ||
generate3_button = gr.Button("Generate 3", visible=False) | ||
generate2_button = gr.Button("Generate 2", visible=False) | ||
generate1_button = gr.Button("Generate", variant="primary") | ||
|
||
prompt_tortoise.submit(fn=generate_tortoise_, | ||
inputs=inputs, outputs=outputs) | ||
generate1_button.click(fn=generate_tortoise_, | ||
inputs=inputs, outputs=outputs) | ||
generate2_button.click(fn=generate_tortoise_, inputs=inputs, | ||
outputs=outputs + outputs2) | ||
generate3_button.click(fn=generate_tortoise_, inputs=inputs, | ||
outputs=outputs + outputs2 + outputs3) | ||
|
||
def show_closure(count): | ||
def show(): | ||
return [ | ||
gr.Audio.update(visible=True), | ||
gr.Image.update(visible=True), | ||
gr.Audio.update(visible=count > 1), | ||
gr.Image.update(visible=count > 1), | ||
gr.Audio.update(visible=count > 2), | ||
gr.Image.update(visible=count > 2), | ||
] | ||
return show | ||
|
||
generate1_button.click(fn=show_closure( | ||
1), outputs=outputs + outputs2 + outputs3) | ||
generate2_button.click(fn=show_closure( | ||
2), outputs=outputs + outputs2 + outputs3) | ||
generate3_button.click(fn=show_closure( | ||
3), outputs=outputs + outputs2 + outputs3) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
import datetime | ||
|
||
|
||
def get_date(): | ||
now = datetime.datetime.now() | ||
return now.strftime("%Y-%m-%d_%H-%M-%S") |
Oops, something went wrong.