Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate/compute on step #792

Merged
merged 20 commits into from
Feb 8, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/pages/implement.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ The cache is first emptied on the next call to ``update``.

``forward`` serves the dual purpose of both returning the metric on the current data and updating the internal
metric state for accumulating over multiple batches. The ``forward()`` method achieves this by combining calls
to ``update`` and ``compute`` in the following way (assuming metric is initialized with ``compute_on_step=True``):
to ``update`` and ``compute`` in the following way:

1. Calls ``update()`` to update the global metric state (for accumulation over multiple batches)
2. Caches the global state.
Expand Down
20 changes: 14 additions & 6 deletions docs/source/pages/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,32 @@ This metrics API is independent of PyTorch Lightning. Metrics can directly be us
from torchmetrics.classification import Accuracy

train_accuracy = Accuracy()
valid_accuracy = Accuracy(compute_on_step=False)
valid_accuracy = Accuracy()

for epoch in range(epochs):
for x, y in train_data:
y_hat = model(x)

# training step accuracy
batch_acc = train_accuracy(y_hat, y)
print(f"Accuracy of batch{i} is {batch_acc}")

for x, y in valid_data:
y_hat = model(x)
valid_accuracy(y_hat, y)
valid_accuracy.update(y_hat, y)

# total accuracy over all training batches
total_train_accuracy = train_accuracy.compute()
# total accuracy over all training batches
total_train_accuracy = train_accuracy.compute()

# total accuracy over all validation batches
total_valid_accuracy = valid_accuracy.compute()
# total accuracy over all validation batches
total_valid_accuracy = valid_accuracy.compute()

print(f"Training acc for epoch {epoch}: {total_train_accuracy}")
print(f"Validation acc for epoch {epoch}: {total_valid_accuracy}")

# Reset metric states after each epoch
train_accuracy.reset()
valid_accuracy.reset()

.. note::

Expand Down
2 changes: 1 addition & 1 deletion tests/wrappers/test_multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
self,
base_metric_class,
num_outputs: int = 1,
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Any = None,
dist_sync_fn: Optional[Callable] = None,
Expand Down
30 changes: 18 additions & 12 deletions torchmetrics/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ class BaseAggregator(Metric):

compute_on_step:
Forward only calls ``update()`` and returns None if this is
set to False.
set to False. Argument has no use anymore. Deprecated in v0.8 and
will be removed v0.9.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
Expand All @@ -59,7 +60,7 @@ def __init__(
fn: Union[Callable, str],
default_value: Union[Tensor, List],
nan_strategy: Union[str, float] = "error",
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
Expand Down Expand Up @@ -121,7 +122,8 @@ class MaxMetric(BaseAggregator):

compute_on_step:
Forward only calls ``update()`` and returns None if this is
set to False.
set to False. Argument has no use anymore. Deprecated in v0.8 and
will be removed v0.9.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
Expand All @@ -147,7 +149,7 @@ class MaxMetric(BaseAggregator):
def __init__(
self,
nan_strategy: Union[str, float] = "warn",
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
Expand Down Expand Up @@ -186,7 +188,8 @@ class MinMetric(BaseAggregator):

compute_on_step:
Forward only calls ``update()`` and returns None if this is
set to False.
set to False. Argument has no use anymore. Deprecated in v0.8 and
will be removed v0.9.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
Expand All @@ -212,7 +215,7 @@ class MinMetric(BaseAggregator):
def __init__(
self,
nan_strategy: Union[str, float] = "warn",
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
Expand Down Expand Up @@ -251,7 +254,8 @@ class SumMetric(BaseAggregator):

compute_on_step:
Forward only calls ``update()`` and returns None if this is
set to False.
set to False. Argument has no use anymore. Deprecated in v0.8 and
will be removed v0.9.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
Expand All @@ -277,7 +281,7 @@ class SumMetric(BaseAggregator):
def __init__(
self,
nan_strategy: Union[str, float] = "warn",
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
Expand Down Expand Up @@ -309,7 +313,8 @@ class CatMetric(BaseAggregator):

compute_on_step:
Forward only calls ``update()`` and returns None if this is
set to False.
set to False. Argument has no use anymore. Deprecated in v0.8 and
will be removed v0.9.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
Expand All @@ -335,7 +340,7 @@ class CatMetric(BaseAggregator):
def __init__(
self,
nan_strategy: Union[str, float] = "warn",
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
Expand Down Expand Up @@ -372,7 +377,8 @@ class MeanMetric(BaseAggregator):

compute_on_step:
Forward only calls ``update()`` and returns None if this is
set to False.
set to False. Argument has no use anymore. Deprecated in v0.8 and
will be removed v0.9.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
Expand All @@ -398,7 +404,7 @@ class MeanMetric(BaseAggregator):
def __init__(
self,
nan_strategy: Union[str, float] = "warn",
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
Expand Down
8 changes: 5 additions & 3 deletions torchmetrics/audio/pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ class PerceptualEvaluationSpeechQuality(Metric):
keep_same_device:
whether to move the pesq value to the device of preds
compute_on_step:
Forward only calls ``update()`` and return ``None`` if this is set to ``False``.
Forward only calls ``update()`` and returns None if this is
set to False. Argument has no use anymore. Deprecated in v0.8 and
will be removed v0.9.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step
Expand Down Expand Up @@ -91,7 +93,7 @@ def __init__(
self,
fs: int,
mode: str,
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
Expand Down Expand Up @@ -160,7 +162,7 @@ def __init__(
self,
fs: int,
mode: str,
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
Expand Down
8 changes: 5 additions & 3 deletions torchmetrics/audio/pit.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ class PermutationInvariantTraining(Metric):
the function to find the best permutation, can be 'min' or 'max', i.e. the smaller the better
or the larger the better.
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False.
Forward only calls ``update()`` and returns None if this is
set to False. Argument has no use anymore. Deprecated in v0.8 and
will be removed v0.9.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
Expand Down Expand Up @@ -80,7 +82,7 @@ def __init__(
self,
metric_func: Callable,
eval_func: str = "max",
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
Expand Down Expand Up @@ -143,7 +145,7 @@ def __init__(
self,
metric_func: Callable,
eval_func: str = "max",
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
Expand Down
14 changes: 9 additions & 5 deletions torchmetrics/audio/sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ class SignalDistortionRatio(Metric):
This can help stabilize the metric in the case where some of the reference
signals may sometimes be zero
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False.
Forward only calls ``update()`` and returns None if this is
set to False. Argument has no use anymore. Deprecated in v0.8 and
will be removed v0.9.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
Expand Down Expand Up @@ -112,7 +114,7 @@ def __init__(
filter_length: int = 512,
zero_mean: bool = False,
load_diag: Optional[float] = None,
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
Expand Down Expand Up @@ -183,7 +185,7 @@ def __init__(
filter_length: int = 512,
zero_mean: bool = False,
load_diag: Optional[float] = None,
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
Expand Down Expand Up @@ -213,7 +215,9 @@ class ScaleInvariantSignalDistortionRatio(Metric):
zero_mean:
if to zero mean target and preds or not
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False.
Forward only calls ``update()`` and returns None if this is
set to False. Argument has no use anymore. Deprecated in v0.8 and
will be removed v0.9.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
Expand Down Expand Up @@ -252,7 +256,7 @@ class ScaleInvariantSignalDistortionRatio(Metric):
def __init__(
self,
zero_mean: bool = False,
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/audio/si_sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class SI_SDR(ScaleInvariantSignalDistortionRatio):
def __init__(
self,
zero_mean: bool = False,
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/audio/si_snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class SI_SNR(ScaleInvariantSignalNoiseRatio):
@deprecated(target=ScaleInvariantSignalNoiseRatio, deprecated_in="0.7", remove_in="0.8", stream=_future_warning)
def __init__(
self,
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
Expand Down
14 changes: 9 additions & 5 deletions torchmetrics/audio/snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ class SignalNoiseRatio(Metric):
zero_mean:
if to zero mean target and preds or not
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False.
Forward only calls ``update()`` and returns None if this is
set to False. Argument has no use anymore. Deprecated in v0.8 and
will be removed v0.9.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
Expand Down Expand Up @@ -79,7 +81,7 @@ class SignalNoiseRatio(Metric):
def __init__(
self,
zero_mean: bool = False,
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
Expand Down Expand Up @@ -130,7 +132,7 @@ class SNR(SignalNoiseRatio):
def __init__(
self,
zero_mean: bool = False,
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
Expand All @@ -148,7 +150,9 @@ class ScaleInvariantSignalNoiseRatio(Metric):

Args:
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False.
Forward only calls ``update()`` and returns None if this is
set to False. Argument has no use anymore. Deprecated in v0.8 and
will be removed v0.9.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
Expand Down Expand Up @@ -187,7 +191,7 @@ class ScaleInvariantSignalNoiseRatio(Metric):

def __init__(
self,
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
Expand Down
8 changes: 5 additions & 3 deletions torchmetrics/audio/stoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ class ShortTimeObjectiveIntelligibility(Metric):
extended:
whether to use the extended STOI described in [4]
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False.
Forward only calls ``update()`` and returns None if this is
set to False. Argument has no use anymore. Deprecated in v0.8 and
will be removed v0.9.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
Expand Down Expand Up @@ -98,7 +100,7 @@ def __init__(
self,
fs: int,
extended: bool = False,
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
Expand Down Expand Up @@ -160,7 +162,7 @@ def __init__(
self,
fs: int,
extended: bool = False,
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
Expand Down
6 changes: 4 additions & 2 deletions torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,9 @@ class Accuracy(StatScores):
still applies in both cases, if set.

compute_on_step:
Forward only calls ``update()`` and return ``None`` if this is set to ``False``.
Forward only calls ``update()`` and returns None if this is
set to False. Argument has no use anymore. Deprecated in v0.8 and
will be removed v0.9.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step
Expand Down Expand Up @@ -179,7 +181,7 @@ def __init__(
top_k: Optional[int] = None,
multiclass: Optional[bool] = None,
subset_accuracy: bool = False,
compute_on_step: bool = True,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
Expand Down
Loading