Skip to content

Commit

Permalink
Refactor: Rename update and compute methods to _update and _compute (#…
Browse files Browse the repository at this point in the history
…840)

* update and compute
* improve tests
* change to warning
* fix docs

Co-authored-by: Jirka <jirka.borovec@seznam.cz>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Mar 1, 2022
1 parent ffe824a commit 99a0c6b
Show file tree
Hide file tree
Showing 79 changed files with 361 additions and 229 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Deprecated passing in `dist_sync_on_step`, `process_group`, `dist_sync_fn` direct argument ([#833](https://github.com/PyTorchLightning/metrics/pull/833))


- Moved particular metrics implementation from `update` and `compute` methods to `_update` and `_compute` ([#840](https://github.com/PyTorchLightning/metrics/pull/840))


### Removed

- Removed support for versions of Lightning lower than v1.5 ([#788](https://github.com/PyTorchLightning/metrics/pull/788))
Expand Down
4 changes: 2 additions & 2 deletions docs/source/pages/brief_intro.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,14 @@ Implementing a metric
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

def update(self, preds: torch.Tensor, target: torch.Tensor):
def _update(self, preds: torch.Tensor, target: torch.Tensor):
# update metric states
preds, target = self._input_format(preds, target)
assert preds.shape == target.shape

self.correct += torch.sum(preds == target)
self.total += target.numel()

def compute(self):
def _compute(self):
# compute final result
return self.correct.float() / self.total
29 changes: 15 additions & 14 deletions docs/source/pages/implement.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
Implementing a Metric
*********************

To implement your own custom metric, subclass the base :class:`~torchmetrics.Metric` class and implement the following methods:
To implement your own custom metric, subclass the base :class:`~torchmetrics.Metric` class and
implement the following methods:

- ``__init__()``: Each state variable should be called using ``self.add_state(...)``.
- ``update()``: Any code needed to update the state given any inputs to the metric.
- ``compute()``: Computes a final value from the state of the metric.
- ``_update()``: Any code needed to update the state given any inputs to the metric.
- ``_compute()``: Computes a final value from the state of the metric.

We provide the remaining interface, such as ``reset()`` that will make sure to correctly reset all metric
states that have been added using ``add_state``. You should therefore not implement ``reset()`` yourself.
Expand All @@ -29,33 +30,33 @@ Example implementation:
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

def update(self, preds: torch.Tensor, target: torch.Tensor):
def _update(self, preds: torch.Tensor, target: torch.Tensor):
preds, target = self._input_format(preds, target)
assert preds.shape == target.shape

self.correct += torch.sum(preds == target)
self.total += target.numel()

def compute(self):
def _compute(self):
return self.correct.float() / self.total


Internal implementation details
-------------------------------

This section briefly describes how metrics work internally. We encourage looking at the source code for more info.
Internally, TorchMetrics wraps the user defined ``update()`` and ``compute()`` method. We do this to automatically
synchronize and reduce metric states across multiple devices. More precisely, calling ``update()`` does the
following internally:
Whenever the public ``update`` or ``compute`` method is called they will internally try to synchronize and reduce
metric states across multiple device before calling the actual implementation provided in the private methods
`_update()` and `_compute`. More precisely, calling ``update()`` does the following internally:

1. Clears computed cache.
2. Calls user-defined ``update()``.
2. Calls user-defined ``_update()``.

Similarly, calling ``compute()`` does the following internally:

1. Syncs metric states between processes.
2. Reduce gathered metric states.
3. Calls the user defined ``compute()`` method on the gathered metric states.
3. Calls the user defined ``_compute()`` method on the gathered metric states.
4. Cache computed result.

From a user's standpoint this has one important side-effect: computed results are cached. This means that no
Expand All @@ -76,7 +77,6 @@ to ``update`` and ``compute`` in the following way:
This procedure has the consequence of calling the user defined ``update`` **twice** during a single
forward call (one to update global statistics and one for getting the batch statistics).


---------

.. autoclass:: torchmetrics.Metric
Expand All @@ -95,7 +95,8 @@ and tests gets formatted in the following way:
metric (classification, regression, nlp etc) and ``new_metric`` is the name of the metric. In this file, there should be the
following three functions:

1. ``_new_metric_update(...)``: everything that has to do with type/shape checking and all logic required before distributed syncing need to go here.
1. ``_new_metric_update(...)``: everything that has to do with type/shape checking and all logic required before distributed
syncing need to go here.
2. ``_new_metric_compute(...)``: all remaining logic.
3. ``new_metric(...)``: essentially wraps the ``_update`` and ``_compute`` private functions into one public function that
makes up the functional interface for the metric.
Expand All @@ -109,8 +110,8 @@ and tests gets formatted in the following way:
1. Create a new module metric by subclassing ``torchmetrics.Metric``.
2. In the ``__init__`` of the module call ``self.add_state`` for as many metric states are needed for the metric to
proper accumulate metric statistics.
3. The module interface should essentially call the private ``_new_metric_update(...)`` in its `update` method and similarly the
``_new_metric_compute(...)`` function in its ``compute``. No logic should really be implemented in the module interface.
3. The module interface should essentially call the private ``_new_metric_update(...)`` in its `_update` method and similarly the
``_new_metric_compute(...)`` function in its ``_compute``. No logic should really be implemented in the module interface.
We do this to not have duplicate code to maintain.

.. note::
Expand Down
14 changes: 8 additions & 6 deletions docs/source/pages/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -115,16 +115,18 @@ Implementing your own metric
Implementing your own metric is as easy as subclassing a :class:`torch.nn.Module`. Simply, subclass :class:`~torchmetrics.Metric` and do the following:

1. Implement ``__init__`` where you call ``self.add_state`` for every internal state that is needed for the metrics computations
2. Implement ``update`` method, where all logic that is necessary for updating metric states go
3. Implement ``compute`` method, where the final metric computations happens
2. Implement ``_update`` method, where all logic that is necessary for updating metric states go
3. Implement ``_compute`` method, where the final metric computations happens

For practical examples and more info about implementing a metric, please see this :ref:`page <implement>`.


Development Environment
~~~~~~~~~~~~~~~~~~~~~~~

TorchMetrics provides a `Devcontainer <https://code.visualstudio.com/docs/remote/containers>`_ configuration for `Visual Studio Code <https://code.visualstudio.com/>`_ to use a `Docker container <https://www.docker.com/>`_ as a pre-configured development environment.
TorchMetrics provides a `Devcontainer <https://code.visualstudio.com/docs/remote/containers>`_ configuration for
`Visual Studio Code <https://code.visualstudio.com/>`_ to use a `Docker container <https://www.docker.com/>`_ as a
pre-configured development environment.
This avoids struggles setting up a development environment and makes them reproducible and consistent.
Please follow the `installation instructions <https://code.visualstudio.com/docs/remote/containers#_installation>`_ and make yourself familiar with the `container tutorials <https://code.visualstudio.com/docs/remote/containers-tutorial>`_ if you want to use them.
In order to use GPUs, you can enable them within the ``.devcontainer/devcontainer.json`` file.
Please follow the `installation instructions <https://code.visualstudio.com/docs/remote/containers#_installation>`_
and make yourself familiar with the `container tutorials <https://code.visualstudio.com/docs/remote/containers-tutorial>`_
if you want to use them. In order to use GPUs, you can enable them within the ``.devcontainer/devcontainer.json`` file.
4 changes: 2 additions & 2 deletions integrations/test_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@


class DiffMetric(SumMetric):
def update(self, value):
super().update(-value)
def _update(self, value):
super()._update(-value)


def test_metric_lightning(tmpdir):
Expand Down
16 changes: 8 additions & 8 deletions tests/bases/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,33 +31,33 @@ def compare_max(values, weights):
class WrappedMinMetric(MinMetric):
"""Wrapped min metric."""

def update(self, values, weights):
def _update(self, values, weights):
"""only pass values on."""
super().update(values)
super()._update(values)


class WrappedMaxMetric(MaxMetric):
"""Wrapped max metric."""

def update(self, values, weights):
def _update(self, values, weights):
"""only pass values on."""
super().update(values)
super()._update(values)


class WrappedSumMetric(SumMetric):
"""Wrapped min metric."""

def update(self, values, weights):
def _update(self, values, weights):
"""only pass values on."""
super().update(values)
super()._update(values)


class WrappedCatMetric(CatMetric):
"""Wrapped cat metric."""

def update(self, values, weights):
def _update(self, values, weights):
"""only pass values on."""
super().update(values)
super()._update(values)


@pytest.mark.parametrize(
Expand Down
8 changes: 4 additions & 4 deletions tests/bases/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,20 +258,20 @@ class DummyMetric(Metric):
def __init__(self):
super().__init__()

def update(self, *args, kwarg):
def _update(self, *args, kwarg):
print("Entered DummyMetric")

def compute(self):
def _compute(self):
return

class MyAccuracy(Metric):
def __init__(self):
super().__init__()

def update(self, preds, target, kwarg2):
def _update(self, preds, target, kwarg2):
print("Entered MyAccuracy")

def compute(self):
def _compute(self):
return

mc = MetricCollection([Accuracy(), DummyMetric()])
Expand Down
4 changes: 2 additions & 2 deletions tests/bases/test_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ def __init__(self, val_to_return):
self._val_to_return = val_to_return
self._update_called = True

def update(self, *args, **kwargs) -> None:
def _update(self, *args, **kwargs) -> None:
self._num_updates += 1

def compute(self):
def _compute(self):
return tensor(self._val_to_return)


Expand Down
8 changes: 4 additions & 4 deletions tests/bases/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,10 @@ def __init__(self):
super().__init__()
self.add_state("x", default=[], dist_reduce_fx=None)

def update(self, x):
def _update(self, x):
self.x.append(x)

def compute(self):
def _compute(self):
x = torch.cat(self.x, dim=0)
return x.sum()

Expand All @@ -141,11 +141,11 @@ def __init__(self):
self.add_state("x", torch.tensor(0), dist_reduce_fx=torch.sum)
self.add_state("c", torch.tensor(0), dist_reduce_fx=torch.sum)

def update(self, x):
def _update(self, x):
self.x += x
self.c += 1

def compute(self):
def _compute(self):
return self.x // self.c

def __repr__(self):
Expand Down
29 changes: 24 additions & 5 deletions tests/bases/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from tests.helpers import seed_all
from tests.helpers.testers import DummyListMetric, DummyMetric, DummyMetricMultiOutput, DummyMetricSum
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _TORCH_LOWER_1_6

seed_all(42)
Expand All @@ -36,6 +37,24 @@ def test_error_on_wrong_input():
DummyMetric(dist_sync_fn=[2, 3])


def test_error_on_not_implemented_methods():
"""Test that error is raised if _update or _compute is not implemented."""

class TempMetric(Metric):
def _compute(self):
return None

with pytest.raises(NotImplementedError, match="Expected method `_update` to be implemented in subclass."):
TempMetric()

class TempMetric(Metric):
def _update(self):
pass

with pytest.raises(NotImplementedError, match="Expected method `_compute` to be implemented in subclass."):
TempMetric()


def test_inherit():
"""Test that metric that inherits can be instanciated."""
DummyMetric()
Expand Down Expand Up @@ -116,7 +135,7 @@ def test_reset_compute():

def test_update():
class A(DummyMetric):
def update(self, x):
def _update(self, x):
self.x += x

a = A()
Expand All @@ -132,10 +151,10 @@ def update(self, x):

def test_compute():
class A(DummyMetric):
def update(self, x):
def _update(self, x):
self.x += x

def compute(self):
def _compute(self):
return self.x

a = A()
Expand Down Expand Up @@ -182,10 +201,10 @@ class B(DummyListMetric):

def test_forward():
class A(DummyMetric):
def update(self, x):
def _update(self, x):
self.x += x

def compute(self):
def _compute(self):
return self.x

a = A()
Expand Down
18 changes: 9 additions & 9 deletions tests/helpers/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,10 +567,10 @@ def __init__(self, **kwargs):
super().__init__(**kwargs)
self.add_state("x", tensor(0.0), dist_reduce_fx=None)

def update(self):
def _update(self):
pass

def compute(self):
def _compute(self):
pass


Expand All @@ -581,29 +581,29 @@ def __init__(self):
super().__init__()
self.add_state("x", [], dist_reduce_fx=None)

def update(self):
def _update(self):
pass

def compute(self):
def _compute(self):
pass


class DummyMetricSum(DummyMetric):
def update(self, x):
def _update(self, x):
self.x += x

def compute(self):
def _compute(self):
return self.x


class DummyMetricDiff(DummyMetric):
def update(self, y):
def _update(self, y):
self.x -= y

def compute(self):
def _compute(self):
return self.x


class DummyMetricMultiOutput(DummyMetricSum):
def compute(self):
def _compute(self):
return [self.x, self.x]
Loading

0 comments on commit 99a0c6b

Please sign in to comment.