Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optionally Avoid recomputing features #722

Merged
merged 17 commits into from
Apr 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `adaptive_k` for the `RetrievalPrecision` metric ([#910](https://github.com/PyTorchLightning/metrics/pull/910))


- Added `reset_real_features` argument image quality assesment metrics ([#722](https://github.com/PyTorchLightning/metrics/pull/722))


### Changed

- Made `num_classes` in `jaccard_index` a required argument ([#853](https://github.com/PyTorchLightning/metrics/pull/853), [#914](https://github.com/PyTorchLightning/metrics/pull/914))
Expand Down
25 changes: 25 additions & 0 deletions tests/image/test_fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,28 @@ def test_compare_fid(tmpdir, feature=2048):
tm_res = metric.compute()

assert torch.allclose(tm_res.cpu(), torch.tensor([torch_fid["frechet_inception_distance"]]), atol=1e-3)


@pytest.mark.parametrize("reset_real_features", [True, False])
def test_reset_real_features_arg(reset_real_features):
metric = FrechetInceptionDistance(feature=64, reset_real_features=reset_real_features)

metric.update(torch.randint(0, 180, (2, 3, 299, 299), dtype=torch.uint8), real=True)
metric.update(torch.randint(0, 180, (2, 3, 299, 299), dtype=torch.uint8), real=False)

assert len(metric.real_features) == 1
assert list(metric.real_features[0].shape) == [2, 64]

assert len(metric.fake_features) == 1
assert list(metric.fake_features[0].shape) == [2, 64]

metric.reset()

# fake features should always reset
assert len(metric.fake_features) == 0

if reset_real_features:
assert len(metric.real_features) == 0
else:
assert len(metric.real_features) == 1
assert list(metric.real_features[0].shape) == [2, 64]
25 changes: 25 additions & 0 deletions tests/image/test_kid.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,28 @@ def test_compare_kid(tmpdir, feature=2048):

assert torch.allclose(tm_mean.cpu(), torch.tensor([torch_fid["kernel_inception_distance_mean"]]), atol=1e-3)
assert torch.allclose(tm_std.cpu(), torch.tensor([torch_fid["kernel_inception_distance_std"]]), atol=1e-3)


@pytest.mark.parametrize("reset_real_features", [True, False])
def test_reset_real_features_arg(reset_real_features):
metric = KernelInceptionDistance(feature=64, reset_real_features=reset_real_features)

metric.update(torch.randint(0, 180, (2, 3, 299, 299), dtype=torch.uint8), real=True)
metric.update(torch.randint(0, 180, (2, 3, 299, 299), dtype=torch.uint8), real=False)

assert len(metric.real_features) == 1
assert list(metric.real_features[0].shape) == [2, 64]

assert len(metric.fake_features) == 1
assert list(metric.fake_features[0].shape) == [2, 64]

metric.reset()

# fake features should always reset
assert len(metric.fake_features) == 0

if reset_real_features:
assert len(metric.real_features) == 0
else:
assert len(metric.real_features) == 1
assert list(metric.real_features[0].shape) == [2, 64]
23 changes: 22 additions & 1 deletion torchmetrics/image/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ class FrechetInceptionDistance(Metric):
- an ``nn.Module`` for using a custom feature extractor. Expects that its forward method returns
an ``[N,d]`` matrix where ``N`` is the batch size and ``d`` is the feature size.

reset_real_features: Whether to also reset the real features. Since in many cases the real dataset does not
change, the features can cached them to avoid recomputing them which is costly. Set this to ``False`` if
your dataset does not change.

compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False.

Expand All @@ -186,6 +190,8 @@ class FrechetInceptionDistance(Metric):
If ``feature`` is set to an ``int`` not in [64, 192, 768, 2048]
TypeError:
If ``feature`` is not an ``str``, ``int`` or ``torch.nn.Module``
ValueError:
If ``reset_real_features`` is not an ``bool``

Example:
>>> import torch
Expand All @@ -203,11 +209,13 @@ class FrechetInceptionDistance(Metric):
"""
real_features: List[Tensor]
fake_features: List[Tensor]
higher_is_better = False
higher_is_better: bool = False
is_differentiable: bool = False

def __init__(
self,
feature: Union[int, torch.nn.Module] = 2048,
reset_real_features: bool = True,
compute_on_step: Optional[bool] = None,
**kwargs: Dict[str, Any],
) -> None:
Expand Down Expand Up @@ -237,6 +245,10 @@ def __init__(
else:
raise TypeError("Got unknown input to argument `feature`")

if not isinstance(reset_real_features, bool):
raise ValueError("Arugment `reset_real_features` expected to be a bool")
self.reset_real_features = reset_real_features

self.add_state("real_features", [], dist_reduce_fx=None)
self.add_state("fake_features", [], dist_reduce_fx=None)

Expand Down Expand Up @@ -274,3 +286,12 @@ def compute(self) -> Tensor:

# compute fid
return _compute_fid(mean1, cov1, mean2, cov2).to(orig_dtype)

def reset(self) -> None:
if not self.reset_real_features:
# remove temporarily to avoid resetting
value = self._defaults.pop("real_features")
super().reset()
self._defaults["real_features"] = value
else:
super().reset()
22 changes: 21 additions & 1 deletion torchmetrics/image/kid.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ class KernelInceptionDistance(Metric):
Scale-length of polynomial kernel. If set to ``None`` will be automatically set to the feature size
coef:
Bias term in the polynomial kernel.
reset_real_features: Whether to also reset the real features. Since in many cases the real dataset does not
change, the features can cached them to avoid recomputing them which is costly. Set this to ``False`` if
your dataset does not change.
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False.

Expand Down Expand Up @@ -145,6 +148,8 @@ class KernelInceptionDistance(Metric):
If ``gamma`` is niether ``None`` or a float larger than 0
ValueError:
If ``coef`` is not an float larger than 0
ValueError:
If ``reset_real_features`` is not an ``bool``

Example:
>>> import torch
Expand All @@ -163,7 +168,8 @@ class KernelInceptionDistance(Metric):
"""
real_features: List[Tensor]
fake_features: List[Tensor]
higher_is_better = False
higher_is_better: bool = False
is_differentiable: bool = False

def __init__(
self,
Expand All @@ -173,6 +179,7 @@ def __init__(
degree: int = 3,
gamma: Optional[float] = None, # type: ignore
coef: float = 1.0,
reset_real_features: bool = True,
compute_on_step: Optional[bool] = None,
**kwargs: Dict[str, Any],
) -> None:
Expand Down Expand Up @@ -222,6 +229,10 @@ def __init__(
raise ValueError("Argument `coef` expected to be float larger than 0")
self.coef = coef

if not isinstance(reset_real_features, bool):
raise ValueError("Arugment `reset_real_features` expected to be a bool")
self.reset_real_features = reset_real_features

# states for extracted features
self.add_state("real_features", [], dist_reduce_fx=None)
self.add_state("fake_features", [], dist_reduce_fx=None)
Expand Down Expand Up @@ -267,3 +278,12 @@ def compute(self) -> Tuple[Tensor, Tensor]:
kid_scores_.append(o)
kid_scores = torch.stack(kid_scores_)
return kid_scores.mean(), kid_scores.std(unbiased=False)

def reset(self) -> None:
if not self.reset_real_features:
# remove temporarily to avoid resetting
value = self._defaults.pop("real_features")
super().reset()
self._defaults["real_features"] = value
else:
super().reset()