Skip to content

Commit

Permalink
[Metrics] Detach bugfix (#4313)
Browse files Browse the repository at this point in the history
* detach on buffer

* doc update

* remove file

* changelog

* suggestions

* Update docs/source/metrics.rst

Co-authored-by: Teddy Koker <teddy.koker@gmail.com>

* fix for 4266

* Update docs/source/metrics.rst

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* Update CHANGELOG.md

Co-authored-by: Teddy Koker <teddy.koker@gmail.com>
Co-authored-by: chaton <thomas@grid.ai>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Ananya Harsh Jha <ananya@pytorchlightning.ai>
Co-authored-by: Roger Shieh <sh.rog@protonmail.ch>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
  • Loading branch information
8 people committed Nov 21, 2020
1 parent 1a64e6a commit f2a5747
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 5 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed error using `auto_select_gpus=True` with `gpus=-1` ([#4209](https://github.com/PyTorchLightning/pytorch-lightning/pull/4209))

- Fixed that metrics do not store computational graph for all seen data ([#4313](https://github.com/PyTorchLightning/pytorch-lightning/pull/4313))

- Fixed AMP unscale for `on_after_backward` ([#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439))

## [1.0.4] - 2020-10-27
Expand Down
14 changes: 13 additions & 1 deletion docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,19 @@ Example implementation:
def compute(self):
return self.correct.float() / self.total
Metrics support backpropagation, if all computations involved in the metric calculation
are differentiable. However, note that the cached state is detached from the computational
graph and cannot be backpropagated. Not doing this would mean storing the computational
graph for each update call, which can lead to out-of-memory errors.
In practise this means that:

.. code-block:: python
metric = MyMetric()
val = metric(pred, target) # this value can be backpropagated
val = metric.compute() # this value cannot be backpropagated
**********
Metric API
**********
Expand Down Expand Up @@ -453,4 +466,3 @@ embedding_similarity [func]

.. autofunction:: pytorch_lightning.metrics.functional.self_supervised.embedding_similarity
:noindex:

6 changes: 3 additions & 3 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def get_batch_log_metrics(self, include_forked_originals=True) -> dict:

if options['logger'] and options['on_step']:
if isinstance(self[k], Metric):
result[k] = self[k]._forward_cache
result[k] = self[k]._forward_cache.detach()
else:
result[k] = self[k]

Expand All @@ -281,7 +281,7 @@ def get_epoch_log_metrics(self) -> dict:

if options['logger'] and options['on_epoch']:
if isinstance(self[k], Metric):
result[k] = self[k].compute()
result[k] = self[k].compute().detach()
else:
result[k] = self[k]

Expand All @@ -307,7 +307,7 @@ def get_epoch_pbar_metrics(self):

if options['prog_bar'] and options['on_epoch']:
if isinstance(self[k], Metric):
result[k] = self[k].compute()
result[k] = self[k].compute().detach()
else:
result[k] = self[k]

Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ def forward(self, *args, **kwargs):
Automatically calls ``update()``. Returns the metric value over inputs if ``compute_on_step`` is True.
"""
# add current step
self.update(*args, **kwargs)
with torch.no_grad():
self.update(*args, **kwargs)
self._forward_cache = None

if self.compute_on_step:
Expand Down

0 comments on commit f2a5747

Please sign in to comment.