diff --git a/server.py b/server.py index 79e201f8..0f35ae3c 100644 --- a/server.py +++ b/server.py @@ -63,7 +63,10 @@ def reload_config_and_restart_ui(): else default_config ) -with gr.Blocks(css=full_css) as demo: +with gr.Blocks( + css=full_css, + title="TTS Generation WebUI", +) as demo: gr.Markdown("# TTS Generation WebUI (Bark, MusicGen, Tortoise)") with gr.Tabs() as tabs: register_use_as_history_button = generation_tab_bark(tabs) diff --git a/src/css/css.py b/src/css/css.py index 1f90086b..28250f0e 100644 --- a/src/css/css.py +++ b/src/css/css.py @@ -1,7 +1,7 @@ from src.bark.bark_css import bark_css from src.css.material_symbols_css import material_symbols_css from src.history_tab.history_css import history_css -from src.tortoise.generation_tab_tortoise import css_tortoise +from src.tortoise.css_tortoise import css_tortoise from src.musicgen.musicgen_css import musicgen_css full_css = "" diff --git a/src/musicgen/musicgen_tab.py b/src/musicgen/musicgen_tab.py index f50a4cb7..31acbe3d 100644 --- a/src/musicgen/musicgen_tab.py +++ b/src/musicgen/musicgen_tab.py @@ -21,6 +21,9 @@ import json from typing import Optional +from importlib.metadata import version + +AUDIOCRAFT_VERSION = version("audiocraft") class MusicGenGeneration(TypedDict): @@ -53,6 +56,7 @@ def generate_and_save_metadata( "_version": "0.0.1", "_hash_version": "0.0.3", "_type": "musicgen", + "_audiocraft_version": AUDIOCRAFT_VERSION, "models": {}, "prompt": prompt, "hash": audio_array_to_sha256(audio_array), @@ -209,6 +213,7 @@ def predict(params: MusicGenGeneration, melody_in: Optional[Tuple[int, np.ndarra def generation_tab_musicgen(): with gr.Tab("MusicGen") as tab: musicgen_atom.render() + gr.Markdown(f"""Audiocraft version: {AUDIOCRAFT_VERSION}""") with gr.Row(): with gr.Column(): text = gr.Textbox( @@ -259,7 +264,7 @@ def generation_tab_musicgen(): interactive=True, step=0.1, ) - seed, set_old_seed_button = setup_seed_ui_musicgen() + seed, set_old_seed_button, _ = setup_seed_ui_musicgen() with gr.Column(): output = gr.Audio( diff --git a/src/musicgen/setup_seed_ui_musicgen.py b/src/musicgen/setup_seed_ui_musicgen.py index 9a660117..1a595c30 100644 --- a/src/musicgen/setup_seed_ui_musicgen.py +++ b/src/musicgen/setup_seed_ui_musicgen.py @@ -20,4 +20,12 @@ def setup_seed_ui_musicgen(): ) set_old_seed_button.style(size="sm") - return seed_input, set_old_seed_button + + def link_seed_cache(seed_cache): + set_old_seed_button.click( + fn=lambda x: gr.Number.update(value=x), + inputs=seed_cache, + outputs=seed_input, + ) + + return seed_input, set_old_seed_button, link_seed_cache diff --git a/src/tortoise/css_tortoise.py b/src/tortoise/css_tortoise.py new file mode 100644 index 00000000..db5ef6d7 --- /dev/null +++ b/src/tortoise/css_tortoise.py @@ -0,0 +1,6 @@ +css_tortoise = """ +.btn-sm { + min-width: 2em !important; + flex-grow: 0 !important; +} +""" diff --git a/src/tortoise/gen_tortoise.py b/src/tortoise/gen_tortoise.py index 605cb2ac..b4dfcbd3 100644 --- a/src/tortoise/gen_tortoise.py +++ b/src/tortoise/gen_tortoise.py @@ -9,28 +9,39 @@ from models.tortoise.tortoise.api import TextToSpeech, MODELS_DIR from models.tortoise.tortoise.utils.audio import load_voices +import gradio as gr SAMPLE_RATE = 24_000 OUTPUT_PATH = "outputs/" +MODEL = None + + +def get_tts(): + global MODEL + if MODEL is None: + MODEL = TextToSpeech(models_dir=MODELS_DIR) + return MODEL + def generate_tortoise( text="The expressiveness of autoregressive transformers is literally nuts! I absolutely adore them.", voice="random", preset="fast", - model_dir=MODELS_DIR, candidates=3, seed=None, cvvp_amount=0.0, ): os.makedirs(OUTPUT_PATH, exist_ok=True) - tts = TextToSpeech(models_dir=model_dir) - - filenames = [] + datas = [] voice_sel = voice.split("&") if "&" in voice else [voice] voice_samples, conditioning_latents = load_voices(voice_sel) + if seed == -1: + seed = None + + tts = get_tts() gen, state = tts.tts_with_preset( text, k=candidates, @@ -46,16 +57,14 @@ def generate_tortoise( if isinstance(gen, list): for j, g in enumerate(gen): - process_gen( - text, voice, preset, candidates, seed, cvvp_amount, filenames, g, j - ) + process_gen(text, voice, preset, candidates, seed, cvvp_amount, datas, g, j) else: - process_gen(text, voice, preset, candidates, seed, cvvp_amount, filenames, gen) - return filenames + process_gen(text, voice, preset, candidates, seed, cvvp_amount, datas, gen) + return datas def process_gen( - text, voice, preset, candidates, seed, cvvp_amount, filenames, gen, j=0 + text, voice, preset, candidates, seed, cvvp_amount, datas, gen, j=0 ): audio_tensor = gen.squeeze(0).cpu() @@ -74,11 +83,14 @@ def process_gen( filename_json = f"{base_filename}.json" metadata = { + "_version": "0.0.1", + "_type": model, + "date": date, "text": text, "voice": voice, "preset": preset, "candidates": candidates, - "seed": seed, + "seed": str(seed), "cvvp_amount": cvvp_amount, } import json @@ -86,7 +98,17 @@ def process_gen( with open(filename_json, "w") as f: json.dump(metadata, f) - filenames.extend((filename, filename_png)) + history_bundle_name_data = os.path.dirname(filename) + + datas.extend( + ( + filename, + filename_png, + gr.Button.update(value="Save to favorites", visible=True), + seed, + history_bundle_name_data, + ) + ) def generate_tortoise_n(n): diff --git a/src/tortoise/generation_tab_tortoise.py b/src/tortoise/generation_tab_tortoise.py index e758cae0..790fcbbc 100644 --- a/src/tortoise/generation_tab_tortoise.py +++ b/src/tortoise/generation_tab_tortoise.py @@ -1,33 +1,175 @@ -from src.tortoise.gen_tortoise import generate_tortoise_n -from models.tortoise.tortoise.utils.audio import get_voices +from typing import Any +import numpy as np +import torch +import torchaudio +from src.bark.split_text_functions import split_by_lines +from src.history_tab.save_to_favorites import save_to_favorites +from src.musicgen.setup_seed_ui_musicgen import setup_seed_ui_musicgen +from src.tortoise.gen_tortoise import generate_tortoise +from models.tortoise.tortoise.utils.audio import get_voices +from src.css.css import full_css import gradio as gr +MAX_OUTPUTS = 9 + + +class TortoiseOutputRow: + def __init__(self, audio, image, save_button, seed, bundle_name): + self.audio: gr.Audio = audio + self.image: gr.Image = image + self.save_button: gr.Button = save_button + self.seed: gr.State = seed + self.bundle_name: gr.State = bundle_name + + def to_list(self): + return [ + self.audio, + self.image, + self.save_button, + self.seed, + self.bundle_name, + ] + + @staticmethod + def from_list(components): + return TortoiseOutputRow( + audio=components[0], + image=components[1], + save_button=components[2], + seed=components[3], + bundle_name=components[4], + ) + + # def __iter__(self): + # return iter(self.to_list()) + + +def create_components(index): + with gr.Column(visible=index == 0) as col: + audio = gr.Audio( + type="filepath", label="Generated audio", elem_classes="tts-audio" + ) + image = gr.Image(label="Waveform", shape=(None, 100), elem_classes="tts-image") # type: ignore + with gr.Row(): + save_button = gr.Button("Save to favorites", visible=False) + seed = gr.State() # type: ignore + bundle_name = gr.State() # type: ignore + + save_button.click( + fn=save_to_favorites, + inputs=[bundle_name], + outputs=[save_button], + ) + + return ( + TortoiseOutputRow(audio, image, save_button, seed, bundle_name).to_list(), + col, + seed, + ) + + +def generate_tortoise_long(outs: list[TortoiseOutputRow], count: int): + def gen( + prompt_raw, + voice="random", + preset="ultra_fast", + seed=None, + cvvp_amount=0.0, + split_prompt=False, + ): + prompts = split_by_lines(prompt_raw) if split_prompt else [prompt_raw] + audio_pieces = [[] for _ in range(count)] + + for i, prompt in enumerate(prompts): + datas = generate_tortoise( + text=prompt, + voice=voice, + preset=preset, + seed=seed, + cvvp_amount=cvvp_amount, + candidates=count, + ) + for i in range(count): + yield { + outs[i].audio: datas[5 * i], + outs[i].image: datas[5 * i + 1], + outs[i].save_button: gr.Button.update(visible=True), + outs[i].seed: datas[5 * i + 3], + outs[i].bundle_name: datas[5 * i + 4], + } + # accumulate audio filenames + for i in range(count): + audio_filename = datas[5 * i] + audio_tensor, _ = torchaudio.load(audio_filename) + audio_array = audio_tensor.t().numpy() + audio_pieces[i].append(audio_array) + + # if there is only one prompt, then we don't need to concatenate + if len(prompts) == 1: + return {} + + # concatenate audio pieces + def concat_and_save_pieces(audio_pieces): + for i in range(count): + audio_pieces[i] = np.concatenate(audio_pieces[i]) + audio_tensor = torch.from_numpy(audio_pieces[i]).t() + # TEMP - before long generations are supported + audio_filename = datas[5 * i] + "_long.wav" + torchaudio.save(audio_filename, audio_tensor, 22050) + yield audio_filename + + # get audio filenames + audio_filenames = list(concat_and_save_pieces(audio_pieces)) + yield {outs[i].audio: audio_filenames[i] for i in range(count)} + return {} + + return gen + def generation_tab_tortoise(): - with gr.Tab("Generation (Tortoise)"): - prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Enter text here...") + with gr.Tab("Tortoise TTS"): + inputs, output_rows = tortoise_ui() + total_columns = len(output_rows) with gr.Row(): - # with gr.Box(): - # gr.Markdown("### Voice") - with gr.Row(): - voice = gr.Dropdown( - choices=["random"] + list(get_voices()), - value="random", - # show_label=False, - label="Voice", + for i in range(total_columns): + target_count = total_columns - i + generate_button( + text=f"Generate {target_count if target_count > 1 else ''}", + count=target_count, + variant="primary" if target_count == 1 else "secondary", + inputs=inputs, + output_rows=output_rows, + total_columns=total_columns, ) - # voice.style(container=False) - # reload_voices = gr.Button("🔁", elem_classes="btn-sm") - # reload_voices.style(size="sm") - # def reload_voices_fn(): - # choices = - # print(choices) - # return [ - # gr.Dropdown.update(choices=choices), - # ] - # reload_voices.click(fn=reload_voices_fn, outputs=[voice]) + + +def tortoise_ui(): + with gr.Row(): + with gr.Column(): + with gr.Box(): + gr.Markdown("Voice") + with gr.Row(): + voice = gr.Dropdown( + choices=["random"] + list(get_voices()), + value="random", + show_label=False, + ) + voice.style(container=False) + reload_voices = gr.Button( + "refresh", + elem_classes="btn-sm material-symbols-outlined", + ) + reload_voices.style(size="sm") + + def reload_voices_fn(): + choices = ["random"] + list(get_voices()) + return [ + gr.Dropdown.update(choices=choices), + ] + + reload_voices.click(fn=reload_voices_fn, outputs=[voice]) preset = gr.Dropdown( label="Preset", choices=[ @@ -38,83 +180,66 @@ def generation_tab_tortoise(): ], value="ultra_fast", ) - # Args: - # seed (int): The desired seed. Value must be within the inclusive range - # `[-0x8000_0000_0000_0000, 0xffff_ffff_ffff_ffff]`. Otherwise, a RuntimeError - # is raised. Negative inputs are remapped to positive values with the formula - # `0xffff_ffff_ffff_ffff + seed`. - seed = gr.Textbox( - label="Seed", - lines=1, - placeholder="Enter seed here...", - value="None", - visible=False, - ) - cvvp_amount = gr.Slider( - label="CVVP Amount", value=0.0, minimum=0.0, maximum=1.0, step=0.1 - ) + with gr.Column(): + cvvp_amount = gr.Slider( + label="CVVP Amount", value=0.0, minimum=0.0, maximum=1.0, step=0.1 + ) + seed, _, link_seed_cache = setup_seed_ui_musicgen() - inputs = [prompt, voice, preset, seed, cvvp_amount] + split_prompt = gr.Checkbox(label="Split prompt by lines", value=False) - with gr.Row(): - audio_1 = gr.Audio(type="filepath", label="Generated audio") - audio_2 = gr.Audio(type="filepath", label="Generated audio", visible=False) - audio_3 = gr.Audio(type="filepath", label="Generated audio", visible=False) + prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Enter text here...") - with gr.Row(): - image_1 = gr.Image(label="Waveform") - image_2 = gr.Image(label="Waveform", visible=False) - image_3 = gr.Image(label="Waveform", visible=False) + with gr.Row(): + output_rows = [create_components(i) for i in range(MAX_OUTPUTS)] - outputs = [audio_1, image_1] - outputs2 = [audio_2, image_2] - outputs3 = [audio_3, image_3] + link_seed_cache(seed_cache=output_rows[0][2]) - with gr.Row(): - generate3_button = gr.Button("Generate 3") - generate2_button = gr.Button("Generate 2") - generate1_button = gr.Button("Generate", variant="primary") + inputs = [prompt, voice, preset, seed, cvvp_amount, split_prompt] + return inputs, output_rows - prompt.submit(fn=generate_tortoise_n(1), inputs=inputs, outputs=outputs) - generate1_button.click( - fn=generate_tortoise_n(1), inputs=inputs, outputs=outputs - ) - generate2_button.click( - fn=generate_tortoise_n(2), inputs=inputs, outputs=outputs + outputs2 + +def generate_button(text, count, variant, inputs, output_rows, total_columns): + def get_all_components(count): + return [i for i, _, _ in output_rows[:count]] + + def get_output_list(count): + return sum(get_all_components(count), []) + + def get_all_outs(count): + return [TortoiseOutputRow.from_list(i) for i in get_all_components(count)] + + def hide_all_save_buttons(list_of_outs: list[TortoiseOutputRow]): + return lambda: { + outs.save_button: gr.Button.update(visible=False) for outs in list_of_outs + } + + def show(count): + return [gr.Column.update(visible=count > i) for i in range(total_columns)] + + output_cols: list[Any] = [col for _, col, _ in output_rows] + return ( + gr.Button(text, variant=variant) + .click(fn=lambda: show(count), outputs=output_cols) + .then( + fn=hide_all_save_buttons(get_all_outs(count)), + outputs=get_output_list(count), ) - generate3_button.click( - fn=generate_tortoise_n(3), + .then( + fn=generate_tortoise_long( + get_all_outs(count), + count, + ), inputs=inputs, - outputs=outputs + outputs2 + outputs3, + outputs=get_output_list(count), ) + ) - def show_closure(count): - def show(): - return [ - gr.Audio.update(visible=True), - gr.Image.update(visible=True), - gr.Audio.update(visible=count > 1), - gr.Image.update(visible=count > 1), - gr.Audio.update(visible=count > 2), - gr.Image.update(visible=count > 2), - ] - - return show - - generate1_button.click( - fn=show_closure(1), outputs=outputs + outputs2 + outputs3 - ) - generate2_button.click( - fn=show_closure(2), outputs=outputs + outputs2 + outputs3 - ) - generate3_button.click( - fn=show_closure(3), outputs=outputs + outputs2 + outputs3 - ) +if __name__ == "__main__": + with gr.Blocks(css=full_css) as demo: + generation_tab_tortoise() -css_tortoise = """ -.btn-sm { - min-width: 2em !important; - flex-grow: 0 !important; -} -""" + demo.launch( + enable_queue=True, + )