@@ -55,6 +55,9 @@ class EpochWise(MetricUsage):
5555 - :meth:`~ignite.metrics.Metric.started` on every ``EPOCH_STARTED`` (See :class:`~ignite.engine.events.Events`).
5656 - :meth:`~ignite.metrics.Metric.iteration_completed` on every ``ITERATION_COMPLETED``.
5757 - :meth:`~ignite.metrics.Metric.completed` on every ``EPOCH_COMPLETED``.
58+
59+ Attributes:
60+ usage_name (str): usage name string
5861 """
5962
6063 usage_name = "epoch_wise"
@@ -76,6 +79,9 @@ class BatchWise(MetricUsage):
7679 - :meth:`~ignite.metrics.Metric.started` on every ``ITERATION_STARTED`` (See :class:`~ignite.engine.events.Events`).
7780 - :meth:`~ignite.metrics.Metric.iteration_completed` on every ``ITERATION_COMPLETED``.
7881 - :meth:`~ignite.metrics.Metric.completed` on every ``ITERATION_COMPLETED``.
82+
83+ Attributes:
84+ usage_name (str): usage name string
7985 """
8086
8187 usage_name = "batch_wise"
@@ -125,9 +131,68 @@ class Metric(metaclass=ABCMeta):
125131 device (str or torch.device): specifies which device updates are accumulated on. Setting the
126132 metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
127133 non-blocking. By default, CPU.
134+
135+ Attributes:
136+ required_output_keys (tuple): dictionary defines required keys to be found in ``engine.state.output`` if the
137+ latter is a dictionary. Default, ``("y_pred", "y")``. This is useful with custom metrics that can require
138+ other arguments than predictions ``y_pred`` and targets ``y``. See notes below for an example.
139+
140+ Note:
141+
142+ Let's implement a custom metric that requires ``y_pred``, ``y`` and ``x`` as input for ``update`` function.
143+ In the example below we show how to setup standard metric like Accuracy and the custom metric using by an
144+ ``evaluator`` created with :meth:`~ignite.engine.create_supervised_evaluator` method.
145+
146+ .. code-block:: python
147+
148+ # https://discuss.pytorch.org/t/how-access-inputs-in-custom-ignite-metric/91221/5
149+
150+ import torch
151+ import torch.nn as nn
152+
153+ from ignite.metrics import Metric, Accuracy
154+ from ignite.engine import create_supervised_evaluator
155+
156+ class CustomMetric(Metric):
157+
158+ required_output_keys = ("y_pred", "y", "x")
159+
160+ def __init__(self, *args, **kwargs):
161+ super().__init__(*args, **kwargs)
162+
163+ def update(self, output):
164+ y_pred, y, x = output
165+ # ...
166+
167+ def reset(self):
168+ # ...
169+ pass
170+
171+ def compute(self):
172+ # ...
173+ pass
174+
175+ model = ...
176+
177+ metrics = {
178+ "Accuracy": Accuracy(),
179+ "CustomMetric": CustomMetric()
180+ }
181+
182+ evaluator = create_supervised_evaluator(
183+ model,
184+ metrics=metrics,
185+ output_transform=lambda x, y, y_pred: {"x": x, "y": y, "y_pred": y_pred}
186+ )
187+
188+ res = evaluator.run(data)
189+
128190 """
129191
130- _required_output_keys = ("y_pred" , "y" )
192+ # public class attribute
193+ required_output_keys = ("y_pred" , "y" )
194+ # for backward compatibility
195+ _required_output_keys = required_output_keys
131196
132197 def __init__ (
133198 self , output_transform : Callable = lambda x : x , device : Union [str , torch .device ] = torch .device ("cpu" ),
@@ -211,18 +276,18 @@ def iteration_completed(self, engine: Engine) -> None:
211276
212277 output = self ._output_transform (engine .state .output )
213278 if isinstance (output , Mapping ):
214- if self ._required_output_keys is None :
279+ if self .required_output_keys is None :
215280 raise TypeError (
216281 "Transformed engine output for {} metric should be a tuple/list, but given {}" .format (
217282 self .__class__ .__name__ , type (output )
218283 )
219284 )
220- if not all ([k in output for k in self ._required_output_keys ]):
285+ if not all ([k in output for k in self .required_output_keys ]):
221286 raise ValueError (
222287 "When transformed engine's output is a mapping, "
223- "it should contain {} keys, but given {}" .format (self ._required_output_keys , list (output .keys ()))
288+ "it should contain {} keys, but given {}" .format (self .required_output_keys , list (output .keys ()))
224289 )
225- output = tuple (output [k ] for k in self ._required_output_keys )
290+ output = tuple (output [k ] for k in self .required_output_keys )
226291 self .update (output )
227292
228293 def completed (self , engine : Engine , name : str ) -> None :
@@ -265,7 +330,8 @@ def attach(self, engine: Engine, name: str, usage: Union[str, MetricUsage] = Epo
265330 engine (Engine): the engine to which the metric must be attached
266331 name (str): the name of the metric to attach
267332 usage (str or MetricUsage, optional): the usage of the metric. Valid string values should be
268- 'EpochWise.usage_name' (default) or 'BatchWise.usage_name'.
333+ :attr:`ignite.metrics.EpochWise.usage_name` (default) or
334+ :attr:`ignite.metrics.BatchWise.usage_name`.
269335
270336 Example:
271337
0 commit comments