Skip to content

Commit

Permalink
Clamp variance calculation in certain image metrics (#2378)
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki authored Feb 14, 2024
1 parent ee1a529 commit afae59e
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 5 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

-
- Fixed negative variance estimates in certain image metrics ([#2378](https://github.com/Lightning-AI/torchmetrics/pull/2378))



Expand Down
5 changes: 3 additions & 2 deletions src/torchmetrics/functional/image/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,9 @@ def _ssim_update(
mu_target_sq = output_list[1].pow(2)
mu_pred_target = output_list[0] * output_list[1]

sigma_pred_sq = output_list[2] - mu_pred_sq
sigma_target_sq = output_list[3] - mu_target_sq
# Calculate the variance of the predicted and target images, should be non-negative
sigma_pred_sq = torch.clamp(output_list[2] - mu_pred_sq, min=0.0)
sigma_target_sq = torch.clamp(output_list[3] - mu_target_sq, min=0.0)
sigma_pred_target = output_list[4] - mu_pred_target

upper = 2 * sigma_pred_target.to(dtype) + c2
Expand Down
5 changes: 3 additions & 2 deletions src/torchmetrics/functional/image/uqi.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,9 @@ def _uqi_compute(
mu_target_sq = output_list[1].pow(2)
mu_pred_target = output_list[0] * output_list[1]

sigma_pred_sq = output_list[2] - mu_pred_sq
sigma_target_sq = output_list[3] - mu_target_sq
# Calculate the variance of the predicted and target images, should be non-negative
sigma_pred_sq = torch.clamp(output_list[2] - mu_pred_sq, min=0.0)
sigma_target_sq = torch.clamp(output_list[3] - mu_target_sq, min=0.0)
sigma_pred_target = output_list[4] - mu_pred_target

upper = 2 * sigma_pred_target
Expand Down

0 comments on commit afae59e

Please sign in to comment.