Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
9f7daa1
update accuracy to accumulate _num_correct in a tensor on the right d…
n2cholas Aug 7, 2020
a87f93d
update loss metric to accumulate _sum in a tensor on the right device
n2cholas Aug 7, 2020
30b2e19
update mae metric to accumulate in a tensor on the right device
n2cholas Aug 7, 2020
a3e237c
update mpd metric to accumulate in a tensor on the right device
n2cholas Aug 7, 2020
7100176
update mse metric to accumulate in a tensor on the right device
n2cholas Aug 7, 2020
3228a0a
update top k accuracy metric to accumulate in a tensor on the right …
n2cholas Aug 7, 2020
412551e
update precision and recall metrics to accumulate in tensors on the r…
n2cholas Aug 8, 2020
4c4a76c
.....
n2cholas Aug 8, 2020
b1e6956
black formatting
n2cholas Aug 8, 2020
b081e92
reverted run*.sh
n2cholas Aug 10, 2020
a343c35
change all metrics default device to cpu except running_average
n2cholas Aug 16, 2020
8548601
Update ignite/metrics/precision.py
n2cholas Aug 16, 2020
b84226b
remove Optional type from metric devices since default is cpu
n2cholas Aug 16, 2020
685c23b
add comment explaining lack of detach in accuracy metrics
n2cholas Aug 16, 2020
0b4337d
update docstrings and docs
n2cholas Aug 17, 2020
b2fa213
Update ignite/metrics/accumulation.py
n2cholas Aug 17, 2020
90e0e9a
Update ignite/metrics/accumulation.py
n2cholas Aug 17, 2020
6c1fda4
Update ignite/metrics/accumulation.py
n2cholas Aug 17, 2020
c510e10
Update ignite/metrics/accuracy.py
n2cholas Aug 17, 2020
d5d4854
Update ignite/metrics/fbeta.py
n2cholas Aug 17, 2020
39515b7
Update ignite/metrics/loss.py
n2cholas Aug 17, 2020
3c49871
Update ignite/metrics/metric.py
n2cholas Aug 17, 2020
6de10dd
Update ignite/metrics/precision.py
n2cholas Aug 17, 2020
eca0bc3
Update ignite/metrics/recall.py
n2cholas Aug 17, 2020
ad7082e
add comment explaining lack of detach in metrics docs
n2cholas Aug 17, 2020
c057d52
Merge remote-tracking branch 'pytorch-ignite/metrics_impl' into metri…
n2cholas Aug 17, 2020
90b5b85
support device argument for running_average
n2cholas Aug 17, 2020
3481da1
update support for device argumenet for accumulation
n2cholas Aug 18, 2020
d340bb7
fix and improve device tests for metrics
n2cholas Aug 18, 2020
4824e24
fix and improve device tests for metrics
n2cholas Aug 18, 2020
1361866
fix TPU tests
n2cholas Aug 18, 2020
556262b
Apply suggestions from code review
vfdev-5 Aug 18, 2020
489620b
Apply suggestions from code review
vfdev-5 Aug 18, 2020
566e9bc
Merge branch 'metrics_impl' of https://github.com/pytorch/ignite into…
Aug 31, 2020
6edd30d
detach tensors earlier in update
Aug 31, 2020
375f91e
remove redundant to() call
Aug 31, 2020
960449c
ensure metrics aren't created on XLA devices
n2cholas Sep 7, 2020
96128fd
Merge branch 'metrics_impl' of https://github.com/pytorch/ignite into…
n2cholas Sep 7, 2020
d192a8f
Fixed isort
vfdev-5 Sep 8, 2020
23d72c0
move xla check to Metric.__init__ instead of individual metrics
n2cholas Sep 9, 2020
edc74a4
update xla tests
n2cholas Sep 9, 2020
2c6b7d2
replace deleted callable check
n2cholas Sep 9, 2020
6e64f37
remove redundant precision and recall __init__
n2cholas Sep 9, 2020
2828e74
replace precision/recall __init__ for docs rendering
n2cholas Sep 9, 2020
bcc3cbb
add support for metrics_lambda with components on diff devices
n2cholas Sep 9, 2020
bb257de
Merge branch 'metrics_impl' of https://github.com/pytorch/ignite into…
n2cholas Sep 9, 2020
2f7e542
fix epoch_metric xla test
n2cholas Sep 9, 2020
0d1ba42
Merge branch 'metrics_impl' of https://github.com/pytorch/ignite into…
n2cholas Sep 11, 2020
ebefd4b
detach output consistently for all metrics
n2cholas Sep 11, 2020
4ce863d
fix horovod two gpu tests
n2cholas Sep 11, 2020
8a94b8f
make confusion matrix detaches like other metrics
n2cholas Sep 11, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions ignite/metrics/accumulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ def _check_output_type(self, output: Union[Any, torch.Tensor, numbers.Number]) -
def update(self, output: Union[Any, torch.Tensor, numbers.Number]) -> None:
self._check_output_type(output)

if self._device is not None:
# Put output to the metric's device
if isinstance(output, torch.Tensor) and (output.device != self._device):
if isinstance(output, torch.Tensor):
output = output.detach()
if output.device != self._device:
output = output.to(self._device)

self.accumulator = self._op(self.accumulator, output)
Expand Down
7 changes: 3 additions & 4 deletions ignite/metrics/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,9 @@ def reset(self) -> None:

@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
y_pred, y = output
self._check_shape((y_pred, y))
self._check_type((y_pred, y))
self._check_shape(output)
self._check_type(output)
y_pred, y = output[0].detach(), output[1].detach()

if self._type == "binary":
correct = torch.eq(y_pred.view(-1).to(y), y.view(-1))
Expand All @@ -163,7 +163,6 @@ def update(self, output: Sequence[torch.Tensor]) -> None:
y = torch.transpose(y, 1, last_dim - 1).reshape(-1, num_classes)
correct = torch.all(y == y_pred.type_as(y), dim=-1)

# Don't need to detach here because torch.eq is not differentiable, so the computation graph is detached anyway.
self._num_correct += torch.sum(correct).to(self._device)
self._num_examples += correct.shape[0]

Expand Down
4 changes: 2 additions & 2 deletions ignite/metrics/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def reset(self) -> None:
self._num_examples = 0

def _check_shape(self, output: Sequence[torch.Tensor]) -> None:
y_pred, y = output
y_pred, y = output[0].detach(), output[1].detach()

if y_pred.ndimension() < 2:
raise ValueError(
Expand Down Expand Up @@ -94,7 +94,7 @@ def _check_shape(self, output: Sequence[torch.Tensor]) -> None:
@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
self._check_shape(output)
y_pred, y = output
y_pred, y = output[0].detach(), output[1].detach()

self._num_examples += y_pred.shape[0]

Expand Down
4 changes: 2 additions & 2 deletions ignite/metrics/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,9 @@ def __init__(

@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
y_pred, y = output
self._check_shape(output)
self._check_type((y_pred, y))
self._check_type(output)
y_pred, y = output[0].detach(), output[1].detach()

if self._type == "binary":
y_pred = y_pred.view(-1)
Expand Down
4 changes: 2 additions & 2 deletions ignite/metrics/recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ def __init__(

@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
y_pred, y = output
self._check_shape(output)
self._check_type((y_pred, y))
self._check_type(output)
y_pred, y = output[0].detach(), output[1].detach()

if self._type == "binary":
y_pred = y_pred.view(-1)
Expand Down
3 changes: 2 additions & 1 deletion ignite/metrics/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ def _gaussian_or_uniform_kernel(self, kernel_size, sigma):

@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
y_pred, y = output
y_pred, y = output[0].detach(), output[1].detach()

if y_pred.dtype != y.dtype:
raise TypeError(
"Expected y_pred and y to have the same data type. Got y_pred: {} and y: {}.".format(
Expand Down
3 changes: 1 addition & 2 deletions ignite/metrics/top_k_categorical_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,11 @@ def reset(self) -> None:

@reinit__is_reduced
def update(self, output: Sequence) -> None:
y_pred, y = output
y_pred, y = output[0].detach(), output[1].detach()
sorted_indices = torch.topk(y_pred, self._k, dim=1)[1]
expanded_y = y.view(-1, 1).expand(-1, self._k)
correct = torch.sum(torch.eq(sorted_indices, expanded_y), dim=1)

# Don't need to detach here because torch.eq is not differentiable, so the computation graph is detached anyway.
self._num_correct += torch.sum(correct).to(self._device)
self._num_examples += correct.shape[0]

Expand Down
10 changes: 5 additions & 5 deletions tests/ignite/metrics/test_accumulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def _test(metric_device):
for _ in range(3):
_test("cpu")
if device.type != "xla":
_test(device)
_test(idist.device())


def _test_distrib_average(device):
Expand Down Expand Up @@ -271,7 +271,7 @@ def _test(metric_device):
for _ in range(3):
_test("cpu")
if device.type != "xla":
_test(device)
_test(idist.device())


def _test_distrib_geom_average(device):
Expand Down Expand Up @@ -310,7 +310,7 @@ def _test(metric_device):
for _ in range(3):
_test("cpu")
if device.type != "xla":
_test(device)
_test(idist.device())


def _test_distrib_integration(device):
Expand Down Expand Up @@ -362,7 +362,7 @@ def _geom_mean(y_true):

metric_devices = ["cpu"]
if device.type != "xla":
metric_devices.append(device)
metric_devices.append(idist.device())
for metric_device in metric_devices:
_test(Average, _mean, metric_device)
_test(GeometricAverage, _geom_mean, metric_device, tol=1e-4)
Expand All @@ -372,7 +372,7 @@ def _test_distrib_accumulator_device(device):

metric_devices = [torch.device("cpu")]
if device.type != "xla":
metric_devices.append(device)
metric_devices.append(idist.device())
for metric_device in metric_devices:

m = VariableAccumulation(lambda a, x: x, device=metric_device)
Expand Down
8 changes: 4 additions & 4 deletions tests/ignite/metrics/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ def _test(metric_device):
for _ in range(3):
_test("cpu")
if device.type != "xla":
_test(device)
_test(idist.device())


def _test_distrib_integration_multiclass(device):
Expand Down Expand Up @@ -751,7 +751,7 @@ def update(engine, i):

metric_devices = ["cpu"]
if device.type != "xla":
metric_devices.append(device)
metric_devices.append(idist.device())
for metric_device in metric_devices:
for _ in range(2):
_test(n_epochs=1, metric_device=metric_device)
Expand Down Expand Up @@ -804,7 +804,7 @@ def update(engine, i):

metric_devices = ["cpu"]
if device.type != "xla":
metric_devices.append(device)
metric_devices.append(idist.device())
for metric_device in metric_devices:
for _ in range(2):
_test(n_epochs=1, metric_device=metric_device)
Expand All @@ -815,7 +815,7 @@ def _test_distrib_accumulator_device(device):

metric_devices = [torch.device("cpu")]
if device.type != "xla":
metric_devices.append(device)
metric_devices.append(idist.device())
for metric_device in metric_devices:

acc = Accuracy(device=metric_device)
Expand Down
2 changes: 1 addition & 1 deletion tests/ignite/metrics/test_confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ def _test_distrib_accumulator_device(device):

metric_devices = [torch.device("cpu")]
if device.type != "xla":
metric_devices.append(device)
metric_devices.append(idist.device())
for metric_device in metric_devices:

cm = ConfusionMatrix(num_classes=3, device=metric_device)
Expand Down
2 changes: 1 addition & 1 deletion tests/ignite/metrics/test_fbeta.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def update(engine, i):

metric_devices = ["cpu"]
if device.type != "xla":
metric_devices.append(device)
metric_devices.append(idist.device())
for metric_device in metric_devices:
_test(None, None, average=True, n_epochs=1, metric_device=metric_device)
_test(None, None, average=True, n_epochs=2, metric_device=metric_device)
Expand Down
4 changes: 2 additions & 2 deletions tests/ignite/metrics/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,14 @@ def _test(metric_device):

_test("cpu")
if device.type != "xla":
_test(device)
_test(idist.device())


def _test_distrib_accumulator_device(device):

metric_devices = [torch.device("cpu")]
if device.type != "xla":
metric_devices.append(device)
metric_devices.append(idist.device())
for metric_device in metric_devices:
loss = Loss(nll_loss, device=metric_device)
assert loss._device == metric_device
Expand Down
4 changes: 2 additions & 2 deletions tests/ignite/metrics/test_mean_absolute_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,14 @@ def _test(metric_device):

_test("cpu")
if device.type != "xla":
_test(device)
_test(idist.device())


def _test_distrib_accumulator_device(device):

metric_devices = [torch.device("cpu")]
if device.type != "xla":
metric_devices.append(device)
metric_devices.append(idist.device())
for metric_device in metric_devices:
mae = MeanAbsoluteError(device=metric_device)
assert mae._device == metric_device
Expand Down
4 changes: 2 additions & 2 deletions tests/ignite/metrics/test_mean_pairwise_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,14 @@ def _test(metric_device):

_test("cpu")
if device.type != "xla":
_test(device)
_test(idist.device())


def _test_distrib_accumulator_device(device):

metric_devices = [torch.device("cpu")]
if device.type != "xla":
metric_devices.append(device)
metric_devices.append(idist.device())
for metric_device in metric_devices:

mpd = MeanPairwiseDistance(device=metric_device)
Expand Down
4 changes: 2 additions & 2 deletions tests/ignite/metrics/test_mean_squared_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,14 @@ def _test(metric_device):

_test("cpu")
if device.type != "xla":
_test(device)
_test(idist.device())


def _test_distrib_accumulator_device(device):

metric_devices = [torch.device("cpu")]
if device.type != "xla":
metric_devices.append(device)
metric_devices.append(idist.device())
for metric_device in metric_devices:

device = torch.device(device)
Expand Down
2 changes: 1 addition & 1 deletion tests/ignite/metrics/test_metrics_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def Fbeta(r, p, beta):
for _ in range(3):
_test("cpu")
if device.type != "xla":
_test(device)
_test(idist.device())


def _test_distrib_metrics_on_diff_devices(device):
Expand Down
8 changes: 4 additions & 4 deletions tests/ignite/metrics/test_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,7 @@ def update(engine, i):

metric_devices = [torch.device("cpu")]
if device.type != "xla":
metric_devices.append(device)
metric_devices.append(idist.device())
for _ in range(2):
for metric_device in metric_devices:
_test(average=True, n_epochs=1, metric_device=metric_device)
Expand Down Expand Up @@ -818,7 +818,7 @@ def update(engine, i):

metric_devices = ["cpu"]
if device.type != "xla":
metric_devices.append(device)
metric_devices.append(idist.device())
for _ in range(2):
for metric_device in metric_devices:
_test(average=True, n_epochs=1, metric_device=metric_device)
Expand Down Expand Up @@ -863,7 +863,7 @@ def _test(average, metric_device):

metric_devices = [torch.device("cpu")]
if device.type != "xla":
metric_devices.append(device)
metric_devices.append(idist.device())
for metric_device in metric_devices:
_test(True, metric_device=metric_device)
_test(False, metric_device=metric_device)
Expand Down Expand Up @@ -896,7 +896,7 @@ def _test(average, metric_device):

metric_devices = [torch.device("cpu")]
if device.type != "xla":
metric_devices.append(device)
metric_devices.append(idist.device())
for metric_device in metric_devices:
_test(True, metric_device=metric_device)
_test(False, metric_device=metric_device)
Expand Down
8 changes: 4 additions & 4 deletions tests/ignite/metrics/test_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,7 @@ def update(engine, i):

metric_devices = [torch.device("cpu")]
if device.type != "xla":
metric_devices.append(device)
metric_devices.append(idist.device())
for _ in range(2):
for metric_device in metric_devices:
_test(average=True, n_epochs=1, metric_device=metric_device)
Expand Down Expand Up @@ -818,7 +818,7 @@ def update(engine, i):

metric_devices = ["cpu"]
if device.type != "xla":
metric_devices.append(device)
metric_devices.append(idist.device())
for _ in range(2):
for metric_device in metric_devices:
_test(average=True, n_epochs=1, metric_device=metric_device)
Expand Down Expand Up @@ -863,7 +863,7 @@ def _test(average, metric_device):

metric_devices = [torch.device("cpu")]
if device.type != "xla":
metric_devices.append(device)
metric_devices.append(idist.device())
for metric_device in metric_devices:
_test(True, metric_device=metric_device)
_test(False, metric_device=metric_device)
Expand Down Expand Up @@ -896,7 +896,7 @@ def _test(average, metric_device):

metric_devices = [torch.device("cpu")]
if device.type != "xla":
metric_devices.append(device)
metric_devices.append(idist.device())
for metric_device in metric_devices:
_test(True, metric_device=metric_device)
_test(False, metric_device=metric_device)
Expand Down
2 changes: 1 addition & 1 deletion tests/ignite/metrics/test_root_mean_squared_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _test(metric_device):

_test("cpu")
if device.type != "xla":
_test(device)
_test(idist.device())


@pytest.mark.distributed
Expand Down
6 changes: 3 additions & 3 deletions tests/ignite/metrics/test_running_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def update_fn(engine, batch):
trainer = Engine(update_fn)
alpha = 0.98

metric_device = device if torch.device(device).type != "xla" else "cpu"
metric_device = idist.device() if torch.device(device).type != "xla" else "cpu"
avg_output = RunningAverage(output_transform=lambda x: x, alpha=alpha, epoch_bound=False, device=metric_device)
avg_output.attach(trainer, "running_avg_output")

Expand Down Expand Up @@ -365,14 +365,14 @@ def assert_equal_running_avg_acc_values(engine):

_test("cpu")
if device.type != "xla":
_test(device)
_test(idist.device())


def _test_distrib_accumulator_device(device):

metric_devices = [torch.device("cpu")]
if device.type != "xla":
metric_devices.append(device)
metric_devices.append(idist.device())
for metric_device in metric_devices:

# Don't test the src=Metric case because compute() returns a scalar,
Expand Down
Loading