Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate/compute on step #792

Merged
merged 20 commits into from
Feb 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Deprecated

- Deprecated method `compute_on_step` ([#792](https://github.com/PyTorchLightning/metrics/pull/792))


### Removed

Expand Down
2 changes: 1 addition & 1 deletion docs/source/pages/implement.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ The cache is first emptied on the next call to ``update``.

``forward`` serves the dual purpose of both returning the metric on the current data and updating the internal
metric state for accumulating over multiple batches. The ``forward()`` method achieves this by combining calls
to ``update`` and ``compute`` in the following way (assuming metric is initialized with ``compute_on_step=True``):
to ``update`` and ``compute`` in the following way:

1. Calls ``update()`` to update the global metric state (for accumulation over multiple batches)
2. Caches the global state.
Expand Down
20 changes: 14 additions & 6 deletions docs/source/pages/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,32 @@ This metrics API is independent of PyTorch Lightning. Metrics can directly be us
from torchmetrics.classification import Accuracy

train_accuracy = Accuracy()
valid_accuracy = Accuracy(compute_on_step=False)
valid_accuracy = Accuracy()

for epoch in range(epochs):
for x, y in train_data:
y_hat = model(x)

# training step accuracy
batch_acc = train_accuracy(y_hat, y)
print(f"Accuracy of batch{i} is {batch_acc}")

for x, y in valid_data:
y_hat = model(x)
valid_accuracy(y_hat, y)
valid_accuracy.update(y_hat, y)

# total accuracy over all training batches
total_train_accuracy = train_accuracy.compute()
# total accuracy over all training batches
total_train_accuracy = train_accuracy.compute()

# total accuracy over all validation batches
total_valid_accuracy = valid_accuracy.compute()
# total accuracy over all validation batches
total_valid_accuracy = valid_accuracy.compute()

print(f"Training acc for epoch {epoch}: {total_train_accuracy}")
print(f"Validation acc for epoch {epoch}: {total_valid_accuracy}")

# Reset metric states after each epoch
train_accuracy.reset()
valid_accuracy.reset()

.. note::

Expand Down
23 changes: 0 additions & 23 deletions tests/bases/test_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,26 +553,3 @@ def test_compositional_metrics_update():

assert compos.metric_a._num_updates == 3
assert compos.metric_b._num_updates == 3


@pytest.mark.parametrize("compute_on_step", [True, False])
@pytest.mark.parametrize("metric_b", [4, DummyMetric(4)])
def test_compositional_metrics_forward(compute_on_step, metric_b):
"""test forward method of compositional metrics."""
metric_a = DummyMetric(5)
metric_a.compute_on_step = compute_on_step
compos = metric_a + metric_b

assert isinstance(compos, CompositionalMetric)
for _ in range(3):
val = compos()
assert val == 9 if compute_on_step else val is None

assert isinstance(compos.metric_a, DummyMetric)
assert compos.metric_a._num_updates == 3

if isinstance(metric_b, DummyMetric):
assert isinstance(compos.metric_b, DummyMetric)
assert compos.metric_b._num_updates == 3

compos.reset()
2 changes: 1 addition & 1 deletion tests/wrappers/test_multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
self,
base_metric_class,
num_outputs: int = 1,
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Any = None,
dist_sync_fn: Optional[Callable] = None,
Expand Down
54 changes: 36 additions & 18 deletions torchmetrics/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,11 @@ class BaseAggregator(Metric):
- a float: if a float is provided will impude any `nan` values with this value

compute_on_step:
Forward only calls ``update()`` and returns None if this is
set to False.
Forward only calls ``update()`` and returns None if this is set to False.

.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.

dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
Expand All @@ -59,7 +62,7 @@ def __init__(
fn: Union[Callable, str],
default_value: Union[Tensor, List],
nan_strategy: Union[str, float] = "error",
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
Expand Down Expand Up @@ -120,8 +123,11 @@ class MaxMetric(BaseAggregator):
- a float: if a float is provided will impude any `nan` values with this value

compute_on_step:
Forward only calls ``update()`` and returns None if this is
set to False.
Forward only calls ``update()`` and returns None if this is set to False.

.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.

dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
Expand All @@ -147,7 +153,7 @@ class MaxMetric(BaseAggregator):
def __init__(
self,
nan_strategy: Union[str, float] = "warn",
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
Expand Down Expand Up @@ -185,8 +191,11 @@ class MinMetric(BaseAggregator):
- a float: if a float is provided will impude any `nan` values with this value

compute_on_step:
Forward only calls ``update()`` and returns None if this is
set to False.
Forward only calls ``update()`` and returns None if this is set to False.

.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.

dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
Expand All @@ -212,7 +221,7 @@ class MinMetric(BaseAggregator):
def __init__(
self,
nan_strategy: Union[str, float] = "warn",
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
Expand Down Expand Up @@ -250,8 +259,11 @@ class SumMetric(BaseAggregator):
- a float: if a float is provided will impude any `nan` values with this value

compute_on_step:
Forward only calls ``update()`` and returns None if this is
set to False.
Forward only calls ``update()`` and returns None if this is set to False.

.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.

dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
Expand All @@ -277,7 +289,7 @@ class SumMetric(BaseAggregator):
def __init__(
self,
nan_strategy: Union[str, float] = "warn",
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
Expand Down Expand Up @@ -308,8 +320,11 @@ class CatMetric(BaseAggregator):
- a float: if a float is provided will impude any `nan` values with this value

compute_on_step:
Forward only calls ``update()`` and returns None if this is
set to False.
Forward only calls ``update()`` and returns None if this is set to False.

.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.

dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
Expand All @@ -335,7 +350,7 @@ class CatMetric(BaseAggregator):
def __init__(
self,
nan_strategy: Union[str, float] = "warn",
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
Expand Down Expand Up @@ -371,8 +386,11 @@ class MeanMetric(BaseAggregator):
- a float: if a float is provided will impude any `nan` values with this value

compute_on_step:
Forward only calls ``update()`` and returns None if this is
set to False.
Forward only calls ``update()`` and returns None if this is set to False.

.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.

dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
Expand All @@ -398,7 +416,7 @@ class MeanMetric(BaseAggregator):
def __init__(
self,
nan_strategy: Union[str, float] = "warn",
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
Expand Down
8 changes: 6 additions & 2 deletions torchmetrics/audio/pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ class PerceptualEvaluationSpeechQuality(Metric):
keep_same_device:
whether to move the pesq value to the device of preds
compute_on_step:
Forward only calls ``update()`` and return ``None`` if this is set to ``False``.
Forward only calls ``update()`` and returns None if this is set to False.

.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.

dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step
Expand Down Expand Up @@ -89,7 +93,7 @@ def __init__(
self,
fs: int,
mode: str,
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
Expand Down
6 changes: 5 additions & 1 deletion torchmetrics/audio/pit.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ class PermutationInvariantTraining(Metric):
or the larger the better.
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False.

.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.

dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
Expand Down Expand Up @@ -78,7 +82,7 @@ def __init__(
self,
metric_func: Callable,
eval_func: str = "max",
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
Expand Down
12 changes: 10 additions & 2 deletions torchmetrics/audio/sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ class SignalDistortionRatio(Metric):
signals may sometimes be zero
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False.

.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.

dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
Expand Down Expand Up @@ -110,7 +114,7 @@ def __init__(
filter_length: int = 512,
zero_mean: bool = False,
load_diag: Optional[float] = None,
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
Expand Down Expand Up @@ -168,6 +172,10 @@ class ScaleInvariantSignalDistortionRatio(Metric):
if to zero mean target and preds or not
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False.

.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.

dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
Expand Down Expand Up @@ -206,7 +214,7 @@ class ScaleInvariantSignalDistortionRatio(Metric):
def __init__(
self,
zero_mean: bool = False,
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
Expand Down
12 changes: 10 additions & 2 deletions torchmetrics/audio/snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ class SignalNoiseRatio(Metric):
if to zero mean target and preds or not
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False.

.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.

dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
Expand Down Expand Up @@ -77,7 +81,7 @@ class SignalNoiseRatio(Metric):
def __init__(
self,
zero_mean: bool = False,
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
Expand Down Expand Up @@ -121,6 +125,10 @@ class ScaleInvariantSignalNoiseRatio(Metric):
Args:
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False.

.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.

dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
Expand Down Expand Up @@ -159,7 +167,7 @@ class ScaleInvariantSignalNoiseRatio(Metric):

def __init__(
self,
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
Expand Down
6 changes: 5 additions & 1 deletion torchmetrics/audio/stoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ class ShortTimeObjectiveIntelligibility(Metric):
whether to use the extended STOI described in [4]
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False.

.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.

dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
Expand Down Expand Up @@ -96,7 +100,7 @@ def __init__(
self,
fs: int,
extended: bool = False,
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
Expand Down
8 changes: 6 additions & 2 deletions torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,11 @@ class Accuracy(StatScores):
still applies in both cases, if set.

compute_on_step:
Forward only calls ``update()`` and return ``None`` if this is set to ``False``.
Forward only calls ``update()`` and returns None if this is set to False.

.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.

dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step
Expand Down Expand Up @@ -179,7 +183,7 @@ def __init__(
top_k: Optional[int] = None,
multiclass: Optional[bool] = None,
subset_accuracy: bool = False,
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
Expand Down
Loading