Skip to content

Commit d4f7252

Browse files
authored
Added required_output_keys public attribute (1289) (#1291)
* Fixes #1289 - Promoted _required_output_keys to be public as user would like to override it. * Updated docs
1 parent 564e541 commit d4f7252

File tree

5 files changed

+123
-12
lines changed

5 files changed

+123
-12
lines changed

ignite/metrics/accumulation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class VariableAccumulation(Metric):
3737
3838
"""
3939

40-
_required_output_keys = None
40+
required_output_keys = None
4141

4242
def __init__(
4343
self,

ignite/metrics/loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class Loss(Metric):
3232
3333
"""
3434

35-
_required_output_keys = None
35+
required_output_keys = None
3636

3737
def __init__(
3838
self,

ignite/metrics/metric.py

Lines changed: 72 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ class EpochWise(MetricUsage):
5555
- :meth:`~ignite.metrics.Metric.started` on every ``EPOCH_STARTED`` (See :class:`~ignite.engine.events.Events`).
5656
- :meth:`~ignite.metrics.Metric.iteration_completed` on every ``ITERATION_COMPLETED``.
5757
- :meth:`~ignite.metrics.Metric.completed` on every ``EPOCH_COMPLETED``.
58+
59+
Attributes:
60+
usage_name (str): usage name string
5861
"""
5962

6063
usage_name = "epoch_wise"
@@ -76,6 +79,9 @@ class BatchWise(MetricUsage):
7679
- :meth:`~ignite.metrics.Metric.started` on every ``ITERATION_STARTED`` (See :class:`~ignite.engine.events.Events`).
7780
- :meth:`~ignite.metrics.Metric.iteration_completed` on every ``ITERATION_COMPLETED``.
7881
- :meth:`~ignite.metrics.Metric.completed` on every ``ITERATION_COMPLETED``.
82+
83+
Attributes:
84+
usage_name (str): usage name string
7985
"""
8086

8187
usage_name = "batch_wise"
@@ -125,9 +131,68 @@ class Metric(metaclass=ABCMeta):
125131
device (str or torch.device): specifies which device updates are accumulated on. Setting the
126132
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
127133
non-blocking. By default, CPU.
134+
135+
Attributes:
136+
required_output_keys (tuple): dictionary defines required keys to be found in ``engine.state.output`` if the
137+
latter is a dictionary. Default, ``("y_pred", "y")``. This is useful with custom metrics that can require
138+
other arguments than predictions ``y_pred`` and targets ``y``. See notes below for an example.
139+
140+
Note:
141+
142+
Let's implement a custom metric that requires ``y_pred``, ``y`` and ``x`` as input for ``update`` function.
143+
In the example below we show how to setup standard metric like Accuracy and the custom metric using by an
144+
``evaluator`` created with :meth:`~ignite.engine.create_supervised_evaluator` method.
145+
146+
.. code-block:: python
147+
148+
# https://discuss.pytorch.org/t/how-access-inputs-in-custom-ignite-metric/91221/5
149+
150+
import torch
151+
import torch.nn as nn
152+
153+
from ignite.metrics import Metric, Accuracy
154+
from ignite.engine import create_supervised_evaluator
155+
156+
class CustomMetric(Metric):
157+
158+
required_output_keys = ("y_pred", "y", "x")
159+
160+
def __init__(self, *args, **kwargs):
161+
super().__init__(*args, **kwargs)
162+
163+
def update(self, output):
164+
y_pred, y, x = output
165+
# ...
166+
167+
def reset(self):
168+
# ...
169+
pass
170+
171+
def compute(self):
172+
# ...
173+
pass
174+
175+
model = ...
176+
177+
metrics = {
178+
"Accuracy": Accuracy(),
179+
"CustomMetric": CustomMetric()
180+
}
181+
182+
evaluator = create_supervised_evaluator(
183+
model,
184+
metrics=metrics,
185+
output_transform=lambda x, y, y_pred: {"x": x, "y": y, "y_pred": y_pred}
186+
)
187+
188+
res = evaluator.run(data)
189+
128190
"""
129191

130-
_required_output_keys = ("y_pred", "y")
192+
# public class attribute
193+
required_output_keys = ("y_pred", "y")
194+
# for backward compatibility
195+
_required_output_keys = required_output_keys
131196

132197
def __init__(
133198
self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu"),
@@ -211,18 +276,18 @@ def iteration_completed(self, engine: Engine) -> None:
211276

212277
output = self._output_transform(engine.state.output)
213278
if isinstance(output, Mapping):
214-
if self._required_output_keys is None:
279+
if self.required_output_keys is None:
215280
raise TypeError(
216281
"Transformed engine output for {} metric should be a tuple/list, but given {}".format(
217282
self.__class__.__name__, type(output)
218283
)
219284
)
220-
if not all([k in output for k in self._required_output_keys]):
285+
if not all([k in output for k in self.required_output_keys]):
221286
raise ValueError(
222287
"When transformed engine's output is a mapping, "
223-
"it should contain {} keys, but given {}".format(self._required_output_keys, list(output.keys()))
288+
"it should contain {} keys, but given {}".format(self.required_output_keys, list(output.keys()))
224289
)
225-
output = tuple(output[k] for k in self._required_output_keys)
290+
output = tuple(output[k] for k in self.required_output_keys)
226291
self.update(output)
227292

228293
def completed(self, engine: Engine, name: str) -> None:
@@ -265,7 +330,8 @@ def attach(self, engine: Engine, name: str, usage: Union[str, MetricUsage] = Epo
265330
engine (Engine): the engine to which the metric must be attached
266331
name (str): the name of the metric to attach
267332
usage (str or MetricUsage, optional): the usage of the metric. Valid string values should be
268-
'EpochWise.usage_name' (default) or 'BatchWise.usage_name'.
333+
:attr:`ignite.metrics.EpochWise.usage_name` (default) or
334+
:attr:`ignite.metrics.BatchWise.usage_name`.
269335
270336
Example:
271337

ignite/metrics/running_average.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def log_running_avg_metrics(engine):
4444
4545
"""
4646

47-
_required_output_keys = None
47+
required_output_keys = None
4848

4949
def __init__(
5050
self,

tests/ignite/metrics/test_metric.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def test_output_as_mapping_wrong_keys():
6767

6868
def test_output_as_mapping_keys_is_none():
6969
class DummyMetric(Metric):
70-
_required_output_keys = None
70+
required_output_keys = None
7171

7272
def reset(self):
7373
pass
@@ -79,7 +79,7 @@ def update(self, output):
7979
pass
8080

8181
metric = DummyMetric()
82-
assert metric._required_output_keys is None
82+
assert metric.required_output_keys is None
8383
state = State(output=({"y1": 0, "y2": 1}))
8484
engine = MagicMock(state=state)
8585

@@ -318,7 +318,7 @@ def process_function(*args, **kwargs):
318318

319319
def test_detach():
320320
class DummyMetric(Metric):
321-
_required_output_keys = None
321+
required_output_keys = None
322322

323323
def reset(self):
324324
pass
@@ -793,3 +793,48 @@ def _():
793793
assert bfm[0] == 1
794794

795795
engine.run([0, 1, 2, 3], max_epochs=10)
796+
797+
798+
def test_override_required_output_keys():
799+
# https://discuss.pytorch.org/t/how-access-inputs-in-custom-ignite-metric/91221/5
800+
import torch.nn as nn
801+
802+
from ignite.engine import create_supervised_evaluator
803+
804+
counter = [0]
805+
806+
class CustomMetric(Metric):
807+
required_output_keys = ("y_pred", "y", "x")
808+
809+
def __init__(self, *args, **kwargs):
810+
super().__init__(*args, **kwargs)
811+
812+
def update(self, output):
813+
y_pred, y, x = output
814+
assert y_pred.shape == (4, 3)
815+
assert y.shape == (4,)
816+
assert x.shape == (4, 10)
817+
assert x.equal(data[counter[0]][0])
818+
assert y.equal(data[counter[0]][1])
819+
counter[0] += 1
820+
821+
def reset(self):
822+
pass
823+
824+
def compute(self):
825+
pass
826+
827+
model = nn.Linear(10, 3)
828+
829+
metrics = {"Precision": Precision(), "CustomMetric": CustomMetric()}
830+
831+
evaluator = create_supervised_evaluator(
832+
model, metrics=metrics, output_transform=lambda x, y, y_pred: {"x": x, "y": y, "y_pred": y_pred}
833+
)
834+
835+
data = [
836+
(torch.rand(4, 10), torch.randint(0, 3, size=(4,))),
837+
(torch.rand(4, 10), torch.randint(0, 3, size=(4,))),
838+
(torch.rand(4, 10), torch.randint(0, 3, size=(4,))),
839+
]
840+
evaluator.run(data)

0 commit comments

Comments
 (0)