From f2a574780652ad383a74989e32f53948984e35bb Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 2 Nov 2020 20:44:49 +0100 Subject: [PATCH] [Metrics] Detach bugfix (#4313) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * detach on buffer * doc update * remove file * changelog * suggestions * Update docs/source/metrics.rst Co-authored-by: Teddy Koker * fix for 4266 * Update docs/source/metrics.rst Co-authored-by: Adrian Wälchli * Update CHANGELOG.md Co-authored-by: Teddy Koker Co-authored-by: chaton Co-authored-by: Adrian Wälchli Co-authored-by: Ananya Harsh Jha Co-authored-by: Roger Shieh Co-authored-by: Sean Naren Co-authored-by: Rohit Gupta --- CHANGELOG.md | 2 ++ docs/source/metrics.rst | 14 +++++++++++++- pytorch_lightning/core/step_result.py | 6 +++--- pytorch_lightning/metrics/metric.py | 3 ++- 4 files changed, 20 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 84d483dd03f2c3..95417dd8aa9ae7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,6 +39,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed error using `auto_select_gpus=True` with `gpus=-1` ([#4209](https://github.com/PyTorchLightning/pytorch-lightning/pull/4209)) +- Fixed that metrics do not store computational graph for all seen data ([#4313](https://github.com/PyTorchLightning/pytorch-lightning/pull/4313)) + - Fixed AMP unscale for `on_after_backward` ([#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439)) ## [1.0.4] - 2020-10-27 diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index de3cd01c33e9b7..4fadfaa5071689 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -150,6 +150,19 @@ Example implementation: def compute(self): return self.correct.float() / self.total +Metrics support backpropagation, if all computations involved in the metric calculation +are differentiable. However, note that the cached state is detached from the computational +graph and cannot be backpropagated. Not doing this would mean storing the computational +graph for each update call, which can lead to out-of-memory errors. +In practise this means that: + +.. code-block:: python + + metric = MyMetric() + val = metric(pred, target) # this value can be backpropagated + val = metric.compute() # this value cannot be backpropagated + + ********** Metric API ********** @@ -453,4 +466,3 @@ embedding_similarity [func] .. autofunction:: pytorch_lightning.metrics.functional.self_supervised.embedding_similarity :noindex: - diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 650c1876d0cd0a..a8224f45e3829a 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -259,7 +259,7 @@ def get_batch_log_metrics(self, include_forked_originals=True) -> dict: if options['logger'] and options['on_step']: if isinstance(self[k], Metric): - result[k] = self[k]._forward_cache + result[k] = self[k]._forward_cache.detach() else: result[k] = self[k] @@ -281,7 +281,7 @@ def get_epoch_log_metrics(self) -> dict: if options['logger'] and options['on_epoch']: if isinstance(self[k], Metric): - result[k] = self[k].compute() + result[k] = self[k].compute().detach() else: result[k] = self[k] @@ -307,7 +307,7 @@ def get_epoch_pbar_metrics(self): if options['prog_bar'] and options['on_epoch']: if isinstance(self[k], Metric): - result[k] = self[k].compute() + result[k] = self[k].compute().detach() else: result[k] = self[k] diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 3a853be0ebdd50..f003e0d3da72ae 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -150,7 +150,8 @@ def forward(self, *args, **kwargs): Automatically calls ``update()``. Returns the metric value over inputs if ``compute_on_step`` is True. """ # add current step - self.update(*args, **kwargs) + with torch.no_grad(): + self.update(*args, **kwargs) self._forward_cache = None if self.compute_on_step: