Skip to content

Commit

Permalink
feat(wandb): let wandb cli handle runs (#4648)
Browse files Browse the repository at this point in the history
* feat(wandb): reinit handled by CLI

* fix: typo

* docs(wandb): improve formatting

* test(wandb): set wandb.run to None

* test(wandb): fix tests

* style: fix formatting

* docs(wandb): fix documentation

* Update code markup

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* docs(wandb): update CHANGELOG

* test(wandb): init called only when needed

* Update CHANGELOG.md

* try fix the test

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: chaton <thomas@grid.ai>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: edenlightning <66261195+edenlightning@users.noreply.github.com>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
Co-authored-by: rohitgr7 <rohitgr1998@gmail.com>
  • Loading branch information
7 people authored Nov 23, 2020
1 parent 404af43 commit c586e5d
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 5 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Tuner algorithms will be skipped if `fast_dev_run=True` ([#3903](https://github.com/PyTorchLightning/pytorch-lightning/pull/3903))

- WandbLogger does not force wandb `reinit` arg to True anymore and creates a run only when needed ([#4648](https://github.com/PyTorchLightning/pytorch-lightning/pull/4648))


- Renamed class metric `Fbeta` -> `FBeta` ([#4656](https://github.com/PyTorchLightning/pytorch-lightning/pull/4656))

Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class WandbLogger(LightningLoggerBase):
Example::
.. code::
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer
wandb_logger = WandbLogger()
Expand Down Expand Up @@ -131,7 +133,7 @@ def experiment(self) -> Run:
os.environ['WANDB_MODE'] = 'dryrun'
self._experiment = wandb.init(
name=self._name, dir=self._save_dir, project=self._project, anonymous=self._anonymous,
reinit=True, id=self._id, resume='allow', **self._kwargs)
id=self._id, resume='allow', **self._kwargs) if wandb.run is None else wandb.run
# save checkpoints in wandb dir to upload on W&B servers
if self._log_model:
self._save_dir = self._experiment.dir
Expand Down
2 changes: 1 addition & 1 deletion tests/loggers/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,4 +369,4 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch):
with mock.patch('pytorch_lightning.loggers.wandb.wandb') as wandb:
logger = _instantiate_logger(WandbLogger, save_idr=tmpdir, prefix=prefix)
logger.log_metrics({"test": 1.0}, step=0)
wandb.init().log.assert_called_once_with({'tmp-test': 1.0}, step=0)
logger.experiment.log.assert_called_once_with({'tmp-test': 1.0}, step=0)
15 changes: 12 additions & 3 deletions tests/loggers/test_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,27 @@


@mock.patch('pytorch_lightning.loggers.wandb.wandb')
def test_wandb_logger(wandb):
def test_wandb_logger_init(wandb):
"""Verify that basic functionality of wandb logger works.
Wandb doesn't work well with pytest so we have to mock it out here."""
logger = WandbLogger(anonymous=True, offline=True)

# test wandb.init called when there is no W&B run
wandb.run = None
logger = WandbLogger()
logger.log_metrics({'acc': 1.0})
wandb.init.assert_called_once()
wandb.init().log.assert_called_once_with({'acc': 1.0}, step=None)

# test wandb.init not called if there is a W&B run
wandb.init().log.reset_mock()
wandb.init.reset_mock()
wandb.run = wandb.init()
logger = WandbLogger()
logger.log_metrics({'acc': 1.0}, step=3)
wandb.init.assert_called_once()
wandb.init().log.assert_called_once_with({'acc': 1.0}, step=3)

# continue training on same W&B run
# continue training on same W&B run and offset step
wandb.init().step = 3
logger.finalize('success')
logger.log_metrics({'acc': 1.0}, step=3)
Expand Down Expand Up @@ -67,6 +75,7 @@ class Experiment:
def project_name(self):
return 'the_project_name'

wandb.run = None
wandb.init.return_value = Experiment()
logger = WandbLogger(id='the_id', offline=True)

Expand Down

0 comments on commit c586e5d

Please sign in to comment.