diff --git a/ignite/metrics/accumulation.py b/ignite/metrics/accumulation.py index fabe04f4bbfc..926e7816bae2 100644 --- a/ignite/metrics/accumulation.py +++ b/ignite/metrics/accumulation.py @@ -66,9 +66,9 @@ def _check_output_type(self, output: Union[Any, torch.Tensor, numbers.Number]) - def update(self, output: Union[Any, torch.Tensor, numbers.Number]) -> None: self._check_output_type(output) - if self._device is not None: - # Put output to the metric's device - if isinstance(output, torch.Tensor) and (output.device != self._device): + if isinstance(output, torch.Tensor): + output = output.detach() + if output.device != self._device: output = output.to(self._device) self.accumulator = self._op(self.accumulator, output) diff --git a/ignite/metrics/accuracy.py b/ignite/metrics/accuracy.py index fe53db66f828..7d6c939e4b53 100644 --- a/ignite/metrics/accuracy.py +++ b/ignite/metrics/accuracy.py @@ -146,9 +146,9 @@ def reset(self) -> None: @reinit__is_reduced def update(self, output: Sequence[torch.Tensor]) -> None: - y_pred, y = output - self._check_shape((y_pred, y)) - self._check_type((y_pred, y)) + self._check_shape(output) + self._check_type(output) + y_pred, y = output[0].detach(), output[1].detach() if self._type == "binary": correct = torch.eq(y_pred.view(-1).to(y), y.view(-1)) @@ -163,7 +163,6 @@ def update(self, output: Sequence[torch.Tensor]) -> None: y = torch.transpose(y, 1, last_dim - 1).reshape(-1, num_classes) correct = torch.all(y == y_pred.type_as(y), dim=-1) - # Don't need to detach here because torch.eq is not differentiable, so the computation graph is detached anyway. self._num_correct += torch.sum(correct).to(self._device) self._num_examples += correct.shape[0] diff --git a/ignite/metrics/confusion_matrix.py b/ignite/metrics/confusion_matrix.py index 3d179a458d62..3c797efaf3e4 100644 --- a/ignite/metrics/confusion_matrix.py +++ b/ignite/metrics/confusion_matrix.py @@ -63,7 +63,7 @@ def reset(self) -> None: self._num_examples = 0 def _check_shape(self, output: Sequence[torch.Tensor]) -> None: - y_pred, y = output + y_pred, y = output[0].detach(), output[1].detach() if y_pred.ndimension() < 2: raise ValueError( @@ -94,7 +94,7 @@ def _check_shape(self, output: Sequence[torch.Tensor]) -> None: @reinit__is_reduced def update(self, output: Sequence[torch.Tensor]) -> None: self._check_shape(output) - y_pred, y = output + y_pred, y = output[0].detach(), output[1].detach() self._num_examples += y_pred.shape[0] diff --git a/ignite/metrics/precision.py b/ignite/metrics/precision.py index 055f11c9c391..f1d417ee015f 100644 --- a/ignite/metrics/precision.py +++ b/ignite/metrics/precision.py @@ -141,9 +141,9 @@ def __init__( @reinit__is_reduced def update(self, output: Sequence[torch.Tensor]) -> None: - y_pred, y = output self._check_shape(output) - self._check_type((y_pred, y)) + self._check_type(output) + y_pred, y = output[0].detach(), output[1].detach() if self._type == "binary": y_pred = y_pred.view(-1) diff --git a/ignite/metrics/recall.py b/ignite/metrics/recall.py index a094557249c5..ad391705a004 100644 --- a/ignite/metrics/recall.py +++ b/ignite/metrics/recall.py @@ -79,9 +79,9 @@ def __init__( @reinit__is_reduced def update(self, output: Sequence[torch.Tensor]) -> None: - y_pred, y = output self._check_shape(output) - self._check_type((y_pred, y)) + self._check_type(output) + y_pred, y = output[0].detach(), output[1].detach() if self._type == "binary": y_pred = y_pred.view(-1) diff --git a/ignite/metrics/ssim.py b/ignite/metrics/ssim.py index 91491432db37..7c20a3260cf7 100644 --- a/ignite/metrics/ssim.py +++ b/ignite/metrics/ssim.py @@ -116,7 +116,8 @@ def _gaussian_or_uniform_kernel(self, kernel_size, sigma): @reinit__is_reduced def update(self, output: Sequence[torch.Tensor]) -> None: - y_pred, y = output + y_pred, y = output[0].detach(), output[1].detach() + if y_pred.dtype != y.dtype: raise TypeError( "Expected y_pred and y to have the same data type. Got y_pred: {} and y: {}.".format( diff --git a/ignite/metrics/top_k_categorical_accuracy.py b/ignite/metrics/top_k_categorical_accuracy.py index aa64f5b45319..dad423f86af1 100644 --- a/ignite/metrics/top_k_categorical_accuracy.py +++ b/ignite/metrics/top_k_categorical_accuracy.py @@ -28,12 +28,11 @@ def reset(self) -> None: @reinit__is_reduced def update(self, output: Sequence) -> None: - y_pred, y = output + y_pred, y = output[0].detach(), output[1].detach() sorted_indices = torch.topk(y_pred, self._k, dim=1)[1] expanded_y = y.view(-1, 1).expand(-1, self._k) correct = torch.sum(torch.eq(sorted_indices, expanded_y), dim=1) - # Don't need to detach here because torch.eq is not differentiable, so the computation graph is detached anyway. self._num_correct += torch.sum(correct).to(self._device) self._num_examples += correct.shape[0] diff --git a/tests/ignite/metrics/test_accumulation.py b/tests/ignite/metrics/test_accumulation.py index 949819fdd2ee..1e47ab5b1ade 100644 --- a/tests/ignite/metrics/test_accumulation.py +++ b/tests/ignite/metrics/test_accumulation.py @@ -232,7 +232,7 @@ def _test(metric_device): for _ in range(3): _test("cpu") if device.type != "xla": - _test(device) + _test(idist.device()) def _test_distrib_average(device): @@ -271,7 +271,7 @@ def _test(metric_device): for _ in range(3): _test("cpu") if device.type != "xla": - _test(device) + _test(idist.device()) def _test_distrib_geom_average(device): @@ -310,7 +310,7 @@ def _test(metric_device): for _ in range(3): _test("cpu") if device.type != "xla": - _test(device) + _test(idist.device()) def _test_distrib_integration(device): @@ -362,7 +362,7 @@ def _geom_mean(y_true): metric_devices = ["cpu"] if device.type != "xla": - metric_devices.append(device) + metric_devices.append(idist.device()) for metric_device in metric_devices: _test(Average, _mean, metric_device) _test(GeometricAverage, _geom_mean, metric_device, tol=1e-4) @@ -372,7 +372,7 @@ def _test_distrib_accumulator_device(device): metric_devices = [torch.device("cpu")] if device.type != "xla": - metric_devices.append(device) + metric_devices.append(idist.device()) for metric_device in metric_devices: m = VariableAccumulation(lambda a, x: x, device=metric_device) diff --git a/tests/ignite/metrics/test_accuracy.py b/tests/ignite/metrics/test_accuracy.py index 6ca82e8cca72..3960a09ec7f7 100644 --- a/tests/ignite/metrics/test_accuracy.py +++ b/tests/ignite/metrics/test_accuracy.py @@ -702,7 +702,7 @@ def _test(metric_device): for _ in range(3): _test("cpu") if device.type != "xla": - _test(device) + _test(idist.device()) def _test_distrib_integration_multiclass(device): @@ -751,7 +751,7 @@ def update(engine, i): metric_devices = ["cpu"] if device.type != "xla": - metric_devices.append(device) + metric_devices.append(idist.device()) for metric_device in metric_devices: for _ in range(2): _test(n_epochs=1, metric_device=metric_device) @@ -804,7 +804,7 @@ def update(engine, i): metric_devices = ["cpu"] if device.type != "xla": - metric_devices.append(device) + metric_devices.append(idist.device()) for metric_device in metric_devices: for _ in range(2): _test(n_epochs=1, metric_device=metric_device) @@ -815,7 +815,7 @@ def _test_distrib_accumulator_device(device): metric_devices = [torch.device("cpu")] if device.type != "xla": - metric_devices.append(device) + metric_devices.append(idist.device()) for metric_device in metric_devices: acc = Accuracy(device=metric_device) diff --git a/tests/ignite/metrics/test_confusion_matrix.py b/tests/ignite/metrics/test_confusion_matrix.py index 1b5ba4ba67c4..8da41e07a411 100644 --- a/tests/ignite/metrics/test_confusion_matrix.py +++ b/tests/ignite/metrics/test_confusion_matrix.py @@ -619,7 +619,7 @@ def _test_distrib_accumulator_device(device): metric_devices = [torch.device("cpu")] if device.type != "xla": - metric_devices.append(device) + metric_devices.append(idist.device()) for metric_device in metric_devices: cm = ConfusionMatrix(num_classes=3, device=metric_device) diff --git a/tests/ignite/metrics/test_fbeta.py b/tests/ignite/metrics/test_fbeta.py index a51df4b3583f..da4a8df58c23 100644 --- a/tests/ignite/metrics/test_fbeta.py +++ b/tests/ignite/metrics/test_fbeta.py @@ -133,7 +133,7 @@ def update(engine, i): metric_devices = ["cpu"] if device.type != "xla": - metric_devices.append(device) + metric_devices.append(idist.device()) for metric_device in metric_devices: _test(None, None, average=True, n_epochs=1, metric_device=metric_device) _test(None, None, average=True, n_epochs=2, metric_device=metric_device) diff --git a/tests/ignite/metrics/test_loss.py b/tests/ignite/metrics/test_loss.py index 8c2014431d5c..35638684916d 100644 --- a/tests/ignite/metrics/test_loss.py +++ b/tests/ignite/metrics/test_loss.py @@ -108,14 +108,14 @@ def _test(metric_device): _test("cpu") if device.type != "xla": - _test(device) + _test(idist.device()) def _test_distrib_accumulator_device(device): metric_devices = [torch.device("cpu")] if device.type != "xla": - metric_devices.append(device) + metric_devices.append(idist.device()) for metric_device in metric_devices: loss = Loss(nll_loss, device=metric_device) assert loss._device == metric_device diff --git a/tests/ignite/metrics/test_mean_absolute_error.py b/tests/ignite/metrics/test_mean_absolute_error.py index 946a540f7f75..c47a2c281b7f 100644 --- a/tests/ignite/metrics/test_mean_absolute_error.py +++ b/tests/ignite/metrics/test_mean_absolute_error.py @@ -67,14 +67,14 @@ def _test(metric_device): _test("cpu") if device.type != "xla": - _test(device) + _test(idist.device()) def _test_distrib_accumulator_device(device): metric_devices = [torch.device("cpu")] if device.type != "xla": - metric_devices.append(device) + metric_devices.append(idist.device()) for metric_device in metric_devices: mae = MeanAbsoluteError(device=metric_device) assert mae._device == metric_device diff --git a/tests/ignite/metrics/test_mean_pairwise_distance.py b/tests/ignite/metrics/test_mean_pairwise_distance.py index c85a84674072..b0db1917c1d2 100644 --- a/tests/ignite/metrics/test_mean_pairwise_distance.py +++ b/tests/ignite/metrics/test_mean_pairwise_distance.py @@ -80,14 +80,14 @@ def _test(metric_device): _test("cpu") if device.type != "xla": - _test(device) + _test(idist.device()) def _test_distrib_accumulator_device(device): metric_devices = [torch.device("cpu")] if device.type != "xla": - metric_devices.append(device) + metric_devices.append(idist.device()) for metric_device in metric_devices: mpd = MeanPairwiseDistance(device=metric_device) diff --git a/tests/ignite/metrics/test_mean_squared_error.py b/tests/ignite/metrics/test_mean_squared_error.py index eea8a7f9428c..91a87b561974 100644 --- a/tests/ignite/metrics/test_mean_squared_error.py +++ b/tests/ignite/metrics/test_mean_squared_error.py @@ -67,14 +67,14 @@ def _test(metric_device): _test("cpu") if device.type != "xla": - _test(device) + _test(idist.device()) def _test_distrib_accumulator_device(device): metric_devices = [torch.device("cpu")] if device.type != "xla": - metric_devices.append(device) + metric_devices.append(idist.device()) for metric_device in metric_devices: device = torch.device(device) diff --git a/tests/ignite/metrics/test_metrics_lambda.py b/tests/ignite/metrics/test_metrics_lambda.py index b53a53ae46cd..c8e77736370c 100644 --- a/tests/ignite/metrics/test_metrics_lambda.py +++ b/tests/ignite/metrics/test_metrics_lambda.py @@ -369,7 +369,7 @@ def Fbeta(r, p, beta): for _ in range(3): _test("cpu") if device.type != "xla": - _test(device) + _test(idist.device()) def _test_distrib_metrics_on_diff_devices(device): diff --git a/tests/ignite/metrics/test_precision.py b/tests/ignite/metrics/test_precision.py index ef33208c0882..493d3902b3e9 100644 --- a/tests/ignite/metrics/test_precision.py +++ b/tests/ignite/metrics/test_precision.py @@ -759,7 +759,7 @@ def update(engine, i): metric_devices = [torch.device("cpu")] if device.type != "xla": - metric_devices.append(device) + metric_devices.append(idist.device()) for _ in range(2): for metric_device in metric_devices: _test(average=True, n_epochs=1, metric_device=metric_device) @@ -818,7 +818,7 @@ def update(engine, i): metric_devices = ["cpu"] if device.type != "xla": - metric_devices.append(device) + metric_devices.append(idist.device()) for _ in range(2): for metric_device in metric_devices: _test(average=True, n_epochs=1, metric_device=metric_device) @@ -863,7 +863,7 @@ def _test(average, metric_device): metric_devices = [torch.device("cpu")] if device.type != "xla": - metric_devices.append(device) + metric_devices.append(idist.device()) for metric_device in metric_devices: _test(True, metric_device=metric_device) _test(False, metric_device=metric_device) @@ -896,7 +896,7 @@ def _test(average, metric_device): metric_devices = [torch.device("cpu")] if device.type != "xla": - metric_devices.append(device) + metric_devices.append(idist.device()) for metric_device in metric_devices: _test(True, metric_device=metric_device) _test(False, metric_device=metric_device) diff --git a/tests/ignite/metrics/test_recall.py b/tests/ignite/metrics/test_recall.py index 4ee7f8298d10..0f3e4eaf497c 100644 --- a/tests/ignite/metrics/test_recall.py +++ b/tests/ignite/metrics/test_recall.py @@ -759,7 +759,7 @@ def update(engine, i): metric_devices = [torch.device("cpu")] if device.type != "xla": - metric_devices.append(device) + metric_devices.append(idist.device()) for _ in range(2): for metric_device in metric_devices: _test(average=True, n_epochs=1, metric_device=metric_device) @@ -818,7 +818,7 @@ def update(engine, i): metric_devices = ["cpu"] if device.type != "xla": - metric_devices.append(device) + metric_devices.append(idist.device()) for _ in range(2): for metric_device in metric_devices: _test(average=True, n_epochs=1, metric_device=metric_device) @@ -863,7 +863,7 @@ def _test(average, metric_device): metric_devices = [torch.device("cpu")] if device.type != "xla": - metric_devices.append(device) + metric_devices.append(idist.device()) for metric_device in metric_devices: _test(True, metric_device=metric_device) _test(False, metric_device=metric_device) @@ -896,7 +896,7 @@ def _test(average, metric_device): metric_devices = [torch.device("cpu")] if device.type != "xla": - metric_devices.append(device) + metric_devices.append(idist.device()) for metric_device in metric_devices: _test(True, metric_device=metric_device) _test(False, metric_device=metric_device) diff --git a/tests/ignite/metrics/test_root_mean_squared_error.py b/tests/ignite/metrics/test_root_mean_squared_error.py index 9cc3d84f1b32..ad2d5bf71d2f 100644 --- a/tests/ignite/metrics/test_root_mean_squared_error.py +++ b/tests/ignite/metrics/test_root_mean_squared_error.py @@ -69,7 +69,7 @@ def _test(metric_device): _test("cpu") if device.type != "xla": - _test(device) + _test(idist.device()) @pytest.mark.distributed diff --git a/tests/ignite/metrics/test_running_average.py b/tests/ignite/metrics/test_running_average.py index 21416814c017..047e3e19f6e5 100644 --- a/tests/ignite/metrics/test_running_average.py +++ b/tests/ignite/metrics/test_running_average.py @@ -269,7 +269,7 @@ def update_fn(engine, batch): trainer = Engine(update_fn) alpha = 0.98 - metric_device = device if torch.device(device).type != "xla" else "cpu" + metric_device = idist.device() if torch.device(device).type != "xla" else "cpu" avg_output = RunningAverage(output_transform=lambda x: x, alpha=alpha, epoch_bound=False, device=metric_device) avg_output.attach(trainer, "running_avg_output") @@ -365,14 +365,14 @@ def assert_equal_running_avg_acc_values(engine): _test("cpu") if device.type != "xla": - _test(device) + _test(idist.device()) def _test_distrib_accumulator_device(device): metric_devices = [torch.device("cpu")] if device.type != "xla": - metric_devices.append(device) + metric_devices.append(idist.device()) for metric_device in metric_devices: # Don't test the src=Metric case because compute() returns a scalar, diff --git a/tests/ignite/metrics/test_top_k_categorical_accuracy.py b/tests/ignite/metrics/test_top_k_categorical_accuracy.py index 6f4c68327271..af6d003800f3 100644 --- a/tests/ignite/metrics/test_top_k_categorical_accuracy.py +++ b/tests/ignite/metrics/test_top_k_categorical_accuracy.py @@ -96,7 +96,7 @@ def update(engine, i): metric_devices = ["cpu"] if device.type != "xla": - metric_devices.append(device) + metric_devices.append(idist.device()) for _ in range(3): for metric_device in metric_devices: _test(n_epochs=1, metric_device=metric_device) @@ -107,7 +107,7 @@ def _test_distrib_accumulator_device(device): metric_devices = [torch.device("cpu")] if device.type != "xla": - metric_devices.append(device) + metric_devices.append(idist.device()) for metric_device in metric_devices: acc = TopKCategoricalAccuracy(2, device=metric_device)