From 41ac741aee8b204faf8f661174ff29511de21087 Mon Sep 17 00:00:00 2001 From: vcarpani Date: Sun, 4 Oct 2020 09:58:54 +0200 Subject: [PATCH 1/7] Add probability to precision. --- ignite/metrics/accuracy.py | 14 ++++- tests/ignite/metrics/test_accuracy.py | 84 +++++++++++++++++++++++++++ 2 files changed, 97 insertions(+), 1 deletion(-) diff --git a/ignite/metrics/accuracy.py b/ignite/metrics/accuracy.py index 7d6c939e4b5..e089cc4194b 100644 --- a/ignite/metrics/accuracy.py +++ b/ignite/metrics/accuracy.py @@ -1,4 +1,5 @@ from typing import Callable, Sequence, Union +import enum import torch @@ -125,17 +126,24 @@ def thresholded_output_transform(output): device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By default, CPU. + mode (Mode): specifies in which form input will be passed. This can be useful to directly compute + accuracy on the output of a neural networki, which ofter return probabilities. By default, LABELS. """ + class Mode(enum.Enum): + LABELS = enum.auto() + PROBABILITIES = enum.auto() def __init__( self, output_transform: Callable = lambda x: x, is_multilabel: bool = False, device: Union[str, torch.device] = torch.device("cpu"), + mode: Mode = Mode.LABELS, ): self._num_correct = None self._num_examples = None + self._mode = mode super(Accuracy, self).__init__(output_transform=output_transform, is_multilabel=is_multilabel, device=device) @reinit__is_reduced @@ -147,9 +155,13 @@ def reset(self) -> None: @reinit__is_reduced def update(self, output: Sequence[torch.Tensor]) -> None: self._check_shape(output) - self._check_type(output) y_pred, y = output[0].detach(), output[1].detach() + if self._mode == self.Mode.PROBABILITIES: + y_pred = torch.round(y_pred) + + self._check_type([y_pred, y]) + if self._type == "binary": correct = torch.eq(y_pred.view(-1).to(y), y.view(-1)) elif self._type == "multiclass": diff --git a/tests/ignite/metrics/test_accuracy.py b/tests/ignite/metrics/test_accuracy.py index 3960a09ec7f..902dd6e2924 100644 --- a/tests/ignite/metrics/test_accuracy.py +++ b/tests/ignite/metrics/test_accuracy.py @@ -91,6 +91,43 @@ def _test(): _test() +def test_binary_input_N_probabilities(): + # Binary accuracy on probabilities input of shape (N, 1) or (N, ) + def _test(): + acc = Accuracy(mode=Accuracy.Mode.PROBABILITIES) + + y_pred = torch.rand(size=(10,)) + y = torch.randint(0, 2, size=(10,)).long() + acc.update((y_pred, y)) + np_y = y.numpy().ravel() + np_y_pred = y_pred.numpy().round().ravel() + assert acc._type == "binary" + assert isinstance(acc.compute(), float) + assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute()) + + # Batched Updates + acc.reset() + y_pred = torch.rand(size=(100,)) + y = torch.randint(0, 2, size=(100,)).long() + + n_iters = 16 + batch_size = y.shape[0] // n_iters + 1 + + for i in range(n_iters): + idx = i * batch_size + acc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) + + np_y = y.numpy().ravel() + np_y_pred = y_pred.numpy().round().ravel() + assert acc._type == "binary" + assert isinstance(acc.compute(), float) + assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute()) + + # check multiple random inputs as random exact occurencies are rare + for _ in range(10): + _test() + + def test_binary_input_NL(): # Binary accuracy on input of shape (N, L) def _test(): @@ -138,6 +175,53 @@ def _test(): _test() +def test_binary_input_NL_probabilities(): + # Binary accuracy on probabilities input of shape (N, L) + def _test(): + acc = Accuracy(mode=Accuracy.Mode.PROBABILITIES) + + y_pred = torch.rand(size=(10, 5)) + y = torch.randint(0, 2, size=(10, 5)).long() + acc.update((y_pred, y)) + np_y = y.numpy().ravel() + np_y_pred = y_pred.numpy().round().ravel() + assert acc._type == "binary" + assert isinstance(acc.compute(), float) + assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute()) + + acc.reset() + y_pred = torch.rand(size=(10, 1, 5)) + y = torch.randint(0, 2, size=(10, 1, 5)).long() + acc.update((y_pred, y)) + np_y = y.numpy().ravel() + np_y_pred = y_pred.numpy().round().ravel() + assert acc._type == "binary" + assert isinstance(acc.compute(), float) + assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute()) + + # Batched Updates + acc.reset() + y_pred = torch.rand(size=(100, 8)) + y = torch.randint(0, 2, size=(100, 8)).long() + + n_iters = 16 + batch_size = y.shape[0] // n_iters + 1 + + for i in range(n_iters): + idx = i * batch_size + acc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) + + np_y = y.numpy().ravel() + np_y_pred = y_pred.numpy().round().ravel() + assert acc._type == "binary" + assert isinstance(acc.compute(), float) + assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute()) + + # check multiple random inputs as random exact occurencies are rare + for _ in range(10): + _test() + + def test_binary_input_NHW(): # Binary accuracy on input of shape (N, H, W, ...) def _test(): From 5ac700d4ce94682523187c6220209437b4a3490f Mon Sep 17 00:00:00 2001 From: vcarpani Date: Sun, 4 Oct 2020 19:56:51 +0200 Subject: [PATCH 2/7] Accuracy: add binarization threshold option. --- ignite/metrics/accuracy.py | 21 ++++++--------- tests/ignite/metrics/test_accuracy.py | 38 +++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 13 deletions(-) diff --git a/ignite/metrics/accuracy.py b/ignite/metrics/accuracy.py index e089cc4194b..8445d915215 100644 --- a/ignite/metrics/accuracy.py +++ b/ignite/metrics/accuracy.py @@ -104,19 +104,10 @@ class Accuracy(_BaseClassification): - `y` and `y_pred` must be in the following shape of (batch_size, num_categories, ...) and num_categories must be greater than 1 for multilabel cases. - In binary and multilabel cases, the elements of `y` and `y_pred` should have 0 or 1 values. Thresholding of + In binary and multilabel cases, the elements of `y` should have 0 or 1 values, while `y_pred` can be + thresholded using PROBABILITIES Mode. predictions can be done as below: - .. code-block:: python - - def thresholded_output_transform(output): - y_pred, y = output - y_pred = torch.round(y_pred) - return y_pred, y - - binary_accuracy = Accuracy(thresholded_output_transform) - - Args: output_transform (callable, optional): a callable that is used to transform the :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the @@ -127,7 +118,9 @@ def thresholded_output_transform(output): device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By default, CPU. mode (Mode): specifies in which form input will be passed. This can be useful to directly compute - accuracy on the output of a neural networki, which ofter return probabilities. By default, LABELS. + accuracy on the output of a neural network, which ofter return probabilities. By default, LABELS. + binarization_threshold (float): threshold for binarization of the input, in case a Mode that uses + binarization is used. """ class Mode(enum.Enum): @@ -140,10 +133,12 @@ def __init__( is_multilabel: bool = False, device: Union[str, torch.device] = torch.device("cpu"), mode: Mode = Mode.LABELS, + threshold: float = 0.5, ): self._num_correct = None self._num_examples = None self._mode = mode + self._threshold = threshold super(Accuracy, self).__init__(output_transform=output_transform, is_multilabel=is_multilabel, device=device) @reinit__is_reduced @@ -158,7 +153,7 @@ def update(self, output: Sequence[torch.Tensor]) -> None: y_pred, y = output[0].detach(), output[1].detach() if self._mode == self.Mode.PROBABILITIES: - y_pred = torch.round(y_pred) + y_pred = (y_pred >= self._threshold).int() self._check_type([y_pred, y]) diff --git a/tests/ignite/metrics/test_accuracy.py b/tests/ignite/metrics/test_accuracy.py index 902dd6e2924..80bd16357da 100644 --- a/tests/ignite/metrics/test_accuracy.py +++ b/tests/ignite/metrics/test_accuracy.py @@ -128,6 +128,44 @@ def _test(): _test() +def test_binary_input_N_probabilities_threshold(): + # Binary accuracy on probabilities input of shape (N, 1) or (N, ), + # with custom binarization threshold. + def _test(): + acc = Accuracy(mode=Accuracy.Mode.PROBABILITIES, threshold=0.75) + + y_pred = torch.rand(size=(10,)) + y = torch.randint(0, 2, size=(10,)).long() + acc.update((y_pred, y)) + np_y = y.numpy().ravel() + np_y_pred = (y_pred.numpy() >= 0.75).astype(int).ravel() + assert acc._type == "binary" + assert isinstance(acc.compute(), float) + assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute()) + + # Batched Updates + acc.reset() + y_pred = torch.rand(size=(100,)) + y = torch.randint(0, 2, size=(100,)).long() + + n_iters = 16 + batch_size = y.shape[0] // n_iters + 1 + + for i in range(n_iters): + idx = i * batch_size + acc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) + + np_y = y.numpy().ravel() + np_y_pred = (y_pred.numpy() >= 0.75).astype(int).ravel() + assert acc._type == "binary" + assert isinstance(acc.compute(), float) + assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute()) + + # check multiple random inputs as random exact occurencies are rare + for _ in range(10): + _test() + + def test_binary_input_NL(): # Binary accuracy on input of shape (N, L) def _test(): From 482666bd860c931a956f1d00d51a676c566e32d5 Mon Sep 17 00:00:00 2001 From: vcarpani Date: Mon, 5 Oct 2020 09:30:17 +0200 Subject: [PATCH 3/7] Maintain compatibility with python 3.5. --- ignite/metrics/accuracy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ignite/metrics/accuracy.py b/ignite/metrics/accuracy.py index 8445d915215..0b58003fbd3 100644 --- a/ignite/metrics/accuracy.py +++ b/ignite/metrics/accuracy.py @@ -124,8 +124,8 @@ class Accuracy(_BaseClassification): """ class Mode(enum.Enum): - LABELS = enum.auto() - PROBABILITIES = enum.auto() + LABELS = 0 + PROBABILITIES = 1 def __init__( self, From b5670a7e624a50be704c2a8f992472f633199823 Mon Sep 17 00:00:00 2001 From: vcarpani Date: Fri, 9 Oct 2020 08:30:22 +0200 Subject: [PATCH 4/7] Add supprt for logits input in Accuracy. --- ignite/metrics/accuracy.py | 7 ++- tests/ignite/metrics/test_accuracy.py | 75 +++++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 2 deletions(-) diff --git a/ignite/metrics/accuracy.py b/ignite/metrics/accuracy.py index 0b58003fbd3..2f002c7791f 100644 --- a/ignite/metrics/accuracy.py +++ b/ignite/metrics/accuracy.py @@ -104,8 +104,8 @@ class Accuracy(_BaseClassification): - `y` and `y_pred` must be in the following shape of (batch_size, num_categories, ...) and num_categories must be greater than 1 for multilabel cases. - In binary and multilabel cases, the elements of `y` should have 0 or 1 values, while `y_pred` can be - thresholded using PROBABILITIES Mode. + In binary and multilabel cases, the elements of `y` should have 0 or 1 values, while `y_pred` can represent + probabilities using PROBABILITIES Mode or logits using LOGITS Mode. predictions can be done as below: Args: @@ -126,6 +126,7 @@ class Accuracy(_BaseClassification): class Mode(enum.Enum): LABELS = 0 PROBABILITIES = 1 + LOGITS = 2 def __init__( self, @@ -154,6 +155,8 @@ def update(self, output: Sequence[torch.Tensor]) -> None: if self._mode == self.Mode.PROBABILITIES: y_pred = (y_pred >= self._threshold).int() + if self._mode == self.Mode.LOGITS: + y_pred = (torch.sigmoid(y_pred) >= self._threshold).int() self._check_type([y_pred, y]) diff --git a/tests/ignite/metrics/test_accuracy.py b/tests/ignite/metrics/test_accuracy.py index 80bd16357da..90ee9c31ab9 100644 --- a/tests/ignite/metrics/test_accuracy.py +++ b/tests/ignite/metrics/test_accuracy.py @@ -166,6 +166,81 @@ def _test(): _test() +def test_binary_input_N_logits(): + # Binary accuracy on logits input of shape (N, 1) or (N, ) + def _test(): + acc = Accuracy(mode=Accuracy.Mode.LOGITS) + + y_pred = torch.randn(size=(10,)) + y = torch.randint(0, 2, size=(10,)).long() + acc.update((y_pred, y)) + np_y = y.numpy().ravel() + np_y_pred = torch.sigmoid(y_pred).numpy().round().ravel() + assert acc._type == "binary" + assert isinstance(acc.compute(), float) + assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute()) + + # Batched Updates + acc.reset() + y_pred = torch.randn(size=(100,)) + y = torch.randint(0, 2, size=(100,)).long() + + n_iters = 16 + batch_size = y.shape[0] // n_iters + 1 + + for i in range(n_iters): + idx = i * batch_size + acc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) + + np_y = y.numpy().ravel() + np_y_pred = torch.sigmoid(y_pred).numpy().round().ravel() + assert acc._type == "binary" + assert isinstance(acc.compute(), float) + assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute()) + + # check multiple random inputs as random exact occurencies are rare + for _ in range(10): + _test() + + +def test_binary_input_N_logits_threshold(): + # Binary accuracy on logits input of shape (N, 1) or (N, ), + # with custom binarization threshold. + def _test(): + acc = Accuracy(mode=Accuracy.Mode.LOGITS, threshold=0.75) + + y_pred = torch.randn(size=(10,)) + y = torch.randint(0, 2, size=(10,)).long() + acc.update((y_pred, y)) + np_y = y.numpy().ravel() + np_y_pred = (torch.sigmoid(y_pred).numpy() >= 0.75).astype(int).ravel() + assert acc._type == "binary" + assert isinstance(acc.compute(), float) + assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute()) + + # Batched Updates + acc.reset() + y_pred = torch.randn(size=(100,)) + y = torch.randint(0, 2, size=(100,)).long() + + n_iters = 16 + batch_size = y.shape[0] // n_iters + 1 + + for i in range(n_iters): + idx = i * batch_size + acc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) + + np_y = y.numpy().ravel() + np_y_pred = (torch.sigmoid(y_pred).numpy() >= 0.75).astype(int).ravel() + assert acc._type == "binary" + assert isinstance(acc.compute(), float) + assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute()) + + # check multiple random inputs as random exact occurencies are rare + for _ in range(10): + _test() + + def test_binary_input_NL(): # Binary accuracy on input of shape (N, L) def _test(): From f296e955a384f8971861f2d4071d34bd3200dada Mon Sep 17 00:00:00 2001 From: vcarpani Date: Fri, 9 Oct 2020 08:45:01 +0200 Subject: [PATCH 5/7] Update documentation. --- docs/source/metrics.rst | 34 +++++++++++++++++++++++++++++++--- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 13ef7b3347e..9fd736fe35c 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -17,6 +17,34 @@ value is then computed using the output of the engine's ``process_function``: metric = Accuracy() metric.attach(engine, "accuracy") +If the engine's prediction output ``y_pred`` represents probability estimates, it can be binarized using the +``Mode.PROBABILITIES``: +.. code-block:: python + + def process_function(engine, batch): + # ... + y = torch.from_numpy(np.array([0, 0, 1])) + y_pred = torch.from_numpy(np.array([0.1, 0.2, 0.7])) + return y_pred, y + + engine = Engine(process_function) + metric = Accuracy(mode=Accuracy.Mode.PROBABILITIES) + metric.attach(engine, "accuracy") + +If the engine's prediction output ``y_pred`` represents logits, it can be binarized using the +``Mode.LOGITS``: +.. code-block:: python + + def process_function(engine, batch): + # ... + y = torch.from_numpy(np.array([0, 0, 1])) + y_pred = torch.from_numpy(np.array([-2.1, 0.6, 1.7])) + return y_pred, y + + engine = Engine(process_function) + metric = Accuracy(mode=Accuracy.Mode.LOGITS) + metric.attach(engine, "accuracy") + If the engine's output is not in the format ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``, the user can use the ``output_transform`` argument to transform it: @@ -41,13 +69,13 @@ use the ``output_transform`` argument to transform it: .. warning:: Please, be careful when using ``lambda`` functions to setup multiple ``output_transform`` for multiple metrics - + .. code-block:: python # Wrong # metrics_group = [Accuracy(output_transform=lambda output: output[name]) for name in names] # As lambda can not store `name` and all `output_transform` will use the last `name` - + # A correct way. For example, using functools.partial from functools import partial @@ -55,7 +83,7 @@ use the ``output_transform`` argument to transform it: return output[name] metrics_group = [Accuracy(output_transform=partial(ot_func, name=name)) for name in names] - + For more details, see `here `_ .. Note :: From 4c78b5103beb260bc50a7bf6ff6cc6308f9b5e8c Mon Sep 17 00:00:00 2001 From: vcarpani Date: Fri, 9 Oct 2020 08:47:47 +0200 Subject: [PATCH 6/7] Fix coding style. --- ignite/metrics/accuracy.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ignite/metrics/accuracy.py b/ignite/metrics/accuracy.py index 2f002c7791f..395d14d0b12 100644 --- a/ignite/metrics/accuracy.py +++ b/ignite/metrics/accuracy.py @@ -1,5 +1,5 @@ -from typing import Callable, Sequence, Union import enum +from typing import Callable, Sequence, Union import torch @@ -123,6 +123,7 @@ class Accuracy(_BaseClassification): binarization is used. """ + class Mode(enum.Enum): LABELS = 0 PROBABILITIES = 1 From 8592b288af16bab8f3f91c430faf0c906fcd2bda Mon Sep 17 00:00:00 2001 From: vcarpani Date: Sun, 25 Oct 2020 10:52:09 +0100 Subject: [PATCH 7/7] Fixes after review. --- docs/source/metrics.rst | 4 ++-- ignite/metrics/accuracy.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 9fd736fe35c..d526ec00af4 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -18,7 +18,7 @@ value is then computed using the output of the engine's ``process_function``: metric.attach(engine, "accuracy") If the engine's prediction output ``y_pred`` represents probability estimates, it can be binarized using the -``Mode.PROBABILITIES``: +``Mode.PROBABILITIES``, with default threshold of 0.5: .. code-block:: python def process_function(engine, batch): @@ -32,7 +32,7 @@ If the engine's prediction output ``y_pred`` represents probability estimates, i metric.attach(engine, "accuracy") If the engine's prediction output ``y_pred`` represents logits, it can be binarized using the -``Mode.LOGITS``: +``Mode.LOGITS``, with default threshold of 0.5: .. code-block:: python def process_function(engine, batch): diff --git a/ignite/metrics/accuracy.py b/ignite/metrics/accuracy.py index 395d14d0b12..8e0f50572bc 100644 --- a/ignite/metrics/accuracy.py +++ b/ignite/metrics/accuracy.py @@ -118,14 +118,14 @@ class Accuracy(_BaseClassification): device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By default, CPU. mode (Mode): specifies in which form input will be passed. This can be useful to directly compute - accuracy on the output of a neural network, which ofter return probabilities. By default, LABELS. - binarization_threshold (float): threshold for binarization of the input, in case a Mode that uses + accuracy on the output of a neural network, which ofter return probabilities. By default, UNCHANGED. + threshold (float): threshold for binarization of the input, in case a Mode that uses binarization is used. """ class Mode(enum.Enum): - LABELS = 0 + UNCHANGED = 0 PROBABILITIES = 1 LOGITS = 2 @@ -134,7 +134,7 @@ def __init__( output_transform: Callable = lambda x: x, is_multilabel: bool = False, device: Union[str, torch.device] = torch.device("cpu"), - mode: Mode = Mode.LABELS, + mode: Mode = Mode.UNCHANGED, threshold: float = 0.5, ): self._num_correct = None