Skip to content

Commit

Permalink
Remove check that preds value need to be smaller than num_classes (#357)
Browse files Browse the repository at this point in the history
* fix

* changelog

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
SkafteNicki and pre-commit-ci[bot] authored Jul 9, 2021
1 parent 0424c17 commit 3de06f9
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 5 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Removed restriction that `threshold` has to be in (0,1) range to support logit input ([#351](https://github.com/PyTorchLightning/metrics/pull/351))


- Removed restriction that `preds` could not be bigger than `num_classes` to support logit input ([#357](https://github.com/PyTorchLightning/metrics/pull/357))


### Fixed


Expand Down
2 changes: 1 addition & 1 deletion tests/classification/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@

_input_multilabel_no_match = Input(preds=__temp_preds, target=__temp_target)

__mc_prob_logits = torch.randn(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)
__mc_prob_logits = 10 * torch.randn(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)
__mc_prob_preds = __mc_prob_logits.abs() / __mc_prob_logits.abs().sum(dim=2, keepdim=True)

_input_multiclass_prob = Input(
Expand Down
2 changes: 0 additions & 2 deletions tests/classification/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,6 @@ def test_threshold():
(_mc_prob.preds[0], randint(low=NUM_CLASSES + 1, high=100, size=(BATCH_SIZE, NUM_CLASSES)), 4, None),
# Max target larger than num_classes (with #dim preds = #dims target)
(randint(high=4, size=(7, 3)), randint(low=5, high=7, size=(7, 3)), 4, None),
# Max preds larger than num_classes (with #dim preds = #dims target)
(randint(low=5, high=7, size=(7, 3)), randint(high=4, size=(7, 3)), 4, None),
# Num_classes=1, but multiclass not false
(randint(high=2, size=(7, )), randint(high=2, size=(7, )), 1, None),
# multiclass=False, but implied class dimension (for multi-label, from shape) != num_classes
Expand Down
2 changes: 0 additions & 2 deletions torchmetrics/utilities/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,6 @@ def _check_num_classes_mc(
)
if num_classes <= target.max():
raise ValueError("The highest label in `target` should be smaller than `num_classes`.")
if num_classes <= preds.max():
raise ValueError("The highest label in `preds` should be smaller than `num_classes`.")
if preds.shape != target.shape and num_classes != implied_classes:
raise ValueError("The size of C dimension of `preds` does not match `num_classes`.")

Expand Down

0 comments on commit 3de06f9

Please sign in to comment.