Skip to content

Commit

Permalink
Add Tortoise support, refactor code (#1)
Browse files Browse the repository at this point in the history
* simple tortoise generator
* basic refactor, splitting tabs and utils
* group all history tab functions in one file
  • Loading branch information
rsxdalv authored May 1, 2023
1 parent 3e02e07 commit 958fdab
Show file tree
Hide file tree
Showing 15 changed files with 757 additions and 504 deletions.
36 changes: 36 additions & 0 deletions BarkModelManager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from models.bark.bark.generation import preload_models


class BarkModelManager:
def __init__(self, config):
self.models_loaded = False
if config["load_models_on_startup"]:
self.reload_models(config)

def reload_models(self, config):
print(f"{'Rel' if self.models_loaded else 'L'}oading Bark models")
self.models_loaded = True
model_config = config["model"]
text_use_gpu = model_config["text_use_gpu"]
text_use_small = model_config["text_use_small"]
coarse_use_gpu = model_config["coarse_use_gpu"]
coarse_use_small = model_config["coarse_use_small"]
fine_use_gpu = model_config["fine_use_gpu"]
fine_use_small = model_config["fine_use_small"]
codec_use_gpu = model_config["codec_use_gpu"]

print(f'''\t- Text Generation:\t\t GPU: {"Yes" if text_use_gpu else "No"}, Small Model: {"Yes" if text_use_small else "No"}
\t- Coarse-to-Fine Inference:\t GPU: {"Yes" if coarse_use_gpu else "No"}, Small Model: {"Yes" if coarse_use_small else "No"}
\t- Fine-tuning:\t\t\t GPU: {"Yes" if fine_use_gpu else "No"}, Small Model: {"Yes" if fine_use_small else "No"}
\t- Codec:\t\t\t GPU: {"Yes" if codec_use_gpu else "No"}''')

preload_models(
text_use_gpu=text_use_gpu,
text_use_small=text_use_small,
coarse_use_gpu=coarse_use_gpu,
coarse_use_small=coarse_use_small,
fine_use_gpu=fine_use_gpu,
fine_use_small=fine_use_small,
codec_use_gpu=codec_use_gpu,
force_reload=True,
)
4 changes: 4 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from load_config import load_config


config = load_config()
5 changes: 5 additions & 0 deletions create_base_filename.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import os


def create_base_filename(title, output_path, model, date):
return os.path.join(output_path, f"audio__{model}__{title}__{date}")
90 changes: 90 additions & 0 deletions gen_tortoise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import datetime
import os
from create_base_filename import create_base_filename

from get_date import get_date
from save_waveform_plot import save_waveform_plot

import torchaudio

from models.tortoise.tortoise.api import TextToSpeech, MODELS_DIR
from models.tortoise.tortoise.utils.audio import load_voices

SAMPLE_RATE = 24_000


def generate_tortoise(text="The expressiveness of autoregressive transformers is literally nuts! I absolutely adore them.",
voice='random',
preset='fast',
output_path='results/',
model_dir=MODELS_DIR,
candidates=3,
seed=None,
cvvp_amount=.0):

os.makedirs(output_path, exist_ok=True)

tts = TextToSpeech(models_dir=model_dir)

filenames = []
voice_sel = voice.split('&') if '&' in voice else [voice]
voice_samples, conditioning_latents = load_voices(voice_sel)

gen, state = tts.tts_with_preset(text,
k=candidates,
voice_samples=voice_samples,
conditioning_latents=conditioning_latents,
preset=preset,
use_deterministic_seed=seed,
return_deterministic_state=True,
cvvp_amount=cvvp_amount)

seed, _, _, _ = state

if isinstance(gen, list):
for j, g in enumerate(gen):
filename = os.path.join(output_path, f'{voice}_{j}.wav')
torchaudio.save(filename, g.squeeze(0).cpu(), SAMPLE_RATE)
filenames.append(filename)
else:
audio_tensor = gen.squeeze(0).cpu()

model = "tortoise"
date = get_date()

base_filename = create_base_filename(voice, output_path, model, date)
filename = f'{base_filename}.wav'
torchaudio.save(filename, audio_tensor, SAMPLE_RATE)
audio_array = audio_tensor.t().numpy()
# Plot the waveform using matplotlib
filename_png = f'{base_filename}.png'
save_waveform_plot(audio_array, filename_png)

filename_json = f'{base_filename}.json'

metadata = {
"text": text,
"voice": voice,
"preset": preset,
"candidates": candidates,
"seed": seed,
"cvvp_amount": cvvp_amount,
"filename": filename,
"filename_png": filename_png,
"filename_json": filename_json,
}
import json
with open(filename_json, 'w') as f:
json.dump(metadata, f)

filenames.extend((filename, filename_png))
return filenames

def generate_tortoise_(prompt):
return generate_tortoise(text=prompt,
voice="random",
output_path="outputs/",
preset="ultra_fast",
candidates=1,
cvvp_amount=.0)

227 changes: 227 additions & 0 deletions generation_tab_bark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
from create_base_filename import create_base_filename
from gen_tortoise import generate_tortoise_
from get_date import get_date
from models.bark.bark import SAMPLE_RATE, generate_audio
from scipy.io.wavfile import write as write_wav
import json
from models.bark.bark.generation import SUPPORTED_LANGS
import gradio as gr
from save_waveform_plot import save_waveform_plot
from model_manager import model_manager
from config import config

def generate(prompt, useHistory, language=None, speaker_id=0, text_temp=0.7, waveform_temp=0.7):
if not model_manager.models_loaded:
model_manager.reload_models(config)

# generate audio from text
history_prompt = f"{SUPPORTED_LANGS[language][1]}_speaker_{speaker_id}" if useHistory else None

print("Generating:", prompt, "history_prompt:", history_prompt,
"text_temp:", text_temp, "waveform_temp:", waveform_temp)
audio_array = generate_audio(
prompt, history_prompt=history_prompt, text_temp=text_temp, waveform_temp=waveform_temp)

model = "bark"
date = get_date()
base_filename = create_base_filename(
history_prompt, "outputs", model, date)
filename = f"{base_filename}.wav"
write_wav(filename, SAMPLE_RATE, audio_array)
filename_png = f"{base_filename}.png"
save_waveform_plot(audio_array, filename_png)

filename_json = f"{base_filename}.json"
# Generate metadata for the audio file
metadata = {
"prompt": prompt,
"language": SUPPORTED_LANGS[language][0] if useHistory else None,
"speaker_id": speaker_id if useHistory else None,
"history_prompt": history_prompt,
"text_temp": text_temp,
"waveform_temp": waveform_temp,
"date": date,
"filename": filename,
"filename_png": filename_png,
"filename_json": filename_json,
}
with open(filename_json, "w") as outfile:
json.dump(metadata, outfile, indent=2)

return [filename, filename_png]


def generate_multi(count=1):
def gen(prompt, useHistory, language=None, speaker_id=0, text_temp=0.7, waveform_temp=0.7):
filenames = []
for i in range(count):
filename, filename_png = generate(
prompt, useHistory, language, speaker_id, text_temp=text_temp, waveform_temp=waveform_temp)
filenames.extend((filename, filename_png))
return filenames
return gen


def toggleHistory(choice):
if choice == True:
return [gr.Radio.update(visible=True), gr.Radio.update(visible=True)]
else:
return [gr.Radio.update(visible=False), gr.Radio.update(visible=False)]


def generation_tab_bark():
with gr.Tab("Generation (Bark)"):
useHistory = gr.Checkbox(
label="Use a voice (History Prompt):", value=False)

languages = [lang[0] for lang in SUPPORTED_LANGS]
languageRadio = gr.Radio(languages, type="index", show_label=False,
value="English", visible=False)

speaker_ids = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]
speakerIdRadio = gr.Radio(speaker_ids, type="value",
label="Speaker ID", value="0", visible=False)

# Show the language and speakerId radios only when useHistory is checked
useHistory.change(fn=toggleHistory, inputs=[useHistory], outputs=[
languageRadio, speakerIdRadio])

with gr.Row():
text_temp = gr.Slider(label="Text temperature",
value=0.7, minimum=0.0, maximum=1.0, step=0.1)
waveform_temp = gr.Slider(
label="Waveform temperature", value=0.7, minimum=0.0, maximum=1.0, step=0.1)

prompt = gr.Textbox(label="Prompt", lines=3,
placeholder="Enter text here...")

inputs = [
prompt,
useHistory,
languageRadio,
speakerIdRadio,
text_temp,
waveform_temp
]

with gr.Row():
audio_1 = gr.Audio(type="filepath", label="Generated audio")
audio_2 = gr.Audio(
type="filepath", label="Generated audio", visible=False)
audio_3 = gr.Audio(
type="filepath", label="Generated audio", visible=False)

with gr.Row():
image_1 = gr.Image(label="Waveform")
image_2 = gr.Image(label="Waveform", visible=False)
image_3 = gr.Image(label="Waveform", visible=False)

outputs = [audio_1, image_1]
outputs2 = [audio_2, image_2]
outputs3 = [audio_3, image_3]
# examples = [
# ["The quick brown fox jumps over the lazy dog."],
# ["To be or not to be, that is the question."],
# ["In a hole in the ground there lived a hobbit."],
# ["This text uses a history prompt, resulting in a more predictable voice.",
# True, "English", "0"],
# ]

with gr.Row():
generate3_button = gr.Button("Generate 3")
generate2_button = gr.Button("Generate 2")
generate1_button = gr.Button("Generate", variant="primary")

prompt.submit(fn=generate, inputs=inputs, outputs=outputs)
generate1_button.click(fn=generate, inputs=inputs, outputs=outputs)
generate2_button.click(fn=generate_multi(2), inputs=inputs,
outputs=outputs + outputs2)
generate3_button.click(fn=generate_multi(3), inputs=inputs,
outputs=outputs + outputs2 + outputs3)

def show(count): return [
gr.Audio.update(visible=True),
gr.Image.update(visible=True),
gr.Audio.update(visible=count > 1),
gr.Image.update(visible=count > 1),
gr.Audio.update(visible=count > 2),
gr.Image.update(visible=count > 2),
]

generate1_button.click(fn=lambda: show(
1), outputs=outputs + outputs2 + outputs3)
generate2_button.click(fn=lambda: show(
2), outputs=outputs + outputs2 + outputs3)
generate3_button.click(fn=lambda: show(
3), outputs=outputs + outputs2 + outputs3)


def test():
text_prompt = """
Hello, my name is Suno. And, uh — and I like pizza. [laughs]
But I also have other interests such as playing tic tac toe.
"""

history_prompt = "en_speaker_0"
generate(text_prompt, True, history_prompt)
generate(text_prompt, False, history_prompt)


def generation_tab_tortoise():
with gr.Tab("Generation (Tortoise)"):
prompt_tortoise = gr.Textbox(label="Prompt", lines=3,
placeholder="Enter text here...")

inputs = [
prompt_tortoise
]

with gr.Row():
audio_1 = gr.Audio(type="filepath", label="Generated audio")
audio_2 = gr.Audio(
type="filepath", label="Generated audio", visible=False)
audio_3 = gr.Audio(
type="filepath", label="Generated audio", visible=False)

with gr.Row():
image_1 = gr.Image(label="Waveform")
image_2 = gr.Image(label="Waveform", visible=False)
image_3 = gr.Image(label="Waveform", visible=False)

outputs = [audio_1, image_1]
outputs2 = [audio_2, image_2]
outputs3 = [audio_3, image_3]

with gr.Row():
generate3_button = gr.Button("Generate 3", visible=False)
generate2_button = gr.Button("Generate 2", visible=False)
generate1_button = gr.Button("Generate", variant="primary")

prompt_tortoise.submit(fn=generate_tortoise_,
inputs=inputs, outputs=outputs)
generate1_button.click(fn=generate_tortoise_,
inputs=inputs, outputs=outputs)
generate2_button.click(fn=generate_tortoise_, inputs=inputs,
outputs=outputs + outputs2)
generate3_button.click(fn=generate_tortoise_, inputs=inputs,
outputs=outputs + outputs2 + outputs3)

def show_closure(count):
def show():
return [
gr.Audio.update(visible=True),
gr.Image.update(visible=True),
gr.Audio.update(visible=count > 1),
gr.Image.update(visible=count > 1),
gr.Audio.update(visible=count > 2),
gr.Image.update(visible=count > 2),
]
return show

generate1_button.click(fn=show_closure(
1), outputs=outputs + outputs2 + outputs3)
generate2_button.click(fn=show_closure(
2), outputs=outputs + outputs2 + outputs3)
generate3_button.click(fn=show_closure(
3), outputs=outputs + outputs2 + outputs3)

6 changes: 6 additions & 0 deletions get_date.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import datetime


def get_date():
now = datetime.datetime.now()
return now.strftime("%Y-%m-%d_%H-%M-%S")
Loading

0 comments on commit 958fdab

Please sign in to comment.