diff --git a/ddpm_torch/metrics/__init__.py b/ddpm_torch/metrics/__init__.py index 830e5ca..6ad11bf 100644 --- a/ddpm_torch/metrics/__init__.py +++ b/ddpm_torch/metrics/__init__.py @@ -40,7 +40,7 @@ def eval(self, sample_fn, is_leader=True): with trange(num_batches, desc="Evaluating FID", disable=not is_leader) as t: for i in t: if i == len(t) - 1: - batch_size = self.eval_total_size % self.eval_batch_size + batch_size = (self.eval_total_size % self.eval_batch_size) or self.eval_batch_size else: batch_size = self.eval_batch_size x = sample_fn(sample_size=batch_size, diffusion=self.diffusion)