diff --git a/CHANGELOG.md b/CHANGELOG.md index e2aeb69ee3108..7e49a59c79f94 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added metrics aggregation in Horovod and fixed early stopping ([#3775](https://github.com/PyTorchLightning/pytorch-lightning/pull/3775)) +- Added `persistent(mode)` method to metrics, to enable and disable metric states being added to `state_dict` ([#4482](https://github.com/PyTorchLightning/pytorch-lightning/pull/4482)) + + ### Changed - Tuner algorithms will be skipped if `fast_dev_run=True` ([#3903](https://github.com/PyTorchLightning/pytorch-lightning/pull/3903)) @@ -49,6 +52,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed feature-lack in hpc load ([#4526](https://github.com/PyTorchLightning/pytorch-lightning/pull/4526)) + +- Fixed metrics states being overridden in ddp mode ([#4482](https://github.com/PyTorchLightning/pytorch-lightning/pull/4482)) + + ## [1.0.5] - 2020-11-03 ### Added diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index e41fdfa7d1c74..d80b35f91abd1 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -131,6 +131,12 @@ This metrics API is independent of PyTorch Lightning. Metrics can directly be us It is highly recommended to re-initialize the metric per mode as shown in the examples above. +.. note:: + + Metric states will as default add their internal state to the models ``state_dict``. + To change this after initializing the metric the method ``.persistent(mode)`` can + be used to enable (``mode=True``) or disable (``mode=False``) this behaviour. + ********************* Implementing a Metric ********************* diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 1a568bab37209..9fa479dfb567a 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -81,8 +81,9 @@ def __init__( self._forward_cache = None # initialize state - self._reductions = {} self._defaults = {} + self._persistent = {} + self._reductions = {} def add_state( self, name: str, default, dist_reduce_fx: Optional[Union[str, Callable]] = None, persistent: bool = True @@ -138,16 +139,10 @@ def add_state( "`dist_reduce_fx` must be callable or one of ['mean', 'sum', 'cat', None]" ) - if isinstance(default, torch.Tensor): - if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): - # persistent keyword is only supported in torch >= 1.6.0 - self.register_buffer(name, default, persistent=persistent) - else: - self.register_buffer(name, default) - else: - setattr(self, name, default) + setattr(self, name, default) self._defaults[name] = deepcopy(default) + self._persistent[name] = persistent self._reductions[name] = dist_reduce_fx @torch.jit.unused @@ -265,3 +260,36 @@ def __setstate__(self, state): self.__dict__.update(state) self.update = self._wrap_update(self.update) self.compute = self._wrap_compute(self.compute) + + def _apply(self, fn): + """ Overwrite _apply function such that we can also move metric states + to the correct device when `.to`, `.cuda`, etc methods are called + """ + self = super()._apply(fn) + # Also apply fn to metric states + for key in self._defaults.keys(): + current_val = getattr(self, key) + if isinstance(current_val, torch.Tensor): + setattr(self, key, fn(current_val)) + elif isinstance(current_val, Sequence): + setattr(self, key, [fn(cur_v) for cur_v in current_val]) + else: + raise TypeError('Expected metric state to be either a torch.Tensor' + f'or a list of torch.Tensor, but encountered {current_val}') + return self + + def persistent(self, mode: bool = True): + """ Method for post-init to change if metric states should be saved to + its state_dict + """ + for key in self._persistent.keys(): + self._persistent[key] = mode + + def state_dict(self, *args, **kwargs): + # Register metric states to be part of the state_dict + state_dict = super().state_dict() + for key in self._defaults.keys(): + if self._persistent[key]: + current_val = getattr(self, key) + state_dict.update({key: current_val}) + return state_dict diff --git a/tests/metrics/test_metric_lightning.py b/tests/metrics/test_metric_lightning.py index 3c6938734be10..a35562327d717 100644 --- a/tests/metrics/test_metric_lightning.py +++ b/tests/metrics/test_metric_lightning.py @@ -1,9 +1,11 @@ -import os +import pytest import torch + from pytorch_lightning import Trainer from pytorch_lightning.metrics import Metric from tests.base.boring_model import BoringModel +import tests.base.develop_utils as tutils class SumMetric(Metric): @@ -54,15 +56,19 @@ def test_metric_lightning_log(tmpdir): class TestModel(BoringModel): def __init__(self): super().__init__() - self.metric = SumMetric() + self.metric_step = SumMetric() + self.metric_epoch = SumMetric() self.sum = 0.0 def training_step(self, batch, batch_idx): x = batch - self.metric(x.sum()) + self.metric_step(x.sum()) self.sum += x.sum() - self.log("sum", self.metric, on_epoch=True, on_step=False) - return self.step(x) + self.log("sum_step", self.metric_step, on_epoch=True, on_step=False) + return {'loss': self.step(x), 'data': x} + + def training_epoch_end(self, outs): + self.log("sum_epoch", self.metric_epoch(torch.stack([o['data'] for o in outs]).sum())) model = TestModel() model.val_dataloader = None @@ -78,7 +84,8 @@ def training_step(self, batch, batch_idx): trainer.fit(model) logged = trainer.logged_metrics - assert torch.allclose(torch.tensor(logged["sum"]), model.sum) + assert torch.allclose(torch.tensor(logged["sum_step"]), model.sum) + assert torch.allclose(torch.tensor(logged["sum_epoch"]), model.sum) def test_scriptable(tmpdir):