forked from seungheondoh/lp-music-caps
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5d4e65d
commit 57c5e3b
Showing
6 changed files
with
579 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import os | ||
import argparse | ||
import gradio as gr | ||
from timeit import default_timer as timer | ||
import torch | ||
import numpy as np | ||
import pandas as pd | ||
from huggingface_hub import hf_hub_download | ||
from model.bart import BartCaptionModel | ||
from utils.audio_utils import load_audio, STR_CH_FIRST | ||
|
||
if os.path.isfile("transfer.pth") == False: | ||
torch.hub.download_url_to_file('https://huggingface.co/seungheondoh/lp-music-caps/resolve/main/transfer.pth', 'transfer.pth') | ||
torch.hub.download_url_to_file('https://huggingface.co/seungheondoh/lp-music-caps/resolve/main/electronic.mp3', 'electronic.mp3') | ||
torch.hub.download_url_to_file('https://huggingface.co/seungheondoh/lp-music-caps/resolve/main/orchestra.wav', 'orchestra.wav') | ||
|
||
device = "cuda:0" if torch.cuda.is_available() else "cpu" | ||
|
||
example_list = ['electronic.mp3', 'orchestra.wav'] | ||
model = BartCaptionModel(max_length = 128) | ||
pretrained_object = torch.load('./transfer.pth', map_location='cpu') | ||
state_dict = pretrained_object['state_dict'] | ||
model.load_state_dict(state_dict) | ||
if torch.cuda.is_available(): | ||
torch.cuda.set_device(device) | ||
model = model.cuda(device) | ||
model.eval() | ||
|
||
def get_audio(audio_path, duration=10, target_sr=16000): | ||
n_samples = int(duration * target_sr) | ||
audio, sr = load_audio( | ||
path= audio_path, | ||
ch_format= STR_CH_FIRST, | ||
sample_rate= target_sr, | ||
downmix_to_mono= True, | ||
) | ||
if len(audio.shape) == 2: | ||
audio = audio.mean(0, False) # to mono | ||
input_size = int(n_samples) | ||
if audio.shape[-1] < input_size: # pad sequence | ||
pad = np.zeros(input_size) | ||
pad[: audio.shape[-1]] = audio | ||
audio = pad | ||
ceil = int(audio.shape[-1] // n_samples) | ||
audio = torch.from_numpy(np.stack(np.split(audio[:ceil * n_samples], ceil)).astype('float32')) | ||
return audio | ||
|
||
def captioning(audio_path): | ||
audio_tensor = get_audio(audio_path = audio_path) | ||
if device is not None: | ||
audio_tensor = audio_tensor.to(device) | ||
with torch.no_grad(): | ||
output = model.generate( | ||
samples=audio_tensor, | ||
num_beams=5, | ||
) | ||
inference = "" | ||
number_of_chunks = range(audio_tensor.shape[0]) | ||
for chunk, text in zip(number_of_chunks, output): | ||
time = f"[{chunk * 10}:00-{(chunk + 1) * 10}:00]" | ||
inference += f"{time}\n{text} \n \n" | ||
return inference | ||
|
||
title = "Interactive demo: Music Captioning 🤖🎵" | ||
description = """ | ||
<p style='text-align: center'> LP-MusicCaps: LLM-Based Pseudo Music Captioning</p> | ||
<p style='text-align: center'> SeungHeon Doh, Keunwoo Choi, Jongpil Lee, Juhan Nam, ISMIR 2023</p> | ||
<p style='text-align: center'> <a href='#' target='_blank'>ArXiv</a> | <a href='https://github.com/seungheondoh/lp-music-caps' target='_blank'>Github</a> | <a href='https://github.com/seungheondoh/lp-music-caps' target='_blank'>LP-MusicCaps-Dataset</a> </p> | ||
<p style='text-align: center'> To use it, simply upload your audio and click 'submit', or click one of the examples to load them. Read more at the links below. </p> | ||
""" | ||
article = "<p style='text-align: center'><a href='https://github.com/seungheondoh/lp-music-caps' target='_blank'>LP-MusicCaps Github</a> | <a href='#' target='_blank'>LP-MusicCaps Paper</a></p>" | ||
|
||
|
||
demo = gr.Interface(fn=captioning, | ||
inputs=gr.Audio(type="filepath"), | ||
outputs=[ | ||
gr.Textbox(label="Caption generated by LP-MusicCaps Transfer Model"), | ||
], | ||
examples=example_list, | ||
title=title, | ||
description=description, | ||
article=article, | ||
cache_examples=False | ||
) | ||
demo.launch() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import numpy as np | ||
from .modules import AudioEncoder | ||
from transformers import BartForConditionalGeneration, BartTokenizer, BartConfig | ||
|
||
class BartCaptionModel(nn.Module): | ||
def __init__(self, n_mels=128, num_of_conv=6, sr=16000, duration=10, max_length=128, label_smoothing=0.1, bart_type="facebook/bart-base", audio_dim=768): | ||
super(BartCaptionModel, self).__init__() | ||
# non-finetunning case | ||
bart_config = BartConfig.from_pretrained(bart_type) | ||
self.tokenizer = BartTokenizer.from_pretrained(bart_type) | ||
self.bart = BartForConditionalGeneration(bart_config) | ||
|
||
self.n_sample = sr * duration | ||
self.hop_length = int(0.01 * sr) # hard coding hop_size | ||
self.n_frames = int(self.n_sample // self.hop_length) | ||
self.num_of_stride_conv = num_of_conv - 1 | ||
self.n_ctx = int(self.n_frames // 2**self.num_of_stride_conv) + 1 | ||
self.audio_encoder = AudioEncoder( | ||
n_mels = n_mels, # hard coding n_mel | ||
n_ctx = self.n_ctx, | ||
audio_dim = audio_dim, | ||
text_dim = self.bart.config.hidden_size, | ||
num_of_stride_conv = self.num_of_stride_conv | ||
) | ||
|
||
self.max_length = max_length | ||
self.loss_fct = nn.CrossEntropyLoss(label_smoothing= label_smoothing, ignore_index=-100) | ||
|
||
@property | ||
def device(self): | ||
return list(self.parameters())[0].device | ||
|
||
def shift_tokens_right(self, input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): | ||
""" | ||
Shift input ids one token to the right.ls | ||
""" | ||
shifted_input_ids = input_ids.new_zeros(input_ids.shape) | ||
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() | ||
shifted_input_ids[:, 0] = decoder_start_token_id | ||
|
||
if pad_token_id is None: | ||
raise ValueError("self.model.config.pad_token_id has to be defined.") | ||
# replace possible -100 values in labels by `pad_token_id` | ||
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) | ||
return shifted_input_ids | ||
|
||
def forward_encoder(self, audio): | ||
audio_embs = self.audio_encoder(audio) | ||
encoder_outputs = self.bart.model.encoder( | ||
input_ids=None, | ||
inputs_embeds=audio_embs, | ||
return_dict=True | ||
)["last_hidden_state"] | ||
return encoder_outputs, audio_embs | ||
|
||
def forward_decoder(self, text, encoder_outputs): | ||
text = self.tokenizer(text, | ||
padding='longest', | ||
truncation=True, | ||
max_length=self.max_length, | ||
return_tensors="pt") | ||
input_ids = text["input_ids"].to(self.device) | ||
attention_mask = text["attention_mask"].to(self.device) | ||
|
||
decoder_targets = input_ids.masked_fill( | ||
input_ids == self.tokenizer.pad_token_id, -100 | ||
) | ||
|
||
decoder_input_ids = self.shift_tokens_right( | ||
decoder_targets, self.bart.config.pad_token_id, self.bart.config.decoder_start_token_id | ||
) | ||
|
||
decoder_outputs = self.bart( | ||
input_ids=None, | ||
attention_mask=None, | ||
decoder_input_ids=decoder_input_ids, | ||
decoder_attention_mask=attention_mask, | ||
inputs_embeds=None, | ||
labels=None, | ||
encoder_outputs=(encoder_outputs,), | ||
return_dict=True | ||
) | ||
lm_logits = decoder_outputs["logits"] | ||
loss = self.loss_fct(lm_logits.view(-1, self.tokenizer.vocab_size), decoder_targets.view(-1)) | ||
return loss | ||
|
||
def forward(self, audio, text): | ||
encoder_outputs, _ = self.forward_encoder(audio) | ||
loss = self.forward_decoder(text, encoder_outputs) | ||
return loss | ||
|
||
def generate(self, | ||
samples, | ||
use_nucleus_sampling=False, | ||
num_beams=5, | ||
max_length=128, | ||
min_length=2, | ||
top_p=0.9, | ||
repetition_penalty=1.0, | ||
): | ||
|
||
# self.bart.force_bos_token_to_be_generated = True | ||
audio_embs = self.audio_encoder(samples) | ||
encoder_outputs = self.bart.model.encoder( | ||
input_ids=None, | ||
attention_mask=None, | ||
head_mask=None, | ||
inputs_embeds=audio_embs, | ||
output_attentions=None, | ||
output_hidden_states=None, | ||
return_dict=True) | ||
|
||
input_ids = torch.zeros((encoder_outputs['last_hidden_state'].size(0), 1)).long().to(self.device) | ||
input_ids[:, 0] = self.bart.config.decoder_start_token_id | ||
decoder_attention_mask = torch.ones((encoder_outputs['last_hidden_state'].size(0), 1)).long().to(self.device) | ||
if use_nucleus_sampling: | ||
outputs = self.bart.generate( | ||
input_ids=None, | ||
attention_mask=None, | ||
decoder_input_ids=input_ids, | ||
decoder_attention_mask=decoder_attention_mask, | ||
encoder_outputs=encoder_outputs, | ||
max_length=max_length, | ||
min_length=min_length, | ||
do_sample=True, | ||
top_p=top_p, | ||
num_return_sequences=1, | ||
repetition_penalty=1.1) | ||
else: | ||
outputs = self.bart.generate(input_ids=None, | ||
attention_mask=None, | ||
decoder_input_ids=input_ids, | ||
decoder_attention_mask=decoder_attention_mask, | ||
encoder_outputs=encoder_outputs, | ||
head_mask=None, | ||
decoder_head_mask=None, | ||
inputs_embeds=None, | ||
decoder_inputs_embeds=None, | ||
use_cache=None, | ||
output_attentions=None, | ||
output_hidden_states=None, | ||
max_length=max_length, | ||
min_length=min_length, | ||
num_beams=num_beams, | ||
repetition_penalty=repetition_penalty) | ||
|
||
captions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) | ||
return captions |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
### code reference: https://github.com/openai/whisper/blob/main/whisper/audio.py | ||
|
||
import os | ||
import torch | ||
import torchaudio | ||
import numpy as np | ||
import torch.nn.functional as F | ||
from torch import Tensor, nn | ||
from typing import Dict, Iterable, Optional | ||
|
||
# hard-coded audio hyperparameters | ||
SAMPLE_RATE = 16000 | ||
N_FFT = 1024 | ||
N_MELS = 128 | ||
HOP_LENGTH = int(0.01 * SAMPLE_RATE) | ||
DURATION = 10 | ||
N_SAMPLES = int(DURATION * SAMPLE_RATE) | ||
N_FRAMES = N_SAMPLES // HOP_LENGTH + 1 | ||
|
||
def sinusoids(length, channels, max_timescale=10000): | ||
"""Returns sinusoids for positional embedding""" | ||
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) | ||
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) | ||
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] | ||
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) | ||
|
||
class MelEncoder(nn.Module): | ||
""" | ||
time-frequency represntation | ||
""" | ||
def __init__(self, | ||
sample_rate= 16000, | ||
f_min=0, | ||
f_max=8000, | ||
n_fft=1024, | ||
win_length=1024, | ||
hop_length = int(0.01 * 16000), | ||
n_mels = 128, | ||
power = None, | ||
pad= 0, | ||
normalized= False, | ||
center= True, | ||
pad_mode= "reflect" | ||
): | ||
super(MelEncoder, self).__init__() | ||
self.window = torch.hann_window(win_length) | ||
self.spec_fn = torchaudio.transforms.Spectrogram( | ||
n_fft = n_fft, | ||
win_length = win_length, | ||
hop_length = hop_length, | ||
power = power | ||
) | ||
self.mel_scale = torchaudio.transforms.MelScale( | ||
n_mels, | ||
sample_rate, | ||
f_min, | ||
f_max, | ||
n_fft // 2 + 1) | ||
|
||
self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB() | ||
|
||
def forward(self, wav): | ||
spec = self.spec_fn(wav) | ||
power_spec = spec.real.abs().pow(2) | ||
mel_spec = self.mel_scale(power_spec) | ||
mel_spec = self.amplitude_to_db(mel_spec) # Log10(max(reference value and amin)) | ||
return mel_spec | ||
|
||
class AudioEncoder(nn.Module): | ||
def __init__( | ||
self, n_mels: int, n_ctx: int, audio_dim: int, text_dim: int, num_of_stride_conv: int, | ||
): | ||
super().__init__() | ||
self.mel_encoder = MelEncoder(n_mels=n_mels) | ||
self.conv1 = nn.Conv1d(n_mels, audio_dim, kernel_size=3, padding=1) | ||
self.conv_stack = nn.ModuleList([]) | ||
for _ in range(num_of_stride_conv): | ||
self.conv_stack.append( | ||
nn.Conv1d(audio_dim, audio_dim, kernel_size=3, stride=2, padding=1) | ||
) | ||
# self.proj = nn.Linear(audio_dim, text_dim, bias=False) | ||
self.register_buffer("positional_embedding", sinusoids(n_ctx, text_dim)) | ||
|
||
def forward(self, x: Tensor): | ||
""" | ||
x : torch.Tensor, shape = (batch_size, waveform) | ||
single channel wavform | ||
""" | ||
x = self.mel_encoder(x) # (batch_size, n_mels, n_ctx) | ||
x = F.gelu(self.conv1(x)) | ||
for conv in self.conv_stack: | ||
x = F.gelu(conv(x)) | ||
x = x.permute(0, 2, 1) | ||
x = (x + self.positional_embedding).to(x.dtype) | ||
return x |
Oops, something went wrong.