Skip to content

Commit

Permalink
support -1 for fake imgs in SWD metric
Browse files Browse the repository at this point in the history
  • Loading branch information
LeoXing1996 committed Dec 13, 2022
1 parent f59bd90 commit 936cbac
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion mmedit/evaluation/metrics/swd.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,8 @@ def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
data_batch (dict): A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from the model.
"""
if self._num_processed >= self.fake_nums_per_device:
if self.fake_nums != -1 and (self._num_processed >=
self.fake_nums_per_device):
return

real_imgs, fake_imgs = [], []
Expand All @@ -279,6 +280,8 @@ def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:

# real images
assert real_imgs.shape[1:] == self.image_shape
if real_imgs.shape[1] == 1:
real_imgs = real_imgs.repeat(1, 3, 1, 1)
real_pyramid = laplacian_pyramid(real_imgs, self.n_pyramids - 1,
self.gaussian_k)
# lod: layer_of_descriptors
Expand All @@ -291,6 +294,8 @@ def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:

# fake images
assert fake_imgs.shape[1:] == self.image_shape
if fake_imgs.shape[1] == 1:
fake_imgs = fake_imgs.repeat(1, 3, 1, 1)
fake_pyramid = laplacian_pyramid(fake_imgs, self.n_pyramids - 1,
self.gaussian_k)
# lod: layer_of_descriptors
Expand Down

0 comments on commit 936cbac

Please sign in to comment.