Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Remove the requirement for double update in forward #612

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Optimize the Metric forward method to perform only 1 update ([#612](https://github.com/PyTorchLightning/metrics/pull/612))

### Deprecated

Expand Down
83 changes: 77 additions & 6 deletions tests/bases/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from tests.helpers import seed_all
from tests.helpers.testers import DummyListMetric, DummyMetric, DummyMetricMultiOutput, DummyMetricSum
from torchmetrics import Metric
from torchmetrics.utilities.imports import _TORCH_LOWER_1_6

seed_all(42)
Expand Down Expand Up @@ -169,14 +170,52 @@ class B(DummyListMetric):
assert hash(b1) != hash(b2)


def test_forward():
class A(DummyMetric):
def update(self, x):
self.x += x
class DummyMetricReduce(Metric):
name = "Dummy"

def compute(self):
return self.x
def __init__(self):
super().__init__()
self.add_state("x", tensor(0.0), dist_reduce_fx="sum")

def update(self, x):
x = torch.tensor(x)
self.x += x

def compute(self):
return self.x


class DummyListMetricReduce(Metric):
name = "DummyList"

def __init__(self):
super().__init__()
self.add_state("x", [], dist_reduce_fx="cat")

def update(self, x):
self.x.append(torch.tensor(x))

def compute(self):
return torch.tensor(self.x).sum()


class CustomDummyMetric(Metric):
name = "Dummy"

def __init__(self):
super().__init__()
self.add_state("x", tensor(0.0))

def update(self, x):
x = torch.tensor(x)
self.x += x

def compute(self):
return torch.sum(self.x)


@pytest.mark.parametrize("A", [CustomDummyMetric, DummyListMetricReduce, DummyMetricReduce])
def test_forward(A):
a = A()
assert a(5) == 5
assert a._forward_cache == 5
Expand All @@ -186,6 +225,38 @@ def compute(self):

assert a.compute() == 13

assert a(1) == 1
assert a._forward_cache == 1

assert a.compute() == 14


class BatchMetric(Metric):
def __init__(self):
super().__init__()
self.add_state("x", tensor(0.0), dist_reduce_fx=None)
self.add_state("y", [], dist_reduce_fx=None)

def update(self, x, y):
self.x = self.x + x
self.y.append(y)

def compute(self):
return self.x.sum() + sum(y.sum() for y in self.y)


@pytest.mark.parametrize("A", [BatchMetric])
def test_forward_batch(A):
a = A()
x = torch.ones(2, 2)
assert a(x, x) == 8
assert a._forward_cache == 8

assert a(x, x) == 8
assert a._forward_cache == 8

assert a.compute() == 16


def test_pickle(tmpdir):
# doesn't tests for DDP
Expand Down
52 changes: 40 additions & 12 deletions torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,33 +198,61 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
"HINT: Did you forget to call ``unsync`` ?."
)

with torch.no_grad():
self.update(*args, **kwargs)
accumulated_state = {}
update_called = False
if self.compute_on_step:
update_called = self._update_called
# the `accumulated_state` should be captured only if an update has already been performed.
# Otherwise, `accumulated_state` would be the default states.
if update_called:
accumulated_state = {attr: getattr(self, attr) for attr in self._defaults.keys()}
self.reset()

self.update(*args, **kwargs)

if self.compute_on_step:
self._to_sync = self.dist_sync_on_step
# skip restore cache operation from compute as cache is stored below.
self._should_unsync = False

# save context before switch
cache = {attr: getattr(self, attr) for attr in self._defaults}

# call reset, update, compute, on single batch
self.reset()
self.update(*args, **kwargs)
self._forward_cache = self.compute()

# restore context
for attr, val in cache.items():
setattr(self, attr, val)
self._is_synced = False
if update_called:
batch_state = {attr: getattr(self, attr) for attr in self._reductions}
with torch.no_grad():
self._reduce_states([batch_state, accumulated_state])

self._is_synced = False
self._should_unsync = True
self._to_sync = True
self._computed = None

return self._forward_cache

def _reduce_states(self, states: List[Dict[str, Union[list, Tensor]]]) -> None:
"""This function can be used to reduce a list of metric states.

Args:
states: List of metric states.
"""
for attr, reduction_fn in self._reductions.items():

values = [state[attr] for state in states]

if isinstance(values[0], Tensor):
if values[0].dim() > 0:
values = torch.stack(values)
else:
values = dim_zero_cat(values)
elif isinstance(values[0], list):
values = _flatten(values)

if not (callable(reduction_fn) or reduction_fn is None):
raise TypeError("reduction_fn must be callable or None")

reduced = reduction_fn(values) if reduction_fn is not None else values
setattr(self, attr, reduced)

def _sync_dist(self, dist_sync_fn: Callable = gather_all_tensors, process_group: Optional[Any] = None) -> None:
input_dict = {attr: getattr(self, attr) for attr in self._reductions}

Expand Down