diff --git a/demo/app.py b/demo/app.py new file mode 100644 index 0000000..8998805 --- /dev/null +++ b/demo/app.py @@ -0,0 +1,85 @@ +import os +import argparse +import gradio as gr +from timeit import default_timer as timer +import torch +import numpy as np +import pandas as pd +from huggingface_hub import hf_hub_download +from model.bart import BartCaptionModel +from utils.audio_utils import load_audio, STR_CH_FIRST + +if os.path.isfile("transfer.pth") == False: + torch.hub.download_url_to_file('https://huggingface.co/seungheondoh/lp-music-caps/resolve/main/transfer.pth', 'transfer.pth') + torch.hub.download_url_to_file('https://huggingface.co/seungheondoh/lp-music-caps/resolve/main/electronic.mp3', 'electronic.mp3') + torch.hub.download_url_to_file('https://huggingface.co/seungheondoh/lp-music-caps/resolve/main/orchestra.wav', 'orchestra.wav') + +device = "cuda:0" if torch.cuda.is_available() else "cpu" + +example_list = ['electronic.mp3', 'orchestra.wav'] +model = BartCaptionModel(max_length = 128) +pretrained_object = torch.load('./transfer.pth', map_location='cpu') +state_dict = pretrained_object['state_dict'] +model.load_state_dict(state_dict) +if torch.cuda.is_available(): + torch.cuda.set_device(device) +model = model.cuda(device) +model.eval() + +def get_audio(audio_path, duration=10, target_sr=16000): + n_samples = int(duration * target_sr) + audio, sr = load_audio( + path= audio_path, + ch_format= STR_CH_FIRST, + sample_rate= target_sr, + downmix_to_mono= True, + ) + if len(audio.shape) == 2: + audio = audio.mean(0, False) # to mono + input_size = int(n_samples) + if audio.shape[-1] < input_size: # pad sequence + pad = np.zeros(input_size) + pad[: audio.shape[-1]] = audio + audio = pad + ceil = int(audio.shape[-1] // n_samples) + audio = torch.from_numpy(np.stack(np.split(audio[:ceil * n_samples], ceil)).astype('float32')) + return audio + +def captioning(audio_path): + audio_tensor = get_audio(audio_path = audio_path) + if device is not None: + audio_tensor = audio_tensor.to(device) + with torch.no_grad(): + output = model.generate( + samples=audio_tensor, + num_beams=5, + ) + inference = "" + number_of_chunks = range(audio_tensor.shape[0]) + for chunk, text in zip(number_of_chunks, output): + time = f"[{chunk * 10}:00-{(chunk + 1) * 10}:00]" + inference += f"{time}\n{text} \n \n" + return inference + +title = "Interactive demo: Music Captioning 🤖🎵" +description = """ +
LP-MusicCaps: LLM-Based Pseudo Music Captioning
+SeungHeon Doh, Keunwoo Choi, Jongpil Lee, Juhan Nam, ISMIR 2023
+ArXiv | Github | LP-MusicCaps-Dataset
+To use it, simply upload your audio and click 'submit', or click one of the examples to load them. Read more at the links below.
+""" +article = "" + + +demo = gr.Interface(fn=captioning, + inputs=gr.Audio(type="filepath"), + outputs=[ + gr.Textbox(label="Caption generated by LP-MusicCaps Transfer Model"), + ], + examples=example_list, + title=title, + description=description, + article=article, + cache_examples=False + ) +demo.launch() \ No newline at end of file diff --git a/demo/model/bart.py b/demo/model/bart.py new file mode 100644 index 0000000..49b3986 --- /dev/null +++ b/demo/model/bart.py @@ -0,0 +1,151 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from .modules import AudioEncoder +from transformers import BartForConditionalGeneration, BartTokenizer, BartConfig + +class BartCaptionModel(nn.Module): + def __init__(self, n_mels=128, num_of_conv=6, sr=16000, duration=10, max_length=128, label_smoothing=0.1, bart_type="facebook/bart-base", audio_dim=768): + super(BartCaptionModel, self).__init__() + # non-finetunning case + bart_config = BartConfig.from_pretrained(bart_type) + self.tokenizer = BartTokenizer.from_pretrained(bart_type) + self.bart = BartForConditionalGeneration(bart_config) + + self.n_sample = sr * duration + self.hop_length = int(0.01 * sr) # hard coding hop_size + self.n_frames = int(self.n_sample // self.hop_length) + self.num_of_stride_conv = num_of_conv - 1 + self.n_ctx = int(self.n_frames // 2**self.num_of_stride_conv) + 1 + self.audio_encoder = AudioEncoder( + n_mels = n_mels, # hard coding n_mel + n_ctx = self.n_ctx, + audio_dim = audio_dim, + text_dim = self.bart.config.hidden_size, + num_of_stride_conv = self.num_of_stride_conv + ) + + self.max_length = max_length + self.loss_fct = nn.CrossEntropyLoss(label_smoothing= label_smoothing, ignore_index=-100) + + @property + def device(self): + return list(self.parameters())[0].device + + def shift_tokens_right(self, input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right.ls + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + return shifted_input_ids + + def forward_encoder(self, audio): + audio_embs = self.audio_encoder(audio) + encoder_outputs = self.bart.model.encoder( + input_ids=None, + inputs_embeds=audio_embs, + return_dict=True + )["last_hidden_state"] + return encoder_outputs, audio_embs + + def forward_decoder(self, text, encoder_outputs): + text = self.tokenizer(text, + padding='longest', + truncation=True, + max_length=self.max_length, + return_tensors="pt") + input_ids = text["input_ids"].to(self.device) + attention_mask = text["attention_mask"].to(self.device) + + decoder_targets = input_ids.masked_fill( + input_ids == self.tokenizer.pad_token_id, -100 + ) + + decoder_input_ids = self.shift_tokens_right( + decoder_targets, self.bart.config.pad_token_id, self.bart.config.decoder_start_token_id + ) + + decoder_outputs = self.bart( + input_ids=None, + attention_mask=None, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=attention_mask, + inputs_embeds=None, + labels=None, + encoder_outputs=(encoder_outputs,), + return_dict=True + ) + lm_logits = decoder_outputs["logits"] + loss = self.loss_fct(lm_logits.view(-1, self.tokenizer.vocab_size), decoder_targets.view(-1)) + return loss + + def forward(self, audio, text): + encoder_outputs, _ = self.forward_encoder(audio) + loss = self.forward_decoder(text, encoder_outputs) + return loss + + def generate(self, + samples, + use_nucleus_sampling=False, + num_beams=5, + max_length=128, + min_length=2, + top_p=0.9, + repetition_penalty=1.0, + ): + + # self.bart.force_bos_token_to_be_generated = True + audio_embs = self.audio_encoder(samples) + encoder_outputs = self.bart.model.encoder( + input_ids=None, + attention_mask=None, + head_mask=None, + inputs_embeds=audio_embs, + output_attentions=None, + output_hidden_states=None, + return_dict=True) + + input_ids = torch.zeros((encoder_outputs['last_hidden_state'].size(0), 1)).long().to(self.device) + input_ids[:, 0] = self.bart.config.decoder_start_token_id + decoder_attention_mask = torch.ones((encoder_outputs['last_hidden_state'].size(0), 1)).long().to(self.device) + if use_nucleus_sampling: + outputs = self.bart.generate( + input_ids=None, + attention_mask=None, + decoder_input_ids=input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + max_length=max_length, + min_length=min_length, + do_sample=True, + top_p=top_p, + num_return_sequences=1, + repetition_penalty=1.1) + else: + outputs = self.bart.generate(input_ids=None, + attention_mask=None, + decoder_input_ids=input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + head_mask=None, + decoder_head_mask=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + max_length=max_length, + min_length=min_length, + num_beams=num_beams, + repetition_penalty=repetition_penalty) + + captions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) + return captions diff --git a/demo/model/modules.py b/demo/model/modules.py new file mode 100644 index 0000000..788baa6 --- /dev/null +++ b/demo/model/modules.py @@ -0,0 +1,95 @@ +### code reference: https://github.com/openai/whisper/blob/main/whisper/audio.py + +import os +import torch +import torchaudio +import numpy as np +import torch.nn.functional as F +from torch import Tensor, nn +from typing import Dict, Iterable, Optional + +# hard-coded audio hyperparameters +SAMPLE_RATE = 16000 +N_FFT = 1024 +N_MELS = 128 +HOP_LENGTH = int(0.01 * SAMPLE_RATE) +DURATION = 10 +N_SAMPLES = int(DURATION * SAMPLE_RATE) +N_FRAMES = N_SAMPLES // HOP_LENGTH + 1 + +def sinusoids(length, channels, max_timescale=10000): + """Returns sinusoids for positional embedding""" + log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) + scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] + return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) + +class MelEncoder(nn.Module): + """ + time-frequency represntation + """ + def __init__(self, + sample_rate= 16000, + f_min=0, + f_max=8000, + n_fft=1024, + win_length=1024, + hop_length = int(0.01 * 16000), + n_mels = 128, + power = None, + pad= 0, + normalized= False, + center= True, + pad_mode= "reflect" + ): + super(MelEncoder, self).__init__() + self.window = torch.hann_window(win_length) + self.spec_fn = torchaudio.transforms.Spectrogram( + n_fft = n_fft, + win_length = win_length, + hop_length = hop_length, + power = power + ) + self.mel_scale = torchaudio.transforms.MelScale( + n_mels, + sample_rate, + f_min, + f_max, + n_fft // 2 + 1) + + self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB() + + def forward(self, wav): + spec = self.spec_fn(wav) + power_spec = spec.real.abs().pow(2) + mel_spec = self.mel_scale(power_spec) + mel_spec = self.amplitude_to_db(mel_spec) # Log10(max(reference value and amin)) + return mel_spec + +class AudioEncoder(nn.Module): + def __init__( + self, n_mels: int, n_ctx: int, audio_dim: int, text_dim: int, num_of_stride_conv: int, + ): + super().__init__() + self.mel_encoder = MelEncoder(n_mels=n_mels) + self.conv1 = nn.Conv1d(n_mels, audio_dim, kernel_size=3, padding=1) + self.conv_stack = nn.ModuleList([]) + for _ in range(num_of_stride_conv): + self.conv_stack.append( + nn.Conv1d(audio_dim, audio_dim, kernel_size=3, stride=2, padding=1) + ) + # self.proj = nn.Linear(audio_dim, text_dim, bias=False) + self.register_buffer("positional_embedding", sinusoids(n_ctx, text_dim)) + + def forward(self, x: Tensor): + """ + x : torch.Tensor, shape = (batch_size, waveform) + single channel wavform + """ + x = self.mel_encoder(x) # (batch_size, n_mels, n_ctx) + x = F.gelu(self.conv1(x)) + for conv in self.conv_stack: + x = F.gelu(conv(x)) + x = x.permute(0, 2, 1) + x = (x + self.positional_embedding).to(x.dtype) + return x \ No newline at end of file diff --git a/demo/utils/audio_utils.py b/demo/utils/audio_utils.py new file mode 100644 index 0000000..d033238 --- /dev/null +++ b/demo/utils/audio_utils.py @@ -0,0 +1,247 @@ +STR_CLIP_ID = 'clip_id' +STR_AUDIO_SIGNAL = 'audio_signal' +STR_TARGET_VECTOR = 'target_vector' + + +STR_CH_FIRST = 'channels_first' +STR_CH_LAST = 'channels_last' + +import io +import os +import tqdm +import logging +import subprocess +from typing import Tuple +from pathlib import Path + +# import librosa +import numpy as np +import soundfile as sf + +import itertools +from numpy.fft import irfft + +def _resample_load_ffmpeg(path: str, sample_rate: int, downmix_to_mono: bool) -> Tuple[np.ndarray, int]: + """ + Decoding, downmixing, and downsampling by librosa. + Returns a channel-first audio signal. + + Args: + path: + sample_rate: + downmix_to_mono: + + Returns: + (audio signal, sample rate) + """ + + def _decode_resample_by_ffmpeg(filename, sr): + """decode, downmix, and resample audio file""" + channel_cmd = '-ac 1 ' if downmix_to_mono else '' # downmixing option + resampling_cmd = f'-ar {str(sr)}' if sr else '' # downsampling option + cmd = f"ffmpeg -i \"{filename}\" {channel_cmd} {resampling_cmd} -f wav -" + p = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + out, err = p.communicate() + return out + + src, sr = sf.read(io.BytesIO(_decode_resample_by_ffmpeg(path, sr=sample_rate))) + return src.T, sr + + +def _resample_load_librosa(path: str, sample_rate: int, downmix_to_mono: bool, **kwargs) -> Tuple[np.ndarray, int]: + """ + Decoding, downmixing, and downsampling by librosa. + Returns a channel-first audio signal. + """ + src, sr = librosa.load(path, sr=sample_rate, mono=downmix_to_mono, **kwargs) + return src, sr + + +def load_audio( + path: str or Path, + ch_format: str, + sample_rate: int = None, + downmix_to_mono: bool = False, + resample_by: str = 'ffmpeg', + **kwargs, +) -> Tuple[np.ndarray, int]: + """A wrapper of librosa.load that: + - forces the returned audio to be 2-dim, + - defaults to sr=None, and + - defaults to downmix_to_mono=False. + + The audio decoding is done by `audioread` or `soundfile` package and ultimately, often by ffmpeg. + The resampling is done by `librosa`'s child package `resampy`. + + Args: + path: audio file path + ch_format: one of 'channels_first' or 'channels_last' + sample_rate: target sampling rate. if None, use the rate of the audio file + downmix_to_mono: + resample_by (str): 'librosa' or 'ffmpeg'. it decides backend for audio decoding and resampling. + **kwargs: keyword args for librosa.load - offset, duration, dtype, res_type. + + Returns: + (audio, sr) tuple + """ + if ch_format not in (STR_CH_FIRST, STR_CH_LAST): + raise ValueError(f'ch_format is wrong here -> {ch_format}') + + if os.stat(path).st_size > 8000: + if resample_by == 'librosa': + src, sr = _resample_load_librosa(path, sample_rate, downmix_to_mono, **kwargs) + elif resample_by == 'ffmpeg': + src, sr = _resample_load_ffmpeg(path, sample_rate, downmix_to_mono) + else: + raise NotImplementedError(f'resample_by: "{resample_by}" is not supposred yet') + else: + raise ValueError('Given audio is too short!') + return src, sr + + # if src.ndim == 1: + # src = np.expand_dims(src, axis=0) + # # now always 2d and channels_first + + # if ch_format == STR_CH_FIRST: + # return src, sr + # else: + # return src.T, sr + +def ms(x): + """Mean value of signal `x` squared. + :param x: Dynamic quantity. + :returns: Mean squared of `x`. + """ + return (np.abs(x)**2.0).mean() + +def normalize(y, x=None): + """normalize power in y to a (standard normal) white noise signal. + Optionally normalize to power in signal `x`. + #The mean power of a Gaussian with :math:`\\mu=0` and :math:`\\sigma=1` is 1. + """ + if x is not None: + x = ms(x) + else: + x = 1.0 + return y * np.sqrt(x / ms(y)) + +def noise(N, color='white', state=None): + """Noise generator. + :param N: Amount of samples. + :param color: Color of noise. + :param state: State of PRNG. + :type state: :class:`np.random.RandomState` + """ + try: + return _noise_generators[color](N, state) + except KeyError: + raise ValueError("Incorrect color.") + +def white(N, state=None): + """ + White noise. + :param N: Amount of samples. + :param state: State of PRNG. + :type state: :class:`np.random.RandomState` + White noise has a constant power density. It's narrowband spectrum is therefore flat. + The power in white noise will increase by a factor of two for each octave band, + and therefore increases with 3 dB per octave. + """ + state = np.random.RandomState() if state is None else state + return state.randn(N) + +def pink(N, state=None): + """ + Pink noise. + :param N: Amount of samples. + :param state: State of PRNG. + :type state: :class:`np.random.RandomState` + Pink noise has equal power in bands that are proportionally wide. + Power density decreases with 3 dB per octave. + """ + state = np.random.RandomState() if state is None else state + uneven = N % 2 + X = state.randn(N // 2 + 1 + uneven) + 1j * state.randn(N // 2 + 1 + uneven) + S = np.sqrt(np.arange(len(X)) + 1.) # +1 to avoid divide by zero + y = (irfft(X / S)).real + if uneven: + y = y[:-1] + return normalize(y) + +def blue(N, state=None): + """ + Blue noise. + :param N: Amount of samples. + :param state: State of PRNG. + :type state: :class:`np.random.RandomState` + Power increases with 6 dB per octave. + Power density increases with 3 dB per octave. + """ + state = np.random.RandomState() if state is None else state + uneven = N % 2 + X = state.randn(N // 2 + 1 + uneven) + 1j * state.randn(N // 2 + 1 + uneven) + S = np.sqrt(np.arange(len(X))) # Filter + y = (irfft(X * S)).real + if uneven: + y = y[:-1] + return normalize(y) + +def brown(N, state=None): + """ + Violet noise. + :param N: Amount of samples. + :param state: State of PRNG. + :type state: :class:`np.random.RandomState` + Power decreases with -3 dB per octave. + Power density decreases with 6 dB per octave. + """ + state = np.random.RandomState() if state is None else state + uneven = N % 2 + X = state.randn(N // 2 + 1 + uneven) + 1j * state.randn(N // 2 + 1 + uneven) + S = (np.arange(len(X)) + 1) # Filter + y = (irfft(X / S)).real + if uneven: + y = y[:-1] + return normalize(y) + +def violet(N, state=None): + """ + Violet noise. Power increases with 6 dB per octave. + :param N: Amount of samples. + :param state: State of PRNG. + :type state: :class:`np.random.RandomState` + Power increases with +9 dB per octave. + Power density increases with +6 dB per octave. + """ + state = np.random.RandomState() if state is None else state + uneven = N % 2 + X = state.randn(N // 2 + 1 + uneven) + 1j * state.randn(N // 2 + 1 + uneven) + S = (np.arange(len(X))) # Filter + y = (irfft(X * S)).real + if uneven: + y = y[:-1] + return normalize(y) + +_noise_generators = { + 'white': white, + 'pink': pink, + 'blue': blue, + 'brown': brown, + 'violet': violet, +} + +def noise_generator(N=44100, color='white', state=None): + """Noise generator. + :param N: Amount of unique samples to generate. + :param color: Color of noise. + Generate `N` amount of unique samples and cycle over these samples. + """ + #yield from itertools.cycle(noise(N, color)) # Python 3.3 + for sample in itertools.cycle(noise(N, color, state)): + yield sample + +def heaviside(N): + """Heaviside. + Returns the value 0 for `x < 0`, 1 for `x > 0`, and 1/2 for `x = 0`. + """ + return 0.5 * (np.sign(N) + 1) \ No newline at end of file diff --git a/lpmc/llm_captioning/eval.py b/lpmc/llm_captioning/eval.py index b9eceda..abb14af 100644 --- a/lpmc/llm_captioning/eval.py +++ b/lpmc/llm_captioning/eval.py @@ -21,8 +21,6 @@ def baseline_generation(dataset, prediction_col): predictions.append(tag_concat) elif prediction_col == "baseline_template": predictions.append(_apply_template(tag_concat)) - elif prediction_col == "baseline_k2c": - predictions.append(nlp(tag_list)) return predictions def inference_parsing(dataset, prediction_col): diff --git a/setup.py b/setup.py index 605587e..3a3e8bf 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,6 @@ 'evaluate==0.4.0', 'bert_score==0.3.13', 'rouge_score==0.1.2', - 'opencv-python', - 'keytotext' + 'gradio==3.36.1' ] ) \ No newline at end of file