From 23e1c567ade5578eef48570aace13b9ec4323ce0 Mon Sep 17 00:00:00 2001 From: Louis Date: Fri, 19 Jan 2024 16:45:47 +0000 Subject: [PATCH 1/2] feat: training diffusion models on sound Add conditionning by multiple classes Add sounds on visdom Reload diffusion model with different timestep count --- data/image_folder.py | 13 +-- ...elf_supervised_labeled_mask_cls_dataset.py | 1 - ...ervised_labeled_mask_cls_online_dataset.py | 1 - data/self_supervised_labeled_mask_dataset.py | 1 - ...ervised_labeled_mask_online_ref_dataset.py | 1 - ...elf_supervised_labeled_mask_ref_dataset.py | 1 - data/self_supervised_labeled_sound_dataset.py | 66 ++++++++++++++ data/sound_folder.py | 88 +++++++++++++++++++ data/unaligned_labeled_mask_cls_dataset.py | 1 - models/base_model.py | 18 ++++ models/cycle_gan_model.py | 2 - models/diffusion_networks.py | 1 - models/gan_networks.py | 3 - models/modules/classifiers.py | 3 +- models/modules/diffusion_generator.py | 10 +-- models/modules/fid | 1 + models/modules/image_bind/imagebind_model.py | 1 - models/modules/ittr/ittr_generator.py | 1 - models/modules/op/upfirdn2d.py | 1 - models/modules/palette_denoise_fn.py | 38 ++++++-- models/modules/projected_d/diffusion.py | 1 - .../resnet_architecture/resnet_generator.py | 1 - .../resnet_generator_diff.py | 2 - .../sub_mobile_resnet_generator.py | 1 - models/modules/segformer/backbone.py | 2 - models/palette_model.py | 47 +++++++++- options/common_options.py | 6 +- train.py | 10 ++- util/metrics.py | 2 - util/visualizer.py | 55 +++++++++++- 30 files changed, 316 insertions(+), 63 deletions(-) create mode 100644 data/self_supervised_labeled_sound_dataset.py create mode 100644 data/sound_folder.py create mode 160000 models/modules/fid diff --git a/data/image_folder.py b/data/image_folder.py index 972afbbc9..85b3f6505 100644 --- a/data/image_folder.py +++ b/data/image_folder.py @@ -94,18 +94,9 @@ def make_labeled_path_dataset(dir, paths, max_dataset_size=float("inf")): ): # we allow B not having a label images.append(line_split[0]) - elif len(line_split) == 2: + elif len(line_split) >= 2: images.append(line_split[0]) - labels.append(line_split[1]) - - elif len(line_split) > 2: - images.append(line_split[0]) - - label_line = line_split[1] - for i in range(2, len(line_split)): - label_line += " " + line_split[i] - - labels.append(label_line) + labels.append(" ".join(line_split[1:])) return ( images[: min(max_dataset_size, len(images))], diff --git a/data/self_supervised_labeled_mask_cls_dataset.py b/data/self_supervised_labeled_mask_cls_dataset.py index 53f78833e..8f7fb4d17 100644 --- a/data/self_supervised_labeled_mask_cls_dataset.py +++ b/data/self_supervised_labeled_mask_cls_dataset.py @@ -41,7 +41,6 @@ def get_img( ) try: - if self.opt.data_online_creation_rand_mask_A: A_img = fill_mask_with_random(result["A"], result["A_label_mask"], -1) elif self.opt.data_online_creation_color_mask_A: diff --git a/data/self_supervised_labeled_mask_cls_online_dataset.py b/data/self_supervised_labeled_mask_cls_online_dataset.py index 0ac101839..4ec1c3410 100644 --- a/data/self_supervised_labeled_mask_cls_online_dataset.py +++ b/data/self_supervised_labeled_mask_cls_online_dataset.py @@ -43,7 +43,6 @@ def get_img( ) try: - if self.opt.data_online_creation_rand_mask_A: A_img = fill_mask_with_random(result["A"], result["A_label_mask"], -1) elif self.opt.data_online_creation_color_mask_A: diff --git a/data/self_supervised_labeled_mask_dataset.py b/data/self_supervised_labeled_mask_dataset.py index 480b0b948..3b5097a5c 100644 --- a/data/self_supervised_labeled_mask_dataset.py +++ b/data/self_supervised_labeled_mask_dataset.py @@ -42,7 +42,6 @@ def get_img( ) try: - if self.opt.data_online_creation_rand_mask_A: A_img = fill_mask_with_random(result["A"], result["A_label_mask"], -1) elif self.opt.data_online_creation_color_mask_A: diff --git a/data/self_supervised_labeled_mask_online_ref_dataset.py b/data/self_supervised_labeled_mask_online_ref_dataset.py index 03afb8a07..8959b0eb9 100644 --- a/data/self_supervised_labeled_mask_online_ref_dataset.py +++ b/data/self_supervised_labeled_mask_online_ref_dataset.py @@ -44,7 +44,6 @@ def get_img( ) try: - if self.opt.data_online_creation_rand_mask_A: A_img = fill_mask_with_random(result["A"], result["A_label_mask"], -1) elif self.opt.data_online_creation_color_mask_A: diff --git a/data/self_supervised_labeled_mask_ref_dataset.py b/data/self_supervised_labeled_mask_ref_dataset.py index 0c5a52d60..e35ce37b8 100644 --- a/data/self_supervised_labeled_mask_ref_dataset.py +++ b/data/self_supervised_labeled_mask_ref_dataset.py @@ -42,7 +42,6 @@ def get_img( ) try: - if self.opt.data_online_creation_rand_mask_A: A_img = fill_mask_with_random(result["A"], result["A_label_mask"], -1) elif self.opt.data_online_creation_color_mask_A: diff --git a/data/self_supervised_labeled_sound_dataset.py b/data/self_supervised_labeled_sound_dataset.py new file mode 100644 index 000000000..a68344f5c --- /dev/null +++ b/data/self_supervised_labeled_sound_dataset.py @@ -0,0 +1,66 @@ +import os.path +from data.unaligned_labeled_cls_dataset import UnalignedLabeledClsDataset +from data.base_dataset import BaseDataset +from data.online_creation import fill_mask_with_random, fill_mask_with_color +from data.image_folder import make_labeled_path_dataset +from data.sound_folder import load_sound +from PIL import Image +import numpy as np +import torch +from torch.fft import fft + +# TODO optional? +import torchaudio +import warnings + + +class SelfSupervisedLabeledSoundDataset(UnalignedLabeledClsDataset): + """ + This dataset class can create paired datasets with mask labels from only one domain. + """ + + def __init__(self, opt, phase): + """Initialize this dataset class. + + Parameters: + opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + BaseDataset.__init__(self, opt, phase) + + self.A_img_paths, self.A_label = make_labeled_path_dataset( + self.dir_A, "/paths.txt", opt.data_max_dataset_size + ) # load images from '/path/to/data/trainA/paths.txt' as well as labels + + # Split multilabel + self.A_label = [lbl.split(" ") for lbl in self.A_label] + self.A_label = np.array(self.A_label, dtype=np.float32) + + self.A_size = len(self.A_img_paths) # get the size of dataset A + self.B_size = 0 # get the size of dataset B + + def get_img( + self, + A_sound_path, + A_label_mask_path, + A_label_cls, + B_img_path=None, + B_label_mask_path=None, + B_label_cls=None, + index=None, + ): + try: + target = load_sound(A_sound_path) + # XXX: some datasets don't convert to int, which mean they are never used with palette, because palette requires cls to be int + A_label = torch.tensor(self.A_label[index % self.A_size].astype(int)) + result = { + "A": torch.randn_like(target), + "B": target, + "A_img_paths": A_sound_path, + "A_label_cls": A_label, + "B_label_cls": A_label, + } + except Exception as e: + print(e, "self supervised sound data loading") + return None + + return result diff --git a/data/sound_folder.py b/data/sound_folder.py new file mode 100644 index 000000000..a1ab76714 --- /dev/null +++ b/data/sound_folder.py @@ -0,0 +1,88 @@ +"""A modified image folder class + +We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py) +so that this class can load images from both current directory and its subdirectories. +""" + + +import os +import os.path + +import torch +import torch.nn.functional as F +from torch.fft import fft, ifft +import torch.utils.data as data + +# TODO optional? +import torchaudio + + +def window(t): + """ + t between 0 & 1 + """ + return (1 - torch.cos(t * (torch.pi * 2))) / 2 + + +# TODO write a test to check that `wav2D_to_wav(wav_to_2D(x))` is consistent +def wav_to_2D(data, chunk_size, norm_factor): + """ + Transform sound data to image-like data (2D, normalized between -1 & 1) + """ + chunk_gap = chunk_size // 2 + chunks = torch.stack( + [ + data[i : i + chunk_size] + for i in range(0, len(data) - chunk_size + 1, chunk_gap) + ] + ) + chunks_fft = fft(chunks)[:, : chunk_size // 2] + chunks_fft = torch.stack([chunks_fft.real, chunks_fft.imag, torch.abs(chunks_fft)]) + chunks_fft /= norm_factor + # TODO manage sound longer than input size + # TODO don't hard code input size + chunks_fft = torch.nn.functional.pad( + chunks_fft, (0, 0, 0, 256 - chunks_fft.shape[-2]), value=0 + ) + # print(torch.max(chunks_fft), torch.min(chunks_fft)) + return chunks_fft + + +def wav2D_to_wav(sound2d, norm_factor): + """ + Transform image-like data (2D, normalized between -1 & 1) to waveform. This + function is the inverse of wav_to_2D + + Parameters: + sound2d -- The 2D matrix containing the sound, with shape [n_channel, width, height] + """ + # sound2d: channel, time, fourier features + chunk_size = sound2d.shape[-1] * 2 + sound2d = (sound2d[0] + 1j * sound2d[1]) * norm_factor + chunks_fft = F.pad(sound2d, (0, chunk_size // 2), mode="constant", value=0) + chunks = ifft(chunks_fft).real + + # Apply window and paste chunks together + chunk_window = window(torch.linspace(0, 1, chunk_size + 1, device=sound2d.device))[ + :-1 + ] + chunks = chunks * chunk_window + + chunks_odd = F.pad(torch.flatten(chunks[1::2]), (chunk_size // 2, 0)) + chunks_even = torch.flatten(chunks[0::2]) + total_size = min(len(chunks_even), len(chunks_odd)) + + signal = chunks_odd[:total_size] + chunks_even[:total_size] + return signal.unsqueeze(0) + + +def load_sound(sound_path): + data, rate = torchaudio.load(sound_path) + + # Ensure mono audio + data = data[0] + + # TODO dynamic chunk_size + chunk_size = 512 + norm_factor = 256 # 65536 + return wav_to_2D(data, chunk_size, norm_factor) diff --git a/data/unaligned_labeled_mask_cls_dataset.py b/data/unaligned_labeled_mask_cls_dataset.py index 1d625a041..708bac7c6 100644 --- a/data/unaligned_labeled_mask_cls_dataset.py +++ b/data/unaligned_labeled_mask_cls_dataset.py @@ -35,7 +35,6 @@ def get_img( B_label_cls=None, index=None, ): - return_dict = super().get_img( A_img_path, A_label_mask_path, diff --git a/models/base_model.py b/models/base_model.py index 1009f5b01..e7b26025c 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -98,6 +98,7 @@ def __init__(self, opt, rank): self.loss_names = [] self.model_names = [] self.visual_names = [] + self.sound_names = [] self.display_param = [] self.set_display_param() self.optimizers = [] @@ -737,6 +738,10 @@ def compute_visuals(self): """Calculate additional output images for visdom and HTML visualization""" pass + def compute_sounds(self): + """Calculate sounds to listen to on the visualizer""" + pass + def get_image_paths(self): """Return image paths that are used to load current data""" return self.image_paths @@ -767,6 +772,14 @@ def get_current_visuals(self, phase="train"): visual_ret.append(cur_visual) return visual_ret + def get_current_sounds(self): + # TODO phase? do same as visuals? create "types" of visuals? + sound_ret = {} + for i, name in enumerate(self.sound_names): + sound_ret[name] = getattr(self, name) + + return sound_ret + def get_display_param(self): param = OrderedDict() for name in self.display_param: @@ -933,6 +946,11 @@ def load_networks(self, epoch): state_dict[new_key] = state_dict[key].clone() del state_dict[key] + # TODO auto detect when necessary + for key in list(state_dict.keys()): + if key.startswith("denoise_fn") and key.endswith("_test"): + state_dict[key] = net.state_dict()[key] + state1 = list(state_dict.keys()) state2 = list(net.state_dict().keys()) state1.sort() diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py index 249ddd1db..df0a03db7 100644 --- a/models/cycle_gan_model.py +++ b/models/cycle_gan_model.py @@ -64,7 +64,6 @@ def after_parse(opt): return opt def __init__(self, opt, rank): - super().__init__(opt, rank) if opt.alg_cyclegan_lambda_identity > 0.0: @@ -113,7 +112,6 @@ def __init__(self, opt, rank): # Discriminators if self.isTrain: - self.netD_As = gan_networks.define_D(**vars(opt)) self.netD_Bs = gan_networks.define_D(**vars(opt)) diff --git a/models/diffusion_networks.py b/models/diffusion_networks.py index 8829d4ba9..cb58ec949 100644 --- a/models/diffusion_networks.py +++ b/models/diffusion_networks.py @@ -91,7 +91,6 @@ def define_G( in_channel += alg_diffusion_cond_embed_dim if G_netG == "unet_mha": - if model_prior_321_backwardcompatibility: cond_embed_dim = G_ngf * 4 else: diff --git a/models/gan_networks.py b/models/gan_networks.py index af8266392..58f8b6f66 100644 --- a/models/gan_networks.py +++ b/models/gan_networks.py @@ -268,7 +268,6 @@ def define_D( train_feat_wavelet, **unused_options ): - """Create a discriminator Parameters: @@ -309,7 +308,6 @@ def define_D( img_size = data_crop_size for netD in D_netDs: - if netD == "basic": # default PatchGAN classifier net = NLayerDiscriminator( model_input_nc, @@ -360,7 +358,6 @@ def define_D( download_segformer_weight(weight_path) elif D_proj_network_type == "depth": - weight_path = model_depth_network else: diff --git a/models/modules/classifiers.py b/models/modules/classifiers.py index b7d0e1eb7..a38c0bc02 100644 --- a/models/modules/classifiers.py +++ b/models/modules/classifiers.py @@ -55,7 +55,6 @@ def forward(self, x): class VGG16_FCN8s(nn.Module): - transform = torchvision.transforms.Compose( [ torchvision.transforms.ToTensor(), @@ -243,6 +242,8 @@ def forward(self, x): "mnasnet1_0": models.mnasnet1_0, "mnasnet1_3": models.mnasnet1_3, } + + # all models are RGB internally. class torch_model(nn.Module): def __init__(self, input_nc, ndf, nclasses, img_size, template, pretrained): diff --git a/models/modules/diffusion_generator.py b/models/modules/diffusion_generator.py index 1a93307db..8892a84be 100644 --- a/models/modules/diffusion_generator.py +++ b/models/modules/diffusion_generator.py @@ -41,7 +41,6 @@ def __init__( if loading_backward_compatibility: if type(self.denoise_fn.model).__name__ == "ResnetGenerator_attn_diff": - inner_channel = G_ngf self.cond_embed = nn.Sequential( nn.Linear(inner_channel, cond_embed_dim), @@ -50,7 +49,6 @@ def __init__( ) elif type(self.denoise_fn.model).__name__ == "UNet": - inner_channel = G_ngf cond_embed_dim = inner_channel * 4 @@ -63,11 +61,7 @@ def __init__( self.cond_embed_gammas_in = inner_channel else: self.cond_embed_dim = cond_embed_dim - - if any(cond in self.denoise_fn.conditioning for cond in ["class", "ref"]): - self.cond_embed_gammas = self.cond_embed_dim // 2 - else: - self.cond_embed_gammas = self.cond_embed_dim + self.cond_embed_gammas = self.denoise_fn.cond_embed_gammas self.cond_embed = nn.Sequential( nn.Linear(self.cond_embed_gammas, self.cond_embed_gammas), @@ -249,7 +243,6 @@ def p_sample( y_cond=None, guidance_scale=0.0, ): - model_mean, model_log_variance = self.p_mean_variance( y_t=y_t, t=t, @@ -429,7 +422,6 @@ def ddim_p_mean_variance( return model_mean, posterior_log_variance def forward(self, y_0, y_cond, mask, noise, cls, ref, dropout_prob=0.0): - b, *_ = y_0.shape t = torch.randint( 1, self.denoise_fn.model.num_timesteps_train, (b,), device=y_0.device diff --git a/models/modules/fid b/models/modules/fid new file mode 160000 index 000000000..d1e426cca --- /dev/null +++ b/models/modules/fid @@ -0,0 +1 @@ +Subproject commit d1e426ccabdce4ba4d0604c5d5d0dae8a60d576d diff --git a/models/modules/image_bind/imagebind_model.py b/models/modules/image_bind/imagebind_model.py index f0d3fd146..bf6c11816 100644 --- a/models/modules/image_bind/imagebind_model.py +++ b/models/modules/image_bind/imagebind_model.py @@ -498,7 +498,6 @@ def imagebind_huge(pretrained=False): ) if pretrained: - path = ".models/configs/bind/pretrain" file_name = "imagebind_huge.pth" diff --git a/models/modules/ittr/ittr_generator.py b/models/modules/ittr/ittr_generator.py index c70cdde47..864f175c3 100644 --- a/models/modules/ittr/ittr_generator.py +++ b/models/modules/ittr/ittr_generator.py @@ -288,7 +288,6 @@ class ITTRGenerator(nn.Module): """ def __init__(self, input_nc, output_nc, img_size, n_blocks=9, ngf=64): - assert n_blocks >= 0 super(ITTRGenerator, self).__init__() diff --git a/models/modules/op/upfirdn2d.py b/models/modules/op/upfirdn2d.py index 6e4f03b54..66967055d 100755 --- a/models/modules/op/upfirdn2d.py +++ b/models/modules/op/upfirdn2d.py @@ -22,7 +22,6 @@ class UpFirDn2dBackward(Function): def forward( ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size ): - up_x, up_y = up down_x, down_y = down g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad diff --git a/models/modules/palette_denoise_fn.py b/models/modules/palette_denoise_fn.py index b06354261..d0ed9ca86 100644 --- a/models/modules/palette_denoise_fn.py +++ b/models/modules/palette_denoise_fn.py @@ -42,15 +42,30 @@ def __init__(self, model, cond_embed_dim, ref_embed_net, conditioning, nclasses) self.conditioning = conditioning self.cond_embed_dim = cond_embed_dim self.ref_embed_net = ref_embed_net + self.cond_embed_gammas = cond_embed_dim # Label embedding if "class" in conditioning: - cond_embed_class = cond_embed_dim // 2 - self.netl_embedder_class = LabelEmbedder( - nclasses, - cond_embed_class, # * image_size * image_size - ) - nn.init.normal_(self.netl_embedder_class.embedding_table.weight, std=0.02) + if type(nclasses) == list: + # TODO this is arbitrary, half for class & half for detector + cond_embed_class = cond_embed_dim // (len(nclasses) + 1) + self.netl_embedders_class = nn.ModuleList( + [LabelEmbedder(nc, cond_embed_class) for nc in nclasses] + ) + for embed in self.netl_embedders_class: + self.cond_embed_gammas -= cond_embed_class + nn.init.normal_(embed.embedding_table.weight, std=0.02) + else: + # TODO this can be included in the general case + cond_embed_class = cond_embed_dim // 2 + self.netl_embedder_class = LabelEmbedder( + nclasses, + cond_embed_class, # * image_size * image_size + ) + self.cond_embed_gammas -= cond_embed_class + nn.init.normal_( + self.netl_embedder_class.embedding_table.weight, std=0.02 + ) if "mask" in conditioning: cond_embed_mask = cond_embed_dim @@ -58,6 +73,7 @@ def __init__(self, model, cond_embed_dim, ref_embed_net, conditioning, nclasses) nclasses, cond_embed_mask, # * image_size * image_size ) + self.cond_embed_gammas -= cond_embed_class nn.init.normal_(self.netl_embedder_mask.embedding_table.weight, std=0.02) # Instantiate model @@ -90,6 +106,7 @@ def __init__(self, model, cond_embed_dim, ref_embed_net, conditioning, nclasses) self.emb_layers = nn.Sequential( torch.nn.SiLU(), nn.Linear(ref_embed_dim, cond_embed_class) ) + self.cond_embed_gammas -= cond_embed_class def forward(self, input, embed_noise_level, cls, mask, ref): cls_embed, mask_embed, ref_embed = self.compute_cond(input, cls, mask, ref) @@ -114,7 +131,14 @@ def forward(self, input, embed_noise_level, cls, mask, ref): def compute_cond(self, input, cls, mask, ref): if "class" in self.conditioning and cls is not None: - cls_embed = self.netl_embedder_class(cls) + if hasattr(self, "netl_embedders_class"): + cls_embed = [] + for i in range(len(self.netl_embedders_class)): + cls_embed.append(self.netl_embedders_class[i](cls[:, i])) + cls_embed = torch.cat(cls_embed, dim=1) + else: + # TODO general case + cls_embed = self.netl_embedder_class(cls) else: cls_embed = None diff --git a/models/modules/projected_d/diffusion.py b/models/modules/projected_d/diffusion.py index 838b6c7ff..c70354b9f 100755 --- a/models/modules/projected_d/diffusion.py +++ b/models/modules/projected_d/diffusion.py @@ -106,7 +106,6 @@ def __init__( self.noise_std = float(noise_std) # Standard deviation of additive RGB noise. def set_diffusion_process(self, t, beta_schedule): - betas = get_beta_schedule( beta_schedule=beta_schedule, beta_start=self.beta_start, diff --git a/models/modules/resnet_architecture/resnet_generator.py b/models/modules/resnet_architecture/resnet_generator.py index 5dc07b1b4..bce5dcf21 100644 --- a/models/modules/resnet_architecture/resnet_generator.py +++ b/models/modules/resnet_architecture/resnet_generator.py @@ -255,7 +255,6 @@ def compute_feats(self, input, extract_layer_ids=[]): feat = input feats = [] for layer_id, layer in enumerate(self.model): - feat = layer(feat) if layer_id in extract_layer_ids: feats.append(feat) diff --git a/models/modules/resnet_architecture/resnet_generator_diff.py b/models/modules/resnet_architecture/resnet_generator_diff.py index f430eeb53..fb31bee7c 100644 --- a/models/modules/resnet_architecture/resnet_generator_diff.py +++ b/models/modules/resnet_architecture/resnet_generator_diff.py @@ -427,7 +427,6 @@ def weight_init(self, mean, std): normal_init(self._modules[m], mean, std) def compute_feats(self, input, embed_gammas, extract_layer_ids=[]): - if embed_gammas is None: # Only for GAN b = (input.shape[0], self.cond_embed_dim) @@ -458,7 +457,6 @@ def compute_feats(self, input, embed_gammas, extract_layer_ids=[]): return feat, feats, emb def compute_attention_content(self, feat, emb): - x_content = feat for layer_id, layer in enumerate(self.decoder_content): diff --git a/models/modules/resnet_architecture/sub_mobile_resnet_generator.py b/models/modules/resnet_architecture/sub_mobile_resnet_generator.py index a6a2236a0..e4207bd81 100644 --- a/models/modules/resnet_architecture/sub_mobile_resnet_generator.py +++ b/models/modules/resnet_architecture/sub_mobile_resnet_generator.py @@ -89,7 +89,6 @@ def __init__( n_downsampling = 2 for i in range(n_downsampling): # add downsampling layers - mult = 2**i ic = config["channels"][i] oc = config["channels"][i + 1] diff --git a/models/modules/segformer/backbone.py b/models/modules/segformer/backbone.py index 4a809dc3d..580b535be 100644 --- a/models/modules/segformer/backbone.py +++ b/models/modules/segformer/backbone.py @@ -328,7 +328,6 @@ def forward(self, x, hw_shape, identity=None): return identity + self.dropout_layer(self.proj_drop(out)) def legacy_forward(self, x, hw_shape, identity=None): - x_q = x if self.sr_ratio > 1: x_kv = nlc_to_nchw(x, hw_shape) @@ -671,7 +670,6 @@ class AdaptivePadding(nn.Module): """ def __init__(self, kernel_size=1, stride=1, dilation=1, padding="corner"): - super(AdaptivePadding, self).__init__() assert padding in ("same", "corner") diff --git a/models/palette_model.py b/models/palette_model.py index 4900a1e12..2faec5869 100644 --- a/models/palette_model.py +++ b/models/palette_model.py @@ -9,6 +9,7 @@ from torch import nn from data.online_creation import fill_mask_with_color, fill_mask_with_random +from data.sound_folder import wav2D_to_wav from models.modules.sam.sam_inference import compute_mask_with_sam from util.iter_calculator import IterCalculator from util.mask_generation import random_edge_mask @@ -102,9 +103,11 @@ def __init__(self, opt, rank): else: self.inference_num = min(self.opt.alg_diffusion_inference_num, batch_size) - self.num_classes = max( - self.opt.f_s_semantic_nclasses, self.opt.cls_semantic_nclasses - ) + # self.num_classes = max( + # self.opt.f_s_semantic_nclasses, self.opt.cls_semantic_nclasses + # ) + # TODO decide if we keep cls_semantic_nclasses (not used atm) + self.num_classes = self.opt.f_s_semantic_nclasses self.use_ref = ( self.opt.alg_diffusion_cond_image_creation == "ref" @@ -156,6 +159,15 @@ def __init__(self, opt, rank): ) opt.G_nblocks = 2 + # Sounds + # TODO if data input is sound + self.sound_names.extend( + ["gt_sound_" + str(i) for i in range(self.inference_num)] + ) + self.sound_names.extend( + ["output_sound_" + str(i) for i in range(self.inference_num)] + ) + # Define networks self.netG_A = diffusion_networks.define_G(**vars(opt)) @@ -581,10 +593,18 @@ def inference(self): # task: super resolution, pix2pix elif self.task in ["super_resolution", "pix2pix"]: + cls = None + + if "class" in self.opt.alg_palette_conditioning: + cls = [] + for i in self.num_classes: + cls.append(torch.randint_like(self.cls[:, 0], 0, i)) + cls = torch.stack(cls, dim=1) + self.output, self.visuals = netG.restoration( y_cond=self.cond_image[: self.inference_num], sample_num=self.sample_num, - cls=None, + cls=cls, ) self.fake_B = self.output @@ -616,6 +636,25 @@ def compute_visuals(self): with torch.no_grad(): self.inference() + def compute_sounds(self): + super().compute_sounds() + # print("Visuals: " , self.visual_names) + # TODO only when sound input data + # print("Computing sounds") + # print("inference num =", self.inference_num) + # print("n images =", self.gt_image.shape) + for i in range(self.inference_num): + name = "output_" + str(i) + # print("has %s: %s" % (name, hasattr(self, name))) + if hasattr(self, name): + img = getattr(self, name) + sound = wav2D_to_wav(img[0], 256) + name = "output_sound_" + str(i) + setattr(self, name, sound) + + gt_sound = wav2D_to_wav(self.gt_image[i], 256) + setattr(self, "gt_sound_" + str(i), gt_sound) + def get_dummy_input(self, device=None): if device is None: device = self.device diff --git a/options/common_options.py b/options/common_options.py index 102b5f269..f79b90b0c 100644 --- a/options/common_options.py +++ b/options/common_options.py @@ -491,9 +491,10 @@ def initialize(self, parser): ) parser.add_argument( "--f_s_semantic_nclasses", - default=2, + default=[2], + nargs="+", type=int, - help="number of classes of the semantic loss classifier", + help="number of classes of the semantic loss classifiers", ) parser.add_argument( "--f_s_class_weights", @@ -618,6 +619,7 @@ def initialize(self, parser): "self_supervised_labeled_mask_ref", "unaligned_labeled_mask_online_ref", "self_supervised_labeled_mask_online_ref", + "self_supervised_labeled_sound", ], help="chooses how datasets are loaded.", ) diff --git a/train.py b/train.py index 2d20299b7..2a0e321f8 100644 --- a/train.py +++ b/train.py @@ -240,6 +240,7 @@ def train_gpu(rank, world_size, opt, trainset, trainset_temporal): ): # display images on visdom and save images to a HTML file save_result = total_iters % opt.output_update_html_freq == 0 model.compute_visuals() + model.compute_sounds() if not "none" in opt.output_display_type: visualizer.display_current_results( model.get_current_visuals(), @@ -249,6 +250,14 @@ def train_gpu(rank, world_size, opt, trainset, trainset_temporal): first=(total_iters == batch_size), ) + # Play sounds in visdom + sounds = model.get_current_sounds() + if len(sounds) > 0: + visualizer.play_current_sounds( + sounds, + epoch, + ) + if ( total_iters % opt.train_save_latest_freq < batch_size ): # cache our latest model every iterations @@ -418,7 +427,6 @@ def launch_training(opt): def compute_test_metrics(model, dataloader): - return metrics diff --git a/util/metrics.py b/util/metrics.py index d7ac0bf50..809f2c862 100755 --- a/util/metrics.py +++ b/util/metrics.py @@ -124,7 +124,6 @@ def get_activations( # This happens if you choose a dimensionality not equal 2048. if len(pred.shape) == 4: - if pred.size(2) != 1 or pred.size(3) != 1: pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) @@ -150,7 +149,6 @@ def _compute_statistics_of_dataloader( nb_max_img=float("inf"), root=None, ): - if path_sv is not None and os.path.isfile(path_sv): print("Activations loaded for domain %s, from %s." % (domain, path_sv)) f = torch.load(path_sv) diff --git a/util/visualizer.py b/util/visualizer.py index 7cd2ebbce..956c9d64c 100644 --- a/util/visualizer.py +++ b/util/visualizer.py @@ -3,12 +3,16 @@ import sys import ntpath import time +import base64 +import io from . import util, html_util from subprocess import Popen, PIPE from PIL import Image import json from torchinfo import summary import math +import numpy as np +import scipy.io.wavfile if sys.version_info[0] == 2: VisdomExceptionBase = Exception @@ -128,6 +132,9 @@ def create_visdom_connections(self): def display_current_results( self, visuals, epoch, save_result, params=[], first=False, phase="train" ): + """ + Display visuals for current model in visdom or aim + """ if "visdom" in self.display_type: self.display_current_results_visdom( visuals, epoch, save_result, params, phase=phase @@ -287,6 +294,46 @@ def display_current_results_visdom( img_path = os.path.join(self.img_dir, "latest_%s.png" % label) util.save_image(image_numpy, img_path) + def convert_audio_to_b64(self, tensor): + tensor = np.array(tensor) + # Normalize only if sound is too loud + # XXX: clip instead? + tensor = np.int16(tensor / max(np.max(np.abs(tensor)), 1) * 32767) + output = io.BytesIO() + scipy.io.wavfile.write(output, 44100, tensor) + return base64.b64encode(output.getvalue()).decode("utf-8") + + def play_current_sounds(self, sounds, epoch): + """ + Play a sound in visdom + + sounds: a dict with sound name and a 1D tensor representing the sound over time + """ + if "visdom" in self.display_type: + opts = { + "width": 330, + "height": len(sounds) * 50, + "title": "Audio", + } + html_content = "" + + for name in sounds: + # video_path = os.path.join(self.img_dir, "latest_%s.mp4" % name) + sound = sounds[name].squeeze(0).cpu() + b64 = self.convert_audio_to_b64(sound) + mimetype = "wav" + html_content += """
+ + """ % ( + mimetype, + mimetype, + b64, + ) + self.vis.text(text=html_content, win="Audio", env=None, opts=opts) + def plot_current_losses(self, epoch, counter_ratio, losses): if "visdom" in self.display_type: self.plot_current_losses_visdom(epoch, counter_ratio, losses) @@ -431,12 +478,12 @@ def plot_metrics_dict( json.dump(self.metrics_dict, fp) def plot_current_metrics(self, epoch, counter_ratio, metrics): - """display the current fid values on visdom display: dictionary of fid labels and values + """display the current metrics values on visdom display Parameters: epoch (int) -- current epoch counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1 - fids (OrderedDict) -- training fid values stored in the format of (name, float) pairs + metrics (OrderedDict) -- training metrics values stored in the format of (name, float) pairs """ self.plot_metrics_dict( "metric", @@ -449,7 +496,7 @@ def plot_current_metrics(self, epoch, counter_ratio, metrics): ) def plot_current_D_accuracies(self, epoch, counter_ratio, accuracies): - """display the current fid values on visdom display: dictionary of fid labels and values + """display the current accuracies values on visdom display Parameters: epoch (int) -- current epoch @@ -478,7 +525,7 @@ def plot_current_APA_prob(self, epoch, counter_ratio, p): ) def plot_current_miou(self, epoch, counter_ratio, miou): - """display the current fid values on visdom display: dictionary of fid labels and values + """display the current miou values on visdom display Parameters: epoch (int) -- current epoch From ecbae1fe379f8b78d51cc657bdfccfc9432db09c Mon Sep 17 00:00:00 2001 From: Louis Date: Fri, 23 Feb 2024 17:38:44 +0000 Subject: [PATCH 2/2] feat: generate sound with gen_single_image_diffusion --- data/sound_folder.py | 4 ++-- docs/source/inference.rst | 6 ++--- models/cm_model.py | 3 --- models/modules/cm_generator.py | 3 --- models/modules/diffusion_utils.py | 1 - models/palette_model.py | 8 ++----- options/base_options.py | 5 ++-- options/inference_diffusion_options.py | 9 ++++++- scripts/gen_single_image.py | 2 -- scripts/gen_single_image_diffusion.py | 33 ++++++++++++++++---------- scripts/gen_video_diffusion.py | 1 - tests/test_api_predict_diffusion.py | 8 ------- tests/test_api_predict_gan.py | 7 ------ tests/test_client_server.py | 1 - tests/test_run_cm.py | 1 - tests/test_run_diffusion.py | 1 - tests/test_run_mask_online_ref.py | 2 +- tests/test_run_mask_ref.py | 2 +- 18 files changed, 41 insertions(+), 56 deletions(-) diff --git a/data/sound_folder.py b/data/sound_folder.py index a1ab76714..a56870599 100644 --- a/data/sound_folder.py +++ b/data/sound_folder.py @@ -25,7 +25,7 @@ def window(t): # TODO write a test to check that `wav2D_to_wav(wav_to_2D(x))` is consistent -def wav_to_2D(data, chunk_size, norm_factor): +def wav_to_2D(data, chunk_size, norm_factor=256): """ Transform sound data to image-like data (2D, normalized between -1 & 1) """ @@ -48,7 +48,7 @@ def wav_to_2D(data, chunk_size, norm_factor): return chunks_fft -def wav2D_to_wav(sound2d, norm_factor): +def wav2D_to_wav(sound2d, norm_factor=256): """ Transform image-like data (2D, normalized between -1 & 1) to waveform. This function is the inverse of wav_to_2D diff --git a/docs/source/inference.rst b/docs/source/inference.rst index 8de64308c..81b3336f6 100644 --- a/docs/source/inference.rst +++ b/docs/source/inference.rst @@ -208,13 +208,13 @@ Download a pretrained model: Run the inference script ======================== -The ``--cls`` parameter controls the pose for Mario (1 = standing, 2 = walking, 3 = jumping, etc). +The ``--cls_value`` parameter controls the pose for Mario (1 = standing, 2 = walking, 3 = jumping, etc). .. code:: bash mkdir mario_inference_output cd scripts/ - python3 gen_single_image_diffusion.py --model_in_file ../checkpoints/mario/latest_net_G_A.pth --img_in ../datasets/online_mario2sonic_lite/mario/imgs/mario_frame_19538.jpg --bbox_in ../datasets/online_mario2sonic_lite/mario/bbox/r_mario_frame_19538.jpg.txt --dir_out ../mario_inference_output --img_width 128 --img_height 128 --mask_delta 10 --cls 3 + python3 gen_single_image_diffusion.py --model_in_file ../checkpoints/mario/latest_net_G_A.pth --img_in ../datasets/online_mario2sonic_lite/mario/imgs/mario_frame_19538.jpg --bbox_in ../datasets/online_mario2sonic_lite/mario/bbox/r_mario_frame_19538.jpg.txt --dir_out ../mario_inference_output --img_width 128 --img_height 128 --mask_delta 10 --cls_value 3 The output files will be in the ``mario_inference_output`` folder, with: @@ -276,7 +276,7 @@ Download a pretrained model: Run the inference script ======================== -The ``--cond-in`` parameter specifies the conditioning image to use. +The ``--cond_in`` parameter specifies the conditioning image to use. .. code:: bash diff --git a/models/cm_model.py b/models/cm_model.py index bf15d7c2f..d607f1254 100644 --- a/models/cm_model.py +++ b/models/cm_model.py @@ -148,7 +148,6 @@ def __init__(self, opt, rank): self.iter_calculator_init() def set_input(self, data): - if ( len(data["A"].to(self.device).shape) == 5 ): # we're using temporal successive frames @@ -203,7 +202,6 @@ def set_input(self, data): self.real_B = self.gt_image def compute_cm_loss(self): - y_0 = self.gt_image # ground truth y_cond = self.cond_image # conditioning mask = self.mask @@ -224,7 +222,6 @@ def compute_cm_loss(self): self.loss_G_tot = loss * self.opt.alg_diffusion_lambda_G def inference(self): - if hasattr(self.netG_A, "module"): netG = self.netG_A.module else: diff --git a/models/modules/cm_generator.py b/models/modules/cm_generator.py index f3a493b0f..0381ba251 100644 --- a/models/modules/cm_generator.py +++ b/models/modules/cm_generator.py @@ -284,7 +284,6 @@ def forward( mask=None, x_cond=None, ): - num_timesteps = improved_timesteps_schedule( self.current_t, total_training_steps, @@ -344,7 +343,6 @@ def forward( ) def restoration(self, y, y_cond, sigmas, mask, clip_denoised=True): - if mask is not None: mask = torch.clamp( mask, min=0.0, max=1.0 @@ -370,7 +368,6 @@ def restoration(self, y, y_cond, sigmas, mask, clip_denoised=True): x = x * mask + (1 - mask) * y for sigma in sigmas[1:]: - sigma = torch.full((x.shape[0],), sigma, dtype=x.dtype, device=x.device) x = x + pad_dims_like( (sigma**2 - self.sigma_min**2) ** 0.5, x diff --git a/models/modules/diffusion_utils.py b/models/modules/diffusion_utils.py index e432b5ecd..e51c47441 100644 --- a/models/modules/diffusion_utils.py +++ b/models/modules/diffusion_utils.py @@ -28,7 +28,6 @@ def gamma_embedding_1D(gammas, dim, max_period): def gamma_embedding(gammas, dim, max_period=10000): - return_list = [] reduced_dim = dim // gammas.shape[1] diff --git a/models/palette_model.py b/models/palette_model.py index 2faec5869..466fb7208 100644 --- a/models/palette_model.py +++ b/models/palette_model.py @@ -595,7 +595,7 @@ def inference(self): elif self.task in ["super_resolution", "pix2pix"]: cls = None - if "class" in self.opt.alg_palette_conditioning: + if "class" in self.opt.alg_diffusion_cond_embed: cls = [] for i in self.num_classes: cls.append(torch.randint_like(self.cls[:, 0], 0, i)) @@ -638,14 +638,10 @@ def compute_visuals(self): def compute_sounds(self): super().compute_sounds() - # print("Visuals: " , self.visual_names) # TODO only when sound input data - # print("Computing sounds") - # print("inference num =", self.inference_num) - # print("n images =", self.gt_image.shape) for i in range(self.inference_num): name = "output_" + str(i) - # print("has %s: %s" % (name, hasattr(self, name))) + if hasattr(self, name): img = getattr(self, name) sound = wav2D_to_wav(img[0], 256) diff --git a/options/base_options.py b/options/base_options.py index 99d74840c..2eb371591 100644 --- a/options/base_options.py +++ b/options/base_options.py @@ -266,10 +266,11 @@ def _after_parse(self, opt, set_device=True): return self.opt - def parse(self): + def parse(self, save_config=True): """Parse our options, create checkpoints directory suffix, and set up gpu device.""" self.opt = self.gather_options() - self.save_options() + if save_config: + self.save_options() opt = self._after_parse(self.opt) return opt diff --git a/options/inference_diffusion_options.py b/options/inference_diffusion_options.py index 1e5637f30..4b9af2482 100644 --- a/options/inference_diffusion_options.py +++ b/options/inference_diffusion_options.py @@ -100,10 +100,17 @@ def initialize(self, parser): parser.add_argument( "--cls_value", type=int, - default=-1, + nargs="+", + default=[-1], help="override input bbox classe for generation", ) + parser.add_argument( + "--convert_to_sound", + action="store_true", + help="Whether the image should be converted to a sound", + ) + # XXX: options that are not in gen_single_video parser.add_argument("--previous_frame", help="image to transform", default=None) parser.add_argument( diff --git a/scripts/gen_single_image.py b/scripts/gen_single_image.py index 9b5832e2d..d8ba2cd88 100644 --- a/scripts/gen_single_image.py +++ b/scripts/gen_single_image.py @@ -50,7 +50,6 @@ def load_model(modelpath, model_in_file, cpu, gpuid): def inference_logger(name): - PROCESS_NAME = "gen_single_image" LOG_PATH = os.environ.get( "LOG_PATH", os.path.join(os.path.dirname(__file__), "../logs") @@ -70,7 +69,6 @@ def inference_logger(name): def inference(args): - PROGRESS_NUM_STEPS = 6 logger = inference_logger(args.name) logger.info(f"[1/%i] launch inference" % PROGRESS_NUM_STEPS) diff --git a/scripts/gen_single_image_diffusion.py b/scripts/gen_single_image_diffusion.py index 3c171a5de..c60b12e99 100644 --- a/scripts/gen_single_image_diffusion.py +++ b/scripts/gen_single_image_diffusion.py @@ -176,6 +176,7 @@ def generate( img_height, dir_out, write, + convert_to_sound, previous_frame, name, mask_delta, @@ -199,7 +200,6 @@ def generate( nb_samples, **unused_options, ): - PROGRESS_NUM_STEPS = 4 # seed if seed >= 0: @@ -268,12 +268,10 @@ def generate( elts = line.rstrip().split() bboxes.append([int(elts[1]), int(elts[2]), int(elts[3]), int(elts[4])]) if conditioning: - if cls_value > 0: - cls = cls_value - else: - cls = int(elts[0]) + if cls_value <= 0: + cls_value = int(elts[0]) else: - cls = 1 + cls_value = 1 if bbox_ref_id == -1: # sample a bbox here since we are calling crop_image multiple times @@ -336,7 +334,7 @@ def generate( crop_coordinates=crop_coordinates, crop_center=True, bbox_ref_id=bbox_idx, - override_class=cls, + override_class=cls_value, ) x_crop, y_crop, crop_size = crop_coordinates @@ -348,7 +346,7 @@ def generate( if len(mask_delta) == 1: index_cls = 0 else: - index_cls = int(cls) - 1 + index_cls = int(cls_value) - 1 if not isinstance(mask_delta[0][0], float): bbox_select[0] -= mask_delta[index_cls][0] @@ -608,7 +606,12 @@ def generate( if opt.model_type == "palette": if "class" in model.denoise_fn.conditioning: - cls_tensor = torch.ones(1, dtype=torch.int64, device=device) * cls + if len(cls_value) > 1: + cls_tensor = torch.tensor( + cls_value, dtype=torch.int64, device=device + ).unsqueeze(0) + else: + cls_tensor = torch.ones(1, dtype=torch.int64, device=device) * cls_value else: cls_tensor = None if ref is not None: @@ -696,6 +699,14 @@ def generate( if generated_bbox: with open(os.path.join(dir_out, name + "_generated_bbox.json"), "w") as out: out.write(json.dumps(generated_bbox)) + if convert_to_sound: + from data.sound_folder import wav2D_to_wav + import torchaudio + + sound = wav2D_to_wav(out_tensor.squeeze(0)) + torchaudio.save( + os.path.join(dir_out, name + "_generated.wav"), sound.to("cpu"), 44100 + ) print("Successfully generated image ", name) @@ -709,7 +720,6 @@ def generate( def inference_logger(name): - PROCESS_NAME = "gen_single_image_diffusion" LOG_PATH = os.environ.get( "LOG_PATH", os.path.join(os.path.dirname(__file__), "../logs") @@ -729,7 +739,6 @@ def inference_logger(name): def inference(args): - PROGRESS_NUM_STEPS = 6 logger = inference_logger(args.name) @@ -760,5 +769,5 @@ def inference(args): if __name__ == "__main__": - args = InferenceDiffusionOptions().parse() + args = InferenceDiffusionOptions().parse(save_config=False) inference(args) diff --git a/scripts/gen_video_diffusion.py b/scripts/gen_video_diffusion.py index bf7c2702b..5a6624e20 100644 --- a/scripts/gen_video_diffusion.py +++ b/scripts/gen_video_diffusion.py @@ -96,7 +96,6 @@ def natural_keys(text): lmodel = None lopt = None for i, (image, label) in tqdm(enumerate(zip(images, labels)), total=len(images)): - args.img_in = args.data_prefix + image if label.endswith(".txt"): diff --git a/tests/test_api_predict_diffusion.py b/tests/test_api_predict_diffusion.py index fb5f65175..53b0eb499 100644 --- a/tests/test_api_predict_diffusion.py +++ b/tests/test_api_predict_diffusion.py @@ -21,7 +21,6 @@ def api(): @pytest.fixture(autouse=True) def run_before_and_after_tests(dataroot): - name = "joligen_utest_api_palette" json_like_dict = { @@ -63,7 +62,6 @@ def run_before_and_after_tests(dataroot): @pytest.mark.asyncio async def test_predict_endpoint_diffusion_success(dataroot, api): - name = "joligen_utest_api_palette" dir_model = "/".join(dataroot.split("/")[:-1]) @@ -109,11 +107,8 @@ async def test_predict_endpoint_diffusion_success(dataroot, api): assert len(json_response["name"]) > 0 with api.websocket_connect(f"/ws/predict/%s" % predict_name) as ws: - while True: - try: - data = ws.receive_json() if data["status"] != "log": @@ -142,7 +137,6 @@ async def test_predict_endpoint_diffusion_success(dataroot, api): def test_predict_endpoint_sync_success(dataroot, api): - name = "joligen_utest_api_palette" dir_model = "/".join(dataroot.split("/")[:-1]) @@ -200,7 +194,6 @@ def test_predict_endpoint_sync_success(dataroot, api): def test_predict_endpoint_sync_base64(dataroot, api): - name = "joligen_utest_api_palette" dir_model = "/".join(dataroot.split("/")[:-1]) @@ -247,7 +240,6 @@ def test_predict_endpoint_sync_base64(dataroot, api): assert len(json_response["base64"]) == 4 for index, output in enumerate(["cond", "generated", "orig", "y_t"]): - img_out = os.path.join(dir_model, f"%s_0_%s.png" % (predict_name, output)) assert os.path.exists(img_out) diff --git a/tests/test_api_predict_gan.py b/tests/test_api_predict_gan.py index 7689b67c9..efcbab29e 100644 --- a/tests/test_api_predict_gan.py +++ b/tests/test_api_predict_gan.py @@ -20,7 +20,6 @@ def api(): @pytest.fixture(autouse=True) def run_before_and_after_tests(dataroot): - name = "joligen_utest_api_cut" print(dataroot) @@ -60,7 +59,6 @@ def run_before_and_after_tests(dataroot): @pytest.mark.asyncio async def test_predict_endpoint_gan_success(dataroot, api): - name = "joligen_utest_api_cut" dir_model = "/".join(dataroot.split("/")[:-1]) @@ -108,11 +106,8 @@ async def test_predict_endpoint_gan_success(dataroot, api): predict_name = json_response["name"] with api.websocket_connect(f"/ws/predict/%s" % predict_name) as ws: - while True: - try: - data = ws.receive_json() if data["status"] != "log": @@ -135,7 +130,6 @@ async def test_predict_endpoint_gan_success(dataroot, api): def test_predict_endpoint_sync_success(dataroot, api): - name = "joligen_utest_api_cut" dir_model = "/".join(dataroot.split("/")[:-1]) @@ -185,7 +179,6 @@ def test_predict_endpoint_sync_success(dataroot, api): def test_predict_endpoint_sync_base64(dataroot, api): - name = "joligen_utest_api_cut" dir_model = "/".join(dataroot.split("/")[:-1]) diff --git a/tests/test_client_server.py b/tests/test_client_server.py index 374c81c45..dd700c0f1 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -49,7 +49,6 @@ def do_POST(self): path == "/train/test_client" and self.parse_json(body)["train_options"] is not None ): - self.write_response( 200, "application/json", diff --git a/tests/test_run_cm.py b/tests/test_run_cm.py index 46ef803d1..b1d40def0 100644 --- a/tests/test_run_cm.py +++ b/tests/test_run_cm.py @@ -68,7 +68,6 @@ def test_semantic_mask(dataroot): Gtype, alg_diffusion_cond_embed, ) in product_list: - json_like_dict_c = json_like_dict.copy() json_like_dict_c["model_type"] = model json_like_dict_c["name"] += "_" + model diff --git a/tests/test_run_diffusion.py b/tests/test_run_diffusion.py index 5cc6a0dc8..f6cb3cc96 100644 --- a/tests/test_run_diffusion.py +++ b/tests/test_run_diffusion.py @@ -72,7 +72,6 @@ def test_semantic_mask(dataroot): G_efficient, train_feat_wavelet, ) in product_list: - json_like_dict_c = json_like_dict.copy() json_like_dict_c["model_type"] = model json_like_dict_c["name"] += "_" + model diff --git a/tests/test_run_mask_online_ref.py b/tests/test_run_mask_online_ref.py index 4d9b916a1..099960197 100644 --- a/tests/test_run_mask_online_ref.py +++ b/tests/test_run_mask_online_ref.py @@ -36,7 +36,7 @@ ["cut", "unaligned_labeled_mask_online_ref"], ] conditionings = [ - "alg_palette_conditioning", + "alg_diffusion_cond_embed", "alg_palette_cond_image_creation", ] diff --git a/tests/test_run_mask_ref.py b/tests/test_run_mask_ref.py index ebb0568f9..5f2860d06 100644 --- a/tests/test_run_mask_ref.py +++ b/tests/test_run_mask_ref.py @@ -36,7 +36,7 @@ ["cut", "unaligned_labeled_mask_ref"], ] conditionings = [ - "alg_palette_conditioning", + "alg_diffusion_cond_embed", "alg_palette_cond_image_creation", ]