Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/remove double forward #984

Merged
merged 50 commits into from
May 10, 2022
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
1a89f1d
something is working
SkafteNicki Apr 20, 2022
d18b411
better
SkafteNicki Apr 21, 2022
ce221a3
fix tests
SkafteNicki Apr 25, 2022
335021b
Merge branch 'master' into refactor/remove_double_forward
SkafteNicki Apr 25, 2022
969da72
docstring
SkafteNicki Apr 25, 2022
d72dfa4
update docs
SkafteNicki Apr 25, 2022
77729b7
fix
SkafteNicki Apr 25, 2022
25fd83f
Merge branch 'master' into refactor/remove_double_forward
SkafteNicki Apr 26, 2022
f46cfa0
fix issue
SkafteNicki Apr 26, 2022
30c356a
Merge branch 'refactor/remove_double_forward' of https://github.com/P…
SkafteNicki Apr 26, 2022
b411548
fix some tests
SkafteNicki Apr 26, 2022
285cba4
Merge branch 'master' into refactor/remove_double_forward
SkafteNicki Apr 27, 2022
4f9d75b
introduce class property
SkafteNicki Apr 28, 2022
e2cca16
docs
SkafteNicki Apr 28, 2022
279a83c
Merge branch 'refactor/remove_double_forward' of https://github.com/P…
SkafteNicki Apr 28, 2022
967bf30
changelog
SkafteNicki Apr 28, 2022
ddc237b
fix docs
SkafteNicki Apr 28, 2022
ff011b8
update docs
SkafteNicki Apr 28, 2022
2abc26e
rename and re-order
SkafteNicki Apr 28, 2022
398521a
Merge branch 'master' into refactor/remove_double_forward
SkafteNicki Apr 28, 2022
99ff43a
fix list
SkafteNicki Apr 28, 2022
3b0c8cf
Merge branch 'refactor/remove_double_forward' of https://github.com/P…
SkafteNicki Apr 28, 2022
52a4615
Merge branch 'master' into refactor/remove_double_forward
SkafteNicki Apr 28, 2022
0ba7c37
Merge branch 'master' into refactor/remove_double_forward
SkafteNicki Apr 28, 2022
ff75e70
change impl
SkafteNicki May 4, 2022
274078f
Merge branch 'master' into refactor/remove_double_forward
SkafteNicki May 4, 2022
2eff3b6
fix tests
SkafteNicki May 4, 2022
87fa28f
regression
SkafteNicki May 4, 2022
fe58dc3
fix typing
SkafteNicki May 4, 2022
0670117
fix lightning integration
SkafteNicki May 4, 2022
08299ab
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 4, 2022
416b057
remove from test metric
SkafteNicki May 4, 2022
90c2254
update count
SkafteNicki May 5, 2022
7058907
Merge branch 'master' into refactor/remove_double_forward
SkafteNicki May 5, 2022
2790547
Merge branch 'master' into refactor/remove_double_forward
mergify[bot] May 5, 2022
15a4616
Merge branch 'master' into refactor/remove_double_forward
mergify[bot] May 5, 2022
a16b27b
Merge branch 'master' into refactor/remove_double_forward
Borda May 5, 2022
a310496
Merge branch 'master' into refactor/remove_double_forward
mergify[bot] May 5, 2022
5d4977f
Merge branch 'master' into refactor/remove_double_forward
mergify[bot] May 6, 2022
708c848
Merge branch 'master' into refactor/remove_double_forward
SkafteNicki May 6, 2022
c086f2c
audio
SkafteNicki May 6, 2022
5e8ebf1
fix mean reduction
SkafteNicki May 7, 2022
2ef10a6
refactor forward
SkafteNicki May 7, 2022
00d7fe8
classification
SkafteNicki May 7, 2022
aebff09
detection
SkafteNicki May 7, 2022
fcc5461
image
SkafteNicki May 7, 2022
3c5cacf
retrieval
SkafteNicki May 7, 2022
050cf86
text
SkafteNicki May 7, 2022
b8a7e17
Merge branch 'master' into refactor/remove_double_forward
mergify[bot] May 7, 2022
fc1e2c4
Merge branch 'master' into refactor/remove_double_forward
mergify[bot] May 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Reimplemented the `signal_distortion_ratio` metric, which removed the absolute requirement of `fast-bss-eval` ([#964](https://github.com/PyTorchLightning/metrics/pull/964))


-
- Changed `forward` method to only calling `update` a single time instead of twice ([#984](https://github.com/PyTorchLightning/metrics/pull/984))


### Deprecated
Expand Down
17 changes: 6 additions & 11 deletions docs/source/pages/implement.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,13 @@ The cache is first emptied on the next call to ``update``.

``forward`` serves the dual purpose of both returning the metric on the current data and updating the internal
metric state for accumulating over multiple batches. The ``forward()`` method achieves this by combining calls
to ``update`` and ``compute`` in the following way:

1. Calls ``update()`` to update the global metric state (for accumulation over multiple batches)
2. Caches the global state.
3. Calls ``reset()`` to clear global metric state.
4. Calls ``update()`` to update local metric state.
5. Calls ``compute()`` to calculate metric for current batch.
6. Restores the global state.

This procedure has the consequence of calling the user defined ``update`` **twice** during a single
forward call (one to update global statistics and one for getting the batch statistics).
to ``update``, ``compute`` and ``reset`` in the following way:

1. Caches the global state.
2. Calls ``reset`` the metric to its default state
3. Calls ``update`` to update the state with local batch statistics
4. Calls ``compute`` to calculate the metric for the current batch
5. Reduce the global state and batch state into a single state that becomes the new global state

---------

Expand Down
4 changes: 2 additions & 2 deletions tests/helpers/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ class DummyMetric(Metric):

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.add_state("x", tensor(0.0), dist_reduce_fx=None)
self.add_state("x", tensor(0.0), dist_reduce_fx="sum")

def update(self):
pass
Expand All @@ -583,7 +583,7 @@ class DummyListMetric(Metric):

def __init__(self):
super().__init__()
self.add_state("x", [], dist_reduce_fx=None)
self.add_state("x", [], dist_reduce_fx="cat")

def update(self):
pass
Expand Down
4 changes: 2 additions & 2 deletions torchmetrics/classification/roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ def __init__(
self.num_classes = num_classes
self.pos_label = pos_label

self.add_state("preds", default=[], dist_reduce_fx=None)
self.add_state("target", default=[], dist_reduce_fx=None)
self.add_state("preds", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")

rank_zero_warn(
"Metric `ROC` will save all targets and predictions in buffer."
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def __init__(
raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes")

default: Callable = lambda: []
reduce_fn: Optional[str] = None
reduce_fn: Optional[str] = "cat"
if mdmc_reduce != "samplewise" and reduce != "samples":
if reduce == "micro":
zeros_shape = []
Expand Down
4 changes: 2 additions & 2 deletions torchmetrics/image/psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ def __init__(
self.add_state("sum_squared_error", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")
else:
self.add_state("sum_squared_error", default=[])
self.add_state("total", default=[])
self.add_state("sum_squared_error", default=[], dist_reduce_fx="cat")
self.add_state("total", default=[], dist_reduce_fx="cat")

if data_range is None:
if dim is not None:
Expand Down
62 changes: 43 additions & 19 deletions torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,41 +197,38 @@ def add_state(

@torch.jit.unused
def forward(self, *args: Any, **kwargs: Any) -> Any:
"""Automatically calls ``update()``.
"""``forward`` serves the dual purpose of both computing the metric on the current batch of inputs but also
add the batch statistics to the overall accumululating metric state.

Returns the metric value over inputs if ``compute_on_step`` is True.
Input arguments are the exact same as corresponding ``update`` method. The returned output is the exact same as
the output of ``compute``.
"""
# add current step
# check if states are already synced
if self._is_synced:
raise TorchMetricsUserError(
"The Metric shouldn't be synced when performing ``update``. "
"The Metric shouldn't be synced when performing ``forward``. "
"HINT: Did you forget to call ``unsync`` ?."
)
# store global state and reset to default
global_state = {attr: getattr(self, attr) for attr in self._defaults.keys()}
self.reset()

# global accumulation
self.update(*args, **kwargs)

self._to_sync = self.dist_sync_on_step # type: ignore
# skip restore cache operation from compute as cache is stored below.
# local syncronization settings
self._to_sync = self.dist_sync_on_step
self._should_unsync = False
# skip computing on cpu for the batch
_temp_compute_on_cpu = self.compute_on_cpu
self.compute_on_cpu = False

# save context before switch
cache = {attr: getattr(self, attr) for attr in self._defaults}

# call reset, update, compute, on single batch
self._enable_grad = True # allow grads for batch computation
self.reset()

# calculate batch state and compute batch value
self.update(*args, **kwargs)
self._forward_cache = self.compute()

# reduce batch and global state
self._reduce_state(global_state)

# restore context
for attr, val in cache.items():
setattr(self, attr, val)
self._is_synced = False

self._should_unsync = True
self._to_sync = True
self._computed = None
Expand All @@ -240,6 +237,33 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:

return self._forward_cache

def _reduce_state(self, state_to_reduce: Dict[str, Any]) -> None:
for attr in self._defaults.keys():
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
current_state = getattr(self, attr)
incoming_state = state_to_reduce[attr]
reduce_fn = self._reductions[attr]
if reduce_fn == dim_zero_sum:
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
reduced = current_state + incoming_state
elif reduce_fn == dim_zero_mean:
reduced = (current_state + incoming_state) / 2.0
elif reduce_fn == dim_zero_max:
reduced = torch.max(current_state, incoming_state)
elif reduce_fn == dim_zero_min:
reduced = torch.min(current_state, incoming_state)
elif reduce_fn == dim_zero_cat: # or (reduce_fn is None and isinstance(current_state, list)):
reduced = incoming_state + current_state
elif reduce_fn is None and isinstance(current_state, Tensor):
if incoming_state.ndim > 0: # TODO: figure out why this works
reduced = torch.cat([current_state.expand(1), incoming_state], dim=0)
else:
reduced = torch.stack([current_state, incoming_state], dim=0)
elif reduce_fn is None and isinstance(current_state, list):
reduced = _flatten([current_state, incoming_state])
else:
reduced = reduce_fn(torch.stack([current_state, incoming_state])) # type: ignore

setattr(self, attr, reduced)

def _sync_dist(self, dist_sync_fn: Callable = gather_all_tensors, process_group: Optional[Any] = None) -> None:
input_dict = {attr: getattr(self, attr) for attr in self._reductions}

Expand Down