Skip to content

Commit

Permalink
Improve numeric stability of LPIPS (#2144)
Browse files Browse the repository at this point in the history
* improve stability

* changelog

(cherry picked from commit 1d10277)
  • Loading branch information
SkafteNicki authored and Borda committed Dec 1, 2023
1 parent df830c0 commit d167aa6
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

-
- Fixed numerical stability bug in `LearnedPerceptualImagePatchSimilarity` metric ([#2144](https://github.com/Lightning-AI/torchmetrics/pull/2144))


## [1.2.0] - 2023-09-22
Expand Down
8 changes: 4 additions & 4 deletions src/torchmetrics/functional/image/lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,10 @@ def _upsample(in_tens: Tensor, out_hw: Tuple[int, ...] = (64, 64)) -> Tensor:
return nn.Upsample(size=out_hw, mode="bilinear", align_corners=False)(in_tens)


def _normalize_tensor(in_feat: Tensor, eps: float = 1e-10) -> Tensor:
"""Normalize tensors."""
norm_factor = torch.sqrt(torch.sum(in_feat**2, dim=1, keepdim=True))
return in_feat / (norm_factor + eps)
def _normalize_tensor(in_feat: Tensor, eps: float = 1e-8) -> Tensor:
"""Normalize input tensor."""
norm_factor = torch.sqrt(eps + torch.sum(in_feat**2, dim=1, keepdim=True))
return in_feat / norm_factor


def _resize_tensor(x: Tensor, size: int = 64) -> Tensor:
Expand Down

0 comments on commit d167aa6

Please sign in to comment.