Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AudioLDM2模型复现前向推理 #366

Merged
merged 8 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions paddlemix/examples/audioldm2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# AudioLDM2

## 1. 模型简介

该模型是 [AudioLDM2](https://arxiv.org/abs/2308.05734) 的 paddle 实现。


## 2. Demo

### 2.1 依赖安装

- 请确保已安装 ppdiffusers ([参考方法](https://github.com/PaddlePaddle/PaddleMIX/blob/develop/README.md?plain=1#L62))

- 其余依赖安装:

```bash
cd /paddlemix/models/audioldm2
pip install -r requirement.txt
```

### 2.2 动态图推理
```bash
python run_predict.py \
--text "Musical constellations twinkling in the night sky, forming a cosmic melody." \
--model_name_or_path "/my_model_path" \
--seed 1001 \
```
302 changes: 302 additions & 0 deletions paddlemix/examples/audioldm2/run_predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,302 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field
import paddle
from paddlenlp.trainer import PdArgumentParser
import os
import time
import soundfile as sf
from paddlemix.models.audioldm2.modeling import AudioLDM2Model
from paddlemix.models.audioldm2.encoders.phoneme_encoder import text as text
import random
import numpy as np
import re

def seed_everything(seed):
os.environ["PYTHONHASHSEED"] = str(seed)
random.seed(seed)
np.random.seed(seed)
paddle.seed(seed)

def text2phoneme(data):
return text._clean_text(re.sub(r'<.*?>', '', data), ["english_cleaners2"])

def text_to_filename(text):
return text.replace(" ", "_").replace("'", "_").replace('"', "_")

CACHE = {
"get_vits_phoneme_ids":{
"PAD_LENGTH": 310,
"_pad": '_',
"_punctuation": ';:,.!?¡¿—…"«»“” ',
"_letters": 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz',
"_letters_ipa": "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ",
"_special": "♪☎☒☝⚠"
}
}

CACHE["get_vits_phoneme_ids"]["symbols"] = [CACHE["get_vits_phoneme_ids"]["_pad"]] + list(CACHE["get_vits_phoneme_ids"]["_punctuation"]) + list(CACHE["get_vits_phoneme_ids"]["_letters"]) + list(CACHE["get_vits_phoneme_ids"]["_letters_ipa"]) + list(CACHE["get_vits_phoneme_ids"]["_special"])
CACHE["get_vits_phoneme_ids"]["_symbol_to_id"] = {s: i for i, s in enumerate(CACHE["get_vits_phoneme_ids"]["symbols"])}

def get_vits_phoneme_ids_no_padding(phonemes):
pad_token_id = 0
pad_length = CACHE["get_vits_phoneme_ids"]["PAD_LENGTH"]
_symbol_to_id = CACHE["get_vits_phoneme_ids"]["_symbol_to_id"]
batchsize = len(phonemes)

clean_text = phonemes[0] + "⚠"
sequence = []

for symbol in clean_text:
if(symbol not in _symbol_to_id.keys()):
print("%s is not in the vocabulary. %s" % (symbol, clean_text))
symbol = "_"
symbol_id = _symbol_to_id[symbol]
sequence += [symbol_id]

def _pad_phonemes(phonemes_list):
return phonemes_list + [pad_token_id] * (pad_length-len(phonemes_list))

sequence = sequence[:pad_length]

return {"phoneme_idx": paddle.to_tensor(_pad_phonemes(sequence), dtype="int64").unsqueeze(0).expand([batchsize, -1])}


def make_batch_for_text_to_audio(text, transcription="", waveform=None, fbank=None, batchsize=1):
text = [text] * batchsize
if(transcription):
transcription = text2phoneme(transcription)
transcription = [transcription] * batchsize

if batchsize < 1:
print("Warning: Batchsize must be at least 1. Batchsize is set to .")

if fbank is None:
fbank = paddle.zeros(
(batchsize, 1024, 64)
) # Not used, here to keep the code format
else:
fbank = paddle.to_tensor(fbank, dtype="float32")
fbank = fbank.expand([batchsize, 1024, 64])
assert fbank.shape[0] == batchsize

stft = paddle.zeros((batchsize, 1024, 512)) # Not used
phonemes = get_vits_phoneme_ids_no_padding(transcription)

waveform = paddle.zeros((batchsize, 160000)) # Not used
ta_kaldi_fbank = paddle.zeros((batchsize, 1024, 128))

batch = {
"text": text, # list
"fname": [text_to_filename(t) for t in text], # list
"waveform": waveform,
"stft": stft,
"log_mel_spec": fbank,
"ta_kaldi_fbank": ta_kaldi_fbank,
}
batch.update(phonemes)
return batch

def get_time():
t = time.localtime()
return time.strftime("%d_%m_%Y_%H_%M_%S", t)

def save_wave(waveform, savepath, name="outwav", samplerate=16000):
if type(name) is not list:
name = [name] * waveform.shape[0]

for i in range(waveform.shape[0]):
if waveform.shape[0] > 1:
fname = "%s_%s.wav" % (
os.path.basename(name[i])
if (not ".wav" in name[i])
else os.path.basename(name[i]).split(".")[0],
i,
)
else:
fname = "%s.wav" % os.path.basename(name[i]) if (not ".wav" in name[i]) else os.path.basename(name[i]).split(".")[0]
# Avoid the file name too long to be saved
if len(fname) > 255:
fname = f"{hex(hash(fname))}.wav"

path = os.path.join(
savepath, fname
)
print("Save audio to %s" % path)
sf.write(path, waveform[i, 0], samplerate=samplerate)

def read_list(fname):
result = []
with open(fname, "r", encoding="utf-8") as f:
for each in f.readlines():
each = each.strip('\n')
result.append(each)
return result

def text_to_audio(
model,
text,
transcription="",
seed=42,
ddim_steps=200,
duration=10,
batchsize=1,
guidance_scale=3.5,
n_candidate_gen_per_text=3,
latent_t_per_second=25.6,
):

seed_everything(int(seed))
waveform = None

batch = make_batch_for_text_to_audio(text, transcription=transcription, waveform=waveform, batchsize=batchsize)

model.latent_t_size = int(duration * latent_t_per_second)

waveform = model(
batch,
unconditional_guidance_scale=guidance_scale,
ddim_steps=ddim_steps,
n_gen=n_candidate_gen_per_text,
duration=duration,
)

return waveform


@dataclass
class DataArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
Using `PdArgumentParser` we can turn this class
into argparse arguments to be able to specify them on
the command line.
"""

text: str = field(default="", metadata={"help": "Text prompt to the model for audio generation."})
transcription: str = field(default="", metadata={"help": "Transcription for Text-to-Speech."})
text_list: str = field(default="", metadata={"help": "A file (utf-8 encoded) that contains text prompt to the model for audio generation."})

@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""

model_name_or_path: str = field(
default="audioldm2-full",
metadata={"help": "Path to pretrained model or model identifier"},
)
save_path: str = field(
default="./output",
metadata={"help": "The path to save model output."},
)
device: str = field(
default="gpu",
metadata={"help": "The device for computation. If not specified, the script will automatically choose gpu."},
)
batchsize: int = field(
default=1,
metadata={"help": "Generate how many samples at the same time."},
)
ddim_steps: int = field(
default=200,
metadata={"help": "The sampling step for DDIM."},
)
guidance_scale: float = field(
default=3.5,
metadata={"help": "Guidance scale (Large => better quality and relavancy to text; Small => better diversity)."},
)
duration: float = field(
default=10.0,
metadata={"help": "The duration of the samples."},
)
n_candidate_gen_per_text: int = field(
default=3,
metadata={"help": "Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation."},
)
seed: int = field(
default=42,
metadata={"help": "Change this value (any integer number) will lead to a different generation result."},
)

def main():
parser = PdArgumentParser((ModelArguments, DataArguments))
model_args, data_args = parser.parse_args_into_dataclasses()

# process args
text = data_args.text
transcription = data_args.transcription
text_list = data_args.text_list

save_path = os.path.join(model_args.save_path, get_time())
random_seed = model_args.seed
duration = model_args.duration
sample_rate = 16000
latent_t_per_second = 25.6

print("Warning: For AudioLDM2 we currently only support 10s of generation. Please use audioldm_48k or audioldm_16k_crossattn_t5 if you want a different duration.")
duration = 10

guidance_scale = model_args.guidance_scale
n_candidate_gen_per_text = model_args.n_candidate_gen_per_text

if transcription:
if "speech" not in model_args.model_name_or_path:
print("Warning: You choose to perform Text-to-Speech by providing the transcription. However you do not choose the correct model name (audioldm2-speech-gigaspeech or audioldm2-speech-ljspeech).")
print("Warning: We will use audioldm2-speech-gigaspeech by default")
model_args.model_name_or_path = "audioldm2-speech-gigaspeech"
if not text:
print("Warning: You should provide text as a input to describe the speaker. Use default (A male reporter is speaking).")
text = "A female reporter is speaking full of emotion"

if text_list:
print("Generate audio based on the text prompts in %s" % text_list)
prompt_todo = read_list(text_list)
else:
prompt_todo = [text]

# build audioldm2 model
paddle.set_device(model_args.device)
audioldm2 = AudioLDM2Model.from_pretrained(model_args.model_name_or_path)

# predict
os.makedirs(save_path, exist_ok=True)
for text in prompt_todo:
if "|" in text:
text, name = text.split("|")
else:
name = text[:128]

if transcription:
name += "-TTS-%s" % transcription

waveform = text_to_audio(
audioldm2,
text,
transcription=transcription, # To avoid the model to ignore the last vocab
seed=random_seed,
duration=duration,
guidance_scale=guidance_scale,
ddim_steps=model_args.ddim_steps,
n_candidate_gen_per_text=n_candidate_gen_per_text,
batchsize=model_args.batchsize,
latent_t_per_second=latent_t_per_second
)

save_wave(waveform, save_path, name=name, samplerate=sample_rate)

if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions paddlemix/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,5 @@
from .qwen_vl import *
from .visualglm.configuration import *
from .visualglm.modeling import *
from .audioldm2.modeling import *
from .audioldm2.configuration import *
13 changes: 13 additions & 0 deletions paddlemix/models/audioldm2/audiomae/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading