Skip to content

Commit

Permalink
Fix log_dir property (#5537)
Browse files Browse the repository at this point in the history
* fix and update tests

* update with ModelCheckpoint

* chlog

* wip wandb fix

* all fixed

Co-authored-by: chaton <thomas@grid.ai>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Feb 2, 2021
1 parent 69ea388 commit 793fe73
Show file tree
Hide file tree
Showing 10 changed files with 129 additions and 94 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `num_classes` argument in F1 metric ([#5663](https://github.com/PyTorchLightning/pytorch-lightning/pull/5663))


- Fixed `log_dir` property ([#5537](https://github.com/PyTorchLightning/pytorch-lightning/pull/5537))


- Fixed a race condition in `ModelCheckpoint` when checking if a checkpoint file exists ([#5144](https://github.com/PyTorchLightning/pytorch-lightning/pull/5144))

- Remove unnecessary intermediate layers in Dockerfiles ([#5697](https://github.com/PyTorchLightning/pytorch-lightning/pull/5697))
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def on_pretrain_routine_start(self, trainer, pl_module):
"""
When pretrain routine starts we build the ckpt dir on the fly
"""
self.__resolve_ckpt_dir(trainer, pl_module)
self.__resolve_ckpt_dir(trainer)
self.save_function = trainer.save_checkpoint

def on_validation_end(self, trainer, pl_module):
Expand Down Expand Up @@ -448,7 +448,7 @@ def format_checkpoint_name(
ckpt_name = f"{filename}{self.FILE_EXTENSION}"
return os.path.join(self.dirpath, ckpt_name) if self.dirpath else ckpt_name

def __resolve_ckpt_dir(self, trainer, pl_module):
def __resolve_ckpt_dir(self, trainer):
"""
Determines model checkpoint save directory at runtime. References attributes from the
trainer's logger to determine where to save checkpoints.
Expand Down
15 changes: 13 additions & 2 deletions pytorch_lightning/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
import torch.nn as nn

from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import rank_zero_only, _module_available
from pytorch_lightning.utilities import _module_available, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.warning_utils import WarningCache

_WANDB_AVAILABLE = _module_available("wandb")
Expand Down Expand Up @@ -98,6 +99,14 @@ def __init__(
if wandb is None:
raise ImportError('You want to use `wandb` logger which is not installed yet,' # pragma: no-cover
' install it with `pip install wandb`.')

if offline and log_model:
raise MisconfigurationException(
f'Providing log_model={log_model} and offline={offline} is an invalid configuration'
' since model checkpoints cannot be uploaded in offline mode.\n'
'Hint: Set `offline=False` to log your model.'
)

super().__init__()
self._name = name
self._save_dir = save_dir
Expand Down Expand Up @@ -141,10 +150,12 @@ def experiment(self) -> Run:
self._experiment = wandb.init(
name=self._name, dir=self._save_dir, project=self._project, anonymous=self._anonymous,
id=self._id, resume='allow', **self._kwargs) if wandb.run is None else wandb.run

# offset logging step when resuming a run
self._step_offset = self._experiment.step

# save checkpoints in wandb dir to upload on W&B servers
if self._log_model:
if self._save_dir is None:
self._save_dir = self._experiment.dir
return self._experiment

Expand Down
16 changes: 5 additions & 11 deletions pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC
from argparse import ArgumentParser, Namespace
import inspect
import os
from abc import ABC
from argparse import ArgumentParser, Namespace
from typing import cast, List, Optional, Type, TypeVar, Union

from pytorch_lightning.accelerators.accelerator import Accelerator
Expand Down Expand Up @@ -63,16 +63,10 @@ class TrainerProperties(ABC):

@property
def log_dir(self):
if self.checkpoint_callback is not None:
dirpath = self.checkpoint_callback.dirpath
dirpath = os.path.split(dirpath)[0]
elif self.logger is not None:
if isinstance(self.logger, TensorBoardLogger):
dirpath = self.logger.log_dir
else:
dirpath = self.logger.save_dir
if self.logger is None:
dirpath = self.default_root_dir
else:
dirpath = self._default_root_dir
dirpath = getattr(self.logger, 'log_dir' if isinstance(self.logger, TensorBoardLogger) else 'save_dir')

if self.accelerator_backend is not None:
dirpath = self.accelerator_backend.broadcast(dirpath)
Expand Down
21 changes: 4 additions & 17 deletions tests/loggers/test_comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest.mock import patch, DEFAULT
from unittest.mock import DEFAULT, patch

import pytest

Expand Down Expand Up @@ -74,7 +74,7 @@ def test_comet_logger_online(comet):
@patch('pytorch_lightning.loggers.comet.comet_ml')
def test_comet_logger_no_api_key_given(comet):
""" Test that CometLogger fails to initialize if both api key and save_dir are missing. """
with pytest.raises(MisconfigurationException):
with pytest.raises(MisconfigurationException, match='requires either api_key or save_dir'):
comet.config.get_api_key.return_value = None
CometLogger(workspace='dummy-test', project_name='general')

Expand All @@ -89,13 +89,10 @@ def test_comet_logger_experiment_name(comet):
# Test api_key given
with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet_experiment:
logger = CometLogger(api_key=api_key, experiment_name=experiment_name,)

assert logger._experiment is None

_ = logger.experiment

comet_experiment.assert_called_once_with(api_key=api_key, project_name=None)

comet_experiment().set_name.assert_called_once_with(experiment_name)


Expand All @@ -118,13 +115,10 @@ def save_os_environ(*args, **kwargs):
with patch.dict(os.environ, {"COMET_EXPERIMENT_KEY": experiment_key}):
with patch('pytorch_lightning.loggers.comet.CometExperiment', side_effect=save_os_environ) as comet_experiment:
logger = CometLogger(api_key=api_key)

assert logger.version == experiment_key

assert logger._experiment is None

_ = logger.experiment

comet_experiment.assert_called_once_with(api_key=api_key, project_name=None)

assert instantation_environ["COMET_EXPERIMENT_KEY"] == experiment_key
Expand Down Expand Up @@ -156,10 +150,12 @@ def test_comet_logger_dirs_creation(comet, comet_experiment, tmpdir, monkeypatch

model = EvalModelTemplate()
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3)
assert trainer.log_dir == logger.save_dir
trainer.fit(model)

assert trainer.checkpoint_callback.dirpath == (tmpdir / 'test' / "1" / 'checkpoints')
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=9.ckpt'}
assert trainer.log_dir == logger.save_dir


@patch('pytorch_lightning.loggers.comet.comet_ml')
Expand All @@ -170,11 +166,8 @@ def test_comet_name_default(comet):

with patch('pytorch_lightning.loggers.comet.CometExperiment'):
logger = CometLogger(api_key=api_key)

assert logger._experiment is None

assert logger.name == "comet-default"

assert logger._experiment is None


Expand All @@ -187,11 +180,8 @@ def test_comet_name_project_name(comet):

with patch('pytorch_lightning.loggers.comet.CometExperiment'):
logger = CometLogger(api_key=api_key, project_name=project_name)

assert logger._experiment is None

assert logger.name == project_name

assert logger._experiment is None


Expand All @@ -205,14 +195,11 @@ def test_comet_version_without_experiment(comet):

with patch('pytorch_lightning.loggers.comet.CometExperiment'):
logger = CometLogger(api_key=api_key, experiment_name=experiment_name)

assert logger._experiment is None

first_version = logger.version
assert first_version is not None

assert logger.version == first_version

assert logger._experiment is None

_ = logger.experiment
Expand Down
5 changes: 3 additions & 2 deletions tests/loggers/test_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
# limitations under the License.
import importlib.util
import os

from unittest import mock
from unittest.mock import MagicMock
import pytest

import pytest

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import _MLFLOW_AVAILABLE, MLFlowLogger
Expand Down Expand Up @@ -113,9 +112,11 @@ def test_mlflow_log_dir(client, mlflow, tmpdir):
limit_train_batches=1,
limit_val_batches=3,
)
assert trainer.log_dir == logger.save_dir
trainer.fit(model)
assert trainer.checkpoint_callback.dirpath == (tmpdir / "exp-id" / "run-id" / 'checkpoints')
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=0.ckpt'}
assert trainer.log_dir == logger.save_dir


def test_mlflow_logger_dirs_creation(tmpdir):
Expand Down
4 changes: 3 additions & 1 deletion tests/loggers/test_neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import patch, MagicMock
from unittest.mock import MagicMock, patch

import torch

Expand Down Expand Up @@ -114,7 +114,9 @@ def _run_training(logger):
limit_train_batches=0.05,
logger=logger,
)
assert trainer.log_dir is None
trainer.fit(model)
assert trainer.log_dir is None
return logger

logger_close_after_fit = _run_training(NeptuneLogger(offline_mode=True))
Expand Down
4 changes: 3 additions & 1 deletion tests/loggers/test_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ def test_tensorboard_hparams_reload(tmpdir):
model = EvalModelTemplate()

trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
assert trainer.log_dir == trainer.logger.log_dir
trainer.fit(model)

folder_path = trainer.logger.log_dir
assert trainer.log_dir == trainer.logger.log_dir
folder_path = trainer.log_dir

# make sure yaml is there
with open(os.path.join(folder_path, "hparams.yaml")) as file:
Expand Down
22 changes: 18 additions & 4 deletions tests/loggers/test_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@
# limitations under the License.
import os
import pickle
from unittest import mock
from argparse import ArgumentParser
import types
from argparse import ArgumentParser
from unittest import mock

import pytest

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from tests.base import EvalModelTemplate, BoringModel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import BoringModel, EvalModelTemplate


def get_warnings(recwarn):
Expand Down Expand Up @@ -94,6 +97,7 @@ class Experiment:
""" """
id = 'the_id'
step = 0
dir = 'wandb'

def project_name(self):
return 'the_project_name'
Expand All @@ -109,6 +113,7 @@ def project_name(self):
)
# Access the experiment to ensure it's created
assert trainer.logger.experiment, 'missing experiment'
assert trainer.log_dir == logger.save_dir
pkl_bytes = pickle.dumps(trainer)
trainer2 = pickle.loads(pkl_bytes)

Expand Down Expand Up @@ -147,11 +152,13 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir):

version = logger.version
model = EvalModelTemplate()
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3)
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3, log_every_n_steps=1)
assert trainer.log_dir == logger.save_dir
trainer.fit(model)

assert trainer.checkpoint_callback.dirpath == str(tmpdir / 'project' / version / 'checkpoints')
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=9.ckpt'}
assert trainer.log_dir == logger.save_dir


def test_wandb_sanitize_callable_params(tmpdir):
Expand Down Expand Up @@ -182,3 +189,10 @@ def wrapper_something():
assert params["something"] == "something"
assert params["wrapper_something"] == "wrapper_something"
assert params["wrapper_something_wo_name"] == "<lambda>"


@mock.patch('pytorch_lightning.loggers.wandb.wandb')
def test_wandb_logger_offline_log_model(wandb, tmpdir):
""" Test that log_model=True raises an error in offline mode """
with pytest.raises(MisconfigurationException, match='checkpoints cannot be uploaded in offline mode'):
logger = WandbLogger(save_dir=str(tmpdir), offline=True, log_model=True)
Loading

0 comments on commit 793fe73

Please sign in to comment.