Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compositional metrics #5464

Merged
merged 36 commits into from
Jan 26, 2021
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
68d9768
implement compositional metrics
justusschock Jan 11, 2021
73c9009
implement composition functions for metrics
justusschock Jan 11, 2021
e585156
test compositions
justusschock Jan 11, 2021
3dc10c3
docs
justusschock Jan 11, 2021
40b2faf
pytest
justusschock Jan 11, 2021
ebe724d
pep8
justusschock Jan 11, 2021
2d6f778
fix argument resolution
justusschock Jan 11, 2021
3edbca6
return all kwargs if filtering not possible
justusschock Jan 12, 2021
6f6cda3
fix typo
justusschock Jan 12, 2021
f61942c
implement hashing
justusschock Jan 12, 2021
edaf065
Update pytorch_lightning/metrics/compositional.py
justusschock Jan 13, 2021
92d9617
Update docs/source/metrics.rst
justusschock Jan 13, 2021
b6fc554
add representation
justusschock Jan 13, 2021
57be2eb
Apply suggestions from code review
SkafteNicki Jan 13, 2021
908d3a8
Update docs/source/metrics.rst
justusschock Jan 18, 2021
946a062
Merge branch 'release/1.2-dev' into compositional_metric
mergify[bot] Jan 24, 2021
049ebd7
chlog
Borda Jan 24, 2021
f9f7625
flake8
Borda Jan 24, 2021
573bc56
Merge branch 'release/1.2-dev' into compositional_metric
mergify[bot] Jan 24, 2021
3033212
Merge branch 'release/1.2-dev' into compositional_metric
mergify[bot] Jan 24, 2021
2b26fa4
Merge branch 'release/1.2-dev' into compositional_metric
mergify[bot] Jan 24, 2021
bdebabc
Merge branch 'release/1.2-dev' into compositional_metric
mergify[bot] Jan 24, 2021
b963d35
Merge branch 'release/1.2-dev' into compositional_metric
mergify[bot] Jan 25, 2021
121063c
Merge branch 'release/1.2-dev' into compositional_metric
mergify[bot] Jan 25, 2021
00ac4ec
Merge branch 'release/1.2-dev' into compositional_metric
mergify[bot] Jan 25, 2021
23b38f9
Merge branch 'release/1.2-dev' into compositional_metric
mergify[bot] Jan 25, 2021
465ea95
Merge branch 'release/1.2-dev' into compositional_metric
mergify[bot] Jan 25, 2021
c66ea8e
Merge branch 'release/1.2-dev' into compositional_metric
mergify[bot] Jan 25, 2021
7cc6171
Merge branch 'release/1.2-dev' into compositional_metric
mergify[bot] Jan 26, 2021
a337401
Merge branch 'release/1.2-dev' into compositional_metric
mergify[bot] Jan 26, 2021
ca41b67
Merge branch 'release/1.2-dev' into compositional_metric
mergify[bot] Jan 26, 2021
06d1ba4
Merge branch 'release/1.2-dev' into compositional_metric
mergify[bot] Jan 26, 2021
f9f40b7
fix doctest
SkafteNicki Jan 26, 2021
7582850
Merge branch 'release/1.2-dev' into compositional_metric
mergify[bot] Jan 26, 2021
c3e1878
Merge branch 'release/1.2-dev' into compositional_metric
SkafteNicki Jan 26, 2021
4a5c6be
Merge branch 'release/1.2-dev' into compositional_metric
mergify[bot] Jan 26, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- `Recall` and `Precision` metrics (and their functional counterparts `recall` and `precision`) can now be generalized to Recall@K and Precision@K with the use of `top_k` parameter ([#4842](https://github.com/PyTorchLightning/pytorch-lightning/pull/4842))


- Added compositional metrics ([#5464](https://github.com/PyTorchLightning/pytorch-lightning/pull/5464))


### Changed

Expand Down
45 changes: 45 additions & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,51 @@ In practise this means that:
val = metric(pred, target) # this value can be backpropagated
val = metric.compute() # this value cannot be backpropagated

******************
Metric Arithmetics
******************

Metrics support most of python built-in operators for arithmetic, logic and bitwise operations.

For example for a metric that should return the sum of two different metrics, implementing a new metric is an overhead that is not necessary.
It can now be done with:

.. code-block:: python
justusschock marked this conversation as resolved.
Show resolved Hide resolved

first_metric = MyFirstMetric()
second_metric = MySecondMetric()

new_metric = first_metric + second_metric

``new_metric.update(*args, **kwargs)`` now calls update of ``first_metric`` and ``second_metric``. It forwards all positional arguments but
forwards only the keyword arguments that are available in respective metric's update declaration.

Similarly ``new_metric.compute()`` now calls compute of ``first_metric`` and ``second_metric`` and adds the results up.

This pattern is implemented for the following operators (with ``a`` being metrics and ``b`` being metrics, tensors, integer or floats):

* Addition (``a + b``)
* Bitwise AND (``a & b``)
* Equality (``a == b``)
* Floordivision (``a // b``)
* Greater Equal (``a >= b``)
* Greater (``a > b``)
* Less Equal (``a <= b``)
* Less (``a < b``)
* Matrix Multiplication (``a @b ``)
* Modulo (``a % b``)
* Multiplication (``a * b``)
* Inequality (``a != b``)
* Bitwise OR (``a | b``)
* Power (``a ** b``)
* Substraction (``a - b``)
* True Division (``a / b``)
* Bitwise XOR (``a ^ b``)
* Absolute Value (``abs(a)``)
* Inversion (``~a``)
* Negative Value (``neg(a)``)
* Positive Value (``pos(a)``)

****************
MetricCollection
****************
Expand Down
92 changes: 92 additions & 0 deletions pytorch_lightning/metrics/compositional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from typing import Callable, Union

import torch

from pytorch_lightning.metrics.metric import Metric


class CompositionalMetric(Metric):
"""Composition of two metrics with a specific operator
which will be executed upon metric's compute

"""

def __init__(
self,
operator: Callable,
metric_a: Union[Metric, int, float, torch.Tensor],
metric_b: Union[Metric, int, float, torch.Tensor, None],
):
"""

Args:
operator: the operator taking in one (if metric_b is None)
or two arguments. Will be applied to outputs of metric_a.compute()
and (optionally if metric_b is not None) metric_b.compute()
metric_a: first metric whose compute() result is the first argument of operator
metric_b: second metric whose compute() result is the second argument of operator.
For operators taking in only one input, this should be None
"""
super().__init__()

self.op = operator

if isinstance(metric_a, torch.Tensor):
self.register_buffer("metric_a", metric_a)
else:
self.metric_a = metric_a

if isinstance(metric_b, torch.Tensor):
self.register_buffer("metric_b", metric_b)
else:
self.metric_b = metric_b

def _sync_dist(self, dist_sync_fn=None):
# No syncing required here. syncing will be done in metric_a and metric_b
pass

def update(self, *args, **kwargs):
if isinstance(self.metric_a, Metric):
self.metric_a.update(*args, **self.metric_a._filter_kwargs(**kwargs))

if isinstance(self.metric_b, Metric):
self.metric_b.update(*args, **self.metric_b._filter_kwargs(**kwargs))

def compute(self):

# also some parsing for kwargs?
if isinstance(self.metric_a, Metric):
val_a = self.metric_a.compute()
else:
val_a = self.metric_a

if isinstance(self.metric_b, Metric):
val_b = self.metric_b.compute()
else:
val_b = self.metric_b

if val_b is None:
return self.op(val_a)

return self.op(val_a, val_b)

def reset(self):
if isinstance(self.metric_a, Metric):
self.metric_a.reset()

if isinstance(self.metric_b, Metric):
self.metric_b.reset()
justusschock marked this conversation as resolved.
Show resolved Hide resolved

def persistent(self, mode: bool = False):
if isinstance(self.metric_a, Metric):
self.metric_a.persistent(mode=mode)
if isinstance(self.metric_b, Metric):
self.metric_b.persistent(mode=mode)

def __repr__(self):
repr_str = (
self.__class__.__name__
+ f"(\n {self.op.__name__}(\n {repr(self.metric_a)},\n {repr(self.metric_b)}\n )\n)"
)

return repr_str
Loading