Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/remove double forward #984

Merged
merged 50 commits into from
May 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
1a89f1d
something is working
SkafteNicki Apr 20, 2022
d18b411
better
SkafteNicki Apr 21, 2022
ce221a3
fix tests
SkafteNicki Apr 25, 2022
335021b
Merge branch 'master' into refactor/remove_double_forward
SkafteNicki Apr 25, 2022
969da72
docstring
SkafteNicki Apr 25, 2022
d72dfa4
update docs
SkafteNicki Apr 25, 2022
77729b7
fix
SkafteNicki Apr 25, 2022
25fd83f
Merge branch 'master' into refactor/remove_double_forward
SkafteNicki Apr 26, 2022
f46cfa0
fix issue
SkafteNicki Apr 26, 2022
30c356a
Merge branch 'refactor/remove_double_forward' of https://github.com/P…
SkafteNicki Apr 26, 2022
b411548
fix some tests
SkafteNicki Apr 26, 2022
285cba4
Merge branch 'master' into refactor/remove_double_forward
SkafteNicki Apr 27, 2022
4f9d75b
introduce class property
SkafteNicki Apr 28, 2022
e2cca16
docs
SkafteNicki Apr 28, 2022
279a83c
Merge branch 'refactor/remove_double_forward' of https://github.com/P…
SkafteNicki Apr 28, 2022
967bf30
changelog
SkafteNicki Apr 28, 2022
ddc237b
fix docs
SkafteNicki Apr 28, 2022
ff011b8
update docs
SkafteNicki Apr 28, 2022
2abc26e
rename and re-order
SkafteNicki Apr 28, 2022
398521a
Merge branch 'master' into refactor/remove_double_forward
SkafteNicki Apr 28, 2022
99ff43a
fix list
SkafteNicki Apr 28, 2022
3b0c8cf
Merge branch 'refactor/remove_double_forward' of https://github.com/P…
SkafteNicki Apr 28, 2022
52a4615
Merge branch 'master' into refactor/remove_double_forward
SkafteNicki Apr 28, 2022
0ba7c37
Merge branch 'master' into refactor/remove_double_forward
SkafteNicki Apr 28, 2022
ff75e70
change impl
SkafteNicki May 4, 2022
274078f
Merge branch 'master' into refactor/remove_double_forward
SkafteNicki May 4, 2022
2eff3b6
fix tests
SkafteNicki May 4, 2022
87fa28f
regression
SkafteNicki May 4, 2022
fe58dc3
fix typing
SkafteNicki May 4, 2022
0670117
fix lightning integration
SkafteNicki May 4, 2022
08299ab
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 4, 2022
416b057
remove from test metric
SkafteNicki May 4, 2022
90c2254
update count
SkafteNicki May 5, 2022
7058907
Merge branch 'master' into refactor/remove_double_forward
SkafteNicki May 5, 2022
2790547
Merge branch 'master' into refactor/remove_double_forward
mergify[bot] May 5, 2022
15a4616
Merge branch 'master' into refactor/remove_double_forward
mergify[bot] May 5, 2022
a16b27b
Merge branch 'master' into refactor/remove_double_forward
Borda May 5, 2022
a310496
Merge branch 'master' into refactor/remove_double_forward
mergify[bot] May 5, 2022
5d4977f
Merge branch 'master' into refactor/remove_double_forward
mergify[bot] May 6, 2022
708c848
Merge branch 'master' into refactor/remove_double_forward
SkafteNicki May 6, 2022
c086f2c
audio
SkafteNicki May 6, 2022
5e8ebf1
fix mean reduction
SkafteNicki May 7, 2022
2ef10a6
refactor forward
SkafteNicki May 7, 2022
00d7fe8
classification
SkafteNicki May 7, 2022
aebff09
detection
SkafteNicki May 7, 2022
fcc5461
image
SkafteNicki May 7, 2022
3c5cacf
retrieval
SkafteNicki May 7, 2022
050cf86
text
SkafteNicki May 7, 2022
b8a7e17
Merge branch 'master' into refactor/remove_double_forward
mergify[bot] May 7, 2022
fc1e2c4
Merge branch 'master' into refactor/remove_double_forward
mergify[bot] May 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added support for nested metric collections ([#1003](https://github.com/PyTorchLightning/metrics/pull/1003))


### Changed
- Added class property `full_state_update` that determines `forward` should call `update` once or twice ([#984](https://github.com/PyTorchLightning/metrics/pull/984))


- Added support for nested metric collections ([#1003](https://github.com/PyTorchLightning/metrics/pull/1003))

-

### Changed

-

Expand Down
61 changes: 49 additions & 12 deletions docs/source/pages/implement.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
.. _implement:

.. testsetup:: *

from typing import Optional

*********************
Implementing a Metric
*********************
Expand Down Expand Up @@ -40,6 +44,27 @@ Example implementation:
return self.correct.float() / self.total


Additionally you may want to set the class properties: `is_differentiable`, `higher_is_better` and
`full_state_update`. Note that none of them are strictly required for the metric to work.

.. testcode::

from torchmetrics import Metric

class MyMetric(Metric):
# Set to True if the metric is differentiable else set to False
is_differentiable: Optional[bool] = None

# Set to True if the metric reaches it optimal value when the metric is maximized.
# Set to False if it when the metric is minimized.
higher_is_better: Optional[bool] = True

# Set to True if the metric during 'update' requires access to the global metric
# state for its calculations. If not, setting this to False indicates that all
# batch states are independent and we will optimize the runtime of 'forward'
full_state_update: bool = True


Internal implementation details
-------------------------------

Expand All @@ -64,18 +89,30 @@ The cache is first emptied on the next call to ``update``.

``forward`` serves the dual purpose of both returning the metric on the current data and updating the internal
metric state for accumulating over multiple batches. The ``forward()`` method achieves this by combining calls
to ``update`` and ``compute`` in the following way:

1. Calls ``update()`` to update the global metric state (for accumulation over multiple batches)
2. Caches the global state.
3. Calls ``reset()`` to clear global metric state.
4. Calls ``update()`` to update local metric state.
5. Calls ``compute()`` to calculate metric for current batch.
6. Restores the global state.

This procedure has the consequence of calling the user defined ``update`` **twice** during a single
forward call (one to update global statistics and one for getting the batch statistics).

to ``update``, ``compute`` and ``reset``. Depending on the class property ``full_state_update``, ``forward``
can behave in two ways:

1. If ``full_state_update`` is ``True`` it indicates that the metric during ``update`` requires access to the full
metric state and we therefore need to do two calls to ``update`` to secure that the metric is calculated correctly

1. Calls ``update()`` to update the global metric state (for accumulation over multiple batches)
2. Caches the global state.
3. Calls ``reset()`` to clear global metric state.
4. Calls ``update()`` to update local metric state.
5. Calls ``compute()`` to calculate metric for current batch.
6. Restores the global state.

2. If ``full_state_update`` is ``False`` (default) the metric state of one batch is completly independent of the state of
other batches, which means that we only need to call ``update`` once.

1. Caches the global state.
2. Calls ``reset`` the metric to its default state
3. Calls ``update`` to update the state with local batch statistics
4. Calls ``compute`` to calculate the metric for the current batch
5. Reduce the global state and batch state into a single state that becomes the new global state

If implementing your own metric, we recommend trying out the metric with ``full_state_update`` class property set to
both ``True`` and ``False``. If the results are equal, then setting it to ``False`` will usually give the best performance.

---------

Expand Down
1 change: 0 additions & 1 deletion tests/bases/test_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def __init__(self, val_to_return):
super().__init__()
self.add_state("_num_updates", tensor(0), dist_reduce_fx="sum")
self._val_to_return = val_to_return
self._update_called = True

def update(self, *args, **kwargs) -> None:
self._num_updates += 1
Expand Down
4 changes: 2 additions & 2 deletions tests/helpers/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ class DummyMetric(Metric):

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.add_state("x", tensor(0.0), dist_reduce_fx=None)
self.add_state("x", tensor(0.0), dist_reduce_fx="sum")

def update(self):
pass
Expand All @@ -586,7 +586,7 @@ class DummyListMetric(Metric):

def __init__(self):
super().__init__()
self.add_state("x", [], dist_reduce_fx=None)
self.add_state("x", [], dist_reduce_fx="cat")

def update(self):
pass
Expand Down
5 changes: 3 additions & 2 deletions torchmetrics/audio/pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,9 @@ class PerceptualEvaluationSpeechQuality(Metric):

sum_pesq: Tensor
total: Tensor
is_differentiable = False
higher_is_better = True
full_state_update: bool = False
is_differentiable: bool = False
higher_is_better: bool = True

def __init__(
self,
Expand Down
3 changes: 2 additions & 1 deletion torchmetrics/audio/pit.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ class PermutationInvariantTraining(Metric):
Signal Process. ICASSP, IEEE, New Orleans, LA, 2017: pp. 241–245. https://doi.org/10.1109/ICASSP.2017.7952154.
"""

is_differentiable = True
full_state_update: bool = False
is_differentiable: bool = True
sum_pit_metric: Tensor
total: Tensor

Expand Down
5 changes: 3 additions & 2 deletions torchmetrics/audio/sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,9 @@ class SignalDistortionRatio(Metric):

sum_sdr: Tensor
total: Tensor
is_differentiable = True
higher_is_better = True
full_state_update: bool = False
is_differentiable: bool = True
higher_is_better: bool = True

def __init__(
self,
Expand Down
5 changes: 3 additions & 2 deletions torchmetrics/audio/snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,9 @@ class SignalNoiseRatio(Metric):
and Signal Processing (ICASSP) 2019.

"""
is_differentiable = True
higher_is_better = True
full_state_update: bool = False
is_differentiable: bool = True
higher_is_better: bool = True
sum_snr: Tensor
total: Tensor

Expand Down
5 changes: 3 additions & 2 deletions torchmetrics/audio/stoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,9 @@ class ShortTimeObjectiveIntelligibility(Metric):
"""
sum_stoi: Tensor
total: Tensor
is_differentiable = False
higher_is_better = True
full_state_update: bool = False
is_differentiable: bool = False
higher_is_better: bool = True

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ class Accuracy(StatScores):
"""
is_differentiable = False
higher_is_better = True
full_state_update: bool = False
correct: Tensor
total: Tensor

Expand Down
6 changes: 4 additions & 2 deletions torchmetrics/classification/auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

from torch import Tensor

Expand All @@ -35,7 +35,9 @@ class AUC(Metric):

kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
"""
is_differentiable = False
is_differentiable: bool = False
higher_is_better: Optional[bool] = None
full_state_update: bool = False
x: List[Tensor]
y: List[Tensor]

Expand Down
5 changes: 3 additions & 2 deletions torchmetrics/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,9 @@ class AUROC(Metric):
tensor(0.7778)

"""
is_differentiable = False
higher_is_better = True
is_differentiable: bool = False
higher_is_better: bool = True
full_state_update: bool = False
preds: List[Tensor]
target: List[Tensor]

Expand Down
4 changes: 3 additions & 1 deletion torchmetrics/classification/avg_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ class AveragePrecision(Metric):
[tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)]
"""

is_differentiable = False
is_differentiable: bool = False
higher_is_better: Optional[bool] = None
full_state_update: bool = False
preds: List[Tensor]
target: List[Tensor]

Expand Down
5 changes: 4 additions & 1 deletion torchmetrics/classification/binned_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -109,6 +109,9 @@ class BinnedPrecisionRecallCurve(Metric):
tensor([0.0000, 0.5000, 1.0000])]
"""

is_differentiable: bool = False
higher_is_better: Optional[bool] = None
full_state_update: bool = False
TPs: Tensor
FPs: Tensor
FNs: Tensor
Expand Down
4 changes: 3 additions & 1 deletion torchmetrics/classification/calibration_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,10 @@ class CalibrationError(Metric):

kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
"""
is_differentiable: bool = False
higher_is_better: bool = False
full_state_update: bool = False
DISTANCES = {"l1", "l2", "max"}
higher_is_better = False
confidences: List[Tensor]
accuracies: List[Tensor]

Expand Down
5 changes: 3 additions & 2 deletions torchmetrics/classification/cohen_kappa.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,9 @@ class labels.
tensor(0.5000)

"""
is_differentiable = False
higher_is_better = True
is_differentiable: bool = False
higher_is_better: bool = True
full_state_update: bool = False
confmat: Tensor

def __init__(
Expand Down
4 changes: 3 additions & 1 deletion torchmetrics/classification/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ class ConfusionMatrix(Metric):
[[0, 1], [0, 1]]])

"""
is_differentiable = False
is_differentiable: bool = False
higher_is_better: Optional[bool] = None
full_state_update: bool = False
confmat: Tensor

def __init__(
Expand Down
6 changes: 4 additions & 2 deletions torchmetrics/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ class FBetaScore(StatScores):
tensor(0.3333)

"""
full_state_update: bool = False

def __init__(
self,
Expand Down Expand Up @@ -242,8 +243,9 @@ class F1Score(FBetaScore):
tensor(0.3333)
"""

is_differentiable = False
higher_is_better = True
is_differentiable: bool = False
higher_is_better: bool = True
full_state_update: bool = False

def __init__(
self,
Expand Down
5 changes: 3 additions & 2 deletions torchmetrics/classification/hamming.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,9 @@ class HammingDistance(Metric):
tensor(0.2500)

"""
is_differentiable = False
higher_is_better = False
is_differentiable: bool = False
higher_is_better: bool = False
full_state_update: bool = False
correct: Tensor
total: Tensor

Expand Down
5 changes: 3 additions & 2 deletions torchmetrics/classification/hinge.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,9 @@ class HingeLoss(Metric):
tensor([2.2333, 1.5000, 1.2333])

"""
is_differentiable = True
higher_is_better = False
is_differentiable: bool = True
higher_is_better: bool = False
full_state_update: bool = False
measure: Tensor
total: Tensor

Expand Down
5 changes: 3 additions & 2 deletions torchmetrics/classification/jaccard.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,9 @@ class JaccardIndex(ConfusionMatrix):
tensor(0.9660)

"""
is_differentiable = False
higher_is_better = True
is_differentiable: bool = False
higher_is_better: bool = True
full_state_update: bool = False

def __init__(
self,
Expand Down
7 changes: 3 additions & 4 deletions torchmetrics/classification/kl_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,9 @@ class KLDivergence(Metric):
tensor(0.0853)

"""
is_differentiable = True
higher_is_better = False
# TODO: canot be used because if scripting
# measures: Union[List[Tensor], Tensor]
is_differentiable: bool = True
higher_is_better: bool = False
full_state_update: bool = False
total: Tensor

def __init__(
Expand Down
5 changes: 3 additions & 2 deletions torchmetrics/classification/matthews_corrcoef.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ class MatthewsCorrCoef(Metric):
tensor(0.5774)

"""
is_differentiable = False
higher_is_better = True
is_differentiable: bool = False
higher_is_better: bool = True
full_state_update: bool = False
confmat: Tensor

def __init__(
Expand Down
6 changes: 4 additions & 2 deletions torchmetrics/classification/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ class Precision(StatScores):
"""
is_differentiable = False
higher_is_better = True
full_state_update: bool = False

def __init__(
self,
Expand Down Expand Up @@ -242,8 +243,9 @@ class Recall(StatScores):
tensor(0.2500)

"""
is_differentiable = False
higher_is_better = True
is_differentiable: bool = False
higher_is_better: bool = True
full_state_update: bool = False

def __init__(
self,
Expand Down
4 changes: 3 additions & 1 deletion torchmetrics/classification/precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ class PrecisionRecallCurve(Metric):
[tensor(0.7500), tensor(0.7500), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor(0.0500)]
"""

is_differentiable = False
is_differentiable: bool = False
higher_is_better: Optional[bool] = None
full_state_update: bool = False
preds: List[Tensor]
target: List[Tensor]

Expand Down
Loading