-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add audio spectrogram transformer, and full audio clip #406
base: main
Are you sure you want to change the base?
Conversation
…hich it is, as the premise is to change the waveform to a 2d "image" and patchify and attend)
@lucidrains awesome, we should probably put this audio specific stuff in a new file, was thinking of splitting the other sub-transformers at some point too ... audio_transformer.py ? |
@rwightman sure, by modality, or by functionality, or both, either way is fine just let me know |
will still need to add the functions for generating from cfg as well as the full perhaps by modality is good |
yeah was thinking modality, leave base transformer as the parent, and split off modality specific transformers, at least in this case audio since it's new, can split the others later as other PR are probably based on current structure |
You got it, will make the changes next week |
Have a bunch of meetings with people around the valley this week, I'll get around to finishing this next week |
Hi @lucidrains the current code looks great! Feel free to ping us (Ke and I) when you are finished! |
Hi @lucidrains Currently we briefly scanned your code and it looks great to us. After you finish the code, just let us know. We will go mainly over the spec-augment (time masking, freq masking, screeching, etc.) and hyperparameters on the spectrogram transformer. If you provide me the specific location in your code, that would be better. Thanks! |
@lukewys @RetroCirce Hello Yusong and Ke! Thank you so much for offering your audio expertise; it is more helpful than you realize The hyperparameters that I am unsure about are listed here to here. But also whatever you think are reasonable default values would be good too! |
|
||
# audio clip | ||
|
||
class AudioCLIP(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should decide whether to extend CLIP
similarly, decide whether to just extend CoCa
to AudioCoCa
and override the visual modality transformer
also, decided to keep a lot of the |
…ether spectrogram is augmented or not
Hi @lucidrains ! Can you use riffusion spectrogram as input in the |
@marianna13 oh hey Marianna! good to hear from you yes, it should be able to accept spectrograms (you just have to pass in a tensor of shape |
@marianna13 can you make sure the following code can run import torch
from src.open_clip import AudioCLIP, CLIPAudioCfg, CLIPTextCfg
mulan = AudioCLIP(
embed_dim = 512,
audio_cfg = CLIPAudioCfg(),
text_cfg = CLIPTextCfg()
)
spectrogram = torch.randn(2, 32, 1024)
text = torch.randint(0, 10, (2, 77))
audio_latents, text_latents, _ = mulan(spectrogram, text)
print(audio_latents.shape) # (2, 512)
print(text_latents.shape) # (2, 512) |
@lucidrains no, unfortunately I get this error: |
@marianna13 ohh, what is the shape of the input tensor you are passing in? i thought spectrograms only have 1 channel, but i am not really an audio expert |
@marianna13 i can make it accommodate 3 channels, if that is the case |
import torch
from src.open_clip import AudioCLIP, CLIPAudioCfg, CLIPTextCfg
mulan = AudioCLIP(
embed_dim = 512,
audio_cfg = CLIPAudioCfg(channels = 3),
text_cfg = CLIPTextCfg(),
)
spectrogram = torch.randn(2, 3, 32, 1024)
text = torch.randint(0, 10, (2, 77))
audio_latents, text_latents, _ = mulan(spectrogram, text)
print(audio_latents.shape) # (2, 512)
print(text_latents.shape) # (2, 512) |
hmm, how are you testing this? are you checking out the entire PR? this error may also suggest you don't have the necessary changes to the vision transformer (to be able to configure it to have 1 channel) |
@lucidrains I checked again now it works! (I just forgot that I've made changes to the code) sorry, that's my bad! |
@marianna13 oh great! can you confirm that you are using 1 channel then? i should revert that commit |
@marianna13 i'll add the |
@lucidrains yes, I changed back to 1 channel and it worked, but also I tried to run it over a batch of images but it didn't work :( |
That's great! Thank you :) |
oh, that's odd, what is the shape of the batch of images you are sending in? |
…ogram being passed in" This reverts commit bf0ce8b.
@marianna13 if you can show me a reproducible error like the sample script above, i can fix it |
Hi @lucidrains ! Sorry for the late reply. Here's the code I'm using: import torch
import cv2
from src.open_clip import AudioCLIP, CLIPAudioCfg, CLIPTextCfg
import webdataset as wds
import sys
import os
from torchvision import transforms
from PIL import Image
import numpy as np
import time
transform = transforms.Compose([
transforms.ToTensor()
])
def preprocess(sample:tuple):
image, json_data = sample
# json_data = json.loads(json_data.decode())
audio_meta = json_data.get('audio_meta', None)
if audio_meta is not None:
tags = audio_meta.get('tags', None)
if tags is not None:
try:
title, artist, genre = '', '', ''
for k in tags.keys():
if k in ['title', 'TITLE']:
title = f'titled {tags[k]}'
if k in ['artist', 'ARTIST']:
artist = f'by {tags[k]}'
if k in ['genre', 'GENRE']:
genre = tags[k]
label = f'{genre} song "{title}" {artist}'
except:
pass
label = f'{json_data["caption"]}'
return image, {'label': label}
def get_dataset(urls: list):
'''
Pass s3 urls and get processed torch dataset
'''
dataset = (
wds.WebDataset(urls)
.decode("pil")
.to_tuple("jpg", "json")
.map_tuple(transform)
.map(preprocess)
)
return dataset
urls = [f'{i:05}.tar' for i in range(1)]
dataset = get_dataset(urls)
batch_size = 32
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
mulan = AudioCLIP(
embed_dim = 32,
audio_cfg = CLIPAudioCfg(**{'image_size': (512, 1001), 'patch_size': 16}),
text_cfg = CLIPTextCfg()
)
for i, batch in enumerate(loader):
im, label = batch
print(type(mulan.encode_image(im))) The one example of the dataset can be found here: https://drive.google.com/file/d/15VFMSovEWCHJcDeg9lXqFnlACJmi5gr5/view?usp=sharing Thank you! |
@marianna13 hey Marianna, thanks for sharing the script it looks good except for the image dimensions, whose height and width needs to be divisible by the patch size. however, that assert should be in the code somewhere, maybe left for a separate PR. it also does not matter for the vision transformer other than generating the absolute positions, so long as the image dimensions are the maximum of what you send in during training. The spectrogram must be of fixed shape during training as well, for now Could you try rerunning your script? And also insert a print statement before the mulan invocation, in the case that it fails again even with my recent changes? for i, batch in enumerate(loader):
im, label = batch
print('input shape is:', im.shape)
print(type(mulan.encode_image(im))) |
@lucidrains it works! Thank you! :) |
@marianna13 hey Marianna, were you able to do a small test run? if we can even get a training run to overfit on a small training set, maybe we can try to get this PR merged |
Hey @lucidrains, I tried to train a model with a small fraction of the dataset but it gets stuck at the first epoch and then gets killed. I can post my training script (I think it might be an issue on my side) but anyway |
@marianna13 ohh got it, what kind of error do you see before it dies? |
@lucidrains it just says "terminated" (I think some oom issue) |
@marianna13 ohh, yea i think Romain mentioned this to me maybe i should take a look at your dataset class you could also try just plucking the code from here. others have been training parts of audiolm successfully with it |
Thank you @lucidrains ! Does the AudioCLIP accepts only audio? I mean I have a bunch of spectrograms :) |
@marianna13 ohh i see, you will probably have to set a max length on the time dimension of the spectrogram do you know if the spectrogram is generated from the full piece of music, or just chunks of them? |
@marianna13 actually, i can also just allow the audio clip to take care of that (curtailing the time dimension to some maximum number of patches |
@lucidrains I split every audio into 10 sec pieces and then convert them into spectrograms, so they should have the same time dimension |
@marianna13 ohh, that should be ok then 🤔 i'm not sure what's going wrong |
Hi, excuse me, I would just like to ask if work on this stalled or if training is ongoing. I tried looking on https://discord.gg/xBPBXfcFHd as suggested on https://github.com/lucidrains/musiclm-pytorch but I could not find active discussion. This implementation would be really something. Many thanks, L |
Can audioclip be used for training? If yes, how should I modify my config? Thank you. |
for building out MuLaN
Now one can do