Skip to content

Commit

Permalink
[Docs] Explain metric internals (#5899)
Browse files Browse the repository at this point in the history
* correct docs

* fix levels
  • Loading branch information
SkafteNicki authored Feb 16, 2021
1 parent 141316f commit 4062c62
Showing 1 changed file with 44 additions and 6 deletions.
50 changes: 44 additions & 6 deletions docs/source/extensions/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,50 @@ In practise this means that:
val = metric(pred, target) # this value can be backpropagated
val = metric.compute() # this value cannot be backpropagated
Metric API
----------

.. autoclass:: pytorch_lightning.metrics.Metric
:noindex:

Internal implementation details
-------------------------------

This section briefly describe how metrics work internally. We encourage looking at the source code for more info.
Internally, Lightning wraps the user defined ``update()`` and ``compute()`` method. We do this to automatically
synchronize and reduce metric states across multiple devices. More precisely, calling ``update()`` does the
following internally:

1. Clears computed cache
2. Calls user-defined ``update()``

Simiarly, calling ``compute()`` does the following internally

1. Syncs metric states between processes
2. Reduce gathered metric states
3. Calls the user defined ``compute()`` method on the gathered metric states
4. Cache computed result

From a user's standpoint this has one important side-effect: computed results are cached. This means that no
matter how many times ``compute`` is called after one and another, it will continue to return the same result.
The cache is first emptied on the next call to ``update``.

``forward`` serves the dual purpose of both returning the metric on the current data and updating the internal
metric state for accumulating over multiple batches. The ``forward()`` method achives this by combining calls
to ``update`` and ``compute`` in the following way (assuming metric is initialized with ``compute_on_step=True``):

1. Calls ``update()`` to update the global metric states (for accumulation over multiple batches)
2. Caches the global state
3. Calls ``reset()`` to clear global metric state
4. Calls ``update()`` to update local metric state
5. Calls ``compute()`` to calculate metric for current batch
6. Restores the global state

This procedure has the consequence of calling the user defined ``update`` **twice** during a single
forward call (one to update global statistics and one for getting the batch statistics).


******************
Metric Arithmetics
******************
Expand Down Expand Up @@ -367,12 +411,6 @@ inside your LightningModule
.. autoclass:: pytorch_lightning.metrics.MetricCollection
:noindex:

**********
Metric API
**********

.. autoclass:: pytorch_lightning.metrics.Metric
:noindex:

***************************
Class vs Functional Metrics
Expand Down

0 comments on commit 4062c62

Please sign in to comment.