Skip to content

Commit

Permalink
add demo
Browse files Browse the repository at this point in the history
  • Loading branch information
seungheondoh committed Jul 12, 2023
1 parent 5d4e65d commit 57c5e3b
Show file tree
Hide file tree
Showing 6 changed files with 579 additions and 4 deletions.
85 changes: 85 additions & 0 deletions demo/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import os
import argparse
import gradio as gr
from timeit import default_timer as timer
import torch
import numpy as np
import pandas as pd
from huggingface_hub import hf_hub_download
from model.bart import BartCaptionModel
from utils.audio_utils import load_audio, STR_CH_FIRST

if os.path.isfile("transfer.pth") == False:
torch.hub.download_url_to_file('https://huggingface.co/seungheondoh/lp-music-caps/resolve/main/transfer.pth', 'transfer.pth')
torch.hub.download_url_to_file('https://huggingface.co/seungheondoh/lp-music-caps/resolve/main/electronic.mp3', 'electronic.mp3')
torch.hub.download_url_to_file('https://huggingface.co/seungheondoh/lp-music-caps/resolve/main/orchestra.wav', 'orchestra.wav')

device = "cuda:0" if torch.cuda.is_available() else "cpu"

example_list = ['electronic.mp3', 'orchestra.wav']
model = BartCaptionModel(max_length = 128)
pretrained_object = torch.load('./transfer.pth', map_location='cpu')
state_dict = pretrained_object['state_dict']
model.load_state_dict(state_dict)
if torch.cuda.is_available():
torch.cuda.set_device(device)
model = model.cuda(device)
model.eval()

def get_audio(audio_path, duration=10, target_sr=16000):
n_samples = int(duration * target_sr)
audio, sr = load_audio(
path= audio_path,
ch_format= STR_CH_FIRST,
sample_rate= target_sr,
downmix_to_mono= True,
)
if len(audio.shape) == 2:
audio = audio.mean(0, False) # to mono
input_size = int(n_samples)
if audio.shape[-1] < input_size: # pad sequence
pad = np.zeros(input_size)
pad[: audio.shape[-1]] = audio
audio = pad
ceil = int(audio.shape[-1] // n_samples)
audio = torch.from_numpy(np.stack(np.split(audio[:ceil * n_samples], ceil)).astype('float32'))
return audio

def captioning(audio_path):
audio_tensor = get_audio(audio_path = audio_path)
if device is not None:
audio_tensor = audio_tensor.to(device)
with torch.no_grad():
output = model.generate(
samples=audio_tensor,
num_beams=5,
)
inference = ""
number_of_chunks = range(audio_tensor.shape[0])
for chunk, text in zip(number_of_chunks, output):
time = f"[{chunk * 10}:00-{(chunk + 1) * 10}:00]"
inference += f"{time}\n{text} \n \n"
return inference

title = "Interactive demo: Music Captioning 🤖🎵"
description = """
<p style='text-align: center'> LP-MusicCaps: LLM-Based Pseudo Music Captioning</p>
<p style='text-align: center'> SeungHeon Doh, Keunwoo Choi, Jongpil Lee, Juhan Nam, ISMIR 2023</p>
<p style='text-align: center'> <a href='#' target='_blank'>ArXiv</a> | <a href='https://github.com/seungheondoh/lp-music-caps' target='_blank'>Github</a> | <a href='https://github.com/seungheondoh/lp-music-caps' target='_blank'>LP-MusicCaps-Dataset</a> </p>
<p style='text-align: center'> To use it, simply upload your audio and click 'submit', or click one of the examples to load them. Read more at the links below. </p>
"""
article = "<p style='text-align: center'><a href='https://github.com/seungheondoh/lp-music-caps' target='_blank'>LP-MusicCaps Github</a> | <a href='#' target='_blank'>LP-MusicCaps Paper</a></p>"


demo = gr.Interface(fn=captioning,
inputs=gr.Audio(type="filepath"),
outputs=[
gr.Textbox(label="Caption generated by LP-MusicCaps Transfer Model"),
],
examples=example_list,
title=title,
description=description,
article=article,
cache_examples=False
)
demo.launch()
151 changes: 151 additions & 0 deletions demo/model/bart.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from .modules import AudioEncoder
from transformers import BartForConditionalGeneration, BartTokenizer, BartConfig

class BartCaptionModel(nn.Module):
def __init__(self, n_mels=128, num_of_conv=6, sr=16000, duration=10, max_length=128, label_smoothing=0.1, bart_type="facebook/bart-base", audio_dim=768):
super(BartCaptionModel, self).__init__()
# non-finetunning case
bart_config = BartConfig.from_pretrained(bart_type)
self.tokenizer = BartTokenizer.from_pretrained(bart_type)
self.bart = BartForConditionalGeneration(bart_config)

self.n_sample = sr * duration
self.hop_length = int(0.01 * sr) # hard coding hop_size
self.n_frames = int(self.n_sample // self.hop_length)
self.num_of_stride_conv = num_of_conv - 1
self.n_ctx = int(self.n_frames // 2**self.num_of_stride_conv) + 1
self.audio_encoder = AudioEncoder(
n_mels = n_mels, # hard coding n_mel
n_ctx = self.n_ctx,
audio_dim = audio_dim,
text_dim = self.bart.config.hidden_size,
num_of_stride_conv = self.num_of_stride_conv
)

self.max_length = max_length
self.loss_fct = nn.CrossEntropyLoss(label_smoothing= label_smoothing, ignore_index=-100)

@property
def device(self):
return list(self.parameters())[0].device

def shift_tokens_right(self, input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
"""
Shift input ids one token to the right.ls
"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
shifted_input_ids[:, 0] = decoder_start_token_id

if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
return shifted_input_ids

def forward_encoder(self, audio):
audio_embs = self.audio_encoder(audio)
encoder_outputs = self.bart.model.encoder(
input_ids=None,
inputs_embeds=audio_embs,
return_dict=True
)["last_hidden_state"]
return encoder_outputs, audio_embs

def forward_decoder(self, text, encoder_outputs):
text = self.tokenizer(text,
padding='longest',
truncation=True,
max_length=self.max_length,
return_tensors="pt")
input_ids = text["input_ids"].to(self.device)
attention_mask = text["attention_mask"].to(self.device)

decoder_targets = input_ids.masked_fill(
input_ids == self.tokenizer.pad_token_id, -100
)

decoder_input_ids = self.shift_tokens_right(
decoder_targets, self.bart.config.pad_token_id, self.bart.config.decoder_start_token_id
)

decoder_outputs = self.bart(
input_ids=None,
attention_mask=None,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=attention_mask,
inputs_embeds=None,
labels=None,
encoder_outputs=(encoder_outputs,),
return_dict=True
)
lm_logits = decoder_outputs["logits"]
loss = self.loss_fct(lm_logits.view(-1, self.tokenizer.vocab_size), decoder_targets.view(-1))
return loss

def forward(self, audio, text):
encoder_outputs, _ = self.forward_encoder(audio)
loss = self.forward_decoder(text, encoder_outputs)
return loss

def generate(self,
samples,
use_nucleus_sampling=False,
num_beams=5,
max_length=128,
min_length=2,
top_p=0.9,
repetition_penalty=1.0,
):

# self.bart.force_bos_token_to_be_generated = True
audio_embs = self.audio_encoder(samples)
encoder_outputs = self.bart.model.encoder(
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=audio_embs,
output_attentions=None,
output_hidden_states=None,
return_dict=True)

input_ids = torch.zeros((encoder_outputs['last_hidden_state'].size(0), 1)).long().to(self.device)
input_ids[:, 0] = self.bart.config.decoder_start_token_id
decoder_attention_mask = torch.ones((encoder_outputs['last_hidden_state'].size(0), 1)).long().to(self.device)
if use_nucleus_sampling:
outputs = self.bart.generate(
input_ids=None,
attention_mask=None,
decoder_input_ids=input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
max_length=max_length,
min_length=min_length,
do_sample=True,
top_p=top_p,
num_return_sequences=1,
repetition_penalty=1.1)
else:
outputs = self.bart.generate(input_ids=None,
attention_mask=None,
decoder_input_ids=input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
head_mask=None,
decoder_head_mask=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
max_length=max_length,
min_length=min_length,
num_beams=num_beams,
repetition_penalty=repetition_penalty)

captions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
return captions
95 changes: 95 additions & 0 deletions demo/model/modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
### code reference: https://github.com/openai/whisper/blob/main/whisper/audio.py

import os
import torch
import torchaudio
import numpy as np
import torch.nn.functional as F
from torch import Tensor, nn
from typing import Dict, Iterable, Optional

# hard-coded audio hyperparameters
SAMPLE_RATE = 16000
N_FFT = 1024
N_MELS = 128
HOP_LENGTH = int(0.01 * SAMPLE_RATE)
DURATION = 10
N_SAMPLES = int(DURATION * SAMPLE_RATE)
N_FRAMES = N_SAMPLES // HOP_LENGTH + 1

def sinusoids(length, channels, max_timescale=10000):
"""Returns sinusoids for positional embedding"""
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)

class MelEncoder(nn.Module):
"""
time-frequency represntation
"""
def __init__(self,
sample_rate= 16000,
f_min=0,
f_max=8000,
n_fft=1024,
win_length=1024,
hop_length = int(0.01 * 16000),
n_mels = 128,
power = None,
pad= 0,
normalized= False,
center= True,
pad_mode= "reflect"
):
super(MelEncoder, self).__init__()
self.window = torch.hann_window(win_length)
self.spec_fn = torchaudio.transforms.Spectrogram(
n_fft = n_fft,
win_length = win_length,
hop_length = hop_length,
power = power
)
self.mel_scale = torchaudio.transforms.MelScale(
n_mels,
sample_rate,
f_min,
f_max,
n_fft // 2 + 1)

self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()

def forward(self, wav):
spec = self.spec_fn(wav)
power_spec = spec.real.abs().pow(2)
mel_spec = self.mel_scale(power_spec)
mel_spec = self.amplitude_to_db(mel_spec) # Log10(max(reference value and amin))
return mel_spec

class AudioEncoder(nn.Module):
def __init__(
self, n_mels: int, n_ctx: int, audio_dim: int, text_dim: int, num_of_stride_conv: int,
):
super().__init__()
self.mel_encoder = MelEncoder(n_mels=n_mels)
self.conv1 = nn.Conv1d(n_mels, audio_dim, kernel_size=3, padding=1)
self.conv_stack = nn.ModuleList([])
for _ in range(num_of_stride_conv):
self.conv_stack.append(
nn.Conv1d(audio_dim, audio_dim, kernel_size=3, stride=2, padding=1)
)
# self.proj = nn.Linear(audio_dim, text_dim, bias=False)
self.register_buffer("positional_embedding", sinusoids(n_ctx, text_dim))

def forward(self, x: Tensor):
"""
x : torch.Tensor, shape = (batch_size, waveform)
single channel wavform
"""
x = self.mel_encoder(x) # (batch_size, n_mels, n_ctx)
x = F.gelu(self.conv1(x))
for conv in self.conv_stack:
x = F.gelu(conv(x))
x = x.permute(0, 2, 1)
x = (x + self.positional_embedding).to(x.dtype)
return x
Loading

0 comments on commit 57c5e3b

Please sign in to comment.