Skip to content

Commit

Permalink
Checkpointing interval (Lightning-AI#1272)
Browse files Browse the repository at this point in the history
* formatting

* formatting

* fix interval

* fix train loop

* fix test

* parametrize test

* Apply suggestions from code review

Co-Authored-By: Adrian Wälchli <adrian.waelchli@students.unibe.ch>

* fix calling

* flake8

* add types

Co-authored-by: Adrian Wälchli <adrian.waelchli@students.unibe.ch>
Co-authored-by: William Falcon <waf2107@columbia.edu>
  • Loading branch information
3 people authored and akarnachev committed Apr 3, 2020
1 parent 8137b70 commit 4a92684
Show file tree
Hide file tree
Showing 15 changed files with 166 additions and 298 deletions.
6 changes: 1 addition & 5 deletions pl_examples/multi_node_examples/multi_node_ddp_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,7 @@


def main(hparams):
"""
Main training routine specific for this project
:param hparams:
:return:
"""
"""Main training routine specific for this project."""
# ------------------------
# 1 INIT LIGHTNING MODEL
# ------------------------
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
r"""
Callback Base
==============
=============
Abstract base class used to build new callbacks.
"""

Expand Down
90 changes: 42 additions & 48 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(self, filepath, monitor: str = 'val_loss', verbose: bool = False,
self.save_top_k = save_top_k
self.save_weights_only = save_weights_only
self.period = period
self.epochs_since_last_check = 0
self.epoch_last_check = None
self.prefix = prefix
self.best_k_models = {}
# {filename: monitor}
Expand Down Expand Up @@ -139,21 +139,20 @@ def check_monitor_top_k(self, current):
def format_checkpoint_name(self, epoch, metrics, ver=None):
"""Generate a filename according define template.
Examples
--------
>>> tmpdir = os.path.dirname(__file__)
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}'))
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
'epoch=0.ckpt'
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch:03d}'))
>>> os.path.basename(ckpt.format_checkpoint_name(5, {}))
'epoch=005.ckpt'
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}-{val_loss:.2f}'))
>>> os.path.basename(ckpt.format_checkpoint_name(2, dict(val_loss=0.123456)))
'epoch=2-val_loss=0.12.ckpt'
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{missing:d}'))
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
'missing=0.ckpt'
Examples:
>>> tmpdir = os.path.dirname(__file__)
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}'))
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
'epoch=0.ckpt'
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch:03d}'))
>>> os.path.basename(ckpt.format_checkpoint_name(5, {}))
'epoch=005.ckpt'
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}-{val_loss:.2f}'))
>>> os.path.basename(ckpt.format_checkpoint_name(2, dict(val_loss=0.123456)))
'epoch=2-val_loss=0.12.ckpt'
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{missing:d}'))
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
'missing=0.ckpt'
"""
# check if user passed in keys to the string
groups = re.findall(r'(\{.*?)[:\}]', self.filename)
Expand Down Expand Up @@ -181,41 +180,36 @@ def on_validation_end(self, trainer, pl_module):

metrics = trainer.callback_metrics
epoch = trainer.current_epoch
self.epochs_since_last_check += 1

if self.save_top_k == 0:
# no models are saved
return
if self.epochs_since_last_check >= self.period:
self.epochs_since_last_check = 0

filepath = self.format_checkpoint_name(epoch, metrics)
version_cnt = 0
while os.path.isfile(filepath):
filepath = self.format_checkpoint_name(epoch, metrics, ver=version_cnt)
# this epoch called before
version_cnt += 1

if self.save_top_k != -1:
current = metrics.get(self.monitor)

if current is None:
warnings.warn(
f'Can save best model only with {self.monitor} available,'
' skipping.', RuntimeWarning)
else:
if self.check_monitor_top_k(current):
self._do_check_save(filepath, current, epoch)
else:
if self.verbose > 0:
log.info(
f'\nEpoch {epoch:05d}: {self.monitor}'
f' was not in top {self.save_top_k}')

else:
if self.verbose > 0:
log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}')
self._save_model(filepath)
if self.epoch_last_check is not None and (epoch - self.epoch_last_check) < self.period:
# skipping in this term
return

self.epoch_last_check = epoch

filepath = self.format_checkpoint_name(epoch, metrics)
version_cnt = 0
while os.path.isfile(filepath):
filepath = self.format_checkpoint_name(epoch, metrics, ver=version_cnt)
# this epoch called before
version_cnt += 1

if self.save_top_k != -1:
current = metrics.get(self.monitor)

if current is None:
warnings.warn(f'Can save best model only with {self.monitor} available, skipping.', RuntimeWarning)
elif self.check_monitor_top_k(current):
self._do_check_save(filepath, current, epoch)
elif self.verbose > 0:
log.info(f'\nEpoch {epoch:05d}: {self.monitor} was not in top {self.save_top_k}')

else:
if self.verbose > 0:
log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}')
self._save_model(filepath)

def _do_check_save(self, filepath, current, epoch):
# remove kth
Expand Down
39 changes: 20 additions & 19 deletions pytorch_lightning/profiler/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ class BaseProfiler(ABC):
"""

@abstractmethod
def start(self, action_name):
def start(self, action_name: str) -> None:
"""Defines how to start recording an action."""

@abstractmethod
def stop(self, action_name):
def stop(self, action_name: str) -> None:
"""Defines how to record the duration once an action is complete."""

@contextmanager
def profile(self, action_name):
def profile(self, action_name: str) -> None:
"""
Yields a context manager to encapsulate the scope of a profiled action.
Expand All @@ -43,7 +43,7 @@ def profile(self, action_name):
finally:
self.stop(action_name)

def profile_iterable(self, iterable, action_name):
def profile_iterable(self, iterable, action_name: str) -> None:
iterator = iter(iterable)
while True:
try:
Expand All @@ -55,7 +55,7 @@ def profile_iterable(self, iterable, action_name):
self.stop(action_name)
break

def describe(self):
def describe(self) -> None:
"""Logs a profile report after the conclusion of the training run."""
pass

Expand All @@ -69,10 +69,10 @@ class PassThroughProfiler(BaseProfiler):
def __init__(self):
pass

def start(self, action_name):
def start(self, action_name: str) -> None:
pass

def stop(self, action_name):
def stop(self, action_name: str) -> None:
pass


Expand All @@ -86,14 +86,14 @@ def __init__(self):
self.current_actions = {}
self.recorded_durations = defaultdict(list)

def start(self, action_name):
def start(self, action_name: str) -> None:
if action_name in self.current_actions:
raise ValueError(
f"Attempted to start {action_name} which has already started."
)
self.current_actions[action_name] = time.monotonic()

def stop(self, action_name):
def stop(self, action_name: str) -> None:
end_time = time.monotonic()
if action_name not in self.current_actions:
raise ValueError(
Expand All @@ -103,7 +103,7 @@ def stop(self, action_name):
duration = end_time - start_time
self.recorded_durations[action_name].append(duration)

def describe(self):
def describe(self) -> None:
output_string = "\n\nProfiler Report\n"

def log_row(action, mean, total):
Expand All @@ -126,32 +126,33 @@ class AdvancedProfiler(BaseProfiler):
verbose and you should only use this if you want very detailed reports.
"""

def __init__(self, output_filename=None, line_count_restriction=1.0):
def __init__(self, output_filename: str = None, line_count_restriction: float = 1.0):
"""
:param output_filename (str): optionally save profile results to file instead of printing
to std out when training is finished.
:param line_count_restriction (int|float): this can be used to limit the number of functions
reported for each action. either an integer (to select a count of lines),
or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines)
Args:
output_filename: optionally save profile results to file instead of printing
to std out when training is finished.
line_count_restriction: this can be used to limit the number of functions
reported for each action. either an integer (to select a count of lines),
or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines)
"""
self.profiled_actions = {}
self.output_filename = output_filename
self.line_count_restriction = line_count_restriction

def start(self, action_name):
def start(self, action_name: str) -> None:
if action_name not in self.profiled_actions:
self.profiled_actions[action_name] = cProfile.Profile()
self.profiled_actions[action_name].enable()

def stop(self, action_name):
def stop(self, action_name: str) -> None:
pr = self.profiled_actions.get(action_name)
if pr is None:
raise ValueError( # pragma: no-cover
f"Attempting to stop recording an action ({action_name}) which was never started."
)
pr.disable()

def describe(self):
def describe(self) -> None:
self.recorded_stats = {}
for action_name, pr in self.profiled_actions.items():
s = io.StringIO()
Expand Down
7 changes: 3 additions & 4 deletions pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,9 @@ def set_distributed_mode(self, distributed_backend, num_gpu_nodes):
self.use_ddp2 = distributed_backend == 'ddp2'

elif distributed_backend is None:
m = 'You requested multiple GPUs but did not specify a backend' \
'Trainer(distributed_backend=dp) (or ddp, ddp2)' \
'Setting distributed_backend=dp for you'
warnings.warn(m)
warnings.warn('You requested multiple GPUs but did not specify a backend, e.g.'
' Trainer(distributed_backend=dp) (or ddp, ddp2).'
' Setting distributed_backend=dp for you.')
self.use_dp = True
self.use_ddp = False
self.use_ddp2 = False
Expand Down
24 changes: 10 additions & 14 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,9 +491,8 @@ def tpu_train(self, tpu_core_idx, model):
if self.precision == 16:
os.environ['XLA_USE_BF16'] = str(1)

m = f'INIT TPU local core: {self.tpu_local_core_rank}, ' \
f'global rank: {self.tpu_global_core_rank}'
log.info(m)
log.info(f'INIT TPU local core: {self.tpu_local_core_rank},'
f' global rank: {self.tpu_global_core_rank}')

# continue training routine
self.run_pretrain_routine(model)
Expand All @@ -512,12 +511,10 @@ def dp_train(self, model):
# https://github.com/NVIDIA/apex/issues/227
if self.use_dp and self.use_amp:
if self.amp_level == 'O2':
m = f"""
Amp level {self.amp_level} with DataParallel is not supported.
See this note from NVIDIA for more info: https://github.com/NVIDIA/apex/issues/227.
We recommend you switch to ddp if you want to use amp
"""
raise MisconfigurationException(m)
raise MisconfigurationException(
f'Amp level {self.amp_level} with DataParallel is not supported.'
f' See this note from NVIDIA for more info: https://github.com/NVIDIA/apex/issues/227.'
f' We recommend you switch to ddp if you want to use amp')
else:
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)

Expand Down Expand Up @@ -584,11 +581,10 @@ def sanitize_gpu_ids(gpus):
all_available_gpus = get_all_available_gpus()
for gpu in gpus:
if gpu not in all_available_gpus:
message = f"""
You requested GPUs: {gpus}
But your machine only has: {all_available_gpus}
"""
raise MisconfigurationException(message)
raise MisconfigurationException(f"""
You requested GPUs: {gpus}
But your machine only has: {all_available_gpus}
""")
return gpus


Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,9 +322,9 @@ def _evaluate(self, model: LightningModule, dataloaders, max_batches: int, test_
def run_evaluation(self, test_mode: bool = False):
# when testing make sure user defined a test step
if test_mode and not self.is_overriden('test_step'):
m = "You called `.test()` without defining model's `.test_step()`." \
" Please define and try again"
raise MisconfigurationException(m)
raise MisconfigurationException(
"You called `.test()` without defining model's `.test_step()`."
" Please define and try again")

# Validation/Test begin callbacks
if test_mode:
Expand Down
22 changes: 9 additions & 13 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,11 +328,8 @@ def __init__(
if self.fast_dev_run:
self.num_sanity_val_steps = 1
self.max_epochs = 1
m = '''
Running in fast_dev_run mode: will run a full train,
val loop using a single batch
'''
log.info(m)
log.info('Running in fast_dev_run mode: will run a full train,'
' val loop using a single batch')

# set default save path if user didn't provide one
self.default_save_path = default_save_path
Expand Down Expand Up @@ -740,22 +737,22 @@ def __attach_dataloaders(self, model, train_dataloader, val_dataloaders, test_da
# functions to overwrite with these implementations
if train_dataloader is not None:
if not self.is_overriden('training_step', model):
m = 'You called .fit() with a train_dataloader but did not define training_step()'
raise MisconfigurationException(m)
raise MisconfigurationException(
'You called `.fit()` with a `train_dataloader` but did not define `training_step()`')

model.train_dataloader = _PatchDataLoader(train_dataloader)

if val_dataloaders is not None:
if not self.is_overriden('validation_step', model):
m = 'You called .fit() with a val_dataloaders but did not define validation_step()'
raise MisconfigurationException(m)
raise MisconfigurationException(
'You called `.fit()` with a `val_dataloaders` but did not define `validation_step()`')

model.val_dataloader = _PatchDataLoader(val_dataloaders)

if test_dataloaders is not None:
if not self.is_overriden('test_step', model):
m = 'You called .fit() with a test_dataloaders but did not define test_step()'
raise MisconfigurationException(m)
raise MisconfigurationException(
'You called `.fit()` with a `test_dataloaders` but did not define `test_step()`')

model.test_dataloader = _PatchDataLoader(test_dataloaders)

Expand Down Expand Up @@ -856,8 +853,7 @@ def run_pretrain_routine(self, model: LightningModule):
if self.weights_summary in ['full', 'top']:
ref_model.summarize(mode=self.weights_summary)
else:
m = "weights_summary can be None, 'full' or 'top'"
raise MisconfigurationException(m)
raise MisconfigurationException("weights_summary can be None, 'full' or 'top'")

# track model now.
# if cluster resets state, the model will update with the saved weights
Expand Down
Loading

0 comments on commit 4a92684

Please sign in to comment.