Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove remaining threshold checks #401

Merged
merged 2 commits into from
Jul 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Removed

- Removed restriction that `threshold` has to be in (0,1) range to support logit input ([#351](https://github.com/PyTorchLightning/metrics/pull/351))
- Removed restriction that `threshold` has to be in (0,1) range to support logit input (
[#351](https://github.com/PyTorchLightning/metrics/pull/351)
[#401](https://github.com/PyTorchLightning/metrics/pull/401))


- Removed restriction that `preds` could not be bigger than `num_classes` to support logit input ([#357](https://github.com/PyTorchLightning/metrics/pull/357))
Expand Down
5 changes: 0 additions & 5 deletions torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,6 @@ class Accuracy(StatScores):
will be used to perform the allgather

Raises:
ValueError:
If ``threshold`` is not between ``0`` and ``1``.
ValueError:
If ``top_k`` is not an ``integer`` larger than ``0``.
ValueError:
Expand Down Expand Up @@ -205,9 +203,6 @@ def __init__(
self.add_state("correct", default=tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")

if not 0 < threshold < 1:
raise ValueError(f"The `threshold` should be a float in the (0,1) interval, got {threshold}")

if top_k is not None and (not isinstance(top_k, int) or top_k <= 0):
raise ValueError(f"The `top_k` should be an integer larger than 0, got {top_k}")

Expand Down
5 changes: 3 additions & 2 deletions torchmetrics/classification/cohen_kappa.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class labels.
- ``target`` (long tensor): ``(N, ...)``

If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument
to convert into integer labels. This is the case for binary and multi-label probabilities.
to convert into integer labels. This is the case for binary and multi-label probabilities or logits.

If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``.

Expand All @@ -55,7 +55,8 @@ class labels.
- ``'quadratic'``: quadratic weighting

threshold:
Threshold value for binary or multi-label probabilites. default: 0.5
Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case
of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities.

compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
Expand Down
2 changes: 0 additions & 2 deletions torchmetrics/classification/hamming_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,6 @@ def __init__(
self.add_state("correct", default=tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")

if not 0 < threshold < 1:
raise ValueError("The `threshold` should lie in the (0,1) interval.")
self.threshold = threshold

def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
Expand Down
2 changes: 0 additions & 2 deletions torchmetrics/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,6 @@ class StatScores(Metric):
will be used to perform the allgather.

Raises:
ValueError:
If ``threshold`` is not a ``float`` between ``0`` and ``1``.
ValueError:
If ``reduce`` is none of ``"micro"``, ``"macro"`` or ``"samples"``.
ValueError:
Expand Down
6 changes: 0 additions & 6 deletions torchmetrics/functional/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,6 @@ def accuracy(
still applies in both cases, if set.

Raises:
ValueError:
If ``threshold`` is not a ``float`` between ``0`` and ``1``.
ValueError:
If ``top_k`` parameter is set for ``multi-label`` inputs.
ValueError:
Expand Down Expand Up @@ -270,10 +268,6 @@ def accuracy(
>>> accuracy(preds, target, top_k=2)
tensor(0.6667)
"""

if not 0 < threshold < 1:
raise ValueError(f"The `threshold` should be a float in the (0,1) interval, got {threshold}")

allowed_average = ["micro", "macro", "weighted", "samples", "none", None]
if average not in allowed_average:
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")
Expand Down
4 changes: 2 additions & 2 deletions torchmetrics/utilities/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def _check_classification_inputs(
preds: Tensor with predictions (labels or probabilities)
target: Tensor with ground truth labels, always integers (labels)
threshold:
Threshold probability value for transforming probability predictions to binary
Threshold value for transforming probability/logit predictions to binary
(0,1) predictions, in the case of binary or multi-label inputs.
num_classes:
Number of classes. If not explicitly set, the number of classes will be inferred
Expand Down Expand Up @@ -371,7 +371,7 @@ def _input_format_classification(
preds: Tensor with predictions (labels or probabilities)
target: Tensor with ground truth labels, always integers (labels)
threshold:
Threshold probability value for transforming probability predictions to binary
Threshold value for transforming probability/logit predictions to binary
(0 or 1) predictions, in the case of binary or multi-label inputs.
num_classes:
Number of classes. If not explicitly set, the number of classes will be inferred
Expand Down