Skip to content

Commit 307ac11

Browse files
thescriptedvfdev-5
andauthored
Adding max_iters as an optional arg in Engine run (#1381)
* initial draft, adding max_iters as optional args in run * fixed typo * minor bug fixes * resolving failing tests * fixed out-of-place conditional * typo fix * updated docstring for 'run' * added initial tests * (WIP) restructured creating a new state with max_iters * updated tests & docstrings * initial draft, adding max_iters as optional args in run * fixed typo * minor bug fixes * resolving failing tests * fixed out-of-place conditional * typo fix * updated docstring for 'run' * added initial tests * (WIP) restructured creating a new state with max_iters * updated tests & docstrings * added test to check _is_done * updating engine loop condition * autopep8 fix * linting issues * fixed mypy errors * fixed formatting * minor fix & add test for larger max_iters * removed unused typechecking Co-authored-by: thescripted <thescripted@users.noreply.github.com> Co-authored-by: vfdev <vfdev.5@gmail.com>
1 parent 1296c74 commit 307ac11

File tree

3 files changed

+88
-8
lines changed

3 files changed

+88
-8
lines changed

ignite/engine/engine.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import functools
22
import logging
3+
import math
34
import time
45
import warnings
56
import weakref
@@ -510,7 +511,7 @@ def load_state_dict(self, state_dict: Mapping) -> None:
510511
If `engine.state_dict_user_keys` contains keys, they should be also present in the state dictionary.
511512
Iteration and epoch values are 0-based: the first iteration or epoch is zero.
512513
513-
This method does not remove any custom attributs added by user.
514+
This method does not remove any custom attributes added by user.
514515
515516
Args:
516517
state_dict (Mapping): a dict with parameters
@@ -557,7 +558,13 @@ def load_state_dict(self, state_dict: Mapping) -> None:
557558

558559
@staticmethod
559560
def _is_done(state: State) -> bool:
560-
return state.iteration == state.epoch_length * state.max_epochs # type: ignore[operator]
561+
is_done_iters = state.max_iters is not None and state.iteration >= state.max_iters
562+
is_done_count = (
563+
state.epoch_length is not None
564+
and state.iteration >= state.epoch_length * state.max_epochs # type: ignore[operator]
565+
)
566+
is_done_epochs = state.max_epochs is not None and state.epoch >= state.max_epochs
567+
return is_done_iters or is_done_count or is_done_epochs
561568

562569
def set_data(self, data: Union[Iterable, DataLoader]) -> None:
563570
"""Method to set data. After calling the method the next batch passed to `processing_function` is
@@ -595,13 +602,19 @@ def switch_dataloader():
595602
self.state.dataloader = data
596603
self._dataloader_iter = iter(self.state.dataloader)
597604

598-
def run(self, data: Iterable, max_epochs: Optional[int] = None, epoch_length: Optional[int] = None,) -> State:
605+
def run(
606+
self,
607+
data: Iterable,
608+
max_epochs: Optional[int] = None,
609+
max_iters: Optional[int] = None,
610+
epoch_length: Optional[int] = None,
611+
) -> State:
599612
"""Runs the `process_function` over the passed data.
600613
601614
Engine has a state and the following logic is applied in this function:
602615
603-
- At the first call, new state is defined by `max_epochs`, `epoch_length` if provided. A timer for
604-
total and per-epoch time is initialized when Events.STARTED is handled.
616+
- At the first call, new state is defined by `max_epochs`, `max_iters`, `epoch_length`, if provided.
617+
A timer for total and per-epoch time is initialized when Events.STARTED is handled.
605618
- If state is already defined such that there are iterations to run until `max_epochs` and no input arguments
606619
provided, state is kept and used in the function.
607620
- If state is defined and engine is "done" (no iterations to run until `max_epochs`), a new state is defined.
@@ -617,6 +630,8 @@ def run(self, data: Iterable, max_epochs: Optional[int] = None, epoch_length: Op
617630
`len(data)`. If `data` is an iterator and `epoch_length` is not set, then it will be automatically
618631
determined as the iteration on which data iterator raises `StopIteration`.
619632
This argument should not change if run is resuming from a state.
633+
max_iters (int, optional): Number of iterations to run for.
634+
`max_iters` and `max_epochs` are mutually exclusive; only one of the two arguments should be provided.
620635
621636
Returns:
622637
State: output state.
@@ -670,16 +685,27 @@ def switch_batch(engine):
670685

671686
if self.state.max_epochs is None or self._is_done(self.state):
672687
# Create new state
673-
if max_epochs is None:
674-
max_epochs = 1
675688
if epoch_length is None:
676689
epoch_length = self._get_data_length(data)
677690
if epoch_length is not None and epoch_length < 1:
678691
raise ValueError("Input data has zero size. Please provide non-empty data")
679692

693+
if max_iters is None:
694+
if max_epochs is None:
695+
max_epochs = 1
696+
else:
697+
if max_epochs is not None:
698+
raise ValueError(
699+
"Arguments max_iters and max_epochs are mutually exclusive."
700+
"Please provide only max_epochs or max_iters."
701+
)
702+
if epoch_length is not None:
703+
max_epochs = math.ceil(max_iters / epoch_length)
704+
680705
self.state.iteration = 0
681706
self.state.epoch = 0
682707
self.state.max_epochs = max_epochs
708+
self.state.max_iters = max_iters
683709
self.state.epoch_length = epoch_length
684710
self.logger.info("Engine run starting with max_epochs={}.".format(max_epochs))
685711
else:
@@ -726,7 +752,7 @@ def _internal_run(self) -> State:
726752
try:
727753
start_time = time.time()
728754
self._fire_event(Events.STARTED)
729-
while self.state.epoch < self.state.max_epochs and not self.should_terminate: # type: ignore[operator]
755+
while not self._is_done(self.state) and not self.should_terminate:
730756
self.state.epoch += 1
731757
self._fire_event(Events.EPOCH_STARTED)
732758

@@ -800,6 +826,8 @@ def _run_once_on_dataset(self) -> float:
800826
if self.state.epoch_length is None:
801827
# Define epoch length and stop the epoch
802828
self.state.epoch_length = iter_counter
829+
if self.state.max_iters is not None:
830+
self.state.max_epochs = math.ceil(self.state.max_iters / self.state.epoch_length)
803831
break
804832

805833
# Should exit while loop if we can not iterate
@@ -839,6 +867,10 @@ def _run_once_on_dataset(self) -> float:
839867
if self.state.epoch_length is not None and iter_counter == self.state.epoch_length:
840868
break
841869

870+
if self.state.max_iters is not None and self.state.iteration == self.state.max_iters:
871+
self.should_terminate = True
872+
break
873+
842874
except Exception as e:
843875
self.logger.error("Current run is terminating due to exception: %s.", str(e))
844876
self._handle_exception(e)

ignite/engine/events.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,7 @@ class State:
344344
state.dataloader # data passed to engine
345345
state.epoch_length # optional length of an epoch
346346
state.max_epochs # number of epochs to run
347+
state.max_iter # number of iterations to run
347348
state.batch # batch passed to `process_function`
348349
state.output # output of `process_function` after a single iteration
349350
state.metrics # dictionary with defined metrics if any
@@ -368,6 +369,7 @@ def __init__(self, **kwargs: Any) -> None:
368369
self.epoch = 0
369370
self.epoch_length = None # type: Optional[int]
370371
self.max_epochs = None # type: Optional[int]
372+
self.max_iters = None # type: Optional[int]
371373
self.output = None # type: Optional[int]
372374
self.batch = None # type: Optional[int]
373375
self.metrics = {} # type: Dict[str, Any]

tests/ignite/engine/test_engine.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -891,3 +891,49 @@ def switch_dataloader():
891891
trainer.set_data(data2)
892892

893893
trainer.run(data1, max_epochs=10)
894+
895+
896+
def test_run_with_max_iters():
897+
max_iters = 8
898+
engine = Engine(lambda e, b: 1)
899+
engine.run([0] * 20, max_iters=max_iters)
900+
assert engine.state.iteration == max_iters
901+
assert engine.state.max_iters == max_iters
902+
903+
904+
def test_run_with_max_iters_greater_than_epoch_length():
905+
max_iters = 73
906+
engine = Engine(lambda e, b: 1)
907+
engine.run([0] * 20, max_iters=max_iters)
908+
assert engine.state.iteration == max_iters
909+
910+
911+
def test_run_with_invalid_max_iters_and_max_epoch():
912+
max_iters = 12
913+
max_epochs = 2
914+
engine = Engine(lambda e, b: 1)
915+
with pytest.raises(
916+
ValueError,
917+
match=r"Arguments max_iters and max_epochs are mutually exclusive."
918+
"Please provide only max_epochs or max_iters.",
919+
):
920+
engine.run([0] * 20, max_iters=max_iters, max_epochs=max_epochs)
921+
922+
923+
def test_epoch_events_fired():
924+
max_iters = 32
925+
engine = Engine(lambda e, b: 1)
926+
927+
@engine.on(Events.EPOCH_COMPLETED)
928+
def fired_event(engine):
929+
assert engine.state.iteration % engine.state.epoch_length == 0
930+
931+
engine.run([0] * 10, max_iters=max_iters)
932+
933+
934+
def test_is_done_with_max_iters():
935+
state = State(iteration=100, epoch=1, max_epochs=3, epoch_length=100, max_iters=250)
936+
assert not Engine._is_done(state)
937+
938+
state = State(iteration=250, epoch=1, max_epochs=3, epoch_length=100, max_iters=250)
939+
assert Engine._is_done(state)

0 commit comments

Comments
 (0)