Skip to content

Commit

Permalink
Add warning when wandb.run already exists (Lightning-AI#8714)
Browse files Browse the repository at this point in the history
Co-authored-by: thomas chaton <thomas@grid.ai>
  • Loading branch information
2 people authored and four4fish committed Aug 16, 2021
1 parent c5255e5 commit 2123593
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added DeepSpeed collate checkpoint utility function ([#8701](https://github.com/PyTorchLightning/pytorch-lightning/pull/8701))


- Added a warning to `WandbLogger` when reusing a wandb run ([#8714](https://github.com/PyTorchLightning/pytorch-lightning/pull/8714))


- Added `log_graph` argument for `watch` method of `WandbLogger` ([#8662](https://github.com/PyTorchLightning/pytorch-lightning/pull/8662))


Expand Down
9 changes: 8 additions & 1 deletion pytorch_lightning/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,14 @@ def experiment(self) -> Run:
if self._experiment is None:
if self._offline:
os.environ["WANDB_MODE"] = "dryrun"
self._experiment = wandb.init(**self._wandb_init) if wandb.run is None else wandb.run
if wandb.run is None:
self._experiment = wandb.init(**self._wandb_init)
else:
warning_cache.warn(
"There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse"
" this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`."
)
self._experiment = wandb.run

# define default x-axis (for latest wandb versions)
if getattr(self._experiment, "define_metric", None):
Expand Down
6 changes: 6 additions & 0 deletions tests/loggers/test_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import pytest

import pytorch_lightning
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -51,8 +52,13 @@ def test_wandb_logger_init(wandb):
wandb.init.reset_mock()
wandb.run = wandb.init()
logger = WandbLogger()

# verify default resume value
assert logger._wandb_init["resume"] == "allow"

_ = logger.experiment
assert any("There is a wandb run already in progress" in w for w in pytorch_lightning.loggers.wandb.warning_cache)

logger.log_metrics({"acc": 1.0}, step=3)
wandb.init.assert_called_once()
wandb.init().log.assert_called_once_with({"acc": 1.0, "trainer/global_step": 3})
Expand Down

0 comments on commit 2123593

Please sign in to comment.