From 7e7094990921a9948e0685b33104f608d112cccb Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 14 Sep 2020 22:11:05 +0000 Subject: [PATCH 1/2] Fixes #1289 - Promoted _required_output_keys to be public as user would like to override it. --- ignite/metrics/accumulation.py | 2 +- ignite/metrics/loss.py | 2 +- ignite/metrics/metric.py | 66 ++++++++++++++++++++++++++--- ignite/metrics/running_average.py | 2 +- tests/ignite/metrics/test_metric.py | 51 ++++++++++++++++++++-- 5 files changed, 112 insertions(+), 11 deletions(-) diff --git a/ignite/metrics/accumulation.py b/ignite/metrics/accumulation.py index 926e7816bae2..62708ec53a78 100644 --- a/ignite/metrics/accumulation.py +++ b/ignite/metrics/accumulation.py @@ -37,7 +37,7 @@ class VariableAccumulation(Metric): """ - _required_output_keys = None + required_output_keys = None def __init__( self, diff --git a/ignite/metrics/loss.py b/ignite/metrics/loss.py index c667cf3ade52..8fc3aaba3002 100644 --- a/ignite/metrics/loss.py +++ b/ignite/metrics/loss.py @@ -32,7 +32,7 @@ class Loss(Metric): """ - _required_output_keys = None + required_output_keys = None def __init__( self, diff --git a/ignite/metrics/metric.py b/ignite/metrics/metric.py index 228a89d2d1bf..272c2184e72d 100644 --- a/ignite/metrics/metric.py +++ b/ignite/metrics/metric.py @@ -125,9 +125,65 @@ class Metric(metaclass=ABCMeta): device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By default, CPU. + + Class Attributes: + required_output_keys (dict): dictionary defines required keys to be found in ``engine.state.output`` if the + latter is a dictionary. This is useful with custom metrics that can require other arguments than + predictions ``y_pred`` and targets ``y``. See notes below for an example. + + Note: + + .. code-block:: python + + # https://discuss.pytorch.org/t/how-access-inputs-in-custom-ignite-metric/91221/5 + # Let's implement a custom metric that requires ``y_pred``, ``y`` and ``x`` + + import torch + import torch.nn as nn + + from ignite.metrics import Metric, Accuracy + from ignite.engine import create_supervised_evaluator + + class CustomMetric(Metric): + + required_output_keys = ("y_pred", "y", "x") + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def update(self, output): + y_pred, y, x = output + # ... + + def reset(self): + # ... + pass + + def compute(self): + # ... + pass + + model = ... + + metrics = { + "Accuracy": Accuracy(), + "CustomMetric": CustomMetric() + } + + evaluator = create_supervised_evaluator( + model, + metrics=metrics, + output_transform=lambda x, y, y_pred: {"x": x, "y": y, "y_pred": y_pred} + ) + + res = evaluator.run(data) + """ - _required_output_keys = ("y_pred", "y") + # public class attribute + required_output_keys = ("y_pred", "y") + # for backward compatibility + _required_output_keys = required_output_keys def __init__( self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu"), @@ -211,18 +267,18 @@ def iteration_completed(self, engine: Engine) -> None: output = self._output_transform(engine.state.output) if isinstance(output, Mapping): - if self._required_output_keys is None: + if self.required_output_keys is None: raise TypeError( "Transformed engine output for {} metric should be a tuple/list, but given {}".format( self.__class__.__name__, type(output) ) ) - if not all([k in output for k in self._required_output_keys]): + if not all([k in output for k in self.required_output_keys]): raise ValueError( "When transformed engine's output is a mapping, " - "it should contain {} keys, but given {}".format(self._required_output_keys, list(output.keys())) + "it should contain {} keys, but given {}".format(self.required_output_keys, list(output.keys())) ) - output = tuple(output[k] for k in self._required_output_keys) + output = tuple(output[k] for k in self.required_output_keys) self.update(output) def completed(self, engine: Engine, name: str) -> None: diff --git a/ignite/metrics/running_average.py b/ignite/metrics/running_average.py index 0fa1216c7940..419743ec15a5 100644 --- a/ignite/metrics/running_average.py +++ b/ignite/metrics/running_average.py @@ -44,7 +44,7 @@ def log_running_avg_metrics(engine): """ - _required_output_keys = None + required_output_keys = None def __init__( self, diff --git a/tests/ignite/metrics/test_metric.py b/tests/ignite/metrics/test_metric.py index 31395b99ec98..99317b4addfd 100644 --- a/tests/ignite/metrics/test_metric.py +++ b/tests/ignite/metrics/test_metric.py @@ -67,7 +67,7 @@ def test_output_as_mapping_wrong_keys(): def test_output_as_mapping_keys_is_none(): class DummyMetric(Metric): - _required_output_keys = None + required_output_keys = None def reset(self): pass @@ -79,7 +79,7 @@ def update(self, output): pass metric = DummyMetric() - assert metric._required_output_keys is None + assert metric.required_output_keys is None state = State(output=({"y1": 0, "y2": 1})) engine = MagicMock(state=state) @@ -318,7 +318,7 @@ def process_function(*args, **kwargs): def test_detach(): class DummyMetric(Metric): - _required_output_keys = None + required_output_keys = None def reset(self): pass @@ -794,3 +794,48 @@ def _(): assert bfm[0] == 1 engine.run([0, 1, 2, 3], max_epochs=10) + + +def test_override_required_output_keys(): + # https://discuss.pytorch.org/t/how-access-inputs-in-custom-ignite-metric/91221/5 + import torch.nn as nn + + from ignite.engine import create_supervised_evaluator + + counter = [0] + + class CustomMetric(Metric): + required_output_keys = ("y_pred", "y", "x") + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def update(self, output): + y_pred, y, x = output + assert y_pred.shape == (4, 3) + assert y.shape == (4,) + assert x.shape == (4, 10) + assert x.equal(data[counter[0]][0]) + assert y.equal(data[counter[0]][1]) + counter[0] += 1 + + def reset(self): + pass + + def compute(self): + pass + + model = nn.Linear(10, 3) + + metrics = {"Precision": Precision(), "CustomMetric": CustomMetric()} + + evaluator = create_supervised_evaluator( + model, metrics=metrics, output_transform=lambda x, y, y_pred: {"x": x, "y": y, "y_pred": y_pred} + ) + + data = [ + (torch.rand(4, 10), torch.randint(0, 3, size=(4,))), + (torch.rand(4, 10), torch.randint(0, 3, size=(4,))), + (torch.rand(4, 10), torch.randint(0, 3, size=(4,))), + ] + evaluator.run(data) From 1d1f46cc20b5b0fe7f0076bdeb509330b1e115d4 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 14 Sep 2020 22:35:25 +0000 Subject: [PATCH 2/2] Updated docs --- ignite/metrics/metric.py | 82 ++++++++++++++++++++++------------------ 1 file changed, 46 insertions(+), 36 deletions(-) diff --git a/ignite/metrics/metric.py b/ignite/metrics/metric.py index 272c2184e72d..49989fa47944 100644 --- a/ignite/metrics/metric.py +++ b/ignite/metrics/metric.py @@ -55,6 +55,9 @@ class EpochWise(MetricUsage): - :meth:`~ignite.metrics.Metric.started` on every ``EPOCH_STARTED`` (See :class:`~ignite.engine.events.Events`). - :meth:`~ignite.metrics.Metric.iteration_completed` on every ``ITERATION_COMPLETED``. - :meth:`~ignite.metrics.Metric.completed` on every ``EPOCH_COMPLETED``. + + Attributes: + usage_name (str): usage name string """ usage_name = "epoch_wise" @@ -76,6 +79,9 @@ class BatchWise(MetricUsage): - :meth:`~ignite.metrics.Metric.started` on every ``ITERATION_STARTED`` (See :class:`~ignite.engine.events.Events`). - :meth:`~ignite.metrics.Metric.iteration_completed` on every ``ITERATION_COMPLETED``. - :meth:`~ignite.metrics.Metric.completed` on every ``ITERATION_COMPLETED``. + + Attributes: + usage_name (str): usage name string """ usage_name = "batch_wise" @@ -126,57 +132,60 @@ class Metric(metaclass=ABCMeta): metric's device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By default, CPU. - Class Attributes: - required_output_keys (dict): dictionary defines required keys to be found in ``engine.state.output`` if the - latter is a dictionary. This is useful with custom metrics that can require other arguments than - predictions ``y_pred`` and targets ``y``. See notes below for an example. + Attributes: + required_output_keys (tuple): dictionary defines required keys to be found in ``engine.state.output`` if the + latter is a dictionary. Default, ``("y_pred", "y")``. This is useful with custom metrics that can require + other arguments than predictions ``y_pred`` and targets ``y``. See notes below for an example. Note: - .. code-block:: python + Let's implement a custom metric that requires ``y_pred``, ``y`` and ``x`` as input for ``update`` function. + In the example below we show how to setup standard metric like Accuracy and the custom metric using by an + ``evaluator`` created with :meth:`~ignite.engine.create_supervised_evaluator` method. - # https://discuss.pytorch.org/t/how-access-inputs-in-custom-ignite-metric/91221/5 - # Let's implement a custom metric that requires ``y_pred``, ``y`` and ``x`` + .. code-block:: python - import torch - import torch.nn as nn + # https://discuss.pytorch.org/t/how-access-inputs-in-custom-ignite-metric/91221/5 - from ignite.metrics import Metric, Accuracy - from ignite.engine import create_supervised_evaluator + import torch + import torch.nn as nn - class CustomMetric(Metric): + from ignite.metrics import Metric, Accuracy + from ignite.engine import create_supervised_evaluator - required_output_keys = ("y_pred", "y", "x") + class CustomMetric(Metric): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + required_output_keys = ("y_pred", "y", "x") - def update(self, output): - y_pred, y, x = output - # ... + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) - def reset(self): - # ... - pass + def update(self, output): + y_pred, y, x = output + # ... - def compute(self): - # ... - pass + def reset(self): + # ... + pass - model = ... + def compute(self): + # ... + pass - metrics = { - "Accuracy": Accuracy(), - "CustomMetric": CustomMetric() - } + model = ... - evaluator = create_supervised_evaluator( - model, - metrics=metrics, - output_transform=lambda x, y, y_pred: {"x": x, "y": y, "y_pred": y_pred} - ) + metrics = { + "Accuracy": Accuracy(), + "CustomMetric": CustomMetric() + } + + evaluator = create_supervised_evaluator( + model, + metrics=metrics, + output_transform=lambda x, y, y_pred: {"x": x, "y": y, "y_pred": y_pred} + ) - res = evaluator.run(data) + res = evaluator.run(data) """ @@ -321,7 +330,8 @@ def attach(self, engine: Engine, name: str, usage: Union[str, MetricUsage] = Epo engine (Engine): the engine to which the metric must be attached name (str): the name of the metric to attach usage (str or MetricUsage, optional): the usage of the metric. Valid string values should be - 'EpochWise.usage_name' (default) or 'BatchWise.usage_name'. + :attr:`ignite.metrics.EpochWise.usage_name` (default) or + :attr:`ignite.metrics.BatchWise.usage_name`. Example: