From d6ba429f2f01b6bb06fd14ee74e83460e2ae5696 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Sun, 1 Nov 2020 17:12:38 +0100 Subject: [PATCH 1/9] changes --- pytorch_lightning/metrics/metric.py | 47 +++++++++++++++++++----- tests/metrics/test_metric_lightning.py | 50 ++++++++++++++++++++++++-- 2 files changed, 86 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 3a853be0ebdd5..add194c4b46e9 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -75,9 +75,14 @@ def __init__( self._computed = None self._forward_cache = None + # Hook that will add metric states to state_dict + self._register_state_dict_hook(add_metrics_state_dict) + # initialize state - self._reductions = {} self._defaults = {} + self._persistent = {} + self._reductions = {} + def add_state( self, name: str, default, dist_reduce_fx: Optional[Union[str, Callable]] = None, persistent: bool = True @@ -133,16 +138,10 @@ def add_state( "`dist_reduce_fx` must be callable or one of ['mean', 'sum', 'cat', None]" ) - if isinstance(default, torch.Tensor): - if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): - # persistent keyword is only supported in torch >= 1.6.0 - self.register_buffer(name, default, persistent=persistent) - else: - self.register_buffer(name, default) - else: - setattr(self, name, default) + setattr(self, name, default) self._defaults[name] = deepcopy(default) + self._persistent[name] = persistent self._reductions[name] = dist_reduce_fx def forward(self, *args, **kwargs): @@ -255,3 +254,33 @@ def __setstate__(self, state): self.__dict__.update(state) self.update = self._wrap_update(self.update) self.compute = self._wrap_compute(self.compute) + + def _apply(self, fn): + """ Overwrite _apply function such that we can also move metric states + to the correct divice when `.to`, `.cuda` ect methods are called + """ + self = super()._apply(fn) + # Also apply fn to metric states + for key in self._defaults.keys(): + current_val = getattr(self, key, None) + if current_val is not None and isinstance(current_val, torch.Tensor): + setattr(self, key, fn(current_val)) + else: + setattr(self, key, [fn(cur_v) for cur_v in current_val]) + return self + + def persistant(self, mode: bool = True): + """ Method for post-init to change if metric states should be saved to + its state_dict + """ + for key in self._persistent.keys(): + self._persistant[key] = mode + + +def add_metrics_state_dict(self, state_dict, prefix, local_metadata): + """ Register metric states to be part of the state_dict """ + for key in self._defaults.keys(): + if self._persistent[key]: + current_val = getattr(self, key) + state_dict.update({key: current_val}) + return state_dict diff --git a/tests/metrics/test_metric_lightning.py b/tests/metrics/test_metric_lightning.py index 7a860ea6c16fd..7e12409e7fff8 100644 --- a/tests/metrics/test_metric_lightning.py +++ b/tests/metrics/test_metric_lightning.py @@ -1,3 +1,5 @@ +import pytest + import torch from pytorch_lightning import Trainer @@ -50,6 +52,43 @@ def training_epoch_end(self, outs): def test_metric_lightning_log(tmpdir): + class TestModel(BoringModel): + def __init__(self): + super().__init__() + self.metric_step = SumMetric() + self.metric_epoch = SumMetric() + self.sum = 0.0 + + def training_step(self, batch, batch_idx): + x = batch + self.metric_step(x.sum()) + self.sum += x.sum() + self.log("sum_step", self.metric_step, on_epoch=True, on_step=False) + return {'loss': self.step(x), 'data': x} + + def training_epoch_end(self, outs): + self.log("sum_epoch", self.metric_epoch(torch.stack([o['data'] for o in outs]).sum())) + + model = TestModel() + model.val_dataloader = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + ) + trainer.fit(model) + + logged = trainer.logged_metrics + assert torch.allclose(torch.tensor(logged["sum_step"]), model.sum) + assert torch.allclose(torch.tensor(logged["sum_epoch"]), model.sum) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_metric_lightning_ddp(tmpdir): class TestModel(BoringModel): def __init__(self): super().__init__() @@ -58,7 +97,7 @@ def __init__(self): def training_step(self, batch, batch_idx): x = batch - self.metric(x.sum()) + self.metric_step(x.sum()) self.sum += x.sum() self.log("sum", self.metric, on_epoch=True, on_step=False) return self.step(x) @@ -73,8 +112,15 @@ def training_step(self, batch, batch_idx): max_epochs=1, log_every_n_steps=1, weights_summary=None, + gpus=2, + accelerator='ddp' ) trainer.fit(model) + # Manual calculate sum + manual_sum = 0.0 + for batch in model.train_dataloader(): + manual_sum += batch.sum() + logged = trainer.logged_metrics - assert torch.allclose(torch.tensor(logged["sum"]), model.sum) + assert torch.allclose(torch.tensor(logged["sum"]), manual_sum) From c4bd366fb4a721dc52e8bdf3f9c919814ff2515d Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Sun, 1 Nov 2020 18:23:50 +0100 Subject: [PATCH 2/9] fix spelling --- pytorch_lightning/metrics/metric.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index add194c4b46e9..7633fbcf0a641 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -269,12 +269,12 @@ def _apply(self, fn): setattr(self, key, [fn(cur_v) for cur_v in current_val]) return self - def persistant(self, mode: bool = True): + def persistent(self, mode: bool = True): """ Method for post-init to change if metric states should be saved to its state_dict """ for key in self._persistent.keys(): - self._persistant[key] = mode + self._persistent[key] = mode def add_metrics_state_dict(self, state_dict, prefix, local_metadata): From 53190b7727b0d6b4b6623cede686942eedb31c74 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 3 Nov 2020 09:54:57 +0100 Subject: [PATCH 3/9] small note --- docs/source/metrics.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 4fadfaa507168..1a3cc7bda5606 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -111,6 +111,12 @@ This metrics API is independent of PyTorch Lightning. Metrics can directly be us It is highly recommended to re-initialize the metric per mode as shown in the examples above. +.. note:: + + Metric states will as default add their internal state to the models ``state_dict``. + To prevent this, the class method ``.persistent(mode)`` can be used, with mode + set to ``False``. + ********************* Implementing a Metric ********************* From 80e336ab28cf1f1f670bb56dd8fa3972de8fb930 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 3 Nov 2020 14:46:20 +0100 Subject: [PATCH 4/9] trying to fix ddp test --- tests/metrics/test_metric_lightning.py | 52 +++++++++++++++----------- 1 file changed, 31 insertions(+), 21 deletions(-) diff --git a/tests/metrics/test_metric_lightning.py b/tests/metrics/test_metric_lightning.py index 7e12409e7fff8..19dec7270b579 100644 --- a/tests/metrics/test_metric_lightning.py +++ b/tests/metrics/test_metric_lightning.py @@ -5,7 +5,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.metrics import Metric from tests.base.boring_model import BoringModel - +import tests.base.develop_utils as tutils class SumMetric(Metric): def __init__(self): @@ -89,38 +89,48 @@ def training_epoch_end(self, outs): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") def test_metric_lightning_ddp(tmpdir): + tutils.set_random_master_port() + + # Dummy dataset, where sum is known + data=torch.arange(10)[:,None].float() + dataset = torch.utils.data.TensorDataset(data) + class TestModel(BoringModel): def __init__(self): super().__init__() self.metric = SumMetric() - self.sum = 0.0 + self.p = torch.nn.Linear(1,1) # fake params def training_step(self, batch, batch_idx): - x = batch - self.metric_step(x.sum()) - self.sum += x.sum() - self.log("sum", self.metric, on_epoch=True, on_step=False) - return self.step(x) + val = self.metric(batch[0]) + self.log("sum", self.metric, on_step=False, on_epoch=True) + return self.p(val.view(1,1)) - model = TestModel() - model.val_dataloader = None + def train_dataloader(self): + return torch.utils.data.DataLoader( + dataset, + batch_size=1, + sampler=torch.utils.data.distributed.DistributedSampler(dataset, shuffle=False) + ) + + def configure_optimizers(self): + return None + model = TestModel() trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=2, - limit_val_batches=2, + gpus=2, max_epochs=1, log_every_n_steps=1, - weights_summary=None, - gpus=2, - accelerator='ddp' + accelerator='ddp', + progress_bar_refresh_rate=0, + replace_sampler_ddp=False ) trainer.fit(model) - # Manual calculate sum - manual_sum = 0.0 - for batch in model.train_dataloader(): - manual_sum += batch.sum() - logged = trainer.logged_metrics - assert torch.allclose(torch.tensor(logged["sum"]), manual_sum) + + assert torch.tensor(logged["sum"]) == dataset.tensors[0].sum(), \ + "Metrics did not accumulate correctly in ddp mode" + + + From e9e126765696201b1984b622e5a4a15640cadfaf Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 5 Nov 2020 16:17:25 +0100 Subject: [PATCH 5/9] fix ddp --- docs/source/metrics.rst | 4 +-- pytorch_lightning/metrics/metric.py | 19 +++++------ tests/metrics/test_metric_lightning.py | 44 +++++++++++--------------- 3 files changed, 29 insertions(+), 38 deletions(-) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 1a3cc7bda5606..660dec028886e 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -114,8 +114,8 @@ This metrics API is independent of PyTorch Lightning. Metrics can directly be us .. note:: Metric states will as default add their internal state to the models ``state_dict``. - To prevent this, the class method ``.persistent(mode)`` can be used, with mode - set to ``False``. + To change this after initializing the metric the method ``.persistent(mode)`` can + be used to enable (``mode=True``) or disable (``mode=False``) this behaviour. ********************* Implementing a Metric diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index be4ac1d321609..1eeb482fcf818 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -75,9 +75,6 @@ def __init__( self._computed = None self._forward_cache = None - # Hook that will add metric states to state_dict - self._register_state_dict_hook(add_metrics_state_dict) - # initialize state self._defaults = {} self._persistent = {} @@ -277,11 +274,11 @@ def persistent(self, mode: bool = True): for key in self._persistent.keys(): self._persistent[key] = mode - -def add_metrics_state_dict(self, state_dict, prefix, local_metadata): - """ Register metric states to be part of the state_dict """ - for key in self._defaults.keys(): - if self._persistent[key]: - current_val = getattr(self, key) - state_dict.update({key: current_val}) - return state_dict + def state_dict(self, *args, **kwargs): + # Register metric states to be part of the state_dict + state_dict = super().state_dict() + for key in self._defaults.keys(): + if self._persistent[key]: + current_val = getattr(self, key) + state_dict.update({key: current_val}) + return state_dict diff --git a/tests/metrics/test_metric_lightning.py b/tests/metrics/test_metric_lightning.py index 19dec7270b579..7b1804d31127e 100644 --- a/tests/metrics/test_metric_lightning.py +++ b/tests/metrics/test_metric_lightning.py @@ -87,49 +87,43 @@ def training_epoch_end(self, outs): assert torch.allclose(torch.tensor(logged["sum_epoch"]), model.sum) +class TestModel(BoringModel): + def __init__(self): + super().__init__() + self.metric = SumMetric() + self.p = torch.nn.Linear(1,1) # fake params + + def training_step(self, batch, batch_idx): + val = self.metric(batch[0]) + self.log("sum", self.metric, on_step=False, on_epoch=True) + return self.p(val.view(1,1)) + + def configure_optimizers(self): + return None + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") def test_metric_lightning_ddp(tmpdir): tutils.set_random_master_port() # Dummy dataset, where sum is known - data=torch.arange(10)[:,None].float() + data = torch.arange(10)[:,None].float() dataset = torch.utils.data.TensorDataset(data) - - class TestModel(BoringModel): - def __init__(self): - super().__init__() - self.metric = SumMetric() - self.p = torch.nn.Linear(1,1) # fake params - - def training_step(self, batch, batch_idx): - val = self.metric(batch[0]) - self.log("sum", self.metric, on_step=False, on_epoch=True) - return self.p(val.view(1,1)) - - def train_dataloader(self): - return torch.utils.data.DataLoader( - dataset, - batch_size=1, - sampler=torch.utils.data.distributed.DistributedSampler(dataset, shuffle=False) - ) - - def configure_optimizers(self): - return None + dataloader = torch.utils.data.DataLoader(dataset) model = TestModel() trainer = Trainer( gpus=2, max_epochs=1, log_every_n_steps=1, - accelerator='ddp', + accelerator='ddp_spawn', progress_bar_refresh_rate=0, replace_sampler_ddp=False ) - trainer.fit(model) + trainer.fit(model, dataloader) logged = trainer.logged_metrics - assert torch.tensor(logged["sum"]) == dataset.tensors[0].sum(), \ + assert torch.tensor(logged["sum"]) == data.sum(), \ "Metrics did not accumulate correctly in ddp mode" From 01426e4247f752b76ecdd451855979b206833788 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 5 Nov 2020 19:41:18 +0100 Subject: [PATCH 6/9] fix for test --- tests/metrics/test_metric_lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/metrics/test_metric_lightning.py b/tests/metrics/test_metric_lightning.py index 99f5a886d23ec..a3fca84768724 100644 --- a/tests/metrics/test_metric_lightning.py +++ b/tests/metrics/test_metric_lightning.py @@ -95,7 +95,7 @@ def __init__(self): self.p = torch.nn.Linear(1,1) # fake params def training_step(self, batch, batch_idx): - val = self.metric(batch[0]) + val = self.metric(batch[0].sum()) self.log("sum", self.metric, on_step=False, on_epoch=True) return self.p(val.view(1,1)) From 1bf03f4a9241bcc559c19ae2297cdc62fa429399 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 9 Nov 2020 11:36:54 +0100 Subject: [PATCH 7/9] suggestion --- pytorch_lightning/metrics/metric.py | 9 ++++-- tests/metrics/test_metric_lightning.py | 41 -------------------------- 2 files changed, 6 insertions(+), 44 deletions(-) diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 0eefceecc5a6e..8f10c4149aa18 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -260,11 +260,14 @@ def _apply(self, fn): self = super()._apply(fn) # Also apply fn to metric states for key in self._defaults.keys(): - current_val = getattr(self, key, None) - if current_val is not None and isinstance(current_val, torch.Tensor): + current_val = getattr(self, key) + if isinstance(current_val, torch.Tensor): setattr(self, key, fn(current_val)) - else: + elif isinstance(current_val, Sequence): setattr(self, key, [fn(cur_v) for cur_v in current_val]) + else: + raise TypeError('Expected metric state to be either a torch.Tensor' + f'or a list of torch.Tensor, but encountered {current_val}') return self def persistent(self, mode: bool = True): diff --git a/tests/metrics/test_metric_lightning.py b/tests/metrics/test_metric_lightning.py index a3fca84768724..a35562327d717 100644 --- a/tests/metrics/test_metric_lightning.py +++ b/tests/metrics/test_metric_lightning.py @@ -88,47 +88,6 @@ def training_epoch_end(self, outs): assert torch.allclose(torch.tensor(logged["sum_epoch"]), model.sum) -class TestModel(BoringModel): - def __init__(self): - super().__init__() - self.metric = SumMetric() - self.p = torch.nn.Linear(1,1) # fake params - - def training_step(self, batch, batch_idx): - val = self.metric(batch[0].sum()) - self.log("sum", self.metric, on_step=False, on_epoch=True) - return self.p(val.view(1,1)) - - def configure_optimizers(self): - return None - - -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -def test_metric_lightning_ddp(tmpdir): - tutils.set_random_master_port() - - # Dummy dataset, where sum is known - data = torch.arange(10)[:,None].float() - dataset = torch.utils.data.TensorDataset(data) - dataloader = torch.utils.data.DataLoader(dataset) - - model = TestModel() - trainer = Trainer( - gpus=2, - max_epochs=1, - log_every_n_steps=1, - accelerator='ddp_spawn', - progress_bar_refresh_rate=0, - replace_sampler_ddp=False - ) - trainer.fit(model, dataloader) - - logged = trainer.logged_metrics - - assert torch.tensor(logged["sum"]) == data.sum(), \ - "Metrics did not accumulate correctly in ddp mode" - - def test_scriptable(tmpdir): class TestModel(BoringModel): def __init__(self): From 7cbf652efd73d7ce935a147afc02688e85fa9493 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 9 Nov 2020 11:41:28 +0100 Subject: [PATCH 8/9] CHANGELOG --- CHANGELOG.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1f4defbd5cc30..c27a408252c07 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,7 +30,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `fsspec` to tuner ([#4458](https://github.com/PyTorchLightning/pytorch-lightning/pull/4458)) -- Added metrics aggregation in Horovod and fixed early stopping ([#3775](https://github.com/PyTorchLightning/pytorch-lightning/pull/3775)) +- Added metrics aggregation in Horovod and fixed early stopping ([#3775](https://github.com/PyTorchLightning/pytorch-lightning/pull/3775)) + + +- Added `persistent(mode)` method to metrics, to enable and disable metric states being added to `state_dict` ([#4482](https://github.com/PyTorchLightning/pytorch-lightning/pull/4482)) ### Changed @@ -47,6 +50,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed metrics states being overridden in ddp mode ([#4482](https://github.com/PyTorchLightning/pytorch-lightning/pull/4482)) ## [1.0.5] - 2020-11-03 From 36a397f830e6c278545cc6558213259ee713542d Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Mon, 9 Nov 2020 17:17:17 +0000 Subject: [PATCH 9/9] Update pytorch_lightning/metrics/metric.py --- pytorch_lightning/metrics/metric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 555560de3545b..9fa479dfb567a 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -263,7 +263,7 @@ def __setstate__(self, state): def _apply(self, fn): """ Overwrite _apply function such that we can also move metric states - to the correct divice when `.to`, `.cuda` ect methods are called + to the correct device when `.to`, `.cuda`, etc methods are called """ self = super()._apply(fn) # Also apply fn to metric states