Skip to content

Commit

Permalink
Add error on wrong shape in CosineSimilarity metric (#2241)
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki authored Nov 28, 2023
1 parent 58ffb01 commit 245ee24
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 25 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed bug in `Metric._reduce_states(...)` when using `dist_sync_fn="cat"` ([#2226](https://github.com/Lightning-AI/torchmetrics/pull/2226))


- Fixed bug in `CosineSimilarity` where 2d is expected but 1d input was given ([#2241](https://github.com/Lightning-AI/torchmetrics/pull/2241))


- Fixed bug in `MetricCollection` when using compute groups and `compute` is called more than once ([#2211](https://github.com/Lightning-AI/torchmetrics/pull/2211))


## [1.2.0] - 2023-09-22

### Added
Expand Down
5 changes: 5 additions & 0 deletions src/torchmetrics/functional/regression/cosine_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ def _cosine_similarity_update(
"""
_check_same_shape(preds, target)
if preds.ndim != 2:
raise ValueError(
"Expected input to cosine similarity to be 2D tensors of shape `[N,D]` where `N` is the number of samples"
f" and `D` is the number of dimensions, but got tensor of shape {preds.shape}"
)
preds = preds.float()
target = target.float()

Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/regression/cosine_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def plot(
>>> # Example plotting a single value
>>> from torchmetrics.regression import CosineSimilarity
>>> metric = CosineSimilarity()
>>> metric.update(randn(10,), randn(10,))
>>> metric.update(randn(10,2), randn(10,2))
>>> fig_, ax_ = metric.plot()
.. plot::
Expand All @@ -130,7 +130,7 @@ def plot(
>>> metric = CosineSimilarity()
>>> values = []
>>> for _ in range(10):
... values.append(metric(randn(10,), randn(10,)))
... values.append(metric(randn(10,2), randn(10,2)))
>>> fig, ax = metric.plot(values)
"""
Expand Down
38 changes: 16 additions & 22 deletions tests/unittests/regression/test_cosine_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@


_single_target_inputs = _Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE),
target=torch.rand(NUM_BATCHES, BATCH_SIZE),
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 1),
target=torch.rand(NUM_BATCHES, BATCH_SIZE, 1),
)

_multi_target_inputs = _Input(
Expand All @@ -40,23 +40,10 @@
)


def _multi_target_ref_metric(preds, target, reduction, sk_fn=sk_cosine):
sk_preds = preds.view(-1, num_targets).numpy()
sk_target = target.view(-1, num_targets).numpy()
result_array = sk_fn(sk_target, sk_preds)
col = np.diagonal(result_array)
col_sum = col.sum()
if reduction == "sum":
return col_sum
if reduction == "mean":
return col_sum / len(col)
return col


def _single_target_ref_metric(preds, target, reduction, sk_fn=sk_cosine):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()
result_array = sk_fn(np.expand_dims(sk_preds, axis=0), np.expand_dims(sk_target, axis=0))
def _ref_metric(preds, target, reduction):
sk_preds = preds.numpy()
sk_target = target.numpy()
result_array = sk_cosine(sk_target, sk_preds)
col = np.diagonal(result_array)
col_sum = col.sum()
if reduction == "sum":
Expand All @@ -70,8 +57,8 @@ def _single_target_ref_metric(preds, target, reduction, sk_fn=sk_cosine):
@pytest.mark.parametrize(
"preds, target, ref_metric",
[
(_single_target_inputs.preds, _single_target_inputs.target, _single_target_ref_metric),
(_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_ref_metric),
(_single_target_inputs.preds, _single_target_inputs.target, _ref_metric),
(_multi_target_inputs.preds, _multi_target_inputs.target, _ref_metric),
],
)
class TestCosineSimilarity(MetricTester):
Expand Down Expand Up @@ -104,4 +91,11 @@ def test_error_on_different_shape(metric_class=CosineSimilarity):
"""Test that error is raised on different shapes of input."""
metric = metric_class()
with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"):
metric(torch.randn(100), torch.randn(50))
metric(torch.randn(100, 2), torch.randn(50, 2))


def test_error_on_non_2d_input():
"""Test that error is raised if input is not 2-dimensional."""
metric = CosineSimilarity()
with pytest.raises(ValueError, match="Expected input to cosine similarity to be 2D tensors of shape.*"):
metric(torch.randn(100), torch.randn(100))
2 changes: 1 addition & 1 deletion tests/unittests/utilities/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@
id="learned perceptual image patch similarity",
),
pytest.param(ConcordanceCorrCoef, _rand_input, _rand_input, id="concordance corr coef"),
pytest.param(CosineSimilarity, _rand_input, _rand_input, id="cosine similarity"),
pytest.param(CosineSimilarity, _multilabel_rand_input, _multilabel_rand_input, id="cosine similarity"),
pytest.param(ExplainedVariance, _rand_input, _rand_input, id="explained variance"),
pytest.param(KendallRankCorrCoef, _rand_input, _rand_input, id="kendall rank corr coef"),
pytest.param(
Expand Down

0 comments on commit 245ee24

Please sign in to comment.