Skip to content

Commit

Permalink
[Enhancement] Revise SWD metric and DCGAN's configs (#1528)
Browse files Browse the repository at this point in the history
* add number checking for base gen metric

* support -1 for fake imgs in SWD metric

* revise dcgan's config

* update unit test of base gen metric

* update unit test of swd

* ignore some windows unit test
  • Loading branch information
LeoXing1996 authored Dec 30, 2022
1 parent 0a58456 commit 320b9a2
Show file tree
Hide file tree
Showing 11 changed files with 102 additions and 26 deletions.
2 changes: 2 additions & 0 deletions configs/dcgan/dcgan_1xb128-300kiters_celeba-cropped-64.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
sample_model='orig',
image_shape=(3, 64, 64))
]
# save best checkpoints
default_hooks = dict(checkpoint=dict(save_best='swd/avg', rule='less'))

val_evaluator = dict(metrics=metrics)
test_evaluator = dict(metrics=metrics)
2 changes: 2 additions & 0 deletions configs/dcgan/dcgan_1xb128-5epoches_lsun-bedroom-64x64.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
sample_model='orig',
image_shape=(3, 64, 64))
]
# save best checkpoints
default_hooks = dict(checkpoint=dict(save_best='swd/avg', rule='less'))

val_evaluator = dict(metrics=metrics)
test_evaluator = dict(metrics=metrics)
40 changes: 20 additions & 20 deletions configs/dcgan/dcgan_Glr4e-4_Dlr1e-4_1xb128-5kiters_mnist-64x64.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,47 +5,44 @@
]

# output single channel
model = dict(generator=dict(out_channels=1), discriminator=dict(in_channels=1))
model = dict(
data_preprocessor=dict(mean=[127.5], std=[127.5]),
generator=dict(out_channels=1),
discriminator=dict(in_channels=1))

# define dataset
# modify train_pipeline to load gray scale images
train_pipeline = [
dict(
type='LoadImageFromFile',
key='img',
io_backend='disk',
color_type='grayscale'),
dict(type='LoadImageFromFile', key='img', color_type='grayscale'),
dict(type='Resize', scale=(64, 64)),
dict(type='PackEditInputs', meta_keys=[])
dict(type='PackEditInputs')
]

# set ``batch_size``` and ``data_root```
batch_size = 128
data_root = 'data/mnist_64/train'
train_dataloader = dict(
batch_size=batch_size, dataset=dict(data_root=data_root))
batch_size=batch_size,
dataset=dict(data_root=data_root, pipeline=train_pipeline))

val_dataloader = dict(batch_size=batch_size, dataset=dict(data_root=data_root))
val_dataloader = dict(
batch_size=batch_size,
dataset=dict(data_root=data_root, pipeline=train_pipeline))

test_dataloader = dict(
batch_size=batch_size, dataset=dict(data_root=data_root))

default_hooks = dict(
checkpoint=dict(
interval=500,
save_best=['swd/avg', 'ms-ssim/avg'],
rule=['less', 'greater']))
batch_size=batch_size,
dataset=dict(data_root=data_root, pipeline=train_pipeline))

# VIS_HOOK
custom_hooks = [
dict(
type='GenVisualizationHook',
interval=10000,
interval=500,
fixed_input=True,
vis_kwargs_list=dict(type='GAN', name='fake_img'))
]

train_cfg = dict(max_iters=5000)
train_cfg = dict(max_iters=5000, val_interval=500)

# METRICS
metrics = [
Expand All @@ -55,10 +52,13 @@
dict(
type='SWD',
prefix='swd',
fake_nums=16384,
fake_nums=-1,
sample_model='orig',
image_shape=(3, 64, 64))
image_shape=(1, 64, 64))
]
# save best checkpoints
default_hooks = dict(
checkpoint=dict(interval=500, save_best='swd/avg', rule='less'))

val_evaluator = dict(metrics=metrics)
test_evaluator = dict(metrics=metrics)
Expand Down
10 changes: 7 additions & 3 deletions mmedit/evaluation/metrics/base_gen_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,14 @@ def get_metric_sampler(self, model: nn.Module, dataloader: DataLoader,
DataLoader: Default sampler for normal metrics.
"""
batch_size = dataloader.batch_size

dataset_length = len(dataloader.dataset)
rank, num_gpus = get_dist_info()
item_subset = [(i * num_gpus + rank) % self.real_nums
for i in range((self.real_nums - 1) // num_gpus + 1)]
assert self.real_nums <= dataset_length, (
f'\'real_nums\'({self.real_nums}) can not larger than length of '
f'dataset ({dataset_length}).')
nums = dataset_length if self.real_nums == -1 else self.real_nums
item_subset = [(i * num_gpus + rank) % nums
for i in range((nums - 1) // num_gpus + 1)]

metric_dataloader = DataLoader(
dataloader.dataset,
Expand Down
7 changes: 6 additions & 1 deletion mmedit/evaluation/metrics/swd.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,8 @@ def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
data_batch (dict): A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from the model.
"""
if self._num_processed >= self.fake_nums_per_device:
if self.fake_nums != -1 and (self._num_processed >=
self.fake_nums_per_device):
return

real_imgs, fake_imgs = [], []
Expand All @@ -279,6 +280,8 @@ def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:

# real images
assert real_imgs.shape[1:] == self.image_shape
if real_imgs.shape[1] == 1:
real_imgs = real_imgs.repeat(1, 3, 1, 1)
real_pyramid = laplacian_pyramid(real_imgs, self.n_pyramids - 1,
self.gaussian_k)
# lod: layer_of_descriptors
Expand All @@ -291,6 +294,8 @@ def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:

# fake images
assert fake_imgs.shape[1:] == self.image_shape
if fake_imgs.shape[1] == 1:
fake_imgs = fake_imgs.repeat(1, 3, 1, 1)
fake_pyramid = laplacian_pyramid(fake_imgs, self.n_pyramids - 1,
self.gaussian_k)
# lod: layer_of_descriptors
Expand Down
7 changes: 7 additions & 0 deletions tests/test_evaluation/test_metrics/test_base_gen_metric.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import MagicMock, patch

import pytest
import torch
from mmengine.model import MMDistributedDataParallel

Expand Down Expand Up @@ -43,6 +44,7 @@ def test_GenMetric():
# test get_metric_sampler
model = MagicMock()
dataset = MagicMock()
dataset.__len__.return_value = 10
dataloader = MagicMock()
dataloader.batch_size = 4
dataloader.dataset = dataset
Expand All @@ -57,6 +59,11 @@ def test_GenMetric():
metric.prepare(model, dataloader)
assert metric.data_preprocessor == preprocessor

# test raise error with dataset is length than real_nums
dataset.__len__.return_value = 5
with pytest.raises(AssertionError):
metric.get_metric_sampler(model, dataloader, [metric])


def test_GenerativeMetric():
metric = ToyGenerativeMetric(11, need_cond_input=True)
Expand Down
41 changes: 39 additions & 2 deletions tests/test_evaluation/test_metrics/test_swd.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
class TestSWD(TestCase):

def test_init(self):
swd = SlicedWassersteinDistance(
fake_nums=10, image_shape=(3, 32, 32)) # noqa
swd = SlicedWassersteinDistance(fake_nums=10, image_shape=(3, 32, 32))
self.assertEqual(len(swd.real_results), 2)

def test_prosess(self):
Expand All @@ -35,6 +34,13 @@ def test_prosess(self):
]

swd.process(real_samples, fake_samples)
# 100 samples are passed in 1 batch, _num_processed should be 100
self.assertEqual(swd._num_processed, 100)
# _num_processed(100) > fake_nums(4), _num_processed should be
# unchanged
swd.process(real_samples, fake_samples)
self.assertEqual(swd._num_processed, 100)

output = swd.evaluate()
result = [16.495922580361366, 24.15413036942482, 20.325026474893093]
output = [item / 100 for item in output.values()]
Expand All @@ -48,3 +54,34 @@ def test_prosess(self):
sample_model='orig',
image_shape=(3, 32, 32))
swd.prepare(model, None)

# test gray scale input
swd.image_shape = (1, 32, 32)
real_samples = [
dict(inputs=torch.rand(1, 32, 32) * 255.) for _ in range(100)
]
fake_samples = [
EditDataSample(
fake_img=PixelData(data=torch.rand(1, 32, 32) * 2 - 1),
gt_img=PixelData(data=torch.rand(1, 32, 32) * 2 -
1)).to_dict() for _ in range(100)
]
swd.process(real_samples, fake_samples)

# test fake_nums is -1
swd = SlicedWassersteinDistance(
fake_nums=-1,
fake_key='fake',
real_key='img',
sample_model='orig',
image_shape=(3, 32, 32))
fake_samples = [
EditDataSample(
fake_img=PixelData(data=torch.rand(3, 32, 32) * 2 - 1),
gt_img=PixelData(data=torch.rand(3, 32, 32) * 2 -
1)).to_dict() for _ in range(10)
]
for _ in range(3):
swd.process(None, fake_samples)
# fake_nums is -1, all samples (10 * 3 = 30) is processed
self.assertEqual(swd._num_processed, 30)
5 changes: 5 additions & 0 deletions tests/test_models/test_editors/test_wgan_gp/test_wgan_gp.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
import platform
from copy import deepcopy
from unittest import TestCase

import pytest
import torch
from mmengine import MessageHub
from mmengine.optim import OptimWrapper, OptimWrapperDict
Expand Down Expand Up @@ -54,6 +56,9 @@ def test_init(self):
gan = WGANGP(generator=gen, data_preprocessor=GenDataPreprocessor())
self.assertEqual(gan.discriminator, None)

@pytest.mark.skipif(
'win' in platform.system().lower() and 'cu' in torch.__version__,
reason='skip on windows-cuda due to limited RAM.')
def test_train_step(self):
# prepare model
accu_iter = 1
Expand Down
5 changes: 5 additions & 0 deletions tests/test_models/test_losses/test_feature_loss.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
import platform

import pytest
import torch

from mmedit.models.losses import LightCNNFeatureLoss


@pytest.mark.skipif(
'win' in platform.system().lower() and 'cu' in torch.__version__,
reason='skip on windows-cuda due to limited RAM.')
def test_light_cnn_feature_loss():

pretrained = 'https://download.openmmlab.com/mmediting/' + \
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import platform

import pytest
import torch
from mmengine.utils.dl_utils import TORCH_VERSION
Expand All @@ -8,6 +10,9 @@
from mmedit.models.losses import GeneratorPathRegularizerComps


@pytest.mark.skipif(
'win' in platform.system().lower() and 'cu' in torch.__version__,
reason='skip on windows-cuda due to limited RAM.')
class TestPathRegularizer:

@classmethod
Expand Down
4 changes: 4 additions & 0 deletions tests/test_models/test_losses/test_perceptual_loss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import platform
from unittest.mock import patch

import pytest
Expand All @@ -8,6 +9,9 @@
TransferalPerceptualLoss)


@pytest.mark.skipif(
'win' in platform.system().lower() and 'cu' in torch.__version__,
reason='skip on windows-cuda due to limited RAM.')
@patch.object(PerceptualVGG, 'init_weights')
def test_perceptual_loss(init_weights):
if torch.cuda.is_available():
Expand Down

0 comments on commit 320b9a2

Please sign in to comment.