Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: training diffusion models on sound #604

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 2 additions & 11 deletions data/image_folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,18 +94,9 @@ def make_labeled_path_dataset(dir, paths, max_dataset_size=float("inf")):
): # we allow B not having a label
images.append(line_split[0])

elif len(line_split) == 2:
elif len(line_split) >= 2:
images.append(line_split[0])
labels.append(line_split[1])

elif len(line_split) > 2:
images.append(line_split[0])

label_line = line_split[1]
for i in range(2, len(line_split)):
label_line += " " + line_split[i]

labels.append(label_line)
labels.append(" ".join(line_split[1:]))

return (
images[: min(max_dataset_size, len(images))],
Expand Down
1 change: 0 additions & 1 deletion data/self_supervised_labeled_mask_cls_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def get_img(
)

try:

if self.opt.data_online_creation_rand_mask_A:
A_img = fill_mask_with_random(result["A"], result["A_label_mask"], -1)
elif self.opt.data_online_creation_color_mask_A:
Expand Down
1 change: 0 additions & 1 deletion data/self_supervised_labeled_mask_cls_online_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def get_img(
)

try:

if self.opt.data_online_creation_rand_mask_A:
A_img = fill_mask_with_random(result["A"], result["A_label_mask"], -1)
elif self.opt.data_online_creation_color_mask_A:
Expand Down
1 change: 0 additions & 1 deletion data/self_supervised_labeled_mask_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def get_img(
)

try:

if self.opt.data_online_creation_rand_mask_A:
A_img = fill_mask_with_random(result["A"], result["A_label_mask"], -1)
elif self.opt.data_online_creation_color_mask_A:
Expand Down
1 change: 0 additions & 1 deletion data/self_supervised_labeled_mask_online_ref_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def get_img(
)

try:

if self.opt.data_online_creation_rand_mask_A:
A_img = fill_mask_with_random(result["A"], result["A_label_mask"], -1)
elif self.opt.data_online_creation_color_mask_A:
Expand Down
1 change: 0 additions & 1 deletion data/self_supervised_labeled_mask_ref_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def get_img(
)

try:

if self.opt.data_online_creation_rand_mask_A:
A_img = fill_mask_with_random(result["A"], result["A_label_mask"], -1)
elif self.opt.data_online_creation_color_mask_A:
Expand Down
66 changes: 66 additions & 0 deletions data/self_supervised_labeled_sound_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import os.path
from data.unaligned_labeled_cls_dataset import UnalignedLabeledClsDataset
from data.base_dataset import BaseDataset
from data.online_creation import fill_mask_with_random, fill_mask_with_color
from data.image_folder import make_labeled_path_dataset
from data.sound_folder import load_sound
from PIL import Image
import numpy as np
import torch
from torch.fft import fft

# TODO optional?
import torchaudio
import warnings


class SelfSupervisedLabeledSoundDataset(UnalignedLabeledClsDataset):
"""
This dataset class can create paired datasets with mask labels from only one domain.
"""

def __init__(self, opt, phase):
"""Initialize this dataset class.

Parameters:
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
BaseDataset.__init__(self, opt, phase)

self.A_img_paths, self.A_label = make_labeled_path_dataset(
self.dir_A, "/paths.txt", opt.data_max_dataset_size
) # load images from '/path/to/data/trainA/paths.txt' as well as labels

# Split multilabel
self.A_label = [lbl.split(" ") for lbl in self.A_label]
self.A_label = np.array(self.A_label, dtype=np.float32)

self.A_size = len(self.A_img_paths) # get the size of dataset A
self.B_size = 0 # get the size of dataset B

def get_img(
self,
A_sound_path,
A_label_mask_path,
A_label_cls,
B_img_path=None,
B_label_mask_path=None,
B_label_cls=None,
index=None,
):
try:
target = load_sound(A_sound_path)
# XXX: some datasets don't convert to int, which mean they are never used with palette, because palette requires cls to be int
A_label = torch.tensor(self.A_label[index % self.A_size].astype(int))
result = {
"A": torch.randn_like(target),
"B": target,
"A_img_paths": A_sound_path,
"A_label_cls": A_label,
"B_label_cls": A_label,
}
except Exception as e:
print(e, "self supervised sound data loading")
return None

return result
88 changes: 88 additions & 0 deletions data/sound_folder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""A modified image folder class

We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
so that this class can load images from both current directory and its subdirectories.
"""


import os
import os.path

import torch
import torch.nn.functional as F
from torch.fft import fft, ifft
import torch.utils.data as data

# TODO optional?
import torchaudio


def window(t):
"""
t between 0 & 1
"""
return (1 - torch.cos(t * (torch.pi * 2))) / 2


# TODO write a test to check that `wav2D_to_wav(wav_to_2D(x))` is consistent
def wav_to_2D(data, chunk_size, norm_factor=256):
"""
Transform sound data to image-like data (2D, normalized between -1 & 1)
"""
chunk_gap = chunk_size // 2
chunks = torch.stack(
[
data[i : i + chunk_size]
for i in range(0, len(data) - chunk_size + 1, chunk_gap)
]
)
chunks_fft = fft(chunks)[:, : chunk_size // 2]
chunks_fft = torch.stack([chunks_fft.real, chunks_fft.imag, torch.abs(chunks_fft)])
chunks_fft /= norm_factor
# TODO manage sound longer than input size
# TODO don't hard code input size
chunks_fft = torch.nn.functional.pad(
chunks_fft, (0, 0, 0, 256 - chunks_fft.shape[-2]), value=0
)
# print(torch.max(chunks_fft), torch.min(chunks_fft))
return chunks_fft


def wav2D_to_wav(sound2d, norm_factor=256):
"""
Transform image-like data (2D, normalized between -1 & 1) to waveform. This
function is the inverse of wav_to_2D

Parameters:
sound2d -- The 2D matrix containing the sound, with shape [n_channel, width, height]
"""
# sound2d: channel, time, fourier features
chunk_size = sound2d.shape[-1] * 2
sound2d = (sound2d[0] + 1j * sound2d[1]) * norm_factor
chunks_fft = F.pad(sound2d, (0, chunk_size // 2), mode="constant", value=0)
chunks = ifft(chunks_fft).real

# Apply window and paste chunks together
chunk_window = window(torch.linspace(0, 1, chunk_size + 1, device=sound2d.device))[
:-1
]
chunks = chunks * chunk_window

chunks_odd = F.pad(torch.flatten(chunks[1::2]), (chunk_size // 2, 0))
chunks_even = torch.flatten(chunks[0::2])
total_size = min(len(chunks_even), len(chunks_odd))

signal = chunks_odd[:total_size] + chunks_even[:total_size]
return signal.unsqueeze(0)


def load_sound(sound_path):
data, rate = torchaudio.load(sound_path)

# Ensure mono audio
data = data[0]

# TODO dynamic chunk_size
chunk_size = 512
norm_factor = 256 # 65536
return wav_to_2D(data, chunk_size, norm_factor)
1 change: 0 additions & 1 deletion data/unaligned_labeled_mask_cls_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def get_img(
B_label_cls=None,
index=None,
):

return_dict = super().get_img(
A_img_path,
A_label_mask_path,
Expand Down
6 changes: 3 additions & 3 deletions docs/source/inference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -208,13 +208,13 @@ Download a pretrained model:
Run the inference script
========================

The ``--cls`` parameter controls the pose for Mario (1 = standing, 2 = walking, 3 = jumping, etc).
The ``--cls_value`` parameter controls the pose for Mario (1 = standing, 2 = walking, 3 = jumping, etc).

.. code:: bash

mkdir mario_inference_output
cd scripts/
python3 gen_single_image_diffusion.py --model_in_file ../checkpoints/mario/latest_net_G_A.pth --img_in ../datasets/online_mario2sonic_lite/mario/imgs/mario_frame_19538.jpg --bbox_in ../datasets/online_mario2sonic_lite/mario/bbox/r_mario_frame_19538.jpg.txt --dir_out ../mario_inference_output --img_width 128 --img_height 128 --mask_delta 10 --cls 3
python3 gen_single_image_diffusion.py --model_in_file ../checkpoints/mario/latest_net_G_A.pth --img_in ../datasets/online_mario2sonic_lite/mario/imgs/mario_frame_19538.jpg --bbox_in ../datasets/online_mario2sonic_lite/mario/bbox/r_mario_frame_19538.jpg.txt --dir_out ../mario_inference_output --img_width 128 --img_height 128 --mask_delta 10 --cls_value 3

The output files will be in the ``mario_inference_output`` folder, with:

Expand Down Expand Up @@ -276,7 +276,7 @@ Download a pretrained model:
Run the inference script
========================

The ``--cond-in`` parameter specifies the conditioning image to use.
The ``--cond_in`` parameter specifies the conditioning image to use.

.. code:: bash

Expand Down
18 changes: 18 additions & 0 deletions models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __init__(self, opt, rank):
self.loss_names = []
self.model_names = []
self.visual_names = []
self.sound_names = []
self.display_param = []
self.set_display_param()
self.optimizers = []
Expand Down Expand Up @@ -737,6 +738,10 @@ def compute_visuals(self):
"""Calculate additional output images for visdom and HTML visualization"""
pass

def compute_sounds(self):
"""Calculate sounds to listen to on the visualizer"""
pass

def get_image_paths(self):
"""Return image paths that are used to load current data"""
return self.image_paths
Expand Down Expand Up @@ -767,6 +772,14 @@ def get_current_visuals(self, phase="train"):
visual_ret.append(cur_visual)
return visual_ret

def get_current_sounds(self):
# TODO phase? do same as visuals? create "types" of visuals?
sound_ret = {}
for i, name in enumerate(self.sound_names):
sound_ret[name] = getattr(self, name)

return sound_ret

def get_display_param(self):
param = OrderedDict()
for name in self.display_param:
Expand Down Expand Up @@ -933,6 +946,11 @@ def load_networks(self, epoch):
state_dict[new_key] = state_dict[key].clone()
del state_dict[key]

# TODO auto detect when necessary
for key in list(state_dict.keys()):
if key.startswith("denoise_fn") and key.endswith("_test"):
state_dict[key] = net.state_dict()[key]

state1 = list(state_dict.keys())
state2 = list(net.state_dict().keys())
state1.sort()
Expand Down
3 changes: 0 additions & 3 deletions models/cm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ def __init__(self, opt, rank):
self.iter_calculator_init()

def set_input(self, data):

if (
len(data["A"].to(self.device).shape) == 5
): # we're using temporal successive frames
Expand Down Expand Up @@ -203,7 +202,6 @@ def set_input(self, data):
self.real_B = self.gt_image

def compute_cm_loss(self):

y_0 = self.gt_image # ground truth
y_cond = self.cond_image # conditioning
mask = self.mask
Expand All @@ -224,7 +222,6 @@ def compute_cm_loss(self):
self.loss_G_tot = loss * self.opt.alg_diffusion_lambda_G

def inference(self):

if hasattr(self.netG_A, "module"):
netG = self.netG_A.module
else:
Expand Down
2 changes: 0 additions & 2 deletions models/cycle_gan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def after_parse(opt):
return opt

def __init__(self, opt, rank):

super().__init__(opt, rank)

if opt.alg_cyclegan_lambda_identity > 0.0:
Expand Down Expand Up @@ -113,7 +112,6 @@ def __init__(self, opt, rank):
# Discriminators

if self.isTrain:

self.netD_As = gan_networks.define_D(**vars(opt))
self.netD_Bs = gan_networks.define_D(**vars(opt))

Expand Down
1 change: 0 additions & 1 deletion models/diffusion_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def define_G(
in_channel += alg_diffusion_cond_embed_dim

if G_netG == "unet_mha":

if model_prior_321_backwardcompatibility:
cond_embed_dim = G_ngf * 4
else:
Expand Down
3 changes: 0 additions & 3 deletions models/gan_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,6 @@ def define_D(
train_feat_wavelet,
**unused_options
):

"""Create a discriminator

Parameters:
Expand Down Expand Up @@ -309,7 +308,6 @@ def define_D(
img_size = data_crop_size

for netD in D_netDs:

if netD == "basic": # default PatchGAN classifier
net = NLayerDiscriminator(
model_input_nc,
Expand Down Expand Up @@ -360,7 +358,6 @@ def define_D(
download_segformer_weight(weight_path)

elif D_proj_network_type == "depth":

weight_path = model_depth_network

else:
Expand Down
3 changes: 2 additions & 1 deletion models/modules/classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def forward(self, x):


class VGG16_FCN8s(nn.Module):

transform = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
Expand Down Expand Up @@ -243,6 +242,8 @@ def forward(self, x):
"mnasnet1_0": models.mnasnet1_0,
"mnasnet1_3": models.mnasnet1_3,
}


# all models are RGB internally.
class torch_model(nn.Module):
def __init__(self, input_nc, ndf, nclasses, img_size, template, pretrained):
Expand Down
Loading
Loading