From 3d77854570be9d704fc0fafaa675186bc71bbcf3 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Wed, 11 May 2022 11:53:25 +0200 Subject: [PATCH 1/2] fix fid computation --- tests/image/test_fid.py | 14 +++++++++----- torchmetrics/image/fid.py | 3 ++- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/tests/image/test_fid.py b/tests/image/test_fid.py index dfabd378da4..56e06591734 100644 --- a/tests/image/test_fid.py +++ b/tests/image/test_fid.py @@ -125,21 +125,25 @@ def __len__(self): @pytest.mark.skipif(not torch.cuda.is_available(), reason="test is too slow without gpu") @pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity") -def test_compare_fid(tmpdir, feature=2048): +@pytest.mark.parametrize("equal_size", [False, True]) +def test_compare_fid(tmpdir, equal_size, feature=2048): """check that the hole pipeline give the same result as torch-fidelity.""" from torch_fidelity import calculate_metrics metric = FrechetInceptionDistance(feature=feature).cuda() + n = 100 + m = 100 if equal_size else 90 + # Generate some synthetic data - img1 = torch.randint(0, 180, (100, 3, 299, 299), dtype=torch.uint8) - img2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8) + img1 = torch.randint(0, 180, (n, 3, 299, 299), dtype=torch.uint8) + img2 = torch.randint(100, 255, (m, 3, 299, 299), dtype=torch.uint8) batch_size = 10 - for i in range(img1.shape[0] // batch_size): + for i in range(n // batch_size): metric.update(img1[batch_size * i : batch_size * (i + 1)].cuda(), real=True) - for i in range(img2.shape[0] // batch_size): + for i in range(m // batch_size): metric.update(img2[batch_size * i : batch_size * (i + 1)].cuda(), real=False) torch_fid = calculate_metrics( diff --git a/torchmetrics/image/fid.py b/torchmetrics/image/fid.py index 3be4589ff31..36c798aa46d 100644 --- a/torchmetrics/image/fid.py +++ b/torchmetrics/image/fid.py @@ -268,12 +268,13 @@ def compute(self) -> Tensor: # calculate mean and covariance n = real_features.shape[0] + m = fake_features.shape[0] mean1 = real_features.mean(dim=0) mean2 = fake_features.mean(dim=0) diff1 = real_features - mean1 diff2 = fake_features - mean2 cov1 = 1.0 / (n - 1) * diff1.t().mm(diff1) - cov2 = 1.0 / (n - 1) * diff2.t().mm(diff2) + cov2 = 1.0 / (m - 1) * diff2.t().mm(diff2) # compute fid return _compute_fid(mean1, cov1, mean2, cov2).to(orig_dtype) From ad0e1a3ff8b58de1a2f33aaf8b31f75676796878 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Wed, 11 May 2022 11:55:24 +0200 Subject: [PATCH 2/2] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f4d6354c808..5264ca994f1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -59,6 +59,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed non-empty state dict for a few metrics ([#1012](https://github.com/PyTorchLightning/metrics/pull/1012)) +- Fixed `FID` calculation for non-equal size real and fake input ([#1028](https://github.com/PyTorchLightning/metrics/pull/1028)) + + ## [0.8.2] - 2022-05-06