Skip to content

Commit

Permalink
Updated tests and docs
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 committed Aug 22, 2022
1 parent 54e2406 commit 20ee6a9
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 9 deletions.
22 changes: 18 additions & 4 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,14 +438,28 @@ def fire_event(self, event_name: Any) -> None:
return self._fire_event(event_name)

def terminate(self) -> None:
"""Sends terminate signal to the engine, so that it terminates completely the run after
the current iteration."""
"""Sends terminate signal to the engine, so that it terminates completely the run. The run is
terminated after the event on which ``terminate`` method was called. The following events are triggered:
- ...
- Terminating event
- :attr:`~ignite.engine.events.Events.TERMINATE`
- :attr:`~ignite.engine.events.Events.COMPLETED`
"""
self.logger.info("Terminate signaled. Engine will stop after current iteration is finished.")
self.should_terminate = True

def terminate_epoch(self) -> None:
"""Sends terminate signal to the engine, so that it terminates the current epoch
after the current iteration."""
"""Sends terminate signal to the engine, so that it terminates the current epoch. The run
continues from the next epoch. The following events are triggered:
- ...
- Event on which ``terminate_epoch`` method is called
- :attr:`~ignite.engine.events.Events.TERMINATE_SINGLE_EPOCH`
- :attr:`~ignite.engine.events.Events.EPOCH_COMPLETED`
- :attr:`~ignite.engine.events.Events.EPOCH_STARTED`
- ...
"""
self.logger.info(
"Terminate current epoch is signaled. "
"Current epoch iteration will stop after current iteration is finished."
Expand Down
19 changes: 14 additions & 5 deletions tests/ignite/engine/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def end_of_epoch_handler(engine):


@pytest.mark.parametrize("data, epoch_length", [(None, 10), (range(10), None)])
def test_terminate_at_start_of_epoch_stops_run_after_completing_iteration(data, epoch_length):
def test_terminate_at_start_of_epoch(data, epoch_length):
max_epochs = 5
epoch_to_terminate_on = 3
real_epoch_length = epoch_length if data is None else len(data)
Expand Down Expand Up @@ -293,7 +293,7 @@ def start_of_iteration_handler(engine):
def test_terminate_epoch_events_sequence(terminate_epoch_event, i):
engine = RecordedEngine(MagicMock(return_value=1))
data = range(10)
max_epochs = 5
max_epochs = 3

# TODO: Bug: Events.GET_BATCH_STARTED(once=12) is called twice !
# prevent call_terminate_epoch to be called twice
Expand All @@ -310,14 +310,23 @@ def call_terminate_epoch():
def check_previous_events(iter_counter):
e = i // len(data) + 1

print("engine.called_events:", engine.called_events)

assert engine.called_events[0] == (0, 0, Events.STARTED)
assert engine.called_events[-1] == (e, i, Events.TERMINATE_SINGLE_EPOCH)
assert engine.called_events[-2] == (e, i, terminate_epoch_event)
assert engine.called_events[-1] == (e, i, Events.TERMINATE_SINGLE_EPOCH)

@engine.on(Events.EPOCH_COMPLETED)
def check_previous_events2():
e = i // len(data) + 1
if e == engine.state.epoch and i == engine.state.iteration:
assert engine.called_events[-3] == (e, i, terminate_epoch_event)
assert engine.called_events[-2] == (e, i, Events.TERMINATE_SINGLE_EPOCH)
assert engine.called_events[-1] == (e, i, Events.EPOCH_COMPLETED)

engine.run(data, max_epochs=max_epochs)

assert engine.state.epoch == max_epochs
assert (max_epochs - 1) * len(data) < engine.state.iteration < max_epochs * len(data)


def _create_mock_data_loader(epochs, batches_per_epoch):
batches = [MagicMock()] * batches_per_epoch
Expand Down

0 comments on commit 20ee6a9

Please sign in to comment.