Skip to content

Commit 3f8daaf

Browse files
authored
Exposed EventEnum in docs and added import statement in its examples (#1345) (#1353)
1 parent 74531d9 commit 3f8daaf

File tree

5 files changed

+39
-0
lines changed

5 files changed

+39
-0
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,9 @@ Examples
210210

211211
Custom events related to backward and optimizer step calls:
212212
```python
213+
from ignite.engine import EventEnum
214+
215+
213216
class BackpropEvents(EventEnum):
214217
BACKWARD_STARTED = 'backward_started'
215218
BACKWARD_COMPLETED = 'backward_completed'

docs/source/concepts.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,8 @@ and be registered with :meth:`~ignite.engine.engine.Engine.register_events` in a
234234

235235
.. code-block:: python
236236
237+
from ignite.engine import EventEnum
238+
237239
class CustomEvents(EventEnum):
238240
"""
239241
Custom events defined by user

docs/source/engine.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ ignite.engine.events
9292
.. autoclass:: Events
9393
:members:
9494

95+
.. autoclass:: EventEnum
96+
9597
.. autoclass:: State
9698

9799
.. autoclass:: RemovableEventHandle

docs/source/faq.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ flexibility to the user to allow for this:
4747

4848
.. code-block:: python
4949
50+
from ignite.engine import EventEnum
51+
5052
class BackpropEvents(EventEnum):
5153
"""
5254
Events based on back propagation

ignite/engine/events.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,36 @@ def __init__(self, *args, **kwargs):
142142

143143

144144
class EventEnum(CallableEventWithFilter, Enum):
145+
"""Base class for all :class:`~ignite.engine.events.Events`. User defined custom events should also inherit
146+
this class. For example, Custom events based on the loss calculation and backward pass can be created as follows:
147+
148+
.. code-block:: python
149+
150+
from ignite.engine import EventEnum
151+
152+
class BackpropEvents(EventEnum):
153+
BACKWARD_STARTED = 'backward_started'
154+
BACKWARD_COMPLETED = 'backward_completed'
155+
OPTIM_STEP_COMPLETED = 'optim_step_completed'
156+
157+
def update(engine, batch):
158+
# ...
159+
loss = criterion(y_pred, y)
160+
engine.fire_event(BackpropEvents.BACKWARD_STARTED)
161+
loss.backward()
162+
engine.fire_event(BackpropEvents.BACKWARD_COMPLETED)
163+
optimizer.step()
164+
engine.fire_event(BackpropEvents.OPTIM_STEP_COMPLETED)
165+
# ...
166+
167+
trainer = Engine(update)
168+
trainer.register_events(*BackpropEvents)
169+
170+
@trainer.on(BackpropEvents.BACKWARD_STARTED)
171+
def function_before_backprop(engine):
172+
# ...
173+
"""
174+
145175
pass
146176

147177

0 commit comments

Comments
 (0)