diff --git a/ignite/engine/engine.py b/ignite/engine/engine.py index ab9e137b59f8..3bf935eea6d3 100644 --- a/ignite/engine/engine.py +++ b/ignite/engine/engine.py @@ -438,14 +438,28 @@ def fire_event(self, event_name: Any) -> None: return self._fire_event(event_name) def terminate(self) -> None: - """Sends terminate signal to the engine, so that it terminates completely the run after - the current iteration.""" + """Sends terminate signal to the engine, so that it terminates completely the run. The run is + terminated after the event on which ``terminate`` method was called. The following events are triggered: + + - ... + - Terminating event + - :attr:`~ignite.engine.events.Events.TERMINATE` + - :attr:`~ignite.engine.events.Events.COMPLETED` + """ self.logger.info("Terminate signaled. Engine will stop after current iteration is finished.") self.should_terminate = True def terminate_epoch(self) -> None: - """Sends terminate signal to the engine, so that it terminates the current epoch - after the current iteration.""" + """Sends terminate signal to the engine, so that it terminates the current epoch. The run + continues from the next epoch. The following events are triggered: + + - ... + - Event on which ``terminate_epoch`` method is called + - :attr:`~ignite.engine.events.Events.TERMINATE_SINGLE_EPOCH` + - :attr:`~ignite.engine.events.Events.EPOCH_COMPLETED` + - :attr:`~ignite.engine.events.Events.EPOCH_STARTED` + - ... + """ self.logger.info( "Terminate current epoch is signaled. " "Current epoch iteration will stop after current iteration is finished." diff --git a/tests/ignite/engine/test_engine.py b/tests/ignite/engine/test_engine.py index b038a3ce7cae..8d7d21f686f2 100644 --- a/tests/ignite/engine/test_engine.py +++ b/tests/ignite/engine/test_engine.py @@ -104,7 +104,7 @@ def end_of_epoch_handler(engine): @pytest.mark.parametrize("data, epoch_length", [(None, 10), (range(10), None)]) -def test_terminate_at_start_of_epoch_stops_run_after_completing_iteration(data, epoch_length): +def test_terminate_at_start_of_epoch(data, epoch_length): max_epochs = 5 epoch_to_terminate_on = 3 real_epoch_length = epoch_length if data is None else len(data) @@ -293,7 +293,7 @@ def start_of_iteration_handler(engine): def test_terminate_epoch_events_sequence(terminate_epoch_event, i): engine = RecordedEngine(MagicMock(return_value=1)) data = range(10) - max_epochs = 5 + max_epochs = 3 # TODO: Bug: Events.GET_BATCH_STARTED(once=12) is called twice ! # prevent call_terminate_epoch to be called twice @@ -310,14 +310,23 @@ def call_terminate_epoch(): def check_previous_events(iter_counter): e = i // len(data) + 1 - print("engine.called_events:", engine.called_events) - assert engine.called_events[0] == (0, 0, Events.STARTED) - assert engine.called_events[-1] == (e, i, Events.TERMINATE_SINGLE_EPOCH) assert engine.called_events[-2] == (e, i, terminate_epoch_event) + assert engine.called_events[-1] == (e, i, Events.TERMINATE_SINGLE_EPOCH) + + @engine.on(Events.EPOCH_COMPLETED) + def check_previous_events2(): + e = i // len(data) + 1 + if e == engine.state.epoch and i == engine.state.iteration: + assert engine.called_events[-3] == (e, i, terminate_epoch_event) + assert engine.called_events[-2] == (e, i, Events.TERMINATE_SINGLE_EPOCH) + assert engine.called_events[-1] == (e, i, Events.EPOCH_COMPLETED) engine.run(data, max_epochs=max_epochs) + assert engine.state.epoch == max_epochs + assert (max_epochs - 1) * len(data) < engine.state.iteration < max_epochs * len(data) + def _create_mock_data_loader(epochs, batches_per_epoch): batches = [MagicMock()] * batches_per_epoch