diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e1d4b1d6c983..06656907ef442 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -63,6 +63,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed multiple early stopping callbacks ([#6197](https://github.com/PyTorchLightning/pytorch-lightning/pull/6197)) +- Prevent `WandbLogger` from dropping values ([#5931](https://github.com/PyTorchLightning/pytorch-lightning/pull/5931)) + + ## [1.2.1] - 2021-02-23 ### Fixed diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 68c7e228cc14a..285388d6c6765 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -26,6 +26,8 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.warnings import WarningCache +warning_cache = WarningCache() + _WANDB_AVAILABLE = _module_available("wandb") try: @@ -56,7 +58,6 @@ class WandbLogger(LightningLoggerBase): project: The name of the project to which this run will belong. log_model: Save checkpoints in wandb dir to upload on W&B servers. prefix: A string to put at the beginning of metric keys. - sync_step: Sync Trainer step with wandb step. experiment: WandB experiment object. Automatically set when creating a run. \**kwargs: Additional arguments like `entity`, `group`, `tags`, etc. used by :func:`wandb.init` can be passed as keyword arguments in this logger. @@ -98,7 +99,7 @@ def __init__( log_model: Optional[bool] = False, experiment=None, prefix: Optional[str] = '', - sync_step: Optional[bool] = True, + sync_step: Optional[bool] = None, **kwargs ): if wandb is None: @@ -114,6 +115,12 @@ def __init__( 'Hint: Set `offline=False` to log your model.' ) + if sync_step is not None: + warning_cache.warn( + "`WandbLogger(sync_step=(True|False))` is deprecated in v1.2.1 and will be removed in v1.5." + " Metrics are now logged separately and automatically synchronized.", DeprecationWarning + ) + super().__init__() self._name = name self._save_dir = save_dir @@ -123,12 +130,8 @@ def __init__( self._project = project self._log_model = log_model self._prefix = prefix - self._sync_step = sync_step self._experiment = experiment self._kwargs = kwargs - # logging multiple Trainer on a single W&B run (k-fold, resuming, etc) - self._step_offset = 0 - self.warning_cache = WarningCache() def __getstate__(self): state = self.__dict__.copy() @@ -165,12 +168,15 @@ def experiment(self) -> Run: **self._kwargs ) if wandb.run is None else wandb.run - # offset logging step when resuming a run - self._step_offset = self._experiment.step - # save checkpoints in wandb dir to upload on W&B servers if self._save_dir is None: self._save_dir = self._experiment.dir + + # define default x-axis (for latest wandb versions) + if getattr(self._experiment, "define_metric", None): + self._experiment.define_metric("trainer/global_step") + self._experiment.define_metric("*", step_metric='trainer/global_step', step_sync=True) + return self._experiment def watch(self, model: nn.Module, log: str = 'gradients', log_freq: int = 100): @@ -188,15 +194,8 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0' metrics = self._add_prefix(metrics) - if self._sync_step and step is not None and step + self._step_offset < self.experiment.step: - self.warning_cache.warn( - 'Trying to log at a previous step. Use `WandbLogger(sync_step=False)`' - ' or try logging with `commit=False` when calling manually `wandb.log`.' - ) - if self._sync_step: - self.experiment.log(metrics, step=(step + self._step_offset) if step is not None else None) - elif step is not None: - self.experiment.log({**metrics, 'trainer_step': (step + self._step_offset)}) + if step is not None: + self.experiment.log({**metrics, 'trainer/global_step': step}) else: self.experiment.log(metrics) @@ -216,10 +215,6 @@ def version(self) -> Optional[str]: @rank_zero_only def finalize(self, status: str) -> None: - # offset future training logged on same W&B run - if self._experiment is not None: - self._step_offset = self._experiment.step - # upload all checkpoints from saving dir if self._log_model: wandb.save(os.path.join(self.save_dir, "*.ckpt")) diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index e87fb5c2ebbb2..415f1d040ba70 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -13,13 +13,22 @@ # limitations under the License. """Test deprecated functionality which will be removed in v1.5.0""" +from unittest import mock + import pytest from pytorch_lightning import Trainer, Callback +from pytorch_lightning.loggers import WandbLogger from tests.helpers import BoringModel from tests.helpers.utils import no_warning_call +@mock.patch('pytorch_lightning.loggers.wandb.wandb') +def test_v1_5_0_wandb_unused_sync_step(tmpdir): + with pytest.deprecated_call(match=r"v1.2.1 and will be removed in v1.5"): + WandbLogger(sync_step=True) + + def test_v1_5_0_old_callback_on_save_checkpoint(tmpdir): class OldSignature(Callback): def on_save_checkpoint(self, trainer, pl_module): # noqa diff --git a/tests/loggers/test_all.py b/tests/loggers/test_all.py index 02721ba436743..c80dddde2774c 100644 --- a/tests/loggers/test_all.py +++ b/tests/loggers/test_all.py @@ -404,4 +404,4 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch): wandb.run = None wandb.init().step = 0 logger.log_metrics({"test": 1.0}, step=0) - logger.experiment.log.assert_called_once_with({'tmp-test': 1.0}, step=0) + logger.experiment.log.assert_called_once_with({'tmp-test': 1.0, 'trainer/global_step': 0}) diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index e5b9b891b88c1..0eefb9625ddc7 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -41,22 +41,7 @@ def test_wandb_logger_init(wandb, recwarn): logger = WandbLogger() logger.log_metrics({'acc': 1.0}) wandb.init.assert_called_once() - wandb.init().log.assert_called_once_with({'acc': 1.0}, step=None) - - # test sync_step functionality - wandb.init().log.reset_mock() - wandb.init.reset_mock() - wandb.run = None - wandb.init().step = 0 - logger = WandbLogger(sync_step=False) - logger.log_metrics({'acc': 1.0}) wandb.init().log.assert_called_once_with({'acc': 1.0}) - wandb.init().log.reset_mock() - logger.log_metrics({'acc': 1.0}, step=3) - wandb.init().log.assert_called_once_with({'acc': 1.0, 'trainer_step': 3}) - - # mock wandb step - wandb.init().step = 0 # test wandb.init not called if there is a W&B run wandb.init().log.reset_mock() @@ -65,13 +50,12 @@ def test_wandb_logger_init(wandb, recwarn): logger = WandbLogger() logger.log_metrics({'acc': 1.0}, step=3) wandb.init.assert_called_once() - wandb.init().log.assert_called_once_with({'acc': 1.0}, step=3) + wandb.init().log.assert_called_once_with({'acc': 1.0, 'trainer/global_step': 3}) # continue training on same W&B run and offset step - wandb.init().step = 3 logger.finalize('success') - logger.log_metrics({'acc': 1.0}, step=3) - wandb.init().log.assert_called_with({'acc': 1.0}, step=6) + logger.log_metrics({'acc': 1.0}, step=6) + wandb.init().log.assert_called_with({'acc': 1.0, 'trainer/global_step': 6}) # log hyper parameters logger.log_hyperparams({'test': None, 'nested': {'a': 1}, 'b': [2, 3, 4]}) @@ -88,17 +72,6 @@ def test_wandb_logger_init(wandb, recwarn): logger.watch('model', 'log', 10) wandb.init().watch.assert_called_once_with('model', log='log', log_freq=10) - # verify warning for logging at a previous step - assert 'Trying to log at a previous step' not in get_warnings(recwarn) - # current step from wandb should be 6 (last logged step) - logger.experiment.step = 6 - # logging at step 2 should raise a warning (step_offset is still 3) - logger.log_metrics({'acc': 1.0}, step=2) - assert 'Trying to log at a previous step' in get_warnings(recwarn) - # logging again at step 2 should not display again the same warning - logger.log_metrics({'acc': 1.0}, step=2) - assert 'Trying to log at a previous step' not in get_warnings(recwarn) - assert logger.name == wandb.init().project_name() assert logger.version == wandb.init().id