Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add probability to Accuracy #1354

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,34 @@ value is then computed using the output of the engine's ``process_function``:
metric = Accuracy()
metric.attach(engine, "accuracy")

If the engine's prediction output ``y_pred`` represents probability estimates, it can be binarized using the
``Mode.PROBABILITIES``:
vcarpani marked this conversation as resolved.
Show resolved Hide resolved
.. code-block:: python

def process_function(engine, batch):
# ...
y = torch.from_numpy(np.array([0, 0, 1]))
y_pred = torch.from_numpy(np.array([0.1, 0.2, 0.7]))
return y_pred, y

engine = Engine(process_function)
metric = Accuracy(mode=Accuracy.Mode.PROBABILITIES)
metric.attach(engine, "accuracy")

If the engine's prediction output ``y_pred`` represents logits, it can be binarized using the
``Mode.LOGITS``:
vcarpani marked this conversation as resolved.
Show resolved Hide resolved
.. code-block:: python

def process_function(engine, batch):
# ...
y = torch.from_numpy(np.array([0, 0, 1]))
y_pred = torch.from_numpy(np.array([-2.1, 0.6, 1.7]))
return y_pred, y

engine = Engine(process_function)
metric = Accuracy(mode=Accuracy.Mode.LOGITS)
metric.attach(engine, "accuracy")

If the engine's output is not in the format ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``, the user can
use the ``output_transform`` argument to transform it:

Expand All @@ -41,21 +69,21 @@ use the ``output_transform`` argument to transform it:
.. warning::

Please, be careful when using ``lambda`` functions to setup multiple ``output_transform`` for multiple metrics

.. code-block:: python

# Wrong
# metrics_group = [Accuracy(output_transform=lambda output: output[name]) for name in names]
# As lambda can not store `name` and all `output_transform` will use the last `name`

# A correct way. For example, using functools.partial
from functools import partial

def ot_func(output, name):
return output[name]

metrics_group = [Accuracy(output_transform=partial(ot_func, name=name)) for name in names]

For more details, see `here <https://discuss.pytorch.org/t/evaluate-multiple-models-with-one-evaluator-results-weird-metrics/96695>`_

.. Note ::
Expand Down
35 changes: 23 additions & 12 deletions ignite/metrics/accuracy.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import enum
from typing import Callable, Sequence, Union

import torch
Expand Down Expand Up @@ -103,19 +104,10 @@ class Accuracy(_BaseClassification):
- `y` and `y_pred` must be in the following shape of (batch_size, num_categories, ...) and
num_categories must be greater than 1 for multilabel cases.

In binary and multilabel cases, the elements of `y` and `y_pred` should have 0 or 1 values. Thresholding of
In binary and multilabel cases, the elements of `y` should have 0 or 1 values, while `y_pred` can represent
probabilities using PROBABILITIES Mode or logits using LOGITS Mode.
predictions can be done as below:

.. code-block:: python

def thresholded_output_transform(output):
y_pred, y = output
y_pred = torch.round(y_pred)
return y_pred, y

binary_accuracy = Accuracy(thresholded_output_transform)


Args:
output_transform (callable, optional): a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
Expand All @@ -125,17 +117,30 @@ def thresholded_output_transform(output):
device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's
device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By
default, CPU.
mode (Mode): specifies in which form input will be passed. This can be useful to directly compute
accuracy on the output of a neural network, which ofter return probabilities. By default, LABELS.
binarization_threshold (float): threshold for binarization of the input, in case a Mode that uses
vcarpani marked this conversation as resolved.
Show resolved Hide resolved
binarization is used.

"""

class Mode(enum.Enum):
LABELS = 0
PROBABILITIES = 1
LOGITS = 2

def __init__(
self,
output_transform: Callable = lambda x: x,
is_multilabel: bool = False,
device: Union[str, torch.device] = torch.device("cpu"),
mode: Mode = Mode.LABELS,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about correctness of name LABELS for multi-class case where we require to pass probas or logits and then take argmax.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about UNCHANGED or RAW_INPUT?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, something like RAW_INPUT could work. I'm thinking if we could not generalize this new options to all possible inputs and metric type: binary, multiclass, multilabel...

threshold: float = 0.5,
):
self._num_correct = None
self._num_examples = None
self._mode = mode
self._threshold = threshold
super(Accuracy, self).__init__(output_transform=output_transform, is_multilabel=is_multilabel, device=device)

@reinit__is_reduced
Expand All @@ -147,9 +152,15 @@ def reset(self) -> None:
@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
self._check_shape(output)
self._check_type(output)
y_pred, y = output[0].detach(), output[1].detach()

if self._mode == self.Mode.PROBABILITIES:
y_pred = (y_pred >= self._threshold).int()
if self._mode == self.Mode.LOGITS:
y_pred = (torch.sigmoid(y_pred) >= self._threshold).int()

self._check_type([y_pred, y])

if self._type == "binary":
correct = torch.eq(y_pred.view(-1).to(y), y.view(-1))
elif self._type == "multiclass":
Expand Down
197 changes: 197 additions & 0 deletions tests/ignite/metrics/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,156 @@ def _test():
_test()


def test_binary_input_N_probabilities():
# Binary accuracy on probabilities input of shape (N, 1) or (N, )
def _test():
acc = Accuracy(mode=Accuracy.Mode.PROBABILITIES)

y_pred = torch.rand(size=(10,))
y = torch.randint(0, 2, size=(10,)).long()
acc.update((y_pred, y))
np_y = y.numpy().ravel()
np_y_pred = y_pred.numpy().round().ravel()
assert acc._type == "binary"
assert isinstance(acc.compute(), float)
assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())

# Batched Updates
acc.reset()
y_pred = torch.rand(size=(100,))
y = torch.randint(0, 2, size=(100,)).long()

n_iters = 16
batch_size = y.shape[0] // n_iters + 1

for i in range(n_iters):
idx = i * batch_size
acc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))

np_y = y.numpy().ravel()
np_y_pred = y_pred.numpy().round().ravel()
assert acc._type == "binary"
assert isinstance(acc.compute(), float)
assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())

# check multiple random inputs as random exact occurencies are rare
for _ in range(10):
_test()


def test_binary_input_N_probabilities_threshold():
# Binary accuracy on probabilities input of shape (N, 1) or (N, ),
# with custom binarization threshold.
def _test():
acc = Accuracy(mode=Accuracy.Mode.PROBABILITIES, threshold=0.75)

y_pred = torch.rand(size=(10,))
y = torch.randint(0, 2, size=(10,)).long()
acc.update((y_pred, y))
np_y = y.numpy().ravel()
np_y_pred = (y_pred.numpy() >= 0.75).astype(int).ravel()
assert acc._type == "binary"
assert isinstance(acc.compute(), float)
assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())

# Batched Updates
acc.reset()
y_pred = torch.rand(size=(100,))
y = torch.randint(0, 2, size=(100,)).long()

n_iters = 16
batch_size = y.shape[0] // n_iters + 1

for i in range(n_iters):
idx = i * batch_size
acc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))

np_y = y.numpy().ravel()
np_y_pred = (y_pred.numpy() >= 0.75).astype(int).ravel()
assert acc._type == "binary"
assert isinstance(acc.compute(), float)
assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())

# check multiple random inputs as random exact occurencies are rare
for _ in range(10):
_test()


def test_binary_input_N_logits():
# Binary accuracy on logits input of shape (N, 1) or (N, )
def _test():
acc = Accuracy(mode=Accuracy.Mode.LOGITS)

y_pred = torch.randn(size=(10,))
y = torch.randint(0, 2, size=(10,)).long()
acc.update((y_pred, y))
np_y = y.numpy().ravel()
np_y_pred = torch.sigmoid(y_pred).numpy().round().ravel()
assert acc._type == "binary"
assert isinstance(acc.compute(), float)
assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())

# Batched Updates
acc.reset()
y_pred = torch.randn(size=(100,))
y = torch.randint(0, 2, size=(100,)).long()

n_iters = 16
batch_size = y.shape[0] // n_iters + 1

for i in range(n_iters):
idx = i * batch_size
acc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))

np_y = y.numpy().ravel()
np_y_pred = torch.sigmoid(y_pred).numpy().round().ravel()
assert acc._type == "binary"
assert isinstance(acc.compute(), float)
assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())

# check multiple random inputs as random exact occurencies are rare
for _ in range(10):
_test()


def test_binary_input_N_logits_threshold():
# Binary accuracy on logits input of shape (N, 1) or (N, ),
# with custom binarization threshold.
def _test():
acc = Accuracy(mode=Accuracy.Mode.LOGITS, threshold=0.75)

y_pred = torch.randn(size=(10,))
y = torch.randint(0, 2, size=(10,)).long()
acc.update((y_pred, y))
np_y = y.numpy().ravel()
np_y_pred = (torch.sigmoid(y_pred).numpy() >= 0.75).astype(int).ravel()
assert acc._type == "binary"
assert isinstance(acc.compute(), float)
assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())

# Batched Updates
acc.reset()
y_pred = torch.randn(size=(100,))
y = torch.randint(0, 2, size=(100,)).long()

n_iters = 16
batch_size = y.shape[0] // n_iters + 1

for i in range(n_iters):
idx = i * batch_size
acc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))

np_y = y.numpy().ravel()
np_y_pred = (torch.sigmoid(y_pred).numpy() >= 0.75).astype(int).ravel()
assert acc._type == "binary"
assert isinstance(acc.compute(), float)
assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())

# check multiple random inputs as random exact occurencies are rare
for _ in range(10):
_test()


def test_binary_input_NL():
# Binary accuracy on input of shape (N, L)
def _test():
Expand Down Expand Up @@ -138,6 +288,53 @@ def _test():
_test()


def test_binary_input_NL_probabilities():
# Binary accuracy on probabilities input of shape (N, L)
def _test():
acc = Accuracy(mode=Accuracy.Mode.PROBABILITIES)

y_pred = torch.rand(size=(10, 5))
y = torch.randint(0, 2, size=(10, 5)).long()
acc.update((y_pred, y))
np_y = y.numpy().ravel()
np_y_pred = y_pred.numpy().round().ravel()
assert acc._type == "binary"
assert isinstance(acc.compute(), float)
assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())

acc.reset()
y_pred = torch.rand(size=(10, 1, 5))
y = torch.randint(0, 2, size=(10, 1, 5)).long()
acc.update((y_pred, y))
np_y = y.numpy().ravel()
np_y_pred = y_pred.numpy().round().ravel()
assert acc._type == "binary"
assert isinstance(acc.compute(), float)
assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())

# Batched Updates
acc.reset()
y_pred = torch.rand(size=(100, 8))
y = torch.randint(0, 2, size=(100, 8)).long()

n_iters = 16
batch_size = y.shape[0] // n_iters + 1

for i in range(n_iters):
idx = i * batch_size
acc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))

np_y = y.numpy().ravel()
np_y_pred = y_pred.numpy().round().ravel()
assert acc._type == "binary"
assert isinstance(acc.compute(), float)
assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())

# check multiple random inputs as random exact occurencies are rare
for _ in range(10):
_test()


def test_binary_input_NHW():
# Binary accuracy on input of shape (N, H, W, ...)
def _test():
Expand Down