Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -120,21 +120,21 @@ specific condition (e.g. ignore user-defined classes):

class CustomAccuracy(Metric):

def __init__(self, ignored_class, output_transform=lambda x: x):
def __init__(self, ignored_class, output_transform=lambda x: x, device="cpu"):
self.ignored_class = ignored_class
self._num_correct = None
self._num_examples = None
super(CustomAccuracy, self).__init__(output_transform=output_transform)
super(CustomAccuracy, self).__init__(output_transform=output_transform, device=device)

@reinit__is_reduced
def reset(self):
self._num_correct = 0
self._num_correct = torch.tensor(0, device=self._device)
self._num_examples = 0
super(CustomAccuracy, self).reset()

@reinit__is_reduced
def update(self, output):
y_pred, y = output
y_pred, y = output[0].detach(), output[1].detach()

indices = torch.argmax(y_pred, dim=1)

Expand All @@ -144,21 +144,25 @@ specific condition (e.g. ignore user-defined classes):
indices = indices[mask]
correct = torch.eq(indices, y).view(-1)

self._num_correct += torch.sum(correct).item()
self._num_correct += torch.sum(correct).to(self._device)
self._num_examples += correct.shape[0]

@sync_all_reduce("_num_examples", "_num_correct")
def compute(self):
if self._num_examples == 0:
raise NotComputableError('CustomAccuracy must have at least one example before it can be computed.')
return self._num_correct / self._num_examples
return self._num_correct.item() / self._num_examples


We imported necessary classes as :class:`~ignite.metrics.Metric`, :class:`~ignite.exceptions.NotComputableError` and
decorators to adapt the metric for distributed setting. In ``reset`` method, we reset internal variables ``_num_correct``
and ``_num_examples`` which are used to compute the custom metric. In ``updated`` method we define how to update
the internal variables. And finally in ``compute`` method, we compute metric value.

Notice that ``_num_correct`` is a tensor, since in ``update`` we accumulate tensor values. ``_num_examples`` is a python
scalar since we accumulate normal integers. For differentiable metrics, you must detach the accumulated values before
adding them to the internal variables.

We can check this implementation in a simple case:

.. code-block:: python
Expand Down
35 changes: 24 additions & 11 deletions ignite/metrics/accumulation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numbers
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Union

import torch

Expand Down Expand Up @@ -31,14 +31,19 @@ class VariableAccumulation(Metric):
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
device (str of torch.device, optional): optional device specification for internal storage.
device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's
device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By
default, CPU.

"""

_required_output_keys = None

def __init__(
self, op: Callable, output_transform: Callable = lambda x: x, device: Optional[Union[str, torch.device]] = None
self,
op: Callable,
output_transform: Callable = lambda x: x,
device: Union[str, torch.device] = torch.device("cpu"),
):
if not callable(op):
raise TypeError("Argument op should be a callable, but given {}".format(type(op)))
Expand All @@ -61,12 +66,13 @@ def _check_output_type(self, output: Union[Any, torch.Tensor, numbers.Number]) -
def update(self, output: Union[Any, torch.Tensor, numbers.Number]) -> None:
self._check_output_type(output)

if self._device is not None:
# Put output to the metric's device
if isinstance(output, torch.Tensor) and (output.device != self._device):
if isinstance(output, torch.Tensor):
output = output.detach()
if output.device != self._device:
output = output.to(self._device)

self.accumulator = self._op(self.accumulator, output)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we detach output if it is a tensor here ? To have something like

if self._device is not None:
    # Put output to the metric's device
    if isinstance(output, torch.Tensor):
        output = output.detach()
        if (output.device != self._device):
            output = output.to(self._device)

self.accumulator = self._op(self.accumulator, output)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we should, thanks for catching this.

if hasattr(output, "shape"):
self.num_examples += output.shape[0] if len(output.shape) > 1 else 1
else:
Expand Down Expand Up @@ -111,11 +117,14 @@ class Average(VariableAccumulation):
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
device (str of torch.device, optional): optional device specification for internal storage.

device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's
device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By
default, CPU.
"""

def __init__(self, output_transform: Callable = lambda x: x, device: Optional[Union[str, torch.device]] = None):
def __init__(
self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu")
):
def _mean_op(a, x):
if isinstance(x, torch.Tensor) and x.ndim > 1:
x = x.sum(dim=0)
Expand Down Expand Up @@ -155,11 +164,15 @@ class GeometricAverage(VariableAccumulation):
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
device (str of torch.device, optional): optional device specification for internal storage.
device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's
device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By
default, CPU.

"""

def __init__(self, output_transform: Callable = lambda x: x, device: Optional[Union[str, torch.device]] = None):
def __init__(
self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu")
):
def _geom_op(a: torch.Tensor, x: Union[Any, numbers.Number, torch.Tensor]) -> torch.Tensor:
if not isinstance(x, torch.Tensor):
x = torch.tensor(x)
Expand Down
22 changes: 12 additions & 10 deletions ignite/metrics/accuracy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Optional, Sequence, Union
from typing import Callable, Sequence, Union

import torch

Expand All @@ -13,7 +13,7 @@ def __init__(
self,
output_transform: Callable = lambda x: x,
is_multilabel: bool = False,
device: Optional[Union[str, torch.device]] = None,
device: Union[str, torch.device] = torch.device("cpu"),
):
self._is_multilabel = is_multilabel
self._type = None
Expand Down Expand Up @@ -122,31 +122,33 @@ def thresholded_output_transform(output):
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
is_multilabel (bool, optional): flag to use in multilabel case. By default, False.
device (str of torch.device, optional): unused argument.
device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's
device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By
default, CPU.

"""

def __init__(
self,
output_transform: Callable = lambda x: x,
is_multilabel: bool = False,
device: Optional[Union[str, torch.device]] = None,
device: Union[str, torch.device] = torch.device("cpu"),
):
self._num_correct = None
self._num_examples = None
super(Accuracy, self).__init__(output_transform=output_transform, is_multilabel=is_multilabel, device=device)

@reinit__is_reduced
def reset(self) -> None:
self._num_correct = 0
self._num_correct = torch.tensor(0, device=self._device)
self._num_examples = 0
super(Accuracy, self).reset()

@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
y_pred, y = output
self._check_shape((y_pred, y))
self._check_type((y_pred, y))
self._check_shape(output)
self._check_type(output)
y_pred, y = output[0].detach(), output[1].detach()

if self._type == "binary":
correct = torch.eq(y_pred.view(-1).to(y), y.view(-1))
Expand All @@ -161,11 +163,11 @@ def update(self, output: Sequence[torch.Tensor]) -> None:
y = torch.transpose(y, 1, last_dim - 1).reshape(-1, num_classes)
correct = torch.all(y == y_pred.type_as(y), dim=-1)

self._num_correct += torch.sum(correct).item()
self._num_correct += torch.sum(correct).to(self._device)
self._num_examples += correct.shape[0]

@sync_all_reduce("_num_examples", "_num_correct")
def compute(self) -> torch.Tensor:
if self._num_examples == 0:
raise NotComputableError("Accuracy must have at least one example before it can be computed.")
return self._num_correct / self._num_examples
return self._num_correct.item() / self._num_examples
10 changes: 6 additions & 4 deletions ignite/metrics/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ class ConfusionMatrix(Metric):
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
device (str of torch.device, optional): optional device specification for internal storage.
device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's
device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By
default, CPU.

Note:
In case of the targets `y` in `(batch_size, ...)` format, target indices between 0 and `num_classes` only
Expand All @@ -44,7 +46,7 @@ def __init__(
num_classes: int,
average: Optional[str] = None,
output_transform: Callable = lambda x: x,
device: Optional[Union[str, torch.device]] = None,
device: Union[str, torch.device] = torch.device("cpu"),
):
if average is not None and average not in ("samples", "recall", "precision"):
raise ValueError("Argument average can None or one of 'samples', 'recall', 'precision'")
Expand All @@ -61,7 +63,7 @@ def reset(self) -> None:
self._num_examples = 0

def _check_shape(self, output: Sequence[torch.Tensor]) -> None:
y_pred, y = output
y_pred, y = output[0].detach(), output[1].detach()

if y_pred.ndimension() < 2:
raise ValueError(
Expand Down Expand Up @@ -92,7 +94,7 @@ def _check_shape(self, output: Sequence[torch.Tensor]) -> None:
@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
self._check_shape(output)
y_pred, y = output
y_pred, y = output[0].detach(), output[1].detach()

self._num_examples += y_pred.shape[0]

Expand Down
6 changes: 4 additions & 2 deletions ignite/metrics/fbeta.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def Fbeta(
precision: Optional[Precision] = None,
recall: Optional[Recall] = None,
output_transform: Optional[Callable] = None,
device: Optional[Union[str, torch.device]] = None,
device: Union[str, torch.device] = torch.device("cpu"),
) -> MetricsLambda:
"""Calculates F-beta score

Expand All @@ -28,7 +28,9 @@ def Fbeta(
output_transform (callable, optional): a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. It is used only if precision or recall are not provided.
device (str of torch.device, optional): optional device specification for internal storage.
device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's
device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By
default, CPU.

Returns:
MetricsLambda, F-beta metric
Expand Down
6 changes: 5 additions & 1 deletion ignite/metrics/frequency.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Callable, Optional, Union

import torch

import ignite.distributed as idist
Expand Down Expand Up @@ -35,7 +37,9 @@ class Frequency(Metric):
# Epoch [2/10]: [50/100] 50%|█████ , wps=400 [00:17<00:35]
"""

def __init__(self, output_transform=lambda x: x, device=None):
def __init__(
self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu")
):
self._timer = None
self._acc = None
self._n = None
Expand Down
16 changes: 9 additions & 7 deletions ignite/metrics/loss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Optional, Sequence, Union
from typing import Callable, Sequence, Union

import torch

Expand Down Expand Up @@ -26,7 +26,9 @@ class Loss(Metric):
keywords arguments. If extra keywords arguments are provided they are passed to `loss_fn`.
batch_size (callable): a callable taking a target tensor that returns the
first dimension size (usually the batch size).
device (str of torch.device, optional): unused argument.
device (str or torch.device): specifies which device updates are accumulated on. Setting the
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
non-blocking. By default, CPU.

"""

Expand All @@ -37,15 +39,15 @@ def __init__(
loss_fn: Callable,
output_transform: Callable = lambda x: x,
batch_size: Callable = lambda x: len(x),
device: Optional[Union[str, torch.device]] = None,
device: Union[str, torch.device] = torch.device("cpu"),
):
super(Loss, self).__init__(output_transform, device=device)
self._loss_fn = loss_fn
self._batch_size = batch_size

@reinit__is_reduced
def reset(self) -> None:
self._sum = 0
self._sum = torch.tensor(0.0, device=self._device)
self._num_examples = 0

@reinit__is_reduced
Expand All @@ -55,17 +57,17 @@ def update(self, output: Sequence[Union[torch.Tensor, dict]]) -> None:
kwargs = {}
else:
y_pred, y, kwargs = output
average_loss = self._loss_fn(y_pred, y, **kwargs)
average_loss = self._loss_fn(y_pred.detach(), y.detach(), **kwargs)

if len(average_loss.shape) != 0:
raise ValueError("loss_fn did not return the average loss.")

n = self._batch_size(y)
self._sum += average_loss.item() * n
self._sum += average_loss.to(self._device) * n
self._num_examples += n

@sync_all_reduce("_sum", "_num_examples")
def compute(self) -> None:
if self._num_examples == 0:
raise NotComputableError("Loss must have at least one example before it can be computed.")
return self._sum / self._num_examples
return self._sum.item() / self._num_examples
8 changes: 4 additions & 4 deletions ignite/metrics/mean_absolute_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,18 @@ class MeanAbsoluteError(Metric):

@reinit__is_reduced
def reset(self) -> None:
self._sum_of_absolute_errors = 0.0
self._sum_of_absolute_errors = torch.tensor(0.0, device=self._device)
self._num_examples = 0

@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
y_pred, y = output
y_pred, y = output[0].detach(), output[1].detach()
absolute_errors = torch.abs(y_pred - y.view_as(y_pred))
self._sum_of_absolute_errors += torch.sum(absolute_errors).item()
self._sum_of_absolute_errors += torch.sum(absolute_errors).to(self._device)
self._num_examples += y.shape[0]

@sync_all_reduce("_sum_of_absolute_errors", "_num_examples")
def compute(self) -> Union[float, torch.Tensor]:
if self._num_examples == 0:
raise NotComputableError("MeanAbsoluteError must have at least one example before it can be computed.")
return self._sum_of_absolute_errors / self._num_examples
return self._sum_of_absolute_errors.item() / self._num_examples
12 changes: 6 additions & 6 deletions ignite/metrics/mean_pairwise_distance.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Optional, Sequence, Union
from typing import Callable, Sequence, Union

import torch
from torch.nn.functional import pairwise_distance
Expand All @@ -21,26 +21,26 @@ def __init__(
p: int = 2,
eps: float = 1e-6,
output_transform: Callable = lambda x: x,
device: Optional[Union[str, torch.device]] = None,
device: Union[str, torch.device] = torch.device("cpu"),
):
super(MeanPairwiseDistance, self).__init__(output_transform, device=device)
self._p = p
self._eps = eps

@reinit__is_reduced
def reset(self):
self._sum_of_distances = 0.0
self._sum_of_distances = torch.tensor(0.0, device=self._device)
self._num_examples = 0

@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
y_pred, y = output
y_pred, y = output[0].detach(), output[1].detach()
distances = pairwise_distance(y_pred, y, p=self._p, eps=self._eps)
self._sum_of_distances += torch.sum(distances).item()
self._sum_of_distances += torch.sum(distances).to(self._device)
self._num_examples += y.shape[0]

@sync_all_reduce("_sum_of_distances", "_num_examples")
def compute(self) -> Union[float, torch.Tensor]:
if self._num_examples == 0:
raise NotComputableError("MeanAbsoluteError must have at least one example before it can be computed.")
return self._sum_of_distances / self._num_examples
return self._sum_of_distances.item() / self._num_examples
Loading