Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions ignite/metrics/accumulation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numbers
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Union

import torch

Expand Down Expand Up @@ -38,7 +38,10 @@ class VariableAccumulation(Metric):
_required_output_keys = None

def __init__(
self, op: Callable, output_transform: Callable = lambda x: x, device: Optional[Union[str, torch.device]] = None
self,
op: Callable,
output_transform: Callable = lambda x: x,
device: Union[str, torch.device] = torch.device("cpu"),
):
if not callable(op):
raise TypeError("Argument op should be a callable, but given {}".format(type(op)))
Expand Down Expand Up @@ -115,7 +118,9 @@ class Average(VariableAccumulation):

"""

def __init__(self, output_transform: Callable = lambda x: x, device: Optional[Union[str, torch.device]] = None):
def __init__(
self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu")
):
def _mean_op(a, x):
if isinstance(x, torch.Tensor) and x.ndim > 1:
x = x.sum(dim=0)
Expand Down Expand Up @@ -159,7 +164,9 @@ class GeometricAverage(VariableAccumulation):

"""

def __init__(self, output_transform: Callable = lambda x: x, device: Optional[Union[str, torch.device]] = None):
def __init__(
self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu")
):
def _geom_op(a: torch.Tensor, x: Union[Any, numbers.Number, torch.Tensor]) -> torch.Tensor:
if not isinstance(x, torch.Tensor):
x = torch.tensor(x)
Expand Down
13 changes: 7 additions & 6 deletions ignite/metrics/accuracy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Optional, Sequence, Union
from typing import Callable, Sequence, Union

import torch

Expand All @@ -13,7 +13,7 @@ def __init__(
self,
output_transform: Callable = lambda x: x,
is_multilabel: bool = False,
device: Optional[Union[str, torch.device]] = None,
device: Union[str, torch.device] = torch.device("cpu"),
):
self._is_multilabel = is_multilabel
self._type = None
Expand Down Expand Up @@ -130,15 +130,15 @@ def __init__(
self,
output_transform: Callable = lambda x: x,
is_multilabel: bool = False,
device: Optional[Union[str, torch.device]] = None,
device: Union[str, torch.device] = torch.device("cpu"),
):
self._num_correct = None
self._num_examples = None
super(Accuracy, self).__init__(output_transform=output_transform, is_multilabel=is_multilabel, device=device)

@reinit__is_reduced
def reset(self) -> None:
self._num_correct = 0
self._num_correct = torch.tensor(0, device=self._device)
self._num_examples = 0
super(Accuracy, self).reset()

Expand All @@ -161,11 +161,12 @@ def update(self, output: Sequence[torch.Tensor]) -> None:
y = torch.transpose(y, 1, last_dim - 1).reshape(-1, num_classes)
correct = torch.all(y == y_pred.type_as(y), dim=-1)

self._num_correct += torch.sum(correct).item()
# Don't need to detach here because torch.eq is not differentiable, so the computation graph is detached anyway.
self._num_correct += torch.sum(correct).to(self._device)
self._num_examples += correct.shape[0]

@sync_all_reduce("_num_examples", "_num_correct")
def compute(self) -> torch.Tensor:
if self._num_examples == 0:
raise NotComputableError("Accuracy must have at least one example before it can be computed.")
return self._num_correct / self._num_examples
return self._num_correct.item() / self._num_examples
2 changes: 1 addition & 1 deletion ignite/metrics/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(
num_classes: int,
average: Optional[str] = None,
output_transform: Callable = lambda x: x,
device: Optional[Union[str, torch.device]] = None,
device: Union[str, torch.device] = torch.device("cpu"),
):
if average is not None and average not in ("samples", "recall", "precision"):
raise ValueError("Argument average can None or one of 'samples', 'recall', 'precision'")
Expand Down
2 changes: 1 addition & 1 deletion ignite/metrics/fbeta.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def Fbeta(
precision: Optional[Precision] = None,
recall: Optional[Recall] = None,
output_transform: Optional[Callable] = None,
device: Optional[Union[str, torch.device]] = None,
device: Union[str, torch.device] = torch.device("cpu"),
) -> MetricsLambda:
"""Calculates F-beta score

Expand Down
6 changes: 5 additions & 1 deletion ignite/metrics/frequency.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Callable, Optional, Union

import torch

import ignite.distributed as idist
Expand Down Expand Up @@ -35,7 +37,9 @@ class Frequency(Metric):
# Epoch [2/10]: [50/100] 50%|█████ , wps=400 [00:17<00:35]
"""

def __init__(self, output_transform=lambda x: x, device=None):
def __init__(
self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu")
):
self._timer = None
self._acc = None
self._n = None
Expand Down
10 changes: 5 additions & 5 deletions ignite/metrics/loss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Optional, Sequence, Union
from typing import Callable, Sequence, Union

import torch

Expand Down Expand Up @@ -37,15 +37,15 @@ def __init__(
loss_fn: Callable,
output_transform: Callable = lambda x: x,
batch_size: Callable = lambda x: len(x),
device: Optional[Union[str, torch.device]] = None,
device: Union[str, torch.device] = torch.device("cpu"),
):
super(Loss, self).__init__(output_transform, device=device)
self._loss_fn = loss_fn
self._batch_size = batch_size

@reinit__is_reduced
def reset(self) -> None:
self._sum = 0
self._sum = torch.tensor(0.0, device=self._device)
self._num_examples = 0

@reinit__is_reduced
Expand All @@ -61,11 +61,11 @@ def update(self, output: Sequence[Union[torch.Tensor, dict]]) -> None:
raise ValueError("loss_fn did not return the average loss.")

n = self._batch_size(y)
self._sum += average_loss.item() * n
self._sum += average_loss.detach().to(self._device) * n
self._num_examples += n

@sync_all_reduce("_sum", "_num_examples")
def compute(self) -> None:
if self._num_examples == 0:
raise NotComputableError("Loss must have at least one example before it can be computed.")
return self._sum / self._num_examples
return self._sum.item() / self._num_examples
6 changes: 3 additions & 3 deletions ignite/metrics/mean_absolute_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,18 @@ class MeanAbsoluteError(Metric):

@reinit__is_reduced
def reset(self) -> None:
self._sum_of_absolute_errors = 0.0
self._sum_of_absolute_errors = torch.tensor(0.0, device=self._device)
self._num_examples = 0

@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
y_pred, y = output
absolute_errors = torch.abs(y_pred - y.view_as(y_pred))
self._sum_of_absolute_errors += torch.sum(absolute_errors).item()
self._sum_of_absolute_errors += torch.sum(absolute_errors).detach().to(self._device)
self._num_examples += y.shape[0]

@sync_all_reduce("_sum_of_absolute_errors", "_num_examples")
def compute(self) -> Union[float, torch.Tensor]:
if self._num_examples == 0:
raise NotComputableError("MeanAbsoluteError must have at least one example before it can be computed.")
return self._sum_of_absolute_errors / self._num_examples
return self._sum_of_absolute_errors.item() / self._num_examples
10 changes: 5 additions & 5 deletions ignite/metrics/mean_pairwise_distance.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Optional, Sequence, Union
from typing import Callable, Sequence, Union

import torch
from torch.nn.functional import pairwise_distance
Expand All @@ -21,26 +21,26 @@ def __init__(
p: int = 2,
eps: float = 1e-6,
output_transform: Callable = lambda x: x,
device: Optional[Union[str, torch.device]] = None,
device: Union[str, torch.device] = torch.device("cpu"),
):
super(MeanPairwiseDistance, self).__init__(output_transform, device=device)
self._p = p
self._eps = eps

@reinit__is_reduced
def reset(self):
self._sum_of_distances = 0.0
self._sum_of_distances = torch.tensor(0.0, device=self._device)
self._num_examples = 0

@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
y_pred, y = output
distances = pairwise_distance(y_pred, y, p=self._p, eps=self._eps)
self._sum_of_distances += torch.sum(distances).item()
self._sum_of_distances += torch.sum(distances).detach().to(self._device)
self._num_examples += y.shape[0]

@sync_all_reduce("_sum_of_distances", "_num_examples")
def compute(self) -> Union[float, torch.Tensor]:
if self._num_examples == 0:
raise NotComputableError("MeanAbsoluteError must have at least one example before it can be computed.")
return self._sum_of_distances / self._num_examples
return self._sum_of_distances.item() / self._num_examples
6 changes: 3 additions & 3 deletions ignite/metrics/mean_squared_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,18 @@ class MeanSquaredError(Metric):

@reinit__is_reduced
def reset(self) -> None:
self._sum_of_squared_errors = 0.0
self._sum_of_squared_errors = torch.tensor(0.0, device=self._device)
self._num_examples = 0

@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
y_pred, y = output
squared_errors = torch.pow(y_pred - y.view_as(y_pred), 2)
self._sum_of_squared_errors += torch.sum(squared_errors).item()
self._sum_of_squared_errors += torch.sum(squared_errors).detach().to(self._device)
self._num_examples += y.shape[0]

@sync_all_reduce("_sum_of_squared_errors", "_num_examples")
def compute(self) -> Union[float, torch.Tensor]:
if self._num_examples == 0:
raise NotComputableError("MeanSquaredError must have at least one example before it can be computed.")
return self._sum_of_squared_errors / self._num_examples
return self._sum_of_squared_errors.item() / self._num_examples
4 changes: 2 additions & 2 deletions ignite/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from abc import ABCMeta, abstractmethod
from collections.abc import Mapping
from functools import wraps
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Union

import torch

Expand Down Expand Up @@ -129,7 +129,7 @@ class Metric(metaclass=ABCMeta):
_required_output_keys = ("y_pred", "y")

def __init__(
self, output_transform: Callable = lambda x: x, device: Optional[Union[str, torch.device]] = None,
self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu"),
):
self._output_transform = output_transform

Expand Down
32 changes: 19 additions & 13 deletions ignite/metrics/precision.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Callable, Optional, Sequence, Union
from typing import Callable, Sequence, Union

import torch

Expand All @@ -18,7 +18,7 @@ def __init__(
output_transform: Callable = lambda x: x,
average: bool = False,
is_multilabel: bool = False,
device: Optional[Union[str, torch.device]] = None,
device: Union[str, torch.device] = torch.device("cpu"),
):
if idist.get_world_size() > 1:
if (not average) and is_multilabel:
Expand All @@ -39,15 +39,22 @@ def __init__(

@reinit__is_reduced
def reset(self) -> None:
dtype = torch.float64
self._true_positives = torch.tensor([], dtype=dtype) if (self._is_multilabel and not self._average) else 0
self._positives = torch.tensor([], dtype=dtype) if (self._is_multilabel and not self._average) else 0
if self._is_multilabel:
init_value = 0.0 if self._average else []
kws = {"dtype": torch.float64, "device": self._device}
self._true_positives = torch.tensor(init_value, **kws)
self._positives = torch.tensor(init_value, **kws)
else:
self._true_positives = 0
self._positives = 0

super(_BasePrecisionRecall, self).reset()

def compute(self) -> Union[torch.Tensor, float]:
if not (isinstance(self._positives, torch.Tensor) or self._positives > 0):
is_scalar = not isinstance(self._positives, torch.Tensor) or self._positives.ndim == 0
if is_scalar and self._positives == 0:
raise NotComputableError(
"{} must have at least one example before" " it can be computed.".format(self.__class__.__name__)
"{} must have at least one example before it can be computed.".format(self.__class__.__name__)
)

if not (self._type == "multilabel" and not self._average):
Expand Down Expand Up @@ -124,7 +131,7 @@ def __init__(
output_transform: Callable = lambda x: x,
average: bool = False,
is_multilabel: bool = False,
device: Optional[Union[str, torch.device]] = None,
device: Union[str, torch.device] = torch.device("cpu"),
):
super(Precision, self).__init__(
output_transform=output_transform, average=average, is_multilabel=is_multilabel, device=device
Expand Down Expand Up @@ -155,17 +162,16 @@ def update(self, output: Sequence[torch.Tensor]) -> None:
y_pred = torch.transpose(y_pred, 1, 0).reshape(num_classes, -1)
y = torch.transpose(y, 1, 0).reshape(num_classes, -1)

y = y.to(y_pred)
# Convert from int cuda/cpu to double on self._device
y_pred = y_pred.to(dtype=torch.float64, device=self._device)
y = y.to(dtype=torch.float64, device=self._device)
correct = y * y_pred
all_positives = y_pred.sum(dim=0).type(torch.DoubleTensor) # Convert from int cuda/cpu to double cpu
all_positives = y_pred.sum(dim=0)

if correct.sum() == 0:
true_positives = torch.zeros_like(all_positives)
else:
true_positives = correct.sum(dim=0)
# Convert from int cuda/cpu to double cpu
# We need double precision for the division true_positives / all_positives
true_positives = true_positives.type(torch.DoubleTensor)

if self._type == "multilabel":
if not self._average:
Expand Down
14 changes: 6 additions & 8 deletions ignite/metrics/recall.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Optional, Sequence, Union
from typing import Callable, Sequence, Union

import torch

Expand Down Expand Up @@ -69,7 +69,7 @@ def __init__(
output_transform: Callable = lambda x: x,
average: bool = False,
is_multilabel: bool = False,
device: Optional[Union[str, torch.device]] = None,
device: Union[str, torch.device] = torch.device("cpu"),
):
super(Recall, self).__init__(
output_transform=output_transform, average=average, is_multilabel=is_multilabel, device=device
Expand Down Expand Up @@ -100,19 +100,17 @@ def update(self, output: Sequence[torch.Tensor]) -> None:
y_pred = torch.transpose(y_pred, 1, 0).reshape(num_classes, -1)
y = torch.transpose(y, 1, 0).reshape(num_classes, -1)

y = y.type_as(y_pred)
# Convert from int cuda/cpu to double on self._device
y_pred = y_pred.to(dtype=torch.float64, device=self._device)
y = y.to(dtype=torch.float64, device=self._device)
correct = y * y_pred
actual_positives = y.sum(dim=0).type(torch.DoubleTensor) # Convert from int cuda/cpu to double cpu
actual_positives = y.sum(dim=0)

if correct.sum() == 0:
true_positives = torch.zeros_like(actual_positives)
else:
true_positives = correct.sum(dim=0)

# Convert from int cuda/cpu to double cpu
# We need double precision for the division true_positives / actual_positives
true_positives = true_positives.type(torch.DoubleTensor)

if self._type == "multilabel":
if not self._average:
self._true_positives = torch.cat([self._true_positives, true_positives], dim=0)
Expand Down
Loading