Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Trainer max_time argument + Callback #6823

Merged
merged 50 commits into from
Apr 16, 2021
Merged
Show file tree
Hide file tree
Changes from 40 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
c98321b
add timer class
awaelchli Apr 4, 2021
6a39660
add simple test
awaelchli Apr 4, 2021
2ac1cb6
shorter name
awaelchli Apr 4, 2021
cdda687
trainer callback configuration
awaelchli Apr 6, 2021
a291b89
interval default to step
awaelchli Apr 6, 2021
8aa979d
handle unsupported interval choice
awaelchli Apr 6, 2021
7399e7b
handle load and save
awaelchli Apr 6, 2021
27caeff
add start time property
awaelchli Apr 7, 2021
ffba8c3
add time elapsed test
awaelchli Apr 7, 2021
f243b94
complete test
awaelchli Apr 7, 2021
6ae6654
add trainer docs
awaelchli Apr 7, 2021
b8ff17c
update docs
awaelchli Apr 7, 2021
8bd7be7
more tests
awaelchli Apr 7, 2021
89aa9a7
Merge branch 'master' into feature/timer
awaelchli Apr 7, 2021
ad4cf80
fix min steps timer test
awaelchli Apr 7, 2021
4bd2c6e
add resume test
awaelchli Apr 7, 2021
5e45883
add changelog
awaelchli Apr 7, 2021
8b95dfa
yapf + isort
awaelchli Apr 7, 2021
36f4906
update trainer docs
awaelchli Apr 7, 2021
e4dcf07
add more docs
awaelchli Apr 7, 2021
dbc5251
use enum
awaelchli Apr 8, 2021
9c59bdf
fix typo
awaelchli Apr 8, 2021
f0928aa
Merge remote-tracking branch 'origin/feature/timer' into feature/timer
awaelchli Apr 8, 2021
c65cd9f
include elapsed time in message
awaelchli Apr 8, 2021
5991bb3
broadcast instead of reduce
awaelchli Apr 8, 2021
fb5590d
add days to string representation
awaelchli Apr 8, 2021
344d43f
wip parse
awaelchli Apr 8, 2021
a182d50
Revert "wip parse"
awaelchli Apr 8, 2021
25289ea
Merge branch 'master' into feature/timer
awaelchli Apr 10, 2021
44c196d
support dict
awaelchli Apr 10, 2021
d5b9074
add dict example
awaelchli Apr 12, 2021
8e505fb
update timer docstring
awaelchli Apr 12, 2021
1519333
udpate typehint
awaelchli Apr 12, 2021
f52935e
track val/test/predict/times
awaelchli Apr 12, 2021
a89ac0c
track time for all stages
awaelchli Apr 14, 2021
c8b22fb
enum nonsense
awaelchli Apr 14, 2021
8ef429d
make duration optional
awaelchli Apr 14, 2021
4410909
fix duration=None
awaelchli Apr 14, 2021
4d4e22d
add test
awaelchli Apr 14, 2021
1f1c982
add None test
awaelchli Apr 14, 2021
baae22b
Merge branch 'master' into feature/timer
carmocca Apr 15, 2021
f8cfb23
Improve coverage
carmocca Apr 15, 2021
fb54c4b
Typo
carmocca Apr 15, 2021
48d8fa1
Refactor enum usage
carmocca Apr 15, 2021
96b2c78
Typing
carmocca Apr 15, 2021
b2ebba9
Docs
carmocca Apr 15, 2021
fe4fae0
seconds
awaelchli Apr 15, 2021
24b0d3a
fix stage key in checkpoint
awaelchli Apr 15, 2021
df4ddce
skip windows
awaelchli Apr 16, 2021
3989fc7
Update pytorch_lightning/callbacks/timer.py
awaelchli Apr 16, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `LightningModule.lr_schedulers()` for manual optimization ([#6567](https://github.com/PyTorchLightning/pytorch-lightning/pull/6567))


- Added `max_time` Trainer argument to limit training time ([#6823](https://github.com/PyTorchLightning/pytorch-lightning/pull/6823))


### Changed

- Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259))
Expand Down
20 changes: 20 additions & 0 deletions docs/source/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,26 @@ Trainer will train model for at least min_steps or min_epochs (latest).
# Run at least for 100 steps (disable min_epochs)
trainer = Trainer(min_steps=100, min_epochs=0)

max_time
^^^^^^^^

Set the maximum amount of time for training. Training will get interrupted mid-epoch.
For customizable options use the :class:`~pytorch_lightning.callbacks.timer.Timer` callback.

.. testcode::

# Default (disabled)
trainer = Trainer(max_time=None)

# Stop after 12 hours of training or when reaching 10 epochs (string)
trainer = Trainer(max_time="00:12:00:00", max_epochs=10)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

# Stop after 1 day and 5 hours (dict)
trainer = Trainer(max_time={"days": 1, "hours": 5})

In case ``max_time`` is used together with ``min_steps`` or ``min_epochs``, the ``min_*`` requirement
always has precedence.

num_nodes
^^^^^^^^^

Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pytorch_lightning.callbacks.pruning import ModelPruning
from pytorch_lightning.callbacks.quantization import QuantizationAwareTraining
from pytorch_lightning.callbacks.stochastic_weight_avg import StochasticWeightAveraging
from pytorch_lightning.callbacks.timer import Timer

__all__ = [
'BackboneFinetuning',
Expand All @@ -39,4 +40,5 @@
'ProgressBarBase',
'QuantizationAwareTraining',
'StochasticWeightAveraging',
'Timer',
]
163 changes: 163 additions & 0 deletions pytorch_lightning/callbacks/timer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""
Timer
^^^^^
"""
import logging
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Any, Dict, Union, Optional

from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities import LightningEnum
from pytorch_lightning.utilities.distributed import rank_zero_info
from pytorch_lightning.utilities.exceptions import MisconfigurationException

log = logging.getLogger(__name__)


class Interval(LightningEnum):
step = "step"
epoch = "epoch"


class Timer(Callback):
"""
The Timer callback tracks the time spent in the training loop and interrupts the Trainer
if the given time limit is reached.

Args:
duration: A string in the format DD:HH:MM:SS (days, hours, minutes seconds), or a :class:`datetime.timedelta`,
or a dict containing key-value compatible with :class:`~datetime.timedelta`.
interval: Determines if the interruption happens on epoch level or mid-epoch.
Can be either `epoch` or `step`.
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
verbose: Set this to ``False`` to suppress logging messages.

Raises:
MisconfigurationException:
If ``interval`` is not one of the supported choices.

Example::
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Timer

# stop training after 12 hours
timer = Timer(duration="00:12:00:00")

# or provide a datetime.timedelta
from datetime import timedelta
timer = Timer(duration=timedelta(weeks=1))

# or provide a dictionary
timer = Timer(duration=dict(weeks=4, days=2))

# force training to stop after given time limit
trainer = Trainer(callbacks=[timer])

# query training/validation/test time
timer.time_elapsed("train")
timer.start_time("validate")
timer.end_time("test")
"""

def __init__(
self,
duration: Optional[Union[str, timedelta, Dict[str, int]]] = None,
interval: str = Interval.step,
verbose: bool = True,
):
super().__init__()
if isinstance(duration, str):
dhms = duration.strip().split(":")
dhms = [int(i) for i in dhms]
duration = timedelta(days=dhms[0], hours=dhms[1], minutes=dhms[2], seconds=dhms[3])
if isinstance(duration, dict):
duration = timedelta(**duration)
if interval not in set(Interval):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
raise MisconfigurationException(
f"Unsupported parameter value `Timer(interval={interval})`. Possible choices are:"
f" {', '.join(set(Interval))}"
)
self._duration = duration
self._interval = interval
self._verbose = verbose
self._start_time = defaultdict(lambda: None)
self._end_time = defaultdict(lambda: None)
self._offset = timedelta()

def start_time(self, stage: str = RunningStage.TRAINING.value) -> Optional[datetime]:
return self._start_time[stage]

def end_time(self, stage: str = RunningStage.TRAINING.value) -> Optional[datetime]:
return self._end_time[stage]

def time_elapsed(self, stage: str = RunningStage.TRAINING.value) -> timedelta:
start = self.start_time(stage)
end = self.end_time(stage)
offset = self._offset if stage == RunningStage.TRAINING else timedelta(0)
if start is None:
return offset
if end is None:
return datetime.now() - start + offset
return end - start + offset

def time_remaining(self, stage: str = RunningStage.TRAINING.value) -> Optional[timedelta]:
if self._duration is not None:
return self._duration - self.time_elapsed(stage)

def on_train_start(self, *args, **kwargs) -> None:
self._start_time.update({RunningStage.TRAINING.value: datetime.now()})

def on_train_end(self, *args, **kwargs) -> None:
self._end_time.update({RunningStage.TRAINING.value: datetime.now()})

def on_validation_start(self, *args, **kwargs) -> None:
self._start_time.update({RunningStage.VALIDATING.value: datetime.now()})

def on_validation_end(self, *args, **kwargs) -> None:
self._end_time.update({RunningStage.VALIDATING.value: datetime.now()})

def on_test_start(self, *args, **kwargs) -> None:
self._start_time.update({RunningStage.TESTING.value: datetime.now()})

def on_test_end(self, *args, **kwargs) -> None:
self._end_time.update({RunningStage.TESTING.value: datetime.now()})

def on_train_batch_end(self, trainer, *args, **kwargs) -> None:
if self._interval != Interval.step or self._duration is None:
return
self._check_time_remaining(trainer)

def on_train_epoch_end(self, trainer, *args, **kwargs) -> None:
if self._interval != Interval.epoch or self._duration is None:
return
self._check_time_remaining(trainer)

def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]:
return {
"time_elapsed": {k.value: self.time_elapsed(k.value) for k in RunningStage}
}

def on_load_checkpoint(self, callback_state: Dict[str, Any]):
time_elapsed = callback_state.get("time_elapsed", defaultdict(timedelta))
self._offset = time_elapsed[RunningStage.TRAINING.value]

def _check_time_remaining(self, trainer) -> None:
should_stop = self.time_elapsed() >= self._duration
should_stop = trainer.accelerator.broadcast(should_stop)
trainer.should_stop = trainer.should_stop or should_stop
ananthsub marked this conversation as resolved.
Show resolved Hide resolved
if should_stop and self._verbose:
rank_zero_info(f"Time limit reached. Elapsed time is {self.time_elapsed}. Signaling Trainer to stop.")
18 changes: 17 additions & 1 deletion pytorch_lightning/trainer/connectors/callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import List, Union
from datetime import timedelta
from typing import List, Union, Optional, Dict

from pytorch_lightning.callbacks import Callback, ModelCheckpoint, ProgressBar, ProgressBarBase
from pytorch_lightning.callbacks.timer import Timer
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_info
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand All @@ -35,6 +37,7 @@ def on_trainer_init(
weights_save_path,
resume_from_checkpoint,
stochastic_weight_avg,
max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None,
):
self.trainer.resume_from_checkpoint = resume_from_checkpoint

Expand All @@ -55,6 +58,8 @@ def on_trainer_init(
# configure swa callback
self._configure_swa_callbacks()

self._configure_timer_callback(max_time)

# init progress bar
self.trainer._progress_bar_callback = self.configure_progress_bar(progress_bar_refresh_rate, process_position)

Expand Down Expand Up @@ -106,6 +111,17 @@ def configure_progress_bar(self, refresh_rate=None, process_position=0):

return progress_bar_callback

def _configure_timer_callback(self, max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None) -> None:
if max_time is None:
return
if any(isinstance(cb, Timer) for cb in self.trainer.callbacks):
rank_zero_info(
"Ignoring `Trainer(max_time=...)`, callbacks list already contains a Timer."
)
return
timer = Timer(duration=max_time, interval="step")
self.trainer.callbacks.append(timer)

def _trainer_has_checkpoint_callbacks(self):
return len(self.trainer.checkpoint_callbacks) > 0

Expand Down
18 changes: 16 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Trainer to automate the training."""
import logging
import warnings
from datetime import timedelta
from itertools import count
from pathlib import Path
from traceback import print_exc
Expand Down Expand Up @@ -109,6 +110,7 @@ def __init__(
min_epochs: Optional[int] = None,
max_steps: Optional[int] = None,
min_steps: Optional[int] = None,
max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None,
limit_train_batches: Union[int, float] = 1.0,
limit_val_batches: Union[int, float] = 1.0,
limit_test_batches: Union[int, float] = 1.0,
Expand Down Expand Up @@ -242,6 +244,11 @@ def __init__(

min_steps: Force training for at least these number of steps. Disabled by default (None).

max_time: Stop training after this amount of time has passed. Disabled by default (None).
The time duration can be specified in the format DD:HH:MM:SS (days, hours, minutes seconds), as a
:class:`datetime.timedelta`, or a dictionary with keys that will be passed to
:class:`~datetime.timedelta`.

num_nodes: number of GPU nodes for distributed training.

num_processes: number of processes for distributed training with distributed_backend="ddp_cpu"
Expand Down Expand Up @@ -333,8 +340,15 @@ def __init__(
# init callbacks
# Declare attributes to be set in callback_connector on_trainer_init
self.callback_connector.on_trainer_init(
callbacks, checkpoint_callback, progress_bar_refresh_rate, process_position, default_root_dir,
weights_save_path, resume_from_checkpoint, stochastic_weight_avg
callbacks,
checkpoint_callback,
progress_bar_refresh_rate,
process_position,
default_root_dir,
weights_save_path,
resume_from_checkpoint,
stochastic_weight_avg,
max_time,
)

# hook
Expand Down
Loading