Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding KID metric #301

Merged
merged 27 commits into from
Jun 21, 2021
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
9d11ab0
implementation
SkafteNicki Jun 16, 2021
948a178
parameter testing
SkafteNicki Jun 16, 2021
c534a60
fix test
SkafteNicki Jun 16, 2021
72608f7
implementation
SkafteNicki Jun 16, 2021
c0de4f7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 16, 2021
36c5a82
update to torch fidelity 0.3.0
SkafteNicki Jun 16, 2021
1cd256c
Merge branch 'kid' of https://github.com/PyTorchLightning/metrics int…
SkafteNicki Jun 16, 2021
6aa956d
changelog
SkafteNicki Jun 16, 2021
19fefc1
docs
SkafteNicki Jun 16, 2021
90b7a76
Merge branch 'master' into kid
mergify[bot] Jun 17, 2021
a8a2ab5
Apply suggestions from code review
SkafteNicki Jun 17, 2021
3cabc96
Apply suggestions from code review
Borda Jun 17, 2021
6a2a20a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 17, 2021
751b145
add test
SkafteNicki Jun 17, 2021
dae0fd8
Merge branch 'master' into kid
SkafteNicki Jun 18, 2021
7e99282
Merge branches 'kid' and 'kid' of https://github.com/PyTorchLightning…
SkafteNicki Jun 18, 2021
39cc0f9
update
SkafteNicki Jun 18, 2021
beb9d27
Merge branch 'master' into kid
mergify[bot] Jun 21, 2021
42ec431
fix tests
SkafteNicki Jun 21, 2021
81968df
Merge branch 'kid' of https://github.com/PyTorchLightning/metrics int…
SkafteNicki Jun 21, 2021
5dca79a
typing
SkafteNicki Jun 21, 2021
4203628
fix typing
SkafteNicki Jun 21, 2021
889c066
Merge branch 'master' into kid
mergify[bot] Jun 21, 2021
94d44a5
fix bus error
SkafteNicki Jun 21, 2021
08a5bb7
Merge branch 'kid' of https://github.com/PyTorchLightning/metrics int…
SkafteNicki Jun 21, 2021
128fc84
Merge branch 'master' into kid
mergify[bot] Jun 21, 2021
3f1e0e0
Apply suggestions from code review
Borda Jun 21, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `is_differentiable` property to `ConfusionMatrix`, `F1`, `FBeta`, `Hamming`, `Hinge`, `IOU`, `MatthewsCorrcoef`, `Precision`, `Recall`, `PrecisionRecallCurve`, `ROC`, `StatScores` ([#253](https://github.com/PyTorchLightning/metrics/pull/253))


- Added KID metric ([#301](https://github.com/PyTorchLightning/metrics/pull/301))

### Changed

- Forward cache is now reset when `reset` method is called ([#260](https://github.com/PyTorchLightning/metrics/pull/260))
Expand Down
4 changes: 4 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,10 @@ learning algorithms such as `Generative Adverserial Networks (GANs) <https://en.
:noindex:



.. autoclass:: torchmetrics.KID
:noindex:

******************
Regression Metrics
******************
Expand Down
161 changes: 161 additions & 0 deletions tests/image/test_kid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle

import pytest
import torch
from torch.utils.data import Dataset

from torchmetrics.image.kid import KID
from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE

torch.manual_seed(42)


@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity")
def test_no_train():
""" Assert that metric never leaves evaluation mode """

class MyModel(torch.nn.Module):

def __init__(self):
super().__init__()
self.metric = KID()

def forward(self, x):
return x

model = MyModel()
model.train()
assert model.training
assert not model.metric.inception.training, 'FID metric was changed to training mode which should not happen'


@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason='test requires torch-fidelity')
def test_kid_pickle():
""" Assert that we can initialize the metric and pickle it"""
metric = KID()
assert metric

# verify metrics work after being loaded from pickled state
pickled_metric = pickle.dumps(metric)
metric = pickle.loads(pickled_metric)


def test_kid_raises_errors_and_warnings():
""" Test that expected warnings and errors are raised """
with pytest.warns(
UserWarning,
match='Metric `KID` will save all extracted features in buffer.'
' For large datasets this may lead to large memory footprint.'
):
KID()

if _TORCH_FIDELITY_AVAILABLE:
with pytest.raises(ValueError, match='Integer input to argument `feature` must be one of .*'):
KID(feature=2)
else:
with pytest.raises(
ValueError,
match='KID metric requires that Torch-fidelity is installed.'
'Either install as `pip install torchmetrics[image]`'
' or `pip install torch-fidelity`'
):
KID()

with pytest.raises(ValueError, match='Got unknown input to argument `feature`'):
KID(feature=[1, 2])


@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason='test requires torch-fidelity')
def test_kid_extra_parameters():
with pytest.raises(ValueError, match="Argument `subsets` expected to be integer larger than 0"):
KID(subsets=-1)

with pytest.raises(ValueError, match="Argument `subset_size` expected to be integer larger than 0"):
KID(subset_size=-1)

with pytest.raises(ValueError, match="Argument `degree` expected to be integer larger than 0"):
KID(degree=-1)

with pytest.raises(ValueError, match="Argument `gamma` expected to be `None` or float larger than 0"):
KID(gamma=-1)

with pytest.raises(ValueError, match="Argument `coef` expected to be float larger than 0"):
KID(coef=-1)


@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason='test requires torch-fidelity')
@pytest.mark.parametrize("feature", [64, 192, 768, 2048])
def test_kid_same_input(feature):
""" if real and fake are update on the same data the fid score should be 0 """
metric = KID(feature=feature, subsets=5, subset_size=2)

for _ in range(2):
img = torch.randint(0, 255, (10, 3, 299, 299), dtype=torch.uint8)
metric.update(img, real=True)
metric.update(img, real=False)

assert torch.allclose(torch.cat(metric.real_features, dim=0), torch.cat(metric.fake_features, dim=0))

mean, std = metric.compute()
assert mean != 0.0
assert std != 0.0
Borda marked this conversation as resolved.
Show resolved Hide resolved


class _ImgDataset(Dataset):

def __init__(self, imgs):
self.imgs = imgs

def __getitem__(self, idx):
return self.imgs[idx]

def __len__(self):
return self.imgs.shape[0]


@pytest.mark.skipif(not torch.cuda.is_available(), reason='test is too slow without gpu')
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason='test requires torch-fidelity')
def test_compare_kid(tmpdir, feature=2048):
""" check that the hole pipeline give the same result as torch-fidelity """
from torch_fidelity import calculate_metrics

metric = KID(feature=feature, subsets=10, subset_size=10).cuda()

# Generate some synthetic data
img1 = torch.randint(0, 180, (100, 3, 299, 299), dtype=torch.uint8)
img2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8)

batch_size = 10
for i in range(img1.shape[0] // batch_size):
metric.update(img1[batch_size * i:batch_size * (i + 1)].cuda(), real=True)

for i in range(img2.shape[0] // batch_size):
metric.update(img2[batch_size * i:batch_size * (i + 1)].cuda(), real=False)

torch_fid = calculate_metrics(
input1=_ImgDataset(img1),
input2=_ImgDataset(img2),
kid=True,
feature_layer_fid=str(feature),
batch_size=batch_size,
kid_subsets=10,
kid_subset_size=10
)

tm_mean, tm_std = metric.compute()

assert torch.allclose(tm_mean.cpu(), torch.tensor([torch_fid['kernel_inception_distance_mean']]), atol=1e-3)
assert torch.allclose(tm_std.cpu(), torch.tensor([torch_fid['kernel_inception_distance_std']]), atol=1e-3)
2 changes: 1 addition & 1 deletion torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
StatScores,
)
from torchmetrics.collections import MetricCollection # noqa: F401 E402
from torchmetrics.image import FID # noqa: F401 E402
from torchmetrics.image import FID, KID # noqa: F401 E402
from torchmetrics.metric import Metric # noqa: F401 E402
from torchmetrics.regression import ( # noqa: F401 E402
PSNR,
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.image.fid import FID # noqa: F401
from torchmetrics.image.kid import KID # noqa: F401
Loading