Skip to content

Commit

Permalink
rename f1
Browse files Browse the repository at this point in the history
  • Loading branch information
cuent committed Jan 9, 2022
1 parent de79133 commit ab68b55
Show file tree
Hide file tree
Showing 6 changed files with 247 additions and 3 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* `SpearmanCorrcoef` -> `SpearmanCorrCoef`


- Renamed `torchmetrics.functional.f1` to `torchmetrics.functional.f1_score` and `torchmetrics.F1` to `torchmetrics.F1Score` to unify name convention with other frameworks ([#731](https://github.com/PyTorchLightning/metrics/pull/731))

### Removed

- Removed `embedding_similarity` metric ([#638](https://github.com/PyTorchLightning/metrics/pull/638))
Expand Down
2 changes: 2 additions & 0 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
AUC,
AUROC,
F1Score,
F1,
ROC,
Accuracy,
AveragePrecision,
Expand Down Expand Up @@ -103,6 +104,7 @@
"CosineSimilarity",
"TweedieDevianceScore",
"ExplainedVariance",
"F1",
"F1Score",
"FBeta",
"HammingDistance",
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torchmetrics.classification.calibration_error import CalibrationError # noqa: F401
from torchmetrics.classification.cohen_kappa import CohenKappa # noqa: F401
from torchmetrics.classification.confusion_matrix import ConfusionMatrix # noqa: F401
from torchmetrics.classification.f_beta import F1Score, FBeta # noqa: F401
from torchmetrics.classification.f_beta import F1, F1Score, FBeta # noqa: F401
from torchmetrics.classification.hamming_distance import HammingDistance # noqa: F401
from torchmetrics.classification.hinge import Hinge # noqa: F401
from torchmetrics.classification.iou import IoU # noqa: F401
Expand Down
132 changes: 131 additions & 1 deletion torchmetrics/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional
from warnings import warn

import torch
from torch import Tensor

from torchmetrics.classification.stat_scores import StatScores
Expand Down Expand Up @@ -300,3 +300,133 @@ def __init__(
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)


class F1(F1Score):
"""Computes F1 metric. F1 metrics correspond to a harmonic mean of the precision and recall scores.
Works with binary, multiclass, and multilabel data. Accepts logits or probabilities from a model
output or integer class values in prediction. Works with multi-dimensional preds and target.
Forward accepts
- ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes
- ``target`` (long tensor): ``(N, ...)``
If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument.
This is the case for binary and multi-label logits.
If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``.
Args:
num_classes:
Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods.
threshold:
Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case
of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities.
average:
Defines the reduction that is applied. Should be one of the following:
- ``'micro'`` [default]: Calculate the metric globally, across all samples and classes.
- ``'macro'``: Calculate the metric for each class separately, and average the
metrics across classes (with equal weights for each class).
- ``'weighted'``: Calculate the metric for each class separately, and average the
metrics across classes, weighting each class by its support (``tp + fn``).
- ``'none'`` or ``None``: Calculate the metric for each class separately, and return
the metric for every class.
- ``'samples'``: Calculate the metric for each sample, and average the metrics
across samples (with equal weights for each sample).
.. note:: What is considered a sample in the multi-dimensional multi-class case
depends on the value of ``mdmc_average``.
mdmc_average:
Defines how averaging is done for multi-dimensional multi-class inputs (on top of the
``average`` parameter). Should be one of the following:
- ``None`` [default]: Should be left unchanged if your data is not multi-dimensional
multi-class.
- ``'samplewise'``: In this case, the statistics are computed separately for each
sample on the ``N`` axis, and then averaged over samples.
The computation for each sample is done by treating the flattened extra axes ``...``
(see :ref:`references/modules:input types`) as the ``N`` dimension within the sample,
and computing the metric for the sample based on that.
- ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs
(see :ref:`references/modules:input types`)
are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they
were ``(N_X, C)``. From here on the ``average`` parameter applies as usual.
ignore_index:
Integer specifying a target class to ignore. If given, this class index does not contribute
to the returned score, regardless of reduction method. If an index is ignored, and ``average=None``
or ``'none'``, the score for the ignored class will be returned as ``nan``.
top_k:
Number of highest probability or logit score predictions considered to find the correct label,
relevant only for (multi-dimensional) multi-class inputs. The
default value (``None``) will be interpreted as 1 for these inputs.
Should be left at default (``None``) for all other types of inputs.
multiclass:
Used only in certain special cases, where you want to treat inputs as a different type
than what they appear to be. See the parameter's
:ref:`documentation section <references/modules:using the multiclass parameter>`
for a more detailed explanation and examples.
compute_on_step:
Forward only calls ``update()`` and return ``None`` if this is set to ``False``.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step
process_group:
Specify the process group on which synchronization is called.
default: ``None`` (which selects the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When ``None``, DDP
will be used to perform the allgather.
.. deprecated:: v0.7
Use :class:`torchmetrics.F1Score`. Will be removed in v0.8.
Example:
>>> from torchmetrics import F1
>>> target = torch.tensor([0, 1, 2, 0, 1, 2])
>>> preds = torch.tensor([0, 2, 1, 0, 0, 1])
>>> f1 = F1(num_classes=3)
>>> f1(preds, target)
tensor(0.3333)
"""

is_differentiable = False
higher_is_better = True

def __init__(
self,
num_classes: Optional[int] = None,
threshold: float = 0.5,
average: str = "micro",
mdmc_average: Optional[str] = None,
ignore_index: Optional[int] = None,
top_k: Optional[int] = None,
multiclass: Optional[bool] = None,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
) -> None:
warn("`F1` was renamed to `F1Score` in v0.7 and it will be removed in v0.8", DeprecationWarning)
super().__init__(
num_classes=num_classes,
threshold=threshold,
average=average,
mdmc_average=mdmc_average,
ignore_index=ignore_index,
top_k=top_k,
multiclass=multiclass,
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)
2 changes: 1 addition & 1 deletion torchmetrics/functional/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torchmetrics.functional.classification.cohen_kappa import cohen_kappa # noqa: F401
from torchmetrics.functional.classification.confusion_matrix import confusion_matrix # noqa: F401
from torchmetrics.functional.classification.dice import dice_score # noqa: F401
from torchmetrics.functional.classification.f_beta import f1_score, fbeta # noqa: F401
from torchmetrics.functional.classification.f_beta import f1, f1_score, fbeta # noqa: F401
from torchmetrics.functional.classification.hamming_distance import hamming_distance # noqa: F401
from torchmetrics.functional.classification.hinge import hinge # noqa: F401
from torchmetrics.functional.classification.jaccard import jaccard_index # noqa: F401
Expand Down
110 changes: 110 additions & 0 deletions torchmetrics/functional/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,3 +349,113 @@ def f1_score(
tensor(0.3333)
"""
return fbeta(preds, target, 1.0, average, mdmc_average, ignore_index, num_classes, threshold, top_k, multiclass)

def f1(
preds: Tensor,
target: Tensor,
beta: float = 1.0,
average: str = "micro",
mdmc_average: Optional[str] = None,
ignore_index: Optional[int] = None,
num_classes: Optional[int] = None,
threshold: float = 0.5,
top_k: Optional[int] = None,
multiclass: Optional[bool] = None,
) -> Tensor:
"""Computes F1 metric. F1 metrics correspond to a equally weighted average of the precision and recall scores.
Works with binary, multiclass, and multilabel data.
Accepts probabilities or logits from a model output or integer class values in prediction.
Works with multi-dimensional preds and target.
If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument
to convert into integer labels. This is the case for binary and multi-label probabilities or logits.
If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``.
The reduction method (how the precision scores are aggregated) is controlled by the
``average`` parameter, and additionally by the ``mdmc_average`` parameter in the
multi-dimensional multi-class case. Accepts all inputs listed in :ref:`references/modules:input types`.
Args:
preds: Predictions from model (probabilities, logits or labels)
target: Ground truth values
average:
Defines the reduction that is applied. Should be one of the following:
- ``'micro'`` [default]: Calculate the metric globally, across all samples and classes.
- ``'macro'``: Calculate the metric for each class separately, and average the
metrics across classes (with equal weights for each class).
- ``'weighted'``: Calculate the metric for each class separately, and average the
metrics across classes, weighting each class by its support (``tp + fn``).
- ``'none'`` or ``None``: Calculate the metric for each class separately, and return
the metric for every class.
- ``'samples'``: Calculate the metric for each sample, and average the metrics
across samples (with equal weights for each sample).
.. note:: What is considered a sample in the multi-dimensional multi-class case
depends on the value of ``mdmc_average``.
.. note:: If ``'none'`` and a given class doesn't occur in the `preds` or `target`,
the value for the class will be ``nan``.
mdmc_average:
Defines how averaging is done for multi-dimensional multi-class inputs (on top of the
``average`` parameter). Should be one of the following:
- ``None`` [default]: Should be left unchanged if your data is not multi-dimensional
multi-class.
- ``'samplewise'``: In this case, the statistics are computed separately for each
sample on the ``N`` axis, and then averaged over samples.
The computation for each sample is done by treating the flattened extra axes ``...``
(see :ref:`references/modules:input types`) as the ``N`` dimension within the sample,
and computing the metric for the sample based on that.
- ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs
(see :ref:`references/modules:input types`)
are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they
were ``(N_X, C)``. From here on the ``average`` parameter applies as usual.
ignore_index:
Integer specifying a target class to ignore. If given, this class index does not contribute
to the returned score, regardless of reduction method. If an index is ignored, and ``average=None``
or ``'none'``, the score for the ignored class will be returned as ``nan``.
num_classes:
Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods.
threshold:
Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case
of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities.
top_k:
Number of highest probability or logit score predictions considered to find the correct label,
relevant only for (multi-dimensional) multi-class inputs. The
default value (``None``) will be interpreted as 1 for these inputs.
Should be left at default (``None``) for all other types of inputs.
multiclass:
Used only in certain special cases, where you want to treat inputs as a different type
than what they appear to be. See the parameter's
:ref:`documentation section <references/modules:using the multiclass parameter>`
for a more detailed explanation and examples.
Return:
The shape of the returned tensor depends on the ``average`` parameter
- If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned
- If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number
of classes
.. deprecated:: v0.7
Use :class:`torchmetrics.functional.f1_score`. Will be removed in v0.8.
Example:
>>> from torchmetrics.functional import f1
>>> target = torch.tensor([0, 1, 2, 0, 1, 2])
>>> preds = torch.tensor([0, 2, 1, 0, 0, 1])
>>> f1(preds, target, num_classes=3)
tensor(0.3333)
"""
return f1_score(preds, target, 1.0, average, mdmc_average, ignore_index, num_classes, threshold, top_k, multiclass)

0 comments on commit ab68b55

Please sign in to comment.