11import functools
22import logging
3+ import math
34import time
45import warnings
56import weakref
@@ -510,7 +511,7 @@ def load_state_dict(self, state_dict: Mapping) -> None:
510511 If `engine.state_dict_user_keys` contains keys, they should be also present in the state dictionary.
511512 Iteration and epoch values are 0-based: the first iteration or epoch is zero.
512513
513- This method does not remove any custom attributs added by user.
514+ This method does not remove any custom attributes added by user.
514515
515516 Args:
516517 state_dict (Mapping): a dict with parameters
@@ -557,7 +558,13 @@ def load_state_dict(self, state_dict: Mapping) -> None:
557558
558559 @staticmethod
559560 def _is_done (state : State ) -> bool :
560- return state .iteration == state .epoch_length * state .max_epochs # type: ignore[operator]
561+ is_done_iters = state .max_iters is not None and state .iteration >= state .max_iters
562+ is_done_count = (
563+ state .epoch_length is not None
564+ and state .iteration >= state .epoch_length * state .max_epochs # type: ignore[operator]
565+ )
566+ is_done_epochs = state .max_epochs is not None and state .epoch >= state .max_epochs
567+ return is_done_iters or is_done_count or is_done_epochs
561568
562569 def set_data (self , data : Union [Iterable , DataLoader ]) -> None :
563570 """Method to set data. After calling the method the next batch passed to `processing_function` is
@@ -595,13 +602,19 @@ def switch_dataloader():
595602 self .state .dataloader = data
596603 self ._dataloader_iter = iter (self .state .dataloader )
597604
598- def run (self , data : Iterable , max_epochs : Optional [int ] = None , epoch_length : Optional [int ] = None ,) -> State :
605+ def run (
606+ self ,
607+ data : Iterable ,
608+ max_epochs : Optional [int ] = None ,
609+ max_iters : Optional [int ] = None ,
610+ epoch_length : Optional [int ] = None ,
611+ ) -> State :
599612 """Runs the `process_function` over the passed data.
600613
601614 Engine has a state and the following logic is applied in this function:
602615
603- - At the first call, new state is defined by `max_epochs`, `epoch_length` if provided. A timer for
604- total and per-epoch time is initialized when Events.STARTED is handled.
616+ - At the first call, new state is defined by `max_epochs`, `max_iters`, ` epoch_length`, if provided.
617+ A timer for total and per-epoch time is initialized when Events.STARTED is handled.
605618 - If state is already defined such that there are iterations to run until `max_epochs` and no input arguments
606619 provided, state is kept and used in the function.
607620 - If state is defined and engine is "done" (no iterations to run until `max_epochs`), a new state is defined.
@@ -617,6 +630,8 @@ def run(self, data: Iterable, max_epochs: Optional[int] = None, epoch_length: Op
617630 `len(data)`. If `data` is an iterator and `epoch_length` is not set, then it will be automatically
618631 determined as the iteration on which data iterator raises `StopIteration`.
619632 This argument should not change if run is resuming from a state.
633+ max_iters (int, optional): Number of iterations to run for.
634+ `max_iters` and `max_epochs` are mutually exclusive; only one of the two arguments should be provided.
620635
621636 Returns:
622637 State: output state.
@@ -670,16 +685,27 @@ def switch_batch(engine):
670685
671686 if self .state .max_epochs is None or self ._is_done (self .state ):
672687 # Create new state
673- if max_epochs is None :
674- max_epochs = 1
675688 if epoch_length is None :
676689 epoch_length = self ._get_data_length (data )
677690 if epoch_length is not None and epoch_length < 1 :
678691 raise ValueError ("Input data has zero size. Please provide non-empty data" )
679692
693+ if max_iters is None :
694+ if max_epochs is None :
695+ max_epochs = 1
696+ else :
697+ if max_epochs is not None :
698+ raise ValueError (
699+ "Arguments max_iters and max_epochs are mutually exclusive."
700+ "Please provide only max_epochs or max_iters."
701+ )
702+ if epoch_length is not None :
703+ max_epochs = math .ceil (max_iters / epoch_length )
704+
680705 self .state .iteration = 0
681706 self .state .epoch = 0
682707 self .state .max_epochs = max_epochs
708+ self .state .max_iters = max_iters
683709 self .state .epoch_length = epoch_length
684710 self .logger .info ("Engine run starting with max_epochs={}." .format (max_epochs ))
685711 else :
@@ -726,7 +752,7 @@ def _internal_run(self) -> State:
726752 try :
727753 start_time = time .time ()
728754 self ._fire_event (Events .STARTED )
729- while self .state . epoch < self .state . max_epochs and not self .should_terminate : # type: ignore[operator]
755+ while not self ._is_done ( self .state ) and not self .should_terminate :
730756 self .state .epoch += 1
731757 self ._fire_event (Events .EPOCH_STARTED )
732758
@@ -800,6 +826,8 @@ def _run_once_on_dataset(self) -> float:
800826 if self .state .epoch_length is None :
801827 # Define epoch length and stop the epoch
802828 self .state .epoch_length = iter_counter
829+ if self .state .max_iters is not None :
830+ self .state .max_epochs = math .ceil (self .state .max_iters / self .state .epoch_length )
803831 break
804832
805833 # Should exit while loop if we can not iterate
@@ -839,6 +867,10 @@ def _run_once_on_dataset(self) -> float:
839867 if self .state .epoch_length is not None and iter_counter == self .state .epoch_length :
840868 break
841869
870+ if self .state .max_iters is not None and self .state .iteration == self .state .max_iters :
871+ self .should_terminate = True
872+ break
873+
842874 except Exception as e :
843875 self .logger .error ("Current run is terminating due to exception: %s." , str (e ))
844876 self ._handle_exception (e )
0 commit comments