Skip to content

Commit

Permalink
Inception Score (#299)
Browse files Browse the repository at this point in the history
* implementation and test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* changelog

* add example

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Apply suggestions from code review

* update to torch fidelity 0.3.0

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* 35min

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
  • Loading branch information
4 people authored Jun 17, 2021
1 parent 1841cad commit f54ccca
Show file tree
Hide file tree
Showing 9 changed files with 335 additions and 10 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `is_differentiable` property to `ConfusionMatrix`, `F1`, `FBeta`, `Hamming`, `Hinge`, `IOU`, `MatthewsCorrcoef`, `Precision`, `Recall`, `PrecisionRecallCurve`, `ROC`, `StatScores` ([#253](https://github.com/PyTorchLightning/metrics/pull/253))


- Added Inception Score metric to image module ([#299](https://github.com/PyTorchLightning/metrics/pull/299))


### Changed

- Forward cache is now reset when `reset` method is called ([#260](https://github.com/PyTorchLightning/metrics/pull/260))
Expand Down
2 changes: 1 addition & 1 deletion azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pr:
jobs:
- job: pytest
# how long to run the job before automatically cancelling
timeoutInMinutes: 25
timeoutInMinutes: 35
# how much time to give 'run always even if cancelled tasks' before stopping them
cancelTimeoutInMinutes: 2

Expand Down
8 changes: 8 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,17 @@ Image Quality Metrics
Image quality metrics can be used to access the quality of synthetic generated images from machine
learning algorithms such as `Generative Adverserial Networks (GANs) <https://en.wikipedia.org/wiki/Generative_adversarial_network>`_.

FID
~~~

.. autoclass:: torchmetrics.FID
:noindex:

IS
~~

.. autoclass:: torchmetrics.IS
:noindex:

******************
Regression Metrics
Expand Down
16 changes: 12 additions & 4 deletions tests/image/test_fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,18 @@ def test_fid_raises_errors_and_warnings():
):
_ = FID()

with pytest.raises(TypeError, match='Got unknown input to argument `feature`'):
_ = FID(feature=[1, 2])


@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason='test requires torch-fidelity')
def test_fid_same_input():
@pytest.mark.parametrize("feature", [64, 192, 768, 2048])
def test_fid_same_input(feature):
""" if real and fake are update on the same data the fid score should be 0 """
metric = FID(feature=192)
metric = FID(feature=feature)

for _ in range(2):
img = torch.randint(0, 255, (5, 3, 299, 299), dtype=torch.uint8)
img = torch.randint(0, 255, (10, 3, 299, 299), dtype=torch.uint8)
metric.update(img, real=True)
metric.update(img, real=False)

Expand Down Expand Up @@ -140,7 +144,11 @@ def test_compare_fid(tmpdir, feature=2048):
metric.update(img2[batch_size * i:batch_size * (i + 1)].cuda(), real=False)

torch_fid = calculate_metrics(
_ImgDataset(img1), _ImgDataset(img2), fid=True, feature_layer_fid=str(feature), batch_size=batch_size
input1=_ImgDataset(img1),
input2=_ImgDataset(img2),
fid=True,
feature_layer_fid=str(feature),
batch_size=batch_size
)

tm_res = metric.compute()
Expand Down
125 changes: 125 additions & 0 deletions tests/image/test_inception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle

import pytest
import torch
from torch.utils.data import Dataset

from torchmetrics.image.inception import IS
from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE

torch.manual_seed(42)


@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity")
def test_no_train():
""" Assert that metric never leaves evaluation mode """

class MyModel(torch.nn.Module):

def __init__(self):
super().__init__()
self.metric = IS()

def forward(self, x):
return x

model = MyModel()
model.train()
assert model.training
assert not model.metric.inception.training, 'IS metric was changed to training mode which should not happen'


@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason='test requires torch-fidelity')
def test_is_pickle():
""" Assert that we can initialize the metric and pickle it"""
metric = IS()
assert metric

# verify metrics work after being loaded from pickled state
pickled_metric = pickle.dumps(metric)
metric = pickle.loads(pickled_metric)


def test_is_raises_errors_and_warnings():
""" Test that expected warnings and errors are raised """
with pytest.warns(
UserWarning,
match='Metric `IS` will save all extracted features in buffer.'
' For large datasets this may lead to large memory footprint.'
):
IS()

if _TORCH_FIDELITY_AVAILABLE:
with pytest.raises(ValueError, match='Integer input to argument `feature` must be one of .*'):
_ = IS(feature=2)
else:
with pytest.raises(
ValueError,
match='IS metric requires that Torch-fidelity is installed.'
'Either install as `pip install torchmetrics[image-quality]`'
' or `pip install torch-fidelity`'
):
IS()

with pytest.raises(TypeError, match='Got unknown input to argument `feature`'):
IS(feature=[1, 2])


@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason='test requires torch-fidelity')
def test_is_update_compute():
metric = IS()

for _ in range(2):
img = torch.randint(0, 255, (10, 3, 299, 299), dtype=torch.uint8)
metric.update(img)

mean, std = metric.compute()
assert mean != 0.0
assert std != 0.0


class _ImgDataset(Dataset):

def __init__(self, imgs):
self.imgs = imgs

def __getitem__(self, idx):
return self.imgs[idx]

def __len__(self):
return self.imgs.shape[0]


@pytest.mark.skipif(not torch.cuda.is_available(), reason='test is too slow without gpu')
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason='test requires torch-fidelity')
def test_compare_is(tmpdir):
""" check that the hole pipeline give the same result as torch-fidelity """
from torch_fidelity import calculate_metrics

metric = IS(splits=1).cuda()

# Generate some synthetic data
img1 = torch.randint(0, 255, (100, 3, 299, 299), dtype=torch.uint8)

batch_size = 10
for i in range(img1.shape[0] // batch_size):
metric.update(img1[batch_size * i:batch_size * (i + 1)].cuda())

torch_fid = calculate_metrics(input1=_ImgDataset(img1), isc=True, isc_splits=1, batch_size=batch_size)

tm_mean, tm_std = metric.compute()

assert torch.allclose(tm_mean.cpu(), torch.tensor([torch_fid['inception_score_mean']]), atol=1e-3)
2 changes: 1 addition & 1 deletion torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
StatScores,
)
from torchmetrics.collections import MetricCollection # noqa: F401 E402
from torchmetrics.image import FID # noqa: F401 E402
from torchmetrics.image import FID, IS # noqa: F401 E402
from torchmetrics.metric import Metric # noqa: F401 E402
from torchmetrics.regression import ( # noqa: F401 E402
PSNR,
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.image.fid import FID # noqa: F401
from torchmetrics.image.inception import IS # noqa: F401
12 changes: 8 additions & 4 deletions torchmetrics/image/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class FID(Metric):
determines if the images should update the statistics of the real distribution or the fake distribution.
.. note:: using this metric with the default feature extractor requires that ``torch-fidelity``
is installed. Either install as ``pip install torchmetrics[image-quality]`` or
is installed. Either install as ``pip install torchmetrics[image]`` or
``pip install torch-fidelity``
.. note:: the ``forward`` method can be used but ``compute_on_step`` is disabled by default (oppesit of
Expand Down Expand Up @@ -182,6 +182,8 @@ class FID(Metric):
If ``feature`` is set to an ``int`` (default settings) and ``torch-fidelity`` is not installed
ValueError:
If ``feature`` is set to an ``int`` not in [64, 192, 768, 2048]
TypeError:
If ``feature`` is not an ``str``, ``int`` or ``torch.nn.Module``
Example:
>>> import torch
Expand All @@ -204,7 +206,7 @@ def __init__(
compute_on_step: bool = False,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None
dist_sync_fn: Callable[[Tensor], List[Tensor]] = None
):
super().__init__(
compute_on_step=compute_on_step,
Expand All @@ -222,7 +224,7 @@ def __init__(
if not _TORCH_FIDELITY_AVAILABLE:
raise ValueError(
'FID metric requires that Torch-fidelity is installed.'
'Either install as `pip install torchmetrics[image-quality]` or `pip install torch-fidelity`'
'Either install as `pip install torchmetrics[image]` or `pip install torch-fidelity`'
)
valid_int_input = [64, 192, 768, 2048]
if feature not in valid_int_input:
Expand All @@ -231,8 +233,10 @@ def __init__(
)

self.inception = NoTrainInceptionV3(name='inception-v3-compat', features_list=[str(feature)])
else:
elif isinstance(feature, torch.nn.Module):
self.inception = feature
else:
raise TypeError('Got unknown input to argument `feature`')

self.add_state("real_features", [], dist_reduce_fx=None)
self.add_state("fake_features", [], dist_reduce_fx=None)
Expand Down
Loading

0 comments on commit f54ccca

Please sign in to comment.