diff --git a/docs/source/handlers.rst b/docs/source/handlers.rst index 42545fefbfe8..31828b5ac60d 100644 --- a/docs/source/handlers.rst +++ b/docs/source/handlers.rst @@ -28,3 +28,6 @@ Complete list of handlers .. autoclass:: TerminateOnNan .. autofunction:: global_step_from_engine + +.. autoclass:: TimeLimit + :members: \ No newline at end of file diff --git a/ignite/handlers/__init__.py b/ignite/handlers/__init__.py index cb37d8ce431f..749ec4bcb408 100644 --- a/ignite/handlers/__init__.py +++ b/ignite/handlers/__init__.py @@ -5,6 +5,7 @@ from ignite.handlers.checkpoint import Checkpoint, DiskSaver, ModelCheckpoint from ignite.handlers.early_stopping import EarlyStopping from ignite.handlers.terminate_on_nan import TerminateOnNan +from ignite.handlers.time_limit import TimeLimit from ignite.handlers.timing import Timer __all__ = [ @@ -15,6 +16,7 @@ "EarlyStopping", "TerminateOnNan", "global_step_from_engine", + "TimeLimit", ] diff --git a/ignite/handlers/time_limit.py b/ignite/handlers/time_limit.py new file mode 100644 index 000000000000..bef9cf6ad7e7 --- /dev/null +++ b/ignite/handlers/time_limit.py @@ -0,0 +1,46 @@ +import logging +import time +from typing import Optional + +from ignite.engine import Engine + +__all__ = ["TimeLimit"] + + +class TimeLimit: + """TimeLimit handler can be used to control training time for computing environments where session time is limited. + Timer starts when handler is created and not training started. + This handler gracefully terminates the training if time passed in the training exceeds a limit. + + Args: + limit_sec (int, optional): Maximum time before training terminates (in seconds). Defaults to 28800. + + Examples: + + .. code-block:: python + + from ignite.engine import Events + from ignite.handlers import TimeLimit + + handler = TimeLimit() # 8 hours of training + trainer.add_event_handler(Events.ITERATION_COMPLETED, handler) + + .. versionadded:: 0.4.3 + """ + + def __init__(self, limit_sec: Optional[int] = 28800): + + if not isinstance(limit_sec, int): + raise TypeError("Argument limit_sec should be an integer.") + if limit_sec <= 0: + raise ValueError("Argument limit_sec should be a positive integer.") + + self.limit_sec = limit_sec + self.start_time = time.time() + self.logger = logging.getLogger(__name__ + "." + self.__class__.__name__) + + def __call__(self, engine: Engine) -> None: + elapsed_time = time.time() - self.start_time + if elapsed_time > self.limit_sec: + self.logger.info("Reached the time limit: {} sec. Stop training".format(self.limit_sec)) + engine.terminate() diff --git a/tests/ignite/handlers/test_time_limit.py b/tests/ignite/handlers/test_time_limit.py new file mode 100644 index 000000000000..2965603f758e --- /dev/null +++ b/tests/ignite/handlers/test_time_limit.py @@ -0,0 +1,39 @@ +import time + +import pytest + +from ignite.engine import Engine, Events +from ignite.handlers import TimeLimit + + +def test_arg_validation(): + + with pytest.raises(ValueError, match=r"Argument limit_sec should be a positive integer."): + TimeLimit(limit_sec=-5) + + with pytest.raises(TypeError, match=r"Argument limit_sec should be an integer."): + TimeLimit(limit_sec="abc") + + +def test_terminate_on_time_limit(): + def _train_func(engine, batch): + time.sleep(1) + + def _test(n_iters, limit): + started = time.time() + trainer = Engine(_train_func) + + @trainer.on(Events.TERMINATE) + def _(): + trainer.state.is_terminated = True + + trainer.add_event_handler(Events.ITERATION_COMPLETED, TimeLimit(limit)) + trainer.state.is_terminated = False + + trainer.run(range(n_iters)) + elapsed = round(time.time() - started) + assert elapsed <= limit + 1 + assert trainer.state.is_terminated == (n_iters > limit) + + _test(20, 10) + _test(5, 10)