Skip to content

Commit

Permalink
Fix FID computation for non equal size (#1028)
Browse files Browse the repository at this point in the history
* fix fid computation
* changelog
  • Loading branch information
SkafteNicki authored May 12, 2022
1 parent 6f5ac1e commit 6cffeb5
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 6 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0



- Fixed `FID` calculation for non-equal size real and fake input ([#1028](https://github.com/PyTorchLightning/metrics/pull/1028))


## [0.8.2] - 2022-05-06


Expand Down
14 changes: 9 additions & 5 deletions tests/image/test_fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,21 +125,25 @@ def __len__(self):

@pytest.mark.skipif(not torch.cuda.is_available(), reason="test is too slow without gpu")
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity")
def test_compare_fid(tmpdir, feature=2048):
@pytest.mark.parametrize("equal_size", [False, True])
def test_compare_fid(tmpdir, equal_size, feature=2048):
"""check that the hole pipeline give the same result as torch-fidelity."""
from torch_fidelity import calculate_metrics

metric = FrechetInceptionDistance(feature=feature).cuda()

n = 100
m = 100 if equal_size else 90

# Generate some synthetic data
img1 = torch.randint(0, 180, (100, 3, 299, 299), dtype=torch.uint8)
img2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8)
img1 = torch.randint(0, 180, (n, 3, 299, 299), dtype=torch.uint8)
img2 = torch.randint(100, 255, (m, 3, 299, 299), dtype=torch.uint8)

batch_size = 10
for i in range(img1.shape[0] // batch_size):
for i in range(n // batch_size):
metric.update(img1[batch_size * i : batch_size * (i + 1)].cuda(), real=True)

for i in range(img2.shape[0] // batch_size):
for i in range(m // batch_size):
metric.update(img2[batch_size * i : batch_size * (i + 1)].cuda(), real=False)

torch_fid = calculate_metrics(
Expand Down
3 changes: 2 additions & 1 deletion torchmetrics/image/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,12 +268,13 @@ def compute(self) -> Tensor:

# calculate mean and covariance
n = real_features.shape[0]
m = fake_features.shape[0]
mean1 = real_features.mean(dim=0)
mean2 = fake_features.mean(dim=0)
diff1 = real_features - mean1
diff2 = fake_features - mean2
cov1 = 1.0 / (n - 1) * diff1.t().mm(diff1)
cov2 = 1.0 / (n - 1) * diff2.t().mm(diff2)
cov2 = 1.0 / (m - 1) * diff2.t().mm(diff2)

# compute fid
return _compute_fid(mean1, cov1, mean2, cov2).to(orig_dtype)
Expand Down

0 comments on commit 6cffeb5

Please sign in to comment.