diff --git a/detectron2/engine/hooks.py b/detectron2/engine/hooks.py index fc37af0fd3..26bc045d9f 100644 --- a/detectron2/engine/hooks.py +++ b/detectron2/engine/hooks.py @@ -222,6 +222,7 @@ def __init__( val_metric: str, mode: str = "max", file_prefix: str = "model_best", + patience: int = None, ) -> None: """ Args: @@ -231,10 +232,12 @@ def __init__( mode (str): one of {'max', 'min'}. controls whether the chosen val metric should be maximized or minimized, e.g. for "bbox/AP50" it should be "max" file_prefix (str): the prefix of checkpoint's filename, defaults to "model_best" + patience (int): the number of evaluation cycles without improvement before early stopping """ self._logger = logging.getLogger(__name__) self._period = eval_period self._val_metric = val_metric + self._patience = patience assert mode in [ "max", "min", @@ -297,6 +300,24 @@ def after_step(self): and next_iter != self.trainer.max_iter ): self._best_checking() + + if self._patience is None or self.best_iter is None: + return + + iterations_without_improvement = (self.trainer.iter-self.best_iter) // self._period + + if(iterations_without_improvement > self._patience): + self._logger.info( + f"Early stopping triggered at iteration {self.trainer.iter} due to lack of improvement " + f"after {iterations_without_improvement} cycles." + ) + raise Exception("Early stopping triggered. Terminating training process.") + + if(iterations_without_improvement > 0): + self._logger.info( + f"No improvement detected in the last {iterations_without_improvement} evaluation cycles. " + f"{self._patience - iterations_without_improvement} cycles remain before early stopping." + ) def after_train(self): # same conditions as `EvalHook`