diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index bc2da2e12cba..46fe02aa9f11 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -120,21 +120,21 @@ specific condition (e.g. ignore user-defined classes): class CustomAccuracy(Metric): - def __init__(self, ignored_class, output_transform=lambda x: x): + def __init__(self, ignored_class, output_transform=lambda x: x, device="cpu"): self.ignored_class = ignored_class self._num_correct = None self._num_examples = None - super(CustomAccuracy, self).__init__(output_transform=output_transform) + super(CustomAccuracy, self).__init__(output_transform=output_transform, device=device) @reinit__is_reduced def reset(self): - self._num_correct = 0 + self._num_correct = torch.tensor(0, device=self._device) self._num_examples = 0 super(CustomAccuracy, self).reset() @reinit__is_reduced def update(self, output): - y_pred, y = output + y_pred, y = output[0].detach(), output[1].detach() indices = torch.argmax(y_pred, dim=1) @@ -144,14 +144,14 @@ specific condition (e.g. ignore user-defined classes): indices = indices[mask] correct = torch.eq(indices, y).view(-1) - self._num_correct += torch.sum(correct).item() + self._num_correct += torch.sum(correct).to(self._device) self._num_examples += correct.shape[0] @sync_all_reduce("_num_examples", "_num_correct") def compute(self): if self._num_examples == 0: raise NotComputableError('CustomAccuracy must have at least one example before it can be computed.') - return self._num_correct / self._num_examples + return self._num_correct.item() / self._num_examples We imported necessary classes as :class:`~ignite.metrics.Metric`, :class:`~ignite.exceptions.NotComputableError` and @@ -159,6 +159,10 @@ decorators to adapt the metric for distributed setting. In ``reset`` method, we and ``_num_examples`` which are used to compute the custom metric. In ``updated`` method we define how to update the internal variables. And finally in ``compute`` method, we compute metric value. +Notice that ``_num_correct`` is a tensor, since in ``update`` we accumulate tensor values. ``_num_examples`` is a python +scalar since we accumulate normal integers. For differentiable metrics, you must detach the accumulated values before +adding them to the internal variables. + We can check this implementation in a simple case: .. code-block:: python diff --git a/ignite/metrics/accumulation.py b/ignite/metrics/accumulation.py index dff45ee87fcc..926e7816bae2 100644 --- a/ignite/metrics/accumulation.py +++ b/ignite/metrics/accumulation.py @@ -1,5 +1,5 @@ import numbers -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Union import torch @@ -31,14 +31,19 @@ class VariableAccumulation(Metric): :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs. - device (str of torch.device, optional): optional device specification for internal storage. + device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's + device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By + default, CPU. """ _required_output_keys = None def __init__( - self, op: Callable, output_transform: Callable = lambda x: x, device: Optional[Union[str, torch.device]] = None + self, + op: Callable, + output_transform: Callable = lambda x: x, + device: Union[str, torch.device] = torch.device("cpu"), ): if not callable(op): raise TypeError("Argument op should be a callable, but given {}".format(type(op))) @@ -61,12 +66,13 @@ def _check_output_type(self, output: Union[Any, torch.Tensor, numbers.Number]) - def update(self, output: Union[Any, torch.Tensor, numbers.Number]) -> None: self._check_output_type(output) - if self._device is not None: - # Put output to the metric's device - if isinstance(output, torch.Tensor) and (output.device != self._device): + if isinstance(output, torch.Tensor): + output = output.detach() + if output.device != self._device: output = output.to(self._device) self.accumulator = self._op(self.accumulator, output) + if hasattr(output, "shape"): self.num_examples += output.shape[0] if len(output.shape) > 1 else 1 else: @@ -111,11 +117,14 @@ class Average(VariableAccumulation): :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs. - device (str of torch.device, optional): optional device specification for internal storage. - + device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's + device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By + default, CPU. """ - def __init__(self, output_transform: Callable = lambda x: x, device: Optional[Union[str, torch.device]] = None): + def __init__( + self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu") + ): def _mean_op(a, x): if isinstance(x, torch.Tensor) and x.ndim > 1: x = x.sum(dim=0) @@ -155,11 +164,15 @@ class GeometricAverage(VariableAccumulation): :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs. - device (str of torch.device, optional): optional device specification for internal storage. + device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's + device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By + default, CPU. """ - def __init__(self, output_transform: Callable = lambda x: x, device: Optional[Union[str, torch.device]] = None): + def __init__( + self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu") + ): def _geom_op(a: torch.Tensor, x: Union[Any, numbers.Number, torch.Tensor]) -> torch.Tensor: if not isinstance(x, torch.Tensor): x = torch.tensor(x) diff --git a/ignite/metrics/accuracy.py b/ignite/metrics/accuracy.py index 8ac5a25f083f..7d6c939e4b53 100644 --- a/ignite/metrics/accuracy.py +++ b/ignite/metrics/accuracy.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional, Sequence, Union +from typing import Callable, Sequence, Union import torch @@ -13,7 +13,7 @@ def __init__( self, output_transform: Callable = lambda x: x, is_multilabel: bool = False, - device: Optional[Union[str, torch.device]] = None, + device: Union[str, torch.device] = torch.device("cpu"), ): self._is_multilabel = is_multilabel self._type = None @@ -122,7 +122,9 @@ def thresholded_output_transform(output): form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs. is_multilabel (bool, optional): flag to use in multilabel case. By default, False. - device (str of torch.device, optional): unused argument. + device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's + device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By + default, CPU. """ @@ -130,7 +132,7 @@ def __init__( self, output_transform: Callable = lambda x: x, is_multilabel: bool = False, - device: Optional[Union[str, torch.device]] = None, + device: Union[str, torch.device] = torch.device("cpu"), ): self._num_correct = None self._num_examples = None @@ -138,15 +140,15 @@ def __init__( @reinit__is_reduced def reset(self) -> None: - self._num_correct = 0 + self._num_correct = torch.tensor(0, device=self._device) self._num_examples = 0 super(Accuracy, self).reset() @reinit__is_reduced def update(self, output: Sequence[torch.Tensor]) -> None: - y_pred, y = output - self._check_shape((y_pred, y)) - self._check_type((y_pred, y)) + self._check_shape(output) + self._check_type(output) + y_pred, y = output[0].detach(), output[1].detach() if self._type == "binary": correct = torch.eq(y_pred.view(-1).to(y), y.view(-1)) @@ -161,11 +163,11 @@ def update(self, output: Sequence[torch.Tensor]) -> None: y = torch.transpose(y, 1, last_dim - 1).reshape(-1, num_classes) correct = torch.all(y == y_pred.type_as(y), dim=-1) - self._num_correct += torch.sum(correct).item() + self._num_correct += torch.sum(correct).to(self._device) self._num_examples += correct.shape[0] @sync_all_reduce("_num_examples", "_num_correct") def compute(self) -> torch.Tensor: if self._num_examples == 0: raise NotComputableError("Accuracy must have at least one example before it can be computed.") - return self._num_correct / self._num_examples + return self._num_correct.item() / self._num_examples diff --git a/ignite/metrics/confusion_matrix.py b/ignite/metrics/confusion_matrix.py index 2ab1f436bace..3c797efaf3e4 100644 --- a/ignite/metrics/confusion_matrix.py +++ b/ignite/metrics/confusion_matrix.py @@ -30,7 +30,9 @@ class ConfusionMatrix(Metric): :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs. - device (str of torch.device, optional): optional device specification for internal storage. + device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's + device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By + default, CPU. Note: In case of the targets `y` in `(batch_size, ...)` format, target indices between 0 and `num_classes` only @@ -44,7 +46,7 @@ def __init__( num_classes: int, average: Optional[str] = None, output_transform: Callable = lambda x: x, - device: Optional[Union[str, torch.device]] = None, + device: Union[str, torch.device] = torch.device("cpu"), ): if average is not None and average not in ("samples", "recall", "precision"): raise ValueError("Argument average can None or one of 'samples', 'recall', 'precision'") @@ -61,7 +63,7 @@ def reset(self) -> None: self._num_examples = 0 def _check_shape(self, output: Sequence[torch.Tensor]) -> None: - y_pred, y = output + y_pred, y = output[0].detach(), output[1].detach() if y_pred.ndimension() < 2: raise ValueError( @@ -92,7 +94,7 @@ def _check_shape(self, output: Sequence[torch.Tensor]) -> None: @reinit__is_reduced def update(self, output: Sequence[torch.Tensor]) -> None: self._check_shape(output) - y_pred, y = output + y_pred, y = output[0].detach(), output[1].detach() self._num_examples += y_pred.shape[0] diff --git a/ignite/metrics/fbeta.py b/ignite/metrics/fbeta.py index 05e217846115..eb6776a17eb8 100644 --- a/ignite/metrics/fbeta.py +++ b/ignite/metrics/fbeta.py @@ -15,7 +15,7 @@ def Fbeta( precision: Optional[Precision] = None, recall: Optional[Recall] = None, output_transform: Optional[Callable] = None, - device: Optional[Union[str, torch.device]] = None, + device: Union[str, torch.device] = torch.device("cpu"), ) -> MetricsLambda: """Calculates F-beta score @@ -28,7 +28,9 @@ def Fbeta( output_transform (callable, optional): a callable that is used to transform the :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the form expected by the metric. It is used only if precision or recall are not provided. - device (str of torch.device, optional): optional device specification for internal storage. + device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's + device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By + default, CPU. Returns: MetricsLambda, F-beta metric diff --git a/ignite/metrics/frequency.py b/ignite/metrics/frequency.py index 75eba360bb53..447cbbf63fd8 100644 --- a/ignite/metrics/frequency.py +++ b/ignite/metrics/frequency.py @@ -1,3 +1,5 @@ +from typing import Callable, Optional, Union + import torch import ignite.distributed as idist @@ -35,7 +37,9 @@ class Frequency(Metric): # Epoch [2/10]: [50/100] 50%|█████ , wps=400 [00:17<00:35] """ - def __init__(self, output_transform=lambda x: x, device=None): + def __init__( + self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu") + ): self._timer = None self._acc = None self._n = None diff --git a/ignite/metrics/loss.py b/ignite/metrics/loss.py index 5a4133c84d95..c667cf3ade52 100644 --- a/ignite/metrics/loss.py +++ b/ignite/metrics/loss.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional, Sequence, Union +from typing import Callable, Sequence, Union import torch @@ -26,7 +26,9 @@ class Loss(Metric): keywords arguments. If extra keywords arguments are provided they are passed to `loss_fn`. batch_size (callable): a callable taking a target tensor that returns the first dimension size (usually the batch size). - device (str of torch.device, optional): unused argument. + device (str or torch.device): specifies which device updates are accumulated on. Setting the + metric's device to be the same as your ``update`` arguments ensures the ``update`` method is + non-blocking. By default, CPU. """ @@ -37,7 +39,7 @@ def __init__( loss_fn: Callable, output_transform: Callable = lambda x: x, batch_size: Callable = lambda x: len(x), - device: Optional[Union[str, torch.device]] = None, + device: Union[str, torch.device] = torch.device("cpu"), ): super(Loss, self).__init__(output_transform, device=device) self._loss_fn = loss_fn @@ -45,7 +47,7 @@ def __init__( @reinit__is_reduced def reset(self) -> None: - self._sum = 0 + self._sum = torch.tensor(0.0, device=self._device) self._num_examples = 0 @reinit__is_reduced @@ -55,17 +57,17 @@ def update(self, output: Sequence[Union[torch.Tensor, dict]]) -> None: kwargs = {} else: y_pred, y, kwargs = output - average_loss = self._loss_fn(y_pred, y, **kwargs) + average_loss = self._loss_fn(y_pred.detach(), y.detach(), **kwargs) if len(average_loss.shape) != 0: raise ValueError("loss_fn did not return the average loss.") n = self._batch_size(y) - self._sum += average_loss.item() * n + self._sum += average_loss.to(self._device) * n self._num_examples += n @sync_all_reduce("_sum", "_num_examples") def compute(self) -> None: if self._num_examples == 0: raise NotComputableError("Loss must have at least one example before it can be computed.") - return self._sum / self._num_examples + return self._sum.item() / self._num_examples diff --git a/ignite/metrics/mean_absolute_error.py b/ignite/metrics/mean_absolute_error.py index 86e699be096f..1b1b618cbc5a 100644 --- a/ignite/metrics/mean_absolute_error.py +++ b/ignite/metrics/mean_absolute_error.py @@ -17,18 +17,18 @@ class MeanAbsoluteError(Metric): @reinit__is_reduced def reset(self) -> None: - self._sum_of_absolute_errors = 0.0 + self._sum_of_absolute_errors = torch.tensor(0.0, device=self._device) self._num_examples = 0 @reinit__is_reduced def update(self, output: Sequence[torch.Tensor]) -> None: - y_pred, y = output + y_pred, y = output[0].detach(), output[1].detach() absolute_errors = torch.abs(y_pred - y.view_as(y_pred)) - self._sum_of_absolute_errors += torch.sum(absolute_errors).item() + self._sum_of_absolute_errors += torch.sum(absolute_errors).to(self._device) self._num_examples += y.shape[0] @sync_all_reduce("_sum_of_absolute_errors", "_num_examples") def compute(self) -> Union[float, torch.Tensor]: if self._num_examples == 0: raise NotComputableError("MeanAbsoluteError must have at least one example before it can be computed.") - return self._sum_of_absolute_errors / self._num_examples + return self._sum_of_absolute_errors.item() / self._num_examples diff --git a/ignite/metrics/mean_pairwise_distance.py b/ignite/metrics/mean_pairwise_distance.py index 9e9239ee6553..0473fc07ff50 100644 --- a/ignite/metrics/mean_pairwise_distance.py +++ b/ignite/metrics/mean_pairwise_distance.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional, Sequence, Union +from typing import Callable, Sequence, Union import torch from torch.nn.functional import pairwise_distance @@ -21,7 +21,7 @@ def __init__( p: int = 2, eps: float = 1e-6, output_transform: Callable = lambda x: x, - device: Optional[Union[str, torch.device]] = None, + device: Union[str, torch.device] = torch.device("cpu"), ): super(MeanPairwiseDistance, self).__init__(output_transform, device=device) self._p = p @@ -29,18 +29,18 @@ def __init__( @reinit__is_reduced def reset(self): - self._sum_of_distances = 0.0 + self._sum_of_distances = torch.tensor(0.0, device=self._device) self._num_examples = 0 @reinit__is_reduced def update(self, output: Sequence[torch.Tensor]) -> None: - y_pred, y = output + y_pred, y = output[0].detach(), output[1].detach() distances = pairwise_distance(y_pred, y, p=self._p, eps=self._eps) - self._sum_of_distances += torch.sum(distances).item() + self._sum_of_distances += torch.sum(distances).to(self._device) self._num_examples += y.shape[0] @sync_all_reduce("_sum_of_distances", "_num_examples") def compute(self) -> Union[float, torch.Tensor]: if self._num_examples == 0: raise NotComputableError("MeanAbsoluteError must have at least one example before it can be computed.") - return self._sum_of_distances / self._num_examples + return self._sum_of_distances.item() / self._num_examples diff --git a/ignite/metrics/mean_squared_error.py b/ignite/metrics/mean_squared_error.py index 4c5a9ee3371c..f9addc301c42 100644 --- a/ignite/metrics/mean_squared_error.py +++ b/ignite/metrics/mean_squared_error.py @@ -17,18 +17,18 @@ class MeanSquaredError(Metric): @reinit__is_reduced def reset(self) -> None: - self._sum_of_squared_errors = 0.0 + self._sum_of_squared_errors = torch.tensor(0.0, device=self._device) self._num_examples = 0 @reinit__is_reduced def update(self, output: Sequence[torch.Tensor]) -> None: - y_pred, y = output + y_pred, y = output[0].detach(), output[1].detach() squared_errors = torch.pow(y_pred - y.view_as(y_pred), 2) - self._sum_of_squared_errors += torch.sum(squared_errors).item() + self._sum_of_squared_errors += torch.sum(squared_errors).to(self._device) self._num_examples += y.shape[0] @sync_all_reduce("_sum_of_squared_errors", "_num_examples") def compute(self) -> Union[float, torch.Tensor]: if self._num_examples == 0: raise NotComputableError("MeanSquaredError must have at least one example before it can be computed.") - return self._sum_of_squared_errors / self._num_examples + return self._sum_of_squared_errors.item() / self._num_examples diff --git a/ignite/metrics/metric.py b/ignite/metrics/metric.py index 2d334e3c79f0..228a89d2d1bf 100644 --- a/ignite/metrics/metric.py +++ b/ignite/metrics/metric.py @@ -2,7 +2,7 @@ from abc import ABCMeta, abstractmethod from collections.abc import Mapping from functools import wraps -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Union import torch @@ -122,14 +122,15 @@ class Metric(metaclass=ABCMeta): form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs. By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. - device (str or torch.device, optional): optional device specification for internal storage. - + device (str or torch.device): specifies which device updates are accumulated on. Setting the + metric's device to be the same as your ``update`` arguments ensures the ``update`` method is + non-blocking. By default, CPU. """ _required_output_keys = ("y_pred", "y") def __init__( - self, output_transform: Callable = lambda x: x, device: Optional[Union[str, torch.device]] = None, + self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu"), ): self._output_transform = output_transform @@ -143,7 +144,12 @@ def __init__( "across all computing devices".format(self.__class__.__name__), RuntimeWarning, ) - self._device = device + + # Some metrics have a large performance regression when run on XLA devices, so for now, we disallow it. + if torch.device(device).type == "xla": + raise ValueError("Cannot create metric on an XLA device. Use device='cpu' instead.") + + self._device = torch.device(device) self._is_reduced = False self.reset() diff --git a/ignite/metrics/metrics_lambda.py b/ignite/metrics/metrics_lambda.py index 651edca933b8..14348ebf51b1 100644 --- a/ignite/metrics/metrics_lambda.py +++ b/ignite/metrics/metrics_lambda.py @@ -1,6 +1,8 @@ import itertools from typing import Any, Callable, Union +import torch + from ignite.engine import Engine, Events from ignite.metrics.metric import EpochWise, Metric, MetricUsage, reinit__is_reduced @@ -83,8 +85,8 @@ def update(self, output) -> None: pass def compute(self) -> Any: - materialized = [i.compute() if isinstance(i, Metric) else i for i in self.args] - materialized_kwargs = {k: (v.compute() if isinstance(v, Metric) else v) for k, v in self.kwargs.items()} + materialized = [_get_value_on_cpu(i) for i in self.args] + materialized_kwargs = {k: _get_value_on_cpu(v) for k, v in self.kwargs.items()} return self.function(*materialized, **materialized_kwargs) def _internal_attach(self, engine: Engine, usage: MetricUsage) -> None: @@ -134,3 +136,11 @@ def _internal_is_attached(self, engine: Engine, usage: MetricUsage) -> bool: if not engine.has_event_handler(metric.iteration_completed, usage.ITERATION_COMPLETED): is_detached = True return not is_detached + + +def _get_value_on_cpu(v: Any): + if isinstance(v, Metric): + v = v.compute() + if isinstance(v, torch.Tensor): + v = v.cpu() + return v diff --git a/ignite/metrics/precision.py b/ignite/metrics/precision.py index 2b8152630ddb..f1d417ee015f 100644 --- a/ignite/metrics/precision.py +++ b/ignite/metrics/precision.py @@ -1,5 +1,5 @@ import warnings -from typing import Callable, Optional, Sequence, Union +from typing import Callable, Sequence, Union import torch @@ -18,7 +18,7 @@ def __init__( output_transform: Callable = lambda x: x, average: bool = False, is_multilabel: bool = False, - device: Optional[Union[str, torch.device]] = None, + device: Union[str, torch.device] = torch.device("cpu"), ): if idist.get_world_size() > 1: if (not average) and is_multilabel: @@ -39,15 +39,22 @@ def __init__( @reinit__is_reduced def reset(self) -> None: - dtype = torch.float64 - self._true_positives = torch.tensor([], dtype=dtype) if (self._is_multilabel and not self._average) else 0 - self._positives = torch.tensor([], dtype=dtype) if (self._is_multilabel and not self._average) else 0 + if self._is_multilabel: + init_value = 0.0 if self._average else [] + kws = {"dtype": torch.float64, "device": self._device} + self._true_positives = torch.tensor(init_value, **kws) + self._positives = torch.tensor(init_value, **kws) + else: + self._true_positives = 0 + self._positives = 0 + super(_BasePrecisionRecall, self).reset() def compute(self) -> Union[torch.Tensor, float]: - if not (isinstance(self._positives, torch.Tensor) or self._positives > 0): + is_scalar = not isinstance(self._positives, torch.Tensor) or self._positives.ndim == 0 + if is_scalar and self._positives == 0: raise NotComputableError( - "{} must have at least one example before" " it can be computed.".format(self.__class__.__name__) + "{} must have at least one example before it can be computed.".format(self.__class__.__name__) ) if not (self._type == "multilabel" and not self._average): @@ -115,7 +122,9 @@ def thresholded_output_transform(output): in multiclass case), otherwise, returns a tensor with the precision (for each class in multiclass case). is_multilabel (bool, optional) flag to use in multilabel case. By default, value is False. If True, average parameter should be True and the average is computed across samples, instead of classes. - device (str of torch.device, optional): unused argument. + device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's + device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By + default, CPU. """ @@ -124,7 +133,7 @@ def __init__( output_transform: Callable = lambda x: x, average: bool = False, is_multilabel: bool = False, - device: Optional[Union[str, torch.device]] = None, + device: Union[str, torch.device] = torch.device("cpu"), ): super(Precision, self).__init__( output_transform=output_transform, average=average, is_multilabel=is_multilabel, device=device @@ -132,9 +141,9 @@ def __init__( @reinit__is_reduced def update(self, output: Sequence[torch.Tensor]) -> None: - y_pred, y = output self._check_shape(output) - self._check_type((y_pred, y)) + self._check_type(output) + y_pred, y = output[0].detach(), output[1].detach() if self._type == "binary": y_pred = y_pred.view(-1) @@ -155,17 +164,16 @@ def update(self, output: Sequence[torch.Tensor]) -> None: y_pred = torch.transpose(y_pred, 1, 0).reshape(num_classes, -1) y = torch.transpose(y, 1, 0).reshape(num_classes, -1) - y = y.to(y_pred) + # Convert from int cuda/cpu to double on self._device + y_pred = y_pred.to(dtype=torch.float64, device=self._device) + y = y.to(dtype=torch.float64, device=self._device) correct = y * y_pred - all_positives = y_pred.sum(dim=0).type(torch.DoubleTensor) # Convert from int cuda/cpu to double cpu + all_positives = y_pred.sum(dim=0) if correct.sum() == 0: true_positives = torch.zeros_like(all_positives) else: true_positives = correct.sum(dim=0) - # Convert from int cuda/cpu to double cpu - # We need double precision for the division true_positives / all_positives - true_positives = true_positives.type(torch.DoubleTensor) if self._type == "multilabel": if not self._average: diff --git a/ignite/metrics/recall.py b/ignite/metrics/recall.py index 048c11b10c5b..ad391705a004 100644 --- a/ignite/metrics/recall.py +++ b/ignite/metrics/recall.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional, Sequence, Union +from typing import Callable, Sequence, Union import torch @@ -60,7 +60,9 @@ def thresholded_output_transform(output): in multiclass case), otherwise, returns a tensor with the precision (for each class in multiclass case). is_multilabel (bool, optional) flag to use in multilabel case. By default, value is False. If True, average parameter should be True and the average is computed across samples, instead of classes. - device (str of torch.device, optional): unused argument. + device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's + device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By + default, CPU. """ @@ -69,7 +71,7 @@ def __init__( output_transform: Callable = lambda x: x, average: bool = False, is_multilabel: bool = False, - device: Optional[Union[str, torch.device]] = None, + device: Union[str, torch.device] = torch.device("cpu"), ): super(Recall, self).__init__( output_transform=output_transform, average=average, is_multilabel=is_multilabel, device=device @@ -77,9 +79,9 @@ def __init__( @reinit__is_reduced def update(self, output: Sequence[torch.Tensor]) -> None: - y_pred, y = output self._check_shape(output) - self._check_type((y_pred, y)) + self._check_type(output) + y_pred, y = output[0].detach(), output[1].detach() if self._type == "binary": y_pred = y_pred.view(-1) @@ -100,19 +102,17 @@ def update(self, output: Sequence[torch.Tensor]) -> None: y_pred = torch.transpose(y_pred, 1, 0).reshape(num_classes, -1) y = torch.transpose(y, 1, 0).reshape(num_classes, -1) - y = y.type_as(y_pred) + # Convert from int cuda/cpu to double on self._device + y_pred = y_pred.to(dtype=torch.float64, device=self._device) + y = y.to(dtype=torch.float64, device=self._device) correct = y * y_pred - actual_positives = y.sum(dim=0).type(torch.DoubleTensor) # Convert from int cuda/cpu to double cpu + actual_positives = y.sum(dim=0) if correct.sum() == 0: true_positives = torch.zeros_like(actual_positives) else: true_positives = correct.sum(dim=0) - # Convert from int cuda/cpu to double cpu - # We need double precision for the division true_positives / actual_positives - true_positives = true_positives.type(torch.DoubleTensor) - if self._type == "multilabel": if not self._average: self._true_positives = torch.cat([self._true_positives, true_positives], dim=0) diff --git a/ignite/metrics/running_average.py b/ignite/metrics/running_average.py index 2094e3866ada..0fa1216c7940 100644 --- a/ignite/metrics/running_average.py +++ b/ignite/metrics/running_average.py @@ -20,7 +20,11 @@ class RunningAverage(Metric): corresponds the output of process function. Otherwise it should be None. epoch_bound (boolean, optional): whether the running average should be reset after each epoch (defaults to True). - device (str of torch.device, optional): unused argument. + device (str or torch.device, optional): specifies which device updates are accumulated on. Should be + None when ``src`` is an instance of :class:`~ignite.metrics.Metric`, as the running average will + use the ``src``'s device. Otherwise, defaults to CPU. Only applicable when the computed value + from the metric is a tensor. + Examples: @@ -63,6 +67,7 @@ def __init__( self.src = src self._get_src_value = self._get_metric_value self.iteration_completed = self._metric_iteration_completed + device = src._device else: if output_transform is None: raise ValueError( @@ -71,6 +76,8 @@ def __init__( ) self._get_src_value = self._get_output_value self.update = self._output_update + if device is None: + device = torch.device("cpu") self.alpha = alpha self.epoch_bound = epoch_bound @@ -118,5 +125,5 @@ def _metric_iteration_completed(self, engine: Engine) -> None: @reinit__is_reduced def _output_update(self, output: Union[torch.Tensor, float]) -> None: if isinstance(output, torch.Tensor): - output = output.detach().clone() + output = output.detach().to(self._device, copy=True) self.src = output diff --git a/ignite/metrics/ssim.py b/ignite/metrics/ssim.py index 91491432db37..7c20a3260cf7 100644 --- a/ignite/metrics/ssim.py +++ b/ignite/metrics/ssim.py @@ -116,7 +116,8 @@ def _gaussian_or_uniform_kernel(self, kernel_size, sigma): @reinit__is_reduced def update(self, output: Sequence[torch.Tensor]) -> None: - y_pred, y = output + y_pred, y = output[0].detach(), output[1].detach() + if y_pred.dtype != y.dtype: raise TypeError( "Expected y_pred and y to have the same data type. Got y_pred: {} and y: {}.".format( diff --git a/ignite/metrics/top_k_categorical_accuracy.py b/ignite/metrics/top_k_categorical_accuracy.py index 3fb493ed8441..dad423f86af1 100644 --- a/ignite/metrics/top_k_categorical_accuracy.py +++ b/ignite/metrics/top_k_categorical_accuracy.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional, Sequence, Union +from typing import Callable, Sequence, Union import torch @@ -16,23 +16,24 @@ class TopKCategoricalAccuracy(Metric): """ def __init__( - self, k=5, output_transform: Callable = lambda x: x, device: Optional[Union[str, torch.device]] = None + self, k=5, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu"), ): super(TopKCategoricalAccuracy, self).__init__(output_transform, device=device) self._k = k @reinit__is_reduced def reset(self) -> None: - self._num_correct = 0 + self._num_correct = torch.tensor(0, device=self._device) self._num_examples = 0 @reinit__is_reduced def update(self, output: Sequence) -> None: - y_pred, y = output + y_pred, y = output[0].detach(), output[1].detach() sorted_indices = torch.topk(y_pred, self._k, dim=1)[1] expanded_y = y.view(-1, 1).expand(-1, self._k) correct = torch.sum(torch.eq(sorted_indices, expanded_y), dim=1) - self._num_correct += torch.sum(correct).item() + + self._num_correct += torch.sum(correct).to(self._device) self._num_examples += correct.shape[0] @sync_all_reduce("_num_correct", "_num_examples") @@ -41,4 +42,4 @@ def compute(self) -> Union[float, torch.Tensor]: raise NotComputableError( "TopKCategoricalAccuracy must have at" "least one example before it can be computed." ) - return self._num_correct / self._num_examples + return self._num_correct.item() / self._num_examples diff --git a/tests/ignite/metrics/test_accumulation.py b/tests/ignite/metrics/test_accumulation.py index abd71f3ad1df..adc7acac85d1 100644 --- a/tests/ignite/metrics/test_accumulation.py +++ b/tests/ignite/metrics/test_accumulation.py @@ -198,105 +198,127 @@ def compute_mean_std(engine, batch): def _test_distrib_variable_accumulation(device): + def _test(metric_device): + mean_var = VariableAccumulation(lambda a, x: a + x, device=metric_device) + y_true = torch.rand(100, device=device, dtype=torch.float64) - mean_var = VariableAccumulation(lambda a, x: a + x, device=device) - y_true = torch.rand(100, device=device, dtype=torch.float64) + for y in y_true: + mean_var.update(y) - for y in y_true: - mean_var.update(y) + y_true = idist.all_reduce(y_true) + a, n = mean_var.compute() + assert a.item() == pytest.approx(y_true.sum().item()) + assert n == len(y_true) * idist.get_world_size() + # check if call compute twice + a, n = mean_var.compute() + assert a.item() == pytest.approx(y_true.sum().item()) + assert n == len(y_true) * idist.get_world_size() - y_true = idist.all_reduce(y_true) - a, n = mean_var.compute() - assert a.item() == pytest.approx(y_true.sum().item()) - assert n == len(y_true) * idist.get_world_size() - # check if call compute twice - a, n = mean_var.compute() - assert a.item() == pytest.approx(y_true.sum().item()) - assert n == len(y_true) * idist.get_world_size() + mean_var = VariableAccumulation(lambda a, x: a + x, device=metric_device) + y_true = torch.rand(50, 10, device=device, dtype=torch.float64) - mean_var = VariableAccumulation(lambda a, x: a + x, device=device) - y_true = torch.rand(50, 10, device=device, dtype=torch.float64) + for y in y_true: + mean_var.update(y) - for y in y_true: - mean_var.update(y) + y_true = idist.all_reduce(y_true) + a, n = mean_var.compute() + assert n == len(y_true) * idist.get_world_size() + np.testing.assert_almost_equal(a.cpu().numpy(), y_true.sum(dim=0).cpu().numpy(), decimal=4) + a, n = mean_var.compute() + assert n == len(y_true) * idist.get_world_size() + np.testing.assert_almost_equal(a.cpu().numpy(), y_true.sum(dim=0).cpu().numpy(), decimal=4) - y_true = idist.all_reduce(y_true) - a, n = mean_var.compute() - assert n == len(y_true) * idist.get_world_size() - np.testing.assert_almost_equal(a.cpu().numpy(), y_true.sum(dim=0).cpu().numpy(), decimal=4) - a, n = mean_var.compute() - assert n == len(y_true) * idist.get_world_size() - np.testing.assert_almost_equal(a.cpu().numpy(), y_true.sum(dim=0).cpu().numpy(), decimal=4) + # check multiple random inputs as random exact occurencies are rare + for _ in range(3): + _test("cpu") + if device.type != "xla": + _test(idist.device()) def _test_distrib_average(device): + def _test(metric_device): + with pytest.raises(NotComputableError): + v = Average(device=metric_device) + v.compute() - with pytest.raises(NotComputableError): - v = Average(device=device) - v.compute() + mean_var = Average(device=metric_device) + y_true = torch.rand(100, dtype=torch.float64) + torch.randint(0, 10, size=(100,)).double() + y_true = y_true.to(device) - mean_var = Average(device=device) - y_true = torch.rand(100, dtype=torch.float64) + torch.randint(0, 10, size=(100,)).double() - y_true = y_true.to(device) + for y in y_true: + mean_var.update(y) - for y in y_true: - mean_var.update(y) + m = mean_var.compute() - m = mean_var.compute() + y_true = idist.all_reduce(y_true) + assert m.item() == pytest.approx(y_true.mean().item() / idist.get_world_size()) - y_true = idist.all_reduce(y_true) - assert m.item() == pytest.approx(y_true.mean().item() / idist.get_world_size()) + mean_var = Average(device=metric_device) + y_true = torch.rand(100, 10, dtype=torch.float64) + torch.randint(0, 10, size=(100, 10)).double() + y_true = y_true.to(device) - mean_var = Average(device=device) - y_true = torch.rand(100, 10, dtype=torch.float64) + torch.randint(0, 10, size=(100, 10)).double() - y_true = y_true.to(device) + for y in y_true: + mean_var.update(y) - for y in y_true: - mean_var.update(y) + m = mean_var.compute() - m = mean_var.compute() + y_true = idist.all_reduce(y_true) + np.testing.assert_almost_equal( + m.cpu().numpy(), y_true.mean(dim=0).cpu().numpy() / idist.get_world_size(), decimal=5 + ) - y_true = idist.all_reduce(y_true) - np.testing.assert_almost_equal( - m.cpu().numpy(), y_true.mean(dim=0).cpu().numpy() / idist.get_world_size(), decimal=5 - ) + # check multiple random inputs as random exact occurencies are rare + for _ in range(3): + _test("cpu") + if device.type != "xla": + _test(idist.device()) def _test_distrib_geom_average(device): + def _test(metric_device): + with pytest.raises(NotComputableError): + v = GeometricAverage(device=metric_device) + v.compute() - with pytest.raises(NotComputableError): - v = GeometricAverage(device=device) - v.compute() + decimal = 5 if device.type != "xla" else 4 - mean_var = GeometricAverage(device=device) - y_true = torch.rand(100, dtype=torch.float64) + torch.randint(0, 10, size=(100,)).double() - y_true = y_true.to(device) + mean_var = GeometricAverage(device=metric_device) + y_true = torch.rand(100, dtype=torch.float64) + torch.randint(0, 10, size=(100,)).double() + y_true = y_true.to(device) - for y in y_true: - mean_var.update(y) + for y in y_true: + mean_var.update(y) - m = mean_var.compute() - log_y_true = torch.log(y_true) - log_y_true = idist.all_reduce(log_y_true) - assert m.item() == pytest.approx(torch.exp(log_y_true.mean(dim=0) / idist.get_world_size()).item()) + m = mean_var.compute() + log_y_true = torch.log(y_true) + log_y_true = idist.all_reduce(log_y_true) + np.testing.assert_almost_equal( + m.item(), torch.exp(log_y_true.mean(dim=0) / idist.get_world_size()).item(), decimal=decimal + ) - mean_var = GeometricAverage(device=device) - y_true = torch.rand(100, 10, dtype=torch.float64) + torch.randint(0, 10, size=(100, 10)).double() - y_true = y_true.to(device) + mean_var = GeometricAverage(device=metric_device) + y_true = torch.rand(100, 10, dtype=torch.float64) + torch.randint(0, 10, size=(100, 10)).double() + y_true = y_true.to(device) - for y in y_true: - mean_var.update(y) + for y in y_true: + mean_var.update(y) - m = mean_var.compute() - log_y_true = torch.log(y_true) - log_y_true = idist.all_reduce(log_y_true) - np.testing.assert_almost_equal( - m.cpu().numpy(), torch.exp(log_y_true.mean(dim=0) / idist.get_world_size()).cpu().numpy(), decimal=5 - ) + m = mean_var.compute() + log_y_true = torch.log(y_true) + log_y_true = idist.all_reduce(log_y_true) + np.testing.assert_almost_equal( + m.cpu().numpy(), torch.exp(log_y_true.mean(dim=0) / idist.get_world_size()).cpu().numpy(), decimal=decimal + ) + + # check multiple random inputs as random exact occurencies are rare + for _ in range(3): + _test("cpu") + if device.type != "xla": + _test(idist.device()) def _test_distrib_integration(device): - def _test(metric_cls, true_result_fn, tol=1e-5): + def _test(metric_cls, true_result_fn, metric_device, tol=1e-5): size = 100 custom_variable = 10.0 + 5.0 * torch.rand(size, 12, dtype=torch.float64) @@ -307,7 +329,7 @@ def update_fn(engine, batch): engine = Engine(update_fn) - custom_var_mean = metric_cls(output_transform=lambda output: output[1], device=device) + custom_var_mean = metric_cls(output_transform=lambda output: output[1], device=metric_device) custom_var_mean.attach(engine, "agg_custom_var") state = engine.run([0] * size) @@ -326,7 +348,7 @@ def update_fn(engine, batch): engine = Engine(update_fn) - custom_var_mean = metric_cls(output_transform=lambda output: output[1], device=device) + custom_var_mean = metric_cls(output_transform=lambda output: output[1], device=metric_device) custom_var_mean.attach(engine, "agg_custom_var") state = engine.run([0] * size) @@ -342,8 +364,31 @@ def _geom_mean(y_true): np_t = log_y_true.cpu().numpy() return np.exp(np.mean(np_t, axis=0) / idist.get_world_size()) - _test(Average, _mean) - _test(GeometricAverage, _geom_mean, tol=1e-4) + metric_devices = ["cpu"] + if device.type != "xla": + metric_devices.append(idist.device()) + for metric_device in metric_devices: + _test(Average, _mean, metric_device) + _test(GeometricAverage, _geom_mean, metric_device, tol=1e-4) + + +def _test_distrib_accumulator_device(device): + + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(idist.device()) + for metric_device in metric_devices: + + m = VariableAccumulation(lambda a, x: x, device=metric_device) + assert m._device == metric_device + assert m.accumulator.device == metric_device, "{}:{} vs {}:{}".format( + type(m.accumulator.device), m.accumulator.device, type(metric_device), metric_device + ) + + m.update(torch.tensor(1, device=device)) + assert m.accumulator.device == metric_device, "{}:{} vs {}:{}".format( + type(m.accumulator.device), m.accumulator.device, type(metric_device), metric_device + ) @pytest.mark.distributed @@ -351,33 +396,36 @@ def _geom_mean(y_true): @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") def test_distrib_gpu(distributed_context_single_node_nccl): - device = "cuda:{}".format(distributed_context_single_node_nccl["local_rank"]) + device = torch.device("cuda:{}".format(distributed_context_single_node_nccl["local_rank"])) _test_distrib_variable_accumulation(device) _test_distrib_average(device) _test_distrib_geom_average(device) _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") def test_distrib_cpu(distributed_context_single_node_gloo): - device = "cpu" + device = torch.device("cpu") _test_distrib_variable_accumulation(device) _test_distrib_average(device) _test_distrib_geom_average(device) _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.multinode_distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") def test_multinode_distrib_cpu(distributed_context_multi_node_gloo): - device = "cpu" + device = torch.device("cpu") _test_distrib_variable_accumulation(device) _test_distrib_average(device) _test_distrib_geom_average(device) _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.distributed @@ -385,24 +433,26 @@ def test_multinode_distrib_cpu(distributed_context_multi_node_gloo): @pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc") def test_distrib_hvd(gloo_hvd_executor): - device = "cpu" if not torch.cuda.is_available() else "cuda" + device = torch.device("cpu" if not torch.cuda.is_available() else "cuda") nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count() gloo_hvd_executor(_test_distrib_variable_accumulation, (device,), np=nproc, do_init=True) gloo_hvd_executor(_test_distrib_average, (device,), np=nproc, do_init=True) gloo_hvd_executor(_test_distrib_geom_average, (device,), np=nproc, do_init=True) gloo_hvd_executor(_test_distrib_integration, (device,), np=nproc, do_init=True) + gloo_hvd_executor(_test_distrib_accumulator_device, (device,), np=nproc, do_init=True) @pytest.mark.multinode_distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): - device = "cuda:{}".format(distributed_context_multi_node_nccl["local_rank"]) + device = torch.device("cuda:{}".format(distributed_context_multi_node_nccl["local_rank"])) _test_distrib_variable_accumulation(device) _test_distrib_average(device) _test_distrib_geom_average(device) _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.tpu @@ -414,6 +464,7 @@ def test_distrib_single_device_xla(): _test_distrib_average(device) _test_distrib_geom_average(device) _test_distrib_integration(device) + _test_distrib_accumulator_device(device) def _test_distrib_xla_nprocs(index): @@ -422,6 +473,7 @@ def _test_distrib_xla_nprocs(index): _test_distrib_average(device) _test_distrib_geom_average(device) _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.tpu diff --git a/tests/ignite/metrics/test_accuracy.py b/tests/ignite/metrics/test_accuracy.py index 3ca32257d7bf..3960a09ec7f7 100644 --- a/tests/ignite/metrics/test_accuracy.py +++ b/tests/ignite/metrics/test_accuracy.py @@ -612,14 +612,19 @@ def _test_distrib_multilabel_input_NHW(device): rank = idist.get_rank() - def _test(): - acc = Accuracy(is_multilabel=True) + def _test(metric_device): + metric_device = torch.device(metric_device) + acc = Accuracy(is_multilabel=True, device=metric_device) torch.manual_seed(10 + rank) y_pred = torch.randint(0, 2, size=(4, 5, 8, 10), device=device).long() y = torch.randint(0, 2, size=(4, 5, 8, 10), device=device).long() acc.update((y_pred, y)) + assert acc._num_correct.device == metric_device, "{}:{} vs {}:{}".format( + type(acc._num_correct.device), acc._num_correct.device, type(metric_device), metric_device + ) + # gather y_pred, y y_pred = idist.all_gather(y_pred) y = idist.all_gather(y) @@ -639,6 +644,10 @@ def _test(): y = torch.randint(0, 2, size=(4, 7, 10, 8), device=device).long() acc.update((y_pred, y)) + assert acc._num_correct.device == metric_device, "{}:{} vs {}:{}".format( + type(acc._num_correct.device), acc._num_correct.device, type(metric_device), metric_device + ) + # gather y_pred, y y_pred = idist.all_gather(y_pred) y = idist.all_gather(y) @@ -671,6 +680,10 @@ def _test(): idx = i * batch_size acc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) + assert acc._num_correct.device == metric_device, "{}:{} vs {}:{}".format( + type(acc._num_correct.device), acc._num_correct.device, type(metric_device), metric_device + ) + # gather y_pred, y y_pred = idist.all_gather(y_pred) y = idist.all_gather(y) @@ -686,8 +699,10 @@ def _test(): assert accuracy_score(np_y, np_y_pred) == pytest.approx(res) # check multiple random inputs as random exact occurencies are rare - for _ in range(5): - _test() + for _ in range(3): + _test("cpu") + if device.type != "xla": + _test(idist.device()) def _test_distrib_integration_multiclass(device): @@ -697,7 +712,8 @@ def _test_distrib_integration_multiclass(device): rank = idist.get_rank() torch.manual_seed(12) - def _test(n_epochs): + def _test(n_epochs, metric_device): + metric_device = torch.device(metric_device) n_iters = 80 s = 16 n_classes = 10 @@ -714,12 +730,16 @@ def update(engine, i): engine = Engine(update) - acc = Accuracy() + acc = Accuracy(device=metric_device) acc.attach(engine, "acc") data = list(range(n_iters)) engine.run(data=data, max_epochs=n_epochs) + assert acc._num_correct.device == metric_device, "{}:{} vs {}:{}".format( + type(acc._num_correct.device), acc._num_correct.device, type(metric_device), metric_device + ) + assert "acc" in engine.state.metrics res = engine.state.metrics["acc"] if isinstance(res, torch.Tensor): @@ -729,9 +749,13 @@ def update(engine, i): assert pytest.approx(res) == true_res - for _ in range(3): - _test(n_epochs=1) - _test(n_epochs=2) + metric_devices = ["cpu"] + if device.type != "xla": + metric_devices.append(idist.device()) + for metric_device in metric_devices: + for _ in range(2): + _test(n_epochs=1, metric_device=metric_device) + _test(n_epochs=2, metric_device=metric_device) def _test_distrib_integration_multilabel(device): @@ -741,7 +765,8 @@ def _test_distrib_integration_multilabel(device): rank = idist.get_rank() torch.manual_seed(12) - def _test(n_epochs): + def _test(n_epochs, metric_device): + metric_device = torch.device(metric_device) n_iters = 80 s = 16 n_classes = 10 @@ -758,12 +783,16 @@ def update(engine, i): engine = Engine(update) - acc = Accuracy(is_multilabel=True) + acc = Accuracy(is_multilabel=True, device=metric_device) acc.attach(engine, "acc") data = list(range(n_iters)) engine.run(data=data, max_epochs=n_epochs) + assert acc._num_correct.device == metric_device, "{}:{} vs {}:{}".format( + type(acc._num_correct.device), acc._num_correct.device, type(metric_device), metric_device + ) + assert "acc" in engine.state.metrics res = engine.state.metrics["acc"] if isinstance(res, torch.Tensor): @@ -773,29 +802,57 @@ def update(engine, i): assert pytest.approx(res) == true_res - for _ in range(3): - _test(n_epochs=1) - _test(n_epochs=2) + metric_devices = ["cpu"] + if device.type != "xla": + metric_devices.append(idist.device()) + for metric_device in metric_devices: + for _ in range(2): + _test(n_epochs=1, metric_device=metric_device) + _test(n_epochs=2, metric_device=metric_device) + + +def _test_distrib_accumulator_device(device): + + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(idist.device()) + for metric_device in metric_devices: + + acc = Accuracy(device=metric_device) + assert acc._device == metric_device + assert acc._num_correct.device == metric_device, "{}:{} vs {}:{}".format( + type(acc._num_correct.device), acc._num_correct.device, type(metric_device), metric_device + ) + + y_pred = torch.randint(0, 2, size=(10,), device=device, dtype=torch.long) + y = torch.randint(0, 2, size=(10,), device=device, dtype=torch.long) + acc.update((y_pred, y)) + + assert acc._num_correct.device == metric_device, "{}:{} vs {}:{}".format( + type(acc._num_correct.device), acc._num_correct.device, type(metric_device), metric_device + ) @pytest.mark.distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") def test_distrib_gpu(distributed_context_single_node_nccl): - device = "cuda:{}".format(distributed_context_single_node_nccl["local_rank"]) + device = torch.device("cuda:{}".format(distributed_context_single_node_nccl["local_rank"])) _test_distrib_multilabel_input_NHW(device) _test_distrib_integration_multiclass(device) _test_distrib_integration_multilabel(device) + _test_distrib_accumulator_device(device) @pytest.mark.distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") def test_distrib_cpu(distributed_context_single_node_gloo): - device = "cpu" + device = torch.device("cpu") _test_distrib_multilabel_input_NHW(device) _test_distrib_integration_multiclass(device) _test_distrib_integration_multilabel(device) + _test_distrib_accumulator_device(device) @pytest.mark.distributed @@ -803,32 +860,35 @@ def test_distrib_cpu(distributed_context_single_node_gloo): @pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc") def test_distrib_hvd(gloo_hvd_executor): - device = "cpu" if not torch.cuda.is_available() else "cuda" + device = torch.device("cpu" if not torch.cuda.is_available() else "cuda") nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count() gloo_hvd_executor(_test_distrib_multilabel_input_NHW, (device,), np=nproc, do_init=True) gloo_hvd_executor(_test_distrib_integration_multiclass, (device,), np=nproc, do_init=True) gloo_hvd_executor(_test_distrib_integration_multilabel, (device,), np=nproc, do_init=True) + gloo_hvd_executor(_test_distrib_accumulator_device, (device,), np=nproc, do_init=True) @pytest.mark.multinode_distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") def test_multinode_distrib_cpu(distributed_context_multi_node_gloo): - device = "cpu" + device = torch.device("cpu") _test_distrib_multilabel_input_NHW(device) _test_distrib_integration_multiclass(device) _test_distrib_integration_multilabel(device) + _test_distrib_accumulator_device(device) @pytest.mark.multinode_distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): - device = "cuda:{}".format(distributed_context_multi_node_nccl["local_rank"]) + device = torch.device("cuda:{}".format(distributed_context_multi_node_nccl["local_rank"])) _test_distrib_multilabel_input_NHW(device) _test_distrib_integration_multiclass(device) _test_distrib_integration_multilabel(device) + _test_distrib_accumulator_device(device) @pytest.mark.tpu @@ -839,6 +899,7 @@ def test_distrib_single_device_xla(): _test_distrib_multilabel_input_NHW(device) _test_distrib_integration_multiclass(device) _test_distrib_integration_multilabel(device) + _test_distrib_accumulator_device(device) def _test_distrib_xla_nprocs(index): @@ -846,6 +907,7 @@ def _test_distrib_xla_nprocs(index): _test_distrib_multilabel_input_NHW(device) _test_distrib_integration_multiclass(device) _test_distrib_integration_multilabel(device) + _test_distrib_accumulator_device(device) @pytest.mark.tpu diff --git a/tests/ignite/metrics/test_confusion_matrix.py b/tests/ignite/metrics/test_confusion_matrix.py index 2960a59b9d02..8da41e07a411 100644 --- a/tests/ignite/metrics/test_confusion_matrix.py +++ b/tests/ignite/metrics/test_confusion_matrix.py @@ -547,68 +547,94 @@ def test_dice_coefficient(): def _test_distrib_multiclass_images(device): + def _test(metric_device): + num_classes = 3 + cm = ConfusionMatrix(num_classes=num_classes, device=metric_device) - num_classes = 3 - cm = ConfusionMatrix(num_classes=num_classes, device=device) + y_true, y_pred = get_y_true_y_pred() - y_true, y_pred = get_y_true_y_pred() + # Compute confusion matrix with sklearn + true_res = confusion_matrix(y_true.reshape(-1), y_pred.reshape(-1)) - # Compute confusion matrix with sklearn - true_res = confusion_matrix(y_true.reshape(-1), y_pred.reshape(-1)) + th_y_true, th_y_logits = compute_th_y_true_y_logits(y_true, y_pred) + th_y_true = th_y_true.to(device) + th_y_logits = th_y_logits.to(device) - th_y_true, th_y_logits = compute_th_y_true_y_logits(y_true, y_pred) - th_y_true = th_y_true.to(device) - th_y_logits = th_y_logits.to(device) + # Update metric + output = (th_y_logits, th_y_true) + cm.update(output) - # Update metric - output = (th_y_logits, th_y_true) - cm.update(output) + res = cm.compute().cpu().numpy() / idist.get_world_size() - res = cm.compute().cpu().numpy() / idist.get_world_size() + assert np.all(true_res == res) - assert np.all(true_res == res) + # Another test on batch of 2 images + num_classes = 3 + cm = ConfusionMatrix(num_classes=num_classes, device=metric_device) + + # Create a batch of two images: + th_y_true1 = torch.from_numpy(y_true).reshape(1, 30, 30) + th_y_true2 = torch.from_numpy(y_true.transpose()).reshape(1, 30, 30) + th_y_true = torch.cat([th_y_true1, th_y_true2], dim=0) + th_y_true = th_y_true.to(device) + + # Create a batch of 2 logits tensors + y_probas = np.ones((3, 30, 30)) * -10 + y_probas[0, (y_pred == 0)] = 720 + y_probas[1, (y_pred == 1)] = 720 + y_probas[2, (y_pred == 2)] = 768 + th_y_logits1 = torch.from_numpy(y_probas).reshape(1, 3, 30, 30) + + y_probas = np.ones((3, 30, 30)) * -10 + y_probas[0, (y_pred.transpose() == 0)] = 720 + y_probas[1, (y_pred.transpose() == 2)] = 720 + y_probas[2, (y_pred.transpose() == 1)] = 768 + th_y_logits2 = torch.from_numpy(y_probas).reshape(1, 3, 30, 30) + + th_y_logits = torch.cat([th_y_logits1, th_y_logits2], dim=0) + # check update if input is on another device + th_y_logits = th_y_logits.to(device) + + # Update metric & compute + output = (th_y_logits, th_y_true) + cm.update(output) + res = cm.compute().cpu().numpy() - # Another test on batch of 2 images - num_classes = 3 - cm = ConfusionMatrix(num_classes=num_classes, device=device) + # Compute confusion matrix with sklearn + th_y_true = idist.all_gather(th_y_true) + th_y_logits = idist.all_gather(th_y_logits) - # Create a batch of two images: - th_y_true1 = torch.from_numpy(y_true).reshape(1, 30, 30) - th_y_true2 = torch.from_numpy(y_true.transpose()).reshape(1, 30, 30) - th_y_true = torch.cat([th_y_true1, th_y_true2], dim=0) - th_y_true = th_y_true.to(device) + np_y_true = th_y_true.cpu().numpy().reshape(-1) + np_y_pred = np.argmax(th_y_logits.cpu().numpy(), axis=1).reshape(-1) + true_res = confusion_matrix(np_y_true, np_y_pred) - # Create a batch of 2 logits tensors - y_probas = np.ones((3, 30, 30)) * -10 - y_probas[0, (y_pred == 0)] = 720 - y_probas[1, (y_pred == 1)] = 720 - y_probas[2, (y_pred == 2)] = 768 - th_y_logits1 = torch.from_numpy(y_probas).reshape(1, 3, 30, 30) + assert np.all(true_res == res) - y_probas = np.ones((3, 30, 30)) * -10 - y_probas[0, (y_pred.transpose() == 0)] = 720 - y_probas[1, (y_pred.transpose() == 2)] = 720 - y_probas[2, (y_pred.transpose() == 1)] = 768 - th_y_logits2 = torch.from_numpy(y_probas).reshape(1, 3, 30, 30) + _test("cpu") + if device.type != "xla": + _test(idist.device()) - th_y_logits = torch.cat([th_y_logits1, th_y_logits2], dim=0) - # check update if input is on another device - th_y_logits = th_y_logits.to(device) - # Update metric & compute - output = (th_y_logits, th_y_true) - cm.update(output) - res = cm.compute().cpu().numpy() +def _test_distrib_accumulator_device(device): - # Compute confusion matrix with sklearn - th_y_true = idist.all_gather(th_y_true) - th_y_logits = idist.all_gather(th_y_logits) + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(idist.device()) + for metric_device in metric_devices: - np_y_true = th_y_true.cpu().numpy().reshape(-1) - np_y_pred = np.argmax(th_y_logits.cpu().numpy(), axis=1).reshape(-1) - true_res = confusion_matrix(np_y_true, np_y_pred) + cm = ConfusionMatrix(num_classes=3, device=metric_device) + assert cm._device == metric_device + assert cm.confusion_matrix.device == metric_device, "{}:{} vs {}:{}".format( + type(cm.confusion_matrix.device), cm._num_correct.device, type(metric_device), metric_device + ) - assert np.all(true_res == res) + y_true, y_pred = get_y_true_y_pred() + th_y_true, th_y_logits = compute_th_y_true_y_logits(y_true, y_pred) + cm.update((th_y_logits, th_y_true)) + + assert cm.confusion_matrix.device == metric_device, "{}:{} vs {}:{}".format( + type(cm.confusion_matrix.device), acc._num_correct.device, type(metric_device), metric_device + ) @pytest.mark.distributed @@ -616,16 +642,18 @@ def _test_distrib_multiclass_images(device): @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") def test_distrib_gpu(local_rank, distributed_context_single_node_nccl): - device = "cuda:{}".format(local_rank) + device = torch.device("cuda:{}".format(local_rank)) _test_distrib_multiclass_images(device) + _test_distrib_accumulator_device(device) @pytest.mark.distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") def test_distrib_cpu(distributed_context_single_node_gloo): - device = "cpu" + device = torch.device("cpu") _test_distrib_multiclass_images(device) + _test_distrib_accumulator_device(device) @pytest.mark.distributed @@ -633,26 +661,29 @@ def test_distrib_cpu(distributed_context_single_node_gloo): @pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc") def test_distrib_hvd(gloo_hvd_executor): - device = "cpu" if not torch.cuda.is_available() else "cuda" + device = torch.device("cpu" if not torch.cuda.is_available() else "cuda") nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count() gloo_hvd_executor(_test_distrib_multiclass_images, (device,), np=nproc, do_init=True) + gloo_hvd_executor(_test_distrib_accumulator_device, (device,), np=nproc, do_init=True) @pytest.mark.multinode_distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") def test_multinode_distrib_cpu(distributed_context_multi_node_gloo): - device = "cpu" + device = torch.device("cpu") _test_distrib_multiclass_images(device) + _test_distrib_accumulator_device(device) @pytest.mark.multinode_distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): - device = "cuda:{}".format(distributed_context_multi_node_nccl["local_rank"]) + device = torch.device("cuda:{}".format(distributed_context_multi_node_nccl["local_rank"])) _test_distrib_multiclass_images(device) + _test_distrib_accumulator_device(device) @pytest.mark.tpu @@ -661,11 +692,13 @@ def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): def test_distrib_single_device_xla(): device = idist.device() _test_distrib_multiclass_images(device) + _test_distrib_accumulator_device(device) def _test_distrib_xla_nprocs(index): device = idist.device() _test_distrib_multiclass_images(device) + _test_distrib_accumulator_device(device) @pytest.mark.tpu diff --git a/tests/ignite/metrics/test_epoch_metric.py b/tests/ignite/metrics/test_epoch_metric.py index fc488ad92ebb..ce0e4d719a12 100644 --- a/tests/ignite/metrics/test_epoch_metric.py +++ b/tests/ignite/metrics/test_epoch_metric.py @@ -161,7 +161,7 @@ def compute_fn(y_preds, y_targets): def _test_distrib_integration(device=None): if device is None: - device = idist.device() + device = idist.device() if idist.device().type != "xla" else "cpu" rank = idist.get_rank() torch.manual_seed(12) diff --git a/tests/ignite/metrics/test_fbeta.py b/tests/ignite/metrics/test_fbeta.py index 53b883c451af..da4a8df58c23 100644 --- a/tests/ignite/metrics/test_fbeta.py +++ b/tests/ignite/metrics/test_fbeta.py @@ -94,7 +94,7 @@ def _test_distrib_integration(device): rank = idist.get_rank() torch.manual_seed(12) - def _test(p, r, average, n_epochs): + def _test(p, r, average, n_epochs, metric_device): n_iters = 60 s = 16 n_classes = 7 @@ -111,7 +111,7 @@ def update(engine, i): engine = Engine(update) - fbeta = Fbeta(beta=2.5, average=average, device=device) + fbeta = Fbeta(beta=2.5, average=average, device=metric_device) fbeta.attach(engine, "f2.5") data = list(range(n_iters)) @@ -131,26 +131,30 @@ def update(engine, i): assert pytest.approx(res) == true_res - _test(None, None, average=True, n_epochs=1) - _test(None, None, average=True, n_epochs=2) - precision = Precision(average=False) - recall = Recall(average=False) - _test(precision, recall, average=False, n_epochs=1) - _test(precision, recall, average=False, n_epochs=2) + metric_devices = ["cpu"] + if device.type != "xla": + metric_devices.append(idist.device()) + for metric_device in metric_devices: + _test(None, None, average=True, n_epochs=1, metric_device=metric_device) + _test(None, None, average=True, n_epochs=2, metric_device=metric_device) + precision = Precision(average=False, device=metric_device) + recall = Recall(average=False, device=metric_device) + _test(precision, recall, average=False, n_epochs=1, metric_device=metric_device) + _test(precision, recall, average=False, n_epochs=2, metric_device=metric_device) @pytest.mark.distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") def test_distrib_gpu(local_rank, distributed_context_single_node_nccl): - device = "cuda:{}".format(local_rank) + device = torch.device("cuda:{}".format(local_rank)) _test_distrib_integration(device) @pytest.mark.distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") def test_distrib_cpu(distributed_context_single_node_gloo): - device = "cpu" + device = torch.device("cpu") _test_distrib_integration(device) @@ -159,7 +163,7 @@ def test_distrib_cpu(distributed_context_single_node_gloo): @pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc") def test_distrib_hvd(gloo_hvd_executor): - device = "cpu" if not torch.cuda.is_available() else "cuda" + device = torch.device("cpu" if not torch.cuda.is_available() else "cuda") nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count() gloo_hvd_executor(_test_distrib_integration, (device,), np=nproc, do_init=True) @@ -169,7 +173,7 @@ def test_distrib_hvd(gloo_hvd_executor): @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") def test_multinode_distrib_cpu(distributed_context_multi_node_gloo): - device = "cpu" + device = torch.device("cpu") _test_distrib_integration(device) @@ -177,7 +181,7 @@ def test_multinode_distrib_cpu(distributed_context_multi_node_gloo): @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): - device = "cuda:{}".format(distributed_context_multi_node_nccl["local_rank"]) + device = torch.device("cuda:{}".format(distributed_context_multi_node_nccl["local_rank"])) _test_distrib_integration(device) diff --git a/tests/ignite/metrics/test_loss.py b/tests/ignite/metrics/test_loss.py index 5244e7991d7d..35638684916d 100644 --- a/tests/ignite/metrics/test_loss.py +++ b/tests/ignite/metrics/test_loss.py @@ -76,35 +76,70 @@ def test_reset(): def _test_distrib_compute_on_criterion(device): + def _test(metric_device): + criterion = nn.NLLLoss().to(device) + loss = Loss(criterion, device=metric_device) - criterion = nn.NLLLoss().to(device) - loss = Loss(criterion, device=device) + y_pred = torch.tensor([[0.1, 0.4, 0.5], [0.1, 0.7, 0.2]], device=device).log() + y = torch.tensor([2, 2], device=device).long() + loss.update((y_pred, y)) + n = loss._num_examples + assert n == len(y) + res = loss.compute() + assert n * idist.get_world_size() == loss._num_examples + + y_pred = idist.all_gather(y_pred) + y = idist.all_gather(y) + true_loss_value = criterion(y_pred, y) + assert_almost_equal(res, true_loss_value.item()) + + loss.reset() + y_pred = torch.tensor([[0.1, 0.3, 0.6], [0.6, 0.2, 0.2], [0.2, 0.7, 0.1]], device=device).log() + y = torch.tensor([2, 0, 2], device=device).long() + loss.update((y_pred, y)) + n = loss._num_examples + res = loss.compute() + assert n * idist.get_world_size() == loss._num_examples + + y_pred = idist.all_gather(y_pred) + y = idist.all_gather(y) + true_loss_value = criterion(y_pred, y) + assert_almost_equal(res, true_loss_value.item()) + + _test("cpu") + if device.type != "xla": + _test(idist.device()) + + +def _test_distrib_accumulator_device(device): + + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(idist.device()) + for metric_device in metric_devices: + loss = Loss(nll_loss, device=metric_device) + assert loss._device == metric_device + assert loss._sum.device == metric_device, "{}:{} vs {}:{}".format( + type(loss._sum.device), loss._sum.device, type(metric_device), metric_device + ) + + y_pred = torch.tensor([[0.1, 0.4, 0.5], [0.1, 0.7, 0.2]]).log() + y = torch.tensor([2, 2]).long() + loss.update((y_pred, y)) - y_pred = torch.tensor([[0.1, 0.4, 0.5], [0.1, 0.7, 0.2]], device=device).log() - y = torch.tensor([2, 2], device=device).long() - loss.update((y_pred, y)) - n = loss._num_examples - assert n == len(y) - res = loss.compute() - assert n * idist.get_world_size() == loss._num_examples + assert loss._sum.device == metric_device, "{}:{} vs {}:{}".format( + type(loss._sum.device), loss._sum.device, type(metric_device), metric_device + ) - y_pred = idist.all_gather(y_pred) - y = idist.all_gather(y) - true_loss_value = criterion(y_pred, y) - assert_almost_equal(res, true_loss_value.item()) - loss.reset() - y_pred = torch.tensor([[0.1, 0.3, 0.6], [0.6, 0.2, 0.2], [0.2, 0.7, 0.1]], device=device).log() - y = torch.tensor([2, 0, 2], device=device).long() +def test_sum_detached(): + loss = Loss(nll_loss) + + y_pred = torch.tensor([[0.1, 0.4, 0.5], [0.1, 0.7, 0.2]], requires_grad=True).log() + y = torch.tensor([2, 2]).long() loss.update((y_pred, y)) - n = loss._num_examples - res = loss.compute() - assert n * idist.get_world_size() == loss._num_examples - y_pred = idist.all_gather(y_pred) - y = idist.all_gather(y) - true_loss_value = criterion(y_pred, y) - assert_almost_equal(res, true_loss_value.item()) + assert not loss._sum.requires_grad @pytest.mark.distributed @@ -112,16 +147,18 @@ def _test_distrib_compute_on_criterion(device): @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") def test_distrib_gpu(local_rank, distributed_context_single_node_nccl): - device = "cuda:{}".format(local_rank) + device = torch.device("cuda:{}".format(local_rank)) _test_distrib_compute_on_criterion(device) + _test_distrib_accumulator_device(device) @pytest.mark.distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") def test_distrib_cpu(distributed_context_single_node_gloo): - device = "cpu" + device = torch.device("cpu") _test_distrib_compute_on_criterion(device) + _test_distrib_accumulator_device(device) @pytest.mark.distributed @@ -129,26 +166,29 @@ def test_distrib_cpu(distributed_context_single_node_gloo): @pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc") def test_distrib_hvd(gloo_hvd_executor): - device = "cpu" if not torch.cuda.is_available() else "cuda" + device = torch.device("cpu" if not torch.cuda.is_available() else "cuda") nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count() gloo_hvd_executor(_test_distrib_compute_on_criterion, (device,), np=nproc, do_init=True) + gloo_hvd_executor(_test_distrib_accumulator_device, (device,), np=nproc, do_init=True) @pytest.mark.multinode_distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") def test_multinode_distrib_cpu(distributed_context_multi_node_gloo): - device = "cpu" + device = torch.device("cpu") _test_distrib_compute_on_criterion(device) + _test_distrib_accumulator_device(device) @pytest.mark.multinode_distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): - device = "cuda:{}".format(distributed_context_multi_node_nccl["local_rank"]) + device = torch.device("cuda:{}".format(distributed_context_multi_node_nccl["local_rank"])) _test_distrib_compute_on_criterion(device) + _test_distrib_accumulator_device(device) @pytest.mark.tpu @@ -157,11 +197,13 @@ def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): def test_distrib_single_device_xla(): device = idist.device() _test_distrib_compute_on_criterion(device) + _test_distrib_accumulator_device(device) def _test_distrib_xla_nprocs(index): device = idist.device() _test_distrib_compute_on_criterion(device) + _test_distrib_accumulator_device(device) @pytest.mark.tpu diff --git a/tests/ignite/metrics/test_mean_absolute_error.py b/tests/ignite/metrics/test_mean_absolute_error.py index 8279470dd07e..c47a2c281b7f 100644 --- a/tests/ignite/metrics/test_mean_absolute_error.py +++ b/tests/ignite/metrics/test_mean_absolute_error.py @@ -49,35 +49,78 @@ def update(engine, i): y_true[i * s + offset * rank : (i + 1) * s + offset * rank], ) - engine = Engine(update) + def _test(metric_device): + engine = Engine(update) - m = MeanAbsoluteError() - m.attach(engine, "mae") + m = MeanAbsoluteError(device=metric_device) + m.attach(engine, "mae") - data = list(range(n_iters)) - engine.run(data=data, max_epochs=1) + data = list(range(n_iters)) + engine.run(data=data, max_epochs=1) - assert "mae" in engine.state.metrics - res = engine.state.metrics["mae"] + assert "mae" in engine.state.metrics + res = engine.state.metrics["mae"] - true_res = np.mean(np.abs((y_true - y_preds).cpu().numpy())) + true_res = np.mean(np.abs((y_true - y_preds).cpu().numpy())) - assert pytest.approx(res) == true_res + assert pytest.approx(res) == true_res + + _test("cpu") + if device.type != "xla": + _test(idist.device()) + + +def _test_distrib_accumulator_device(device): + + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(idist.device()) + for metric_device in metric_devices: + mae = MeanAbsoluteError(device=metric_device) + assert mae._device == metric_device + assert mae._sum_of_absolute_errors.device == metric_device, "{}:{} vs {}:{}".format( + type(mae._sum_of_absolute_errors.device), + mae._sum_of_absolute_errors.device, + type(metric_device), + metric_device, + ) + + y_pred = torch.tensor([[2.0], [-2.0]]) + y = torch.zeros(2) + mae.update((y_pred, y)) + assert mae._sum_of_absolute_errors.device == metric_device, "{}:{} vs {}:{}".format( + type(mae._sum_of_absolute_errors.device), + mae._sum_of_absolute_errors.device, + type(metric_device), + metric_device, + ) + + +def test_accumulator_detached(): + mae = MeanAbsoluteError() + + y_pred = torch.tensor([[2.0], [-2.0]], requires_grad=True) + y = torch.zeros(2) + mae.update((y_pred, y)) + + assert not mae._sum_of_absolute_errors.requires_grad @pytest.mark.distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") def test_distrib_gpu(local_rank, distributed_context_single_node_nccl): - device = "cuda:{}".format(local_rank) + device = torch.device("cuda:{}".format(local_rank)) _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") def test_distrib_cpu(distributed_context_single_node_gloo): - device = "cpu" + device = torch.device("cpu") _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.distributed @@ -85,26 +128,29 @@ def test_distrib_cpu(distributed_context_single_node_gloo): @pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc") def test_distrib_hvd(gloo_hvd_executor): - device = "cpu" if not torch.cuda.is_available() else "cuda" + device = torch.device("cpu" if not torch.cuda.is_available() else "cuda") nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count() gloo_hvd_executor(_test_distrib_integration, (device,), np=nproc, do_init=True) + gloo_hvd_executor(_test_distrib_accumulator_device, (device,), np=nproc, do_init=True) @pytest.mark.multinode_distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") def test_multinode_distrib_cpu(distributed_context_multi_node_gloo): - device = "cpu" + device = torch.device("cpu") _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.multinode_distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): - device = "cuda:{}".format(distributed_context_multi_node_nccl["local_rank"]) + device = torch.device("cuda:{}".format(distributed_context_multi_node_nccl["local_rank"])) _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.tpu @@ -113,11 +159,13 @@ def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): def test_distrib_single_device_xla(): device = idist.device() _test_distrib_integration(device) + _test_distrib_accumulator_device(device) def _test_distrib_xla_nprocs(index): device = idist.device() _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.tpu diff --git a/tests/ignite/metrics/test_mean_pairwise_distance.py b/tests/ignite/metrics/test_mean_pairwise_distance.py index 7ada0b5474a5..b0db1917c1d2 100644 --- a/tests/ignite/metrics/test_mean_pairwise_distance.py +++ b/tests/ignite/metrics/test_mean_pairwise_distance.py @@ -52,45 +52,84 @@ def update(engine, i): y_true[i * s + offset * rank : (i + 1) * s + offset * rank, ...], ) - engine = Engine(update) + def _test(metric_device): + engine = Engine(update) + + m = MeanPairwiseDistance(device=metric_device) + m.attach(engine, "mpwd") + + data = list(range(n_iters)) + engine.run(data=data, max_epochs=1) + + assert "mpwd" in engine.state.metrics + res = engine.state.metrics["mpwd"] + + true_res = [] + for i in range(n_iters * idist.get_world_size()): + true_res.append( + torch.pairwise_distance( + y_true[i * s : (i + 1) * s, ...], y_preds[i * s : (i + 1) * s, ...], p=m._p, eps=m._eps + ) + .cpu() + .numpy() + ) + true_res = np.array(true_res).ravel() + true_res = true_res.mean() - m = MeanPairwiseDistance() - m.attach(engine, "mpwd") + assert pytest.approx(res) == true_res - data = list(range(n_iters)) - engine.run(data=data, max_epochs=1) + _test("cpu") + if device.type != "xla": + _test(idist.device()) - assert "mpwd" in engine.state.metrics - res = engine.state.metrics["mpwd"] - true_res = [] - for i in range(n_iters * idist.get_world_size()): - true_res.append( - torch.pairwise_distance( - y_true[i * s : (i + 1) * s, ...], y_preds[i * s : (i + 1) * s, ...], p=m._p, eps=m._eps - ) - .cpu() - .numpy() +def _test_distrib_accumulator_device(device): + + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(idist.device()) + for metric_device in metric_devices: + + mpd = MeanPairwiseDistance(device=metric_device) + assert mpd._device == metric_device + assert mpd._sum_of_distances.device == metric_device, "{}:{} vs {}:{}".format( + type(mpd._sum_of_distances.device), mpd._sum_of_distances.device, type(metric_device), metric_device + ) + + y_pred = torch.Tensor([[3.0, 4.0], [-3.0, -4.0]]) + y = torch.zeros(2, 2) + mpd.update((y_pred, y)) + + assert mpd._sum_of_distances.device == metric_device, "{}:{} vs {}:{}".format( + type(mpd._sum_of_distances.device), mpd._sum_of_distances.device, type(metric_device), metric_device ) - true_res = np.array(true_res).ravel() - true_res = true_res.mean() - assert pytest.approx(res) == true_res + +def test_accumulator_detached(): + mpd = MeanPairwiseDistance() + + y_pred = torch.tensor([[3.0, 4.0], [-3.0, -4.0]], requires_grad=True) + y = torch.zeros(2, 2) + mpd.update((y_pred, y)) + + assert not mpd._sum_of_distances.requires_grad @pytest.mark.distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") def test_distrib_gpu(local_rank, distributed_context_single_node_nccl): - device = "cuda:{}".format(local_rank) + device = torch.device("cuda:{}".format(local_rank)) _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") def test_distrib_cpu(distributed_context_single_node_gloo): - device = "cpu" + device = torch.device("cpu") _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.distributed @@ -98,26 +137,29 @@ def test_distrib_cpu(distributed_context_single_node_gloo): @pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc") def test_distrib_hvd(gloo_hvd_executor): - device = "cpu" if not torch.cuda.is_available() else "cuda" + device = torch.device("cpu" if not torch.cuda.is_available() else "cuda") nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count() gloo_hvd_executor(_test_distrib_integration, (device,), np=nproc, do_init=True) + gloo_hvd_executor(_test_distrib_accumulator_device, (device,), np=nproc, do_init=True) @pytest.mark.multinode_distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") def test_multinode_distrib_cpu(distributed_context_multi_node_gloo): - device = "cpu" + device = torch.device("cpu") _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.multinode_distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): - device = "cuda:{}".format(distributed_context_multi_node_nccl["local_rank"]) + device = torch.device("cuda:{}".format(distributed_context_multi_node_nccl["local_rank"])) _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.tpu @@ -126,11 +168,13 @@ def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): def test_distrib_single_device_xla(): device = idist.device() _test_distrib_integration(device) + _test_distrib_accumulator_device(device) def _test_distrib_xla_nprocs(index): device = idist.device() _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.tpu diff --git a/tests/ignite/metrics/test_mean_squared_error.py b/tests/ignite/metrics/test_mean_squared_error.py index 59ce1fdc567d..91a87b561974 100644 --- a/tests/ignite/metrics/test_mean_squared_error.py +++ b/tests/ignite/metrics/test_mean_squared_error.py @@ -49,20 +49,63 @@ def update(engine, i): y_true[i * s + offset * rank : (i + 1) * s + offset * rank], ) - engine = Engine(update) + def _test(metric_device): + engine = Engine(update) - m = MeanSquaredError() - m.attach(engine, "mse") + m = MeanSquaredError(device=metric_device) + m.attach(engine, "mse") - data = list(range(n_iters)) - engine.run(data=data, max_epochs=1) + data = list(range(n_iters)) + engine.run(data=data, max_epochs=1) - assert "mse" in engine.state.metrics - res = engine.state.metrics["mse"] + assert "mse" in engine.state.metrics + res = engine.state.metrics["mse"] - true_res = np.mean(np.power((y_true - y_preds).cpu().numpy(), 2.0)) + true_res = np.mean(np.power((y_true - y_preds).cpu().numpy(), 2.0)) - assert pytest.approx(res, rel=tol) == true_res + assert pytest.approx(res, rel=tol) == true_res + + _test("cpu") + if device.type != "xla": + _test(idist.device()) + + +def _test_distrib_accumulator_device(device): + + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(idist.device()) + for metric_device in metric_devices: + + device = torch.device(device) + mse = MeanSquaredError(device=metric_device) + assert mse._device == metric_device + assert mse._sum_of_squared_errors.device == metric_device, "{}:{} vs {}:{}".format( + type(mse._sum_of_squared_errors.device), + mse._sum_of_squared_errors.device, + type(metric_device), + metric_device, + ) + + y_pred = torch.tensor([[2.0], [-2.0]]) + y = torch.zeros(2) + mse.update((y_pred, y)) + assert mse._sum_of_squared_errors.device == metric_device, "{}:{} vs {}:{}".format( + type(mse._sum_of_squared_errors.device), + mse._sum_of_squared_errors.device, + type(metric_device), + metric_device, + ) + + +def test_accumulator_detached(): + mse = MeanSquaredError() + + y_pred = torch.tensor([[2.0], [-2.0]], requires_grad=True) + y = torch.zeros(2) + mse.update((y_pred, y)) + + assert not mse._sum_of_squared_errors.requires_grad @pytest.mark.distributed @@ -70,15 +113,17 @@ def update(engine, i): @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") def test_distrib_gpu(local_rank, distributed_context_single_node_nccl): - device = "cuda:{}".format(local_rank) + device = torch.device("cuda:{}".format(local_rank)) _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") def test_distrib_cpu(distributed_context_single_node_gloo): - device = "cpu" + device = torch.device("cpu") _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.distributed @@ -86,26 +131,29 @@ def test_distrib_cpu(distributed_context_single_node_gloo): @pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc") def test_distrib_hvd(gloo_hvd_executor): - device = "cpu" if not torch.cuda.is_available() else "cuda" + device = torch.device("cpu" if not torch.cuda.is_available() else "cuda") nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count() gloo_hvd_executor(_test_distrib_integration, (device,), np=nproc, do_init=True) + gloo_hvd_executor(_test_distrib_accumulator_device, (device,), np=nproc, do_init=True) @pytest.mark.multinode_distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") def test_multinode_distrib_cpu(distributed_context_multi_node_gloo): - device = "cpu" + device = torch.device("cpu") _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.multinode_distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): - device = "cuda:{}".format(distributed_context_multi_node_nccl["local_rank"]) + device = torch.device("cuda:{}".format(distributed_context_multi_node_nccl["local_rank"])) _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.tpu @@ -114,11 +162,13 @@ def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): def test_distrib_single_device_xla(): device = idist.device() _test_distrib_integration(device, tol=1e-4) + _test_distrib_accumulator_device(device) def _test_distrib_xla_nprocs(index): device = idist.device() _test_distrib_integration(device, tol=1e-4) + _test_distrib_accumulator_device(device) @pytest.mark.tpu diff --git a/tests/ignite/metrics/test_metric.py b/tests/ignite/metrics/test_metric.py index a9b5dc54d86a..31395b99ec98 100644 --- a/tests/ignite/metrics/test_metric.py +++ b/tests/ignite/metrics/test_metric.py @@ -568,13 +568,19 @@ def update(self, output): self.a += 10.0 self.b -= 5.0 - m = DummyMetric(device=device) + metric_device = device if torch.device(device).type != "xla" else "cpu" + m = DummyMetric(device=metric_device) m.update(None) m.compute() # check if can call compute multiple times without all reduce invocation m.compute() +def _test_creating_on_xla_fails(device): + with pytest.raises(ValueError, match=r"Cannot create metric on an XLA device. Use device='cpu' instead."): + DummyMetric2(device=device) + + @pytest.mark.distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") @@ -625,11 +631,13 @@ def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): def test_distrib_single_device_xla(): device = idist.device() _test_distrib_sync_all_reduce_decorator(device) + _test_creating_on_xla_fails(device) def _test_distrib_xla_nprocs(index): device = idist.device() _test_distrib_sync_all_reduce_decorator(device) + _test_creating_on_xla_fails(device) @pytest.mark.tpu diff --git a/tests/ignite/metrics/test_metrics_lambda.py b/tests/ignite/metrics/test_metrics_lambda.py index a6a0902496dd..c8e77736370c 100644 --- a/tests/ignite/metrics/test_metrics_lambda.py +++ b/tests/ignite/metrics/test_metrics_lambda.py @@ -325,7 +325,7 @@ def _test_distrib_integration(device): batch_size = 10 n_classes = 10 - def _test(): + def _test(metric_device): y_true = np.arange(0, n_iters * batch_size * idist.get_world_size(), dtype="int64") % n_classes y_pred = 0.2 * np.random.rand(n_iters * batch_size * idist.get_world_size(), n_classes) for i in range(n_iters * batch_size * idist.get_world_size()): @@ -345,8 +345,8 @@ def update_fn(engine, i): evaluator = Engine(update_fn) - precision = Precision(average=False, device=device) - recall = Recall(average=False, device=device) + precision = Precision(average=False, device=metric_device) + recall = Recall(average=False, device=metric_device) def Fbeta(r, p, beta): return torch.mean((1 + beta ** 2) * p * r / (beta ** 2 * p + r)).item() @@ -366,8 +366,37 @@ def Fbeta(r, p, beta): assert f1_true == approx(state.metrics["f1"]) assert 1.0 + f1_true == approx(state.metrics["ff1"]) - for _ in range(5): - _test() + for _ in range(3): + _test("cpu") + if device.type != "xla": + _test(idist.device()) + + +def _test_distrib_metrics_on_diff_devices(device): + n_classes = 10 + n_iters = 12 + s = 16 + offset = n_iters * s + rank = idist.get_rank() + + y_true = torch.randint(0, n_classes, size=(offset * idist.get_world_size(),)).to(device) + y_preds = torch.rand(offset * idist.get_world_size(), n_classes).to(device) + + def update(engine, i): + return ( + y_preds[i * s + rank * offset : (i + 1) * s + rank * offset], + y_true[i * s + rank * offset : (i + 1) * s + rank * offset], + ) + + precision = Precision(average=False, device="cpu") + recall = Recall(average=False, device=device) + custom_metric = precision * recall + + engine = Engine(update) + custom_metric.attach(engine, "custom_metric") + + data = list(range(n_iters)) + engine.run(data, max_epochs=2) @pytest.mark.distributed @@ -375,15 +404,16 @@ def Fbeta(r, p, beta): @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") def test_distrib_gpu(local_rank, distributed_context_single_node_nccl): - device = "cuda:{}".format(local_rank) + device = torch.device("cuda:{}".format(local_rank)) _test_distrib_integration(device) + _test_distrib_metrics_on_diff_devices(device) @pytest.mark.distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") def test_distrib_cpu(local_rank, distributed_context_single_node_gloo): - device = "cpu" + device = torch.device("cpu") _test_distrib_integration(device) @@ -392,17 +422,18 @@ def test_distrib_cpu(local_rank, distributed_context_single_node_gloo): @pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc") def test_distrib_hvd(gloo_hvd_executor): - device = "cpu" if not torch.cuda.is_available() else "cuda" + device = torch.device("cpu" if not torch.cuda.is_available() else "cuda") nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count() gloo_hvd_executor(_test_distrib_integration, (device,), np=nproc, do_init=True) + gloo_hvd_executor(_test_distrib_metrics_on_diff_devices, (device,), np=nproc, do_init=True) @pytest.mark.multinode_distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") def test_multinode_distrib_cpu(distributed_context_multi_node_gloo): - device = "cpu" + device = torch.device("cpu") _test_distrib_integration(device) @@ -410,8 +441,9 @@ def test_multinode_distrib_cpu(distributed_context_multi_node_gloo): @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): - device = "cuda:{}".format(distributed_context_multi_node_nccl["local_rank"]) + device = torch.device("cuda:{}".format(distributed_context_multi_node_nccl["local_rank"])) _test_distrib_integration(device) + _test_distrib_metrics_on_diff_devices(device) @pytest.mark.tpu diff --git a/tests/ignite/metrics/test_precision.py b/tests/ignite/metrics/test_precision.py index 94f1d643585f..493d3902b3e9 100644 --- a/tests/ignite/metrics/test_precision.py +++ b/tests/ignite/metrics/test_precision.py @@ -722,7 +722,7 @@ def _test_distrib_integration_multiclass(device): rank = idist.get_rank() torch.manual_seed(12) - def _test(average, n_epochs): + def _test(average, n_epochs, metric_device): n_iters = 60 s = 16 n_classes = 7 @@ -739,7 +739,7 @@ def update(engine, i): engine = Engine(update) - pr = Precision(average=average) + pr = Precision(average=average, device=metric_device) pr.attach(engine, "pr") data = list(range(n_iters)) @@ -748,7 +748,7 @@ def update(engine, i): assert "pr" in engine.state.metrics res = engine.state.metrics["pr"] if isinstance(res, torch.Tensor): - assert res.device.type == "cpu" + assert res.device == metric_device res = res.cpu().numpy() true_res = precision_score( @@ -757,11 +757,15 @@ def update(engine, i): assert pytest.approx(res) == true_res + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(idist.device()) for _ in range(2): - _test(average=True, n_epochs=1) - _test(average=True, n_epochs=2) - _test(average=False, n_epochs=1) - _test(average=False, n_epochs=2) + for metric_device in metric_devices: + _test(average=True, n_epochs=1, metric_device=metric_device) + _test(average=True, n_epochs=2, metric_device=metric_device) + _test(average=False, n_epochs=1, metric_device=metric_device) + _test(average=False, n_epochs=2, metric_device=metric_device) def _test_distrib_integration_multilabel(device): @@ -771,7 +775,7 @@ def _test_distrib_integration_multilabel(device): rank = idist.get_rank() torch.manual_seed(12) - def _test(average, n_epochs): + def _test(average, n_epochs, metric_device): n_iters = 60 s = 16 n_classes = 7 @@ -812,9 +816,13 @@ def update(engine, i): assert pytest.approx(res) == true_res + metric_devices = ["cpu"] + if device.type != "xla": + metric_devices.append(idist.device()) for _ in range(2): - _test(average=True, n_epochs=1) - _test(average=True, n_epochs=2) + for metric_device in metric_devices: + _test(average=True, n_epochs=1, metric_device=metric_device) + _test(average=True, n_epochs=2, metric_device=metric_device) if idist.get_world_size() > 1: with pytest.warns( @@ -833,21 +841,86 @@ def update(engine, i): assert (pr_compute1 == pr_compute2).all() +def _test_distrib_accumulator_device(device): + # Binary accuracy on input of shape (N, 1) or (N, ) + + def _test(average, metric_device): + pr = Precision(average=average, device=metric_device) + assert pr._device == metric_device + # Since the shape of the accumulated amount isn't known before the first update + # call, the internal variables aren't tensors on the right device yet. + + y_pred = torch.randint(0, 2, size=(10,)) + y = torch.randint(0, 2, size=(10,)).long() + pr.update((y_pred, y)) + + assert pr._true_positives.device == metric_device, "{}:{} vs {}:{}".format( + type(pr._true_positives.device), pr._true_positives.device, type(metric_device), metric_device + ) + assert pr._positives.device == metric_device, "{}:{} vs {}:{}".format( + type(pr._positives.device), pr._positives.device, type(metric_device), metric_device + ) + + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(idist.device()) + for metric_device in metric_devices: + _test(True, metric_device=metric_device) + _test(False, metric_device=metric_device) + + +def _test_distrib_multilabel_accumulator_device(device): + # Multiclass input data of shape (N, ) and (N, C) + + def _test(average, metric_device): + pr = Precision(is_multilabel=True, average=average, device=metric_device) + + assert pr._device == metric_device + assert pr._true_positives.device == metric_device, "{}:{} vs {}:{}".format( + type(pr._true_positives.device), pr._true_positives.device, type(metric_device), metric_device + ) + assert pr._positives.device == metric_device, "{}:{} vs {}:{}".format( + type(pr._positives.device), pr._positives.device, type(metric_device), metric_device + ) + + y_pred = torch.randint(0, 2, size=(10, 4, 20, 23)) + y = torch.randint(0, 2, size=(10, 4, 20, 23)).long() + pr.update((y_pred, y)) + + assert pr._true_positives.device == metric_device, "{}:{} vs {}:{}".format( + type(pr._true_positives.device), pr._true_positives.device, type(metric_device), metric_device + ) + assert pr._positives.device == metric_device, "{}:{} vs {}:{}".format( + type(pr._positives.device), pr._positives.device, type(metric_device), metric_device + ) + + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(idist.device()) + for metric_device in metric_devices: + _test(True, metric_device=metric_device) + _test(False, metric_device=metric_device) + + @pytest.mark.distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") def test_distrib_gpu(local_rank, distributed_context_single_node_nccl): - device = "cuda:{}".format(local_rank) + device = torch.device("cuda:{}".format(local_rank)) _test_distrib_integration_multiclass(device) _test_distrib_integration_multilabel(device) + _test_distrib_accumulator_device(device) + _test_distrib_multilabel_accumulator_device(device) @pytest.mark.distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") def test_distrib_cpu(local_rank, distributed_context_single_node_gloo): - device = "cpu" + device = torch.device("cpu") _test_distrib_integration_multiclass(device) _test_distrib_integration_multilabel(device) + _test_distrib_accumulator_device(device) + _test_distrib_multilabel_accumulator_device(device) @pytest.mark.distributed @@ -855,29 +928,35 @@ def test_distrib_cpu(local_rank, distributed_context_single_node_gloo): @pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc") def test_distrib_hvd(gloo_hvd_executor): - device = "cpu" if not torch.cuda.is_available() else "cuda" + device = torch.device("cpu" if not torch.cuda.is_available() else "cuda") nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count() gloo_hvd_executor(_test_distrib_integration_multiclass, (device,), np=nproc, do_init=True) gloo_hvd_executor(_test_distrib_integration_multilabel, (device,), np=nproc, do_init=True) + gloo_hvd_executor(_test_distrib_accumulator_device, (device,), np=nproc, do_init=True) + gloo_hvd_executor(_test_distrib_integration_multilabel, (device,), np=nproc, do_init=True) @pytest.mark.multinode_distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") def test_multinode_distrib_cpu(distributed_context_multi_node_gloo): - device = "cpu" + device = torch.device("cpu") _test_distrib_integration_multiclass(device) _test_distrib_integration_multilabel(device) + _test_distrib_accumulator_device(device) + _test_distrib_multilabel_accumulator_device(device) @pytest.mark.multinode_distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): - device = "cuda:{}".format(distributed_context_multi_node_nccl["local_rank"]) + device = torch.device("cuda:{}".format(distributed_context_multi_node_nccl["local_rank"])) _test_distrib_integration_multiclass(device) _test_distrib_integration_multilabel(device) + _test_distrib_accumulator_device(device) + _test_distrib_multilabel_accumulator_device(device) @pytest.mark.tpu @@ -887,12 +966,16 @@ def test_distrib_single_device_xla(): device = idist.device() _test_distrib_integration_multiclass(device) _test_distrib_integration_multilabel(device) + _test_distrib_accumulator_device(device) + _test_distrib_multilabel_accumulator_device(device) def _test_distrib_xla_nprocs(index): device = idist.device() _test_distrib_integration_multiclass(device) _test_distrib_integration_multilabel(device) + _test_distrib_accumulator_device(device) + _test_distrib_multilabel_accumulator_device(device) @pytest.mark.tpu diff --git a/tests/ignite/metrics/test_recall.py b/tests/ignite/metrics/test_recall.py index 214fc43c4314..0f3e4eaf497c 100644 --- a/tests/ignite/metrics/test_recall.py +++ b/tests/ignite/metrics/test_recall.py @@ -722,7 +722,7 @@ def _test_distrib_integration_multiclass(device): rank = idist.get_rank() torch.manual_seed(12) - def _test(average, n_epochs): + def _test(average, n_epochs, metric_device): n_iters = 60 s = 16 n_classes = 7 @@ -739,7 +739,7 @@ def update(engine, i): engine = Engine(update) - re = Recall(average=average) + re = Recall(average=average, device=metric_device) re.attach(engine, "re") data = list(range(n_iters)) @@ -748,7 +748,7 @@ def update(engine, i): assert "re" in engine.state.metrics res = engine.state.metrics["re"] if isinstance(res, torch.Tensor): - assert res.device.type == "cpu" + assert res.device == metric_device res = res.cpu().numpy() true_res = recall_score( @@ -757,11 +757,15 @@ def update(engine, i): assert pytest.approx(res) == true_res + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(idist.device()) for _ in range(2): - _test(average=True, n_epochs=1) - _test(average=True, n_epochs=2) - _test(average=False, n_epochs=1) - _test(average=False, n_epochs=2) + for metric_device in metric_devices: + _test(average=True, n_epochs=1, metric_device=metric_device) + _test(average=True, n_epochs=2, metric_device=metric_device) + _test(average=False, n_epochs=1, metric_device=metric_device) + _test(average=False, n_epochs=2, metric_device=metric_device) def _test_distrib_integration_multilabel(device): @@ -771,7 +775,7 @@ def _test_distrib_integration_multilabel(device): rank = idist.get_rank() torch.manual_seed(12) - def _test(average, n_epochs): + def _test(average, n_epochs, metric_device): n_iters = 60 s = 16 n_classes = 7 @@ -788,7 +792,7 @@ def update(engine, i): engine = Engine(update) - re = Recall(average=average, is_multilabel=True) + re = Recall(average=average, is_multilabel=True, device=metric_device) re.attach(engine, "re") data = list(range(n_iters)) @@ -812,9 +816,13 @@ def update(engine, i): assert pytest.approx(res) == true_res + metric_devices = ["cpu"] + if device.type != "xla": + metric_devices.append(idist.device()) for _ in range(2): - _test(average=True, n_epochs=1) - _test(average=True, n_epochs=2) + for metric_device in metric_devices: + _test(average=True, n_epochs=1, metric_device=metric_device) + _test(average=True, n_epochs=2, metric_device=metric_device) if idist.get_world_size() > 1: with pytest.warns( @@ -833,21 +841,86 @@ def update(engine, i): assert (re_compute1 == re_compute2).all() +def _test_distrib_accumulator_device(device): + # Binary accuracy on input of shape (N, 1) or (N, ) + + def _test(average, metric_device): + re = Recall(average=average, device=metric_device) + assert re._device == metric_device + # Since the shape of the accumulated amount isn't known before the first update + # call, the internal variables aren't tensors on the right device yet. + + y_reed = torch.randint(0, 2, size=(10,)) + y = torch.randint(0, 2, size=(10,)).long() + re.update((y_reed, y)) + + assert re._true_positives.device == metric_device, "{}:{} vs {}:{}".format( + type(re._true_positives.device), re._true_positives.device, type(metric_device), metric_device + ) + assert re._positives.device == metric_device, "{}:{} vs {}:{}".format( + type(re._positives.device), re._positives.device, type(metric_device), metric_device + ) + + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(idist.device()) + for metric_device in metric_devices: + _test(True, metric_device=metric_device) + _test(False, metric_device=metric_device) + + +def _test_distrib_multilabel_accumulator_device(device): + # Multiclass input data of shape (N, ) and (N, C) + + def _test(average, metric_device): + re = Recall(is_multilabel=True, average=average, device=metric_device) + + assert re._device == metric_device + assert re._true_positives.device == metric_device, "{}:{} vs {}:{}".format( + type(re._true_positives.device), re._true_positives.device, type(metric_device), metric_device + ) + assert re._positives.device == metric_device, "{}:{} vs {}:{}".format( + type(re._positives.device), re._positives.device, type(metric_device), metric_device + ) + + y_reed = torch.randint(0, 2, size=(10, 4, 20, 23)) + y = torch.randint(0, 2, size=(10, 4, 20, 23)).long() + re.update((y_reed, y)) + + assert re._true_positives.device == metric_device, "{}:{} vs {}:{}".format( + type(re._true_positives.device), re._true_positives.device, type(metric_device), metric_device + ) + assert re._positives.device == metric_device, "{}:{} vs {}:{}".format( + type(re._positives.device), re._positives.device, type(metric_device), metric_device + ) + + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(idist.device()) + for metric_device in metric_devices: + _test(True, metric_device=metric_device) + _test(False, metric_device=metric_device) + + @pytest.mark.distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") def test_distrib_gpu(local_rank, distributed_context_single_node_nccl): - device = "cuda:{}".format(local_rank) + device = torch.device("cuda:{}".format(local_rank)) _test_distrib_integration_multiclass(device) _test_distrib_integration_multilabel(device) + _test_distrib_accumulator_device(device) + _test_distrib_multilabel_accumulator_device(device) @pytest.mark.distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") def test_distrib_cpu(distributed_context_single_node_gloo): - device = "cpu" + device = torch.device("cpu") _test_distrib_integration_multiclass(device) _test_distrib_integration_multilabel(device) + _test_distrib_accumulator_device(device) + _test_distrib_multilabel_accumulator_device(device) @pytest.mark.distributed @@ -855,29 +928,35 @@ def test_distrib_cpu(distributed_context_single_node_gloo): @pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc") def test_distrib_hvd(gloo_hvd_executor): - device = "cpu" if not torch.cuda.is_available() else "cuda" + device = torch.device("cpu" if not torch.cuda.is_available() else "cuda") nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count() gloo_hvd_executor(_test_distrib_integration_multiclass, (device,), np=nproc, do_init=True) gloo_hvd_executor(_test_distrib_integration_multilabel, (device,), np=nproc, do_init=True) + gloo_hvd_executor(_test_distrib_accumulator_device, (device,), np=nproc, do_init=True) + gloo_hvd_executor(_test_distrib_multilabel_accumulator_device, (device,), np=nproc, do_init=True) @pytest.mark.multinode_distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") def test_multinode_distrib_cpu(distributed_context_multi_node_gloo): - device = "cpu" + device = torch.device("cpu") _test_distrib_integration_multiclass(device) _test_distrib_integration_multilabel(device) + _test_distrib_accumulator_device(device) + _test_distrib_multilabel_accumulator_device(device) @pytest.mark.multinode_distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): - device = "cuda:{}".format(distributed_context_multi_node_nccl["local_rank"]) + device = torch.device("cuda:{}".format(distributed_context_multi_node_nccl["local_rank"])) _test_distrib_integration_multiclass(device) _test_distrib_integration_multilabel(device) + _test_distrib_accumulator_device(device) + _test_distrib_multilabel_accumulator_device(device) @pytest.mark.tpu @@ -887,12 +966,16 @@ def test_distrib_single_device_xla(): device = idist.device() _test_distrib_integration_multiclass(device) _test_distrib_integration_multilabel(device) + _test_distrib_accumulator_device(device) + _test_distrib_multilabel_accumulator_device(device) def _test_distrib_xla_nprocs(index): device = idist.device() _test_distrib_integration_multiclass(device) _test_distrib_integration_multilabel(device) + _test_distrib_accumulator_device(device) + _test_distrib_multilabel_accumulator_device(device) @pytest.mark.tpu diff --git a/tests/ignite/metrics/test_root_mean_squared_error.py b/tests/ignite/metrics/test_root_mean_squared_error.py index 878ef9df367d..ad2d5bf71d2f 100644 --- a/tests/ignite/metrics/test_root_mean_squared_error.py +++ b/tests/ignite/metrics/test_root_mean_squared_error.py @@ -46,25 +46,30 @@ def _test_distrib_integration(device, tol=1e-6): def update(engine, i): return y_preds[i * s : (i + 1) * s], y_true[i * s + offset * rank : (i + 1) * s + offset * rank] - engine = Engine(update) + def _test(metric_device): + engine = Engine(update) - m = RootMeanSquaredError() - m.attach(engine, "rmse") + m = RootMeanSquaredError(device=metric_device) + m.attach(engine, "rmse") - data = list(range(n_iters)) - engine.run(data=data, max_epochs=1) + data = list(range(n_iters)) + engine.run(data=data, max_epochs=1) - assert "rmse" in engine.state.metrics - res = engine.state.metrics["rmse"] + assert "rmse" in engine.state.metrics + res = engine.state.metrics["rmse"] - y_preds_full = [] - for i in range(idist.get_world_size()): - y_preds_full.append((i + 1) * torch.ones(offset)) - y_preds_full = torch.stack(y_preds_full).to(device).flatten() + y_preds_full = [] + for i in range(idist.get_world_size()): + y_preds_full.append((i + 1) * torch.ones(offset)) + y_preds_full = torch.stack(y_preds_full).to(device).flatten() - true_res = np.sqrt(np.mean(np.square((y_true - y_preds_full).cpu().numpy()))) + true_res = np.sqrt(np.mean(np.square((y_true - y_preds_full).cpu().numpy()))) - assert pytest.approx(res, rel=tol) == true_res + assert pytest.approx(res, rel=tol) == true_res + + _test("cpu") + if device.type != "xla": + _test(idist.device()) @pytest.mark.distributed @@ -72,7 +77,7 @@ def update(engine, i): @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") def test_distrib_gpu(local_rank, distributed_context_single_node_nccl): - device = "cuda:{}".format(local_rank) + device = torch.device("cuda:{}".format(local_rank)) _test_distrib_integration(device) @@ -80,7 +85,7 @@ def test_distrib_gpu(local_rank, distributed_context_single_node_nccl): @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") def test_distrib_cpu(local_rank, distributed_context_single_node_gloo): - device = "cpu" + device = torch.device("cpu") _test_distrib_integration(device) @@ -89,7 +94,7 @@ def test_distrib_cpu(local_rank, distributed_context_single_node_gloo): @pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc") def test_distrib_hvd(gloo_hvd_executor): - device = "cpu" if not torch.cuda.is_available() else "cuda" + device = torch.device("cpu" if not torch.cuda.is_available() else "cuda") nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count() gloo_hvd_executor(_test_distrib_integration, (device,), np=nproc, do_init=True) @@ -99,7 +104,7 @@ def test_distrib_hvd(gloo_hvd_executor): @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") def test_multinode_distrib_cpu(distributed_context_multi_node_gloo): - device = "cpu" + device = torch.device("cpu") _test_distrib_integration(device) @@ -107,7 +112,7 @@ def test_multinode_distrib_cpu(distributed_context_multi_node_gloo): @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): - device = "cuda:{}".format(distributed_context_multi_node_nccl["local_rank"]) + device = torch.device("cuda:{}".format(distributed_context_multi_node_nccl["local_rank"])) _test_distrib_integration(device) diff --git a/tests/ignite/metrics/test_running_average.py b/tests/ignite/metrics/test_running_average.py index c66fdfabdc5e..047e3e19f6e5 100644 --- a/tests/ignite/metrics/test_running_average.py +++ b/tests/ignite/metrics/test_running_average.py @@ -269,7 +269,8 @@ def update_fn(engine, batch): trainer = Engine(update_fn) alpha = 0.98 - avg_output = RunningAverage(output_transform=lambda x: x, alpha=alpha, epoch_bound=False, device=device) + metric_device = idist.device() if torch.device(device).type != "xla" else "cpu" + avg_output = RunningAverage(output_transform=lambda x: x, alpha=alpha, epoch_bound=False, device=metric_device) avg_output.attach(trainer, "running_avg_output") @trainer.on(Events.STARTED) @@ -305,61 +306,88 @@ def _test_distrib_on_metric(device): batch_size = 10 n_classes = 10 - data = list(range(n_iters)) - np.random.seed(12) - all_y_true_batch_values = np.random.randint( - 0, n_classes, size=(idist.get_world_size(), n_epochs * n_iters, batch_size) - ) - all_y_pred_batch_values = np.random.rand(idist.get_world_size(), n_epochs * n_iters, batch_size, n_classes) + def _test(metric_device): + data = list(range(n_iters)) + np.random.seed(12) + all_y_true_batch_values = np.random.randint( + 0, n_classes, size=(idist.get_world_size(), n_epochs * n_iters, batch_size) + ) + all_y_pred_batch_values = np.random.rand(idist.get_world_size(), n_epochs * n_iters, batch_size, n_classes) - y_true_batch_values = iter(all_y_true_batch_values[rank, ...]) - y_pred_batch_values = iter(all_y_pred_batch_values[rank, ...]) + y_true_batch_values = iter(all_y_true_batch_values[rank, ...]) + y_pred_batch_values = iter(all_y_pred_batch_values[rank, ...]) - def update_fn(engine, batch): - y_true_batch = next(y_true_batch_values) - y_pred_batch = next(y_pred_batch_values) - return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch) + def update_fn(engine, batch): + y_true_batch = next(y_true_batch_values) + y_pred_batch = next(y_pred_batch_values) + return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch) - trainer = Engine(update_fn) - alpha = 0.98 + trainer = Engine(update_fn) + alpha = 0.98 - acc_metric = RunningAverage( - Accuracy(output_transform=lambda x: [x[0], x[1]], device=device), alpha=alpha, epoch_bound=False - ) - acc_metric.attach(trainer, "running_avg_accuracy") + acc_metric = RunningAverage( + Accuracy(output_transform=lambda x: [x[0], x[1]], device=metric_device), alpha=alpha, epoch_bound=False + ) + acc_metric.attach(trainer, "running_avg_accuracy") + + running_avg_acc = [ + None, + ] + true_acc_metric = Accuracy(device=metric_device) + + @trainer.on(Events.ITERATION_COMPLETED) + def manual_running_avg_acc(engine): + i = engine.state.iteration - 1 + + true_acc_metric.reset() + for j in range(idist.get_world_size()): + output = ( + torch.from_numpy(all_y_pred_batch_values[j, i, :, :]), + torch.from_numpy(all_y_true_batch_values[j, i, :]), + ) + true_acc_metric.update(output) + + batch_acc = true_acc_metric._num_correct.item() * 1.0 / true_acc_metric._num_examples + + if running_avg_acc[0] is None: + running_avg_acc[0] = batch_acc + else: + running_avg_acc[0] = running_avg_acc[0] * alpha + (1.0 - alpha) * batch_acc + engine.state.running_avg_acc = running_avg_acc[0] + + @trainer.on(Events.ITERATION_COMPLETED) + def assert_equal_running_avg_acc_values(engine): + assert engine.state.running_avg_acc == engine.state.metrics["running_avg_accuracy"], "{} vs {}".format( + engine.state.running_avg_acc, engine.state.metrics["running_avg_accuracy"] + ) - running_avg_acc = [ - None, - ] - true_acc_metric = Accuracy(device=device) + trainer.run(data, max_epochs=3) - @trainer.on(Events.ITERATION_COMPLETED) - def manual_running_avg_acc(engine): - i = engine.state.iteration - 1 + _test("cpu") + if device.type != "xla": + _test(idist.device()) - true_acc_metric.reset() - for j in range(idist.get_world_size()): - output = ( - torch.from_numpy(all_y_pred_batch_values[j, i, :, :]), - torch.from_numpy(all_y_true_batch_values[j, i, :]), - ) - true_acc_metric.update(output) - batch_acc = true_acc_metric._num_correct * 1.0 / true_acc_metric._num_examples +def _test_distrib_accumulator_device(device): - if running_avg_acc[0] is None: - running_avg_acc[0] = batch_acc - else: - running_avg_acc[0] = running_avg_acc[0] * alpha + (1.0 - alpha) * batch_acc - engine.state.running_avg_acc = running_avg_acc[0] + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(idist.device()) + for metric_device in metric_devices: - @trainer.on(Events.ITERATION_COMPLETED) - def assert_equal_running_avg_acc_values(engine): - assert engine.state.running_avg_acc == engine.state.metrics["running_avg_accuracy"], "{} vs {}".format( - engine.state.running_avg_acc, engine.state.metrics["running_avg_accuracy"] - ) + # Don't test the src=Metric case because compute() returns a scalar, + # so the metric doesn't accumulate on the device specified + avg = RunningAverage(output_transform=lambda x: x, device=metric_device) + assert avg._device == metric_device + # Value is None until the first update then compute call - trainer.run(data, max_epochs=3) + for _ in range(3): + avg.update(torch.tensor(1.0, device=device)) + avg.compute() + + assert avg._value.device == metric_device, "{}:{} vs {}:{}".format( + type(avg._value.device), avg._value.device, type(metric_device), metric_device + ) @pytest.mark.distributed @@ -367,18 +395,20 @@ def assert_equal_running_avg_acc_values(engine): @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") def test_distrib_gpu(local_rank, distributed_context_single_node_nccl): - device = "cuda:{}".format(local_rank) + device = torch.device("cuda:{}".format(local_rank)) _test_distrib_on_output(device) _test_distrib_on_metric(device) + _test_distrib_accumulator_device(device) @pytest.mark.distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") def test_distrib_cpu(distributed_context_single_node_gloo): - device = "cpu" + device = torch.device("cpu") _test_distrib_on_output(device) _test_distrib_on_metric(device) + _test_distrib_accumulator_device(device) @pytest.mark.distributed @@ -386,29 +416,32 @@ def test_distrib_cpu(distributed_context_single_node_gloo): @pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc") def test_distrib_hvd(gloo_hvd_executor): - device = "cpu" if not torch.cuda.is_available() else "cuda" + device = torch.device("cpu" if not torch.cuda.is_available() else "cuda") nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count() gloo_hvd_executor(_test_distrib_on_output, (device,), np=nproc, do_init=True) gloo_hvd_executor(_test_distrib_on_metric, (device,), np=nproc, do_init=True) + gloo_hvd_executor(_test_distrib_accumulator_device, (device,), np=nproc, do_init=True) @pytest.mark.multinode_distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") def test_multinode_distrib_cpu(distributed_context_multi_node_gloo): - device = "cpu" + device = torch.device("cpu") _test_distrib_on_output(device) _test_distrib_on_metric(device) + _test_distrib_accumulator_device(device) @pytest.mark.multinode_distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): - device = "cuda:{}".format(distributed_context_multi_node_nccl["local_rank"]) + device = torch.device("cuda:{}".format(distributed_context_multi_node_nccl["local_rank"])) _test_distrib_on_output(device) _test_distrib_on_metric(device) + _test_distrib_accumulator_device(device) @pytest.mark.tpu @@ -418,12 +451,14 @@ def test_distrib_single_device_xla(): device = idist.device() _test_distrib_on_output(device) _test_distrib_on_metric(device) + _test_distrib_accumulator_device(device) def _test_distrib_xla_nprocs(index): device = idist.device() _test_distrib_on_output(device) _test_distrib_on_metric(device) + _test_distrib_accumulator_device(device) @pytest.mark.tpu diff --git a/tests/ignite/metrics/test_top_k_categorical_accuracy.py b/tests/ignite/metrics/test_top_k_categorical_accuracy.py index 6caf39a08f71..af6d003800f3 100644 --- a/tests/ignite/metrics/test_top_k_categorical_accuracy.py +++ b/tests/ignite/metrics/test_top_k_categorical_accuracy.py @@ -59,7 +59,7 @@ def _test_distrib_integration(device): rank = idist.get_rank() torch.manual_seed(12) - def _test(n_epochs): + def _test(n_epochs, metric_device): n_iters = 100 s = 16 n_classes = 10 @@ -79,7 +79,7 @@ def update(engine, i): engine = Engine(update) k = 5 - acc = TopKCategoricalAccuracy(k=k, device=device) + acc = TopKCategoricalAccuracy(k=k, device=metric_device) acc.attach(engine, "acc") data = list(range(n_iters)) @@ -94,24 +94,52 @@ def update(engine, i): assert pytest.approx(res) == true_res - for _ in range(5): - _test(n_epochs=1) - _test(n_epochs=2) + metric_devices = ["cpu"] + if device.type != "xla": + metric_devices.append(idist.device()) + for _ in range(3): + for metric_device in metric_devices: + _test(n_epochs=1, metric_device=metric_device) + _test(n_epochs=2, metric_device=metric_device) + + +def _test_distrib_accumulator_device(device): + + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(idist.device()) + for metric_device in metric_devices: + + acc = TopKCategoricalAccuracy(2, device=metric_device) + assert acc._device == metric_device + assert acc._num_correct.device == metric_device, "{}:{} vs {}:{}".format( + type(acc._num_correct.device), acc._num_correct.device, type(metric_device), metric_device + ) + + y_pred = torch.tensor([[0.2, 0.4, 0.6, 0.8], [0.8, 0.6, 0.4, 0.2]]) + y = torch.ones(2).long() + acc.update((y_pred, y)) + + assert acc._num_correct.device == metric_device, "{}:{} vs {}:{}".format( + type(acc._num_correct.device), acc._num_correct.device, type(metric_device), metric_device + ) @pytest.mark.distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") def test_distrib_gpu(local_rank, distributed_context_single_node_nccl): - device = "cuda:{}".format(local_rank) + device = torch.device("cuda:{}".format(local_rank)) _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") def test_distrib_cpu(local_rank, distributed_context_single_node_gloo): - device = "cpu" + device = torch.device("cpu") _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.distributed @@ -119,26 +147,29 @@ def test_distrib_cpu(local_rank, distributed_context_single_node_gloo): @pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc") def test_distrib_hvd(gloo_hvd_executor): - device = "cpu" if not torch.cuda.is_available() else "cuda" + device = torch.device("cpu" if not torch.cuda.is_available() else "cuda") nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count() gloo_hvd_executor(_test_distrib_integration, (device,), np=nproc, do_init=True) + gloo_hvd_executor(_test_distrib_accumulator_device, (device,), np=nproc, do_init=True) @pytest.mark.multinode_distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") def test_multinode_distrib_cpu(distributed_context_multi_node_gloo): - device = "cpu" + device = torch.device("cpu") _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.multinode_distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): - device = "cuda:{}".format(distributed_context_multi_node_nccl["local_rank"]) + device = torch.device("cuda:{}".format(distributed_context_multi_node_nccl["local_rank"])) _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.tpu @@ -147,11 +178,13 @@ def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): def test_distrib_single_device_xla(): device = idist.device() _test_distrib_integration(device) + _test_distrib_accumulator_device(device) def _test_distrib_xla_nprocs(index): device = idist.device() _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.tpu