From 9f7daa1041afedcd5b8345cc47dd3f9554d912bb Mon Sep 17 00:00:00 2001 From: Nicholas Vadivelu Date: Fri, 7 Aug 2020 16:52:59 -0400 Subject: [PATCH 01/14] update accuracy to accumulate _num_correct in a tensor on the right device --- ignite/metrics/accuracy.py | 6 +++--- tests/ignite/metrics/test_accuracy.py | 19 +++++++++++++++++++ tests/ignite/metrics/test_running_average.py | 2 +- 3 files changed, 23 insertions(+), 4 deletions(-) diff --git a/ignite/metrics/accuracy.py b/ignite/metrics/accuracy.py index 8ac5a25f083f..c41053ef5b99 100644 --- a/ignite/metrics/accuracy.py +++ b/ignite/metrics/accuracy.py @@ -138,7 +138,7 @@ def __init__( @reinit__is_reduced def reset(self) -> None: - self._num_correct = 0 + self._num_correct = torch.tensor(0, device=self._device) self._num_examples = 0 super(Accuracy, self).reset() @@ -161,11 +161,11 @@ def update(self, output: Sequence[torch.Tensor]) -> None: y = torch.transpose(y, 1, last_dim - 1).reshape(-1, num_classes) correct = torch.all(y == y_pred.type_as(y), dim=-1) - self._num_correct += torch.sum(correct).item() + self._num_correct += torch.sum(correct).to(self._device) self._num_examples += correct.shape[0] @sync_all_reduce("_num_examples", "_num_correct") def compute(self) -> torch.Tensor: if self._num_examples == 0: raise NotComputableError("Accuracy must have at least one example before it can be computed.") - return self._num_correct / self._num_examples + return self._num_correct.item() / self._num_examples diff --git a/tests/ignite/metrics/test_accuracy.py b/tests/ignite/metrics/test_accuracy.py index 3ca32257d7bf..bd1a8901c848 100644 --- a/tests/ignite/metrics/test_accuracy.py +++ b/tests/ignite/metrics/test_accuracy.py @@ -778,6 +778,18 @@ def update(engine, i): _test(n_epochs=2) +def _test_distrib_accumulator_device(device): + device = torch.device(device) + acc = Accuracy(device=device) + assert acc._device == device + + y_pred = torch.randint(0, 2, size=(10,)).long() + y = torch.randint(0, 2, size=(10,)).long() + acc.update((y_pred, y)) + + assert acc._num_correct.device == device + + @pytest.mark.distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") @@ -786,6 +798,7 @@ def test_distrib_gpu(distributed_context_single_node_nccl): _test_distrib_multilabel_input_NHW(device) _test_distrib_integration_multiclass(device) _test_distrib_integration_multilabel(device) + _test_distrib_accumulator_device(device) @pytest.mark.distributed @@ -796,6 +809,7 @@ def test_distrib_cpu(distributed_context_single_node_gloo): _test_distrib_multilabel_input_NHW(device) _test_distrib_integration_multiclass(device) _test_distrib_integration_multilabel(device) + _test_distrib_accumulator_device(device) @pytest.mark.distributed @@ -809,6 +823,7 @@ def test_distrib_hvd(gloo_hvd_executor): gloo_hvd_executor(_test_distrib_multilabel_input_NHW, (device,), np=nproc, do_init=True) gloo_hvd_executor(_test_distrib_integration_multiclass, (device,), np=nproc, do_init=True) gloo_hvd_executor(_test_distrib_integration_multilabel, (device,), np=nproc, do_init=True) + gloo_hvd_executor(_test_distrib_accumulator_device, (device,), np=nproc, do_init=True) @pytest.mark.multinode_distributed @@ -819,6 +834,7 @@ def test_multinode_distrib_cpu(distributed_context_multi_node_gloo): _test_distrib_multilabel_input_NHW(device) _test_distrib_integration_multiclass(device) _test_distrib_integration_multilabel(device) + _test_distrib_accumulator_device(device) @pytest.mark.multinode_distributed @@ -829,6 +845,7 @@ def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): _test_distrib_multilabel_input_NHW(device) _test_distrib_integration_multiclass(device) _test_distrib_integration_multilabel(device) + _test_distrib_accumulator_device(device) @pytest.mark.tpu @@ -839,6 +856,7 @@ def test_distrib_single_device_xla(): _test_distrib_multilabel_input_NHW(device) _test_distrib_integration_multiclass(device) _test_distrib_integration_multilabel(device) + _test_distrib_accumulator_device(device) def _test_distrib_xla_nprocs(index): @@ -846,6 +864,7 @@ def _test_distrib_xla_nprocs(index): _test_distrib_multilabel_input_NHW(device) _test_distrib_integration_multiclass(device) _test_distrib_integration_multilabel(device) + _test_distrib_accumulator_device(device) @pytest.mark.tpu diff --git a/tests/ignite/metrics/test_running_average.py b/tests/ignite/metrics/test_running_average.py index c66fdfabdc5e..ab1528e10596 100644 --- a/tests/ignite/metrics/test_running_average.py +++ b/tests/ignite/metrics/test_running_average.py @@ -345,7 +345,7 @@ def manual_running_avg_acc(engine): ) true_acc_metric.update(output) - batch_acc = true_acc_metric._num_correct * 1.0 / true_acc_metric._num_examples + batch_acc = true_acc_metric._num_correct.item() * 1.0 / true_acc_metric._num_examples if running_avg_acc[0] is None: running_avg_acc[0] = batch_acc From a87f93de6594a66721d1c454070eedb2eff80d6d Mon Sep 17 00:00:00 2001 From: Nicholas Vadivelu Date: Fri, 7 Aug 2020 17:14:21 -0400 Subject: [PATCH 02/14] update loss metric to accumulate _sum in a tensor on the right device --- ignite/metrics/loss.py | 6 +++--- tests/ignite/metrics/test_loss.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/ignite/metrics/loss.py b/ignite/metrics/loss.py index 5a4133c84d95..f44ce3f6f193 100644 --- a/ignite/metrics/loss.py +++ b/ignite/metrics/loss.py @@ -45,7 +45,7 @@ def __init__( @reinit__is_reduced def reset(self) -> None: - self._sum = 0 + self._sum = torch.tensor(0.0, device=self._device) self._num_examples = 0 @reinit__is_reduced @@ -61,11 +61,11 @@ def update(self, output: Sequence[Union[torch.Tensor, dict]]) -> None: raise ValueError("loss_fn did not return the average loss.") n = self._batch_size(y) - self._sum += average_loss.item() * n + self._sum += average_loss.detach().to(self._device) * n self._num_examples += n @sync_all_reduce("_sum", "_num_examples") def compute(self) -> None: if self._num_examples == 0: raise NotComputableError("Loss must have at least one example before it can be computed.") - return self._sum / self._num_examples + return self._sum.item() / self._num_examples diff --git a/tests/ignite/metrics/test_loss.py b/tests/ignite/metrics/test_loss.py index 5244e7991d7d..4739bca16d1f 100644 --- a/tests/ignite/metrics/test_loss.py +++ b/tests/ignite/metrics/test_loss.py @@ -107,6 +107,28 @@ def _test_distrib_compute_on_criterion(device): assert_almost_equal(res, true_loss_value.item()) +def _test_distrib_sum_device(device): + device = torch.device(device) + loss = Loss(nll_loss, device=device) + assert loss._device == device + + y_pred = torch.tensor([[0.1, 0.4, 0.5], [0.1, 0.7, 0.2]]).log() + y = torch.tensor([2, 2]).long() + loss.update((y_pred, y)) + + assert loss._sum.device == device + + +def test_sum_detached(): + loss = Loss(nll_loss) + + y_pred = torch.tensor([[0.1, 0.4, 0.5], [0.1, 0.7, 0.2]], requires_grad=True).log() + y = torch.tensor([2, 2]).long() + loss.update((y_pred, y)) + + assert not loss._sum.requires_grad + + @pytest.mark.distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") @@ -114,6 +136,7 @@ def test_distrib_gpu(local_rank, distributed_context_single_node_nccl): device = "cuda:{}".format(local_rank) _test_distrib_compute_on_criterion(device) + _test_distrib_sum_device(device) @pytest.mark.distributed @@ -122,6 +145,7 @@ def test_distrib_cpu(distributed_context_single_node_gloo): device = "cpu" _test_distrib_compute_on_criterion(device) + _test_distrib_sum_device(device) @pytest.mark.distributed @@ -133,6 +157,7 @@ def test_distrib_hvd(gloo_hvd_executor): nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count() gloo_hvd_executor(_test_distrib_compute_on_criterion, (device,), np=nproc, do_init=True) + gloo_hvd_executor(_test_distrib_sum_device, (device,), np=nproc, do_init=True) @pytest.mark.multinode_distributed @@ -141,6 +166,7 @@ def test_distrib_hvd(gloo_hvd_executor): def test_multinode_distrib_cpu(distributed_context_multi_node_gloo): device = "cpu" _test_distrib_compute_on_criterion(device) + _test_distrib_sum_device(device) @pytest.mark.multinode_distributed @@ -149,6 +175,7 @@ def test_multinode_distrib_cpu(distributed_context_multi_node_gloo): def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): device = "cuda:{}".format(distributed_context_multi_node_nccl["local_rank"]) _test_distrib_compute_on_criterion(device) + _test_distrib_sum_device(device) @pytest.mark.tpu @@ -157,6 +184,7 @@ def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): def test_distrib_single_device_xla(): device = idist.device() _test_distrib_compute_on_criterion(device) + _test_distrib_sum_device(device) def _test_distrib_xla_nprocs(index): From 30b2e19e400aca12cc0b8e07dc6bc72f50df5ab1 Mon Sep 17 00:00:00 2001 From: Nicholas Vadivelu Date: Fri, 7 Aug 2020 17:25:27 -0400 Subject: [PATCH 03/14] update mae metric to accumulate in a tensor on the right device --- ignite/metrics/mean_absolute_error.py | 6 ++-- .../metrics/test_mean_absolute_error.py | 28 +++++++++++++++++++ 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/ignite/metrics/mean_absolute_error.py b/ignite/metrics/mean_absolute_error.py index 86e699be096f..c3c9a9a9ed2f 100644 --- a/ignite/metrics/mean_absolute_error.py +++ b/ignite/metrics/mean_absolute_error.py @@ -17,18 +17,18 @@ class MeanAbsoluteError(Metric): @reinit__is_reduced def reset(self) -> None: - self._sum_of_absolute_errors = 0.0 + self._sum_of_absolute_errors = torch.tensor(0.0, device=self._device) self._num_examples = 0 @reinit__is_reduced def update(self, output: Sequence[torch.Tensor]) -> None: y_pred, y = output absolute_errors = torch.abs(y_pred - y.view_as(y_pred)) - self._sum_of_absolute_errors += torch.sum(absolute_errors).item() + self._sum_of_absolute_errors += torch.sum(absolute_errors).detach().to(self._device) self._num_examples += y.shape[0] @sync_all_reduce("_sum_of_absolute_errors", "_num_examples") def compute(self) -> Union[float, torch.Tensor]: if self._num_examples == 0: raise NotComputableError("MeanAbsoluteError must have at least one example before it can be computed.") - return self._sum_of_absolute_errors / self._num_examples + return self._sum_of_absolute_errors.item() / self._num_examples diff --git a/tests/ignite/metrics/test_mean_absolute_error.py b/tests/ignite/metrics/test_mean_absolute_error.py index 8279470dd07e..557c4182f3b4 100644 --- a/tests/ignite/metrics/test_mean_absolute_error.py +++ b/tests/ignite/metrics/test_mean_absolute_error.py @@ -65,12 +65,34 @@ def update(engine, i): assert pytest.approx(res) == true_res +def _test_distrib_accumulator_device(device): + device = torch.device(device) + mae = MeanAbsoluteError(device=device) + assert mae._device == device + + y_pred = torch.tensor([[2.0], [-2.0]]) + y = torch.zeros(2) + mae.update((y_pred, y)) + assert mae._sum_of_absolute_errors.device == device + + +def test_accumulator_detached(): + mae = MeanAbsoluteError() + + y_pred = torch.tensor([[2.0], [-2.0]], requires_grad=True) + y = torch.zeros(2) + mae.update((y_pred, y)) + + assert not mae._sum_of_absolute_errors.requires_grad + + @pytest.mark.distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") def test_distrib_gpu(local_rank, distributed_context_single_node_nccl): device = "cuda:{}".format(local_rank) _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.distributed @@ -78,6 +100,7 @@ def test_distrib_gpu(local_rank, distributed_context_single_node_nccl): def test_distrib_cpu(distributed_context_single_node_gloo): device = "cpu" _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.distributed @@ -89,6 +112,7 @@ def test_distrib_hvd(gloo_hvd_executor): nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count() gloo_hvd_executor(_test_distrib_integration, (device,), np=nproc, do_init=True) + gloo_hvd_executor(_test_distrib_accumulator_device, (device,), np=nproc, do_init=True) @pytest.mark.multinode_distributed @@ -97,6 +121,7 @@ def test_distrib_hvd(gloo_hvd_executor): def test_multinode_distrib_cpu(distributed_context_multi_node_gloo): device = "cpu" _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.multinode_distributed @@ -105,6 +130,7 @@ def test_multinode_distrib_cpu(distributed_context_multi_node_gloo): def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): device = "cuda:{}".format(distributed_context_multi_node_nccl["local_rank"]) _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.tpu @@ -113,11 +139,13 @@ def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): def test_distrib_single_device_xla(): device = idist.device() _test_distrib_integration(device) + _test_distrib_accumulator_device(device) def _test_distrib_xla_nprocs(index): device = idist.device() _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.tpu From a3e237c42ecd1860ea31d6f539c9678cdbb36267 Mon Sep 17 00:00:00 2001 From: Nicholas Vadivelu Date: Fri, 7 Aug 2020 17:37:40 -0400 Subject: [PATCH 04/14] update mpd metric to accumulate in a tensor on the right device --- ignite/metrics/mean_pairwise_distance.py | 6 ++-- .../metrics/test_mean_pairwise_distance.py | 29 +++++++++++++++++++ 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/ignite/metrics/mean_pairwise_distance.py b/ignite/metrics/mean_pairwise_distance.py index 9e9239ee6553..b1f1c9a52c26 100644 --- a/ignite/metrics/mean_pairwise_distance.py +++ b/ignite/metrics/mean_pairwise_distance.py @@ -29,18 +29,18 @@ def __init__( @reinit__is_reduced def reset(self): - self._sum_of_distances = 0.0 + self._sum_of_distances = torch.tensor(0.0, device=self._device) self._num_examples = 0 @reinit__is_reduced def update(self, output: Sequence[torch.Tensor]) -> None: y_pred, y = output distances = pairwise_distance(y_pred, y, p=self._p, eps=self._eps) - self._sum_of_distances += torch.sum(distances).item() + self._sum_of_distances += torch.sum(distances).detach().to(self._device) self._num_examples += y.shape[0] @sync_all_reduce("_sum_of_distances", "_num_examples") def compute(self) -> Union[float, torch.Tensor]: if self._num_examples == 0: raise NotComputableError("MeanAbsoluteError must have at least one example before it can be computed.") - return self._sum_of_distances / self._num_examples + return self._sum_of_distances.item() / self._num_examples diff --git a/tests/ignite/metrics/test_mean_pairwise_distance.py b/tests/ignite/metrics/test_mean_pairwise_distance.py index 7ada0b5474a5..f5ea4e731ef7 100644 --- a/tests/ignite/metrics/test_mean_pairwise_distance.py +++ b/tests/ignite/metrics/test_mean_pairwise_distance.py @@ -78,12 +78,35 @@ def update(engine, i): assert pytest.approx(res) == true_res +def _test_distrib_accumulator_device(device): + device = torch.device(device) + mpd = MeanPairwiseDistance(device=device) + assert mpd._device == device + + y_pred = torch.Tensor([[3.0, 4.0], [-3.0, -4.0]]) + y = torch.zeros(2, 2) + mpd.update((y_pred, y)) + + assert mpd._sum_of_distances.device == device + + +def test_accumulator_detached(): + mpd = MeanPairwiseDistance() + + y_pred = torch.tensor([[3.0, 4.0], [-3.0, -4.0]], requires_grad=True) + y = torch.zeros(2, 2) + mpd.update((y_pred, y)) + + assert not mpd._sum_of_distances.requires_grad + + @pytest.mark.distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") def test_distrib_gpu(local_rank, distributed_context_single_node_nccl): device = "cuda:{}".format(local_rank) _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.distributed @@ -91,6 +114,7 @@ def test_distrib_gpu(local_rank, distributed_context_single_node_nccl): def test_distrib_cpu(distributed_context_single_node_gloo): device = "cpu" _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.distributed @@ -102,6 +126,7 @@ def test_distrib_hvd(gloo_hvd_executor): nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count() gloo_hvd_executor(_test_distrib_integration, (device,), np=nproc, do_init=True) + gloo_hvd_executor(_test_distrib_accumulator_device, (device,), np=nproc, do_init=True) @pytest.mark.multinode_distributed @@ -110,6 +135,7 @@ def test_distrib_hvd(gloo_hvd_executor): def test_multinode_distrib_cpu(distributed_context_multi_node_gloo): device = "cpu" _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.multinode_distributed @@ -118,6 +144,7 @@ def test_multinode_distrib_cpu(distributed_context_multi_node_gloo): def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): device = "cuda:{}".format(distributed_context_multi_node_nccl["local_rank"]) _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.tpu @@ -126,11 +153,13 @@ def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): def test_distrib_single_device_xla(): device = idist.device() _test_distrib_integration(device) + _test_distrib_accumulator_device(device) def _test_distrib_xla_nprocs(index): device = idist.device() _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.tpu From 71001760f91c5cb390ded84216461402e2f48557 Mon Sep 17 00:00:00 2001 From: Nicholas Vadivelu Date: Fri, 7 Aug 2020 17:41:45 -0400 Subject: [PATCH 05/14] update mse metric to accumulate in a tensor on the right device --- ignite/metrics/mean_squared_error.py | 6 ++-- .../ignite/metrics/test_mean_squared_error.py | 28 +++++++++++++++++++ 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/ignite/metrics/mean_squared_error.py b/ignite/metrics/mean_squared_error.py index 4c5a9ee3371c..29d1dc07639c 100644 --- a/ignite/metrics/mean_squared_error.py +++ b/ignite/metrics/mean_squared_error.py @@ -17,18 +17,18 @@ class MeanSquaredError(Metric): @reinit__is_reduced def reset(self) -> None: - self._sum_of_squared_errors = 0.0 + self._sum_of_squared_errors = torch.tensor(0.0, device=self._device) self._num_examples = 0 @reinit__is_reduced def update(self, output: Sequence[torch.Tensor]) -> None: y_pred, y = output squared_errors = torch.pow(y_pred - y.view_as(y_pred), 2) - self._sum_of_squared_errors += torch.sum(squared_errors).item() + self._sum_of_squared_errors += torch.sum(squared_errors).detach().to(self._device) self._num_examples += y.shape[0] @sync_all_reduce("_sum_of_squared_errors", "_num_examples") def compute(self) -> Union[float, torch.Tensor]: if self._num_examples == 0: raise NotComputableError("MeanSquaredError must have at least one example before it can be computed.") - return self._sum_of_squared_errors / self._num_examples + return self._sum_of_squared_errors.item() / self._num_examples diff --git a/tests/ignite/metrics/test_mean_squared_error.py b/tests/ignite/metrics/test_mean_squared_error.py index 59ce1fdc567d..08552836531f 100644 --- a/tests/ignite/metrics/test_mean_squared_error.py +++ b/tests/ignite/metrics/test_mean_squared_error.py @@ -65,6 +65,27 @@ def update(engine, i): assert pytest.approx(res, rel=tol) == true_res +def _test_distrib_accumulator_device(device): + device = torch.device(device) + mse = MeanSquaredError(device=device) + assert mse._device == device + + y_pred = torch.tensor([[2.0], [-2.0]]) + y = torch.zeros(2) + mse.update((y_pred, y)) + assert mse._sum_of_squared_errors.device == device + + +def test_accumulator_detached(): + mse = MeanSquaredError() + + y_pred = torch.tensor([[2.0], [-2.0]], requires_grad=True) + y = torch.zeros(2) + mse.update((y_pred, y)) + + assert not mse._sum_of_squared_errors.requires_grad + + @pytest.mark.distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") @@ -72,6 +93,7 @@ def test_distrib_gpu(local_rank, distributed_context_single_node_nccl): device = "cuda:{}".format(local_rank) _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.distributed @@ -79,6 +101,7 @@ def test_distrib_gpu(local_rank, distributed_context_single_node_nccl): def test_distrib_cpu(distributed_context_single_node_gloo): device = "cpu" _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.distributed @@ -90,6 +113,7 @@ def test_distrib_hvd(gloo_hvd_executor): nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count() gloo_hvd_executor(_test_distrib_integration, (device,), np=nproc, do_init=True) + gloo_hvd_executor(_test_distrib_accumulator_device, (device,), np=nproc, do_init=True) @pytest.mark.multinode_distributed @@ -98,6 +122,7 @@ def test_distrib_hvd(gloo_hvd_executor): def test_multinode_distrib_cpu(distributed_context_multi_node_gloo): device = "cpu" _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.multinode_distributed @@ -106,6 +131,7 @@ def test_multinode_distrib_cpu(distributed_context_multi_node_gloo): def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): device = "cuda:{}".format(distributed_context_multi_node_nccl["local_rank"]) _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.tpu @@ -114,11 +140,13 @@ def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): def test_distrib_single_device_xla(): device = idist.device() _test_distrib_integration(device, tol=1e-4) + _test_distrib_accumulator_device(device) def _test_distrib_xla_nprocs(index): device = idist.device() _test_distrib_integration(device, tol=1e-4) + _test_distrib_accumulator_device(device) @pytest.mark.tpu From 3228a0abc270af213c1fbd7a60189d2fefa7a54a Mon Sep 17 00:00:00 2001 From: Nicholas Vadivelu Date: Fri, 7 Aug 2020 17:53:19 -0400 Subject: [PATCH 06/14] update top k accuracy metric to accumulate in a tensor on the right device --- ignite/metrics/top_k_categorical_accuracy.py | 6 +++--- .../metrics/test_top_k_categorical_accuracy.py | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/ignite/metrics/top_k_categorical_accuracy.py b/ignite/metrics/top_k_categorical_accuracy.py index 3fb493ed8441..6d33d51001c5 100644 --- a/ignite/metrics/top_k_categorical_accuracy.py +++ b/ignite/metrics/top_k_categorical_accuracy.py @@ -23,7 +23,7 @@ def __init__( @reinit__is_reduced def reset(self) -> None: - self._num_correct = 0 + self._num_correct = torch.tensor(0, device=self._device) self._num_examples = 0 @reinit__is_reduced @@ -32,7 +32,7 @@ def update(self, output: Sequence) -> None: sorted_indices = torch.topk(y_pred, self._k, dim=1)[1] expanded_y = y.view(-1, 1).expand(-1, self._k) correct = torch.sum(torch.eq(sorted_indices, expanded_y), dim=1) - self._num_correct += torch.sum(correct).item() + self._num_correct += torch.sum(correct).to(self._device) self._num_examples += correct.shape[0] @sync_all_reduce("_num_correct", "_num_examples") @@ -41,4 +41,4 @@ def compute(self) -> Union[float, torch.Tensor]: raise NotComputableError( "TopKCategoricalAccuracy must have at" "least one example before it can be computed." ) - return self._num_correct / self._num_examples + return self._num_correct.item() / self._num_examples diff --git a/tests/ignite/metrics/test_top_k_categorical_accuracy.py b/tests/ignite/metrics/test_top_k_categorical_accuracy.py index 6caf39a08f71..0f22d2f6b697 100644 --- a/tests/ignite/metrics/test_top_k_categorical_accuracy.py +++ b/tests/ignite/metrics/test_top_k_categorical_accuracy.py @@ -99,12 +99,24 @@ def update(engine, i): _test(n_epochs=2) +def _test_distrib_accumulator_device(device): + device = torch.device(device) + acc = TopKCategoricalAccuracy(2, device=device) + assert acc._device == device + + y_pred = torch.tensor([[0.2, 0.4, 0.6, 0.8], [0.8, 0.6, 0.4, 0.2]]) + y = torch.ones(2).long() + acc.update((y_pred, y)) + assert acc._num_correct.device == device + + @pytest.mark.distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") def test_distrib_gpu(local_rank, distributed_context_single_node_nccl): device = "cuda:{}".format(local_rank) _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.distributed @@ -112,6 +124,7 @@ def test_distrib_gpu(local_rank, distributed_context_single_node_nccl): def test_distrib_cpu(local_rank, distributed_context_single_node_gloo): device = "cpu" _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.distributed @@ -123,6 +136,7 @@ def test_distrib_hvd(gloo_hvd_executor): nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count() gloo_hvd_executor(_test_distrib_integration, (device,), np=nproc, do_init=True) + gloo_hvd_executor(_test_distrib_accumulator_device, (device,), np=nproc, do_init=True) @pytest.mark.multinode_distributed @@ -131,6 +145,7 @@ def test_distrib_hvd(gloo_hvd_executor): def test_multinode_distrib_cpu(distributed_context_multi_node_gloo): device = "cpu" _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.multinode_distributed @@ -139,6 +154,7 @@ def test_multinode_distrib_cpu(distributed_context_multi_node_gloo): def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): device = "cuda:{}".format(distributed_context_multi_node_nccl["local_rank"]) _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.tpu @@ -147,11 +163,13 @@ def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): def test_distrib_single_device_xla(): device = idist.device() _test_distrib_integration(device) + _test_distrib_accumulator_device(device) def _test_distrib_xla_nprocs(index): device = idist.device() _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.tpu From 412551edc5161e64632911225e4ace06164ddc85 Mon Sep 17 00:00:00 2001 From: Nicholas Vadivelu Date: Fri, 7 Aug 2020 23:08:21 -0400 Subject: [PATCH 07/14] update precision and recall metrics to accumulate in tensors on the right device --- ignite/metrics/precision.py | 28 +++++++++------ ignite/metrics/recall.py | 12 +++---- tests/ignite/metrics/test_precision.py | 50 ++++++++++++++++++++++++++ tests/ignite/metrics/test_recall.py | 50 ++++++++++++++++++++++++++ 4 files changed, 122 insertions(+), 18 deletions(-) diff --git a/ignite/metrics/precision.py b/ignite/metrics/precision.py index 2b8152630ddb..aa84cff2f453 100644 --- a/ignite/metrics/precision.py +++ b/ignite/metrics/precision.py @@ -18,7 +18,7 @@ def __init__( output_transform: Callable = lambda x: x, average: bool = False, is_multilabel: bool = False, - device: Optional[Union[str, torch.device]] = None, + device: Optional[Union[str, torch.device]] = torch.device("cpu"), ): if idist.get_world_size() > 1: if (not average) and is_multilabel: @@ -39,13 +39,20 @@ def __init__( @reinit__is_reduced def reset(self) -> None: - dtype = torch.float64 - self._true_positives = torch.tensor([], dtype=dtype) if (self._is_multilabel and not self._average) else 0 - self._positives = torch.tensor([], dtype=dtype) if (self._is_multilabel and not self._average) else 0 + if self._is_multilabel: + init_value = 0.0 if self._average else [] + kws = {'dtype': torch.float64, 'device': self._device} + self._true_positives = torch.tensor(init_value, **kws) + self._positives = torch.tensor(init_value, **kws) + else: + self._true_positives = 0 + self._positives = 0 + super(_BasePrecisionRecall, self).reset() def compute(self) -> Union[torch.Tensor, float]: - if not (isinstance(self._positives, torch.Tensor) or self._positives > 0): + is_scalar = not isinstance(self._positives, torch.Tensor) or self._positives.ndim == 0 + if is_scalar and self._positives == 0: raise NotComputableError( "{} must have at least one example before" " it can be computed.".format(self.__class__.__name__) ) @@ -124,7 +131,7 @@ def __init__( output_transform: Callable = lambda x: x, average: bool = False, is_multilabel: bool = False, - device: Optional[Union[str, torch.device]] = None, + device: Optional[Union[str, torch.device]] = torch.device("cpu"), ): super(Precision, self).__init__( output_transform=output_transform, average=average, is_multilabel=is_multilabel, device=device @@ -155,17 +162,16 @@ def update(self, output: Sequence[torch.Tensor]) -> None: y_pred = torch.transpose(y_pred, 1, 0).reshape(num_classes, -1) y = torch.transpose(y, 1, 0).reshape(num_classes, -1) - y = y.to(y_pred) + # Convert from int cuda/cpu to double on self._device + y_pred = y_pred.to(dtype=torch.float64, device=self._device) + y = y.to(dtype=torch.float64, device=self._device) correct = y * y_pred - all_positives = y_pred.sum(dim=0).type(torch.DoubleTensor) # Convert from int cuda/cpu to double cpu + all_positives = y_pred.sum(dim=0) if correct.sum() == 0: true_positives = torch.zeros_like(all_positives) else: true_positives = correct.sum(dim=0) - # Convert from int cuda/cpu to double cpu - # We need double precision for the division true_positives / all_positives - true_positives = true_positives.type(torch.DoubleTensor) if self._type == "multilabel": if not self._average: diff --git a/ignite/metrics/recall.py b/ignite/metrics/recall.py index 048c11b10c5b..c3378d3f3f6a 100644 --- a/ignite/metrics/recall.py +++ b/ignite/metrics/recall.py @@ -69,7 +69,7 @@ def __init__( output_transform: Callable = lambda x: x, average: bool = False, is_multilabel: bool = False, - device: Optional[Union[str, torch.device]] = None, + device: Optional[Union[str, torch.device]] = torch.device("cpu"), ): super(Recall, self).__init__( output_transform=output_transform, average=average, is_multilabel=is_multilabel, device=device @@ -100,19 +100,17 @@ def update(self, output: Sequence[torch.Tensor]) -> None: y_pred = torch.transpose(y_pred, 1, 0).reshape(num_classes, -1) y = torch.transpose(y, 1, 0).reshape(num_classes, -1) - y = y.type_as(y_pred) + # Convert from int cuda/cpu to double on self._device + y_pred = y_pred.to(dtype=torch.float64, device=self._device) + y = y.to(dtype=torch.float64, device=self._device) correct = y * y_pred - actual_positives = y.sum(dim=0).type(torch.DoubleTensor) # Convert from int cuda/cpu to double cpu + actual_positives = y.sum(dim=0) if correct.sum() == 0: true_positives = torch.zeros_like(actual_positives) else: true_positives = correct.sum(dim=0) - # Convert from int cuda/cpu to double cpu - # We need double precision for the division true_positives / actual_positives - true_positives = true_positives.type(torch.DoubleTensor) - if self._type == "multilabel": if not self._average: self._true_positives = torch.cat([self._true_positives, true_positives], dim=0) diff --git a/tests/ignite/metrics/test_precision.py b/tests/ignite/metrics/test_precision.py index 94f1d643585f..9ea0741fc835 100644 --- a/tests/ignite/metrics/test_precision.py +++ b/tests/ignite/metrics/test_precision.py @@ -833,6 +833,42 @@ def update(engine, i): assert (pr_compute1 == pr_compute2).all() +def _test_distrib_accumulator_device(device): + # Binary accuracy on input of shape (N, 1) or (N, ) + device = torch.device(device) + + def _test(average): + pr = Precision(average=average, device=device) + assert pr._device == device + + y_pred = torch.randint(0, 2, size=(10,)) + y = torch.randint(0, 2, size=(10,)).long() + pr.update((y_pred, y)) + + assert pr._true_positives.device == device + assert pr._positives.device == device + + _test(True) + _test(False) + + +def _test_distrib_multilabel_accumulator_device(device): + # Multiclass input data of shape (N, ) and (N, C) + device = torch.device(device) + + def _test(average): + pr = Precision(is_multilabel=True, average=average, device=device) + y_pred = torch.randint(0, 2, size=(10, 4, 20, 23)) + y = torch.randint(0, 2, size=(10, 4, 20, 23)).long() + pr.update((y_pred, y)) + + assert pr._true_positives.device == device + assert pr._positives.device == device + + _test(True) + _test(False) + + @pytest.mark.distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") @@ -840,6 +876,8 @@ def test_distrib_gpu(local_rank, distributed_context_single_node_nccl): device = "cuda:{}".format(local_rank) _test_distrib_integration_multiclass(device) _test_distrib_integration_multilabel(device) + _test_distrib_accumulator_device(device) + _test_distrib_multilabel_accumulator_device(device) @pytest.mark.distributed @@ -848,6 +886,8 @@ def test_distrib_cpu(local_rank, distributed_context_single_node_gloo): device = "cpu" _test_distrib_integration_multiclass(device) _test_distrib_integration_multilabel(device) + _test_distrib_accumulator_device(device) + _test_distrib_multilabel_accumulator_device(device) @pytest.mark.distributed @@ -860,6 +900,8 @@ def test_distrib_hvd(gloo_hvd_executor): gloo_hvd_executor(_test_distrib_integration_multiclass, (device,), np=nproc, do_init=True) gloo_hvd_executor(_test_distrib_integration_multilabel, (device,), np=nproc, do_init=True) + gloo_hvd_executor(_test_distrib_accumulator_device, (device,), np=nproc, do_init=True) + gloo_hvd_executor(_test_distrib_integration_multilabel, (device,), np=nproc, do_init=True) @pytest.mark.multinode_distributed @@ -869,6 +911,8 @@ def test_multinode_distrib_cpu(distributed_context_multi_node_gloo): device = "cpu" _test_distrib_integration_multiclass(device) _test_distrib_integration_multilabel(device) + _test_distrib_accumulator_device(device) + _test_distrib_multilabel_accumulator_device(device) @pytest.mark.multinode_distributed @@ -878,6 +922,8 @@ def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): device = "cuda:{}".format(distributed_context_multi_node_nccl["local_rank"]) _test_distrib_integration_multiclass(device) _test_distrib_integration_multilabel(device) + _test_distrib_accumulator_device(device) + _test_distrib_multilabel_accumulator_device(device) @pytest.mark.tpu @@ -887,12 +933,16 @@ def test_distrib_single_device_xla(): device = idist.device() _test_distrib_integration_multiclass(device) _test_distrib_integration_multilabel(device) + _test_distrib_accumulator_device(device) + _test_distrib_multilabel_accumulator_device(device) def _test_distrib_xla_nprocs(index): device = idist.device() _test_distrib_integration_multiclass(device) _test_distrib_integration_multilabel(device) + _test_distrib_accumulator_device(device) + _test_distrib_multilabel_accumulator_device(device) @pytest.mark.tpu diff --git a/tests/ignite/metrics/test_recall.py b/tests/ignite/metrics/test_recall.py index 214fc43c4314..b783b2ec83ce 100644 --- a/tests/ignite/metrics/test_recall.py +++ b/tests/ignite/metrics/test_recall.py @@ -833,6 +833,42 @@ def update(engine, i): assert (re_compute1 == re_compute2).all() +def _test_distrib_accumulator_device(device): + # Binary accuracy on input of shape (N, 1) or (N, ) + device = torch.device(device) + + def _test(average): + re = Recall(average=average, device=device) + assert re._device == device + + y_reed = torch.randint(0, 2, size=(10,)) + y = torch.randint(0, 2, size=(10,)).long() + re.update((y_reed, y)) + + assert re._true_positives.device == device + assert re._positives.device == device + + _test(True) + _test(False) + + +def _test_distrib_multilabel_accumulator_device(device): + # Multiclass input data of shape (N, ) and (N, C) + device = torch.device(device) + + def _test(average): + re = Recall(is_multilabel=True, average=average, device=device) + y_reed = torch.randint(0, 2, size=(10, 4, 20, 23)) + y = torch.randint(0, 2, size=(10, 4, 20, 23)).long() + re.update((y_reed, y)) + + assert re._true_positives.device == device + assert re._positives.device == device + + _test(True) + _test(False) + + @pytest.mark.distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") @@ -840,6 +876,8 @@ def test_distrib_gpu(local_rank, distributed_context_single_node_nccl): device = "cuda:{}".format(local_rank) _test_distrib_integration_multiclass(device) _test_distrib_integration_multilabel(device) + _test_distrib_accumulator_device(device) + _test_distrib_multilabel_accumulator_device(device) @pytest.mark.distributed @@ -848,6 +886,8 @@ def test_distrib_cpu(distributed_context_single_node_gloo): device = "cpu" _test_distrib_integration_multiclass(device) _test_distrib_integration_multilabel(device) + _test_distrib_accumulator_device(device) + _test_distrib_multilabel_accumulator_device(device) @pytest.mark.distributed @@ -860,6 +900,8 @@ def test_distrib_hvd(gloo_hvd_executor): gloo_hvd_executor(_test_distrib_integration_multiclass, (device,), np=nproc, do_init=True) gloo_hvd_executor(_test_distrib_integration_multilabel, (device,), np=nproc, do_init=True) + gloo_hvd_executor(_test_distrib_accumulator_device, (device,), np=nproc, do_init=True) + gloo_hvd_executor(_test_distrib_multilabel_accumulator_device, (device,), np=nproc, do_init=True) @pytest.mark.multinode_distributed @@ -869,6 +911,8 @@ def test_multinode_distrib_cpu(distributed_context_multi_node_gloo): device = "cpu" _test_distrib_integration_multiclass(device) _test_distrib_integration_multilabel(device) + _test_distrib_accumulator_device(device) + _test_distrib_multilabel_accumulator_device(device) @pytest.mark.multinode_distributed @@ -878,6 +922,8 @@ def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): device = "cuda:{}".format(distributed_context_multi_node_nccl["local_rank"]) _test_distrib_integration_multiclass(device) _test_distrib_integration_multilabel(device) + _test_distrib_accumulator_device(device) + _test_distrib_multilabel_accumulator_device(device) @pytest.mark.tpu @@ -887,12 +933,16 @@ def test_distrib_single_device_xla(): device = idist.device() _test_distrib_integration_multiclass(device) _test_distrib_integration_multilabel(device) + _test_distrib_accumulator_device(device) + _test_distrib_multilabel_accumulator_device(device) def _test_distrib_xla_nprocs(index): device = idist.device() _test_distrib_integration_multiclass(device) _test_distrib_integration_multilabel(device) + _test_distrib_accumulator_device(device) + _test_distrib_multilabel_accumulator_device(device) @pytest.mark.tpu From 4c4a76cffb40d795d27d7018f90c175a49c5736e Mon Sep 17 00:00:00 2001 From: Nicholas Vadivelu Date: Fri, 7 Aug 2020 23:08:30 -0400 Subject: [PATCH 08/14] ..... --- tests/run_cpu_tests.sh | 0 tests/run_gpu_tests.sh | 0 2 files changed, 0 insertions(+), 0 deletions(-) mode change 100644 => 100755 tests/run_cpu_tests.sh mode change 100644 => 100755 tests/run_gpu_tests.sh diff --git a/tests/run_cpu_tests.sh b/tests/run_cpu_tests.sh old mode 100644 new mode 100755 diff --git a/tests/run_gpu_tests.sh b/tests/run_gpu_tests.sh old mode 100644 new mode 100755 From b1e6956814b8f551231683bf6cb8c72bbfce0586 Mon Sep 17 00:00:00 2001 From: Nicholas Vadivelu Date: Fri, 7 Aug 2020 23:09:03 -0400 Subject: [PATCH 09/14] black formatting --- ignite/metrics/precision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ignite/metrics/precision.py b/ignite/metrics/precision.py index aa84cff2f453..b8d920d61964 100644 --- a/ignite/metrics/precision.py +++ b/ignite/metrics/precision.py @@ -41,7 +41,7 @@ def __init__( def reset(self) -> None: if self._is_multilabel: init_value = 0.0 if self._average else [] - kws = {'dtype': torch.float64, 'device': self._device} + kws = {"dtype": torch.float64, "device": self._device} self._true_positives = torch.tensor(init_value, **kws) self._positives = torch.tensor(init_value, **kws) else: From b081e92df7a40ebea446adb0c06c23f9bbe9369a Mon Sep 17 00:00:00 2001 From: Nicholas Vadivelu Date: Mon, 10 Aug 2020 00:06:44 -0400 Subject: [PATCH 10/14] reverted run*.sh --- tests/run_cpu_tests.sh | 0 tests/run_gpu_tests.sh | 0 2 files changed, 0 insertions(+), 0 deletions(-) mode change 100755 => 100644 tests/run_cpu_tests.sh mode change 100755 => 100644 tests/run_gpu_tests.sh diff --git a/tests/run_cpu_tests.sh b/tests/run_cpu_tests.sh old mode 100755 new mode 100644 diff --git a/tests/run_gpu_tests.sh b/tests/run_gpu_tests.sh old mode 100755 new mode 100644 From a343c358923acea1448f85967e9ac57ddd31084b Mon Sep 17 00:00:00 2001 From: Nicholas Vadivelu Date: Sun, 16 Aug 2020 03:42:08 -0400 Subject: [PATCH 11/14] change all metrics default device to cpu except running_average --- ignite/metrics/accumulation.py | 13 ++++++++++--- ignite/metrics/accuracy.py | 4 ++-- ignite/metrics/confusion_matrix.py | 2 +- ignite/metrics/fbeta.py | 2 +- ignite/metrics/frequency.py | 6 +++++- ignite/metrics/loss.py | 2 +- ignite/metrics/mean_pairwise_distance.py | 2 +- ignite/metrics/metric.py | 4 +++- ignite/metrics/top_k_categorical_accuracy.py | 5 ++++- 9 files changed, 28 insertions(+), 12 deletions(-) diff --git a/ignite/metrics/accumulation.py b/ignite/metrics/accumulation.py index dff45ee87fcc..e72debfa98c3 100644 --- a/ignite/metrics/accumulation.py +++ b/ignite/metrics/accumulation.py @@ -38,7 +38,10 @@ class VariableAccumulation(Metric): _required_output_keys = None def __init__( - self, op: Callable, output_transform: Callable = lambda x: x, device: Optional[Union[str, torch.device]] = None + self, + op: Callable, + output_transform: Callable = lambda x: x, + device: Optional[Union[str, torch.device]] = torch.device("cpu"), ): if not callable(op): raise TypeError("Argument op should be a callable, but given {}".format(type(op))) @@ -115,7 +118,9 @@ class Average(VariableAccumulation): """ - def __init__(self, output_transform: Callable = lambda x: x, device: Optional[Union[str, torch.device]] = None): + def __init__( + self, output_transform: Callable = lambda x: x, device: Optional[Union[str, torch.device]] = torch.device("cpu") + ): def _mean_op(a, x): if isinstance(x, torch.Tensor) and x.ndim > 1: x = x.sum(dim=0) @@ -159,7 +164,9 @@ class GeometricAverage(VariableAccumulation): """ - def __init__(self, output_transform: Callable = lambda x: x, device: Optional[Union[str, torch.device]] = None): + def __init__( + self, output_transform: Callable = lambda x: x, device: Optional[Union[str, torch.device]] = torch.device("cpu") + ): def _geom_op(a: torch.Tensor, x: Union[Any, numbers.Number, torch.Tensor]) -> torch.Tensor: if not isinstance(x, torch.Tensor): x = torch.tensor(x) diff --git a/ignite/metrics/accuracy.py b/ignite/metrics/accuracy.py index c41053ef5b99..c79d93b79c64 100644 --- a/ignite/metrics/accuracy.py +++ b/ignite/metrics/accuracy.py @@ -13,7 +13,7 @@ def __init__( self, output_transform: Callable = lambda x: x, is_multilabel: bool = False, - device: Optional[Union[str, torch.device]] = None, + device: Optional[Union[str, torch.device]] = torch.device("cpu"), ): self._is_multilabel = is_multilabel self._type = None @@ -130,7 +130,7 @@ def __init__( self, output_transform: Callable = lambda x: x, is_multilabel: bool = False, - device: Optional[Union[str, torch.device]] = None, + device: Optional[Union[str, torch.device]] = torch.device("cpu"), ): self._num_correct = None self._num_examples = None diff --git a/ignite/metrics/confusion_matrix.py b/ignite/metrics/confusion_matrix.py index 2ab1f436bace..ab9b188a6880 100644 --- a/ignite/metrics/confusion_matrix.py +++ b/ignite/metrics/confusion_matrix.py @@ -44,7 +44,7 @@ def __init__( num_classes: int, average: Optional[str] = None, output_transform: Callable = lambda x: x, - device: Optional[Union[str, torch.device]] = None, + device: Optional[Union[str, torch.device]] = torch.device("cpu"), ): if average is not None and average not in ("samples", "recall", "precision"): raise ValueError("Argument average can None or one of 'samples', 'recall', 'precision'") diff --git a/ignite/metrics/fbeta.py b/ignite/metrics/fbeta.py index 05e217846115..1383b520364d 100644 --- a/ignite/metrics/fbeta.py +++ b/ignite/metrics/fbeta.py @@ -15,7 +15,7 @@ def Fbeta( precision: Optional[Precision] = None, recall: Optional[Recall] = None, output_transform: Optional[Callable] = None, - device: Optional[Union[str, torch.device]] = None, + device: Optional[Union[str, torch.device]] = torch.device("cpu"), ) -> MetricsLambda: """Calculates F-beta score diff --git a/ignite/metrics/frequency.py b/ignite/metrics/frequency.py index 75eba360bb53..fa91e36df0b1 100644 --- a/ignite/metrics/frequency.py +++ b/ignite/metrics/frequency.py @@ -1,3 +1,5 @@ +from typing import Callable, Optional, Union + import torch import ignite.distributed as idist @@ -35,7 +37,9 @@ class Frequency(Metric): # Epoch [2/10]: [50/100] 50%|█████ , wps=400 [00:17<00:35] """ - def __init__(self, output_transform=lambda x: x, device=None): + def __init__( + self, output_transform: Callable = lambda x: x, device: Optional[Union[str, torch.device]] = torch.device("cpu") + ): self._timer = None self._acc = None self._n = None diff --git a/ignite/metrics/loss.py b/ignite/metrics/loss.py index f44ce3f6f193..c6fb85171894 100644 --- a/ignite/metrics/loss.py +++ b/ignite/metrics/loss.py @@ -37,7 +37,7 @@ def __init__( loss_fn: Callable, output_transform: Callable = lambda x: x, batch_size: Callable = lambda x: len(x), - device: Optional[Union[str, torch.device]] = None, + device: Optional[Union[str, torch.device]] = torch.device("cpu"), ): super(Loss, self).__init__(output_transform, device=device) self._loss_fn = loss_fn diff --git a/ignite/metrics/mean_pairwise_distance.py b/ignite/metrics/mean_pairwise_distance.py index b1f1c9a52c26..1edbf6009d3a 100644 --- a/ignite/metrics/mean_pairwise_distance.py +++ b/ignite/metrics/mean_pairwise_distance.py @@ -21,7 +21,7 @@ def __init__( p: int = 2, eps: float = 1e-6, output_transform: Callable = lambda x: x, - device: Optional[Union[str, torch.device]] = None, + device: Optional[Union[str, torch.device]] = torch.device("cpu"), ): super(MeanPairwiseDistance, self).__init__(output_transform, device=device) self._p = p diff --git a/ignite/metrics/metric.py b/ignite/metrics/metric.py index 4d78044f0406..89a62da2b259 100644 --- a/ignite/metrics/metric.py +++ b/ignite/metrics/metric.py @@ -129,7 +129,9 @@ class Metric(metaclass=ABCMeta): _required_output_keys = ("y_pred", "y") def __init__( - self, output_transform: Callable = lambda x: x, device: Optional[Union[str, torch.device]] = None, + self, + output_transform: Callable = lambda x: x, + device: Optional[Union[str, torch.device]] = torch.device("cpu"), ): self._output_transform = output_transform diff --git a/ignite/metrics/top_k_categorical_accuracy.py b/ignite/metrics/top_k_categorical_accuracy.py index 6d33d51001c5..8c41265d81e6 100644 --- a/ignite/metrics/top_k_categorical_accuracy.py +++ b/ignite/metrics/top_k_categorical_accuracy.py @@ -16,7 +16,10 @@ class TopKCategoricalAccuracy(Metric): """ def __init__( - self, k=5, output_transform: Callable = lambda x: x, device: Optional[Union[str, torch.device]] = None + self, + k=5, + output_transform: Callable = lambda x: x, + device: Optional[Union[str, torch.device]] = torch.device("cpu"), ): super(TopKCategoricalAccuracy, self).__init__(output_transform, device=device) self._k = k From 854860113abfa5bf81f61f9437335c58c6977c57 Mon Sep 17 00:00:00 2001 From: Nicholas Vadivelu Date: Sun, 16 Aug 2020 14:57:54 -0400 Subject: [PATCH 12/14] Update ignite/metrics/precision.py Co-authored-by: vfdev --- ignite/metrics/precision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ignite/metrics/precision.py b/ignite/metrics/precision.py index b8d920d61964..bd79916eb0cd 100644 --- a/ignite/metrics/precision.py +++ b/ignite/metrics/precision.py @@ -54,7 +54,7 @@ def compute(self) -> Union[torch.Tensor, float]: is_scalar = not isinstance(self._positives, torch.Tensor) or self._positives.ndim == 0 if is_scalar and self._positives == 0: raise NotComputableError( - "{} must have at least one example before" " it can be computed.".format(self.__class__.__name__) + "{} must have at least one example before it can be computed.".format(self.__class__.__name__) ) if not (self._type == "multilabel" and not self._average): From b84226bacce77d997f81b6263cbc7025c0510ef9 Mon Sep 17 00:00:00 2001 From: Nicholas Vadivelu Date: Sun, 16 Aug 2020 15:20:10 -0400 Subject: [PATCH 13/14] remove Optional type from metric devices since default is cpu --- ignite/metrics/accumulation.py | 8 ++++---- ignite/metrics/accuracy.py | 6 +++--- ignite/metrics/confusion_matrix.py | 2 +- ignite/metrics/fbeta.py | 2 +- ignite/metrics/frequency.py | 2 +- ignite/metrics/loss.py | 4 ++-- ignite/metrics/mean_pairwise_distance.py | 4 ++-- ignite/metrics/metric.py | 6 ++---- ignite/metrics/precision.py | 6 +++--- ignite/metrics/recall.py | 4 ++-- ignite/metrics/top_k_categorical_accuracy.py | 7 ++----- 11 files changed, 23 insertions(+), 28 deletions(-) diff --git a/ignite/metrics/accumulation.py b/ignite/metrics/accumulation.py index e72debfa98c3..0de2ee469cc7 100644 --- a/ignite/metrics/accumulation.py +++ b/ignite/metrics/accumulation.py @@ -1,5 +1,5 @@ import numbers -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Union import torch @@ -41,7 +41,7 @@ def __init__( self, op: Callable, output_transform: Callable = lambda x: x, - device: Optional[Union[str, torch.device]] = torch.device("cpu"), + device: Union[str, torch.device] = torch.device("cpu"), ): if not callable(op): raise TypeError("Argument op should be a callable, but given {}".format(type(op))) @@ -119,7 +119,7 @@ class Average(VariableAccumulation): """ def __init__( - self, output_transform: Callable = lambda x: x, device: Optional[Union[str, torch.device]] = torch.device("cpu") + self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu") ): def _mean_op(a, x): if isinstance(x, torch.Tensor) and x.ndim > 1: @@ -165,7 +165,7 @@ class GeometricAverage(VariableAccumulation): """ def __init__( - self, output_transform: Callable = lambda x: x, device: Optional[Union[str, torch.device]] = torch.device("cpu") + self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu") ): def _geom_op(a: torch.Tensor, x: Union[Any, numbers.Number, torch.Tensor]) -> torch.Tensor: if not isinstance(x, torch.Tensor): diff --git a/ignite/metrics/accuracy.py b/ignite/metrics/accuracy.py index c79d93b79c64..52f897e7aedd 100644 --- a/ignite/metrics/accuracy.py +++ b/ignite/metrics/accuracy.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional, Sequence, Union +from typing import Callable, Sequence, Union import torch @@ -13,7 +13,7 @@ def __init__( self, output_transform: Callable = lambda x: x, is_multilabel: bool = False, - device: Optional[Union[str, torch.device]] = torch.device("cpu"), + device: Union[str, torch.device] = torch.device("cpu"), ): self._is_multilabel = is_multilabel self._type = None @@ -130,7 +130,7 @@ def __init__( self, output_transform: Callable = lambda x: x, is_multilabel: bool = False, - device: Optional[Union[str, torch.device]] = torch.device("cpu"), + device: Union[str, torch.device] = torch.device("cpu"), ): self._num_correct = None self._num_examples = None diff --git a/ignite/metrics/confusion_matrix.py b/ignite/metrics/confusion_matrix.py index ab9b188a6880..574e7e6a1105 100644 --- a/ignite/metrics/confusion_matrix.py +++ b/ignite/metrics/confusion_matrix.py @@ -44,7 +44,7 @@ def __init__( num_classes: int, average: Optional[str] = None, output_transform: Callable = lambda x: x, - device: Optional[Union[str, torch.device]] = torch.device("cpu"), + device: Union[str, torch.device] = torch.device("cpu"), ): if average is not None and average not in ("samples", "recall", "precision"): raise ValueError("Argument average can None or one of 'samples', 'recall', 'precision'") diff --git a/ignite/metrics/fbeta.py b/ignite/metrics/fbeta.py index 1383b520364d..6af40b7234d0 100644 --- a/ignite/metrics/fbeta.py +++ b/ignite/metrics/fbeta.py @@ -15,7 +15,7 @@ def Fbeta( precision: Optional[Precision] = None, recall: Optional[Recall] = None, output_transform: Optional[Callable] = None, - device: Optional[Union[str, torch.device]] = torch.device("cpu"), + device: Union[str, torch.device] = torch.device("cpu"), ) -> MetricsLambda: """Calculates F-beta score diff --git a/ignite/metrics/frequency.py b/ignite/metrics/frequency.py index fa91e36df0b1..447cbbf63fd8 100644 --- a/ignite/metrics/frequency.py +++ b/ignite/metrics/frequency.py @@ -38,7 +38,7 @@ class Frequency(Metric): """ def __init__( - self, output_transform: Callable = lambda x: x, device: Optional[Union[str, torch.device]] = torch.device("cpu") + self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu") ): self._timer = None self._acc = None diff --git a/ignite/metrics/loss.py b/ignite/metrics/loss.py index c6fb85171894..b0e7d1955fd7 100644 --- a/ignite/metrics/loss.py +++ b/ignite/metrics/loss.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional, Sequence, Union +from typing import Callable, Sequence, Union import torch @@ -37,7 +37,7 @@ def __init__( loss_fn: Callable, output_transform: Callable = lambda x: x, batch_size: Callable = lambda x: len(x), - device: Optional[Union[str, torch.device]] = torch.device("cpu"), + device: Union[str, torch.device] = torch.device("cpu"), ): super(Loss, self).__init__(output_transform, device=device) self._loss_fn = loss_fn diff --git a/ignite/metrics/mean_pairwise_distance.py b/ignite/metrics/mean_pairwise_distance.py index 1edbf6009d3a..01a1eb7386e5 100644 --- a/ignite/metrics/mean_pairwise_distance.py +++ b/ignite/metrics/mean_pairwise_distance.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional, Sequence, Union +from typing import Callable, Sequence, Union import torch from torch.nn.functional import pairwise_distance @@ -21,7 +21,7 @@ def __init__( p: int = 2, eps: float = 1e-6, output_transform: Callable = lambda x: x, - device: Optional[Union[str, torch.device]] = torch.device("cpu"), + device: Union[str, torch.device] = torch.device("cpu"), ): super(MeanPairwiseDistance, self).__init__(output_transform, device=device) self._p = p diff --git a/ignite/metrics/metric.py b/ignite/metrics/metric.py index 89a62da2b259..a12a09fbb4dc 100644 --- a/ignite/metrics/metric.py +++ b/ignite/metrics/metric.py @@ -2,7 +2,7 @@ from abc import ABCMeta, abstractmethod from collections.abc import Mapping from functools import wraps -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Union import torch @@ -129,9 +129,7 @@ class Metric(metaclass=ABCMeta): _required_output_keys = ("y_pred", "y") def __init__( - self, - output_transform: Callable = lambda x: x, - device: Optional[Union[str, torch.device]] = torch.device("cpu"), + self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu"), ): self._output_transform = output_transform diff --git a/ignite/metrics/precision.py b/ignite/metrics/precision.py index bd79916eb0cd..7754c47f47cc 100644 --- a/ignite/metrics/precision.py +++ b/ignite/metrics/precision.py @@ -1,5 +1,5 @@ import warnings -from typing import Callable, Optional, Sequence, Union +from typing import Callable, Sequence, Union import torch @@ -18,7 +18,7 @@ def __init__( output_transform: Callable = lambda x: x, average: bool = False, is_multilabel: bool = False, - device: Optional[Union[str, torch.device]] = torch.device("cpu"), + device: Union[str, torch.device] = torch.device("cpu"), ): if idist.get_world_size() > 1: if (not average) and is_multilabel: @@ -131,7 +131,7 @@ def __init__( output_transform: Callable = lambda x: x, average: bool = False, is_multilabel: bool = False, - device: Optional[Union[str, torch.device]] = torch.device("cpu"), + device: Union[str, torch.device] = torch.device("cpu"), ): super(Precision, self).__init__( output_transform=output_transform, average=average, is_multilabel=is_multilabel, device=device diff --git a/ignite/metrics/recall.py b/ignite/metrics/recall.py index c3378d3f3f6a..d446583f6bd6 100644 --- a/ignite/metrics/recall.py +++ b/ignite/metrics/recall.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional, Sequence, Union +from typing import Callable, Sequence, Union import torch @@ -69,7 +69,7 @@ def __init__( output_transform: Callable = lambda x: x, average: bool = False, is_multilabel: bool = False, - device: Optional[Union[str, torch.device]] = torch.device("cpu"), + device: Union[str, torch.device] = torch.device("cpu"), ): super(Recall, self).__init__( output_transform=output_transform, average=average, is_multilabel=is_multilabel, device=device diff --git a/ignite/metrics/top_k_categorical_accuracy.py b/ignite/metrics/top_k_categorical_accuracy.py index 8c41265d81e6..abf7996d3047 100644 --- a/ignite/metrics/top_k_categorical_accuracy.py +++ b/ignite/metrics/top_k_categorical_accuracy.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional, Sequence, Union +from typing import Callable, Sequence, Union import torch @@ -16,10 +16,7 @@ class TopKCategoricalAccuracy(Metric): """ def __init__( - self, - k=5, - output_transform: Callable = lambda x: x, - device: Optional[Union[str, torch.device]] = torch.device("cpu"), + self, k=5, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu"), ): super(TopKCategoricalAccuracy, self).__init__(output_transform, device=device) self._k = k From 685c23bc21ba3b32eb0732c3bdb2f314b75f7ae4 Mon Sep 17 00:00:00 2001 From: Nicholas Vadivelu Date: Sun, 16 Aug 2020 15:30:39 -0400 Subject: [PATCH 14/14] add comment explaining lack of detach in accuracy metrics --- ignite/metrics/accuracy.py | 1 + ignite/metrics/top_k_categorical_accuracy.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/ignite/metrics/accuracy.py b/ignite/metrics/accuracy.py index 52f897e7aedd..79ba5e4c1d80 100644 --- a/ignite/metrics/accuracy.py +++ b/ignite/metrics/accuracy.py @@ -161,6 +161,7 @@ def update(self, output: Sequence[torch.Tensor]) -> None: y = torch.transpose(y, 1, last_dim - 1).reshape(-1, num_classes) correct = torch.all(y == y_pred.type_as(y), dim=-1) + # Don't need to detach here because torch.eq is not differentiable, so the computation graph is detached anyway. self._num_correct += torch.sum(correct).to(self._device) self._num_examples += correct.shape[0] diff --git a/ignite/metrics/top_k_categorical_accuracy.py b/ignite/metrics/top_k_categorical_accuracy.py index abf7996d3047..aa64f5b45319 100644 --- a/ignite/metrics/top_k_categorical_accuracy.py +++ b/ignite/metrics/top_k_categorical_accuracy.py @@ -32,6 +32,8 @@ def update(self, output: Sequence) -> None: sorted_indices = torch.topk(y_pred, self._k, dim=1)[1] expanded_y = y.view(-1, 1).expand(-1, self._k) correct = torch.sum(torch.eq(sorted_indices, expanded_y), dim=1) + + # Don't need to detach here because torch.eq is not differentiable, so the computation graph is detached anyway. self._num_correct += torch.sum(correct).to(self._device) self._num_examples += correct.shape[0]