Skip to content

Commit

Permalink
add 48k model
Browse files Browse the repository at this point in the history
  • Loading branch information
haoheliu committed Aug 27, 2023
1 parent 1740a57 commit c894099
Show file tree
Hide file tree
Showing 9 changed files with 458 additions and 74 deletions.
41 changes: 20 additions & 21 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,20 @@

This repo currently support Text-to-Audio (including Music) and Text-to-Speech Generation.

* [TODO](#todo)
* [Web APP](#web-app)
* [Commandline Usage](#commandline-usage)
+ [Installation](#installation)
+ [Run the model in commandline](#run-the-model-in-commandline)
* [Random Seed Matters](#random-seed-matters)
* [Pretrained Models](#pretrained-models)
* [Other options](#other-options)
* [Cite this work](#cite-this-work)

<hr>

## Change Log
- 2023-08-27: Add two new checkpoints!
- 🌟 **48kHz AudioLDM model**: Now we support high-fidelity audio generation! Use this checkpoint simply by setting "--model_name audioldm_48k"
- **16kHz improved AudioLDM model**: Trained with more data and optimized model architecture.

## TODO
- [x] Add the text-to-speech checkpoint
- [ ] Add the text-to-audio checkpoint that does not use FLAN-T5 Cross Attention
- [ ] Open-source the AudioLDM training code.
- [ ] Support the generation of longer audio (> 10s)
- [ ] Optimizing the inference speed of the model.
- [x] Support the generation of longer audio (> 10s)
- [x] Optimizing the inference speed of the model.
- [ ] Integration with the Diffusers library
- [ ] Add the style-transfer and inpainting code for the audioldm_48k checkpoint (PR welcomed, same logic as [AudioLDMv1](https://github.com/haoheliu/AudioLDM))

## Web APP

Expand Down Expand Up @@ -90,19 +85,21 @@ You can choose model checkpoint by setting up "model_name":

```shell
# CUDA
audioldm2 --model_name "audioldm2-full-large-1150k" --device cuda -t "Musical constellations twinkling in the night sky, forming a cosmic melody."
audioldm2 --model_name "audioldm_48k" --device cuda -t "Musical constellations twinkling in the night sky, forming a cosmic melody."

# MPS
audioldm2 --model_name "audioldm2-full-large-1150k" --device mps -t "Musical constellations twinkling in the night sky, forming a cosmic melody."
audioldm2 --model_name "audioldm_48k" --device mps -t "Musical constellations twinkling in the night sky, forming a cosmic melody."
```

We have five checkpoints you can choose:

1. **audioldm2-full** (default): Generate both sound effect and music generation.
2. **audioldm2-full-large-1150k**: Larger version of audioldm2-full.
3. **audioldm2-music-665k**: Music generation.
4. **audioldm2-speech-gigaspeech** (default for TTS): Text-to-Speech, trained on GigaSpeech Dataset.
5. **audioldm2-speech-ljspeech**: Text-to-Speech, trained on LJSpeech Dataset.
1. **audioldm_48k** (default): This checkpoint can generate high fidelity sound effect and music.
2. **audioldm2-full**: Generate both sound effect and music generation with the AudioLDM2 architecture.
2. **audioldm_16k_crossattn_t5**: The improved version of [AudioLDM 1.0](https://github.com/haoheliu/AudioLDM).
4. **audioldm2-full-large-1150k**: Larger version of audioldm2-full.
5. **audioldm2-music-665k**: Music generation.
6. **audioldm2-speech-gigaspeech** (default for TTS): Text-to-Speech, trained on GigaSpeech Dataset.
7. **audioldm2-speech-ljspeech**: Text-to-Speech, trained on LJSpeech Dataset.

We currently support 3 devices:
- cpu
Expand All @@ -112,7 +109,7 @@ We currently support 3 devices:
## Other options
```shell
usage: audioldm2 [-h] [-t TEXT] [-tl TEXT_LIST] [-s SAVE_PATH]
[--model_name {audioldm2-full,audioldm2-music-665k,audioldm2-full-large-1150k,audioldm2-speech-ljspeech,audioldm2-speech-gigaspeech}] [-d DEVICE]
[--model_name {audioldm_48k, audioldm_16k_crossattn_t5, audioldm2-full,audioldm2-music-665k,audioldm2-full-large-1150k,audioldm2-speech-ljspeech,audioldm2-speech-gigaspeech}] [-d DEVICE]
[-b BATCHSIZE] [--ddim_steps DDIM_STEPS] [-gs GUIDANCE_SCALE] [-n N_CANDIDATE_GEN_PER_TEXT]
[--seed SEED]

Expand All @@ -132,6 +129,8 @@ We currently support 3 devices:
-b BATCHSIZE, --batchsize BATCHSIZE
Generate how many samples at the same time
--ddim_steps DDIM_STEPS
-dur DURATION, --duration DURATION
The duration of the samples
The sampling step for DDIM
-gs GUIDANCE_SCALE, --guidance_scale GUIDANCE_SCALE
Guidance scale (Large => better quality and relavancy to text; Small => better diversity)
Expand Down
46 changes: 30 additions & 16 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@

os.environ["TOKENIZERS_PARALLELISM"] = "true"

default_checkpoint="audioldm2-full"
# default_checkpoint="audioldm2-full"
default_checkpoint="audioldm_48k"
audioldm = None
current_model_name = None

def text2audio(
text,
duration,
guidance_scale,
random_seed,
n_candidates,
Expand All @@ -25,19 +27,26 @@ def text2audio(
if audioldm is None or model_name != current_model_name:
audioldm = build_model(model_name=model_name)
current_model_name = model_name
audioldm = torch.compile(audioldm)

# audioldm = torch.compile(audioldm)
# print(text, length, guidance_scale)
if("48k" in model_name):
latent_t_per_second=12.8
sample_rate=48000
else:
latent_t_per_second=25.6
sample_rate=16000

waveform = text_to_audio(
latent_diffusion=audioldm,
text=text,
seed=random_seed,
duration=10,
duration=duration,
guidance_scale=guidance_scale,
n_candidate_gen_per_text=int(n_candidates),
latent_t_per_second=latent_t_per_second,
) # [bs, 1, samples]
waveform = [
gr.make_waveform((16000, wave[0]), bg_image="bg.png") for wave in waveform
gr.make_waveform((sample_rate, wave[0]), bg_image="bg.png") for wave in waveform
]
# waveform = [(16000, np.random.randn(16000)), (16000, np.random.randn(16000))]
if len(waveform) == 1:
Expand Down Expand Up @@ -228,9 +237,9 @@ def text2audio(
value=45,
label="Change this value (any integer number) will lead to a different generation result.",
)
# duration = gr.Slider(
# 10, 10, value=10, step=2.5, label="Duration (seconds)"
# )
duration = gr.Slider(
5, 15, value=10, step=2.5, label="Duration (seconds)"
)
guidance_scale = gr.Slider(
0,
6,
Expand All @@ -245,9 +254,9 @@ def text2audio(
step=1,
label="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation",
)
# model_name = gr.Dropdown(
# ["audioldm-m-text-ft", "audioldm-s-text-ft", "audioldm-m-full","audioldm-s-full-v2", "audioldm-s-full", "audioldm-l-full"], value="audioldm-m-full", label="Choose the model to use. audioldm-m-text-ft and audioldm-s-text-ft are recommanded. -s- means small, -m- means medium and -l- means large",
# )
model_name = gr.Dropdown(
["audioldm_48k", "audioldm_crossattn_flant5", "audioldm2-full"], value="audioldm_48k",
)
############# Output
# outputs=gr.Audio(label="Output", type="numpy")
outputs = gr.Video(label="Output", elem_id="output-video")
Expand All @@ -270,7 +279,7 @@ def text2audio(
# textbox, duration, guidance_scale, seed, n_candidates, model_name], outputs=[outputs])
btn.click(
text2audio,
inputs=[textbox, guidance_scale, seed, n_candidates],
inputs=[textbox, duration, guidance_scale, seed, n_candidates],
outputs=[outputs],
api_name="text2audio",
)
Expand All @@ -291,43 +300,48 @@ def text2audio(
[
[
"An excited crowd cheering at a sports game.",
10,
3.5,
45,
3,
default_checkpoint,
],
[
"A cat is meowing for attention.",
10,
3.5,
45,
3,
default_checkpoint,
],
[
"Birds singing sweetly in a blooming garden.",
10,
3.5,
45,
3,
default_checkpoint,
],
[
"A modern synthesizer creating futuristic soundscapes.",
10,
3.5,
45,
3,
default_checkpoint,
],
[
"The vibrant beat of Brazilian samba drums.",
10,
3.5,
45,
3,
default_checkpoint,
],
],
fn=text2audio,
# inputs=[textbox, duration, guidance_scale, seed, n_candidates, model_name],
inputs=[textbox, guidance_scale, seed, n_candidates],
inputs=[textbox, duration, guidance_scale, seed, n_candidates, model_name],
# inputs=[textbox, guidance_scale, seed, n_candidates],
outputs=[outputs],
cache_examples=True,
)
Expand All @@ -353,5 +367,5 @@ def text2audio(
# <p>This demo is strictly for research demo purpose only. For commercial use please <a href="haoheliu@gmail.com">contact us</a>.</p>

iface.queue(concurrency_count=3)
iface.launch(debug=True)
# iface.launch(debug=True, share=True)
# iface.launch(debug=True)
iface.launch(debug=True, share=True)
20 changes: 10 additions & 10 deletions audioldm2/latent_diffusion/models/ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,16 +861,16 @@ def get_input(
if cond_model_key in cond_dict.keys():
continue

if not self.training:
if isinstance(
self.cond_stage_models[
self.cond_stage_model_metadata[cond_model_key]["model_idx"]
],
CLAPAudioEmbeddingClassifierFreev2,
):
print(
"Warning: CLAP model normally should use text for evaluation"
)
# if not self.training:
# if isinstance(
# self.cond_stage_models[
# self.cond_stage_model_metadata[cond_model_key]["model_idx"]
# ],
# CLAPAudioEmbeddingClassifierFreev2,
# ):
# print(
# "Warning: CLAP model normally should use text for evaluation"
# )

# The original data for conditioning
# If cond_model_key is "all", that means the conditional model need all the information from a batch
Expand Down
12 changes: 4 additions & 8 deletions audioldm2/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def extract_kaldi_fbank_feature(waveform, sampling_rate, log_mel_spec):

def make_batch_for_text_to_audio(text, transcription="", waveform=None, fbank=None, batchsize=1):
text = [text] * batchsize
transcription = text2phoneme(transcription)
if(transcription):
transcription = text2phoneme(transcription)
transcription = [transcription] * batchsize

if batchsize < 1:
Expand Down Expand Up @@ -170,9 +171,6 @@ def build_model(ckpt_path=None, config=None, device=None, model_name="audioldm2-

return latent_diffusion

def duration_to_latent_t_size(duration):
return int(duration * 25.6)

def text_to_audio(
latent_diffusion,
text,
Expand All @@ -183,18 +181,16 @@ def text_to_audio(
batchsize=1,
guidance_scale=3.5,
n_candidate_gen_per_text=3,
latent_t_per_second=25.6,
config=None,
):
assert (
duration == 10
), "Error: Currently we only support 10 seconds of generation. Generating longer files requires some extra coding, which would be a part of the future work."

seed_everything(int(seed))
waveform = None

batch = make_batch_for_text_to_audio(text, transcription=transcription, waveform=waveform, batchsize=batchsize)

latent_diffusion.latent_t_size = duration_to_latent_t_size(duration)
latent_diffusion.latent_t_size = int(duration * latent_t_per_second)

with torch.no_grad():
waveform = latent_diffusion.generate_batch(
Expand Down
73 changes: 62 additions & 11 deletions audioldm2/utilities/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,44 @@ def get_vocoder_config():
},
}

def get_vocoder_config_48k():
return {
"resblock": "1",
"num_gpus": 8,
"batch_size": 128,
"learning_rate": 0.0001,
"adam_b1": 0.8,
"adam_b2": 0.99,
"lr_decay": 0.999,
"seed": 1234,

"upsample_rates": [6,5,4,2,2],
"upsample_kernel_sizes": [12,10,8,4,4],
"upsample_initial_channel": 1536,
"resblock_kernel_sizes": [3,7,11,15],
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5], [1,3,5]],

"segment_size": 15360,
"num_mels": 256,
"n_fft": 2048,
"hop_size": 480,
"win_size": 2048,

"sampling_rate": 48000,

"fmin": 20,
"fmax": 24000,
"fmax_for_loss": None,

"num_workers": 8,

"dist_config": {
"dist_backend": "nccl",
"dist_url": "tcp://localhost:18273",
"world_size": 1
}
}


def get_available_checkpoint_keys(model, ckpt):
state_dict = torch.load(ckpt)["state_dict"]
Expand Down Expand Up @@ -89,17 +127,30 @@ def get_vocoder(config, device, mel_bins):
vocoder.mel2wav.eval()
vocoder.mel2wav.to(device)
elif name == "HiFi-GAN":
config = get_vocoder_config()
config = hifigan.AttrDict(config)
vocoder = hifigan.Generator_old(config)
# print("Load hifigan/g_01080000")
# ckpt = torch.load(os.path.join(ROOT, "hifigan/g_01080000"))
# ckpt = torch.load(os.path.join(ROOT, "hifigan/g_00660000"))
# ckpt = torch_version_orig_mod_remove(ckpt)
# vocoder.load_state_dict(ckpt["generator"])
vocoder.eval()
vocoder.remove_weight_norm()
vocoder.to(device)
if(mel_bins == 64):
config = get_vocoder_config()
config = hifigan.AttrDict(config)
vocoder = hifigan.Generator_old(config)
# print("Load hifigan/g_01080000")
# ckpt = torch.load(os.path.join(ROOT, "hifigan/g_01080000"))
# ckpt = torch.load(os.path.join(ROOT, "hifigan/g_00660000"))
# ckpt = torch_version_orig_mod_remove(ckpt)
# vocoder.load_state_dict(ckpt["generator"])
vocoder.eval()
vocoder.remove_weight_norm()
vocoder.to(device)
else:
config = get_vocoder_config_48k()
config = hifigan.AttrDict(config)
vocoder = hifigan.Generator_old(config)
# print("Load hifigan/g_01080000")
# ckpt = torch.load(os.path.join(ROOT, "hifigan/g_01080000"))
# ckpt = torch.load(os.path.join(ROOT, "hifigan/g_00660000"))
# ckpt = torch_version_orig_mod_remove(ckpt)
# vocoder.load_state_dict(ckpt["generator"])
vocoder.eval()
vocoder.remove_weight_norm()
vocoder.to(device)
return vocoder


Expand Down
Loading

0 comments on commit c894099

Please sign in to comment.