From d9cac4ce207474b09a85be0b5964d44951159ebd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 23 Jul 2020 23:19:51 +0200 Subject: [PATCH 01/21] fix weights_save path and drop ckpt_path --- .../callbacks/model_checkpoint.py | 10 +++---- pytorch_lightning/trainer/callback_config.py | 4 --- pytorch_lightning/trainer/deprecated_api.py | 15 ++++++++++ pytorch_lightning/trainer/trainer.py | 28 +++++++++++++++---- tests/callbacks/test_callbacks.py | 2 +- 5 files changed, 44 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index b6a92efc53321..9d12e510322e6 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -242,10 +242,11 @@ def on_train_start(self, trainer, pl_module): self.filename = '{epoch}' if trainer.logger is not None: - # weights_save_path overrides anything - save_dir = (getattr(trainer, 'weights_save_path', None) - or getattr(trainer.logger, 'save_dir', None) - or trainer.default_root_dir) + if trainer.weights_save_path != trainer.default_root_dir: + # the user has changed weights_save_path, it overrides anything + save_dir = trainer.weights_save_path + else: + save_dir = trainer.logger.save_dir or trainer.default_root_dir version = trainer.logger.version if isinstance( trainer.logger.version, str) else f'version_{trainer.logger.version}' @@ -263,7 +264,6 @@ def on_train_start(self, trainer, pl_module): assert trainer.global_rank == 0, 'tried to make a checkpoint from non global_rank=0' os.makedirs(self.dirpath, exist_ok=True) - trainer.ckpt_path = ckpt_path trainer.weights_save_path = ckpt_path @rank_zero_only diff --git a/pytorch_lightning/trainer/callback_config.py b/pytorch_lightning/trainer/callback_config.py index b65dc37ef8b1b..47d02fe87f265 100644 --- a/pytorch_lightning/trainer/callback_config.py +++ b/pytorch_lightning/trainer/callback_config.py @@ -53,10 +53,6 @@ def configure_checkpoint_callback(self, checkpoint_callback): if checkpoint_callback: checkpoint_callback.save_function = self.save_checkpoint - # if weights_save_path is still none here, set to current working dir - if self.weights_save_path is None: - self.weights_save_path = self.default_root_dir - return checkpoint_callback def configure_early_stopping(self, early_stop_callback): diff --git a/pytorch_lightning/trainer/deprecated_api.py b/pytorch_lightning/trainer/deprecated_api.py index 6cb160f1d26b2..86ad6e645ce31 100644 --- a/pytorch_lightning/trainer/deprecated_api.py +++ b/pytorch_lightning/trainer/deprecated_api.py @@ -45,6 +45,7 @@ class TrainerDeprecatedAPITillVer0_10(ABC): limit_test_batches: Union[int, float] limit_train_batches: Union[int, float] overfit_batches: Union[int, float] + weights_save_path: str def __init__(self): super().__init__() # mixin calls super too @@ -118,3 +119,17 @@ def proc_rank(self, rank): rank_zero_warn("Attribute `proc_rank` is now set by `global_rank` since v0.8.0" " and this method will be removed in v0.10.0", DeprecationWarning) self.global_rank = rank + + @property + def ckpt_path(self) -> str: + """Back compatibility, will be removed in v0.10.0""" + rank_zero_warn("Attribute `ckpt_path` is now set by `weights_save_path` since v0.9.0" + " and this method will be removed in v0.10.0", DeprecationWarning) + return self.weights_save_path + + @ckpt_path.setter + def ckpt_path(self, path: str): + """Back compatibility, will be removed in v0.10.0""" + rank_zero_warn("Attribute `ckpt_path` is now set by `weights_save_path` since v0.9.0" + " and this method will be removed in v0.10.0", DeprecationWarning) + self.weights_save_path = path diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index bee332f831cce..7d7aaeea26e19 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -419,10 +419,7 @@ def __init__( self.should_stop = False self.running_sanity_check = False - # set default save path if user didn't provide one - if default_root_dir is None: - default_root_dir = os.getcwd() - self.default_root_dir = default_root_dir + self._default_root_dir = default_root_dir # init callbacks self.callbacks = callbacks or [] @@ -436,7 +433,7 @@ def __init__( # configure checkpoint callback # it is important that this is the last callback to run # pass through the required args to figure out defaults - self.weights_save_path = weights_save_path + self._weights_save_path = weights_save_path checkpoint_callback = self.configure_checkpoint_callback(checkpoint_callback) if checkpoint_callback: self.callbacks.append(checkpoint_callback) @@ -894,6 +891,27 @@ def enable_validation(self) -> bool: val_loop_enabled = (self.is_overridden('validation_step') and self.limit_val_batches > 0) return val_loop_enabled or self.fast_dev_run + @property + def default_root_dir(self) -> str: + """ set default save path if user didn't provide one """ + path = self._default_root_dir or os.getcwd() + path = os.path.normpath(path) + return path + + @default_root_dir.setter + def default_root_dir(self, path: str): + self._default_root_dir = path + + @property + def weights_save_path(self) -> str: + path = self._weights_save_path or self.default_root_dir + path = os.path.normpath(path) + return path + + @weights_save_path.setter + def weights_save_path(self, path: str): + self._weights_save_path = path + # ----------------------------- # MODEL TRAINING # ----------------------------- diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index b1034ef7d7f28..469b829c26daa 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -321,7 +321,7 @@ def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k): ) trainer.fit(model) - # These should be different if the dirpath has be overridden + # These should be different if the dirpath has been overridden assert trainer.ckpt_path != trainer.default_root_dir From 924a997fd319bde7b5889c9d672f2e125195a231 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 23 Jul 2020 23:50:04 +0200 Subject: [PATCH 02/21] add tests --- tests/loggers/test_comet.py | 3 +++ tests/loggers/test_mlflow.py | 3 +++ tests/loggers/test_tensorboard.py | 3 +++ tests/loggers/test_wandb.py | 4 ++++ 4 files changed, 13 insertions(+) diff --git a/tests/loggers/test_comet.py b/tests/loggers/test_comet.py index a89840163fe7a..9e491015ef580 100644 --- a/tests/loggers/test_comet.py +++ b/tests/loggers/test_comet.py @@ -104,4 +104,7 @@ def test_comet_logger_dirs_creation(tmpdir, monkeypatch): trainer.fit(model) assert trainer.ckpt_path == trainer.weights_save_path == (tmpdir / 'test' / version / 'checkpoints') + # save_dir must be a subpath of weights_save_path + assert (os.path.relpath(trainer.weights_save_path, logger.save_dir) == + os.path.join('test', version, 'checkpoints')) assert set(os.listdir(trainer.ckpt_path)) == {'epoch=0.ckpt'} diff --git a/tests/loggers/test_mlflow.py b/tests/loggers/test_mlflow.py index ec9bc8db332a4..ca0dcda714e9e 100644 --- a/tests/loggers/test_mlflow.py +++ b/tests/loggers/test_mlflow.py @@ -38,4 +38,7 @@ def test_mlflow_logger_dirs_creation(tmpdir): assert 'epoch' in os.listdir(tmpdir / exp_id / run_id / 'metrics') assert set(os.listdir(tmpdir / exp_id / run_id / 'params')) == model.hparams.keys() assert trainer.ckpt_path == trainer.weights_save_path == (tmpdir / exp_id / run_id / 'checkpoints') + # save_dir must be a subpath of weights_save_path + assert (os.path.relpath(trainer.weights_save_path, logger.save_dir) == + os.path.join(exp_id, run_id, 'checkpoints')) assert set(os.listdir(trainer.ckpt_path)) == {'epoch=0.ckpt'} diff --git a/tests/loggers/test_tensorboard.py b/tests/loggers/test_tensorboard.py index 44009a2ddf658..da9fd20af6eda 100644 --- a/tests/loggers/test_tensorboard.py +++ b/tests/loggers/test_tensorboard.py @@ -32,6 +32,9 @@ def test_tensorboard_hparams_reload(tmpdir): # verify artifacts assert len(os.listdir(os.path.join(folder_path, 'checkpoints'))) == 1 + # save_dir must be a subpath of weights_save_path + assert (os.path.relpath(trainer.weights_save_path, trainer.logger.save_dir) == + os.path.join('lightning_logs', 'version_0', 'checkpoints')) # # # verify tb logs # event_acc = EventAccumulator(folder_path) diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index 57b0aff311264..5a942c43087e3 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -1,5 +1,6 @@ import os import pickle +from pathlib import Path, PurePath from unittest import mock from pytorch_lightning import Trainer @@ -95,4 +96,7 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir): trainer.fit(model) assert trainer.ckpt_path == trainer.weights_save_path == str(tmpdir / 'project' / version / 'checkpoints') + # save_dir must be a subpath of weights_save_path + assert (os.path.relpath(trainer.weights_save_path, logger.save_dir) == + os.path.join('project', version, 'checkpoints')) assert set(os.listdir(trainer.ckpt_path)) == {'epoch=0.ckpt'} From e0d2c5ac3138999eefb17c6a3ee46b29cd708bc8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 23 Jul 2020 23:52:15 +0200 Subject: [PATCH 03/21] unused import --- tests/loggers/test_wandb.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index 5a942c43087e3..13c6cf57d3152 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -1,6 +1,5 @@ import os import pickle -from pathlib import Path, PurePath from unittest import mock from pytorch_lightning import Trainer From 7583661213b5580bc15fa86a91fbeca6229cf068 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 24 Jul 2020 00:45:28 +0200 Subject: [PATCH 04/21] update docs --- .../callbacks/model_checkpoint.py | 19 +++++++++++++++---- pytorch_lightning/trainer/__init__.py | 4 ++-- pytorch_lightning/trainer/callback_config.py | 6 ------ pytorch_lightning/trainer/trainer.py | 11 ++++++++++- 4 files changed, 27 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 9d12e510322e6..710912ab66504 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -41,8 +41,10 @@ class ModelCheckpoint(Callback): ... filepath='my/path/{epoch}-{val_loss:.2f}-{other_metric:.2f}' ... ) - Can also be set to `None`, then it will be set to default location - during trainer construction. + Can also be set to `None`, then it will be set to the location + specified by :class:`~pytorch_lightning.trainer.trainer.Trainer`'s + :paramref:`~pytorch_lightning.trainer.trainer.Trainer.default_root_dir` or + :paramref:`~pytorch_lightning.trainer.trainer.Trainer.weights_save_path` arguments. monitor: quantity to monitor. verbose: verbosity mode. Default: ``False``. @@ -233,8 +235,17 @@ def format_checkpoint_name(self, epoch, metrics, ver=None): @rank_zero_only def on_train_start(self, trainer, pl_module): """ - Determine model checkpoint save directory at runtime. References attributes from the - Trainer's logger to determine where to save checkpoints. + Determines model checkpoint save directory at runtime. References attributes from the + trainer's logger to determine where to save checkpoints. + The base path for saving weights is set in this priority: + + 1. Checkpoint callback's path (if passed in) + 2. The default_root_dir from trainer if trainer has no logger + 3. The weights_save_path from trainer, if user provides it + 4. User provided weights_saved_path + + The base path gets extended with logger name and version (if these are available) + and subfolder "checkpoints". """ if self.dirpath is not None: return # short circuit diff --git a/pytorch_lightning/trainer/__init__.py b/pytorch_lightning/trainer/__init__.py index b8f02a36c6806..7e188ab97492c 100644 --- a/pytorch_lightning/trainer/__init__.py +++ b/pytorch_lightning/trainer/__init__.py @@ -291,12 +291,12 @@ def on_train_end(self, trainer, pl_module): ) default_root_dir -^^^^^^^^^^^^^^^^^ +^^^^^^^^^^^^^^^^ Default path for logs and weights when no logger or :class:`pytorch_lightning.callbacks.ModelCheckpoint` callback passed. On certain clusters you might want to separate where logs and checkpoints -are stored. If you don't then use this method for convenience. +are stored. If you don't then use this argument for convenience. Example:: diff --git a/pytorch_lightning/trainer/callback_config.py b/pytorch_lightning/trainer/callback_config.py index 47d02fe87f265..8600449d86a94 100644 --- a/pytorch_lightning/trainer/callback_config.py +++ b/pytorch_lightning/trainer/callback_config.py @@ -33,12 +33,6 @@ def is_overridden(self, *args): """Warning: this is just empty shell for code implemented in other class.""" def configure_checkpoint_callback(self, checkpoint_callback): - """ - Weight path set in this priority: - Checkpoint_callback's path (if passed in). - User provided weights_saved_path - Otherwise use os.getcwd() - """ if checkpoint_callback is True: # when no val step is defined, use 'loss' otherwise 'val_loss' train_step_only = not self.is_overridden('validation_step') diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7d7aaeea26e19..31eceba20586d 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -893,7 +893,11 @@ def enable_validation(self) -> bool: @property def default_root_dir(self) -> str: - """ set default save path if user didn't provide one """ + """ + The default location to save artifacts of loggers, checkpoints etc. + It is used as a fallback if logger or checkpoint callback do not define specific save paths. + Defaults to ``os.getcwd()``. + """ path = self._default_root_dir or os.getcwd() path = os.path.normpath(path) return path @@ -904,6 +908,11 @@ def default_root_dir(self, path: str): @property def weights_save_path(self) -> str: + """ + The default location to save weights (checkpoints), e.g., when the + :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` does not define a save location. + It defaults to :paramref:`~pytorch_lightning.trainer.trainer.Trainer.default_root_dir`. + """ path = self._weights_save_path or self.default_root_dir path = os.path.normpath(path) return path From 9f60af3ef5b66a5acc7611dc7daf67a6c76ac054 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 24 Jul 2020 00:57:27 +0200 Subject: [PATCH 05/21] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5ed0abbcfa801..ccbb69d23570f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Deprecated +- Deprecated Trainer attribute `ckpt_path`, which will now be set by `weights_save_path` ([#2681](https://github.com/PyTorchLightning/pytorch-lightning/pull/2681)) ### Removed @@ -29,6 +30,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `dtype` and `device` properties not getting updated in submodules ([#2657](https://github.com/PyTorchLightning/pytorch-lightning/pull/2657)) +- Fixed `save_dir` in loggers getting ignored by default value of `weights_save_path` when user did not specify `weights_save_path` ([#2681](https://github.com/PyTorchLightning/pytorch-lightning/pull/2681)) + ## [0.8.5] - 2020-07-09 ### Added From f91b0f76effe86ee798422db39863ca89880a574 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 24 Jul 2020 00:59:46 +0200 Subject: [PATCH 06/21] pep8 --- tests/loggers/test_comet.py | 4 ++-- tests/loggers/test_mlflow.py | 4 ++-- tests/loggers/test_tensorboard.py | 4 ++-- tests/loggers/test_wandb.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/loggers/test_comet.py b/tests/loggers/test_comet.py index 9e491015ef580..e9f8aba130074 100644 --- a/tests/loggers/test_comet.py +++ b/tests/loggers/test_comet.py @@ -105,6 +105,6 @@ def test_comet_logger_dirs_creation(tmpdir, monkeypatch): assert trainer.ckpt_path == trainer.weights_save_path == (tmpdir / 'test' / version / 'checkpoints') # save_dir must be a subpath of weights_save_path - assert (os.path.relpath(trainer.weights_save_path, logger.save_dir) == - os.path.join('test', version, 'checkpoints')) + assert (os.path.relpath(trainer.weights_save_path, logger.save_dir) + == os.path.join('test', version, 'checkpoints')) assert set(os.listdir(trainer.ckpt_path)) == {'epoch=0.ckpt'} diff --git a/tests/loggers/test_mlflow.py b/tests/loggers/test_mlflow.py index ca0dcda714e9e..2e704cbf4d164 100644 --- a/tests/loggers/test_mlflow.py +++ b/tests/loggers/test_mlflow.py @@ -39,6 +39,6 @@ def test_mlflow_logger_dirs_creation(tmpdir): assert set(os.listdir(tmpdir / exp_id / run_id / 'params')) == model.hparams.keys() assert trainer.ckpt_path == trainer.weights_save_path == (tmpdir / exp_id / run_id / 'checkpoints') # save_dir must be a subpath of weights_save_path - assert (os.path.relpath(trainer.weights_save_path, logger.save_dir) == - os.path.join(exp_id, run_id, 'checkpoints')) + assert (os.path.relpath(trainer.weights_save_path, logger.save_dir) + == os.path.join(exp_id, run_id, 'checkpoints')) assert set(os.listdir(trainer.ckpt_path)) == {'epoch=0.ckpt'} diff --git a/tests/loggers/test_tensorboard.py b/tests/loggers/test_tensorboard.py index da9fd20af6eda..4237eb8442cda 100644 --- a/tests/loggers/test_tensorboard.py +++ b/tests/loggers/test_tensorboard.py @@ -33,8 +33,8 @@ def test_tensorboard_hparams_reload(tmpdir): # verify artifacts assert len(os.listdir(os.path.join(folder_path, 'checkpoints'))) == 1 # save_dir must be a subpath of weights_save_path - assert (os.path.relpath(trainer.weights_save_path, trainer.logger.save_dir) == - os.path.join('lightning_logs', 'version_0', 'checkpoints')) + assert (os.path.relpath(trainer.weights_save_path, trainer.logger.save_dir) + == os.path.join('lightning_logs', 'version_0', 'checkpoints')) # # # verify tb logs # event_acc = EventAccumulator(folder_path) diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index 13c6cf57d3152..62f2d62665633 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -96,6 +96,6 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir): assert trainer.ckpt_path == trainer.weights_save_path == str(tmpdir / 'project' / version / 'checkpoints') # save_dir must be a subpath of weights_save_path - assert (os.path.relpath(trainer.weights_save_path, logger.save_dir) == - os.path.join('project', version, 'checkpoints')) + assert (os.path.relpath(trainer.weights_save_path, logger.save_dir) + == os.path.join('project', version, 'checkpoints')) assert set(os.listdir(trainer.ckpt_path)) == {'epoch=0.ckpt'} From 280745813639e1405ab877515a11707244a4610f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 24 Jul 2020 01:03:54 +0200 Subject: [PATCH 07/21] fix horovod test --- tests/models/data/horovod/train_default_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/models/data/horovod/train_default_model.py b/tests/models/data/horovod/train_default_model.py index a71d7b576ca1f..7138021e8e7e9 100644 --- a/tests/models/data/horovod/train_default_model.py +++ b/tests/models/data/horovod/train_default_model.py @@ -63,7 +63,6 @@ def run_test_from_config(trainer_options): if trainer.global_rank > 0: # on higher ranks the checkpoint location is unknown # we want to test checkpointing on rank 0 only - assert not hasattr(trainer, 'ckpt_path') assert not trainer.checkpoint_callback.best_model_path return From e72779475f34c81fb620dd2e28fa431bd8c12de8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 24 Jul 2020 01:22:11 +0200 Subject: [PATCH 08/21] make backward compatible --- pytorch_lightning/trainer/deprecated_api.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/deprecated_api.py b/pytorch_lightning/trainer/deprecated_api.py index 86ad6e645ce31..88bbbb4bf2aae 100644 --- a/pytorch_lightning/trainer/deprecated_api.py +++ b/pytorch_lightning/trainer/deprecated_api.py @@ -46,6 +46,7 @@ class TrainerDeprecatedAPITillVer0_10(ABC): limit_train_batches: Union[int, float] overfit_batches: Union[int, float] weights_save_path: str + is_global_zero: bool def __init__(self): super().__init__() # mixin calls super too @@ -125,7 +126,7 @@ def ckpt_path(self) -> str: """Back compatibility, will be removed in v0.10.0""" rank_zero_warn("Attribute `ckpt_path` is now set by `weights_save_path` since v0.9.0" " and this method will be removed in v0.10.0", DeprecationWarning) - return self.weights_save_path + return self.weights_save_path if self.is_global_zero else None @ckpt_path.setter def ckpt_path(self, path: str): From c8dfffe13489d7ea482736ed0ae50ca017703351 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 24 Jul 2020 05:38:31 +0200 Subject: [PATCH 09/21] perform same test for all loggers --- tests/loggers/test_all.py | 54 +++++++++++++++++++++++++++++++ tests/loggers/test_comet.py | 3 -- tests/loggers/test_mlflow.py | 3 -- tests/loggers/test_tensorboard.py | 3 -- tests/loggers/test_wandb.py | 3 -- 5 files changed, 54 insertions(+), 12 deletions(-) diff --git a/tests/loggers/test_all.py b/tests/loggers/test_all.py index b64119078c6dd..38bb4753f7398 100644 --- a/tests/loggers/test_all.py +++ b/tests/loggers/test_all.py @@ -1,5 +1,6 @@ import atexit import inspect +import os import pickle import platform from unittest import mock @@ -82,6 +83,59 @@ def log_metrics(self, metrics, step): (1, ['epoch', 'test_acc', 'test_loss'])] +@pytest.mark.parametrize("logger_class", [ + TensorBoardLogger, + CometLogger, + MLFlowLogger, + TestTubeLogger, + WandbLogger, +]) +@mock.patch('pytorch_lightning.loggers.wandb.wandb') +def test_loggers_save_dir_and_weights_save_path(wandb, tmpdir, monkeypatch, logger_class): + """ Test the combinations of save_dir, weights_save_path and default_root_dir. """ + if logger_class == CometLogger: + # prevent comet logger from trying to print at exit, since + # pytest's stdout/stderr redirection breaks it + monkeypatch.setattr(atexit, 'register', lambda _: None) + + class TestLogger(logger_class): + # for this test it does not matter what these attributes are + # so we standardize them to make testing easier + @property + def version(self): + return 'version' + + @property + def name(self): + return 'name' + + model = EvalModelTemplate() + trainer_args = dict( + default_root_dir=tmpdir, + max_steps=1, + ) + + # no weights_save_path given + save_dir = tmpdir / 'logs' + weights_save_path = None + logger = TestLogger(**_get_logger_args(TestLogger, save_dir)) + trainer = Trainer(**trainer_args, logger=logger, weights_save_path=weights_save_path) + trainer.fit(model) + expected_weights_path = os.path.join(logger.save_dir, 'name', 'version', 'checkpoints') + assert trainer.weights_save_path == expected_weights_path + assert trainer.default_root_dir == tmpdir + + # with weights_save_path given, the logger path should not relate with checkpoint dir + save_dir = tmpdir / 'logs' + weights_save_path = tmpdir / 'weights' + logger = TestLogger(**_get_logger_args(TestLogger, save_dir)) + trainer = Trainer(**trainer_args, logger=logger, weights_save_path=weights_save_path) + trainer.fit(model) + expected_weights_path = weights_save_path / 'name' / 'version' / 'checkpoints' + assert trainer.weights_save_path == expected_weights_path + assert trainer.default_root_dir == tmpdir + + @pytest.mark.parametrize("logger_class", [ TensorBoardLogger, CometLogger, diff --git a/tests/loggers/test_comet.py b/tests/loggers/test_comet.py index e9f8aba130074..a89840163fe7a 100644 --- a/tests/loggers/test_comet.py +++ b/tests/loggers/test_comet.py @@ -104,7 +104,4 @@ def test_comet_logger_dirs_creation(tmpdir, monkeypatch): trainer.fit(model) assert trainer.ckpt_path == trainer.weights_save_path == (tmpdir / 'test' / version / 'checkpoints') - # save_dir must be a subpath of weights_save_path - assert (os.path.relpath(trainer.weights_save_path, logger.save_dir) - == os.path.join('test', version, 'checkpoints')) assert set(os.listdir(trainer.ckpt_path)) == {'epoch=0.ckpt'} diff --git a/tests/loggers/test_mlflow.py b/tests/loggers/test_mlflow.py index 2e704cbf4d164..ec9bc8db332a4 100644 --- a/tests/loggers/test_mlflow.py +++ b/tests/loggers/test_mlflow.py @@ -38,7 +38,4 @@ def test_mlflow_logger_dirs_creation(tmpdir): assert 'epoch' in os.listdir(tmpdir / exp_id / run_id / 'metrics') assert set(os.listdir(tmpdir / exp_id / run_id / 'params')) == model.hparams.keys() assert trainer.ckpt_path == trainer.weights_save_path == (tmpdir / exp_id / run_id / 'checkpoints') - # save_dir must be a subpath of weights_save_path - assert (os.path.relpath(trainer.weights_save_path, logger.save_dir) - == os.path.join(exp_id, run_id, 'checkpoints')) assert set(os.listdir(trainer.ckpt_path)) == {'epoch=0.ckpt'} diff --git a/tests/loggers/test_tensorboard.py b/tests/loggers/test_tensorboard.py index 4237eb8442cda..44009a2ddf658 100644 --- a/tests/loggers/test_tensorboard.py +++ b/tests/loggers/test_tensorboard.py @@ -32,9 +32,6 @@ def test_tensorboard_hparams_reload(tmpdir): # verify artifacts assert len(os.listdir(os.path.join(folder_path, 'checkpoints'))) == 1 - # save_dir must be a subpath of weights_save_path - assert (os.path.relpath(trainer.weights_save_path, trainer.logger.save_dir) - == os.path.join('lightning_logs', 'version_0', 'checkpoints')) # # # verify tb logs # event_acc = EventAccumulator(folder_path) diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index 62f2d62665633..57b0aff311264 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -95,7 +95,4 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir): trainer.fit(model) assert trainer.ckpt_path == trainer.weights_save_path == str(tmpdir / 'project' / version / 'checkpoints') - # save_dir must be a subpath of weights_save_path - assert (os.path.relpath(trainer.weights_save_path, logger.save_dir) - == os.path.join('project', version, 'checkpoints')) assert set(os.listdir(trainer.ckpt_path)) == {'epoch=0.ckpt'} From a18e7f2cc7afa6ca16803ec2d42915e1a455dcd8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 24 Jul 2020 05:58:01 +0200 Subject: [PATCH 10/21] fix for when logger=False and weights_save_path is set --- pytorch_lightning/callbacks/model_checkpoint.py | 2 +- tests/loggers/test_all.py | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 710912ab66504..9a814ccb188e5 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -268,7 +268,7 @@ def on_train_start(self, trainer, pl_module): "checkpoints" ) else: - ckpt_path = os.path.join(trainer.default_root_dir, "checkpoints") + ckpt_path = os.path.join(trainer.weights_save_path, "checkpoints") self.dirpath = ckpt_path diff --git a/tests/loggers/test_all.py b/tests/loggers/test_all.py index 38bb4753f7398..1683405656919 100644 --- a/tests/loggers/test_all.py +++ b/tests/loggers/test_all.py @@ -125,7 +125,7 @@ def name(self): assert trainer.weights_save_path == expected_weights_path assert trainer.default_root_dir == tmpdir - # with weights_save_path given, the logger path should not relate with checkpoint dir + # with weights_save_path given, the logger path and checkpoint path should be different save_dir = tmpdir / 'logs' weights_save_path = tmpdir / 'weights' logger = TestLogger(**_get_logger_args(TestLogger, save_dir)) @@ -135,6 +135,12 @@ def name(self): assert trainer.weights_save_path == expected_weights_path assert trainer.default_root_dir == tmpdir + # no logger given + trainer = Trainer(**trainer_args, logger=False, weights_save_path=(tmpdir / 'foo')) + trainer.fit(model) + assert trainer.weights_save_path == tmpdir / 'foo' / 'checkpoints' + assert trainer.default_root_dir == tmpdir + @pytest.mark.parametrize("logger_class", [ TensorBoardLogger, From 2112b5c52bfb508ebd8639317cbb0fe7ed25e53f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 24 Jul 2020 06:01:23 +0200 Subject: [PATCH 11/21] update changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ccbb69d23570f..83beb5b680fb6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `save_dir` in loggers getting ignored by default value of `weights_save_path` when user did not specify `weights_save_path` ([#2681](https://github.com/PyTorchLightning/pytorch-lightning/pull/2681)) +- Fixed `weights_save_path` getting ignored when `logger=False` is passed to Trainer ([#2681](https://github.com/PyTorchLightning/pytorch-lightning/pull/2681)) + ## [0.8.5] - 2020-07-09 ### Added From 3ebd95204312d9925865dd13a0feecf3eeaa9fa4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 24 Jul 2020 06:13:24 +0200 Subject: [PATCH 12/21] update docs --- pytorch_lightning/callbacks/model_checkpoint.py | 5 +++-- pytorch_lightning/trainer/trainer.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 9a814ccb188e5..d3029b48dd235 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -41,10 +41,11 @@ class ModelCheckpoint(Callback): ... filepath='my/path/{epoch}-{val_loss:.2f}-{other_metric:.2f}' ... ) - Can also be set to `None`, then it will be set to the location + By default, filepath is `None` and will be set at runtime to the location specified by :class:`~pytorch_lightning.trainer.trainer.Trainer`'s :paramref:`~pytorch_lightning.trainer.trainer.Trainer.default_root_dir` or - :paramref:`~pytorch_lightning.trainer.trainer.Trainer.weights_save_path` arguments. + :paramref:`~pytorch_lightning.trainer.trainer.Trainer.weights_save_path` arguments, + and if the Trainer uses a logger, the path will also contain logger name and version. monitor: quantity to monitor. verbose: verbosity mode. Default: ``False``. diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 31eceba20586d..4ba34e2af7e35 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -909,8 +909,8 @@ def default_root_dir(self, path: str): @property def weights_save_path(self) -> str: """ - The default location to save weights (checkpoints), e.g., when the - :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` does not define a save location. + The default root location to save weights (checkpoints), e.g., when the + :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` does not define a file path. It defaults to :paramref:`~pytorch_lightning.trainer.trainer.Trainer.default_root_dir`. """ path = self._weights_save_path or self.default_root_dir From 047ed63a22a78e6374fa5ee1e5bff12aee46fac4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 24 Jul 2020 06:25:02 +0200 Subject: [PATCH 13/21] update tests --- tests/callbacks/test_callbacks.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 469b829c26daa..b2eef1fc578e7 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -27,9 +27,7 @@ def validation_epoch_end(self, outputs): overfit_batches=0.20, max_epochs=20, ) - result = trainer.fit(model) - print(trainer.current_epoch) - + trainer.fit(model) assert trainer.current_epoch == 5, 'early_stopping failed' From 5dbee1eec2d3c4bf59b9e5efd654c363b2059377 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 24 Jul 2020 06:52:11 +0200 Subject: [PATCH 14/21] do not set save dir dynamically --- pytorch_lightning/callbacks/model_checkpoint.py | 2 -- tests/loggers/test_all.py | 15 +++++++++------ 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index d3029b48dd235..bfade6f024ba8 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -274,9 +274,7 @@ def on_train_start(self, trainer, pl_module): self.dirpath = ckpt_path assert trainer.global_rank == 0, 'tried to make a checkpoint from non global_rank=0' - os.makedirs(self.dirpath, exist_ok=True) - trainer.weights_save_path = ckpt_path @rank_zero_only def on_validation_end(self, trainer, pl_module): diff --git a/tests/loggers/test_all.py b/tests/loggers/test_all.py index 1683405656919..e3d6202d05932 100644 --- a/tests/loggers/test_all.py +++ b/tests/loggers/test_all.py @@ -121,8 +121,8 @@ def name(self): logger = TestLogger(**_get_logger_args(TestLogger, save_dir)) trainer = Trainer(**trainer_args, logger=logger, weights_save_path=weights_save_path) trainer.fit(model) - expected_weights_path = os.path.join(logger.save_dir, 'name', 'version', 'checkpoints') - assert trainer.weights_save_path == expected_weights_path + assert trainer.weights_save_path == trainer.default_root_dir + assert trainer.checkpoint_callback.dirpath == os.path.join(logger.save_dir, 'name', 'version', 'checkpoints') assert trainer.default_root_dir == tmpdir # with weights_save_path given, the logger path and checkpoint path should be different @@ -131,14 +131,17 @@ def name(self): logger = TestLogger(**_get_logger_args(TestLogger, save_dir)) trainer = Trainer(**trainer_args, logger=logger, weights_save_path=weights_save_path) trainer.fit(model) - expected_weights_path = weights_save_path / 'name' / 'version' / 'checkpoints' - assert trainer.weights_save_path == expected_weights_path + assert trainer.weights_save_path == weights_save_path + assert trainer.logger.save_dir == save_dir + assert trainer.checkpoint_callback.dirpath == weights_save_path / 'name' / 'version' / 'checkpoints' assert trainer.default_root_dir == tmpdir # no logger given - trainer = Trainer(**trainer_args, logger=False, weights_save_path=(tmpdir / 'foo')) + weights_save_path = tmpdir / 'weights' + trainer = Trainer(**trainer_args, logger=False, weights_save_path=weights_save_path) trainer.fit(model) - assert trainer.weights_save_path == tmpdir / 'foo' / 'checkpoints' + assert trainer.weights_save_path == weights_save_path + assert trainer.checkpoint_callback.dirpath == weights_save_path / 'checkpoints' assert trainer.default_root_dir == tmpdir From cb5636beb1af8f7134f07c1d3bd6b0f72f329499 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 24 Jul 2020 06:54:22 +0200 Subject: [PATCH 15/21] remove duplicate test --- tests/callbacks/test_callbacks.py | 42 ------------------------ tests/callbacks/test_model_checkpoint.py | 13 ++++---- 2 files changed, 6 insertions(+), 49 deletions(-) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index b1034ef7d7f28..c2174737cef9b 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -303,45 +303,3 @@ def test_pickling(tmpdir): assert vars(early_stopping) == vars(early_stopping_loaded) assert vars(ckpt) == vars(ckpt_loaded) - - -@pytest.mark.parametrize('save_top_k', [-1, 0, 1, 2]) -def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k): - """ Test that None in checkpoint callback is valid and that chkp_path is set correctly """ - tutils.reset_seed() - model = EvalModelTemplate() - - checkpoint = ModelCheckpoint(filepath=None, save_top_k=save_top_k) - - trainer = Trainer( - default_root_dir=tmpdir, - checkpoint_callback=checkpoint, - overfit_batches=0.20, - max_epochs=2, - ) - trainer.fit(model) - - # These should be different if the dirpath has be overridden - assert trainer.ckpt_path != trainer.default_root_dir - - -@pytest.mark.parametrize( - 'logger_version,expected', - [(None, 'version_0'), (1, 'version_1'), ('awesome', 'awesome')], -) -def test_model_checkpoint_path(tmpdir, logger_version, expected): - """Test that "version_" prefix is only added when logger's version is an integer""" - tutils.reset_seed() - model = EvalModelTemplate() - logger = TensorBoardLogger(str(tmpdir), version=logger_version) - - trainer = Trainer( - default_root_dir=tmpdir, - overfit_batches=0.2, - max_epochs=2, - logger=logger, - ) - trainer.fit(model) - - ckpt_version = Path(trainer.ckpt_path).parent.name - assert ckpt_version == expected diff --git a/tests/callbacks/test_model_checkpoint.py b/tests/callbacks/test_model_checkpoint.py index 7257dc3874a2a..697121eabddf9 100644 --- a/tests/callbacks/test_model_checkpoint.py +++ b/tests/callbacks/test_model_checkpoint.py @@ -15,9 +15,7 @@ @pytest.mark.parametrize('save_top_k', [-1, 0, 1, 2]) def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k): - """ - Test that None in checkpoint callback is valid and that chkp_path is set correctly - """ + """ Test that None in checkpoint callback is valid and that chkp_path is set correctly """ tutils.reset_seed() model = EvalModelTemplate() @@ -26,8 +24,8 @@ def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k): trainer = Trainer( default_root_dir=tmpdir, checkpoint_callback=checkpoint, - overfit_pct=0.20, - max_epochs=(save_top_k + 2), + overfit_batches=0.20, + max_epochs=2, ) trainer.fit(model) @@ -47,8 +45,8 @@ def test_model_checkpoint_path(tmpdir, logger_version, expected): trainer = Trainer( default_root_dir=tmpdir, - overfit_pct=0.2, - max_epochs=5, + overfit_batches=0.2, + max_epochs=2, logger=logger, ) trainer.fit(model) @@ -57,6 +55,7 @@ def test_model_checkpoint_path(tmpdir, logger_version, expected): assert ckpt_version == expected + def test_pickling(tmpdir): ckpt = ModelCheckpoint(tmpdir) From dd3ee408f81bf2d858bb4682d199b3bd7490705c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 24 Jul 2020 07:00:32 +0200 Subject: [PATCH 16/21] remove duplicated tests --- tests/callbacks/test_callbacks.py | 71 ------------------------ tests/callbacks/test_early_stopping.py | 46 +++++++++++++++ tests/callbacks/test_model_checkpoint.py | 1 - 3 files changed, 46 insertions(+), 72 deletions(-) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index c2174737cef9b..d10965524394b 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -1,38 +1,8 @@ -from pathlib import Path - -import pytest -import torch - -import tests.base.develop_utils as tutils from pytorch_lightning import Callback from pytorch_lightning import Trainer, LightningModule -from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint -from pytorch_lightning.loggers import TensorBoardLogger from tests.base import EvalModelTemplate -def test_early_stopping_functionality(tmpdir): - - class CurrentModel(EvalModelTemplate): - def validation_epoch_end(self, outputs): - losses = [8, 4, 2, 3, 4, 5, 8, 10] - val_loss = losses[self.current_epoch] - return {'val_loss': torch.tensor(val_loss)} - - model = CurrentModel() - - trainer = Trainer( - default_root_dir=tmpdir, - early_stop_callback=True, - overfit_batches=0.20, - max_epochs=20, - ) - result = trainer.fit(model) - print(trainer.current_epoch) - - assert trainer.current_epoch == 5, 'early_stopping failed' - - def test_trainer_callback_system(tmpdir): """Test the callback system.""" @@ -262,44 +232,3 @@ def on_test_end(self, trainer, pl_module): assert not test_callback.on_validation_end_called assert not test_callback.on_validation_batch_end_called assert not test_callback.on_validation_batch_start_called - - -def test_early_stopping_no_val_step(tmpdir): - """Test that early stopping callback falls back to training metrics when no validation defined.""" - - class CurrentModel(EvalModelTemplate): - def training_step(self, *args, **kwargs): - output = super().training_step(*args, **kwargs) - output.update({'my_train_metric': output['loss']}) # could be anything else - return output - - model = CurrentModel() - model.validation_step = None - model.val_dataloader = None - - stopping = EarlyStopping(monitor='my_train_metric', min_delta=0.1) - trainer = Trainer( - default_root_dir=tmpdir, - early_stop_callback=stopping, - overfit_batches=0.20, - max_epochs=2, - ) - result = trainer.fit(model) - - assert result == 1, 'training failed to complete' - assert trainer.current_epoch < trainer.max_epochs - - -def test_pickling(tmpdir): - import pickle - early_stopping = EarlyStopping() - ckpt = ModelCheckpoint(tmpdir) - - early_stopping_pickled = pickle.dumps(early_stopping) - ckpt_pickled = pickle.dumps(ckpt) - - early_stopping_loaded = pickle.loads(early_stopping_pickled) - ckpt_loaded = pickle.loads(ckpt_pickled) - - assert vars(early_stopping) == vars(early_stopping_loaded) - assert vars(ckpt) == vars(ckpt_loaded) diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index 2ba434af26dbb..17ca3bb2210f3 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -130,3 +130,49 @@ def test_pickling(tmpdir): early_stopping_pickled = cloudpickle.dumps(early_stopping) early_stopping_loaded = cloudpickle.loads(early_stopping_pickled) assert vars(early_stopping) == vars(early_stopping_loaded) + + +def test_early_stopping_no_val_step(tmpdir): + """Test that early stopping callback falls back to training metrics when no validation defined.""" + + class CurrentModel(EvalModelTemplate): + def training_step(self, *args, **kwargs): + output = super().training_step(*args, **kwargs) + output.update({'my_train_metric': output['loss']}) # could be anything else + return output + + model = CurrentModel() + model.validation_step = None + model.val_dataloader = None + + stopping = EarlyStopping(monitor='my_train_metric', min_delta=0.1) + trainer = Trainer( + default_root_dir=tmpdir, + early_stop_callback=stopping, + overfit_batches=0.20, + max_epochs=2, + ) + result = trainer.fit(model) + + assert result == 1, 'training failed to complete' + assert trainer.current_epoch < trainer.max_epochs + + +def test_early_stopping_functionality(tmpdir): + + class CurrentModel(EvalModelTemplate): + def validation_epoch_end(self, outputs): + losses = [8, 4, 2, 3, 4, 5, 8, 10] + val_loss = losses[self.current_epoch] + return {'val_loss': torch.tensor(val_loss)} + + model = CurrentModel() + + trainer = Trainer( + default_root_dir=tmpdir, + early_stop_callback=True, + overfit_batches=0.20, + max_epochs=20, + ) + trainer.fit(model) + assert trainer.current_epoch == 5, 'early_stopping failed' diff --git a/tests/callbacks/test_model_checkpoint.py b/tests/callbacks/test_model_checkpoint.py index 697121eabddf9..bb575494c3148 100644 --- a/tests/callbacks/test_model_checkpoint.py +++ b/tests/callbacks/test_model_checkpoint.py @@ -55,7 +55,6 @@ def test_model_checkpoint_path(tmpdir, logger_version, expected): assert ckpt_version == expected - def test_pickling(tmpdir): ckpt = ModelCheckpoint(tmpdir) From 7b87d5c83cf08ed455ca85afb1cfc5f773681326 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 24 Jul 2020 08:52:21 +0200 Subject: [PATCH 17/21] update tests --- pytorch_lightning/trainer/trainer.py | 8 -------- tests/callbacks/test_model_checkpoint.py | 8 +++----- 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 4ba34e2af7e35..ea1ee5daf7cb6 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -902,10 +902,6 @@ def default_root_dir(self) -> str: path = os.path.normpath(path) return path - @default_root_dir.setter - def default_root_dir(self, path: str): - self._default_root_dir = path - @property def weights_save_path(self) -> str: """ @@ -917,10 +913,6 @@ def weights_save_path(self) -> str: path = os.path.normpath(path) return path - @weights_save_path.setter - def weights_save_path(self, path: str): - self._weights_save_path = path - # ----------------------------- # MODEL TRAINING # ----------------------------- diff --git a/tests/callbacks/test_model_checkpoint.py b/tests/callbacks/test_model_checkpoint.py index 7257dc3874a2a..3d05c68f462e9 100644 --- a/tests/callbacks/test_model_checkpoint.py +++ b/tests/callbacks/test_model_checkpoint.py @@ -16,7 +16,7 @@ @pytest.mark.parametrize('save_top_k', [-1, 0, 1, 2]) def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k): """ - Test that None in checkpoint callback is valid and that chkp_path is set correctly + Test that None in checkpoint callback is valid and that dirpath is set correctly """ tutils.reset_seed() model = EvalModelTemplate() @@ -30,9 +30,7 @@ def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k): max_epochs=(save_top_k + 2), ) trainer.fit(model) - - # These should be different if the dirpath has be overridden - assert trainer.ckpt_path != trainer.default_root_dir + assert checkpoint.dirpath == tmpdir / trainer.logger.name / f'version_0' / 'checkpoints' @pytest.mark.parametrize( @@ -53,7 +51,7 @@ def test_model_checkpoint_path(tmpdir, logger_version, expected): ) trainer.fit(model) - ckpt_version = Path(trainer.ckpt_path).parent.name + ckpt_version = Path(trainer.checkpoint_callback.dirpath).parent.name assert ckpt_version == expected From 5c5bcc9d0bac3f358a376f11bf89be55cf1386cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 24 Jul 2020 08:55:44 +0200 Subject: [PATCH 18/21] update tests --- tests/callbacks/test_callbacks.py | 6 ++---- tests/callbacks/test_model_checkpoint.py | 8 ++++---- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index b2eef1fc578e7..82d8960863590 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -318,9 +318,7 @@ def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k): max_epochs=2, ) trainer.fit(model) - - # These should be different if the dirpath has been overridden - assert trainer.ckpt_path != trainer.default_root_dir + assert checkpoint.dirpath == tmpdir / trainer.logger.name / f'version_0' / 'checkpoints' @pytest.mark.parametrize( @@ -341,5 +339,5 @@ def test_model_checkpoint_path(tmpdir, logger_version, expected): ) trainer.fit(model) - ckpt_version = Path(trainer.ckpt_path).parent.name + ckpt_version = Path(trainer.checkpoint_callback.dirpath).parent.name assert ckpt_version == expected diff --git a/tests/callbacks/test_model_checkpoint.py b/tests/callbacks/test_model_checkpoint.py index 3d05c68f462e9..86c8bb49d99a2 100644 --- a/tests/callbacks/test_model_checkpoint.py +++ b/tests/callbacks/test_model_checkpoint.py @@ -26,8 +26,8 @@ def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k): trainer = Trainer( default_root_dir=tmpdir, checkpoint_callback=checkpoint, - overfit_pct=0.20, - max_epochs=(save_top_k + 2), + overfit_batches=0.20, + max_epochs=2, ) trainer.fit(model) assert checkpoint.dirpath == tmpdir / trainer.logger.name / f'version_0' / 'checkpoints' @@ -45,8 +45,8 @@ def test_model_checkpoint_path(tmpdir, logger_version, expected): trainer = Trainer( default_root_dir=tmpdir, - overfit_pct=0.2, - max_epochs=5, + overfit_batches=0.2, + max_epochs=2, logger=logger, ) trainer.fit(model) From 1f17b1a6ffbb02750b0cb9c6bf2cff87ce53c3fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 24 Jul 2020 09:04:45 +0200 Subject: [PATCH 19/21] remove remaining ckpt_path references --- tests/loggers/test_comet.py | 4 ++-- tests/loggers/test_mlflow.py | 4 ++-- tests/loggers/test_wandb.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/loggers/test_comet.py b/tests/loggers/test_comet.py index a89840163fe7a..a3ba883a65ae3 100644 --- a/tests/loggers/test_comet.py +++ b/tests/loggers/test_comet.py @@ -103,5 +103,5 @@ def test_comet_logger_dirs_creation(tmpdir, monkeypatch): trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3) trainer.fit(model) - assert trainer.ckpt_path == trainer.weights_save_path == (tmpdir / 'test' / version / 'checkpoints') - assert set(os.listdir(trainer.ckpt_path)) == {'epoch=0.ckpt'} + assert trainer.checkpoint_callback.dirpath == (tmpdir / 'test' / version / 'checkpoints') + assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'} diff --git a/tests/loggers/test_mlflow.py b/tests/loggers/test_mlflow.py index ec9bc8db332a4..31b580f33f6d4 100644 --- a/tests/loggers/test_mlflow.py +++ b/tests/loggers/test_mlflow.py @@ -37,5 +37,5 @@ def test_mlflow_logger_dirs_creation(tmpdir): assert set(os.listdir(tmpdir / exp_id)) == {run_id, 'meta.yaml'} assert 'epoch' in os.listdir(tmpdir / exp_id / run_id / 'metrics') assert set(os.listdir(tmpdir / exp_id / run_id / 'params')) == model.hparams.keys() - assert trainer.ckpt_path == trainer.weights_save_path == (tmpdir / exp_id / run_id / 'checkpoints') - assert set(os.listdir(trainer.ckpt_path)) == {'epoch=0.ckpt'} + assert trainer.checkpoint_callback.dirpath == (tmpdir / exp_id / run_id / 'checkpoints') + assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'} diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index 57b0aff311264..9907ad9d087a2 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -94,5 +94,5 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir): trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3) trainer.fit(model) - assert trainer.ckpt_path == trainer.weights_save_path == str(tmpdir / 'project' / version / 'checkpoints') - assert set(os.listdir(trainer.ckpt_path)) == {'epoch=0.ckpt'} + assert trainer.checkpoint_callback.dirpath == str(tmpdir / 'project' / version / 'checkpoints') + assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'} From f6930ebad060b43a1b4aee8199adbc10ed576bf2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 25 Jul 2020 19:52:34 +0200 Subject: [PATCH 20/21] move defaults to init as suggested by @Borda --- pytorch_lightning/trainer/trainer.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index cdebbd2ac74c2..0ece0c26d93d9 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -206,7 +206,8 @@ def __init__( callbacks: Add a list of callbacks. - default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed + default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed. + Default: ``os.getcwd()``. gradient_clip_val: 0 means don't clip. @@ -333,6 +334,7 @@ def __init__( weights_save_path: Where to save weights if specified. Will override default_root_dir for checkpoints only. Use this if for whatever reason you need the checkpoints stored in a different place than the logs written in `default_root_dir`. + Defaults to `default_root_dir`. amp_level: The optimization level to use (O1, O2, etc...). @@ -419,7 +421,8 @@ def __init__( self.should_stop = False self.running_sanity_check = False - self._default_root_dir = default_root_dir + self._default_root_dir = default_root_dir or os.getcwd() + self._weights_save_path = weights_save_path or self._default_root_dir # init callbacks self.callbacks = callbacks or [] @@ -433,7 +436,6 @@ def __init__( # configure checkpoint callback # it is important that this is the last callback to run # pass through the required args to figure out defaults - self._weights_save_path = weights_save_path checkpoint_callback = self.configure_checkpoint_callback(checkpoint_callback) if checkpoint_callback: self.callbacks.append(checkpoint_callback) @@ -896,22 +898,16 @@ def default_root_dir(self) -> str: """ The default location to save artifacts of loggers, checkpoints etc. It is used as a fallback if logger or checkpoint callback do not define specific save paths. - Defaults to ``os.getcwd()``. """ - path = self._default_root_dir or os.getcwd() - path = os.path.normpath(path) - return path + return os.path.normpath(self._default_root_dir) @property def weights_save_path(self) -> str: """ The default root location to save weights (checkpoints), e.g., when the :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` does not define a file path. - It defaults to :paramref:`~pytorch_lightning.trainer.trainer.Trainer.default_root_dir`. """ - path = self._weights_save_path or self.default_root_dir - path = os.path.normpath(path) - return path + return os.path.normpath(self._weights_save_path) # ----------------------------- # MODEL TRAINING From 9c44f18521cad4e47419ebf971ca83e0a6b6681b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 25 Jul 2020 20:09:46 +0200 Subject: [PATCH 21/21] test deprecation --- pytorch_lightning/trainer/deprecated_api.py | 5 +++-- tests/test_deprecated.py | 4 ++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/deprecated_api.py b/pytorch_lightning/trainer/deprecated_api.py index 88bbbb4bf2aae..f38ff6a486874 100644 --- a/pytorch_lightning/trainer/deprecated_api.py +++ b/pytorch_lightning/trainer/deprecated_api.py @@ -45,8 +45,9 @@ class TrainerDeprecatedAPITillVer0_10(ABC): limit_test_batches: Union[int, float] limit_train_batches: Union[int, float] overfit_batches: Union[int, float] - weights_save_path: str is_global_zero: bool + _weights_save_path: str + weights_save_path: str def __init__(self): super().__init__() # mixin calls super too @@ -133,4 +134,4 @@ def ckpt_path(self, path: str): """Back compatibility, will be removed in v0.10.0""" rank_zero_warn("Attribute `ckpt_path` is now set by `weights_save_path` since v0.9.0" " and this method will be removed in v0.10.0", DeprecationWarning) - self.weights_save_path = path + self._weights_save_path = path diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index e6eb86e42c1fc..9d4d69faa30bf 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -50,6 +50,10 @@ def test_tbd_remove_in_v0_10_0_trainer(): with pytest.deprecated_call(match='will be removed in v0.10.0'): assert trainer.proc_rank == trainer.global_rank + with pytest.deprecated_call(match='will be removed in v0.10.0'): + trainer.ckpt_path = 'foo' + assert trainer.ckpt_path == trainer.weights_save_path == 'foo' + def test_tbd_remove_in_v0_9_0_trainer(): # test show_progress_bar set by progress_bar_refresh_rate