From 320b9a2ee4013eaad14a3332c8238412d24e0f9a Mon Sep 17 00:00:00 2001 From: LeoXing1996 Date: Fri, 30 Dec 2022 17:10:49 +0800 Subject: [PATCH] [Enhancement] Revise SWD metric and DCGAN's configs (#1528) * add number checking for base gen metric * support -1 for fake imgs in SWD metric * revise dcgan's config * update unit test of base gen metric * update unit test of swd * ignore some windows unit test --- ...cgan_1xb128-300kiters_celeba-cropped-64.py | 2 + ...cgan_1xb128-5epoches_lsun-bedroom-64x64.py | 2 + ...4e-4_Dlr1e-4_1xb128-5kiters_mnist-64x64.py | 40 +++++++++--------- mmedit/evaluation/metrics/base_gen_metric.py | 10 +++-- mmedit/evaluation/metrics/swd.py | 7 +++- .../test_metrics/test_base_gen_metric.py | 7 ++++ .../test_evaluation/test_metrics/test_swd.py | 41 ++++++++++++++++++- .../test_editors/test_wgan_gp/test_wgan_gp.py | 5 +++ .../test_losses/test_feature_loss.py | 5 +++ .../test_gen_auxiliary_loss_comps.py | 5 +++ .../test_losses/test_perceptual_loss.py | 4 ++ 11 files changed, 102 insertions(+), 26 deletions(-) diff --git a/configs/dcgan/dcgan_1xb128-300kiters_celeba-cropped-64.py b/configs/dcgan/dcgan_1xb128-300kiters_celeba-cropped-64.py index 95324ff8e4..2f1ab816e6 100644 --- a/configs/dcgan/dcgan_1xb128-300kiters_celeba-cropped-64.py +++ b/configs/dcgan/dcgan_1xb128-300kiters_celeba-cropped-64.py @@ -45,6 +45,8 @@ sample_model='orig', image_shape=(3, 64, 64)) ] +# save best checkpoints +default_hooks = dict(checkpoint=dict(save_best='swd/avg', rule='less')) val_evaluator = dict(metrics=metrics) test_evaluator = dict(metrics=metrics) diff --git a/configs/dcgan/dcgan_1xb128-5epoches_lsun-bedroom-64x64.py b/configs/dcgan/dcgan_1xb128-5epoches_lsun-bedroom-64x64.py index 79f4e56f5f..e4396a3462 100644 --- a/configs/dcgan/dcgan_1xb128-5epoches_lsun-bedroom-64x64.py +++ b/configs/dcgan/dcgan_1xb128-5epoches_lsun-bedroom-64x64.py @@ -44,6 +44,8 @@ sample_model='orig', image_shape=(3, 64, 64)) ] +# save best checkpoints +default_hooks = dict(checkpoint=dict(save_best='swd/avg', rule='less')) val_evaluator = dict(metrics=metrics) test_evaluator = dict(metrics=metrics) diff --git a/configs/dcgan/dcgan_Glr4e-4_Dlr1e-4_1xb128-5kiters_mnist-64x64.py b/configs/dcgan/dcgan_Glr4e-4_Dlr1e-4_1xb128-5kiters_mnist-64x64.py index 3eb1b4ad99..3625022ccf 100644 --- a/configs/dcgan/dcgan_Glr4e-4_Dlr1e-4_1xb128-5kiters_mnist-64x64.py +++ b/configs/dcgan/dcgan_Glr4e-4_Dlr1e-4_1xb128-5kiters_mnist-64x64.py @@ -5,47 +5,44 @@ ] # output single channel -model = dict(generator=dict(out_channels=1), discriminator=dict(in_channels=1)) +model = dict( + data_preprocessor=dict(mean=[127.5], std=[127.5]), + generator=dict(out_channels=1), + discriminator=dict(in_channels=1)) # define dataset # modify train_pipeline to load gray scale images train_pipeline = [ - dict( - type='LoadImageFromFile', - key='img', - io_backend='disk', - color_type='grayscale'), + dict(type='LoadImageFromFile', key='img', color_type='grayscale'), dict(type='Resize', scale=(64, 64)), - dict(type='PackEditInputs', meta_keys=[]) + dict(type='PackEditInputs') ] # set ``batch_size``` and ``data_root``` batch_size = 128 data_root = 'data/mnist_64/train' train_dataloader = dict( - batch_size=batch_size, dataset=dict(data_root=data_root)) + batch_size=batch_size, + dataset=dict(data_root=data_root, pipeline=train_pipeline)) -val_dataloader = dict(batch_size=batch_size, dataset=dict(data_root=data_root)) +val_dataloader = dict( + batch_size=batch_size, + dataset=dict(data_root=data_root, pipeline=train_pipeline)) test_dataloader = dict( - batch_size=batch_size, dataset=dict(data_root=data_root)) - -default_hooks = dict( - checkpoint=dict( - interval=500, - save_best=['swd/avg', 'ms-ssim/avg'], - rule=['less', 'greater'])) + batch_size=batch_size, + dataset=dict(data_root=data_root, pipeline=train_pipeline)) # VIS_HOOK custom_hooks = [ dict( type='GenVisualizationHook', - interval=10000, + interval=500, fixed_input=True, vis_kwargs_list=dict(type='GAN', name='fake_img')) ] -train_cfg = dict(max_iters=5000) +train_cfg = dict(max_iters=5000, val_interval=500) # METRICS metrics = [ @@ -55,10 +52,13 @@ dict( type='SWD', prefix='swd', - fake_nums=16384, + fake_nums=-1, sample_model='orig', - image_shape=(3, 64, 64)) + image_shape=(1, 64, 64)) ] +# save best checkpoints +default_hooks = dict( + checkpoint=dict(interval=500, save_best='swd/avg', rule='less')) val_evaluator = dict(metrics=metrics) test_evaluator = dict(metrics=metrics) diff --git a/mmedit/evaluation/metrics/base_gen_metric.py b/mmedit/evaluation/metrics/base_gen_metric.py index c5916cdc4c..eb8424b9ca 100644 --- a/mmedit/evaluation/metrics/base_gen_metric.py +++ b/mmedit/evaluation/metrics/base_gen_metric.py @@ -155,10 +155,14 @@ def get_metric_sampler(self, model: nn.Module, dataloader: DataLoader, DataLoader: Default sampler for normal metrics. """ batch_size = dataloader.batch_size - + dataset_length = len(dataloader.dataset) rank, num_gpus = get_dist_info() - item_subset = [(i * num_gpus + rank) % self.real_nums - for i in range((self.real_nums - 1) // num_gpus + 1)] + assert self.real_nums <= dataset_length, ( + f'\'real_nums\'({self.real_nums}) can not larger than length of ' + f'dataset ({dataset_length}).') + nums = dataset_length if self.real_nums == -1 else self.real_nums + item_subset = [(i * num_gpus + rank) % nums + for i in range((nums - 1) // num_gpus + 1)] metric_dataloader = DataLoader( dataloader.dataset, diff --git a/mmedit/evaluation/metrics/swd.py b/mmedit/evaluation/metrics/swd.py index ecbb9a282c..ba8a46bc18 100644 --- a/mmedit/evaluation/metrics/swd.py +++ b/mmedit/evaluation/metrics/swd.py @@ -255,7 +255,8 @@ def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: data_batch (dict): A batch of data from the dataloader. data_samples (Sequence[dict]): A batch of outputs from the model. """ - if self._num_processed >= self.fake_nums_per_device: + if self.fake_nums != -1 and (self._num_processed >= + self.fake_nums_per_device): return real_imgs, fake_imgs = [], [] @@ -279,6 +280,8 @@ def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: # real images assert real_imgs.shape[1:] == self.image_shape + if real_imgs.shape[1] == 1: + real_imgs = real_imgs.repeat(1, 3, 1, 1) real_pyramid = laplacian_pyramid(real_imgs, self.n_pyramids - 1, self.gaussian_k) # lod: layer_of_descriptors @@ -291,6 +294,8 @@ def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: # fake images assert fake_imgs.shape[1:] == self.image_shape + if fake_imgs.shape[1] == 1: + fake_imgs = fake_imgs.repeat(1, 3, 1, 1) fake_pyramid = laplacian_pyramid(fake_imgs, self.n_pyramids - 1, self.gaussian_k) # lod: layer_of_descriptors diff --git a/tests/test_evaluation/test_metrics/test_base_gen_metric.py b/tests/test_evaluation/test_metrics/test_base_gen_metric.py index da0602fe60..699f188ac2 100644 --- a/tests/test_evaluation/test_metrics/test_base_gen_metric.py +++ b/tests/test_evaluation/test_metrics/test_base_gen_metric.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from unittest.mock import MagicMock, patch +import pytest import torch from mmengine.model import MMDistributedDataParallel @@ -43,6 +44,7 @@ def test_GenMetric(): # test get_metric_sampler model = MagicMock() dataset = MagicMock() + dataset.__len__.return_value = 10 dataloader = MagicMock() dataloader.batch_size = 4 dataloader.dataset = dataset @@ -57,6 +59,11 @@ def test_GenMetric(): metric.prepare(model, dataloader) assert metric.data_preprocessor == preprocessor + # test raise error with dataset is length than real_nums + dataset.__len__.return_value = 5 + with pytest.raises(AssertionError): + metric.get_metric_sampler(model, dataloader, [metric]) + def test_GenerativeMetric(): metric = ToyGenerativeMetric(11, need_cond_input=True) diff --git a/tests/test_evaluation/test_metrics/test_swd.py b/tests/test_evaluation/test_metrics/test_swd.py index 9e85310198..201772670f 100644 --- a/tests/test_evaluation/test_metrics/test_swd.py +++ b/tests/test_evaluation/test_metrics/test_swd.py @@ -13,8 +13,7 @@ class TestSWD(TestCase): def test_init(self): - swd = SlicedWassersteinDistance( - fake_nums=10, image_shape=(3, 32, 32)) # noqa + swd = SlicedWassersteinDistance(fake_nums=10, image_shape=(3, 32, 32)) self.assertEqual(len(swd.real_results), 2) def test_prosess(self): @@ -35,6 +34,13 @@ def test_prosess(self): ] swd.process(real_samples, fake_samples) + # 100 samples are passed in 1 batch, _num_processed should be 100 + self.assertEqual(swd._num_processed, 100) + # _num_processed(100) > fake_nums(4), _num_processed should be + # unchanged + swd.process(real_samples, fake_samples) + self.assertEqual(swd._num_processed, 100) + output = swd.evaluate() result = [16.495922580361366, 24.15413036942482, 20.325026474893093] output = [item / 100 for item in output.values()] @@ -48,3 +54,34 @@ def test_prosess(self): sample_model='orig', image_shape=(3, 32, 32)) swd.prepare(model, None) + + # test gray scale input + swd.image_shape = (1, 32, 32) + real_samples = [ + dict(inputs=torch.rand(1, 32, 32) * 255.) for _ in range(100) + ] + fake_samples = [ + EditDataSample( + fake_img=PixelData(data=torch.rand(1, 32, 32) * 2 - 1), + gt_img=PixelData(data=torch.rand(1, 32, 32) * 2 - + 1)).to_dict() for _ in range(100) + ] + swd.process(real_samples, fake_samples) + + # test fake_nums is -1 + swd = SlicedWassersteinDistance( + fake_nums=-1, + fake_key='fake', + real_key='img', + sample_model='orig', + image_shape=(3, 32, 32)) + fake_samples = [ + EditDataSample( + fake_img=PixelData(data=torch.rand(3, 32, 32) * 2 - 1), + gt_img=PixelData(data=torch.rand(3, 32, 32) * 2 - + 1)).to_dict() for _ in range(10) + ] + for _ in range(3): + swd.process(None, fake_samples) + # fake_nums is -1, all samples (10 * 3 = 30) is processed + self.assertEqual(swd._num_processed, 30) diff --git a/tests/test_models/test_editors/test_wgan_gp/test_wgan_gp.py b/tests/test_models/test_editors/test_wgan_gp/test_wgan_gp.py index 99b82cb468..fde0125b54 100644 --- a/tests/test_models/test_editors/test_wgan_gp/test_wgan_gp.py +++ b/tests/test_models/test_editors/test_wgan_gp/test_wgan_gp.py @@ -1,7 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. +import platform from copy import deepcopy from unittest import TestCase +import pytest import torch from mmengine import MessageHub from mmengine.optim import OptimWrapper, OptimWrapperDict @@ -54,6 +56,9 @@ def test_init(self): gan = WGANGP(generator=gen, data_preprocessor=GenDataPreprocessor()) self.assertEqual(gan.discriminator, None) + @pytest.mark.skipif( + 'win' in platform.system().lower() and 'cu' in torch.__version__, + reason='skip on windows-cuda due to limited RAM.') def test_train_step(self): # prepare model accu_iter = 1 diff --git a/tests/test_models/test_losses/test_feature_loss.py b/tests/test_models/test_losses/test_feature_loss.py index b0ef06b92a..1b71a99627 100644 --- a/tests/test_models/test_losses/test_feature_loss.py +++ b/tests/test_models/test_losses/test_feature_loss.py @@ -1,10 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. +import platform + import pytest import torch from mmedit.models.losses import LightCNNFeatureLoss +@pytest.mark.skipif( + 'win' in platform.system().lower() and 'cu' in torch.__version__, + reason='skip on windows-cuda due to limited RAM.') def test_light_cnn_feature_loss(): pretrained = 'https://download.openmmlab.com/mmediting/' + \ diff --git a/tests/test_models/test_losses/test_loss_comps/test_gen_auxiliary_loss_comps.py b/tests/test_models/test_losses/test_loss_comps/test_gen_auxiliary_loss_comps.py index 8753f77380..160170d609 100644 --- a/tests/test_models/test_losses/test_loss_comps/test_gen_auxiliary_loss_comps.py +++ b/tests/test_models/test_losses/test_loss_comps/test_gen_auxiliary_loss_comps.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +import platform + import pytest import torch from mmengine.utils.dl_utils import TORCH_VERSION @@ -8,6 +10,9 @@ from mmedit.models.losses import GeneratorPathRegularizerComps +@pytest.mark.skipif( + 'win' in platform.system().lower() and 'cu' in torch.__version__, + reason='skip on windows-cuda due to limited RAM.') class TestPathRegularizer: @classmethod diff --git a/tests/test_models/test_losses/test_perceptual_loss.py b/tests/test_models/test_losses/test_perceptual_loss.py index 3c2ebc12ff..d928600049 100644 --- a/tests/test_models/test_losses/test_perceptual_loss.py +++ b/tests/test_models/test_losses/test_perceptual_loss.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import platform from unittest.mock import patch import pytest @@ -8,6 +9,9 @@ TransferalPerceptualLoss) +@pytest.mark.skipif( + 'win' in platform.system().lower() and 'cu' in torch.__version__, + reason='skip on windows-cuda due to limited RAM.') @patch.object(PerceptualVGG, 'init_weights') def test_perceptual_loss(init_weights): if torch.cuda.is_available():