Skip to content

Commit

Permalink
Deprecate/compute on step (#792)
Browse files Browse the repository at this point in the history
* adjust code
* update docs
* fix tests

Co-authored-by: Jirka <jirka.borovec@seznam.cz>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Feb 8, 2022
1 parent c0e4250 commit 48dc058
Show file tree
Hide file tree
Showing 74 changed files with 511 additions and 220 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Deprecated

- Deprecated method `compute_on_step` ([#792](https://github.com/PyTorchLightning/metrics/pull/792))


### Removed

Expand Down
2 changes: 1 addition & 1 deletion docs/source/pages/implement.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ The cache is first emptied on the next call to ``update``.

``forward`` serves the dual purpose of both returning the metric on the current data and updating the internal
metric state for accumulating over multiple batches. The ``forward()`` method achieves this by combining calls
to ``update`` and ``compute`` in the following way (assuming metric is initialized with ``compute_on_step=True``):
to ``update`` and ``compute`` in the following way:

1. Calls ``update()`` to update the global metric state (for accumulation over multiple batches)
2. Caches the global state.
Expand Down
20 changes: 14 additions & 6 deletions docs/source/pages/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,32 @@ This metrics API is independent of PyTorch Lightning. Metrics can directly be us
from torchmetrics.classification import Accuracy
train_accuracy = Accuracy()
valid_accuracy = Accuracy(compute_on_step=False)
valid_accuracy = Accuracy()
for epoch in range(epochs):
for x, y in train_data:
y_hat = model(x)
# training step accuracy
batch_acc = train_accuracy(y_hat, y)
print(f"Accuracy of batch{i} is {batch_acc}")
for x, y in valid_data:
y_hat = model(x)
valid_accuracy(y_hat, y)
valid_accuracy.update(y_hat, y)
# total accuracy over all training batches
total_train_accuracy = train_accuracy.compute()
# total accuracy over all training batches
total_train_accuracy = train_accuracy.compute()
# total accuracy over all validation batches
total_valid_accuracy = valid_accuracy.compute()
# total accuracy over all validation batches
total_valid_accuracy = valid_accuracy.compute()
print(f"Training acc for epoch {epoch}: {total_train_accuracy}")
print(f"Validation acc for epoch {epoch}: {total_valid_accuracy}")
# Reset metric states after each epoch
train_accuracy.reset()
valid_accuracy.reset()
.. note::

Expand Down
23 changes: 0 additions & 23 deletions tests/bases/test_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,26 +553,3 @@ def test_compositional_metrics_update():

assert compos.metric_a._num_updates == 3
assert compos.metric_b._num_updates == 3


@pytest.mark.parametrize("compute_on_step", [True, False])
@pytest.mark.parametrize("metric_b", [4, DummyMetric(4)])
def test_compositional_metrics_forward(compute_on_step, metric_b):
"""test forward method of compositional metrics."""
metric_a = DummyMetric(5)
metric_a.compute_on_step = compute_on_step
compos = metric_a + metric_b

assert isinstance(compos, CompositionalMetric)
for _ in range(3):
val = compos()
assert val == 9 if compute_on_step else val is None

assert isinstance(compos.metric_a, DummyMetric)
assert compos.metric_a._num_updates == 3

if isinstance(metric_b, DummyMetric):
assert isinstance(compos.metric_b, DummyMetric)
assert compos.metric_b._num_updates == 3

compos.reset()
2 changes: 1 addition & 1 deletion tests/wrappers/test_multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
self,
base_metric_class,
num_outputs: int = 1,
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Any = None,
dist_sync_fn: Optional[Callable] = None,
Expand Down
54 changes: 36 additions & 18 deletions torchmetrics/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,11 @@ class BaseAggregator(Metric):
- a float: if a float is provided will impude any `nan` values with this value
compute_on_step:
Forward only calls ``update()`` and returns None if this is
set to False.
Forward only calls ``update()`` and returns None if this is set to False.
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
Expand All @@ -59,7 +62,7 @@ def __init__(
fn: Union[Callable, str],
default_value: Union[Tensor, List],
nan_strategy: Union[str, float] = "error",
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
Expand Down Expand Up @@ -120,8 +123,11 @@ class MaxMetric(BaseAggregator):
- a float: if a float is provided will impude any `nan` values with this value
compute_on_step:
Forward only calls ``update()`` and returns None if this is
set to False.
Forward only calls ``update()`` and returns None if this is set to False.
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
Expand All @@ -147,7 +153,7 @@ class MaxMetric(BaseAggregator):
def __init__(
self,
nan_strategy: Union[str, float] = "warn",
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
Expand Down Expand Up @@ -185,8 +191,11 @@ class MinMetric(BaseAggregator):
- a float: if a float is provided will impude any `nan` values with this value
compute_on_step:
Forward only calls ``update()`` and returns None if this is
set to False.
Forward only calls ``update()`` and returns None if this is set to False.
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
Expand All @@ -212,7 +221,7 @@ class MinMetric(BaseAggregator):
def __init__(
self,
nan_strategy: Union[str, float] = "warn",
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
Expand Down Expand Up @@ -250,8 +259,11 @@ class SumMetric(BaseAggregator):
- a float: if a float is provided will impude any `nan` values with this value
compute_on_step:
Forward only calls ``update()`` and returns None if this is
set to False.
Forward only calls ``update()`` and returns None if this is set to False.
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
Expand All @@ -277,7 +289,7 @@ class SumMetric(BaseAggregator):
def __init__(
self,
nan_strategy: Union[str, float] = "warn",
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
Expand Down Expand Up @@ -308,8 +320,11 @@ class CatMetric(BaseAggregator):
- a float: if a float is provided will impude any `nan` values with this value
compute_on_step:
Forward only calls ``update()`` and returns None if this is
set to False.
Forward only calls ``update()`` and returns None if this is set to False.
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
Expand All @@ -335,7 +350,7 @@ class CatMetric(BaseAggregator):
def __init__(
self,
nan_strategy: Union[str, float] = "warn",
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
Expand Down Expand Up @@ -371,8 +386,11 @@ class MeanMetric(BaseAggregator):
- a float: if a float is provided will impude any `nan` values with this value
compute_on_step:
Forward only calls ``update()`` and returns None if this is
set to False.
Forward only calls ``update()`` and returns None if this is set to False.
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
Expand All @@ -398,7 +416,7 @@ class MeanMetric(BaseAggregator):
def __init__(
self,
nan_strategy: Union[str, float] = "warn",
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
Expand Down
8 changes: 6 additions & 2 deletions torchmetrics/audio/pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ class PerceptualEvaluationSpeechQuality(Metric):
keep_same_device:
whether to move the pesq value to the device of preds
compute_on_step:
Forward only calls ``update()`` and return ``None`` if this is set to ``False``.
Forward only calls ``update()`` and returns None if this is set to False.
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step
Expand Down Expand Up @@ -89,7 +93,7 @@ def __init__(
self,
fs: int,
mode: str,
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
Expand Down
6 changes: 5 additions & 1 deletion torchmetrics/audio/pit.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ class PermutationInvariantTraining(Metric):
or the larger the better.
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False.
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
Expand Down Expand Up @@ -78,7 +82,7 @@ def __init__(
self,
metric_func: Callable,
eval_func: str = "max",
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
Expand Down
12 changes: 10 additions & 2 deletions torchmetrics/audio/sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ class SignalDistortionRatio(Metric):
signals may sometimes be zero
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False.
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
Expand Down Expand Up @@ -110,7 +114,7 @@ def __init__(
filter_length: int = 512,
zero_mean: bool = False,
load_diag: Optional[float] = None,
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
Expand Down Expand Up @@ -168,6 +172,10 @@ class ScaleInvariantSignalDistortionRatio(Metric):
if to zero mean target and preds or not
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False.
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
Expand Down Expand Up @@ -206,7 +214,7 @@ class ScaleInvariantSignalDistortionRatio(Metric):
def __init__(
self,
zero_mean: bool = False,
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
Expand Down
12 changes: 10 additions & 2 deletions torchmetrics/audio/snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ class SignalNoiseRatio(Metric):
if to zero mean target and preds or not
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False.
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
Expand Down Expand Up @@ -77,7 +81,7 @@ class SignalNoiseRatio(Metric):
def __init__(
self,
zero_mean: bool = False,
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
Expand Down Expand Up @@ -121,6 +125,10 @@ class ScaleInvariantSignalNoiseRatio(Metric):
Args:
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False.
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
Expand Down Expand Up @@ -159,7 +167,7 @@ class ScaleInvariantSignalNoiseRatio(Metric):

def __init__(
self,
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
Expand Down
6 changes: 5 additions & 1 deletion torchmetrics/audio/stoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ class ShortTimeObjectiveIntelligibility(Metric):
whether to use the extended STOI described in [4]
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False.
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
Expand Down Expand Up @@ -96,7 +100,7 @@ def __init__(
self,
fs: int,
extended: bool = False,
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
Expand Down
8 changes: 6 additions & 2 deletions torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,11 @@ class Accuracy(StatScores):
still applies in both cases, if set.
compute_on_step:
Forward only calls ``update()`` and return ``None`` if this is set to ``False``.
Forward only calls ``update()`` and returns None if this is set to False.
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step
Expand Down Expand Up @@ -179,7 +183,7 @@ def __init__(
top_k: Optional[int] = None,
multiclass: Optional[bool] = None,
subset_accuracy: bool = False,
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
Expand Down
Loading

0 comments on commit 48dc058

Please sign in to comment.