Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FrechetInceptionDistance wrong behavior for calculating two features with different batch size #1027

Closed
3288103265 opened this issue May 11, 2022 · 4 comments · Fixed by #1028
Assignees
Labels
bug / fix Something isn't working help wanted Extra attention is needed topic: Image v0.8.x

Comments

@3288103265
Copy link

🐛 Bug

I want to calculate FID between the real and fake images, but the batch_size of fake images is 5 times the real ones (according to the traditional setting in layout2img), and the result is wrong. Then I checked the implementation of FID, which assumes the same batch_size when calculating the covariance, which is not suitable to my knowledge.

To Reproduce

def test_step(self, batch, batch_idx):
        real_imgs, label, bbox, id = batch
        self.FID.update(unit2uint8(real_imgs), real=True)
        fake_list = []
        for _ in range(self.hparams.test_sample):# 5
            z_obj = torch.randn(real_imgs.shape[0], self.hparams.num_obj, self.hparams.z_dim, device=self.device)
            z_im = torch.randn(real_imgs.shape[0], self.hparams.z_dim, device=self.device)
            fake_imgs = self(z_obj, bbox, z_im=z_im, y=label.squeeze(dim=-1))
            fake_list.append(fake_imgs)
        fake_imgs = torch.cat(fake_list, dim=0)
        self.FID.update(unit2uint8(fake_imgs), real=False)
        self.IS.update(unit2uint8(fake_imgs))

Code sample

Here is the official implementation

    def compute(self) -> Tensor:
        """Calculate FID score based on accumulated extracted features from the two distributions."""
        real_features = dim_zero_cat(self.real_features)
        fake_features = dim_zero_cat(self.fake_features)
        # computation is extremely sensitive so it needs to happen in double precision
        orig_dtype = real_features.dtype
        real_features = real_features.double()
        fake_features = fake_features.double()

        # calculate mean and covariance
        n = real_features.shape[0]
        mean1 = real_features.mean(dim=0)
        mean2 = fake_features.mean(dim=0)
        diff1 = real_features - mean1
        diff2 = fake_features - mean2
        cov1 = 1.0 / (n - 1) * diff1.t().mm(diff1) # the same n is used
        cov2 = 1.0 / (n - 1) * diff2.t().mm(diff2)

        # compute fid
        return _compute_fid(mean1, cov1, mean2, cov2).to(orig_dtype)

Expected behavior

Environment

  • TorchMetrics version ( pip): 0.8.0
  • PyTorch Version :1.12.0a0+bd13bc6
  • Any other relevant information such as OS (e.g., Linux): docker environment, the image is highly based on nvcr.io/nvidia/pytorch:22.04-py3

Additional context

@3288103265 3288103265 added bug / fix Something isn't working help wanted Extra attention is needed labels May 11, 2022
@github-actions
Copy link

Hi! thanks for your contribution!, great first issue!

@3288103265
Copy link
Author

FYI, the more suitable behavior might be something like the following code

    def compute(self) -> Tensor:
        """Calculate FID score based on accumulated extracted features from the two distributions."""
        real_features = dim_zero_cat(self.real_features)
        fake_features = dim_zero_cat(self.fake_features)
        # computation is extremely sensitive so it needs to happen in double precision
        orig_dtype = real_features.dtype
        real_features = real_features.double()
        fake_features = fake_features.double()

        # calculate mean and covariance
        n = real_features.shape[0]
        m = fake_features.shape[0]                  ### new
        mean1 = real_features.mean(dim=0)
        mean2 = fake_features.mean(dim=0)
        diff1 = real_features - mean1
        diff2 = fake_features - mean2
        cov1 = 1.0 / (n - 1) * diff1.t().mm(diff1)  
        cov2 = 1.0 / (m - 1) * diff2.t().mm(diff2)            ### change n to m

        # compute fid
        return _compute_fid(mean1, cov1, mean2, cov2).to(orig_dtype)

@SkafteNicki
Copy link
Member

Hi @3288103265,
I have created PR #1028 with your proposed solution. Thank you for reporting this issue :]

@3288103265
Copy link
Author

@SkafteNicki Thanks a lot :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed topic: Image v0.8.x
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants