From 5c68cbd3754761b6c4d2ff4c995c16fbc84440d5 Mon Sep 17 00:00:00 2001 From: AyaseNana <49900969+NKNaN@users.noreply.github.com> Date: Fri, 19 Jan 2024 11:53:35 +0800 Subject: [PATCH] =?UTF-8?q?AudioLDM2=E6=A8=A1=E5=9E=8B=E5=A4=8D=E7=8E=B0?= =?UTF-8?q?=E5=89=8D=E5=90=91=E6=8E=A8=E7=90=86=20(#366)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 任务:https://github.com/PaddlePaddle/PaddleMIX/issues/250 - text-to-audio推理已跑通 --- paddlemix/examples/audioldm2/README.md | 27 + paddlemix/examples/audioldm2/run_predict.py | 302 +++++ paddlemix/models/__init__.py | 2 + .../models/audioldm2/audiomae/__init__.py | 13 + paddlemix/models/audioldm2/audiomae/mae.py | 367 ++++++ .../models/audioldm2/clap_module/clap.py | 47 + .../audioldm2/clap_module/feature_fusion.py | 200 +++ .../audioldm2/clap_module/htsat_model.py | 1105 +++++++++++++++++ .../models/audioldm2/clap_module/model.py | 403 ++++++ .../models/audioldm2/clap_module/utils.py | 344 +++++ paddlemix/models/audioldm2/configuration.py | 204 +++ .../models/audioldm2/diffusionwrapper.py | 166 +++ .../audioldm2/encoders/audiomae_encoder.py | 155 +++ .../models/audioldm2/encoders/clap_encoder.py | 395 ++++++ .../audioldm2/encoders/flant5_encoder.py | 96 ++ .../encoders/phoneme_encoder/__init__.py | 13 + .../encoders/phoneme_encoder/cleaners.py | 103 ++ .../encoders/phoneme_encoder/symbols.py | 28 + .../encoders/phoneme_encoder/text.py | 62 + .../encoders/sequence2audiomae_encoder.py | 487 ++++++++ paddlemix/models/audioldm2/hifigan/model.py | 333 +++++ .../audioldm2/latent_encoder/autoencoder.py | 140 +++ .../audioldm2/latentdiffusion_samplers.py | 870 +++++++++++++ paddlemix/models/audioldm2/modeling.py | 898 ++++++++++++++ paddlemix/models/audioldm2/requirement.txt | 4 + paddlemix/models/audioldm2/unet/attention.py | 199 +++ .../models/audioldm2/unet/openaimodel.py | 868 +++++++++++++ paddlemix/models/audioldm2/utils.py | 86 ++ 28 files changed, 7917 insertions(+) create mode 100644 paddlemix/examples/audioldm2/README.md create mode 100644 paddlemix/examples/audioldm2/run_predict.py create mode 100644 paddlemix/models/audioldm2/audiomae/__init__.py create mode 100644 paddlemix/models/audioldm2/audiomae/mae.py create mode 100644 paddlemix/models/audioldm2/clap_module/clap.py create mode 100644 paddlemix/models/audioldm2/clap_module/feature_fusion.py create mode 100644 paddlemix/models/audioldm2/clap_module/htsat_model.py create mode 100644 paddlemix/models/audioldm2/clap_module/model.py create mode 100644 paddlemix/models/audioldm2/clap_module/utils.py create mode 100644 paddlemix/models/audioldm2/configuration.py create mode 100644 paddlemix/models/audioldm2/diffusionwrapper.py create mode 100644 paddlemix/models/audioldm2/encoders/audiomae_encoder.py create mode 100644 paddlemix/models/audioldm2/encoders/clap_encoder.py create mode 100644 paddlemix/models/audioldm2/encoders/flant5_encoder.py create mode 100644 paddlemix/models/audioldm2/encoders/phoneme_encoder/__init__.py create mode 100644 paddlemix/models/audioldm2/encoders/phoneme_encoder/cleaners.py create mode 100644 paddlemix/models/audioldm2/encoders/phoneme_encoder/symbols.py create mode 100644 paddlemix/models/audioldm2/encoders/phoneme_encoder/text.py create mode 100644 paddlemix/models/audioldm2/encoders/sequence2audiomae_encoder.py create mode 100644 paddlemix/models/audioldm2/hifigan/model.py create mode 100644 paddlemix/models/audioldm2/latent_encoder/autoencoder.py create mode 100644 paddlemix/models/audioldm2/latentdiffusion_samplers.py create mode 100644 paddlemix/models/audioldm2/modeling.py create mode 100644 paddlemix/models/audioldm2/requirement.txt create mode 100644 paddlemix/models/audioldm2/unet/attention.py create mode 100644 paddlemix/models/audioldm2/unet/openaimodel.py create mode 100644 paddlemix/models/audioldm2/utils.py diff --git a/paddlemix/examples/audioldm2/README.md b/paddlemix/examples/audioldm2/README.md new file mode 100644 index 000000000..428b11e4f --- /dev/null +++ b/paddlemix/examples/audioldm2/README.md @@ -0,0 +1,27 @@ +# AudioLDM2 + +## 1. 模型简介 + +该模型是 [AudioLDM2](https://arxiv.org/abs/2308.05734) 的 paddle 实现。 + + +## 2. Demo + +### 2.1 依赖安装 + +- 请确保已安装 ppdiffusers ([参考方法](https://github.com/PaddlePaddle/PaddleMIX/blob/develop/README.md?plain=1#L62)) + +- 其余依赖安装: + +```bash +cd /paddlemix/models/audioldm2 +pip install -r requirement.txt +``` + +### 2.2 动态图推理 +```bash +python run_predict.py \ +--text "Musical constellations twinkling in the night sky, forming a cosmic melody." \ +--model_name_or_path "/my_model_path" \ +--seed 1001 \ +``` diff --git a/paddlemix/examples/audioldm2/run_predict.py b/paddlemix/examples/audioldm2/run_predict.py new file mode 100644 index 000000000..5319b0a5c --- /dev/null +++ b/paddlemix/examples/audioldm2/run_predict.py @@ -0,0 +1,302 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +import paddle +from paddlenlp.trainer import PdArgumentParser +import os +import time +import soundfile as sf +from paddlemix.models.audioldm2.modeling import AudioLDM2Model +from paddlemix.models.audioldm2.encoders.phoneme_encoder import text as text +import random +import numpy as np +import re + +def seed_everything(seed): + os.environ["PYTHONHASHSEED"] = str(seed) + random.seed(seed) + np.random.seed(seed) + paddle.seed(seed) + +def text2phoneme(data): + return text._clean_text(re.sub(r'<.*?>', '', data), ["english_cleaners2"]) + +def text_to_filename(text): + return text.replace(" ", "_").replace("'", "_").replace('"', "_") + +CACHE = { + "get_vits_phoneme_ids":{ + "PAD_LENGTH": 310, + "_pad": '_', + "_punctuation": ';:,.!?¡¿—…"«»“” ', + "_letters": 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz', + "_letters_ipa": "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ", + "_special": "♪☎☒☝⚠" + } +} + +CACHE["get_vits_phoneme_ids"]["symbols"] = [CACHE["get_vits_phoneme_ids"]["_pad"]] + list(CACHE["get_vits_phoneme_ids"]["_punctuation"]) + list(CACHE["get_vits_phoneme_ids"]["_letters"]) + list(CACHE["get_vits_phoneme_ids"]["_letters_ipa"]) + list(CACHE["get_vits_phoneme_ids"]["_special"]) +CACHE["get_vits_phoneme_ids"]["_symbol_to_id"] = {s: i for i, s in enumerate(CACHE["get_vits_phoneme_ids"]["symbols"])} + +def get_vits_phoneme_ids_no_padding(phonemes): + pad_token_id = 0 + pad_length = CACHE["get_vits_phoneme_ids"]["PAD_LENGTH"] + _symbol_to_id = CACHE["get_vits_phoneme_ids"]["_symbol_to_id"] + batchsize = len(phonemes) + + clean_text = phonemes[0] + "⚠" + sequence = [] + + for symbol in clean_text: + if(symbol not in _symbol_to_id.keys()): + print("%s is not in the vocabulary. %s" % (symbol, clean_text)) + symbol = "_" + symbol_id = _symbol_to_id[symbol] + sequence += [symbol_id] + + def _pad_phonemes(phonemes_list): + return phonemes_list + [pad_token_id] * (pad_length-len(phonemes_list)) + + sequence = sequence[:pad_length] + + return {"phoneme_idx": paddle.to_tensor(_pad_phonemes(sequence), dtype="int64").unsqueeze(0).expand([batchsize, -1])} + + +def make_batch_for_text_to_audio(text, transcription="", waveform=None, fbank=None, batchsize=1): + text = [text] * batchsize + if(transcription): + transcription = text2phoneme(transcription) + transcription = [transcription] * batchsize + + if batchsize < 1: + print("Warning: Batchsize must be at least 1. Batchsize is set to .") + + if fbank is None: + fbank = paddle.zeros( + (batchsize, 1024, 64) + ) # Not used, here to keep the code format + else: + fbank = paddle.to_tensor(fbank, dtype="float32") + fbank = fbank.expand([batchsize, 1024, 64]) + assert fbank.shape[0] == batchsize + + stft = paddle.zeros((batchsize, 1024, 512)) # Not used + phonemes = get_vits_phoneme_ids_no_padding(transcription) + + waveform = paddle.zeros((batchsize, 160000)) # Not used + ta_kaldi_fbank = paddle.zeros((batchsize, 1024, 128)) + + batch = { + "text": text, # list + "fname": [text_to_filename(t) for t in text], # list + "waveform": waveform, + "stft": stft, + "log_mel_spec": fbank, + "ta_kaldi_fbank": ta_kaldi_fbank, + } + batch.update(phonemes) + return batch + +def get_time(): + t = time.localtime() + return time.strftime("%d_%m_%Y_%H_%M_%S", t) + +def save_wave(waveform, savepath, name="outwav", samplerate=16000): + if type(name) is not list: + name = [name] * waveform.shape[0] + + for i in range(waveform.shape[0]): + if waveform.shape[0] > 1: + fname = "%s_%s.wav" % ( + os.path.basename(name[i]) + if (not ".wav" in name[i]) + else os.path.basename(name[i]).split(".")[0], + i, + ) + else: + fname = "%s.wav" % os.path.basename(name[i]) if (not ".wav" in name[i]) else os.path.basename(name[i]).split(".")[0] + # Avoid the file name too long to be saved + if len(fname) > 255: + fname = f"{hex(hash(fname))}.wav" + + path = os.path.join( + savepath, fname + ) + print("Save audio to %s" % path) + sf.write(path, waveform[i, 0], samplerate=samplerate) + +def read_list(fname): + result = [] + with open(fname, "r", encoding="utf-8") as f: + for each in f.readlines(): + each = each.strip('\n') + result.append(each) + return result + +def text_to_audio( + model, + text, + transcription="", + seed=42, + ddim_steps=200, + duration=10, + batchsize=1, + guidance_scale=3.5, + n_candidate_gen_per_text=3, + latent_t_per_second=25.6, + ): + + seed_everything(int(seed)) + waveform = None + + batch = make_batch_for_text_to_audio(text, transcription=transcription, waveform=waveform, batchsize=batchsize) + + model.latent_t_size = int(duration * latent_t_per_second) + + waveform = model( + batch, + unconditional_guidance_scale=guidance_scale, + ddim_steps=ddim_steps, + n_gen=n_candidate_gen_per_text, + duration=duration, + ) + + return waveform + + +@dataclass +class DataArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + Using `PdArgumentParser` we can turn this class + into argparse arguments to be able to specify them on + the command line. + """ + + text: str = field(default="", metadata={"help": "Text prompt to the model for audio generation."}) + transcription: str = field(default="", metadata={"help": "Transcription for Text-to-Speech."}) + text_list: str = field(default="", metadata={"help": "A file (utf-8 encoded) that contains text prompt to the model for audio generation."}) + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + default="audioldm2-full", + metadata={"help": "Path to pretrained model or model identifier"}, + ) + save_path: str = field( + default="./output", + metadata={"help": "The path to save model output."}, + ) + device: str = field( + default="gpu", + metadata={"help": "The device for computation. If not specified, the script will automatically choose gpu."}, + ) + batchsize: int = field( + default=1, + metadata={"help": "Generate how many samples at the same time."}, + ) + ddim_steps: int = field( + default=200, + metadata={"help": "The sampling step for DDIM."}, + ) + guidance_scale: float = field( + default=3.5, + metadata={"help": "Guidance scale (Large => better quality and relavancy to text; Small => better diversity)."}, + ) + duration: float = field( + default=10.0, + metadata={"help": "The duration of the samples."}, + ) + n_candidate_gen_per_text: int = field( + default=3, + metadata={"help": "Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation."}, + ) + seed: int = field( + default=42, + metadata={"help": "Change this value (any integer number) will lead to a different generation result."}, + ) + +def main(): + parser = PdArgumentParser((ModelArguments, DataArguments)) + model_args, data_args = parser.parse_args_into_dataclasses() + + # process args + text = data_args.text + transcription = data_args.transcription + text_list = data_args.text_list + + save_path = os.path.join(model_args.save_path, get_time()) + random_seed = model_args.seed + duration = model_args.duration + sample_rate = 16000 + latent_t_per_second = 25.6 + + print("Warning: For AudioLDM2 we currently only support 10s of generation. Please use audioldm_48k or audioldm_16k_crossattn_t5 if you want a different duration.") + duration = 10 + + guidance_scale = model_args.guidance_scale + n_candidate_gen_per_text = model_args.n_candidate_gen_per_text + + if transcription: + if "speech" not in model_args.model_name_or_path: + print("Warning: You choose to perform Text-to-Speech by providing the transcription. However you do not choose the correct model name (audioldm2-speech-gigaspeech or audioldm2-speech-ljspeech).") + print("Warning: We will use audioldm2-speech-gigaspeech by default") + model_args.model_name_or_path = "audioldm2-speech-gigaspeech" + if not text: + print("Warning: You should provide text as a input to describe the speaker. Use default (A male reporter is speaking).") + text = "A female reporter is speaking full of emotion" + + if text_list: + print("Generate audio based on the text prompts in %s" % text_list) + prompt_todo = read_list(text_list) + else: + prompt_todo = [text] + + # build audioldm2 model + paddle.set_device(model_args.device) + audioldm2 = AudioLDM2Model.from_pretrained(model_args.model_name_or_path) + + # predict + os.makedirs(save_path, exist_ok=True) + for text in prompt_todo: + if "|" in text: + text, name = text.split("|") + else: + name = text[:128] + + if transcription: + name += "-TTS-%s" % transcription + + waveform = text_to_audio( + audioldm2, + text, + transcription=transcription, # To avoid the model to ignore the last vocab + seed=random_seed, + duration=duration, + guidance_scale=guidance_scale, + ddim_steps=model_args.ddim_steps, + n_candidate_gen_per_text=n_candidate_gen_per_text, + batchsize=model_args.batchsize, + latent_t_per_second=latent_t_per_second + ) + + save_wave(waveform, save_path, name=name, samplerate=sample_rate) + +if __name__ == "__main__": + main() diff --git a/paddlemix/models/__init__.py b/paddlemix/models/__init__.py index 9bc505502..9688b7e53 100644 --- a/paddlemix/models/__init__.py +++ b/paddlemix/models/__init__.py @@ -21,3 +21,5 @@ from .qwen_vl import * from .visualglm.configuration import * from .visualglm.modeling import * +from .audioldm2.modeling import * +from .audioldm2.configuration import * diff --git a/paddlemix/models/audioldm2/audiomae/__init__.py b/paddlemix/models/audioldm2/audiomae/__init__.py new file mode 100644 index 000000000..fd05a9208 --- /dev/null +++ b/paddlemix/models/audioldm2/audiomae/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlemix/models/audioldm2/audiomae/mae.py b/paddlemix/models/audioldm2/audiomae/mae.py new file mode 100644 index 000000000..a00305ee5 --- /dev/null +++ b/paddlemix/models/audioldm2/audiomae/mae.py @@ -0,0 +1,367 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +import paddle +import paddle.nn as nn +from ..utils import to_2tuple, DropPath, Mlp +from ..clap_module.htsat_model import SwinTransformerBlock + +class Attention(nn.Layer): + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Layer = nn.LayerNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.fused_attn = False + + self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: paddle.Tensor) -> paddle.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape([B, N, 3, self.num_heads, self.head_dim]).transpose([2, 0, 3, 1, 4]) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + if self.fused_attn: + x = nn.functional.scaled_dot_product_attention( + q, k, v, + dropout=self.attn_drop.p if self.training else 0., + )[0] + else: + q = q * self.scale + k_perm = list(range(k.dim())) + new_perm = k_perm + new_perm[-2],new_perm[-1] = k_perm[-1],k_perm[-2] + attn = q @ k.transpose(new_perm) + attn = nn.functional.softmax(attn,axis=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x_perm = list(range(x.dim())) + new_perm = x_perm + new_perm[1],new_perm[2] = x_perm[2],x_perm[1] + x = x.transpose(new_perm).reshape([B, N, C]) + x = self.proj(x) + x = self.proj_drop(x) + return x + +class LayerScale(nn.Layer): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + tmp = init_values * paddle.ones(dim) + self.gamma = paddle.create_parameter(shape=tmp.shape, + dtype=tmp.dtype, + default_initializer=nn.initializer.Assign(tmp)) + self.gamma.stop_gradient = False + + def forward(self, x): + if self.inplace: + x = paddle.multiply(x, self.gamma) + return x + else: + return x * self.gamma + # return paddle.multiply(x, self.gamma) if self.inplace else x * self.gamma + +class Block(nn.Layer): + + def __init__( + self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + drop=0., + attn_drop=0., + init_values=None, + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + +class PatchEmbed_org(nn.Layer): + """Image to Patch Embedding""" + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2D( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size + ) + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + # assert H == self.img_size[0] and W == self.img_size[1], \ + # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x) + y = x.flatten(2).transpose([0, 2, 1]) + return y + +class MaskedAutoencoderViT(nn.Layer): + """Masked Autoencoder with VisionTransformer backbone""" + + def __init__( + self, + img_size=224, + patch_size=16, + stride=10, + in_chans=3, + embed_dim=1024, + depth=24, + num_heads=16, + decoder_embed_dim=512, + decoder_depth=8, + decoder_num_heads=16, + mlp_ratio=4.0, + norm_layer=nn.LayerNorm, + norm_pix_loss=False, + audio_exp=False, + alpha=0.0, + temperature=0.2, + mode=0, + contextual_depth=8, + split_pos=False, + pos_trainable=False, + use_nce=False, + beta=4.0, + decoder_mode=0, + mask_t_prob=0.6, + mask_f_prob=0.5, + mask_2d=False, + epoch=0, + no_shift=False, + use_custom_patch=False, + ): + super().__init__() + + self.audio_exp = audio_exp + self.embed_dim = embed_dim + self.decoder_embed_dim = decoder_embed_dim + # -------------------------------------------------------------------------- + # MAE encoder specifics + self.patch_embed = PatchEmbed_org(img_size, patch_size, in_chans, embed_dim) + self.use_custom_patch = use_custom_patch + + num_patches = self.patch_embed.num_patches + tmp = paddle.zeros([1, 1, embed_dim]) + self.cls_token = paddle.create_parameter(shape=tmp.shape, + dtype=tmp.dtype, + default_initializer=nn.initializer.Assign(tmp)) + self.cls_token.stop_gradient = False + + # self.split_pos = split_pos # not useful + tmp = paddle.zeros([1, num_patches + 1, embed_dim]) + self.pos_embed = paddle.create_parameter(shape=tmp.shape, + dtype=tmp.dtype, + default_initializer=nn.initializer.Assign(tmp)) # fixed sin-cos embedding + self.pos_embed.stop_gradient = not pos_trainable + + self.encoder_depth = depth + self.contextual_depth = contextual_depth + self.blocks = nn.LayerList( + [ + Block( + embed_dim, + num_heads, + mlp_ratio, + qkv_bias=True, + norm_layer=norm_layer, + ) # qk_scale=None + for i in range(depth) + ] + ) + self.norm = norm_layer(embed_dim) + + # -------------------------------------------------------------------------- + # MAE decoder specifics + self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias_attr=True) + + tmp = paddle.zeros([1, 1, decoder_embed_dim]) + self.mask_token = paddle.create_parameter(shape=tmp.shape, + dtype=tmp.dtype, + default_initializer=nn.initializer.Assign(tmp)) + self.mask_token.stop_gradient = False + + tmp = paddle.zeros([1, num_patches + 1, decoder_embed_dim]) + self.decoder_pos_embed = paddle.create_parameter(shape=tmp.shape, + dtype=tmp.dtype, + default_initializer=nn.initializer.Assign(tmp)) # fixed sin-cos embedding + self.decoder_pos_embed.stop_gradient = not pos_trainable + + self.no_shift = no_shift + + self.decoder_mode = decoder_mode + if ( + self.use_custom_patch + ): # overlapped patches as in AST. Similar performance yet compute heavy + window_size = (6, 6) + feat_size = (102, 12) + else: + window_size = (4, 4) + feat_size = (64, 8) + if self.decoder_mode == 1: + decoder_modules = [] + for index in range(16): + if self.no_shift: + shift_size = (0, 0) + else: + if (index % 2) == 0: + shift_size = (0, 0) + else: + shift_size = (2, 0) + decoder_modules.append( + SwinTransformerBlock( + dim=decoder_embed_dim, + num_heads=16, + input_resolution=feat_size, + window_size=window_size, + shift_size=shift_size, + mlp_ratio=mlp_ratio, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + # extra_norm=False, + # sequential_attn=False, + norm_layer=norm_layer, # nn.LayerNorm, + ) + ) + self.decoder_blocks = nn.LayerList(decoder_modules) + else: + # Transfomer + self.decoder_blocks = nn.LayerList( + [ + Block( + decoder_embed_dim, + decoder_num_heads, + mlp_ratio, + qkv_bias=True, + norm_layer=norm_layer, + ) # qk_scale=None, + for i in range(decoder_depth) + ] + ) + + self.decoder_norm = norm_layer(decoder_embed_dim) + self.decoder_pred = nn.Linear( + decoder_embed_dim, patch_size**2 * in_chans, bias_attr=True + ) # decoder to patch + + # -------------------------------------------------------------------------- + + self.norm_pix_loss = norm_pix_loss + + self.patch_size = patch_size + self.stride = stride + + # audio exps + self.alpha = alpha + self.T = temperature + self.mode = mode + self.use_nce = use_nce + self.beta = beta + + self.log_softmax = nn.LogSoftmax(axis=-1) + + self.mask_t_prob = mask_t_prob + self.mask_f_prob = mask_f_prob + self.mask_2d = mask_2d + + self.epoch = epoch + + # self.initialize_weights() + + def forward_encoder_no_mask(self, x): + # embed patches + x = self.patch_embed(x) + + # add pos embed w/o cls token + x = x + self.pos_embed[:, 1:, :] + + # masking: length -> length * mask_ratio + # x, mask, ids_restore = self.random_masking(x, mask_ratio) + # append cls token + cls_token = self.cls_token + self.pos_embed[:, :1, :] + cls_tokens = cls_token.expand([x.shape[0], -1, -1]) + x = paddle.concat((cls_tokens, x), axis=1) + + # apply Transformer blocks + contextual_embs = [] + for n, blk in enumerate(self.blocks): + x = blk(x) + if n > self.contextual_depth: + contextual_embs.append(self.norm(x)) + contextual_emb = paddle.stack(contextual_embs, axis=0).mean(axis=0) + + return contextual_emb + + +def mae_vit_base_patch16_dec512d8b(**kwargs): + model = MaskedAutoencoderViT( + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + decoder_embed_dim=512, + decoder_num_heads=16, + mlp_ratio=4, + norm_layer=partial(nn.LayerNorm, epsilon=1e-6), + **kwargs, + ) + return model + + +# set recommended archs +mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks diff --git a/paddlemix/models/audioldm2/clap_module/clap.py b/paddlemix/models/audioldm2/clap_module/clap.py new file mode 100644 index 000000000..c6bccbbb7 --- /dev/null +++ b/paddlemix/models/audioldm2/clap_module/clap.py @@ -0,0 +1,47 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .model import CLAP, CLAPAudioCfg, CLAPTextCfg +import dataclasses +from dataclasses import dataclass + +@dataclass +class CLAPConfig: + embed_dim: int = 1024 + audio_cfg: CLAPAudioCfg = CLAPAudioCfg() + text_cfg: CLAPTextCfg = CLAPTextCfg() + +def create_clap_model( + amodel_name: str, + tmodel_name: str, + pretrained: str = "", + precision: str = "fp32", + force_quick_gelu: bool = False, + enable_fusion: bool = False, + fusion_type: str = "None" +): + pretrained = pretrained.lower() + + model_cfg = CLAPConfig() + model_cfg = dataclasses.asdict(model_cfg) + if force_quick_gelu: + # override for use of QuickGELU on non-OpenAI transformer models + model_cfg["quick_gelu"] = True + + model_cfg["text_cfg"]["model_type"] = tmodel_name + model_cfg["enable_fusion"] = enable_fusion + model_cfg["fusion_type"] = fusion_type + model = CLAP(**model_cfg) + + return model, model_cfg diff --git a/paddlemix/models/audioldm2/clap_module/feature_fusion.py b/paddlemix/models/audioldm2/clap_module/feature_fusion.py new file mode 100644 index 000000000..4a2fc987d --- /dev/null +++ b/paddlemix/models/audioldm2/clap_module/feature_fusion.py @@ -0,0 +1,200 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn + + +class DAF(nn.Layer): + """ + 直接相加 DirectAddFuse + """ + + def __init__(self): + super(DAF, self).__init__() + + def forward(self, x, residual): + return x + residual + + +class iAFF(nn.Layer): + """ + 多特征融合 iAFF + """ + + def __init__(self, channels=64, r=4, type="2D"): + super(iAFF, self).__init__() + inter_channels = int(channels // r) + + if type == "1D": + # 本地注意力 + self.local_att = nn.Sequential( + nn.Conv1D(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1D(inter_channels), + nn.ReLU(inplace=True), + nn.Conv1D(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1D(channels), + ) + + # 全局注意力 + self.global_att = nn.Sequential( + nn.AdaptiveAvgPool1D(1), + nn.Conv1D(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1D(inter_channels), + nn.ReLU(), + nn.Conv1D(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1D(channels), + ) + + # 第二次本地注意力 + self.local_att2 = nn.Sequential( + nn.Conv1D(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1D(inter_channels), + nn.ReLU(), + nn.Conv1D(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1D(channels), + ) + # 第二次全局注意力 + self.global_att2 = nn.Sequential( + nn.AdaptiveAvgPool1D(1), + nn.Conv1D(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1D(inter_channels), + nn.ReLU(), + nn.Conv1D(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1D(channels), + ) + elif type == "2D": + # 本地注意力 + self.local_att = nn.Sequential( + nn.Conv2D(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2D(inter_channels), + nn.ReLU(), + nn.Conv2D(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2D(channels), + ) + + # 全局注意力 + self.global_att = nn.Sequential( + nn.AdaptiveAvgPool2D(1), + nn.Conv2D(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2D(inter_channels), + nn.ReLU(), + nn.Conv2D(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2D(channels), + ) + + # 第二次本地注意力 + self.local_att2 = nn.Sequential( + nn.Conv2D(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2D(inter_channels), + nn.ReLU(), + nn.Conv2D(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2D(channels), + ) + # 第二次全局注意力 + self.global_att2 = nn.Sequential( + nn.AdaptiveAvgPool2D(1), + nn.Conv2D(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2D(inter_channels), + nn.ReLU(), + nn.Conv2D(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2D(channels), + ) + else: + raise f"the type is not supported" + + self.sigmoid = nn.Sigmoid() + + def forward(self, x, residual): + flag = False + xa = x + residual + if xa.size(0) == 1: + xa = paddle.concat([xa, xa], axis=0) + flag = True + xl = self.local_att(xa) + xg = self.global_att(xa) + xlg = xl + xg + wei = self.sigmoid(xlg) + xi = x * wei + residual * (1 - wei) + + xl2 = self.local_att2(xi) + xg2 = self.global_att(xi) + xlg2 = xl2 + xg2 + wei2 = self.sigmoid(xlg2) + xo = x * wei2 + residual * (1 - wei2) + if flag: + xo = xo[0].unsqueeze(0) + return xo + + +class AFF(nn.Layer): + """ + 多特征融合 AFF + """ + + def __init__(self, channels=64, r=4, type="2D"): + super(AFF, self).__init__() + inter_channels = int(channels // r) + + if type == "1D": + self.local_att = nn.Sequential( + nn.Conv1D(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1D(inter_channels), + nn.ReLU(), + nn.Conv1D(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1D(channels), + ) + self.global_att = nn.Sequential( + nn.AdaptiveAvgPool1D(1), + nn.Conv1D(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1D(inter_channels), + nn.ReLU(), + nn.Conv1D(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1D(channels), + ) + elif type == "2D": + self.local_att = nn.Sequential( + nn.Conv2D(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2D(inter_channels), + nn.ReLU(), + nn.Conv2D(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2D(channels), + ) + self.global_att = nn.Sequential( + nn.AdaptiveAvgPool2D(1), + nn.Conv2D(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2D(inter_channels), + nn.ReLU(), + nn.Conv2D(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2D(channels), + ) + else: + raise f"the type is not supported." + + self.sigmoid = nn.Sigmoid() + + def forward(self, x, residual): + flag = False + xa = x + residual + if xa.size(0) == 1: + xa = paddle.concat([xa, xa], axis=0) + flag = True + xl = self.local_att(xa) + xg = self.global_att(xa) + xlg = xl + xg + wei = self.sigmoid(xlg) + xo = 2 * x * wei + 2 * residual * (1 - wei) + if flag: + xo = xo[0].unsqueeze(0) + return xo diff --git a/paddlemix/models/audioldm2/clap_module/htsat_model.py b/paddlemix/models/audioldm2/clap_module/htsat_model.py new file mode 100644 index 000000000..c6588a929 --- /dev/null +++ b/paddlemix/models/audioldm2/clap_module/htsat_model.py @@ -0,0 +1,1105 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn + +import math +import warnings +import random + +from .utils import do_mixup, interpolate, Spectrogram, LogmelFilterBank, SpecAugmentation +from ..utils import to_2tuple, DropPath, Mlp +from .feature_fusion import iAFF, AFF, DAF + +class PatchEmbed(nn.Layer): + """2D Image to Patch Embedding""" + + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + norm_layer=None, + flatten=True, + patch_stride=16, + enable_fusion=False, + fusion_type="None", + ): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patch_stride = to_2tuple(patch_stride) + self.img_size = img_size + self.patch_size = patch_size + self.patch_stride = patch_stride + self.grid_size = ( + img_size[0] // patch_stride[0], + img_size[1] // patch_stride[1], + ) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.enable_fusion = enable_fusion + self.fusion_type = fusion_type + + padding = ( + (patch_size[0] - patch_stride[0]) // 2, + (patch_size[1] - patch_stride[1]) // 2, + ) + + if (self.enable_fusion) and (self.fusion_type == "channel_map"): + self.proj = nn.Conv2D( + in_chans * 4, + embed_dim, + kernel_size=patch_size, + stride=patch_stride, + padding=padding, + ) + else: + self.proj = nn.Conv2D( + in_chans, + embed_dim, + kernel_size=patch_size, + stride=patch_stride, + padding=padding, + ) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + if (self.enable_fusion) and ( + self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"] + ): + self.mel_conv2d = nn.Conv2D( + in_chans, + embed_dim, + kernel_size=(patch_size[0], patch_size[1] * 3), + stride=(patch_stride[0], patch_stride[1] * 3), + padding=padding, + ) + if self.fusion_type == "daf_2d": + self.fusion_model = DAF() + elif self.fusion_type == "aff_2d": + self.fusion_model = AFF(channels=embed_dim, type="2D") + elif self.fusion_type == "iaff_2d": + self.fusion_model = iAFF(channels=embed_dim, type="2D") + + def forward(self, x, longer_idx=None): + if (self.enable_fusion) and ( + self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"] + ): + global_x = x[:, 0:1, :, :] + + # global processing + B, C, H, W = global_x.shape + assert ( + H == self.img_size[0] and W == self.img_size[1] + ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + global_x = self.proj(global_x) + TW = global_x.shape[-1] + if len(longer_idx) > 0: + # local processing + local_x = x[longer_idx, 1:, :, :] + B, C, H, W = local_x.shape + local_x = local_x.reshape([B * C, 1, H, W]) + local_x = self.mel_conv2d(local_x) + local_x = local_x.reshape( + [B, C, local_x.shape[1], local_x.shape[2], local_x.shape[3]] + ) + local_x = local_x.transpose([0, 2, 3, 1, 4]).flatten(3) + TB, TC, TH, _ = local_x.shape + if local_x.shape[-1] < TW: + local_x = paddle.concat( + [ + local_x, + paddle.zeros( + [TB, TC, TH, TW - local_x.shape[-1]] + ), + ], + axis=-1, + ) + else: + local_x = local_x[:, :, :, :TW] + + global_x[longer_idx] = self.fusion_model(global_x[longer_idx], local_x) + x = global_x + else: + B, C, H, W = x.shape + assert ( + H == self.img_size[0] and W == self.img_size[1] + ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x) + + if self.flatten: + x = x.flatten(2).transpose([0, 2, 1]) # BCHW -> BNC + x = self.norm(x) + return x + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + with paddle.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor = paddle.multiply(tensor, paddle.to_tensor(std) * math.sqrt(2.0)) + # tensor.mul_(std * math.sqrt(2.0)) + tensor = paddle.add(tensor, paddle.to_tensor(mean)) + # tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clip_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.reshape([B, H // window_size, window_size, W // window_size, window_size, C]) + windows = ( + x.transpose([0, 1, 3, 2, 4, 5]).reshape([-1, window_size, window_size, C]) + ) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.reshape( + [B, H // window_size, W // window_size, window_size, window_size, -1] + ) + x = x.transpose([0, 1, 3, 2, 4, 5]).reshape([B, H, W, -1]) + return x + + +class WindowAttention(nn.Layer): + r"""Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__( + self, + dim, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ): + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + # define a parameter table of relative position bias + relative_position_bias_table = paddle.zeros([(2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads]) + self.relative_position_bias_table = paddle.create_parameter( + shape=relative_position_bias_table.shape, + dtype=str(relative_position_bias_table.numpy().dtype), + default_initializer=nn.initializer.Assign(relative_position_bias_table) + ) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = paddle.arange(self.window_size[0]) + coords_w = paddle.arange(self.window_size[1]) + coords = paddle.stack(paddle.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = paddle.flatten(coords, 1) # 2, Wh*Ww + relative_coords = ( + coords_flatten[:, :, None] - coords_flatten[:, None, :] + ) # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.transpose( + [1, 2, 0] + ) # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=0.02) + self.softmax = nn.Softmax(axis=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape([B_, N, 3, self.num_heads, C // self.num_heads]) + .transpose([2, 0, 3, 1, 4]) + ) + q, k, v = ( + qkv[0], + qkv[1], + qkv[2], + ) + + q = q * self.scale + k_perm_shape = list(range(k.dim())) + k_new_perm_shape = k_perm_shape + k_new_perm_shape[-1], k_new_perm_shape[-2] = k_perm_shape[-2], k_perm_shape[-1] + attn = q @ k.transpose(k_new_perm_shape) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.reshape([-1]) + ].reshape( + [self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1] + ) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.transpose( + [2, 0, 1] + ) # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.reshape([B_ // nW, nW, self.num_heads, N, N]) + mask.unsqueeze( + 1 + ).unsqueeze(0) + attn = attn.reshape([-1, self.num_heads, N, N]) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + tmp = attn @ v + tmp_perm_shape = list(range(tmp.dim())) + new_tmp_perm_shape = tmp_perm_shape + new_tmp_perm_shape[1], new_tmp_perm_shape[2] = tmp_perm_shape[2], tmp_perm_shape[1] + x = tmp.transpose(new_tmp_perm_shape).reshape([B_, N, C]) + x = self.proj(x) + x = self.proj_drop(x) + return x, attn + + def extra_repr(self): + return f"dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}" + + +# We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model +class SwinTransformerBlock(nn.Layer): + r"""Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Layer, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Layer, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__( + self, + dim, + input_resolution, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + norm_before_mlp="ln", + ): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + self.norm_before_mlp = norm_before_mlp + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert ( + 0 <= self.shift_size < self.window_size + ), "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + if self.norm_before_mlp == "ln": + self.norm2 = nn.LayerNorm(dim) + elif self.norm_before_mlp == "bn": + self.bn2 = nn.BatchNorm1D(dim) + def norm2_fun(x): + perm_shape = list(range(x.dim())) + new_perm_shape = perm_shape + new_perm_shape[1], new_perm_shape[2] = perm_shape[2], perm_shape[1] + return self.bn2(x.transpose(new_perm_shape)).transpose(new_perm_shape) + + self.norm2 = norm2_fun + else: + raise NotImplementedError + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = paddle.zeros([1, H, W, 1]) # 1 H W 1 + h_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + w_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition( + img_mask, self.window_size + ) # nW, window_size, window_size, 1 + mask_windows = mask_windows.reshape([-1, self.window_size * self.window_size]) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = paddle.where(attn_mask != 0, paddle.ones_like(attn_mask)*float(-100.0), attn_mask) + attn_mask = paddle.where(attn_mask == 0, paddle.ones_like(attn_mask)*float(0.0), attn_mask) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def forward(self, x): + H, W = self.input_resolution + B, L, C = x.shape + + shortcut = x + x = self.norm1(x) + x = x.reshape([B, H, W, C]) + + # cyclic shift + if self.shift_size > 0: + shifted_x = paddle.roll( + x, shifts=(-self.shift_size, -self.shift_size), axis=(1, 2) + ) + else: + shifted_x = x + + # partition windows + x_windows = window_partition( + shifted_x, self.window_size + ) # nW*B, window_size, window_size, C + x_windows = x_windows.reshape( + [-1, self.window_size * self.window_size, C] + ) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows, attn = self.attn( + x_windows, mask=self.attn_mask + ) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.reshape([-1, self.window_size, self.window_size, C]) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = paddle.roll( + shifted_x, shifts=(self.shift_size, self.shift_size), axis=(1, 2) + ) + else: + x = shifted_x + x = x.reshape([B, H * W, C]) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x, attn + + def extra_repr(self): + return ( + f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + ) + + +class PatchMerging(nn.Layer): + r"""Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias_attr=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.reshape([B, H, W, C]) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = paddle.concat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.reshape([B, -1, 4 * C]) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self): + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + +class BasicLayer(nn.Layer): + """A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Layer, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Layer | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__( + self, + dim, + input_resolution, + depth, + num_heads, + window_size, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + norm_before_mlp="ln", + ): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.LayerList( + [ + SwinTransformerBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] + if isinstance(drop_path, list) + else drop_path, + norm_layer=norm_layer, + norm_before_mlp=norm_before_mlp, + ) + for i in range(depth) + ] + ) + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + input_resolution, dim=dim, norm_layer=norm_layer + ) + else: + self.downsample = None + + def forward(self, x): + attns = [] + for blk in self.blocks: + x, attn = blk(x) + if not self.training: + attns.append(attn.unsqueeze(0)) + if self.downsample is not None: + x = self.downsample(x) + if not self.training: + attn = paddle.concat(attns, axis=0) + attn = paddle.mean(attn, axis=0) + return x, attn + + def extra_repr(self): + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + +# The Core of HTSAT +class HTSAT_Swin_Transformer(nn.Layer): + r"""HTSAT based on the Swin Transformer + Args: + spec_size (int | tuple(int)): Input Spectrogram size. Default 256 + patch_size (int | tuple(int)): Patch size. Default: 4 + path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4 + in_chans (int): Number of input image channels. Default: 1 (mono) + num_classes (int): Number of classes for classification head. Default: 527 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 8 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Layer): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + config (module): The configuration Module from config.py + """ + + def __init__( + self, + spec_size=256, + patch_size=4, + patch_stride=(4, 4), + in_chans=1, + num_classes=527, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[4, 8, 16, 32], + window_size=8, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + use_checkpoint=False, + norm_before_mlp="ln", + config=None, + enable_fusion=False, + fusion_type="None", + **kwargs, + ): + super(HTSAT_Swin_Transformer, self).__init__() + + self.config = config + self.spec_size = spec_size + self.patch_stride = patch_stride + self.patch_size = patch_size + self.window_size = window_size + self.embed_dim = embed_dim + self.depths = depths + self.ape = ape + self.in_chans = in_chans + self.num_classes = num_classes + self.num_heads = num_heads + self.num_layers = len(self.depths) + self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1)) + + self.drop_rate = drop_rate + self.attn_drop_rate = attn_drop_rate + self.drop_path_rate = drop_path_rate + + self.qkv_bias = qkv_bias + self.qk_scale = None + + self.patch_norm = patch_norm + self.norm_layer = norm_layer if self.patch_norm else None + self.norm_before_mlp = norm_before_mlp + self.mlp_ratio = mlp_ratio + + self.use_checkpoint = use_checkpoint + + self.enable_fusion = enable_fusion + self.fusion_type = fusion_type + + # process mel-spec ; used only once + self.freq_ratio = self.spec_size // self.config.mel_bins + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + self.interpolate_ratio = 32 # Downsampled ratio + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=config.window_size, + hop_length=config.hop_size, + win_length=config.window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=config.sample_rate, + n_fft=config.window_size, + n_mels=config.mel_bins, + fmin=config.fmin, + fmax=config.fmax, + ref=ref, + amin=amin, + top_db=top_db, + freeze_parameters=True, + ) + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) # 2 2 + self.bn0 = nn.BatchNorm2D(self.config.mel_bins) + + # split spctrogram into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=self.spec_size, + patch_size=self.patch_size, + in_chans=self.in_chans, + embed_dim=self.embed_dim, + norm_layer=self.norm_layer, + patch_stride=patch_stride, + enable_fusion=self.enable_fusion, + fusion_type=self.fusion_type, + ) + + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.grid_size + self.patches_resolution = patches_resolution + + # absolute position embedding + if self.ape: + absolute_pos_embed = paddle.zeros([1, num_patches, self.embed_dim]) + self.absolute_pos_embed = paddle.create_parameter( + shape=absolute_pos_embed.shape, + dtype=str(absolute_pos_embed.numpy().dtype), + default_initializer=nn.initializer.Assign(absolute_pos_embed) + ) + trunc_normal_(self.absolute_pos_embed, std=0.02) + + self.pos_drop = nn.Dropout(p=self.drop_rate) + + # stochastic depth + dpr = [ + x.item() for x in paddle.linspace(0, self.drop_path_rate, sum(self.depths)) + ] # stochastic depth decay rule + + # build layers + self.layers = nn.LayerList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(self.embed_dim * 2**i_layer), + input_resolution=( + patches_resolution[0] // (2**i_layer), + patches_resolution[1] // (2**i_layer), + ), + depth=self.depths[i_layer], + num_heads=self.num_heads[i_layer], + window_size=self.window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=self.qkv_bias, + qk_scale=self.qk_scale, + drop=self.drop_rate, + attn_drop=self.attn_drop_rate, + drop_path=dpr[ + sum(self.depths[:i_layer]) : sum(self.depths[: i_layer + 1]) + ], + norm_layer=self.norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + norm_before_mlp=self.norm_before_mlp, + ) + self.layers.append(layer) + + self.norm = self.norm_layer(self.num_features) + self.avgpool = nn.AdaptiveAvgPool1D(1) + self.maxpool = nn.AdaptiveMaxPool1D(1) + + SF = ( + self.spec_size + // (2 ** (len(self.depths) - 1)) + // self.patch_stride[0] + // self.freq_ratio + ) + self.tscam_conv = nn.Conv2D( + in_channels=self.num_features, + out_channels=self.num_classes, + kernel_size=(SF, 3), + padding=(0, 1), + ) + self.head = nn.Linear(num_classes, num_classes) + + if (self.enable_fusion) and ( + self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"] + ): + self.mel_conv1d = nn.Sequential( + nn.Conv1D(64, 64, kernel_size=5, stride=3, padding=2), + nn.BatchNorm1D(64), + ) + if self.fusion_type == "daf_1d": + self.fusion_model = DAF() + elif self.fusion_type == "aff_1d": + self.fusion_model = AFF(channels=64, type="1D") + elif self.fusion_type == "iaff_1d": + self.fusion_model = iAFF(channels=64, type="1D") + + @paddle.jit.not_to_static + def no_weight_decay(self): + return {"absolute_pos_embed"} + + @paddle.jit.not_to_static + def no_weight_decay_keywords(self): + return {"relative_position_bias_table"} + + def forward_features(self, x, longer_idx=None): + # A deprecated optimization for using a hierarchical output from different blocks + + frames_num = x.shape[2] + x = self.patch_embed(x, longer_idx=longer_idx) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + for i, layer in enumerate(self.layers): + x, attn = layer(x) + # for x + x = self.norm(x) + B, N, C = x.shape + SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0] + ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1] + x = x.transpose([0, 2, 1]).reshape([B, C, SF, ST]) + B, C, F, T = x.shape + # group 2D CNN + c_freq_bin = F // self.freq_ratio + x = x.reshape([B, C, F // c_freq_bin, c_freq_bin, T]) + x = x.transpose([0, 1, 3, 2, 4]).reshape([B, C, c_freq_bin, -1]) + # get latent_output + fine_grained_latent_output = paddle.mean(x, axis=2) + fine_grained_latent_output = interpolate( + fine_grained_latent_output.transpose([0, 2, 1]), + 8 * self.patch_stride[1], + ) + + latent_output = self.avgpool(paddle.flatten(x, 2)) + latent_output = paddle.flatten(latent_output, 1) + + # display the attention map, if needed + + x = self.tscam_conv(x) + x = paddle.flatten(x, 2) # B, C, T + + fpx = interpolate( + nn.functional.sigmoid(x).transpose([0, 2, 1]), 8 * self.patch_stride[1] + ) + + x = self.avgpool(x) + x = paddle.flatten(x, 1) + + output_dict = { + "framewise_output": fpx, # already sigmoided + "clipwise_output": nn.functional.sigmoid(x), + "fine_grained_embedding": fine_grained_latent_output, + "embedding": latent_output, + } + + return output_dict + + def crop_wav(self, x, crop_size, spe_pos=None): + time_steps = x.shape[2] + tx = paddle.zeros([x.shape[0], x.shape[1], crop_size, x.shape[3]]) + for i in range(len(x)): + if spe_pos is None: + crop_pos = random.randint(0, time_steps - crop_size - 1) + else: + crop_pos = spe_pos + tx[i][0] = x[i, 0, crop_pos : crop_pos + crop_size, :] + return tx + + # Reshape the wavform to a img size, if you want to use the pretrained swin transformer model + def reshape_wav2img(self, x): + B, C, T, F = x.shape + target_T = int(self.spec_size * self.freq_ratio) + target_F = self.spec_size // self.freq_ratio + assert ( + T <= target_T and F <= target_F + ), "the wav size should less than or equal to the swin input size" + # to avoid bicubic zero error + if T < target_T: + x = nn.functional.interpolate( + x, (target_T, x.shape[3]), mode="bicubic", align_corners=True + ) + if F < target_F: + x = nn.functional.interpolate( + x, (x.shape[2], target_F), mode="bicubic", align_corners=True + ) + x = x.transpose([0, 1, 3, 2]) + x = x.reshape( + [x.shape[0], + x.shape[1], + x.shape[2], + self.freq_ratio, + x.shape[3] // self.freq_ratio] + ) + # print(x.shape) + x = x.transpose([0, 1, 3, 2, 4]) + x = x.reshape([x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4]]) + return x + + # Repeat the wavform to a img size, if you want to use the pretrained swin transformer model + def repeat_wat2img(self, x, cur_pos): + B, C, T, F = x.shape + target_T = int(self.spec_size * self.freq_ratio) + target_F = self.spec_size // self.freq_ratio + assert ( + T <= target_T and F <= target_F + ), "the wav size should less than or equal to the swin input size" + # to avoid bicubic zero error + if T < target_T: + x = nn.functional.interpolate( + x, (target_T, x.shape[3]), mode="bicubic", align_corners=True + ) + if F < target_F: + x = nn.functional.interpolate( + x, (x.shape[2], target_F), mode="bicubic", align_corners=True + ) + x = x.transpose([0, 1, 3, 2]) # B C F T + x = x[:, :, :, cur_pos : cur_pos + self.spec_size] + # x = x.repeat_interleave(repeats=(1, 1, 4, 1)) + x = x.repeat_interleave(repeats=4, axis=2) + return x + + def forward( + self, x: paddle.Tensor, mixup_lambda=None, infer_mode=False, device=None + ): # out_feat_keys: List[str] = None): + if self.enable_fusion and x["longer"].sum() == 0: + # if no audio is longer than 10s, then randomly select one audio to be longer + x["longer"][paddle.randint(0, x["longer"].shape[0], (1,))] = True + + if not self.enable_fusion: + x = x["waveform"] + x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + x = x.transpose([0, 3, 2, 1]) + x = self.bn0(x) + x = x.transpose([0, 3, 2, 1]) + if self.training: + x = self.spec_augmenter(x) + + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.reshape_wav2img(x) + output_dict = self.forward_features(x) + else: + longer_list = x["longer"] + x = x["mel_fusion"] + x = x.transpose([0, 3, 2, 1]) + x = self.bn0(x) + x = x.transpose([0, 3, 2, 1]) + longer_list_idx = paddle.where(longer_list)[0].squeeze() + if self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]: + new_x = x[:, 0:1, :, :].clone() + if len(longer_list_idx) > 0: + # local processing + fusion_x_local = x[longer_list_idx, 1:, :, :].clone() + FB, FC, FT, FF = fusion_x_local.shape + fusion_x_local = fusion_x_local.reshape([FB * FC, FT, FF]) + fusion_x_local = paddle.transpose( + fusion_x_local, (0, 2, 1) + ) + fusion_x_local = self.mel_conv1d(fusion_x_local) + fusion_x_local = fusion_x_local.reshape( + FB, FC, FF, fusion_x_local.shape[-1] + ) + fusion_x_local = ( + paddle.transpose(fusion_x_local, (0, 2, 1, 3)) + .flatten(2) + ) + if fusion_x_local.shape[-1] < FT: + fusion_x_local = paddle.concat( + [ + fusion_x_local, + paddle.zeros( + (FB, FF, FT - fusion_x_local.size(-1)) + ), + ], + axis=-1, + ) + else: + fusion_x_local = fusion_x_local[:, :, :FT] + # 1D fusion + new_x = new_x.squeeze(1).transpose((0, 2, 1)) + new_x[longer_list_idx] = self.fusion_model( + new_x[longer_list_idx], fusion_x_local + ) + x = new_x.transpose((0, 2, 1))[:, None, :, :] + else: + x = new_x + + elif self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d", "channel_map"]: + x = x # no change + + if self.training: + x = self.spec_augmenter(x) + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.reshape_wav2img(x) + output_dict = self.forward_features(x, longer_idx=longer_list_idx) + + return output_dict + + +def create_htsat_model(audio_cfg, enable_fusion=False, fusion_type="None"): + try: + assert audio_cfg.model_name in [ + "base", + ], "model name for HTS-AT is wrong!" + if audio_cfg.model_name == "base": + model = HTSAT_Swin_Transformer( + spec_size=256, + patch_size=4, + patch_stride=(4, 4), + num_classes=audio_cfg.class_num, + embed_dim=128, + depths=[2, 2, 12, 2], + num_heads=[4, 8, 16, 32], + window_size=8, + config=audio_cfg, + enable_fusion=enable_fusion, + fusion_type=fusion_type, + ) + + return model + except: + raise RuntimeError( + f"Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough." + ) diff --git a/paddlemix/models/audioldm2/clap_module/model.py b/paddlemix/models/audioldm2/clap_module/model.py new file mode 100644 index 000000000..c1c213d2e --- /dev/null +++ b/paddlemix/models/audioldm2/clap_module/model.py @@ -0,0 +1,403 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional + +import numpy as np +import paddle +import paddle.nn.functional as F +import paddle.nn as nn +from dataclasses import dataclass + +import logging + +from .htsat_model import create_htsat_model +from paddlenlp.transformers import RobertaModel, BertModel, BartModel + + +class MLPLayers(nn.Layer): + def __init__(self, units=[512, 512, 512], nonlin=nn.ReLU(), dropout=0.1): + super(MLPLayers, self).__init__() + self.nonlin = nonlin + self.dropout = dropout + + sequence = [] + for u0, u1 in zip(units[:-1], units[1:]): + sequence.append(nn.Linear(u0, u1)) + sequence.append(self.nonlin) + sequence.append(nn.Dropout(self.dropout)) + sequence = sequence[:-2] + + self.sequential = nn.Sequential(*sequence) + + def forward(self, X): + X = self.sequential(X) + return X + +class ResidualAttentionBlock(nn.Layer): + def __init__(self, d_model: int, n_head: int, act_layer: Callable = nn.GELU): + super().__init__() + + self.attn = nn.MultiHeadAttention(d_model, n_head) + self.ln_1 = nn.LayerNorm(d_model) + self.mlp = nn.Sequential( + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", act_layer()), + ("c_proj", nn.Linear(d_model * 4, d_model)), + ) + self.ln_2 = nn.LayerNorm(d_model) + + def attention(self, x: paddle.Tensor, attn_mask: Optional[paddle.Tensor] = None): + return self.attn(x, x, x, attn_mask=attn_mask)[0] + + def forward(self, x: paddle.Tensor, attn_mask: Optional[paddle.Tensor] = None): + x = x + self.attention(self.ln_1(x), attn_mask=attn_mask) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Layer): + def __init__( + self, width: int, layers: int, heads: int, act_layer: Callable = nn.GELU + ): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.LayerList( + [ + ResidualAttentionBlock(width, heads, act_layer=act_layer) + for _ in range(layers) + ] + ) + + def forward(self, x: paddle.Tensor, attn_mask: Optional[paddle.Tensor] = None): + for r in self.resblocks: + x = r(x, attn_mask=attn_mask) + return x + +# Audio Config Class +@dataclass +class CLAPAudioCfg: + model_type: str = "HTSAT" + model_name: str = "base" + sample_rate: int = 48000 + audio_length: int = 1024 + window_size: int = 1024 + hop_size: int = 480 + fmin: int = 50 + fmax: int = 14000 + class_num: int = 527 + mel_bins: int = 64 + clip_samples: int = 480000 + +@dataclass +class CLAPTextCfg: + context_length: int = 77 + vocab_size: int = 49408 + width: int = 512 + heads: int = 8 + layers: int = 12 + model_type: str = "roberta" + +class CLAP(nn.Layer): + def __init__( + self, + embed_dim: int, + audio_cfg: CLAPAudioCfg, + text_cfg: CLAPTextCfg, + quick_gelu: bool = False, + enable_fusion: bool = False, + fusion_type: str = "None", + joint_embed_shape: int = 512, + mlp_act: str = "relu", + ): + super().__init__() + if isinstance(audio_cfg, dict): + audio_cfg = CLAPAudioCfg(**audio_cfg) + if isinstance(text_cfg, dict): + text_cfg = CLAPTextCfg(**text_cfg) + + self.audio_cfg = audio_cfg + self.text_cfg = text_cfg + self.enable_fusion = enable_fusion + self.fusion_type = fusion_type + self.joint_embed_shape = joint_embed_shape + self.mlp_act = mlp_act + + self.context_length = text_cfg.context_length + + act_layer = nn.GELU + + if mlp_act == "relu": + mlp_act_layer = nn.ReLU() + elif mlp_act == "gelu": + mlp_act_layer = nn.GELU() + else: + raise NotImplementedError + + # audio branch + # audio branch parameters + if audio_cfg.model_type == "PANN": + raise ValueError("PANN has not been implemented.") + elif audio_cfg.model_type == "HTSAT": + self.audio_branch = create_htsat_model( + audio_cfg, enable_fusion, fusion_type + ) + else: + logging.error(f"Model config for {audio_cfg.model_type} not found") + raise RuntimeError(f"Model config for {audio_cfg.model_type} not found.") + + # text branch + # text branch parameters + if text_cfg.model_type == "transformer": + self.text_branch = Transformer( + width=text_cfg.width, + layers=text_cfg.layers, + heads=text_cfg.heads, + act_layer=act_layer, + ) + self.vocab_size = text_cfg.vocab_size + self.token_embedding = nn.Embedding(text_cfg.vocab_size, text_cfg.width) + positional_embedding = paddle.empty([self.context_length, text_cfg.width]) + self.positional_embedding = paddle.create_parameter( + shape=positional_embedding.shape, + dtype=str(positional_embedding.numpy().dtype), + default_initializer=nn.initializer.Assign(positional_embedding) + ) + # self.ln_final = LayerNorm(text_cfg.width) + self.ln_final = nn.LayerNorm(text_cfg.width) + self.text_transform = MLPLayers( + units=[ + self.joint_embed_shape, + self.joint_embed_shape, + self.joint_embed_shape, + ], + dropout=0.1, + ) + self.text_projection = nn.Sequential( + nn.Linear(text_cfg.width, self.joint_embed_shape), + mlp_act_layer, + nn.Linear(self.joint_embed_shape, self.joint_embed_shape), + ) + elif text_cfg.model_type == "bert": + self.text_branch = BertModel.from_pretrained("bert-base-uncased") + self.text_transform = MLPLayers( + units=[ + self.joint_embed_shape, + self.joint_embed_shape, + self.joint_embed_shape, + ], + dropout=0.1, + ) + self.text_projection = nn.Sequential( + nn.Linear(768, self.joint_embed_shape), + mlp_act_layer, + nn.Linear(self.joint_embed_shape, self.joint_embed_shape), + ) + elif text_cfg.model_type == "roberta": + self.text_branch = RobertaModel.from_pretrained("roberta-base") + self.text_transform = MLPLayers( + units=[ + self.joint_embed_shape, + self.joint_embed_shape, + self.joint_embed_shape, + ], + dropout=0.1, + ) + self.text_projection = nn.Sequential( + nn.Linear(768, self.joint_embed_shape), + mlp_act_layer, + nn.Linear(self.joint_embed_shape, self.joint_embed_shape), + ) + elif text_cfg.model_type == "bart": + self.text_branch = BartModel.from_pretrained("bart-base") + self.text_transform = MLPLayers( + units=[ + self.joint_embed_shape, + self.joint_embed_shape, + self.joint_embed_shape, + ], + dropout=0.1, + ) + self.text_projection = nn.Sequential( + nn.Linear(768, self.joint_embed_shape), + mlp_act_layer, + nn.Linear(self.joint_embed_shape, self.joint_embed_shape), + ) + else: + logging.error(f"Model config for {text_cfg.model_type} not found") + raise RuntimeError(f"Model config for {text_cfg.model_type} not found.") + self.text_branch_type = text_cfg.model_type + # text branch parameters + + # audio branch parameters + self.audio_transform = MLPLayers( + units=[ + self.joint_embed_shape, + self.joint_embed_shape, + self.joint_embed_shape, + ], + dropout=0.1, + ) + + # below here is text branch parameters + + self.audio_projection = nn.Sequential( + nn.Linear(embed_dim, self.joint_embed_shape), + mlp_act_layer, + nn.Linear(self.joint_embed_shape, self.joint_embed_shape), + ) + + self.logit_scale_a = paddle.create_parameter([],"float32",default_initializer=nn.initializer.Assign(paddle.ones([])*np.log(1 / 0.07))) + self.logit_scale_t = paddle.create_parameter([],"float32",default_initializer=nn.initializer.Assign(paddle.ones([])*np.log(1 / 0.07))) + self.register_buffer("attn_mask", self.build_attention_mask(), persistable=False) + + def build_attention_mask(self): + + mask = paddle.empty([self.context_length, self.context_length]) * float("-inf") + # mask.fill_(float("-inf")) + mask = paddle.triu(mask, 1) # zero out the lower diagonal + # mask.triu_(1) # zero out the lower diagonal + return mask + + def encode_audio(self, audio): + return self.audio_branch( + audio, mixup_lambda=None + ) # mix lambda needs to add + + + def encode_text(self, text): + if self.text_branch_type == "transformer": + x = self.token_embedding(text) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding + x = x.transpose([1, 0, 2]) # NLD -> LND + x = self.text_branch(x, attn_mask=self.attn_mask) + x = x.transpose([1, 0, 2]) # LND -> NLD + x = self.ln_final(x) + + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = self.text_projection(x[paddle.arange(x.shape[0]), text.argmax(axis=-1)]) + elif self.text_branch_type == "bert": + x = self.text_branch( + input_ids=text["input_ids"], + attention_mask=text["attention_mask"], + token_type_ids=text["token_type_ids"], + return_dict=True, + )["pooler_output"] + x = self.text_projection(x) + elif self.text_branch_type == "roberta": + x = self.text_branch( + input_ids=text["input_ids"], + attention_mask=text["attention_mask"], + return_dict=True, + )["pooler_output"] + x = self.text_projection(x) + elif self.text_branch_type == "bart": + x = paddle.mean( + self.text_branch( + input_ids=text["input_ids"], + attention_mask=text["attention_mask"], + return_dict=True, + )["encoder_last_hidden_state"], + axis=1, + ) + x = self.text_projection(x) + else: + logging.error(f"Model type {self.text_branch_type} not found") + raise RuntimeError(f"Model type {self.text_branch_type} not found.") + return x + + def forward(self, audio, text): + """Forward audio and text into the CLAP + + Parameters + ---------- + audio: paddle.Tensor (batch_size, audio_length) + the time-domain audio input / the batch of mel_spec and longer list. + text: paddle.Tensor () // need to add + the text token input + """ + + if audio is None and text is None: + # a hack to get the logit scale + return self.logit_scale_a.exp(), self.logit_scale_t.exp() + elif audio is None: + return self.encode_text(text) + elif text is None: + return self.audio_projection( + self.encode_audio(audio)["embedding"] + ) + audio_features = self.audio_projection( + self.encode_audio(audio)["embedding"] + ) + audio_features = F.normalize(audio_features, axis=-1) + + text_features = self.encode_text(text) + text_features = F.normalize(text_features, axis=-1) + + audio_features_mlp = self.audio_transform(audio_features) + text_features_mlp = self.text_transform(text_features) + # Four outputs: audio features (basic & MLP), text features (basic & MLP) + return ( + audio_features, + text_features, + audio_features_mlp, + text_features_mlp, + self.logit_scale_a.exp(), + self.logit_scale_t.exp(), + ) + + def get_logit_scale(self): + return self.logit_scale_a.exp(), self.logit_scale_t.exp() + + def get_text_embedding(self, data): + """Get the text embedding from the model + + Parameters + ---------- + data: paddle.Tensor + a tensor of text embedding + + Returns + ---------- + text_embed: paddle.Tensor + a tensor of text_embeds (N, D) + + """ + text_embeds = self.encode_text(data) + text_embeds = F.normalize(text_embeds, axis=-1) + + return text_embeds + + def get_audio_embedding(self, data): + """Get the audio embedding from the model + + Parameters + ---------- + data: a list of dict + the audio input dict list from 'get_audio_feature' method + + Returns + ---------- + audio_embed: paddle.Tensor + a tensor of audio_embeds (N, D) + + """ + audio_embeds = self.audio_projection( + self.encode_audio(data)["embedding"] + ) + audio_embeds = F.normalize(audio_embeds, axis=-1) + + return audio_embeds diff --git a/paddlemix/models/audioldm2/clap_module/utils.py b/paddlemix/models/audioldm2/clap_module/utils.py new file mode 100644 index 000000000..51c6e76d7 --- /dev/null +++ b/paddlemix/models/audioldm2/clap_module/utils.py @@ -0,0 +1,344 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import numpy as np +import librosa + +def interpolate(x, ratio): + """Interpolate data in time domain. This is used to compensate the + resolution reduction in downsampling of a CNN. + + Args: + x: (batch_size, time_steps, classes_num) + ratio: int, ratio to interpolate + Returns: + upsampled: (batch_size, time_steps * ratio, classes_num) + """ + (batch_size, time_steps, classes_num) = x.shape + upsampled = x[:, :, None, :].repeat_interleave(ratio, 2) + upsampled = upsampled.reshape([batch_size, time_steps * ratio, classes_num]) + return upsampled + +def do_mixup(x, mixup_lambda): + """ + Args: + x: (batch_size , ...) + mixup_lambda: (batch_size,) + Returns: + out: (batch_size, ...) + """ + perm_shape = list(range(x.dim())) + new_perm_shape = perm_shape + new_perm_shape[0], new_perm_shape[-1] = perm_shape[-1], perm_shape[0] + out = ( + x.transpose(new_perm_shape) * mixup_lambda + + paddle.flip(x, axis=[0]).transpose(new_perm_shape) * (1 - mixup_lambda) + ).transpose(new_perm_shape) + return out + + +class DFTBase(nn.Layer): + def __init__(self): + r"""Base class for DFT and IDFT matrix. + """ + super(DFTBase, self).__init__() + + def dft_matrix(self, n): + (x, y) = np.meshgrid(np.arange(n), np.arange(n)) + omega = np.exp(-2 * np.pi * 1j / n) + W = np.power(omega, x * y) # shape: (n, n) + return W + + def idft_matrix(self, n): + (x, y) = np.meshgrid(np.arange(n), np.arange(n)) + omega = np.exp(2 * np.pi * 1j / n) + W = np.power(omega, x * y) # shape: (n, n) + return W + + +class STFT(DFTBase): + def __init__(self, n_fft=2048, hop_length=None, win_length=None, + window='hann', center=True, pad_mode='reflect', freeze_parameters=True): + r"""Paddle implementation of STFT with Conv1d. The function has the + same output as librosa.stft. + + Args: + n_fft: int, fft window size, e.g., 2048 + hop_length: int, hop length samples, e.g., 441 + win_length: int, window length e.g., 2048 + window: str, window function name, e.g., 'hann' + center: bool + pad_mode: str, e.g., 'reflect' + freeze_parameters: bool, set to True to freeze all parameters. Set + to False to finetune all parameters. + """ + super(STFT, self).__init__() + + assert pad_mode in ['constant', 'reflect'] + + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.window = window + self.center = center + self.pad_mode = pad_mode + + # By default, use the entire frame. + if self.win_length is None: + self.win_length = n_fft + + # Set the default hop, if it's not already specified. + if self.hop_length is None: + self.hop_length = int(self.win_length // 4) + + fft_window = librosa.filters.get_window(window, self.win_length, fftbins=True) + + # Pad the window out to n_fft size. + fft_window = librosa.util.pad_center(data=fft_window, size=n_fft) + + # DFT & IDFT matrix. + self.W = self.dft_matrix(n_fft) + + out_channels = n_fft // 2 + 1 + + weight_attr = paddle.ParamAttr( + initializer=nn.initializer.Assign( + paddle.to_tensor( + np.real(self.W[:, 0 : out_channels] * fft_window[:, None]).T)[:, None, :] + )) + self.conv_real = nn.Conv1D(in_channels=1, out_channels=out_channels, + kernel_size=n_fft, stride=self.hop_length, padding=0, dilation=1, + groups=1, weight_attr=weight_attr, bias_attr=False) + + weight_attr = paddle.ParamAttr( + initializer=nn.initializer.Assign( + paddle.to_tensor( + np.imag(self.W[:, 0 : out_channels] * fft_window[:, None]).T)[:, None, :] + )) + self.conv_imag = nn.Conv1D(in_channels=1, out_channels=out_channels, + kernel_size=n_fft, stride=self.hop_length, padding=0, dilation=1, + groups=1, weight_attr=weight_attr, bias_attr=False) + + if freeze_parameters: + for param in self.parameters(): + param.stop_gradient = True + + def forward(self, input): + r"""Calculate STFT of batch of signals. + + Args: + input: (batch_size, data_length), input signals. + + Returns: + real: (batch_size, 1, time_steps, n_fft // 2 + 1) + imag: (batch_size, 1, time_steps, n_fft // 2 + 1) + """ + + x = input[:, None, :] # (batch_size, channels_num, data_length) + + if self.center: + x = nn.functional.pad(x, pad=(self.n_fft // 2, self.n_fft // 2), mode=self.pad_mode, data_format="NCL") + + real = self.conv_real(x) + imag = self.conv_imag(x) + # (batch_size, n_fft // 2 + 1, time_steps) + + real = real[:, None, :, :].transpose([0, 1, 3, 2]) + imag = imag[:, None, :, :].transpose([0, 1, 3, 2]) + # (batch_size, 1, time_steps, n_fft // 2 + 1) + + return real, imag + + +class Spectrogram(nn.Layer): + def __init__(self, n_fft=2048, hop_length=None, win_length=None, + window='hann', center=True, pad_mode='reflect', power=2.0, + freeze_parameters=True): + r"""Calculate spectrogram using paddle. The STFT is implemented with + Conv1d. The function has the same output of librosa.stft + """ + super(Spectrogram, self).__init__() + + self.power = power + + self.stft = STFT(n_fft=n_fft, hop_length=hop_length, + win_length=win_length, window=window, center=center, + pad_mode=pad_mode, freeze_parameters=True) + + def forward(self, input): + r"""Calculate spectrogram of input signals. + Args: + input: (batch_size, data_length) + + Returns: + spectrogram: (batch_size, 1, time_steps, n_fft // 2 + 1) + """ + + (real, imag) = self.stft.forward(input) + # (batch_size, n_fft // 2 + 1, time_steps) + + spectrogram = real ** 2 + imag ** 2 + + if self.power == 2.0: + pass + else: + spectrogram = spectrogram ** (self.power / 2.0) + + return spectrogram + + +class LogmelFilterBank(nn.Layer): + def __init__(self, sr=22050, n_fft=2048, n_mels=64, fmin=0.0, fmax=None, + is_log=True, ref=1.0, amin=1e-10, top_db=80.0, freeze_parameters=True): + r"""Calculate logmel spectrogram using paddle. The mel filter bank is + the paddle implementation of as librosa.filters.mel + """ + super(LogmelFilterBank, self).__init__() + + self.is_log = is_log + self.ref = paddle.to_tensor(ref, dtype="float32") + self.amin = paddle.to_tensor(amin, dtype="float32") + self.top_db = top_db + if fmax == None: + fmax = sr//2 + + self.melW = librosa.filters.mel(sr=sr, n_fft=n_fft, n_mels=n_mels, + fmin=fmin, fmax=fmax).T + # (n_fft // 2 + 1, mel_bins) + + self.melW = paddle.to_tensor(self.melW) + self.melW = paddle.create_parameter( + self.melW.shape, + str(self.melW.numpy().dtype), + default_initializer=nn.initializer.Assign(self.melW) + ) + + if freeze_parameters: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, input): + r"""Calculate (log) mel spectrogram from spectrogram. + + Args: + input: (*, n_fft), spectrogram + + Returns: + output: (*, mel_bins), (log) mel spectrogram + """ + + # Mel spectrogram + mel_spectrogram = paddle.matmul(input, self.melW) + # (*, mel_bins) + + # Logmel spectrogram + if self.is_log: + output = self.power_to_db(mel_spectrogram) + else: + output = mel_spectrogram + + return output + + + def power_to_db(self, input): + r"""Power to db, this function is the paddle implementation of + librosa.power_to_lb + """ + ref_value = self.ref + log_spec = 10.0 * paddle.log10(paddle.clip(input, min=self.amin, max=None)) + log_spec -= 10.0 * paddle.log10(paddle.maximum(self.amin, ref_value)) + + if self.top_db is not None: + if self.top_db < 0: + raise librosa.util.exceptions.ParameterError('top_db must be non-negative') + log_spec = paddle.clip(log_spec, min=log_spec.max().item() - self.top_db, max=None) + + return log_spec + +class DropStripes(nn.Layer): + def __init__(self, dim, drop_width, stripes_num): + """Drop stripes. + + Args: + dim: int, dimension along which to drop + drop_width: int, maximum width of stripes to drop + stripes_num: int, how many stripes to drop + """ + super(DropStripes, self).__init__() + + assert dim in [2, 3] # dim 2: time; dim 3: frequency + + self.dim = dim + self.drop_width = drop_width + self.stripes_num = stripes_num + + def forward(self, input): + """input: (batch_size, channels, time_steps, freq_bins)""" + + assert input.ndim == 4 + + if self.training is False: + return input + + else: + batch_size = input.shape[0] + total_width = input.shape[self.dim] + + for n in range(batch_size): + self.transform_slice(input[n], total_width) + + return input + + def transform_slice(self, e, total_width): + """e: (channels, time_steps, freq_bins)""" + + for _ in range(self.stripes_num): + distance = paddle.randint(low=0, high=self.drop_width, shape=(1,))[0] + bgn = paddle.randint(low=0, high=total_width - distance, shape=(1,))[0] + + if self.dim == 2: + e[:, bgn : bgn + distance, :] = 0 + elif self.dim == 3: + e[:, :, bgn : bgn + distance] = 0 + + +class SpecAugmentation(nn.Layer): + def __init__(self, time_drop_width, time_stripes_num, freq_drop_width, + freq_stripes_num): + """Spec augmetation. + [ref] Park, D.S., Chan, W., Zhang, Y., Chiu, C.C., Zoph, B., Cubuk, E.D. + and Le, Q.V., 2019. Specaugment: A simple data augmentation method + for automatic speech recognition. arXiv preprint arXiv:1904.08779. + + Args: + time_drop_width: int + time_stripes_num: int + freq_drop_width: int + freq_stripes_num: int + """ + + super(SpecAugmentation, self).__init__() + + self.time_dropper = DropStripes(dim=2, drop_width=time_drop_width, + stripes_num=time_stripes_num) + + self.freq_dropper = DropStripes(dim=3, drop_width=freq_drop_width, + stripes_num=freq_stripes_num) + + def forward(self, input): + x = self.time_dropper(input) + x = self.freq_dropper(x) + return x diff --git a/paddlemix/models/audioldm2/configuration.py b/paddlemix/models/audioldm2/configuration.py new file mode 100644 index 000000000..7590ed74e --- /dev/null +++ b/paddlemix/models/audioldm2/configuration.py @@ -0,0 +1,204 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Union +from paddlenlp.transformers.configuration_utils import PretrainedConfig +from paddlemix.utils.log import logger + +__all__ = ["AudioLDM2Config"] + +class AudioLDM2Config(PretrainedConfig): + + model_type = "audioldm2" + + def __init__( + self, + model_name: str = "audioldm2-full", + first_stage_key: str = "fbank", + sampling_rate: int = 16000, + parameterization: str = "eps", + log_every_t: int = 200, + latent_t_size: int = 256, + latent_f_size: int = 16, + channels: int = 8, + timesteps: int = 1000, + num_timesteps_cond: int = 1, + linear_start: float = 0.0015, + linear_end: float = 0.0195, + unconditional_prob_cfg: float = 0.1, + device: str = "gpu", + unet_image_size: int = 64, + unet_context_dim: list = [768, 1024], + unet_in_channels: int = 8, + unet_out_channels: int = 8, + unet_model_channels: int = 128, + unet_attention_resolutions: list = [8, 4, 2], + unet_num_res_blocks: int = 2, + unet_channel_mult: list = [1, 2, 3, 5], + unet_num_head_channels: int = 32, + unet_use_spatial_transformer: bool = True, + unet_transformer_depth: int = 1, + autoencoder_sampling_rate: int = 16000, + autoencoder_batchsize: int = 4, + autoencoder_image_key: str = "fbank", + autoencoder_subband: int = 1, + autoencoder_embed_dim: int = 8, + autoencoder_time_shuffle: int = 1, + ddconfig_double_z: bool = True, + ddconfig_mel_bins: int = 64, + ddconfig_z_channels: int = 8, + ddconfig_resolution: int = 256, + ddconfig_downsample_time: bool = False, + ddconfig_in_channels: int = 1, + ddconfig_out_ch: int = 1, + ddconfig_ch: int = 128, + ddconfig_ch_mult: list = [1, 2, 4], + ddconfig_num_res_blocks: int = 2, + ddconfig_attn_resolutions: list = [], + ddconfig_dropout: float = 0.0, + sequence2audiomae_always_output_audiomae_gt: bool = False, + sequence2audiomae_learnable: bool = True, + sequence2audiomae_use_gt_mae_output: bool = True, + sequence2audiomae_use_gt_mae_prob: float = 0.0, + sequence2audiomae_base_learning_rate: float = 0.0002, + sequence2audiomae_sequence_gen_length: int = 8, + sequence2audiomae_use_warmup: bool = True, + sequence2audiomae_sequence_input_key: list = ['film_clap_cond1', 'crossattn_flan_t5'], + sequence2audiomae_sequence_input_embed_dim: list = [512, 1024], + sequence2audiomae_batchsize: int = 16, + sequence2audiomae_cond_stage_configs: dict = None, + **kwargs, + ): + kwargs["return_dict"] = kwargs.pop("return_dict", True) + super().__init__(**kwargs) + self.first_stage_key = first_stage_key + self.sampling_rate = sampling_rate + self.parameterization = parameterization + self.log_every_t = log_every_t + self.latent_t_size = latent_t_size + self.latent_f_size = latent_f_size + self.channels = channels + self.timesteps = timesteps + self.num_timesteps_cond = num_timesteps_cond + self.linear_start = linear_start + self.linear_end = linear_end + self.unconditional_prob_cfg = unconditional_prob_cfg + self.device = device + + self.unet_config = {} + self.unet_config["target"] = ".unet.openaimodel.UNetModel" + self.unet_config["params"] = {} + self.unet_config["params"]["image_size"] = unet_image_size + self.unet_config["params"]["context_dim"] = unet_context_dim + self.unet_config["params"]["in_channels"] = unet_in_channels + self.unet_config["params"]["out_channels"] = unet_out_channels + self.unet_config["params"]["model_channels"] = unet_model_channels + self.unet_config["params"]["attention_resolutions"] = unet_attention_resolutions + self.unet_config["params"]["num_res_blocks"] = unet_num_res_blocks + self.unet_config["params"]["channel_mult"] = unet_channel_mult + self.unet_config["params"]["num_head_channels"] = unet_num_head_channels + self.unet_config["params"]["use_spatial_transformer"] = unet_use_spatial_transformer + self.unet_config["params"]["transformer_depth"] = unet_transformer_depth + + self.first_stage_config = {} + self.first_stage_config["target"] = ".latent_encoder.autoencoder.AudioLDMAutoencoderKL" + self.first_stage_config["params"] = {} + self.first_stage_config["params"]["sampling_rate"] = autoencoder_sampling_rate + self.first_stage_config["params"]["batchsize"] = autoencoder_batchsize + self.first_stage_config["params"]["image_key"] = autoencoder_image_key + self.first_stage_config["params"]["subband"] = autoencoder_subband + self.first_stage_config["params"]["embed_dim"] = autoencoder_embed_dim + self.first_stage_config["params"]["time_shuffle"] = autoencoder_time_shuffle + + self.first_stage_config["params"]["ddconfig"] = {} + self.first_stage_config["params"]["ddconfig"]["double_z"] = ddconfig_double_z + self.first_stage_config["params"]["ddconfig"]["mel_bins"] = ddconfig_mel_bins + self.first_stage_config["params"]["ddconfig"]["z_channels"] = ddconfig_z_channels + self.first_stage_config["params"]["ddconfig"]["resolution"] = ddconfig_resolution + self.first_stage_config["params"]["ddconfig"]["downsample_time"] = ddconfig_downsample_time + self.first_stage_config["params"]["ddconfig"]["in_channels"] = ddconfig_in_channels + self.first_stage_config["params"]["ddconfig"]["out_ch"] = ddconfig_out_ch + self.first_stage_config["params"]["ddconfig"]["ch"] = ddconfig_ch + self.first_stage_config["params"]["ddconfig"]["ch_mult"] = ddconfig_ch_mult + self.first_stage_config["params"]["ddconfig"]["num_res_blocks"] = ddconfig_num_res_blocks + self.first_stage_config["params"]["ddconfig"]["attn_resolutions"] = ddconfig_attn_resolutions + self.first_stage_config["params"]["ddconfig"]["dropout"] = ddconfig_dropout + + self.cond_stage_config = {} + self.cond_stage_config["crossattn_audiomae_generated"] = {} + self.cond_stage_config["crossattn_audiomae_generated"]["cond_stage_key"] = "all" + self.cond_stage_config["crossattn_audiomae_generated"]["conditioning_key"] = "crossattn" + self.cond_stage_config["crossattn_audiomae_generated"]["target"] = ".encoders.sequence2audiomae_encoder.SequenceGenAudioMAECond" # gpt2 + self.cond_stage_config["crossattn_audiomae_generated"]["params"] = {} + self.cond_stage_config["crossattn_audiomae_generated"]["params"]["always_output_audiomae_gt"] = sequence2audiomae_always_output_audiomae_gt + self.cond_stage_config["crossattn_audiomae_generated"]["params"]["learnable"] = sequence2audiomae_learnable + self.cond_stage_config["crossattn_audiomae_generated"]["params"]["use_gt_mae_output"] = sequence2audiomae_use_gt_mae_output + self.cond_stage_config["crossattn_audiomae_generated"]["params"]["use_gt_mae_prob"] = sequence2audiomae_use_gt_mae_prob + self.cond_stage_config["crossattn_audiomae_generated"]["params"]["base_learning_rate"] = sequence2audiomae_base_learning_rate + self.cond_stage_config["crossattn_audiomae_generated"]["params"]["sequence_gen_length"] = sequence2audiomae_sequence_gen_length + self.cond_stage_config["crossattn_audiomae_generated"]["params"]["use_warmup"] = sequence2audiomae_use_warmup + self.cond_stage_config["crossattn_audiomae_generated"]["params"]["sequence_input_key"] = sequence2audiomae_sequence_input_key + self.cond_stage_config["crossattn_audiomae_generated"]["params"]["sequence_input_embed_dim"] = sequence2audiomae_sequence_input_embed_dim + self.cond_stage_config["crossattn_audiomae_generated"]["params"]["batchsize"] = sequence2audiomae_batchsize + + if "speech" not in model_name: + self.cond_stage_config["crossattn_flan_t5"] = {} + self.cond_stage_config["crossattn_flan_t5"]["cond_stage_key"] = "text" + self.cond_stage_config["crossattn_flan_t5"]["conditioning_key"] = "crossattn" + self.cond_stage_config["crossattn_flan_t5"]["target"] = ".encoders.flant5_encoder.FlanT5HiddenState" + + if sequence2audiomae_cond_stage_configs is None: + self.cond_stage_config["crossattn_audiomae_generated"]["params"]["cond_stage_config"] = {} + self.cond_stage_config["crossattn_audiomae_generated"]["params"]["cond_stage_config"]["film_clap_cond1"] = {} + self.cond_stage_config["crossattn_audiomae_generated"]["params"]["cond_stage_config"]["film_clap_cond1"]["cond_stage_key"] = "text" + self.cond_stage_config["crossattn_audiomae_generated"]["params"]["cond_stage_config"]["film_clap_cond1"]["conditioning_key"] = "film" + self.cond_stage_config["crossattn_audiomae_generated"]["params"]["cond_stage_config"]["film_clap_cond1"]["target"] = ".encoders.clap_encoder.CLAPAudioEmbeddingClassifierFreev2" + self.cond_stage_config["crossattn_audiomae_generated"]["params"]["cond_stage_config"]["film_clap_cond1"]["params"] = {} + self.cond_stage_config["crossattn_audiomae_generated"]["params"]["cond_stage_config"]["film_clap_cond1"]["params"]["sampling_rate"] = 48000 + self.cond_stage_config["crossattn_audiomae_generated"]["params"]["cond_stage_config"]["film_clap_cond1"]["params"]["embed_mode"] = "text" + self.cond_stage_config["crossattn_audiomae_generated"]["params"]["cond_stage_config"]["film_clap_cond1"]["params"]["amodel"] = "HTSAT-base" + + self.cond_stage_config["crossattn_audiomae_generated"]["params"]["cond_stage_config"]["crossattn_flan_t5"] = {} + self.cond_stage_config["crossattn_audiomae_generated"]["params"]["cond_stage_config"]["crossattn_flan_t5"]["cond_stage_key"] = "text" + self.cond_stage_config["crossattn_audiomae_generated"]["params"]["cond_stage_config"]["crossattn_flan_t5"]["conditioning_key"] = "crossattn" + self.cond_stage_config["crossattn_audiomae_generated"]["params"]["cond_stage_config"]["crossattn_flan_t5"]["target"] = ".encoders.flant5_encoder.FlanT5HiddenState" + + self.cond_stage_config["crossattn_audiomae_generated"]["params"]["cond_stage_config"]["crossattn_audiomae_pooled"] = {} + self.cond_stage_config["crossattn_audiomae_generated"]["params"]["cond_stage_config"]["crossattn_audiomae_pooled"]["cond_stage_key"] = "ta_kaldi_fbank" + self.cond_stage_config["crossattn_audiomae_generated"]["params"]["cond_stage_config"]["crossattn_audiomae_pooled"]["conditioning_key"] = "crossattn" + self.cond_stage_config["crossattn_audiomae_generated"]["params"]["cond_stage_config"]["crossattn_audiomae_pooled"]["target"] = ".encoders.audiomae_encoder.AudioMAEConditionCTPoolRand" + self.cond_stage_config["crossattn_audiomae_generated"]["params"]["cond_stage_config"]["crossattn_audiomae_pooled"]["params"] = {} + self.cond_stage_config["crossattn_audiomae_generated"]["params"]["cond_stage_config"]["crossattn_audiomae_pooled"]["params"]["regularization"] = False + self.cond_stage_config["crossattn_audiomae_generated"]["params"]["cond_stage_config"]["crossattn_audiomae_pooled"]["params"]["no_audiomae_mask"] = True + self.cond_stage_config["crossattn_audiomae_generated"]["params"]["cond_stage_config"]["crossattn_audiomae_pooled"]["params"]["time_pooling_factors"] = [8] + self.cond_stage_config["crossattn_audiomae_generated"]["params"]["cond_stage_config"]["crossattn_audiomae_pooled"]["params"]["freq_pooling_factors"] = [8] + self.cond_stage_config["crossattn_audiomae_generated"]["params"]["cond_stage_config"]["crossattn_audiomae_pooled"]["params"]["eval_time_pooling"] = 8 + self.cond_stage_config["crossattn_audiomae_generated"]["params"]["cond_stage_config"]["crossattn_audiomae_pooled"]["params"]["eval_freq_pooling"] = 8 + self.cond_stage_config["crossattn_audiomae_generated"]["params"]["cond_stage_config"]["crossattn_audiomae_pooled"]["params"]["mask_ratio"] = 0 + else: + self.cond_stage_config["crossattn_audiomae_generated"]["params"]["cond_stage_config"] = sequence2audiomae_cond_stage_configs + + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) diff --git a/paddlemix/models/audioldm2/diffusionwrapper.py b/paddlemix/models/audioldm2/diffusionwrapper.py new file mode 100644 index 000000000..37b391b33 --- /dev/null +++ b/paddlemix/models/audioldm2/diffusionwrapper.py @@ -0,0 +1,166 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +from inspect import isfunction +import importlib +import numpy as np + +class DiffusionWrapper(nn.Layer): + def __init__(self, diff_model_config, conditioning_key): + super().__init__() + self.diffusion_model = instantiate_from_config(diff_model_config) + + self.conditioning_key = conditioning_key + + for key in self.conditioning_key: + if ( + "concat" in key + or "crossattn" in key + or "hybrid" in key + or "film" in key + or "noncond" in key + ): + continue + else: + raise ValueError("The conditioning key %s is illegal" % key) + + self.being_verbosed_once = False + + def forward(self, x, t, cond_dict: dict = {}): + # x with condition (or maybe not) + xc = x + + y = None + context_list, attn_mask_list = [], [] + + conditional_keys = cond_dict.keys() + + for key in conditional_keys: + if "concat" in key: + xc = paddle.concat([x, cond_dict[key].unsqueeze(1)], axis=1) + elif "film" in key: + if y is None: + y = cond_dict[key].squeeze(1) + else: + y = paddle.concat([y, cond_dict[key].squeeze(1)], axis=-1) + elif "crossattn" in key: + # assert context is None, "You can only have one context matrix, got %s" % (cond_dict.keys()) + if isinstance(cond_dict[key], dict): + for k in cond_dict[key].keys(): + if "crossattn" in k: + context, attn_mask = cond_dict[key][ + k + ] # crossattn_audiomae_pooled: paddle.Size([12, 128, 768]) + else: + assert len(cond_dict[key]) == 2, ( + "The context condition for %s you returned should have two element, one context one mask" + % (key) + ) + context, attn_mask = cond_dict[key] + + # The input to the UNet model is a list of context matrix + context_list.append(context) + attn_mask_list.append(attn_mask) + + elif ( + "noncond" in key + ): # If you use loss function in the conditional module, include the keyword "noncond" in the return dictionary + continue + else: + raise NotImplementedError() + + out = self.diffusion_model( + xc, t, context_list=context_list, y=y, context_attn_mask_list=attn_mask_list + ) + + return out + +def instantiate_from_config(config): + if not "target" in config: + if config == "__is_first_stage__": + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package="paddlemix.models.audioldm2"), cls) + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + tmp = float(total_params * 1.e-6) + print(f"{model.__class__.__name__} has {tmp:.2f} M params.") + return total_params + +def make_beta_schedule( + schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 +): + if schedule == "linear": + betas = ( + paddle.linspace( + linear_start**0.5, linear_end**0.5, n_timestep, dtype="float64" + ) + ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + paddle.arange(n_timestep + 1, dtype="float64") / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = paddle.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = paddle.linspace( + linear_start, linear_end, n_timestep, dtype="float64" + ) + elif schedule == "sqrt": + betas = ( + paddle.linspace(linear_start, linear_end, n_timestep, dtype="float64") + ** 0.5 + ) + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(t, -1) + return out.reshape((b,) + ((1,) * (len(x_shape) - 1))) + +def noise_like(shape, repeat=False): + repeat_noise = lambda: paddle.randn((1, *shape[1:])).repeat_interleave(repeats=shape[0], axis=0) + noise = lambda: paddle.randn(shape) + return repeat_noise() if repeat else noise() + +def default(val, d): + if val is not None: + return val + return d() if isfunction(d) else d + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self diff --git a/paddlemix/models/audioldm2/encoders/audiomae_encoder.py b/paddlemix/models/audioldm2/encoders/audiomae_encoder.py new file mode 100644 index 000000000..e06e2e687 --- /dev/null +++ b/paddlemix/models/audioldm2/encoders/audiomae_encoder.py @@ -0,0 +1,155 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import numpy as np +from ..audiomae import mae as models_mae + +class Vanilla_AudioMAE(nn.Layer): + """Audio Masked Autoencoder (MAE) pre-trained on AudioSet (for AudioLDM2)""" + + def __init__( + self, + ): + super().__init__() + model = models_mae.__dict__["mae_vit_base_patch16"]( + in_chans=1, audio_exp=True, img_size=(1024, 128) + ) + + self.model = model.eval() + + def forward(self, x, mask_ratio=0.0, no_mask=False, no_average=False): + """ + x: mel fbank [Batch, 1, 1024 (T), 128 (F)] + mask_ratio: 'masking ratio (percentage of removed patches).' + """ + with paddle.no_grad(): + # embed: [B, 513, 768] for mask_ratio=0.0 + if no_mask: + if no_average: + raise RuntimeError("This function is deprecated") + else: + embed = self.model.forward_encoder_no_mask(x) # mask_ratio + else: + raise RuntimeError("This function is deprecated") + return embed + +class AudioMAEConditionCTPoolRand(nn.Layer): + def __init__( + self, + time_pooling_factors=[1, 2, 4, 8], + freq_pooling_factors=[1, 2, 4, 8], + eval_time_pooling=None, + eval_freq_pooling=None, + mask_ratio=0.0, + regularization=False, + no_audiomae_mask=True, + no_audiomae_average=False, + ): + super().__init__() + self.device = None + self.time_pooling_factors = time_pooling_factors + self.freq_pooling_factors = freq_pooling_factors + self.no_audiomae_mask = no_audiomae_mask + self.no_audiomae_average = no_audiomae_average + + self.eval_freq_pooling = eval_freq_pooling + self.eval_time_pooling = eval_time_pooling + self.mask_ratio = mask_ratio + self.use_reg = regularization + + self.audiomae = Vanilla_AudioMAE() + self.audiomae.eval() + for p in self.audiomae.parameters(): + p.stop_gradient = True + + # Required + def get_unconditional_condition(self, batchsize): + param = self.audiomae.parameters()[0] + assert param.stop_gradient == True + + time_pool, freq_pool = min(self.eval_time_pooling, 64), min( + self.eval_freq_pooling, 8 + ) + + token_num = int(512 / (time_pool * freq_pool)) + return [ + paddle.zeros((batchsize, token_num, 768), dtype="float32"), + paddle.ones((batchsize, token_num), dtype="float32"), + ] + + def pool(self, representation, time_pool=None, freq_pool=None): + assert representation.shape[-1] == 768 + representation = representation[:, 1:, :] + perm = list(range(representation.dim())) + new_perm = perm + new_perm[1], new_perm[2] = perm[2], perm[1] + representation = representation.transpose(new_perm) + bs, embedding_dim, token_num = representation.shape + representation = representation.reshape([bs, embedding_dim, 64, 8]) + + if self.training: + if time_pool is None and freq_pool is None: + time_pool = min( + 64, + self.time_pooling_factors[ + np.random.choice(list(range(len(self.time_pooling_factors)))) + ], + ) + freq_pool = min(8, time_pool) # TODO here I make some modification. + else: + time_pool, freq_pool = min(self.eval_time_pooling, 64), min( + self.eval_freq_pooling, 8 + ) + + self.avgpooling = nn.AvgPool2D( + kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool) + ) + self.maxpooling = nn.MaxPool2D( + kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool) + ) + + pooled = ( + self.avgpooling(representation) + self.maxpooling(representation) + ) / 2 # [bs, embedding_dim, time_token_num, freq_token_num] + pooled = pooled.flatten(2).transpose([0, 2, 1]) + return pooled # [bs, token_num, embedding_dim] + + def regularization(self, x): + assert x.shape[-1] == 768 + x = nn.functional.normalize(x, p=2, axis=-1) + return x + + # Required + def forward(self, batch, time_pool=None, freq_pool=None): + assert batch.shape[-2] == 1024 and batch.shape[-1] == 128 + + batch = batch.unsqueeze(1) + with paddle.no_grad(): + representation = self.audiomae( + batch, + mask_ratio=self.mask_ratio, + no_mask=self.no_audiomae_mask, + no_average=self.no_audiomae_average, + ) + + representation = self.pool(representation, time_pool, freq_pool) + if self.use_reg: + representation = self.regularization(representation) + return [ + representation, + paddle.ones((representation.shape[0], representation.shape[1]), dtype="float32"), + ] + \ No newline at end of file diff --git a/paddlemix/models/audioldm2/encoders/clap_encoder.py b/paddlemix/models/audioldm2/encoders/clap_encoder.py new file mode 100644 index 000000000..fe0f35660 --- /dev/null +++ b/paddlemix/models/audioldm2/encoders/clap_encoder.py @@ -0,0 +1,395 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +import warnings +from paddle.audio.features import MelSpectrogram +from ..clap_module.clap import create_clap_model +from paddlenlp.transformers.roberta.tokenizer import RobertaTokenizer +from typing import Optional + +def get_audio_features( + audio_data, mel, max_len, data_truncating, data_filling, audio_cfg +): + """ + Calculate and add audio features to sample. + Sample: a dict containing all the data of current sample. + audio_data: a tensor of shape (T) containing audio data. + max_len: the maximum length of audio data. + data_truncating: the method of truncating data. + data_filling: the method of filling data. + audio_cfg: a dict containing audio configuration. Comes from model_cfg['audio_cfg']. + """ + sample = {} + + # assert audio_data.size(-1) <= max_len, str(audio_data.size()) + + # split to three parts + chunk_frames = ( + max_len // audio_cfg["hop_size"] + 1 + ) # the +1 related to how the spectrogram is computed + mel = mel[:chunk_frames] + + audio_data = audio_data[..., :max_len] + sample["mel_fusion"] = mel + longer = paddle.to_tensor([True], dtype="bool") + + sample["longer"] = longer + sample["waveform"] = audio_data + + return sample + +def _get_sinc_resample_kernel( + orig_freq: int, + new_freq: int, + gcd: int, + lowpass_filter_width: int = 6, + rolloff: float = 0.99, + resampling_method: str = "sinc_interp_hann", + beta: Optional[float] = None, + dtype: Optional[paddle.dtype] = None, +): + if not (int(orig_freq) == orig_freq and int(new_freq) == new_freq): + raise Exception( + "Frequencies must be of integer type to ensure quality resampling computation. " + ) + + if resampling_method in ["sinc_interpolation", "kaiser_window"]: + method_map = { + "sinc_interpolation": "sinc_interp_hann", + "kaiser_window": "sinc_interp_kaiser", + } + warnings.warn( + f'"{resampling_method}" resampling method name is being deprecated and replaced by ' + f'"{method_map[resampling_method]}" in the next release. ' + "The default behavior remains unchanged.", + stacklevel=3, + ) + elif resampling_method not in ["sinc_interp_hann", "sinc_interp_kaiser"]: + raise ValueError("Invalid resampling method: {}".format(resampling_method)) + + orig_freq = int(orig_freq) // gcd + new_freq = int(new_freq) // gcd + + if lowpass_filter_width <= 0: + raise ValueError("Low pass filter width should be positive.") + base_freq = min(orig_freq, new_freq) + # This will perform antialiasing filtering by removing the highest frequencies. + base_freq *= rolloff + + width = math.ceil(lowpass_filter_width * orig_freq / base_freq) + # If orig_freq is still big after GCD reduction, most filters will be very unbalanced, i.e., + # they will have a lot of almost zero values to the left or to the right... + # There is probably a way to evaluate those filters more efficiently, but this is kept for + # future work. + idx_dtype = dtype if dtype is not None else paddle.float64 + + idx = paddle.arange(-width, width + orig_freq, dtype=idx_dtype)[None, None] / orig_freq + + t = paddle.arange(0, -new_freq, -1, dtype=dtype)[:, None, None] / new_freq + idx + t *= base_freq + t = t.clip_(-lowpass_filter_width, lowpass_filter_width) + + if resampling_method == "sinc_interp_hann": + window = paddle.cos(t * math.pi / lowpass_filter_width / 2) ** 2 + else: + # sinc_interp_kaiser + if beta is None: + beta = 14.769656459379492 + beta_tensor = paddle.to_tensor(float(beta)) + window = paddle.i0(beta_tensor * paddle.sqrt(1 - (t / lowpass_filter_width) ** 2)) / paddle.i0(beta_tensor) + + t *= math.pi + + scale = base_freq / orig_freq + kernels = paddle.where(t == 0, paddle.to_tensor(1.0, dtype=t.dtype), t.sin() / t) + kernels *= window * scale + + if dtype is None: + kernels = paddle.cast(kernels, dtype=paddle.float32) + + return kernels, width + +def _apply_sinc_resample_kernel( + waveform: paddle.Tensor, + orig_freq: int, + new_freq: int, + gcd: int, + kernel: paddle.Tensor, + width: int, +): + if not "float" in str(waveform.dtype): + raise TypeError(f"Expected floating point type for waveform tensor, but received {waveform.dtype}.") + + orig_freq = int(orig_freq) // gcd + new_freq = int(new_freq) // gcd + + # pack batch + shape = waveform.shape + waveform = waveform.reshape([-1, shape[-1]]) + + num_wavs, length = waveform.shape + waveform = nn.functional.pad(waveform.unsqueeze(0), (width, width + orig_freq), data_format='NCL').squeeze(0) + resampled = nn.functional.conv1d(waveform[:, None], kernel, stride=orig_freq) + perm_shape = list(range(resampled.dim())) + new_perm_shape = perm_shape + new_perm_shape[1], new_perm_shape[2] = perm_shape[2], perm_shape[1] + resampled = resampled.transpose(new_perm_shape).reshape([num_wavs, -1]) + target_length = paddle.cast(paddle.ceil(paddle.to_tensor(new_freq * length / orig_freq)), dtype="int64") + resampled = resampled[..., :target_length] + + # unpack batch + resampled = resampled.reshape(shape[:-1] + resampled.shape[-1:]) + return resampled + + +def resample( + waveform: paddle.Tensor, + orig_freq: int, + new_freq: int, + lowpass_filter_width: int = 6, + rolloff: float = 0.99, + resampling_method: str = "sinc_interp_hann", + beta: Optional[float] = None, +) -> paddle.Tensor: + r"""Resamples the waveform at the new frequency using bandlimited interpolation. :cite:`RESAMPLE`. + + Note: + ``transforms.Resample`` precomputes and reuses the resampling kernel, so using it will result in + more efficient computation if resampling multiple waveforms with the same resampling parameters. + + Args: + waveform (Tensor): The input signal of dimension `(..., time)` + orig_freq (int): The original frequency of the signal + new_freq (int): The desired frequency + lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper + but less efficient. (Default: ``6``) + rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist. + Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``) + resampling_method (str, optional): The resampling method to use. + Options: [``"sinc_interp_hann"``, ``"sinc_interp_kaiser"``] (Default: ``"sinc_interp_hann"``) + beta (float or None, optional): The shape parameter used for kaiser window. + + Returns: + Tensor: The waveform at the new frequency of dimension `(..., time).` + """ + + if orig_freq <= 0.0 or new_freq <= 0.0: + raise ValueError("Original frequency and desired frequecy should be positive") + + if orig_freq == new_freq: + return waveform + + gcd = math.gcd(int(orig_freq), int(new_freq)) + + kernel, width = _get_sinc_resample_kernel( + orig_freq, + new_freq, + gcd, + lowpass_filter_width, + rolloff, + resampling_method, + beta, + waveform.dtype, + ) + resampled = _apply_sinc_resample_kernel(waveform, orig_freq, new_freq, gcd, kernel, width) + return resampled + +class CLAPAudioEmbeddingClassifierFreev2(nn.Layer): + def __init__( + self, + pretrained_path="", + enable_cuda=False, + sampling_rate=16000, + embed_mode="audio", + amodel="HTSAT-base", + unconditional_prob=0.1, + random_mute=False, + max_random_mute_portion=0.5, + training_mode=True, + ): + super().__init__() + self.device = "cpu" # The model itself is on cpu + self.cuda = enable_cuda + self.precision = "fp32" + self.amodel = amodel # or 'PANN-14' + self.tmodel = "roberta" # the best text encoder in our training + self.enable_fusion = False # False if you do not want to use the fusion model + self.fusion_type = "aff_2d" + self.pretrained = pretrained_path + self.embed_mode = embed_mode + self.embed_mode_orig = embed_mode + self.sampling_rate = sampling_rate + self.unconditional_prob = unconditional_prob + self.random_mute = random_mute + self.tokenize = RobertaTokenizer.from_pretrained("roberta-base") + self.max_random_mute_portion = max_random_mute_portion + self.training_mode = training_mode + self.model, self.model_cfg = create_clap_model( + self.amodel, + self.tmodel, + self.pretrained, + precision=self.precision, + enable_fusion=self.enable_fusion, + fusion_type=self.fusion_type, + ) + self.model = self.model.to(self.device) + audio_cfg = self.model_cfg["audio_cfg"] + self.mel_transform = MelSpectrogram( + sr=audio_cfg["sample_rate"], + n_fft=audio_cfg["window_size"], + hop_length=audio_cfg["hop_size"], + win_length=audio_cfg["window_size"], + power=2.0, + center=True, + pad_mode="reflect", + # onesided=True, + n_mels=64, + f_min=audio_cfg["fmin"], + f_max=audio_cfg["fmax"], + norm=None, + ) + for p in self.model.parameters(): + # p.requires_grad = False + p.stop_gradient = True + self.unconditional_token = None + self.model.eval() + + def get_unconditional_condition(self, batchsize): + self.unconditional_token = self.model.get_text_embedding( + self.tokenizer(["", ""]) + )[0:1] + return paddle.concat([self.unconditional_token.unsqueeze(0)] * batchsize, axis=0) + + def batch_to_list(self, batch): + ret = [] + for i in range(batch.size(0)): + ret.append(batch[i]) + return ret + + def make_decision(self, probability): + if float(paddle.rand([])) < probability: + return True + else: + return False + + def random_uniform(self, start, end): + val = paddle.rand([]).item() + return start + (end - start) * val + + def _random_mute(self, waveform): + # waveform: [bs, t-steps] + t_steps = waveform.shape[-1] + for i in range(waveform.shape[0]): + mute_size = int( + self.random_uniform(0, end=int(t_steps * self.max_random_mute_portion)) + ) + mute_start = int(self.random_uniform(0, t_steps - mute_size)) + waveform[i, mute_start : mute_start + mute_size] = 0 + return waveform + + def cos_similarity(self, waveform, text): + # waveform: [bs, t_steps] + original_embed_mode = self.embed_mode + with paddle.no_grad(): + self.embed_mode = "audio" + audio_emb = self(waveform) + self.embed_mode = "text" + text_emb = self(text) + similarity = F.cosine_similarity(audio_emb, text_emb, axis=2) + self.embed_mode = original_embed_mode + return similarity.squeeze() + + def build_unconditional_emb(self): + self.unconditional_token = self.model.get_text_embedding( + self.tokenizer(["", ""]) + )[0:1] + + def forward(self, batch): + # If you want this conditioner to be unconditional, set self.unconditional_prob = 1.0 + # If you want this conditioner to be fully conditional, set self.unconditional_prob = 0.0 + if self.model.training == True and not self.training_mode: + print( + "The pretrained CLAP model should always be in eval mode. Reloading model just in case you change the parameters." + ) + self.model, self.model_cfg = create_clap_model( + self.amodel, + self.tmodel, + self.pretrained, + precision=self.precision, + device="cuda" if self.cuda else "cpu", + enable_fusion=self.enable_fusion, + fusion_type=self.fusion_type, + ) + for p in self.model.parameters(): + # p.requires_grad = False + p.stop_gradient = True + self.model.eval() + + if self.unconditional_token is None: + self.build_unconditional_emb() + + # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode + if self.embed_mode == "audio": + if not self.training: + print("INFO: clap model calculate the audio embedding as condition") + with paddle.no_grad(): + if self.sampling_rate != 48000: + batch = resample( + batch, orig_freq=self.sampling_rate, new_freq=48000 + ) + audio_data = batch.squeeze(1) + mel = self.mel_transform(audio_data) + audio_dict = get_audio_features( + audio_data, + mel, + 480000, + data_truncating="fusion", + data_filling="repeatpad", + audio_cfg=self.model_cfg["audio_cfg"], + ) + # [bs, 512] + embed = self.model.get_audio_embedding(audio_dict) + elif self.embed_mode == "text": + with paddle.no_grad(): + # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode + text_data = self.tokenizer(batch) + + if isinstance(batch, str) or ( + isinstance(batch, list) and len(batch) == 1 + ): + for key in text_data.keys(): + text_data[key] = text_data[key].unsqueeze(0) + + embed = self.model.get_text_embedding(text_data) + + embed = embed.unsqueeze(1) + for i in range(embed.shape[0]): + if self.make_decision(self.unconditional_prob): + embed[i] = self.unconditional_token + return embed.detach() + + def tokenizer(self, text): + result = self.tokenize( + text, + padding="max_length", + truncation=True, + max_length=512, + return_tensors="pd", + return_attention_mask=True, + ) + return {k: v.squeeze(0) for k, v in result.items()} diff --git a/paddlemix/models/audioldm2/encoders/flant5_encoder.py b/paddlemix/models/audioldm2/encoders/flant5_encoder.py new file mode 100644 index 000000000..97b77c111 --- /dev/null +++ b/paddlemix/models/audioldm2/encoders/flant5_encoder.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import paddle +import paddle.nn as nn +from paddlenlp.transformers import AutoTokenizer, T5EncoderModel, T5Config + +class FlanT5HiddenState(nn.Layer): + """ + llama = FlanT5HiddenState() + data = ["","this is not an empty sentence"] + encoder_hidden_states = llama(data) + import ipdb;ipdb.set_trace() + """ + + def __init__( + self, text_encoder_name="t5-v1_1-large", freeze_text_encoder=True # t5-v1_1-large -> google/flan-t5-large + ): + super().__init__() + self.freeze_text_encoder = freeze_text_encoder + self.tokenizer = AutoTokenizer.from_pretrained(text_encoder_name) + self.model = T5EncoderModel(T5Config.from_pretrained(text_encoder_name)) + if freeze_text_encoder: + self.model.eval() + for p in self.model.parameters(): + p.stop_gradient = True + else: + print("=> The text encoder is learnable") + + self.empty_hidden_state_cfg = None + self.device = None + + # Required + def get_unconditional_condition(self, batchsize): + param = self.model.parameters()[0] + if self.freeze_text_encoder: + assert param.stop_gradient == True + + # device = param.device + if self.empty_hidden_state_cfg is None: + self.empty_hidden_state_cfg, _ = self([""]) + + hidden_state = paddle.cast(paddle.concat([self.empty_hidden_state_cfg] * batchsize), dtype="float32") + attention_mask = ( + paddle.ones((batchsize, hidden_state.shape[1]), dtype="float32") + ) + return [hidden_state, attention_mask] # Need to return float type + + def forward(self, batch): + param = self.model.parameters()[0] + if self.freeze_text_encoder: + assert param.stop_gradient == True + + try: + return self.encode_text(batch) + except Exception as e: + print(e, batch) + logging.exception("An error occurred: %s", str(e)) + + def encode_text(self, prompt): + # device = self.model.device + batch = self.tokenizer( + prompt, + max_length=128, # self.tokenizer.model_max_length + padding=True, + truncation=True, + return_tensors="pd", + ) + input_ids, attention_mask = batch.input_ids, batch.attention_mask + # Get text encoding + if self.freeze_text_encoder: + with paddle.no_grad(): + encoder_hidden_states = self.model( + input_ids=input_ids, attention_mask=attention_mask + )[0] + else: + encoder_hidden_states = self.model( + input_ids=input_ids, attention_mask=attention_mask + )[0] + return [ + encoder_hidden_states.detach(), + paddle.cast(attention_mask, dtype="float32"), + ] # Attention mask == 1 means usable token + \ No newline at end of file diff --git a/paddlemix/models/audioldm2/encoders/phoneme_encoder/__init__.py b/paddlemix/models/audioldm2/encoders/phoneme_encoder/__init__.py new file mode 100644 index 000000000..fd05a9208 --- /dev/null +++ b/paddlemix/models/audioldm2/encoders/phoneme_encoder/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlemix/models/audioldm2/encoders/phoneme_encoder/cleaners.py b/paddlemix/models/audioldm2/encoders/phoneme_encoder/cleaners.py new file mode 100644 index 000000000..5433a9e8b --- /dev/null +++ b/paddlemix/models/audioldm2/encoders/phoneme_encoder/cleaners.py @@ -0,0 +1,103 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" from https://github.com/keithito/tacotron """ + +import re +from unidecode import unidecode +from phonemizer import phonemize + +__all__ = [ + "basic_cleaners", + "transliteration_cleaners", + "english_cleaners", + "english_cleaners2" +] + +# Regular expression matching whitespace: +_whitespace_re = re.compile(r'\s+') + +# List of (regular expression, replacement) pairs for abbreviations: +_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ + ('mrs', 'misess'), + ('mr', 'mister'), + ('dr', 'doctor'), + ('st', 'saint'), + ('co', 'company'), + ('jr', 'junior'), + ('maj', 'major'), + ('gen', 'general'), + ('drs', 'doctors'), + ('rev', 'reverend'), + ('lt', 'lieutenant'), + ('hon', 'honorable'), + ('sgt', 'sergeant'), + ('capt', 'captain'), + ('esq', 'esquire'), + ('ltd', 'limited'), + ('col', 'colonel'), + ('ft', 'fort'), +]] + + +def expand_abbreviations(text): + for regex, replacement in _abbreviations: + text = re.sub(regex, replacement, text) + return text + +def lowercase(text): + return text.lower() + + +def collapse_whitespace(text): + return re.sub(_whitespace_re, ' ', text) + + +def convert_to_ascii(text): + return unidecode(text) + + +def basic_cleaners(text): + '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def transliteration_cleaners(text): + '''Pipeline for non-English text that transliterates to ASCII.''' + text = convert_to_ascii(text) + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def english_cleaners(text): + '''Pipeline for English text, including abbreviation expansion.''' + text = convert_to_ascii(text) + text = lowercase(text) + text = expand_abbreviations(text) + phonemes = phonemize(text, language='en-us', backend='espeak', strip=True) + phonemes = collapse_whitespace(phonemes) + return phonemes + + +def english_cleaners2(text): + '''Pipeline for English text, including abbreviation expansion. + punctuation + stress''' + text = convert_to_ascii(text) + text = lowercase(text) + text = expand_abbreviations(text) + phonemes = phonemize(text, language='en-us', backend='espeak', strip=True, preserve_punctuation=True, with_stress=True) + phonemes = collapse_whitespace(phonemes) + return phonemes \ No newline at end of file diff --git a/paddlemix/models/audioldm2/encoders/phoneme_encoder/symbols.py b/paddlemix/models/audioldm2/encoders/phoneme_encoder/symbols.py new file mode 100644 index 000000000..fd5abd88c --- /dev/null +++ b/paddlemix/models/audioldm2/encoders/phoneme_encoder/symbols.py @@ -0,0 +1,28 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +''' +Defines the set of symbols used in text input to the model. +''' +_pad = '_' +_punctuation = ';:,.!?¡¿—…"«»“” ' +_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' +_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" + + +# Export all symbols: +symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) + +# Special symbol ids +SPACE_ID = symbols.index(" ") diff --git a/paddlemix/models/audioldm2/encoders/phoneme_encoder/text.py b/paddlemix/models/audioldm2/encoders/phoneme_encoder/text.py new file mode 100644 index 000000000..efbd661b5 --- /dev/null +++ b/paddlemix/models/audioldm2/encoders/phoneme_encoder/text.py @@ -0,0 +1,62 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" from https://github.com/keithito/tacotron """ + +from .cleaners import * +from .symbols import symbols + +# Mappings from symbol to numeric ID and vice versa: +_symbol_to_id = {s: i for i, s in enumerate(symbols)} +_id_to_symbol = {i: s for i, s in enumerate(symbols)} + +cleaner = english_cleaners2 + +def text_to_sequence(text, cleaner_names): + '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. + Args: + text: string to convert to a sequence + cleaner_names: names of the cleaner functions to run the text through + Returns: + List of integers corresponding to the symbols in the text + ''' + sequence = [] + + clean_text = _clean_text(text, cleaner_names) + for symbol in clean_text: + symbol_id = _symbol_to_id[symbol] + sequence += [symbol_id] + return sequence + +def cleaned_text_to_sequence(cleaned_text): + '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. + Args: + text: string to convert to a sequence + Returns: + List of integers corresponding to the symbols in the text + ''' + sequence = [_symbol_to_id[symbol] for symbol in cleaned_text] + return sequence + +def sequence_to_text(sequence): + '''Converts a sequence of IDs back to a string''' + result = '' + for symbol_id in sequence: + s = _id_to_symbol[symbol_id] + result += s + return result + +def _clean_text(text, cleaner_names): + text = cleaner(text) + return text diff --git a/paddlemix/models/audioldm2/encoders/sequence2audiomae_encoder.py b/paddlemix/models/audioldm2/encoders/sequence2audiomae_encoder.py new file mode 100644 index 000000000..4122bbbff --- /dev/null +++ b/paddlemix/models/audioldm2/encoders/sequence2audiomae_encoder.py @@ -0,0 +1,487 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +from paddlenlp.transformers import GPTModel +import importlib + +class Sequence2AudioMAE(nn.Layer): + def __init__( + self, + base_learning_rate, + sequence_gen_length, + sequence_input_key, + sequence_input_embed_dim, + cond_stage_config, + optimizer_type="AdamW", + use_warmup=True, + use_ar_gen_loss=False, + use_audiomae_linear=False, + target_tokens_mask_ratio=0.0, + random_mask_ratio=False, + **kwargs + ): + super().__init__() + assert use_audiomae_linear == False + self.random_mask_ratio = random_mask_ratio + self.learning_rate = base_learning_rate + self.cond_stage_config = cond_stage_config + self.use_audiomae_linear = use_audiomae_linear + self.optimizer_type = optimizer_type + self.use_warmup = use_warmup + self.use_ar_gen_loss = use_ar_gen_loss + # Even though the LDM can be conditioned on mutliple pooling rate + # Our model always predict the higest pooling rate + + self.mae_token_num = sequence_gen_length + self.sequence_input_key = sequence_input_key + self.sequence_input_embed_dim = sequence_input_embed_dim + self.target_tokens_mask_ratio = target_tokens_mask_ratio + + self.start_of_sequence_tokens = nn.Embedding(32, 768) + self.end_of_sequence_tokens = nn.Embedding(32, 768) + + self.input_sequence_embed_linear = nn.LayerList([]) + self.initial_learning_rate = None + + for dim in self.sequence_input_embed_dim: + self.input_sequence_embed_linear.append(nn.Linear(dim, 768)) + + self.cond_stage_models = nn.LayerList([]) + self.instantiate_cond_stage(cond_stage_config) + self.initialize_param_check_toolkit() + + self.model = GPTModel.from_pretrained("gpt2") + + self.loss_fn = nn.L1Loss() + + self.logger_save_dir = None + self.logger_exp_name = None + self.logger_exp_group_name = None + self.logger_version = None + + def set_log_dir(self, save_dir, exp_group_name, exp_name): + self.logger_save_dir = save_dir + self.logger_exp_group_name = exp_group_name + self.logger_exp_name = exp_name + + def cfg_uncond(self, batch_size): + unconditional_conditioning = {} + for key in self.cond_stage_model_metadata: + model_idx = self.cond_stage_model_metadata[key]["model_idx"] + unconditional_conditioning[key] = self.cond_stage_models[ + model_idx + ].get_unconditional_condition(batch_size) + assert ( + "crossattn_audiomae_pooled" in unconditional_conditioning.keys() + ), "The module is not initialized with AudioMAE" + unconditional_conditioning[ + "crossattn_clap_to_audiomae_feature" + ] = unconditional_conditioning["crossattn_audiomae_pooled"] + return unconditional_conditioning + + def add_sos_eos_tokens(self, _id, sequence, attn_mask): + batchsize = sequence.shape[0] + + new_attn_mask_step = paddle.ones((batchsize, 1)) + key_id = paddle.to_tensor([_id]) + + # Add two more steps to attn mask + new_attn_mask = paddle.concat( + [new_attn_mask_step, attn_mask, new_attn_mask_step], axis=1 + ) + + # Add two more tokens in the sequence + sos_token = self.start_of_sequence_tokens(key_id).expand([batchsize, 1, -1]) + eos_token = self.end_of_sequence_tokens(key_id).expand([batchsize, 1, -1]) + new_sequence = paddle.concat([sos_token, sequence, eos_token], axis=1) + return new_sequence, new_attn_mask + + def truncate_sequence_and_mask(self, sequence, mask, max_len=512): + if sequence.shape[1] > max_len: + print( + "The input sequence length to GPT-2 model is too long:", + sequence.shape[1], + ) + return sequence[:, :max_len], mask[:, :max_len] + else: + return sequence, mask + + def get_input_sequence_and_mask(self, cond_dict): + input_embeds = None + input_embeds_attn_mask = None + for _id, sequence_key in enumerate(self.sequence_input_key): + assert sequence_key in cond_dict.keys(), ( + "Invalid sequence key %s" % sequence_key + ) + cond_embed = cond_dict[sequence_key] + if isinstance(cond_embed, list): + assert ( + len(cond_embed) == 2 + ), "The crossattn returned list should have length 2, including embed and attn_mask" + item_input_embeds, item_attn_mask = cond_embed + + item_input_embeds = self.input_sequence_embed_linear[_id]( + item_input_embeds + ) + + item_input_embeds, item_attn_mask = self.add_sos_eos_tokens( + _id, item_input_embeds, item_attn_mask + ) + + if input_embeds is None and input_embeds_attn_mask is None: + input_embeds, input_embeds_attn_mask = ( + item_input_embeds, + item_attn_mask, + ) + else: + input_embeds = paddle.concat( + [input_embeds, item_input_embeds], axis=1 + ) # The 1-st dimension is time steps + input_embeds_attn_mask = paddle.concat( + [input_embeds_attn_mask, item_attn_mask], axis=1 + ) # The 1-st dimension is time steps + else: + assert isinstance(cond_embed, paddle.Tensor) + cond_embed = self.input_sequence_embed_linear[_id](cond_embed) + attn_mask = paddle.ones((cond_embed.shape[0], cond_embed.shape[1])) + + item_input_embeds, item_attn_mask = self.add_sos_eos_tokens( + _id, cond_embed, attn_mask + ) + + if input_embeds is None and input_embeds_attn_mask is None: + input_embeds, input_embeds_attn_mask = ( + item_input_embeds, + item_attn_mask, + ) + else: + input_embeds, input_embeds_attn_mask = paddle.concat( + [input_embeds, item_input_embeds], axis=1 + ), paddle.concat([input_embeds_attn_mask, item_attn_mask], axis=1) + + assert input_embeds is not None and input_embeds_attn_mask is not None + + input_embeds, input_embeds_attn_mask = self.truncate_sequence_and_mask( + input_embeds, input_embeds_attn_mask, int(1024 - self.mae_token_num) + ) + cond_sequence_end_time_idx = input_embeds.shape[ + 1 + ] # The index that we start to collect the output embeds + + return input_embeds, input_embeds_attn_mask, cond_sequence_end_time_idx + + def mask_target_sequence(self, target_embeds, target_embeds_attn_mask): + time_seq_mask = None + if self.target_tokens_mask_ratio > 1e-4: + batchsize, time_seq_len, embed_dim = target_embeds.shape + _, time_seq_len = target_embeds_attn_mask.shape + # Generate random mask + if self.random_mask_ratio: + mask_ratio = paddle.rand((1,)).item() * self.target_tokens_mask_ratio + else: + mask_ratio = self.target_tokens_mask_ratio + + time_seq_mask = (paddle.rand((batchsize, time_seq_len)) > mask_ratio) + + # Mask the target embedding + target_embeds = target_embeds * time_seq_mask.unsqueeze(-1) + target_embeds_attn_mask = target_embeds_attn_mask * time_seq_mask + return target_embeds, target_embeds_attn_mask, time_seq_mask + + def generate_partial(self, batch, cond_dict=None, no_grad=False): + if cond_dict is None: + cond_dict = self.get_input(batch) + + print("Generate partially prompted audio with in-context learning") + + target_embeds, target_embeds_attn_mask = ( + cond_dict["crossattn_audiomae_pooled"][0], + cond_dict["crossattn_audiomae_pooled"][1], + ) + + target_time_steps = target_embeds.shape[1] + + ( + input_embeds, + input_embeds_attn_mask, + cond_sequence_end_time_idx, + ) = self.get_input_sequence_and_mask(cond_dict) + + model_input = paddle.concat( + [input_embeds, target_embeds[:, : target_time_steps // 4, :]], axis=1 + ) + model_input_mask = paddle.concat( + [ + input_embeds_attn_mask, + target_embeds_attn_mask[:, : target_time_steps // 4], + ], + axis=1, + ) + + steps = self.mae_token_num + + for _ in range(3 * steps // 4): + output = self.model( + inputs_embeds=model_input, attention_mask=model_input_mask, return_dict=True + )["last_hidden_state"] + # Update the model input + model_input = paddle.concat([model_input, output[:, -1:, :]], axis=1) + # Update the attention mask + attention_mask_new_step = paddle.ones((model_input_mask.shape[0], 1)) + model_input_mask = paddle.concat( + [model_input_mask, attention_mask_new_step], axis=1 + ) + + output = model_input[:, cond_sequence_end_time_idx:] + + return output, cond_dict + + def generate(self, batch, cond_dict=None, no_grad=False): + if cond_dict is None: + cond_dict = self.get_input(batch) + + ( + input_embeds, + input_embeds_attn_mask, + cond_sequence_end_time_idx, + ) = self.get_input_sequence_and_mask(cond_dict) + model_input = input_embeds + model_input_mask = input_embeds_attn_mask + + steps = self.mae_token_num + + for _ in range(steps): + output = self.model( + inputs_embeds=model_input, attention_mask=model_input_mask, return_dict=True + )["last_hidden_state"] + # Update the model input + model_input = paddle.concat([model_input, output[:, -1:, :]], axis=1) + # Update the attention mask + attention_mask_new_step = paddle.ones((model_input_mask.shape[0], 1)) + model_input_mask = paddle.concat( + [model_input_mask, attention_mask_new_step], axis=1 + ) + + return model_input[:, cond_sequence_end_time_idx:], cond_dict + + def get_input_item(self, batch, k): + fname, text, waveform, stft, fbank = ( + batch["fname"], + batch["text"], + batch["waveform"], + batch["stft"], + batch["log_mel_spec"], + ) + ret = {} + + ret["fbank"] = ( + paddle.cast(fbank.unsqueeze(1), dtype="float32") + ) + ret["stft"] = paddle.cast(stft, dtype="float32") + ret["waveform"] = paddle.cast(waveform, dtype="float32") + ret["text"] = list(text) + ret["fname"] = fname + + for key in batch.keys(): + if key not in ret.keys(): + ret[key] = batch[key] + + return ret[k] + + def get_input(self, batch): + cond_dict = {} + if len(self.cond_stage_model_metadata.keys()) > 0: + unconditional_cfg = False + + for cond_model_key in self.cond_stage_model_metadata.keys(): + cond_stage_key = self.cond_stage_model_metadata[cond_model_key][ + "cond_stage_key" + ] + + # The original data for conditioning + xc = self.get_input_item(batch, cond_stage_key) + if type(xc) == paddle.Tensor: + xc = xc + + c = self.get_learned_conditioning( + xc, key=cond_model_key, unconditional_cfg=unconditional_cfg + ) + cond_dict[cond_model_key] = c + + return cond_dict + + def instantiate_cond_stage(self, config): + self.cond_stage_model_metadata = {} + + for i, cond_model_key in enumerate(config.keys()): + model = instantiate_from_config(config[cond_model_key]) + self.cond_stage_models.append(model) + self.cond_stage_model_metadata[cond_model_key] = { + "model_idx": i, + "cond_stage_key": config[cond_model_key]["cond_stage_key"], + "conditioning_key": config[cond_model_key]["conditioning_key"], + } + + def get_learned_conditioning(self, c, key, unconditional_cfg): + assert key in self.cond_stage_model_metadata.keys() + + # Classifier-free guidance + if not unconditional_cfg: + c = self.cond_stage_models[ + self.cond_stage_model_metadata[key]["model_idx"] + ](c) + else: + if isinstance(c, paddle.Tensor): + batchsize = c.shape[0] + elif isinstance(c, list): + batchsize = len(c) + else: + raise NotImplementedError() + c = self.cond_stage_models[ + self.cond_stage_model_metadata[key]["model_idx"] + ].get_unconditional_condition(batchsize) + + return c + + def initialize_param_check_toolkit(self): + self.tracked_steps = 0 + self.param_dict = {} + + def statistic_require_grad_tensor_number(self, module, name=None): + requires_grad_num = 0 + total_num = 0 + require_grad_tensor = None + for p in module.parameters(): + if not p.stop_gradient: + requires_grad_num += 1 + if require_grad_tensor is None: + require_grad_tensor = p + total_num += 1 + print( + "Module: [%s] have %s trainable parameters out of %s total parameters (%.2f)" + % (name, requires_grad_num, total_num, requires_grad_num / total_num) + ) + return require_grad_tensor + + +class SequenceGenAudioMAECond(Sequence2AudioMAE): + def __init__( + self, + cond_stage_config, + base_learning_rate, + sequence_gen_length, + sequence_input_key, + sequence_input_embed_dim, + batchsize, + always_output_audiomae_gt=False, + pretrained_path=None, + force_reload_pretrain_avoid_overwrite=False, + learnable=True, + use_warmup=True, + use_gt_mae_output=True, # False: does not use AudioMAE GT, True: Use AudioMAE GT + use_gt_mae_prob=0.0, + ): # The prob of using AudioMAE GT + if use_warmup: + use_warmup = False + + super().__init__( + base_learning_rate=base_learning_rate, + cond_stage_config=cond_stage_config, + sequence_gen_length=sequence_gen_length, + sequence_input_key=sequence_input_key, + use_warmup=use_warmup, + sequence_input_embed_dim=sequence_input_embed_dim, + batchsize=batchsize, + ) + + assert use_gt_mae_output is not None and use_gt_mae_prob is not None + self.always_output_audiomae_gt = always_output_audiomae_gt + self.force_reload_pretrain_avoid_overwrite = ( + force_reload_pretrain_avoid_overwrite + ) + self.pretrained_path = pretrained_path + if self.force_reload_pretrain_avoid_overwrite: + self.is_reload = False + else: + self.is_reload = True + + self.load_pretrain_model() + + self.use_gt_mae_output = use_gt_mae_output + self.use_gt_mae_prob = use_gt_mae_prob + self.learnable = learnable + + if not learnable: + # Only optimize the GPT2 model + for p in self.model.parameters(): + p.stop_gradient = True + self.eval() + + def load_pretrain_model(self): + if self.pretrained_path is not None: + print("Reload SequenceGenAudioMAECond from %s" % self.pretrained_path) + state_dict = paddle.load(self.pretrained_path)["state_dict"] + self.load_dict(state_dict) + + # Required + def get_unconditional_condition(self, batchsize): + return_dict = self.cfg_uncond(batchsize) + return_dict["crossattn_audiomae_generated"] = [ + return_dict["crossattn_audiomae_pooled"][0], + paddle.ones_like(return_dict["crossattn_audiomae_pooled"][1], dtype="float32"), + ] + return return_dict + + def forward(self, batch): + # The conditional module can return both tensor or dictionaries + # The returned tensor will be corresponding to the cond_stage_key + # The returned dict will have keys that correspond to the cond_stage_key + ret_dict = {} + + if self.force_reload_pretrain_avoid_overwrite and not self.is_reload: + self.load_pretrain_model() + self.is_reload = True + + input_embeds, cond_dict = self.generate(batch) + input_embeds_mask = ( + paddle.ones((input_embeds.shape[0], input_embeds.shape[1]), dtype="float32") + ) + ret_dict["crossattn_audiomae_generated"] = [ + input_embeds, + input_embeds_mask, + ] # Input sequence and mask + + # If the following two keys are not in cond_stage_key, then they will not be used as condition + for key in cond_dict.keys(): + ret_dict[key] = cond_dict[key] + + return ret_dict + +def instantiate_from_config(config): + if not "target" in config: + if config == "__is_first_stage__": + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package="paddlemix.models.audioldm2"), cls) diff --git a/paddlemix/models/audioldm2/hifigan/model.py b/paddlemix/models/audioldm2/hifigan/model.py new file mode 100644 index 000000000..d0df98101 --- /dev/null +++ b/paddlemix/models/audioldm2/hifigan/model.py @@ -0,0 +1,333 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.nn.utils import weight_norm, remove_weight_norm +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.nn import Conv1D, Conv1DTranspose + +LRELU_SLOPE = 0.1 + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def get_vocoder_config(): + return { + "resblock": "1", + "num_gpus": 6, + "batch_size": 16, + "learning_rate": 0.0002, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.999, + "seed": 1234, + "upsample_rates": [5, 4, 2, 2, 2], + "upsample_kernel_sizes": [16, 16, 8, 4, 4], + "upsample_initial_channel": 1024, + "resblock_kernel_sizes": [3, 7, 11], + "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "segment_size": 8192, + "num_mels": 64, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 160, + "win_size": 1024, + "sampling_rate": 16000, + "fmin": 0, + "fmax": 8000, + "fmax_for_loss": None, + "num_workers": 4, + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1, + }, + } + + +def get_vocoder_config_48k(): + return { + "resblock": "1", + "num_gpus": 8, + "batch_size": 128, + "learning_rate": 0.0001, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.999, + "seed": 1234, + + "upsample_rates": [6,5,4,2,2], + "upsample_kernel_sizes": [12,10,8,4,4], + "upsample_initial_channel": 1536, + "resblock_kernel_sizes": [3,7,11,15], + "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5], [1,3,5]], + + "segment_size": 15360, + "num_mels": 256, + "n_fft": 2048, + "hop_size": 480, + "win_size": 2048, + + "sampling_rate": 48000, + + "fmin": 20, + "fmax": 24000, + "fmax_for_loss": None, + + "num_workers": 8, + + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:18273", + "world_size": 1 + } + } + + +class ResBlock(nn.Layer): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock, self).__init__() + self.h = h + weight_attr1 = paddle.ParamAttr( + initializer=nn.initializer.Normal(mean=0.0, std=0.01) + ) + weight_attr2 = paddle.ParamAttr( + initializer=nn.initializer.Normal(mean=0.0, std=0.01) + ) + weight_attr3 = paddle.ParamAttr( + initializer=nn.initializer.Normal(mean=0.0, std=0.01) + ) + self.convs1 = nn.LayerList( + [ + weight_norm( + Conv1D( + channels, + channels, + kernel_size, + 1, + padding=get_padding(kernel_size, dilation[0]), + dilation=dilation[0], + weight_attr=weight_attr1, + ) + ), + weight_norm( + Conv1D( + channels, + channels, + kernel_size, + 1, + padding=get_padding(kernel_size, dilation[1]), + dilation=dilation[1], + weight_attr=weight_attr2, + ) + ), + weight_norm( + Conv1D( + channels, + channels, + kernel_size, + 1, + padding=get_padding(kernel_size, dilation[2]), + dilation=dilation[2], + weight_attr=weight_attr3, + ) + ), + ] + ) + + weight_attr4 = paddle.ParamAttr( + initializer=nn.initializer.Normal(mean=0.0, std=0.01) + ) + weight_attr5 = paddle.ParamAttr( + initializer=nn.initializer.Normal(mean=0.0, std=0.01) + ) + weight_attr6 = paddle.ParamAttr( + initializer=nn.initializer.Normal(mean=0.0, std=0.01) + ) + self.convs2 = nn.LayerList( + [ + weight_norm( + Conv1D( + channels, + channels, + kernel_size, + 1, + padding=get_padding(kernel_size, 1), + dilation=1, + weight_attr=weight_attr4, + ) + ), + weight_norm( + Conv1D( + channels, + channels, + kernel_size, + 1, + padding=get_padding(kernel_size, 1), + dilation=1, + weight_attr=weight_attr5, + ) + ), + weight_norm( + Conv1D( + channels, + channels, + kernel_size, + 1, + padding=get_padding(kernel_size, 1), + dilation=1, + weight_attr=weight_attr6, + ) + ), + ] + ) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class Generator(nn.Layer): + def __init__(self, h): + super(Generator, self).__init__() + self.h = h + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + self.conv_pre = weight_norm( + Conv1D(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3) + ) + resblock = ResBlock + + self.ups = nn.LayerList() + for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): + weight_attr_tmp = paddle.ParamAttr( + initializer=nn.initializer.Normal(mean=0.0, std=0.01) + ) + self.ups.append( + weight_norm( + Conv1DTranspose( + h.upsample_initial_channel // (2**i), + h.upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + weight_attr=weight_attr_tmp, + ) + ) + ) + + self.resblocks = nn.LayerList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate( + zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) + ): + self.resblocks.append(resblock(h, ch, k, d)) + + weight_attr = paddle.ParamAttr( + initializer=nn.initializer.Normal(mean=0.0, std=0.01) + ) + self.conv_post = weight_norm(Conv1D(ch, 1, 7, 1, padding=3, weight_attr=weight_attr)) + + def forward(self, x): + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = paddle.tanh(x) + + return x + + def remove_weight_norm(self): + # print("Removing weight norm...") + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +def get_vocoder(config, mel_bins): + if(mel_bins == 64): + config = get_vocoder_config() + config = AttrDict(config) + vocoder = Generator(config) + vocoder.eval() + vocoder.remove_weight_norm() + else: + config = get_vocoder_config_48k() + config = AttrDict(config) + vocoder = Generator(config) + vocoder.eval() + vocoder.remove_weight_norm() + + return vocoder + + +def vocoder_infer(mels, vocoder, lengths=None): + with paddle.no_grad(): + wavs = vocoder(mels).squeeze(1) + + wavs = (wavs.numpy() * 32768).astype("int16") + + if lengths is not None: + wavs = wavs[:, :lengths] + + return wavs + + +def synth_one_sample(mel_input, mel_prediction, labels, vocoder): + if vocoder is not None: + + wav_reconstruction = vocoder_infer( + mel_input.transpose([0, 2, 1]), + vocoder, + ) + wav_prediction = vocoder_infer( + mel_prediction.transpose([0, 2, 1]), + vocoder, + ) + else: + wav_reconstruction = wav_prediction = None + + return wav_reconstruction, wav_prediction diff --git a/paddlemix/models/audioldm2/latent_encoder/autoencoder.py b/paddlemix/models/audioldm2/latent_encoder/autoencoder.py new file mode 100644 index 000000000..31aac6f2d --- /dev/null +++ b/paddlemix/models/audioldm2/latent_encoder/autoencoder.py @@ -0,0 +1,140 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import numpy as np + +from ppdiffusers import AutoencoderKL +from ..hifigan.model import get_vocoder + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = paddle.chunk(parameters, 2, axis=1) + self.logvar = paddle.clip(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = paddle.exp(0.5 * self.logvar) + self.var = paddle.exp(self.logvar) + if self.deterministic: + self.var = self.std = paddle.zeros_like(self.mean) + + def sample(self): + x = self.mean + self.std * paddle.randn(self.mean.shape) + return x + + def kl(self, other=None): + if self.deterministic: + return paddle.to_tensor([0.0]) + else: + if other is None: + return 0.5 * paddle.mean( + paddle.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3], + ) + else: + return 0.5 * paddle.mean( + paddle.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample, dims=[1, 2, 3]): + if self.deterministic: + return paddle.to_tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * paddle.sum( + logtwopi + self.logvar + paddle.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + + def mode(self): + return self.mean + + +class AudioLDMAutoencoderKL(AutoencoderKL): + def __init__( + self, + ddconfig=None, + lossconfig=None, + batchsize=None, + embed_dim=None, + time_shuffle=1, + subband=1, + sampling_rate=16000, + reload_from_ckpt=None, + ignore_keys=[], + image_key="fbank", + colorize_nlabels=None, + monitor=None, + base_learning_rate=1e-5, + ): + super().__init__( + in_channels = ddconfig["in_channels"], + out_channels = ddconfig["out_ch"], + down_block_types = ("DownEncoderBlock2D",) * len(ddconfig["ch_mult"]), + up_block_types = ("UpDecoderBlock2D",) * len(ddconfig["ch_mult"]), + block_out_channels = tuple([ddconfig["ch"]*i for i in ddconfig["ch_mult"]]), + layers_per_block = ddconfig["num_res_blocks"], + latent_channels = ddconfig["z_channels"], + ) + self.automatic_optimization = False + assert ( + "mel_bins" in ddconfig.keys() + ), "mel_bins is not specified in the Autoencoder config" + num_mel = ddconfig["mel_bins"] + self.image_key = image_key + self.sampling_rate = sampling_rate + + self.loss = None + self.subband = int(subband) + + if self.subband > 1: + print("Use subband decomposition %s" % self.subband) + + if self.image_key == "fbank": + self.vocoder = get_vocoder(None, num_mel) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels) == int + self.register_buffer("colorize", paddle.randn([3, colorize_nlabels, 1, 1])) + if monitor is not None: + self.monitor = monitor + self.learning_rate = float(base_learning_rate) + # print("Initial learning rate %s" % self.learning_rate) + + self.time_shuffle = time_shuffle + self.reload_from_ckpt = reload_from_ckpt + self.reloaded = False + self.mean, self.std = None, None + + self.feature_cache = None + self.flag_first_run = True + self.train_step = 0 + + self.logger_save_dir = None + self.logger_exp_name = None + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec diff --git a/paddlemix/models/audioldm2/latentdiffusion_samplers.py b/paddlemix/models/audioldm2/latentdiffusion_samplers.py new file mode 100644 index 000000000..561e3dcab --- /dev/null +++ b/paddlemix/models/audioldm2/latentdiffusion_samplers.py @@ -0,0 +1,870 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import numpy as np +from tqdm import tqdm + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(t, -1) + # return out.reshape(b, *((1,) * (len(x_shape) - 1))) + return out.reshape((b,) + ((1,) * (len(x_shape) - 1))) + +def noise_like(shape, repeat=False): + repeat_noise = lambda: paddle.randn((1, *shape[1:])).repeat_interleave(repeats=shape[0], axis=0) + noise = lambda: paddle.randn(shape) + return repeat_noise() if repeat else noise() + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt( + (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev) + ) + if verbose: + print( + f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}" + ) + print( + f"For the chosen value of eta, which is {eta}, " + f"this results in the following sigma_t schedule for ddim sampler {sigmas}" + ) + return sigmas, alphas, alphas_prev + +def make_ddim_timesteps( + ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True +): + if ddim_discr_method == "uniform": + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == "quad": + ddim_timesteps = ( + (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2 + ).astype(int) + else: + raise NotImplementedError( + f'There is no ddim discretization method called "{ddim_discr_method}"' + ) + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f"Selected timesteps for ddim sampler: {steps_out}") + return steps_out + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", device="cpu", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + self.device = device + + def register_buffer(self, name, attr): + setattr(self, name, attr) + + def make_schedule( + self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True + ): + self.ddim_timesteps = make_ddim_timesteps( + ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose, + ) + alphas_cumprod = self.model.alphas_cumprod + assert ( + alphas_cumprod.shape[0] == self.ddpm_num_timesteps + ), "alphas have to be defined for each timestep" + to_paddle = lambda x: paddle.cast(x.clone().detach(), dtype="float32") if isinstance(x, paddle.Tensor) else paddle.to_tensor(x, dtype="float32") + + self.register_buffer("betas", to_paddle(self.model.betas)) + self.register_buffer("alphas_cumprod", to_paddle(alphas_cumprod)) + self.register_buffer( + "alphas_cumprod_prev", to_paddle(self.model.alphas_cumprod_prev) + ) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer( + "sqrt_alphas_cumprod", to_paddle(np.sqrt(alphas_cumprod.numpy())) + ) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", + to_paddle(np.sqrt(1.0 - alphas_cumprod.numpy())), + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_paddle(np.log(1.0 - alphas_cumprod.numpy())) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_paddle(np.sqrt(1.0 / alphas_cumprod.numpy())) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", + to_paddle(np.sqrt(1.0 / alphas_cumprod.numpy() - 1)), + ) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( + alphacums=alphas_cumprod.numpy(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta, + verbose=verbose, + ) + self.register_buffer("ddim_sigmas", ddim_sigmas) + self.register_buffer("ddim_alphas", ddim_alphas) + self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) + self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * paddle.sqrt( + (1 - self.alphas_cumprod_prev) + / (1 - self.alphas_cumprod) + * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) + ) + self.register_buffer( + "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps + ) + + @paddle.no_grad() + def sample( + self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0.0, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + dynamic_threshold=None, + ucg_schedule=None, + **kwargs, + ): + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + # print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling( + conditioning, + size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, + x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ucg_schedule=ucg_schedule, + ) + return samples, intermediates + + @paddle.no_grad() + def ddim_sampling( + self, + cond, + shape, + x_T=None, + ddim_use_original_steps=False, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + log_every_t=100, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + dynamic_threshold=None, + ucg_schedule=None, + ): + + b = shape[0] + if x_T is None: + img = paddle.randn(shape) + else: + img = x_T + + if timesteps is None: + timesteps = ( + self.ddpm_num_timesteps + if ddim_use_original_steps + else self.ddim_timesteps + ) + elif timesteps is not None and not ddim_use_original_steps: + subset_end = ( + int( + min(timesteps / self.ddim_timesteps.shape[0], 1) + * self.ddim_timesteps.shape[0] + ) + - 1 + ) + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {"x_inter": [img], "pred_x0": [img]} + time_range = ( + reversed(range(0, timesteps)) + if ddim_use_original_steps + else np.flip(timesteps) + ) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps) + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = paddle.full((b,), step, dtype="int64") + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample( + x0, ts + ) # TODO: deterministic forward pass? + img = img_orig * mask + (1.0 - mask) * img + + if ucg_schedule is not None: + assert len(ucg_schedule) == len(time_range) + unconditional_guidance_scale = ucg_schedule[i] + + outs = self.p_sample_ddim( + img, + cond, + ts, + index=index, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ) + img, pred_x0 = outs + if callback: + callback(i) + if img_callback: + img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates["x_inter"].append(img) + intermediates["pred_x0"].append(pred_x0) + + return img, intermediates + + @paddle.no_grad() + def p_sample_ddim( + self, + x, + c, + t, + index, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + dynamic_threshold=None, + ): + b, *_ = x.shape + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.0: + model_output = self.model.apply_model(x, t, c) + else: + x_in = x + t_in = t + + assert isinstance(c, dict) + assert isinstance(unconditional_conditioning, dict) + + model_uncond = self.model.apply_model( + x_in, t_in, unconditional_conditioning + ) + model_t = self.model.apply_model(x_in, t_in, c) + + model_output = model_uncond + unconditional_guidance_scale * ( + model_t - model_uncond + ) + + if self.model.parameterization == "v": + e_t = self.model.predict_eps_from_z_and_v(x, t, model_output) + else: + e_t = model_output + + if score_corrector is not None: + assert self.model.parameterization == "eps", "not implemented" + e_t = score_corrector.modify_score( + self.model, e_t, x, t, c, **corrector_kwargs + ) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = ( + self.model.alphas_cumprod_prev + if use_original_steps + else self.ddim_alphas_prev + ) + sqrt_one_minus_alphas = ( + self.model.sqrt_one_minus_alphas_cumprod + if use_original_steps + else self.ddim_sqrt_one_minus_alphas + ) + sigmas = ( + self.model.ddim_sigmas_for_original_num_steps + if use_original_steps + else self.ddim_sigmas + ) + # select parameters corresponding to the currently considered timestep + a_t = paddle.full((b, 1, 1, 1), alphas[index]) + a_prev = paddle.full((b, 1, 1, 1), alphas_prev[index]) + sigma_t = paddle.full((b, 1, 1, 1), sigmas[index]) + sqrt_one_minus_at = paddle.full( + (b, 1, 1, 1), sqrt_one_minus_alphas[index] + ) + + # current prediction for x_0 + if self.model.parameterization != "v": + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + else: + pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) + + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + + if dynamic_threshold is not None: + raise NotImplementedError() + + # direction pointing to x_t + dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, repeat_noise) * temperature + if noise_dropout > 0.0: + noise = nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + @paddle.no_grad() + def encode( + self, + x0, + c, + t_enc, + use_original_steps=False, + return_intermediates=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + callback=None, + ): + num_reference_steps = ( + self.ddpm_num_timesteps + if use_original_steps + else self.ddim_timesteps.shape[0] + ) + + assert t_enc <= num_reference_steps + num_steps = t_enc + + if use_original_steps: + alphas_next = self.alphas_cumprod[:num_steps] + alphas = self.alphas_cumprod_prev[:num_steps] + else: + alphas_next = self.ddim_alphas[:num_steps] + alphas = paddle.to_tensor(self.ddim_alphas_prev[:num_steps]) + + x_next = x0 + intermediates = [] + inter_steps = [] + for i in tqdm(range(num_steps), desc="Encoding Image"): + t = paddle.full( + (x0.shape[0],), i, dtype="int64" + ) + if unconditional_guidance_scale == 1.0: + noise_pred = self.model.apply_model(x_next, t, c) + else: + assert unconditional_conditioning is not None + e_t_uncond, noise_pred = paddle.chunk( + self.model.apply_model( + paddle.concat((x_next, x_next)), + paddle.concat((t, t)), + paddle.concat((unconditional_conditioning, c)), + ), + 2, + ) + noise_pred = e_t_uncond + unconditional_guidance_scale * ( + noise_pred - e_t_uncond + ) + + xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next + weighted_noise_pred = ( + alphas_next[i].sqrt() + * ((1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) + * noise_pred + ) + x_next = xt_weighted + weighted_noise_pred + if ( + return_intermediates + and i % (num_steps // return_intermediates) == 0 + and i < num_steps - 1 + ): + intermediates.append(x_next) + inter_steps.append(i) + elif return_intermediates and i >= num_steps - 2: + intermediates.append(x_next) + inter_steps.append(i) + if callback: + callback(i) + + out = {"x_encoded": x_next, "intermediate_steps": inter_steps} + if return_intermediates: + out.update({"intermediates": intermediates}) + return x_next, out + + @paddle.no_grad() + def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): + # fast, but does not allow for exact reconstruction + # t serves as an index to gather the correct alphas + if use_original_steps: + sqrt_alphas_cumprod = self.sqrt_alphas_cumprod + sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod + else: + sqrt_alphas_cumprod = paddle.sqrt(self.ddim_alphas) + sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas + + if noise is None: + noise = paddle.randn(x0.shape) + return ( + extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise + ) + + @paddle.no_grad() + def decode( + self, + x_latent, + cond, + t_start, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + use_original_steps=False, + callback=None, + ): + timesteps = ( + np.arange(self.ddpm_num_timesteps) + if use_original_steps + else self.ddim_timesteps + ) + timesteps = timesteps[:t_start] + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc="Decoding image", total=total_steps) + x_dec = x_latent + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = paddle.full( + (x_latent.shape[0],), step, dtype="int64" + ) + x_dec, _ = self.p_sample_ddim( + x_dec, + cond, + ts, + index=index, + use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + if callback: + callback(i) + return x_dec + + +class PLMSSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + setattr(self, name, attr) + + def make_schedule( + self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True + ): + if ddim_eta != 0: + ddim_eta = 0 + # raise ValueError('ddim_eta must be 0 for PLMS') + + self.ddim_timesteps = make_ddim_timesteps( + ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose, + ) + alphas_cumprod = self.model.alphas_cumprod + assert ( + alphas_cumprod.shape[0] == self.ddpm_num_timesteps + ), "alphas have to be defined for each timestep" + to_paddle = lambda x: paddle.cast(x.clone().detach(), dtype="float32") + + self.register_buffer("betas", to_paddle(self.model.betas)) + self.register_buffer("alphas_cumprod", to_paddle(alphas_cumprod)) + self.register_buffer( + "alphas_cumprod_prev", to_paddle(self.model.alphas_cumprod_prev) + ) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer( + "sqrt_alphas_cumprod", to_paddle(np.sqrt(alphas_cumprod.numpy())) + ) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", + to_paddle(np.sqrt(1.0 - alphas_cumprod.numpy())), + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_paddle(np.log(1.0 - alphas_cumprod.numpy())) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_paddle(np.sqrt(1.0 / alphas_cumprod.numpy())) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", + to_paddle(np.sqrt(1.0 / alphas_cumprod.numpy() - 1)), + ) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( + alphacums=alphas_cumprod.numpy(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta, + verbose=verbose, + ) + self.register_buffer("ddim_sigmas", ddim_sigmas) + self.register_buffer("ddim_alphas", ddim_alphas) + self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) + self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * paddle.sqrt( + (1 - self.alphas_cumprod_prev) + / (1 - self.alphas_cumprod) + * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) + ) + self.register_buffer( + "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps + ) + + @paddle.no_grad() + def sample( + self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0.0, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs, + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print( + f"Warning: Got {cbs} conditionings but batch-size is {batch_size}" + ) + else: + if conditioning.shape[0] != batch_size: + print( + f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}" + ) + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f"Data shape for PLMS sampling is {size}") + + samples, intermediates = self.plms_sampling( + conditioning, + size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, + x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples, intermediates + + @paddle.no_grad() + def plms_sampling( + self, + cond, + shape, + x_T=None, + ddim_use_original_steps=False, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + log_every_t=100, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + ): + + b = shape[0] + if x_T is None: + img = paddle.randn(shape) + else: + img = x_T + + if timesteps is None: + timesteps = ( + self.ddpm_num_timesteps + if ddim_use_original_steps + else self.ddim_timesteps + ) + elif timesteps is not None and not ddim_use_original_steps: + subset_end = ( + int( + min(timesteps / self.ddim_timesteps.shape[0], 1) + * self.ddim_timesteps.shape[0] + ) + - 1 + ) + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {"x_inter": [img], "pred_x0": [img]} + time_range = ( + list(reversed(range(0, timesteps))) + if ddim_use_original_steps + else np.flip(timesteps) + ) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running PLMS Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc="PLMS Sampler", total=total_steps) + old_eps = [] + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = paddle.full((b,), step, dtype="int64") + ts_next = paddle.full( + (b,), + time_range[min(i + 1, len(time_range) - 1)], + dtype="int64", + ) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample( + x0, ts + ) # TODO: deterministic forward pass? + img = img_orig * mask + (1.0 - mask) * img + + outs = self.p_sample_plms( + img, + cond, + ts, + index=index, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + old_eps=old_eps, + t_next=ts_next, + ) + img, pred_x0, e_t = outs + old_eps.append(e_t) + if len(old_eps) >= 4: + old_eps.pop(0) + if callback: + callback(i) + if img_callback: + img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates["x_inter"].append(img) + intermediates["pred_x0"].append(pred_x0) + + return img, intermediates + + @paddle.no_grad() + def p_sample_plms( + self, + x, + c, + t, + index, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + old_eps=None, + t_next=None, + ): + b, *_ = x.shape + + def get_model_output(x, t): + if ( + unconditional_conditioning is None + or unconditional_guidance_scale == 1.0 + ): + e_t = self.model.apply_model(x, t, c) + else: + x_in = paddle.concat([x] * 2) + t_in = paddle.concat([t] * 2) + c_in = paddle.concat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score( + self.model, e_t, x, t, c, **corrector_kwargs + ) + + return e_t + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = ( + self.model.alphas_cumprod_prev + if use_original_steps + else self.ddim_alphas_prev + ) + sqrt_one_minus_alphas = ( + self.model.sqrt_one_minus_alphas_cumprod + if use_original_steps + else self.ddim_sqrt_one_minus_alphas + ) + sigmas = ( + self.model.ddim_sigmas_for_original_num_steps + if use_original_steps + else self.ddim_sigmas + ) + + def get_x_prev_and_pred_x0(e_t, index): + # select parameters corresponding to the currently considered timestep + a_t = paddle.full((b, 1, 1, 1), alphas[index]) + a_prev = paddle.full((b, 1, 1, 1), alphas_prev[index]) + sigma_t = paddle.full((b, 1, 1, 1), sigmas[index]) + sqrt_one_minus_at = paddle.full( + (b, 1, 1, 1), sqrt_one_minus_alphas[index] + ) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, repeat_noise) * temperature + if noise_dropout > 0.0: + noise = nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + e_t = get_model_output(x, t) + if len(old_eps) == 0: + # Pseudo Improved Euler (2nd order) + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) + e_t_next = get_model_output(x_prev, t_next) + e_t_prime = (e_t + e_t_next) / 2 + elif len(old_eps) == 1: + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (3 * e_t - old_eps[-1]) / 2 + elif len(old_eps) == 2: + # 3nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + elif len(old_eps) >= 3: + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = ( + 55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3] + ) / 24 + + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) + + return x_prev, pred_x0, e_t diff --git a/paddlemix/models/audioldm2/modeling.py b/paddlemix/models/audioldm2/modeling.py new file mode 100644 index 000000000..4e52ac8d6 --- /dev/null +++ b/paddlemix/models/audioldm2/modeling.py @@ -0,0 +1,898 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import os +import numpy as np +from paddlemix.models.model_utils import MixPretrainedModel +# from ppdiffusers.models import LitEma +import soundfile as sf +import tqdm +from .encoders.clap_encoder import CLAPAudioEmbeddingClassifierFreev2 +from .latentdiffusion_samplers import DDIMSampler, PLMSSampler +from .latent_encoder.autoencoder import DiagonalGaussianDistribution +from .diffusionwrapper import ( + DiffusionWrapper, + make_beta_schedule, + extract_into_tensor, + noise_like, + default, + instantiate_from_config, + disabled_train +) +from .configuration import AudioLDM2Config + +__all__ = [ + "AudioLDM2Model", + "AudioLDM2PretrainedModel", +] + +class AudioLDM2PretrainedModel(MixPretrainedModel): + """ + The class for pretrained model of AudioLDM2. + """ + + model_config_file = "config.json" + config_class = AudioLDM2Config + resource_files_names = {"model_state": "model_state.pdparams"} + base_model_prefix = "audioldm2" + +class AudioLDM2Model(AudioLDM2PretrainedModel): + """ + Args: + config (:class:`AudioLDM2Config`): + """ + + def __init__(self, config: AudioLDM2Config): + super(AudioLDM2Model, self).__init__(config) + assert config.parameterization in [ + "eps", + "x0", + "v", + ], 'currently only supporting "eps" and "x0" and "v"' + self.parameterization = config.parameterization + self.device_name = config.device + self.clip_denoised = False + self.log_every_t = config.log_every_t + self.first_stage_key = config.first_stage_key + self.sampling_rate = config.sampling_rate + # self.use_ema = True + # if self.use_ema: + # self.model_ema = LitEma(self.model) + + self.clap = CLAPAudioEmbeddingClassifierFreev2( + pretrained_path="", + enable_cuda=self.device_name=="gpu", + sampling_rate=self.sampling_rate, + embed_mode="audio", + amodel="HTSAT-base", + ) + self.latent_t_size = config.latent_t_size + self.latent_f_size = config.latent_f_size + self.channels = config.channels + self.use_positional_encodings = False + self.conditioning_key = list(config.cond_stage_config.keys()) + self.model = DiffusionWrapper(config.unet_config, self.conditioning_key) + + self.v_posterior = 0.0 # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + + self.num_timesteps_cond = default(config.num_timesteps_cond, 1) + assert self.num_timesteps_cond <= config.timesteps + self.register_schedule( + beta_schedule="linear", + timesteps=config.timesteps, + linear_start=config.linear_start, + linear_end=config.linear_end, + cosine_s=8e-3, + ) + logvar_init = 0.0 + self.logvar = paddle.full(shape=(self.num_timesteps,), fill_value=logvar_init) + self.logvar = paddle.create_parameter( + shape=self.logvar.shape, + dtype=str(self.logvar.numpy().dtype), + default_initializer=nn.initializer.Assign(self.logvar) + ) + self.logvar.stop_gradient = True + + self.register_buffer("scale_factor", paddle.to_tensor(1.0)) + self.instantiate_first_stage(config.first_stage_config) + self.unconditional_prob_cfg = config.unconditional_prob_cfg + self.cond_stage_models = nn.LayerList([]) + self.instantiate_cond_stage(config.cond_stage_config) + self.conditional_dry_run_finished = False + + def instantiate_first_stage(self, config): + model = instantiate_from_config(config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.stop_gradient = True + + def instantiate_cond_stage(self, config): + self.cond_stage_model_metadata = {} + for i, cond_model_key in enumerate(config.keys()): + if "params" in config[cond_model_key] and "device" in config[cond_model_key]["params"]: + config[cond_model_key]["params"]["device"] = self.device_name + model = instantiate_from_config(config[cond_model_key]) + model = model.to(self.device_name) + self.cond_stage_models.append(model) + self.cond_stage_model_metadata[cond_model_key] = { + "model_idx": i, + "cond_stage_key": config[cond_model_key]["cond_stage_key"], + "conditioning_key": config[cond_model_key]["conditioning_key"], + } + + def make_cond_schedule( + self, + ): + self.cond_ids = paddle.full( + size=(self.num_timesteps,), + fill_value=self.num_timesteps - 1, + dtype="int64", + ) + ids = paddle.cast( + paddle.round( + paddle.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond) + ), + dtype="int64" + ) + self.cond_ids[: self.num_timesteps_cond] = ids + + + def register_schedule( + self, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): + betas = make_beta_schedule( + beta_schedule, + timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) + alphas = 1.0 - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) + + (timesteps,) = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert ( + alphas_cumprod.shape[0] == self.num_timesteps + ), "alphas have to be defined for each timestep" + + self.register_buffer("betas", paddle.to_tensor(betas, dtype="float32")) + self.register_buffer("alphas_cumprod", paddle.to_tensor(alphas_cumprod, dtype="float32")) + self.register_buffer("alphas_cumprod_prev", paddle.to_tensor(alphas_cumprod_prev, dtype="float32")) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer("sqrt_alphas_cumprod", paddle.to_tensor(np.sqrt(alphas_cumprod), dtype="float32")) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", paddle.to_tensor(np.sqrt(1.0 - alphas_cumprod), dtype="float32") + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", paddle.to_tensor(np.log(1.0 - alphas_cumprod), dtype="float32") + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", paddle.to_tensor(np.sqrt(1.0 / alphas_cumprod), dtype="float32") + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", paddle.to_tensor(np.sqrt(1.0 / alphas_cumprod - 1), dtype="float32") + ) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * ( + 1.0 - alphas_cumprod_prev + ) / (1.0 - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer("posterior_variance", paddle.to_tensor(posterior_variance, dtype="float32")) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer( + "posterior_log_variance_clipped", + paddle.to_tensor(np.log(np.maximum(posterior_variance, 1e-20)), dtype="float32"), + ) + self.register_buffer( + "posterior_mean_coef1", + paddle.to_tensor(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod), dtype="float32"), + ) + self.register_buffer( + "posterior_mean_coef2", + paddle.to_tensor( + (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod), + dtype="float32" + ), + ) + + if self.parameterization == "eps": + lvlb_weights = self.betas**2 / ( + 2 + * self.posterior_variance + * paddle.to_tensor(alphas, dtype="float32") + * (1 - self.alphas_cumprod) + ) + elif self.parameterization == "x0": + lvlb_weights = ( + 0.5 + * np.sqrt(paddle.to_tensor(alphas_cumprod, dtype="float32")) + / (2.0 * 1 - paddle.to_tensor(alphas_cumprod, dtype="float32")) + ) + elif self.parameterization == "v": + lvlb_weights = paddle.ones_like( + self.betas**2 + / ( + 2 + * self.posterior_variance + * paddle.to_tensor(alphas, dtype="float32") + * (1 - self.alphas_cumprod) + ) + ) + else: + raise NotImplementedError("mu not supported") + # TODO how to choose this term + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer("lvlb_weights", lvlb_weights, persistable=False) + assert not paddle.isnan(self.lvlb_weights).all() + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def make_decision(self, probability): + if float(paddle.rand([])) < probability: + return True + else: + return False + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract_into_tensor( + self.log_one_minus_alphas_cumprod, t, x_start.shape + ) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance( + self, + x, + c, + t, + clip_denoised: bool, + return_codebook_ids=False, + quantize_denoised=False, + return_x0=False, + score_corrector=None, + corrector_kwargs=None, + ): + t_in = t + model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) + + if score_corrector is not None: + assert self.parameterization == "eps" + model_out = score_corrector.modify_score( + self, model_out, x, t, c, **corrector_kwargs + ) + + if return_codebook_ids: + model_out, logits = model_out + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + else: + raise NotImplementedError() + + if clip_denoised: + x_recon.clip_(-1.0, 1.0) + if quantize_denoised: + x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) + model_mean, posterior_variance, posterior_log_variance = self.q_posterior( + x_start=x_recon, x_t=x, t=t + ) + if return_codebook_ids: + return model_mean, posterior_variance, posterior_log_variance, logits + elif return_x0: + return model_mean, posterior_variance, posterior_log_variance, x_recon + else: + return model_mean, posterior_variance, posterior_log_variance + + @paddle.no_grad() + def p_sample( + self, + x, + c, + t, + clip_denoised=False, + repeat_noise=False, + return_codebook_ids=False, + quantize_denoised=False, + return_x0=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + ): + b, *_ = x.shape + outputs = self.p_mean_variance( + x=x, + c=c, + t=t, + clip_denoised=clip_denoised, + return_codebook_ids=return_codebook_ids, + quantize_denoised=quantize_denoised, + return_x0=return_x0, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + ) + if return_codebook_ids: + raise DeprecationWarning("Support dropped.") + elif return_x0: + model_mean, _, model_log_variance, x0 = outputs + else: + model_mean, _, model_log_variance = outputs + + noise = noise_like(x.shape, repeat_noise) * temperature + if noise_dropout > 0.0: + noise = nn.functional.dropout(noise, p=noise_dropout) + # no noise when t == 0 + nonzero_mask = ( + (1 - paddle.cast(t == 0, "float32")).reshape((b, *((1,) * (len(x.shape) - 1)))) + ) + + if return_x0: + return ( + model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, + x0, + ) + else: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @paddle.no_grad() + def p_sample_loop( + self, + cond, + shape, + return_intermediates=False, + x_T=None, + verbose=True, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + start_T=None, + log_every_t=None, + ): + if not log_every_t: + log_every_t = self.log_every_t + b = shape[0] + if x_T is None: + img = paddle.randn(shape) + else: + img = x_T + + intermediates = [img] + if timesteps is None: + timesteps = self.num_timesteps + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = ( + tqdm(reversed(range(0, timesteps)), desc="Sampling t", total=timesteps) + if verbose + else reversed(range(0, timesteps)) + ) + + if mask is not None: + assert x0 is not None + assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match + + for i in iterator: + ts = paddle.full((b,), i, dtype="int64") + + if self.shorten_cond_schedule: + assert self.model.conditioning_key != "hybrid" + tc = self.cond_ids[ts] + cond = self.q_sample(x_start=cond, t=tc, noise=paddle.randn(cond.shapes)) + + img = self.p_sample( + img, + cond, + ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, + ) + + if mask is not None: + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1.0 - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(img) + if callback: + callback(i) + if img_callback: + img_callback(img, i) + + if return_intermediates: + return img, intermediates + return img + + @paddle.no_grad() + def sample( + self, + cond, + batch_size=16, + return_intermediates=False, + x_T=None, + verbose=True, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + shape=None, + **kwargs, + ): + if shape is None: + shape = (batch_size, self.channels, self.latent_t_size, self.latent_f_size) + if cond is not None: + if isinstance(cond, dict): + cond = { + key: cond[key][:batch_size] + if not isinstance(cond[key], list) + else list(map(lambda x: x[:batch_size], cond[key])) + for key in cond + } + else: + cond = ( + [c[:batch_size] for c in cond] + if isinstance(cond, list) + else cond[:batch_size] + ) + return self.p_sample_loop( + cond, + shape, + return_intermediates=return_intermediates, + x_T=x_T, + verbose=verbose, + timesteps=timesteps, + quantize_denoised=quantize_denoised, + mask=mask, + x0=x0, + **kwargs, + ) + + @paddle.no_grad() + def sample_log( + self, + cond, + batch_size, + ddim, + ddim_steps, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + use_plms=False, + mask=None, + **kwargs, + ): + if mask is not None: + shape = (self.channels, mask.shape[-2], mask.shape[-1]) + else: + shape = (self.channels, self.latent_t_size, self.latent_f_size) + + intermediate = None + if ddim and not use_plms: + ddim_sampler = DDIMSampler(self, device=self.device) + samples, intermediates = ddim_sampler.sample( + ddim_steps, + batch_size, + shape, + cond, + verbose=False, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + mask=mask, + **kwargs, + ) + elif use_plms: + plms_sampler = PLMSSampler(self) + samples, intermediates = plms_sampler.sample( + ddim_steps, + batch_size, + shape, + cond, + verbose=False, + unconditional_guidance_scale=unconditional_guidance_scale, + mask=mask, + unconditional_conditioning=unconditional_conditioning, + **kwargs, + ) + + else: + samples, intermediates = self.sample( + cond=cond, + batch_size=batch_size, + return_intermediates=True, + unconditional_guidance_scale=unconditional_guidance_scale, + mask=mask, + unconditional_conditioning=unconditional_conditioning, + **kwargs, + ) + + return samples, intermediate + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: paddle.randn(x_start.shape)) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + * noise + ) + + def predict_start_from_z_and_v(self, x_t, t, v): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v + ) + + def predict_eps_from_z_and_v(self, x_t, t, v): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) + * x_t + ) + + def get_v(self, x, noise, t): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x + ) + + def _get_input(self, batch, k): + fname, text, waveform, stft, fbank, phoneme_idx = ( + batch["fname"], + batch["text"], + batch["waveform"], + batch["stft"], + batch["log_mel_spec"], + batch["phoneme_idx"] + ) + ret = {} + + ret["fbank"] = ( + paddle.cast(fbank.unsqueeze(1), dtype="float32") + ) + ret["stft"] = paddle.cast(stft, dtype="float32") + ret["waveform"] = paddle.cast(waveform, dtype="float32") + ret["phoneme_idx"] = paddle.cast(phoneme_idx, dtype="int64") + ret["text"] = list(text) + ret["fname"] = fname + + for key in batch.keys(): + if key not in ret.keys(): + ret[key] = batch[key] + + return ret[k] + + def get_first_stage_encoding(self, encoder_posterior): + z = encoder_posterior.sample() + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, paddle.Tensor): + z = encoder_posterior + else: + raise NotImplementedError( + f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" + ) + return self.scale_factor * z + + def get_learned_conditioning(self, c, key, unconditional_cfg): + assert key in self.cond_stage_model_metadata.keys() + + # Classifier-free guidance + if not unconditional_cfg: + c = self.cond_stage_models[ + self.cond_stage_model_metadata[key]["model_idx"] + ](c) + else: + # when the cond_stage_key is "all", pick one random element out + if isinstance(c, dict): + c = c[list(c.keys())[0]] + + if isinstance(c, paddle.Tensor): + batchsize = c.shape[0] + elif isinstance(c, list): + batchsize = len(c) + else: + raise NotImplementedError() + + c = self.cond_stage_models[ + self.cond_stage_model_metadata[key]["model_idx"] + ].get_unconditional_condition(batchsize) + + return c + + def get_input( + self, + batch, + k, + return_first_stage_encode=True, + return_decoding_output=False, + return_encoder_input=False, + return_encoder_output=False, + unconditional_prob_cfg=0.1, + ): + x = self._get_input(batch, k) + + if return_first_stage_encode: + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + else: + z = None + cond_dict = {} + if len(self.cond_stage_model_metadata.keys()) > 0: + unconditional_cfg = False + if self.conditional_dry_run_finished and self.make_decision( + unconditional_prob_cfg + ): + unconditional_cfg = True + for cond_model_key in self.cond_stage_model_metadata.keys(): + cond_stage_key = self.cond_stage_model_metadata[cond_model_key][ + "cond_stage_key" + ] + + if cond_model_key in cond_dict.keys(): + continue + + # The original data for conditioning + # If cond_model_key is "all", that means the conditional model need all the information from a batch + if cond_stage_key != "all": + xc = self._get_input(batch, cond_stage_key) + else: + xc = batch + # if cond_stage_key is "all", xc will be a dictionary containing all keys + # Otherwise xc will be an entry of the dictionary + c = self.get_learned_conditioning( + xc, key=cond_model_key, unconditional_cfg=unconditional_cfg + ) + # cond_dict will be used to condition the diffusion model + # If one conditional model return multiple conditioning signal + if isinstance(c, dict): + for k in c.keys(): + cond_dict[k] = c[k] + else: + cond_dict[cond_model_key] = c + + out = [z, cond_dict] + + if return_decoding_output: + xrec = self.decode_first_stage(z) + out += [xrec] + + if return_encoder_input: + out += [x] + + if return_encoder_output: + out += [encoder_posterior] + + if not self.conditional_dry_run_finished: + self.conditional_dry_run_finished = True + + # Output is a dictionary, where the value could only be tensor or tuple + return out + + def encode_first_stage(self, x): + with paddle.no_grad(): + return self.first_stage_model.encode(x) + + def decode_first_stage(self, z): + with paddle.no_grad(): + z = 1.0 / self.scale_factor * z + decoding = self.first_stage_model.decode(z) + return decoding + + def mel_spectrogram_to_waveform( + self, mel, savepath=".", bs=None, name="outwav", save=True + ): + # Mel: [bs, 1, t-steps, fbins] + if len(mel.shape) == 4: + mel = mel.squeeze(1) + mel = mel.transpose([0, 2, 1]) + waveform = self.first_stage_model.vocoder(mel) + waveform = waveform.cpu().detach().numpy() + if save: + self.save_waveform(waveform, savepath, name) + return waveform + + def save_waveform(self, waveform, savepath, name="outwav"): + for i in range(waveform.shape[0]): + if type(name) is str: + path = os.path.join( + savepath, "%s_%s_%s.wav" % (self.global_step, i, name) + ) + elif type(name) is list: + path = os.path.join( + savepath, + "%s.wav" + % ( + os.path.basename(name[i]) + if (not ".wav" in name[i]) + else os.path.basename(name[i]).split(".")[0] + ), + ) + else: + raise NotImplementedError + todo_waveform = waveform[i, 0] + todo_waveform = ( + todo_waveform / np.max(np.abs(todo_waveform)) + ) * 0.8 # Normalize the energy of the generation output + sf.write(path, todo_waveform, samplerate=self.sampling_rate) + + def filter_useful_cond_dict(self, cond_dict): + new_cond_dict = {} + for key in cond_dict.keys(): + if key in self.cond_stage_model_metadata.keys(): + new_cond_dict[key] = cond_dict[key] + + # All the conditional key in the metadata should be used + for key in self.cond_stage_model_metadata.keys(): + assert key in new_cond_dict.keys(), "%s, %s" % ( + key, + str(new_cond_dict.keys()), + ) + + return new_cond_dict + + def reorder_cond_dict(self, cond_dict): + # To make sure the order is correct + new_cond_dict = {} + for key in self.conditioning_key: + new_cond_dict[key] = cond_dict[key] + return new_cond_dict + + def apply_model(self, x_noisy, t, cond, return_ids=False): + cond = self.reorder_cond_dict(cond) + + x_recon = self.model(x_noisy, t, cond_dict=cond) + + if isinstance(x_recon, tuple) and not return_ids: + return x_recon[0] + else: + return x_recon + + def forward( + self, + batch, + ddim_steps=200, + ddim_eta=1.0, + x_T=None, + n_gen=1, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + use_plms=False, + **kwargs, + ): + # Generate n_gen times and select the best + # Batch: audio, text, fnames + assert x_T is None + + if use_plms: + assert ddim_steps is not None + + use_ddim = ddim_steps is not None + + # with self.ema_scope("Plotting"): + for i in range(1): + z, c = self.get_input( + batch, + self.first_stage_key, + unconditional_prob_cfg=0.0, # Do not output unconditional information in the c + ) + + c = self.filter_useful_cond_dict(c) + + text = self._get_input(batch, "text") + + # Generate multiple samples + batch_size = z.shape[0] * n_gen + + # Generate multiple samples at a time and filter out the best + # The condition to the diffusion wrapper can have many format + for cond_key in c.keys(): + if isinstance(c[cond_key], list): + for i in range(len(c[cond_key])): + c[cond_key][i] = paddle.concat([c[cond_key][i]] * n_gen, axis=0) + elif isinstance(c[cond_key], dict): + for k in c[cond_key].keys(): + c[cond_key][k] = paddle.concat([c[cond_key][k]] * n_gen, axis=0) + else: + c[cond_key] = paddle.concat([c[cond_key]] * n_gen, axis=0) + + text = text * n_gen + + if unconditional_guidance_scale != 1.0: + unconditional_conditioning = {} + for key in self.cond_stage_model_metadata: + model_idx = self.cond_stage_model_metadata[key]["model_idx"] + unconditional_conditioning[key] = self.cond_stage_models[ + model_idx + ].get_unconditional_condition(batch_size) + + fnames = list(self._get_input(batch, "fname")) + samples, _ = self.sample_log( + cond=c, + batch_size=batch_size, + x_T=x_T, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + use_plms=use_plms, + ) + + mel = self.decode_first_stage(samples) + + waveform = self.mel_spectrogram_to_waveform( + mel, savepath="", bs=None, name=fnames, save=False + ) + + if n_gen > 1: + best_index = [] + similarity = self.clap.cos_similarity( + paddle.to_tensor(waveform, dtype="float32").squeeze(1), text + ) + for i in range(z.shape[0]): + candidates = similarity[i :: z.shape[0]] + max_index = paddle.argmax(candidates).item() + best_index.append(i + max_index * z.shape[0]) + + waveform = waveform[best_index] + + print("Similarity between generated audio and text:") + print(' '.join('{:.2f}'.format(num) for num in similarity.detach().numpy().tolist())) + print("Choose the following indexes as the output:", best_index) + + return waveform diff --git a/paddlemix/models/audioldm2/requirement.txt b/paddlemix/models/audioldm2/requirement.txt new file mode 100644 index 000000000..4ee0937a9 --- /dev/null +++ b/paddlemix/models/audioldm2/requirement.txt @@ -0,0 +1,4 @@ +librosa +unidecode +phonemizer +espeak \ No newline at end of file diff --git a/paddlemix/models/audioldm2/unet/attention.py b/paddlemix/models/audioldm2/unet/attention.py new file mode 100644 index 000000000..5a6aaa3eb --- /dev/null +++ b/paddlemix/models/audioldm2/unet/attention.py @@ -0,0 +1,199 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle import nn +from ppdiffusers.models.attention import GEGLU +from einops import rearrange, repeat +from ..diffusionwrapper import default + +def Normalize(in_channels): + return nn.GroupNorm( + num_groups=32, num_channels=in_channels, epsilon=1e-6 + ) + +class FeedForward(nn.Layer): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = ( + nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) + if not glu + else GEGLU(dim, inner_dim) + ) + + self.net = nn.Sequential( + project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +class CrossAttention(nn.Layer): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head**-0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias_attr=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias_attr=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias_attr=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) + + sim = paddle.einsum("b i d, b j d -> b i j", q, k) * self.scale + + if mask is not None: + mask = rearrange(mask, "b ... -> b (...)") + max_neg_value = -paddle.finfo(sim.dtype).max + mask = repeat(mask, "b j -> (b h) () j", h=h) + tmp = paddle.full(sim.shape, max_neg_value, sim.dtype) + sim = paddle.where(~(mask == 1), tmp, sim) + + # attention, what we cannot get enough of + attn = nn.functional.softmax(sim, axis=-1) + out = paddle.einsum("b i j, b j d -> b i d", attn, v) + out = rearrange(out, "(b h) n d -> b n (h d)", h=h) + return self.to_out(out) + + +class LinearAttention(nn.Layer): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2D(dim, hidden_dim * 3, 1, bias_attr=False) + self.to_out = nn.Conv2D(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange( + qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 + ) + k = nn.functional.softmax(k, axis=-1) + context = paddle.einsum("bhdn,bhen->bhde", k, v) + out = paddle.einsum("bhde,bhdn->bhen", context, q) + out = rearrange( + out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w + ) + return self.to_out(out) + +class BasicTransformerBlock(nn.Layer): + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + ): + super().__init__() + self.attn1 = CrossAttention( + query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + ) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None, mask=None): + x = self.attn1(self.norm1(x)) + x + x = self.attn2(self.norm2(x), context=context, mask=mask) + x + x = self.ff(self.norm3(x)) + x + return x + +class SpatialTransformer(nn.Layer): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + + def __init__( + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + context_dim=None, + ): + super().__init__() + + context_dim = context_dim + + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + + self.proj_in = nn.Conv2D( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0 + ) + + self.transformer_blocks = nn.LayerList( + [ + BasicTransformerBlock( + inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim + ) + for d in range(depth) + ] + ) + weight_attr = paddle.ParamAttr( + initializer=nn.initializer.Constant(value=0.0) + ) + self.proj_out = nn.Conv2D(inner_dim, in_channels, kernel_size=1, stride=1, padding=0, weight_attr=weight_attr) + + def forward(self, x, context=None, mask=None): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = rearrange(x, "b c h w -> b (h w) c") + for block in self.transformer_blocks: + x = block(x, context=context, mask=mask) + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) + x = self.proj_out(x) + return x + x_in diff --git a/paddlemix/models/audioldm2/unet/openaimodel.py b/paddlemix/models/audioldm2/unet/openaimodel.py new file mode 100644 index 000000000..e40d5a6c0 --- /dev/null +++ b/paddlemix/models/audioldm2/unet/openaimodel.py @@ -0,0 +1,868 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +import math +import numpy as np +from abc import abstractmethod +from .attention import SpatialTransformer +from einops import repeat + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1D(*args, **kwargs) + elif dims == 2: + return nn.Conv2D(*args, **kwargs) + elif dims == 3: + return nn.Conv3D(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1D(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2D(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3D(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = paddle.exp( + -math.log(max_period) + * paddle.arange(start=0, end=half, dtype="float32") + / half + ) + args = paddle.cast(timesteps[:, None], dtype="float32") * freqs[None] + embedding = paddle.concat([paddle.cos(args), paddle.sin(args)], axis=-1) + if dim % 2: + embedding = paddle.concat( + [embedding, paddle.zeros_like(embedding[:, :1])], axis=-1 + ) + else: + embedding = repeat(timesteps, "b -> b d", d=dim) + return embedding + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return paddle.cast(super().forward(paddle.cast(x, dtype="float32")), dtype = x.dtype) + + +class TimestepBlock(nn.Layer): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + +class Upsample(nn.Layer): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd( + dims, self.channels, self.out_channels, 3, padding=padding + ) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Layer): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding, + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.Silu(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.Silu(), + nn.Linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + weight_attr = paddle.ParamAttr( + initializer=nn.initializer.Constant(value=0.0) + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.Silu(), + nn.Dropout(p=dropout), + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, weight_attr=weight_attr) + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = paddle.cast(self.emb_layers(emb), dtype = h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = paddle.chunk(emb_out, 2, axis=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class QKVAttention(nn.Layer): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, axis=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = paddle.einsum( + "bct,bcs->bts", + (q * scale).rehsape([bs * self.n_heads, ch, length]), + (k * scale).rehsape([bs * self.n_heads, ch, length]), + ) # More stable with f16 than dividing afterwards + weight = paddle.cast(F.softmax(paddle.cast(weight, dtype="float32"), axis=-1), dtype=weight.dtype) + a = paddle.einsum( + "bts,bcs->bct", + weight, + v.reshape([bs * self.n_heads, ch, length]), + ) + return a.reshape([bs, -1, length]) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttentionLegacy(nn.Layer): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = ( + qkv.reshape([bs * self.n_heads, ch * 3, length]).split(ch, axis=1) + ) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = paddle.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = paddle.cast(F.softmax(paddle.cast(weight, dtype="float32"), axis=-1), dtype=weight.dtype) + a = paddle.einsum("bts,bcs->bct", weight, v) + return a.reshape([bs, -1, length]) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class AttentionBlock(nn.Layer): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + weight_attr = paddle.ParamAttr( + initializer=nn.initializer.Constant(value=0.0) + ) + self.proj_out = conv_nd(1, channels, channels, 1, weight_attr=weight_attr) + + def forward(self, x): + b, c, *spatial = x.shape + x = x.reshape([b, c, -1]) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape([b, c, *spatial]) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial**2) * c + model.total_ops += paddle.to_tensor([matmul_ops], dtype="float64") + + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context_list=None, mask_list=None): + # The first spatial transformer block does not have context + spatial_transformer_id = 0 + context_list = [None] + context_list + mask_list = [None] + mask_list + + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + if spatial_transformer_id >= len(context_list): + context, mask = None, None + else: + context, mask = ( + context_list[spatial_transformer_id], + mask_list[spatial_transformer_id], + ) + if mask is not None: + mask = paddle.cast(mask, dtype="bool") + x = layer(x, context, mask=mask) + spatial_transformer_id += 1 + else: + x = layer(x) + return x + + +class UNetModel(nn.Layer): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + extra_sa_layer=True, + num_classes=None, + extra_film_condition_dim=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=True, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + ): + super().__init__() + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert ( + num_head_channels != -1 + ), "Either num_heads or num_head_channels has to be set" + + if num_head_channels == -1: + assert ( + num_heads != -1 + ), "Either num_heads or num_head_channels has to be set" + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.extra_film_condition_dim = extra_film_condition_dim + self.use_checkpoint = use_checkpoint + self._dtype = "float16" if use_fp16 else "float32" + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + nn.Linear(model_channels, time_embed_dim), + nn.Silu(), + nn.Linear(time_embed_dim, time_embed_dim), + ) + + # assert not ( + # self.num_classes is not None and self.extra_film_condition_dim is not None + # ), "As for the condition of theh UNet model, you can only set using class label or an extra embedding vector (such as from CLAP). You cannot set both num_classes and extra_film_condition_dim." + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + self.use_extra_film_by_concat = self.extra_film_condition_dim is not None + + if self.extra_film_condition_dim is not None: + self.film_emb = nn.Linear(self.extra_film_condition_dim, time_embed_dim) + print( + "+ Use extra condition on UNet channel using Film. Extra condition dimension is %s. " + % self.extra_film_condition_dim + ) + + if context_dim is not None and not use_spatial_transformer: + assert ( + use_spatial_transformer + ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..." + + if context_dim is not None and not isinstance(context_dim, list): + context_dim = [context_dim] + elif context_dim is None: + context_dim = [None] # At least use one spatial transformer + + self.input_blocks = nn.LayerList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) + if extra_sa_layer: + layers.append( + SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=None, + ) + ) + for context_dim_id in range(len(context_dim)): + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim[context_dim_id], + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + middle_layers = [ + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + if extra_sa_layer: + middle_layers.append( + SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=None + ) + ) + for context_dim_id in range(len(context_dim)): + middle_layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim[context_dim_id], + ) + ) + middle_layers.append( + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ) + self.middle_block = TimestepEmbedSequential(*middle_layers) + + self._feature_size += ch + + self.output_blocks = nn.LayerList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) + if extra_sa_layer: + layers.append( + SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=None, + ) + ) + for context_dim_id in range(len(context_dim)): + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim[context_dim_id], + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + weight_attr = paddle.ParamAttr( + initializer=nn.initializer.Constant(value=0.0) + ) + self.out = nn.Sequential( + normalization(ch), + nn.Silu(), + conv_nd(dims, model_channels, out_channels, 3, padding=1, weight_attr=weight_attr), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + self.shape_reported = False + + def forward( + self, + x, + timesteps=None, + y=None, + context_list=None, + context_attn_mask_list=None, + **kwargs, + ): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. an [N, extra_film_condition_dim] Tensor if film-embed conditional + :return: an [N x C x ...] Tensor of outputs. + """ + if not self.shape_reported: + # print("The shape of UNet input is", x.size()) + self.shape_reported = True + + assert (y is not None) == ( + self.num_classes is not None or self.extra_film_condition_dim is not None + ), "must specify y if and only if the model is class-conditional or film embedding conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.use_extra_film_by_concat: + emb = paddle.concat([emb, self.film_emb(y)], axis=-1) + + h = paddle.cast(x, dtype="float32") + for module in self.input_blocks: + h = module(h, emb, context_list, context_attn_mask_list) + hs.append(h) + h = self.middle_block(h, emb, context_list, context_attn_mask_list) + for module in self.output_blocks: + concate_tensor = hs.pop() + h = paddle.concat([h, concate_tensor], axis=1) + h = module(h, emb, context_list, context_attn_mask_list) + h = paddle.cast(h, dtype=x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) diff --git a/paddlemix/models/audioldm2/utils.py b/paddlemix/models/audioldm2/utils.py new file mode 100644 index 000000000..7adb63e30 --- /dev/null +++ b/paddlemix/models/audioldm2/utils.py @@ -0,0 +1,86 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +from itertools import repeat +import collections.abc +from functools import partial + +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + +to_2tuple = _ntuple(2) + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype) + random_tensor.floor_() # binarize + output = x.divide(keep_prob) * random_tensor + return output + + +class DropPath(nn.Layer): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + +class Mlp(nn.Layer): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks + """ + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + bias=True, + drop=0., + use_conv=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + linear_layer = partial(nn.Conv2D, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer(in_features, hidden_features, bias_attr=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.fc2 = linear_layer(hidden_features, out_features, bias_attr=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + \ No newline at end of file