Skip to content

Commit

Permalink
Optionally Avoid recomputing features (#722)
Browse files Browse the repository at this point in the history
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com>
  • Loading branch information
3 people authored Apr 11, 2022
1 parent 0d56376 commit 97f2bf5
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `adaptive_k` for the `RetrievalPrecision` metric ([#910](https://github.com/PyTorchLightning/metrics/pull/910))


- Added `reset_real_features` argument image quality assesment metrics ([#722](https://github.com/PyTorchLightning/metrics/pull/722))


### Changed

- Made `num_classes` in `jaccard_index` a required argument ([#853](https://github.com/PyTorchLightning/metrics/pull/853), [#914](https://github.com/PyTorchLightning/metrics/pull/914))
Expand Down
25 changes: 25 additions & 0 deletions tests/image/test_fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,28 @@ def test_compare_fid(tmpdir, feature=2048):
tm_res = metric.compute()

assert torch.allclose(tm_res.cpu(), torch.tensor([torch_fid["frechet_inception_distance"]]), atol=1e-3)


@pytest.mark.parametrize("reset_real_features", [True, False])
def test_reset_real_features_arg(reset_real_features):
metric = FrechetInceptionDistance(feature=64, reset_real_features=reset_real_features)

metric.update(torch.randint(0, 180, (2, 3, 299, 299), dtype=torch.uint8), real=True)
metric.update(torch.randint(0, 180, (2, 3, 299, 299), dtype=torch.uint8), real=False)

assert len(metric.real_features) == 1
assert list(metric.real_features[0].shape) == [2, 64]

assert len(metric.fake_features) == 1
assert list(metric.fake_features[0].shape) == [2, 64]

metric.reset()

# fake features should always reset
assert len(metric.fake_features) == 0

if reset_real_features:
assert len(metric.real_features) == 0
else:
assert len(metric.real_features) == 1
assert list(metric.real_features[0].shape) == [2, 64]
25 changes: 25 additions & 0 deletions tests/image/test_kid.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,28 @@ def test_compare_kid(tmpdir, feature=2048):

assert torch.allclose(tm_mean.cpu(), torch.tensor([torch_fid["kernel_inception_distance_mean"]]), atol=1e-3)
assert torch.allclose(tm_std.cpu(), torch.tensor([torch_fid["kernel_inception_distance_std"]]), atol=1e-3)


@pytest.mark.parametrize("reset_real_features", [True, False])
def test_reset_real_features_arg(reset_real_features):
metric = KernelInceptionDistance(feature=64, reset_real_features=reset_real_features)

metric.update(torch.randint(0, 180, (2, 3, 299, 299), dtype=torch.uint8), real=True)
metric.update(torch.randint(0, 180, (2, 3, 299, 299), dtype=torch.uint8), real=False)

assert len(metric.real_features) == 1
assert list(metric.real_features[0].shape) == [2, 64]

assert len(metric.fake_features) == 1
assert list(metric.fake_features[0].shape) == [2, 64]

metric.reset()

# fake features should always reset
assert len(metric.fake_features) == 0

if reset_real_features:
assert len(metric.real_features) == 0
else:
assert len(metric.real_features) == 1
assert list(metric.real_features[0].shape) == [2, 64]
23 changes: 22 additions & 1 deletion torchmetrics/image/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ class FrechetInceptionDistance(Metric):
- an ``nn.Module`` for using a custom feature extractor. Expects that its forward method returns
an ``[N,d]`` matrix where ``N`` is the batch size and ``d`` is the feature size.
reset_real_features: Whether to also reset the real features. Since in many cases the real dataset does not
change, the features can cached them to avoid recomputing them which is costly. Set this to ``False`` if
your dataset does not change.
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False.
Expand All @@ -186,6 +190,8 @@ class FrechetInceptionDistance(Metric):
If ``feature`` is set to an ``int`` not in [64, 192, 768, 2048]
TypeError:
If ``feature`` is not an ``str``, ``int`` or ``torch.nn.Module``
ValueError:
If ``reset_real_features`` is not an ``bool``
Example:
>>> import torch
Expand All @@ -203,11 +209,13 @@ class FrechetInceptionDistance(Metric):
"""
real_features: List[Tensor]
fake_features: List[Tensor]
higher_is_better = False
higher_is_better: bool = False
is_differentiable: bool = False

def __init__(
self,
feature: Union[int, torch.nn.Module] = 2048,
reset_real_features: bool = True,
compute_on_step: Optional[bool] = None,
**kwargs: Dict[str, Any],
) -> None:
Expand Down Expand Up @@ -237,6 +245,10 @@ def __init__(
else:
raise TypeError("Got unknown input to argument `feature`")

if not isinstance(reset_real_features, bool):
raise ValueError("Arugment `reset_real_features` expected to be a bool")
self.reset_real_features = reset_real_features

self.add_state("real_features", [], dist_reduce_fx=None)
self.add_state("fake_features", [], dist_reduce_fx=None)

Expand Down Expand Up @@ -274,3 +286,12 @@ def compute(self) -> Tensor:

# compute fid
return _compute_fid(mean1, cov1, mean2, cov2).to(orig_dtype)

def reset(self) -> None:
if not self.reset_real_features:
# remove temporarily to avoid resetting
value = self._defaults.pop("real_features")
super().reset()
self._defaults["real_features"] = value
else:
super().reset()
22 changes: 21 additions & 1 deletion torchmetrics/image/kid.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ class KernelInceptionDistance(Metric):
Scale-length of polynomial kernel. If set to ``None`` will be automatically set to the feature size
coef:
Bias term in the polynomial kernel.
reset_real_features: Whether to also reset the real features. Since in many cases the real dataset does not
change, the features can cached them to avoid recomputing them which is costly. Set this to ``False`` if
your dataset does not change.
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False.
Expand Down Expand Up @@ -145,6 +148,8 @@ class KernelInceptionDistance(Metric):
If ``gamma`` is niether ``None`` or a float larger than 0
ValueError:
If ``coef`` is not an float larger than 0
ValueError:
If ``reset_real_features`` is not an ``bool``
Example:
>>> import torch
Expand All @@ -163,7 +168,8 @@ class KernelInceptionDistance(Metric):
"""
real_features: List[Tensor]
fake_features: List[Tensor]
higher_is_better = False
higher_is_better: bool = False
is_differentiable: bool = False

def __init__(
self,
Expand All @@ -173,6 +179,7 @@ def __init__(
degree: int = 3,
gamma: Optional[float] = None, # type: ignore
coef: float = 1.0,
reset_real_features: bool = True,
compute_on_step: Optional[bool] = None,
**kwargs: Dict[str, Any],
) -> None:
Expand Down Expand Up @@ -222,6 +229,10 @@ def __init__(
raise ValueError("Argument `coef` expected to be float larger than 0")
self.coef = coef

if not isinstance(reset_real_features, bool):
raise ValueError("Arugment `reset_real_features` expected to be a bool")
self.reset_real_features = reset_real_features

# states for extracted features
self.add_state("real_features", [], dist_reduce_fx=None)
self.add_state("fake_features", [], dist_reduce_fx=None)
Expand Down Expand Up @@ -267,3 +278,12 @@ def compute(self) -> Tuple[Tensor, Tensor]:
kid_scores_.append(o)
kid_scores = torch.stack(kid_scores_)
return kid_scores.mean(), kid_scores.std(unbiased=False)

def reset(self) -> None:
if not self.reset_real_features:
# remove temporarily to avoid resetting
value = self._defaults.pop("real_features")
super().reset()
self._defaults["real_features"] = value
else:
super().reset()

0 comments on commit 97f2bf5

Please sign in to comment.