Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tortoise-upgrade #45

Merged
merged 3 commits into from
Jun 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion server.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ def reload_config_and_restart_ui():
else default_config
)

with gr.Blocks(css=full_css) as demo:
with gr.Blocks(
css=full_css,
title="TTS Generation WebUI",
) as demo:
gr.Markdown("# TTS Generation WebUI (Bark, MusicGen, Tortoise)")
with gr.Tabs() as tabs:
register_use_as_history_button = generation_tab_bark(tabs)
Expand Down
2 changes: 1 addition & 1 deletion src/css/css.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from src.bark.bark_css import bark_css
from src.css.material_symbols_css import material_symbols_css
from src.history_tab.history_css import history_css
from src.tortoise.generation_tab_tortoise import css_tortoise
from src.tortoise.css_tortoise import css_tortoise
from src.musicgen.musicgen_css import musicgen_css

full_css = ""
Expand Down
7 changes: 6 additions & 1 deletion src/musicgen/musicgen_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@

import json
from typing import Optional
from importlib.metadata import version

AUDIOCRAFT_VERSION = version("audiocraft")


class MusicGenGeneration(TypedDict):
Expand Down Expand Up @@ -53,6 +56,7 @@ def generate_and_save_metadata(
"_version": "0.0.1",
"_hash_version": "0.0.3",
"_type": "musicgen",
"_audiocraft_version": AUDIOCRAFT_VERSION,
"models": {},
"prompt": prompt,
"hash": audio_array_to_sha256(audio_array),
Expand Down Expand Up @@ -209,6 +213,7 @@ def predict(params: MusicGenGeneration, melody_in: Optional[Tuple[int, np.ndarra
def generation_tab_musicgen():
with gr.Tab("MusicGen") as tab:
musicgen_atom.render()
gr.Markdown(f"""Audiocraft version: {AUDIOCRAFT_VERSION}""")
with gr.Row():
with gr.Column():
text = gr.Textbox(
Expand Down Expand Up @@ -259,7 +264,7 @@ def generation_tab_musicgen():
interactive=True,
step=0.1,
)
seed, set_old_seed_button = setup_seed_ui_musicgen()
seed, set_old_seed_button, _ = setup_seed_ui_musicgen()

with gr.Column():
output = gr.Audio(
Expand Down
10 changes: 9 additions & 1 deletion src/musicgen/setup_seed_ui_musicgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,12 @@ def setup_seed_ui_musicgen():
)

set_old_seed_button.style(size="sm")
return seed_input, set_old_seed_button

def link_seed_cache(seed_cache):
set_old_seed_button.click(
fn=lambda x: gr.Number.update(value=x),
inputs=seed_cache,
outputs=seed_input,
)

return seed_input, set_old_seed_button, link_seed_cache
6 changes: 6 additions & 0 deletions src/tortoise/css_tortoise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
css_tortoise = """
.btn-sm {
min-width: 2em !important;
flex-grow: 0 !important;
}
"""
46 changes: 34 additions & 12 deletions src/tortoise/gen_tortoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,39 @@

from models.tortoise.tortoise.api import TextToSpeech, MODELS_DIR
from models.tortoise.tortoise.utils.audio import load_voices
import gradio as gr

SAMPLE_RATE = 24_000
OUTPUT_PATH = "outputs/"

MODEL = None


def get_tts():
global MODEL
if MODEL is None:
MODEL = TextToSpeech(models_dir=MODELS_DIR)
return MODEL


def generate_tortoise(
text="The expressiveness of autoregressive transformers is literally nuts! I absolutely adore them.",
voice="random",
preset="fast",
model_dir=MODELS_DIR,
candidates=3,
seed=None,
cvvp_amount=0.0,
):
os.makedirs(OUTPUT_PATH, exist_ok=True)

tts = TextToSpeech(models_dir=model_dir)

filenames = []
datas = []
voice_sel = voice.split("&") if "&" in voice else [voice]
voice_samples, conditioning_latents = load_voices(voice_sel)

if seed == -1:
seed = None

tts = get_tts()
gen, state = tts.tts_with_preset(
text,
k=candidates,
Expand All @@ -46,16 +57,14 @@ def generate_tortoise(

if isinstance(gen, list):
for j, g in enumerate(gen):
process_gen(
text, voice, preset, candidates, seed, cvvp_amount, filenames, g, j
)
process_gen(text, voice, preset, candidates, seed, cvvp_amount, datas, g, j)
else:
process_gen(text, voice, preset, candidates, seed, cvvp_amount, filenames, gen)
return filenames
process_gen(text, voice, preset, candidates, seed, cvvp_amount, datas, gen)
return datas


def process_gen(
text, voice, preset, candidates, seed, cvvp_amount, filenames, gen, j=0
text, voice, preset, candidates, seed, cvvp_amount, datas, gen, j=0
):
audio_tensor = gen.squeeze(0).cpu()

Expand All @@ -74,19 +83,32 @@ def process_gen(
filename_json = f"{base_filename}.json"

metadata = {
"_version": "0.0.1",
"_type": model,
"date": date,
"text": text,
"voice": voice,
"preset": preset,
"candidates": candidates,
"seed": seed,
"seed": str(seed),
"cvvp_amount": cvvp_amount,
}
import json

with open(filename_json, "w") as f:
json.dump(metadata, f)

filenames.extend((filename, filename_png))
history_bundle_name_data = os.path.dirname(filename)

datas.extend(
(
filename,
filename_png,
gr.Button.update(value="Save to favorites", visible=True),
seed,
history_bundle_name_data,
)
)


def generate_tortoise_n(n):
Expand Down
Loading