diff --git a/CHANGELOG.md b/CHANGELOG.md index e604a92d2145c..494abb29d01fe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -57,6 +57,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Tuner algorithms will be skipped if `fast_dev_run=True` ([#3903](https://github.com/PyTorchLightning/pytorch-lightning/pull/3903)) +- WandbLogger does not force wandb `reinit` arg to True anymore and creates a run only when needed ([#4648](https://github.com/PyTorchLightning/pytorch-lightning/pull/4648)) + - Renamed class metric `Fbeta` -> `FBeta` ([#4656](https://github.com/PyTorchLightning/pytorch-lightning/pull/4656)) diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 436f900c669c4..20ecb8fe40d19 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -59,6 +59,8 @@ class WandbLogger(LightningLoggerBase): Example:: + .. code:: + from pytorch_lightning.loggers import WandbLogger from pytorch_lightning import Trainer wandb_logger = WandbLogger() @@ -131,7 +133,7 @@ def experiment(self) -> Run: os.environ['WANDB_MODE'] = 'dryrun' self._experiment = wandb.init( name=self._name, dir=self._save_dir, project=self._project, anonymous=self._anonymous, - reinit=True, id=self._id, resume='allow', **self._kwargs) + id=self._id, resume='allow', **self._kwargs) if wandb.run is None else wandb.run # save checkpoints in wandb dir to upload on W&B servers if self._log_model: self._save_dir = self._experiment.dir diff --git a/tests/loggers/test_all.py b/tests/loggers/test_all.py index d1e87332ccc71..386b8f1e23ea9 100644 --- a/tests/loggers/test_all.py +++ b/tests/loggers/test_all.py @@ -369,4 +369,4 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch): with mock.patch('pytorch_lightning.loggers.wandb.wandb') as wandb: logger = _instantiate_logger(WandbLogger, save_idr=tmpdir, prefix=prefix) logger.log_metrics({"test": 1.0}, step=0) - wandb.init().log.assert_called_once_with({'tmp-test': 1.0}, step=0) + logger.experiment.log.assert_called_once_with({'tmp-test': 1.0}, step=0) diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index 468ca819f91b1..33211e6492d91 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -23,19 +23,27 @@ @mock.patch('pytorch_lightning.loggers.wandb.wandb') -def test_wandb_logger(wandb): +def test_wandb_logger_init(wandb): """Verify that basic functionality of wandb logger works. Wandb doesn't work well with pytest so we have to mock it out here.""" - logger = WandbLogger(anonymous=True, offline=True) + # test wandb.init called when there is no W&B run + wandb.run = None + logger = WandbLogger() logger.log_metrics({'acc': 1.0}) + wandb.init.assert_called_once() wandb.init().log.assert_called_once_with({'acc': 1.0}, step=None) + # test wandb.init not called if there is a W&B run wandb.init().log.reset_mock() + wandb.init.reset_mock() + wandb.run = wandb.init() + logger = WandbLogger() logger.log_metrics({'acc': 1.0}, step=3) + wandb.init.assert_called_once() wandb.init().log.assert_called_once_with({'acc': 1.0}, step=3) - # continue training on same W&B run + # continue training on same W&B run and offset step wandb.init().step = 3 logger.finalize('success') logger.log_metrics({'acc': 1.0}, step=3) @@ -67,6 +75,7 @@ class Experiment: def project_name(self): return 'the_project_name' + wandb.run = None wandb.init.return_value = Experiment() logger = WandbLogger(id='the_id', offline=True)