Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed issue when DATALOADER_STOP_ITERATION event is triggered when engine.run(data=None, ...) #3217

Merged
merged 2 commits into from
Mar 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 32 additions & 12 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,15 +1037,23 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:

while True:
self.state.batch = self.state.output = None

try:
# Avoid Events.GET_BATCH_STARTED triggered twice when data iter is restarted
if self.last_event_name != Events.DATALOADER_STOP_ITERATION:
self._fire_event(Events.GET_BATCH_STARTED)
yield from self._maybe_terminate_or_interrupt()
# We should not trigger GET_BATCH_STARTED, GET_BATCH_COMPLETED, DATALOADER_STOP_ITERATION events
# if no data was provided to engine.run(data=None, ...)
if self.state.dataloader is not None:
self._fire_event(Events.GET_BATCH_STARTED)
yield from self._maybe_terminate_or_interrupt()

self.state.batch = next(self._dataloader_iter)
self._fire_event(Events.GET_BATCH_COMPLETED)
yield from self._maybe_terminate_or_interrupt()

# We should not trigger GET_BATCH_STARTED, GET_BATCH_COMPLETED, DATALOADER_STOP_ITERATION events
# if no data was provided to engine.run(data=None, ...)
if self.state.dataloader is not None:
self._fire_event(Events.GET_BATCH_COMPLETED)
yield from self._maybe_terminate_or_interrupt()

iter_counter += 1
should_exit = False
Expand Down Expand Up @@ -1074,8 +1082,11 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:
)
break

self._fire_event(Events.DATALOADER_STOP_ITERATION)
yield from self._maybe_terminate_or_interrupt()
# We should not trigger GET_BATCH_STARTED, GET_BATCH_COMPLETED, DATALOADER_STOP_ITERATION events
# if no data was provided to engine.run(data=None, ...)
if self.state.dataloader is not None:
self._fire_event(Events.DATALOADER_STOP_ITERATION)
yield from self._maybe_terminate_or_interrupt()

self._setup_dataloader_iter()
should_exit = True
Expand Down Expand Up @@ -1198,12 +1209,18 @@ def _run_once_on_dataset_legacy(self) -> float:
try:
# Avoid Events.GET_BATCH_STARTED triggered twice when data iter is restarted
if self.last_event_name != Events.DATALOADER_STOP_ITERATION:
self._fire_event(Events.GET_BATCH_STARTED)
self._maybe_terminate_legacy()
# We should not trigger GET_BATCH_STARTED, GET_BATCH_COMPLETED, DATALOADER_STOP_ITERATION events
# if no data was provided to engine.run(data=None, ...)
if self.state.dataloader is not None:
self._fire_event(Events.GET_BATCH_STARTED)
self._maybe_terminate_legacy()

self.state.batch = next(self._dataloader_iter)
self._fire_event(Events.GET_BATCH_COMPLETED)
self._maybe_terminate_legacy()
# We should not trigger GET_BATCH_STARTED, GET_BATCH_COMPLETED, DATALOADER_STOP_ITERATION events
# if no data was provided to engine.run(data=None, ...)
if self.state.dataloader is not None:
self._fire_event(Events.GET_BATCH_COMPLETED)
self._maybe_terminate_legacy()

iter_counter += 1
should_exit = False
Expand Down Expand Up @@ -1232,8 +1249,11 @@ def _run_once_on_dataset_legacy(self) -> float:
)
break

self._fire_event(Events.DATALOADER_STOP_ITERATION)
self._maybe_terminate_legacy()
# We should not trigger GET_BATCH_STARTED, GET_BATCH_COMPLETED, DATALOADER_STOP_ITERATION events
# if no data was provided to engine.run(data=None, ...)
if self.state.dataloader is not None:
self._fire_event(Events.DATALOADER_STOP_ITERATION)
self._maybe_terminate_legacy()

self._setup_dataloader_iter()
should_exit = True
Expand Down
25 changes: 22 additions & 3 deletions tests/ignite/engine/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,8 +601,8 @@ def _test_check_triggered_events(self, data, max_epochs, epoch_length, exp_iter_
Events.EPOCH_COMPLETED: max_epochs,
Events.ITERATION_STARTED: max_epochs * epoch_length,
Events.ITERATION_COMPLETED: max_epochs * epoch_length,
Events.GET_BATCH_STARTED: max_epochs * epoch_length,
Events.GET_BATCH_COMPLETED: max_epochs * epoch_length,
Events.GET_BATCH_STARTED: max_epochs * epoch_length if data is not None else 0,
Events.GET_BATCH_COMPLETED: max_epochs * epoch_length if data is not None else 0,
Events.DATALOADER_STOP_ITERATION: (max_epochs - 1) if exp_iter_stops is None else exp_iter_stops,
}

Expand All @@ -617,7 +617,7 @@ def _test_run_check_triggered_events(self):
self._test_check_triggered_events(
list(range(100)), max_epochs=5, epoch_length=150, exp_iter_stops=150 * 5 // 100
)
self._test_check_triggered_events(None, max_epochs=5, epoch_length=150)
self._test_check_triggered_events(None, max_epochs=5, epoch_length=150, exp_iter_stops=0)

def test_run_check_triggered_events_list(self):
self._test_run_check_triggered_events()
Expand Down Expand Up @@ -1146,6 +1146,25 @@ def train_step(engine, batch):
assert trainer.state.epoch == 20
assert trainer.state.dataloader is None

def test_engine_no_data_events(self):
# Reproduces the issue https://github.com/pytorch/ignite/issues/3190
max_epochs = 4
dataset = range(10)

def training_step(engine, _):
assert engine.state.dataloader is None

trainer = Engine(training_step)
trainer.state.dataiter = iter(dataset)

@trainer.on(Events.DATALOADER_STOP_ITERATION)
@trainer.on(Events.GET_BATCH_STARTED)
@trainer.on(Events.GET_BATCH_COMPLETED)
def should_not_be_called():
assert False, trainer.last_event_name

trainer.run(max_epochs=max_epochs, epoch_length=4)

@pytest.mark.parametrize("data, epoch_length", [(None, 10), (range(10), None)])
def test_engine_run_resume(self, data, epoch_length):
# https://github.com/pytorch/ignite/wiki/Roadmap#runresume-logic-improvements
Expand Down
Loading