Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Revise SWD metric and DCGAN's configs #1528

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions configs/dcgan/dcgan_1xb128-300kiters_celeba-cropped-64.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
sample_model='orig',
image_shape=(3, 64, 64))
]
# save best checkpoints
default_hooks = dict(checkpoint=dict(save_best='swd/avg', rule='less'))

val_evaluator = dict(metrics=metrics)
test_evaluator = dict(metrics=metrics)
2 changes: 2 additions & 0 deletions configs/dcgan/dcgan_1xb128-5epoches_lsun-bedroom-64x64.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
sample_model='orig',
image_shape=(3, 64, 64))
]
# save best checkpoints
default_hooks = dict(checkpoint=dict(save_best='swd/avg', rule='less'))

val_evaluator = dict(metrics=metrics)
test_evaluator = dict(metrics=metrics)
40 changes: 20 additions & 20 deletions configs/dcgan/dcgan_Glr4e-4_Dlr1e-4_1xb128-5kiters_mnist-64x64.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,47 +5,44 @@
]

# output single channel
model = dict(generator=dict(out_channels=1), discriminator=dict(in_channels=1))
model = dict(
data_preprocessor=dict(mean=[127.5], std=[127.5]),
generator=dict(out_channels=1),
discriminator=dict(in_channels=1))

# define dataset
# modify train_pipeline to load gray scale images
train_pipeline = [
dict(
type='LoadImageFromFile',
key='img',
io_backend='disk',
color_type='grayscale'),
dict(type='LoadImageFromFile', key='img', color_type='grayscale'),
dict(type='Resize', scale=(64, 64)),
dict(type='PackEditInputs', meta_keys=[])
dict(type='PackEditInputs')
]

# set ``batch_size``` and ``data_root```
batch_size = 128
data_root = 'data/mnist_64/train'
train_dataloader = dict(
batch_size=batch_size, dataset=dict(data_root=data_root))
batch_size=batch_size,
dataset=dict(data_root=data_root, pipeline=train_pipeline))

val_dataloader = dict(batch_size=batch_size, dataset=dict(data_root=data_root))
val_dataloader = dict(
batch_size=batch_size,
dataset=dict(data_root=data_root, pipeline=train_pipeline))

test_dataloader = dict(
batch_size=batch_size, dataset=dict(data_root=data_root))

default_hooks = dict(
checkpoint=dict(
interval=500,
save_best=['swd/avg', 'ms-ssim/avg'],
rule=['less', 'greater']))
batch_size=batch_size,
dataset=dict(data_root=data_root, pipeline=train_pipeline))

# VIS_HOOK
custom_hooks = [
dict(
type='GenVisualizationHook',
interval=10000,
interval=500,
fixed_input=True,
vis_kwargs_list=dict(type='GAN', name='fake_img'))
]

train_cfg = dict(max_iters=5000)
train_cfg = dict(max_iters=5000, val_interval=500)

# METRICS
metrics = [
Expand All @@ -55,10 +52,13 @@
dict(
type='SWD',
prefix='swd',
fake_nums=16384,
fake_nums=-1,
sample_model='orig',
image_shape=(3, 64, 64))
image_shape=(1, 64, 64))
]
# save best checkpoints
default_hooks = dict(
checkpoint=dict(interval=500, save_best='swd/avg', rule='less'))

val_evaluator = dict(metrics=metrics)
test_evaluator = dict(metrics=metrics)
Expand Down
10 changes: 7 additions & 3 deletions mmedit/evaluation/metrics/base_gen_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,14 @@ def get_metric_sampler(self, model: nn.Module, dataloader: DataLoader,
DataLoader: Default sampler for normal metrics.
"""
batch_size = dataloader.batch_size

dataset_length = len(dataloader.dataset)
rank, num_gpus = get_dist_info()
item_subset = [(i * num_gpus + rank) % self.real_nums
for i in range((self.real_nums - 1) // num_gpus + 1)]
assert self.real_nums <= dataset_length, (
f'\'real_nums\'({self.real_nums}) can not larger than length of '
f'dataset ({dataset_length}).')
nums = dataset_length if self.real_nums == -1 else self.real_nums
item_subset = [(i * num_gpus + rank) % nums
for i in range((nums - 1) // num_gpus + 1)]

metric_dataloader = DataLoader(
dataloader.dataset,
Expand Down
7 changes: 6 additions & 1 deletion mmedit/evaluation/metrics/swd.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,8 @@ def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
data_batch (dict): A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from the model.
"""
if self._num_processed >= self.fake_nums_per_device:
if self.fake_nums != -1 and (self._num_processed >=
self.fake_nums_per_device):
return

real_imgs, fake_imgs = [], []
Expand All @@ -279,6 +280,8 @@ def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:

# real images
assert real_imgs.shape[1:] == self.image_shape
if real_imgs.shape[1] == 1:
real_imgs = real_imgs.repeat(1, 3, 1, 1)
real_pyramid = laplacian_pyramid(real_imgs, self.n_pyramids - 1,
self.gaussian_k)
# lod: layer_of_descriptors
Expand All @@ -291,6 +294,8 @@ def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:

# fake images
assert fake_imgs.shape[1:] == self.image_shape
if fake_imgs.shape[1] == 1:
fake_imgs = fake_imgs.repeat(1, 3, 1, 1)
fake_pyramid = laplacian_pyramid(fake_imgs, self.n_pyramids - 1,
self.gaussian_k)
# lod: layer_of_descriptors
Expand Down
7 changes: 7 additions & 0 deletions tests/test_evaluation/test_metrics/test_base_gen_metric.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import MagicMock, patch

import pytest
import torch
from mmengine.model import MMDistributedDataParallel

Expand Down Expand Up @@ -43,6 +44,7 @@ def test_GenMetric():
# test get_metric_sampler
model = MagicMock()
dataset = MagicMock()
dataset.__len__.return_value = 10
dataloader = MagicMock()
dataloader.batch_size = 4
dataloader.dataset = dataset
Expand All @@ -57,6 +59,11 @@ def test_GenMetric():
metric.prepare(model, dataloader)
assert metric.data_preprocessor == preprocessor

# test raise error with dataset is length than real_nums
dataset.__len__.return_value = 5
with pytest.raises(AssertionError):
metric.get_metric_sampler(model, dataloader, [metric])


def test_GenerativeMetric():
metric = ToyGenerativeMetric(11, need_cond_input=True)
Expand Down
41 changes: 39 additions & 2 deletions tests/test_evaluation/test_metrics/test_swd.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
class TestSWD(TestCase):

def test_init(self):
swd = SlicedWassersteinDistance(
fake_nums=10, image_shape=(3, 32, 32)) # noqa
swd = SlicedWassersteinDistance(fake_nums=10, image_shape=(3, 32, 32))
self.assertEqual(len(swd.real_results), 2)

def test_prosess(self):
Expand All @@ -35,6 +34,13 @@ def test_prosess(self):
]

swd.process(real_samples, fake_samples)
# 100 samples are passed in 1 batch, _num_processed should be 100
self.assertEqual(swd._num_processed, 100)
# _num_processed(100) > fake_nums(4), _num_processed should be
# unchanged
swd.process(real_samples, fake_samples)
self.assertEqual(swd._num_processed, 100)

output = swd.evaluate()
result = [16.495922580361366, 24.15413036942482, 20.325026474893093]
output = [item / 100 for item in output.values()]
Expand All @@ -48,3 +54,34 @@ def test_prosess(self):
sample_model='orig',
image_shape=(3, 32, 32))
swd.prepare(model, None)

# test gray scale input
swd.image_shape = (1, 32, 32)
real_samples = [
dict(inputs=torch.rand(1, 32, 32) * 255.) for _ in range(100)
]
fake_samples = [
EditDataSample(
fake_img=PixelData(data=torch.rand(1, 32, 32) * 2 - 1),
gt_img=PixelData(data=torch.rand(1, 32, 32) * 2 -
1)).to_dict() for _ in range(100)
]
swd.process(real_samples, fake_samples)

# test fake_nums is -1
swd = SlicedWassersteinDistance(
fake_nums=-1,
fake_key='fake',
real_key='img',
sample_model='orig',
image_shape=(3, 32, 32))
fake_samples = [
EditDataSample(
fake_img=PixelData(data=torch.rand(3, 32, 32) * 2 - 1),
gt_img=PixelData(data=torch.rand(3, 32, 32) * 2 -
1)).to_dict() for _ in range(10)
]
for _ in range(3):
swd.process(None, fake_samples)
# fake_nums is -1, all samples (10 * 3 = 30) is processed
self.assertEqual(swd._num_processed, 30)
5 changes: 5 additions & 0 deletions tests/test_models/test_editors/test_wgan_gp/test_wgan_gp.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
import platform
from copy import deepcopy
from unittest import TestCase

import pytest
import torch
from mmengine import MessageHub
from mmengine.optim import OptimWrapper, OptimWrapperDict
Expand Down Expand Up @@ -54,6 +56,9 @@ def test_init(self):
gan = WGANGP(generator=gen, data_preprocessor=GenDataPreprocessor())
self.assertEqual(gan.discriminator, None)

@pytest.mark.skipif(
'win' in platform.system().lower() and 'cu' in torch.__version__,
reason='skip on windows-cuda due to limited RAM.')
def test_train_step(self):
# prepare model
accu_iter = 1
Expand Down
5 changes: 5 additions & 0 deletions tests/test_models/test_losses/test_feature_loss.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
import platform

import pytest
import torch

from mmedit.models.losses import LightCNNFeatureLoss


@pytest.mark.skipif(
'win' in platform.system().lower() and 'cu' in torch.__version__,
reason='skip on windows-cuda due to limited RAM.')
def test_light_cnn_feature_loss():

pretrained = 'https://download.openmmlab.com/mmediting/' + \
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import platform

import pytest
import torch
from mmengine.utils.dl_utils import TORCH_VERSION
Expand All @@ -8,6 +10,9 @@
from mmedit.models.losses import GeneratorPathRegularizerComps


@pytest.mark.skipif(
'win' in platform.system().lower() and 'cu' in torch.__version__,
reason='skip on windows-cuda due to limited RAM.')
class TestPathRegularizer:

@classmethod
Expand Down
4 changes: 4 additions & 0 deletions tests/test_models/test_losses/test_perceptual_loss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import platform
from unittest.mock import patch

import pytest
Expand All @@ -8,6 +9,9 @@
TransferalPerceptualLoss)


@pytest.mark.skipif(
'win' in platform.system().lower() and 'cu' in torch.__version__,
reason='skip on windows-cuda due to limited RAM.')
@patch.object(PerceptualVGG, 'init_weights')
def test_perceptual_loss(init_weights):
if torch.cuda.is_available():
Expand Down