diff --git a/CHANGELOG.md b/CHANGELOG.md index fff4578b101..eec4a7b0577 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -63,7 +63,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed multi device aggregation in `PearsonCorrCoef` ([#998](https://github.com/PyTorchLightning/metrics/pull/998)) -- +- Fixed compatibility with future Pytorch 1.12 in `pairwise_cosine_similarity` ([#1011](https://github.com/PyTorchLightning/metrics/pull/1011)) ## [0.8.1] - 2022-04-27 diff --git a/torchmetrics/functional/pairwise/cosine.py b/torchmetrics/functional/pairwise/cosine.py index 74df87352bb..29103a52dd9 100644 --- a/torchmetrics/functional/pairwise/cosine.py +++ b/torchmetrics/functional/pairwise/cosine.py @@ -20,6 +20,16 @@ from torchmetrics.functional.pairwise.helpers import _check_input, _reduce_distance_matrix +def _safe_matmul(x: Tensor, y: Tensor) -> Tensor: + """Safe calculation of matrix multiplication. + + If input is float16, will cast to float32 for computation and back again. + """ + if x.dtype == torch.float16 or y.dtype == torch.float16: + return (x.float() @ y.T.float()).half() + return x @ y.T + + def _pairwise_cosine_similarity_update( x: Tensor, y: Optional[Tensor] = None, zero_diagonal: Optional[bool] = None ) -> Tensor: @@ -37,7 +47,7 @@ def _pairwise_cosine_similarity_update( norm = torch.norm(y, p=2, dim=1) y /= norm.unsqueeze(1) - distance = x @ y.T + distance = _safe_matmul(x, y) if zero_diagonal: distance.fill_diagonal_(0) return distance