Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
e6691bc
Recall/Precision metrics for ddp : average == false and multilabel ==…
fco-dv Feb 3, 2021
b57d89d
For v0.4.3 - Add more versionadded, versionchanged tags - Change v0.5…
fco-dv Feb 5, 2021
4575e92
added TimeLimit handler with its test and doc (#1611)
ahmedo42 Feb 6, 2021
8935f4c
Update handlers to use setup_logger (#1617)
1nF0rmed Feb 6, 2021
d01793a
Managing Deprecation using decorators (#1585)
Devanshu24 Feb 6, 2021
3c0b68f
Create documentation.md
vfdev-5 Feb 6, 2021
4a52ebc
Distributed tests on Windows should be skipped until fixed. (#1620)
ahmedo42 Feb 6, 2021
e4571ae
Added Checkpoint.get_default_score_fn (#1621)
vfdev-5 Feb 7, 2021
6e8dd3d
Update about.rst
vfdev-5 Feb 8, 2021
2c83380
Update pre-commit hooks and CONTRIBUTING.md (#1622)
Devanshu24 Feb 8, 2021
801a6a9
added requirements.txt and updated readme.md (#1624)
sparkingdark Feb 9, 2021
bd4ab8c
Replace relative paths with raw.githubusercontent (#1629)
Devanshu24 Feb 11, 2021
944afab
Updated cifar10 example (#1632)
vfdev-5 Feb 11, 2021
02e767e
Fixed failling CI and typos for cifar10 examples (#1633)
vfdev-5 Feb 12, 2021
61d8c2f
Removed temporary hack to install pth 1.7.1 (#1638)
vfdev-5 Feb 14, 2021
27eca29
[docker] Pillow -> Pillow-SIMD (#1509) (#1639)
vfdev-5 Feb 14, 2021
1f47f3f
Fix multinode tests script (#1631)
fco-dv Feb 14, 2021
04f8fd8
remove warning for average=False and is_multilabel=True
fco-dv Feb 15, 2021
611ea97
Merge branch 'master' into WIP_ddp_precision_recall
fco-dv Feb 16, 2021
dd07cbd
Merge branch 'master' into WIP_ddp_precision_recall
sdesrozis Feb 17, 2021
d9a16ee
Merge branch 'master' into WIP_ddp_precision_recall
sdesrozis Feb 17, 2021
1e3e5d3
Merge branch 'master' into WIP_ddp_precision_recall
sdesrozis Feb 18, 2021
f3998cb
update docstring and {precision, recall} tests according to test_mult…
fco-dv Feb 21, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 6 additions & 17 deletions ignite/metrics/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,6 @@ def __init__(
is_multilabel: bool = False,
device: Union[str, torch.device] = torch.device("cpu"),
):
if idist.get_world_size() > 1:
if (not average) and is_multilabel:
warnings.warn(
"Precision/Recall metrics do not work in distributed setting when average=False "
"and is_multilabel=True. Results are not reduced across computing devices. Computed result "
"corresponds to the local rank's (single process) result.",
RuntimeWarning,
)

self._average = average
self.eps = 1e-20
Expand All @@ -53,12 +45,14 @@ def compute(self) -> Union[torch.Tensor, float]:
raise NotComputableError(
f"{self.__class__.__name__} must have at least one example before it can be computed."
)

if not (self._type == "multilabel" and not self._average):
if not self._is_reduced:
if not self._is_reduced:
if not (self._type == "multilabel" and not self._average):
self._true_positives = idist.all_reduce(self._true_positives) # type: ignore[assignment]
self._positives = idist.all_reduce(self._positives) # type: ignore[assignment]
self._is_reduced = True # type: bool
else:
self._true_positives = cast(torch.Tensor, idist.all_gather(self._true_positives))
self._positives = cast(torch.Tensor, idist.all_gather(self._positives))
self._is_reduced = True # type: bool

result = self._true_positives / (self._positives + self.eps)

Expand Down Expand Up @@ -107,11 +101,6 @@ def thresholded_output_transform(output):
as tensors before computing a metric. This can potentially lead to a memory error if the input data is larger
than available RAM.

.. warning::

In multilabel cases, if average is False, current implementation does not work with distributed computations.
Results are not reduced across the GPUs. Computed result corresponds to the local rank's (single GPU) result.


Args:
output_transform (callable, optional): a callable that is used to transform the
Expand Down
5 changes: 0 additions & 5 deletions ignite/metrics/recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,6 @@ def thresholded_output_transform(output):
as tensors before computing a metric. This can potentially lead to a memory error if the input data is larger
than available RAM.

.. warning::

In multilabel cases, if average is False, current implementation does not work with distributed computations.
Results are not reduced across the GPUs. Computed result corresponds to the local rank's (single GPU) result.


Args:
output_transform (callable, optional): a callable that is used to transform the
Expand Down
36 changes: 15 additions & 21 deletions tests/ignite/metrics/test_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,7 @@ def update(engine, i):

engine = Engine(update)

pr = Precision(average=average, is_multilabel=True)
pr = Precision(average=average, is_multilabel=True, device=metric_device)
pr.attach(engine, "pr")

data = list(range(n_iters))
Expand All @@ -808,13 +808,13 @@ def update(engine, i):
else:
assert res == res2

np_y_preds = to_numpy_multilabel(y_preds)
np_y_true = to_numpy_multilabel(y_true)
assert pr._type == "multilabel"
res = res if average else res.mean().item()
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UndefinedMetricWarning)
true_res = precision_score(
to_numpy_multilabel(y_true), to_numpy_multilabel(y_preds), average="samples" if average else None
)

assert pytest.approx(res) == true_res
assert precision_score(np_y_true, np_y_preds, average="samples") == pytest.approx(res)

metric_devices = ["cpu"]
if device.type != "xla":
Expand All @@ -823,22 +823,16 @@ def update(engine, i):
for metric_device in metric_devices:
_test(average=True, n_epochs=1, metric_device=metric_device)
_test(average=True, n_epochs=2, metric_device=metric_device)
_test(average=False, n_epochs=1, metric_device=metric_device)
_test(average=False, n_epochs=2, metric_device=metric_device)

if idist.get_world_size() > 1:
with pytest.warns(
RuntimeWarning,
match="Precision/Recall metrics do not work in distributed setting when "
"average=False and is_multilabel=True",
):
pr = Precision(average=False, is_multilabel=True)

y_pred = torch.randint(0, 2, size=(4, 3, 6, 8))
y = torch.randint(0, 2, size=(4, 3, 6, 8)).long()
pr.update((y_pred, y))
pr_compute1 = pr.compute()
pr_compute2 = pr.compute()
assert len(pr_compute1) == 4 * 6 * 8
assert (pr_compute1 == pr_compute2).all()
pr1 = Precision(is_multilabel=True, average=True)
pr2 = Precision(is_multilabel=True, average=False)
y_pred = torch.randint(0, 2, size=(10, 4, 20, 23))
y = torch.randint(0, 2, size=(10, 4, 20, 23)).long()
pr1.update((y_pred, y))
pr2.update((y_pred, y))
assert pr1.compute() == pytest.approx(pr2.compute().mean().item())


def _test_distrib_accumulator_device(device):
Expand Down
34 changes: 14 additions & 20 deletions tests/ignite/metrics/test_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,13 +808,13 @@ def update(engine, i):
else:
assert res == res2

np_y_preds = to_numpy_multilabel(y_preds)
np_y_true = to_numpy_multilabel(y_true)
assert re._type == "multilabel"
res = res if average else res.mean().item()
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UndefinedMetricWarning)
true_res = recall_score(
to_numpy_multilabel(y_true), to_numpy_multilabel(y_preds), average="samples" if average else None
)

assert pytest.approx(res) == true_res
assert recall_score(np_y_true, np_y_preds, average="samples") == pytest.approx(res)

metric_devices = ["cpu"]
if device.type != "xla":
Expand All @@ -823,22 +823,16 @@ def update(engine, i):
for metric_device in metric_devices:
_test(average=True, n_epochs=1, metric_device=metric_device)
_test(average=True, n_epochs=2, metric_device=metric_device)
_test(average=False, n_epochs=1, metric_device=metric_device)
_test(average=False, n_epochs=2, metric_device=metric_device)

if idist.get_world_size() > 1:
with pytest.warns(
RuntimeWarning,
match="Precision/Recall metrics do not work in distributed setting when "
"average=False and is_multilabel=True",
):
re = Recall(average=False, is_multilabel=True)

y_pred = torch.randint(0, 2, size=(4, 3, 6, 8))
y = torch.randint(0, 2, size=(4, 3, 6, 8)).long()
re.update((y_pred, y))
re_compute1 = re.compute()
re_compute2 = re.compute()
assert len(re_compute1) == 4 * 6 * 8
assert (re_compute1 == re_compute2).all()
re1 = Recall(is_multilabel=True, average=True)
re2 = Recall(is_multilabel=True, average=False)
y_pred = torch.randint(0, 2, size=(10, 4, 20, 23))
y = torch.randint(0, 2, size=(10, 4, 20, 23)).long()
re1.update((y_pred, y))
re2.update((y_pred, y))
assert re1.compute() == pytest.approx(re2.compute().mean().item())


def _test_distrib_accumulator_device(device):
Expand Down