Skip to content

Commit

Permalink
Merge branch 'master' into feature/bleuscore-weights
Browse files Browse the repository at this point in the history
  • Loading branch information
stancld authored Jun 7, 2022
2 parents f6c5cd4 + c04090b commit 8d8fb19
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 1 deletion.
18 changes: 18 additions & 0 deletions tests/bases/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from tests.helpers import seed_all
from tests.helpers.testers import DummyListMetric, DummyMetric, DummyMetricMultiOutput, DummyMetricSum
from tests.helpers.utilities import no_warning_call
from torchmetrics.utilities.imports import _TORCH_LOWER_1_6

seed_all(42)
Expand Down Expand Up @@ -423,3 +424,20 @@ class UnsetProperty(metric_class):
match="Torchmetrics v0.9 introduced a new argument class property called.*",
):
UnsetProperty()


@pytest.mark.parametrize("metric_class", [DummyListMetric, DummyMetric, DummyMetricMultiOutput, DummyMetricSum])
def test_no_warning_on_custom_forward(metric_class):
"""If metric is using custom forward, full_state_update is irrelevant."""

class UnsetProperty(metric_class):
full_state_update = None

def forward(self, *args, **kwargs):
self.update(*args, **kwargs)

with no_warning_call(
UserWarning,
match="Torchmetrics v0.9 introduced a new argument class property called.*",
):
UnsetProperty()
40 changes: 40 additions & 0 deletions tests/helpers/utilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from contextlib import contextmanager
from typing import Optional, Type

import pytest


@contextmanager
def no_warning_call(expected_warning: Type[Warning] = UserWarning, match: Optional[str] = None):
with pytest.warns(None) as record:
yield

if match is None:
try:
w = record.pop(expected_warning)
except AssertionError:
# no warning raised
return
else:
for w in record.list:
if w.category is expected_warning and re.compile(match).search(w.message.args[0]):
break
else:
return

msg = "A warning" if expected_warning is None else f"`{expected_warning.__name__}`"
raise AssertionError(f"{msg} was raised: {w}")
3 changes: 2 additions & 1 deletion torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torch.nn import Module

from torchmetrics.utilities import apply_to_collection, rank_zero_warn
from torchmetrics.utilities.checks import is_overridden
from torchmetrics.utilities.data import (
_flatten,
_squeeze_if_scalar,
Expand Down Expand Up @@ -127,7 +128,7 @@ def __init__(
self._is_synced = False
self._cache: Optional[Dict[str, Union[List[Tensor], Tensor]]] = None

if self.full_state_update is None:
if self.full_state_update is None and not is_overridden("forward", self, Metric):
rank_zero_warn(
f"""Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has
not been set for this class ({self.__class__.__name__}). The property determines if `update` by
Expand Down
27 changes: 27 additions & 0 deletions torchmetrics/utilities/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from time import perf_counter
from typing import Any, Dict, Mapping, Optional, Sequence, Tuple, no_type_check
from unittest.mock import Mock

import torch
from torch import Tensor
Expand Down Expand Up @@ -723,3 +725,28 @@ class PartState(metric_class):

faster = (mean[1, -1] < mean[0, -1]).item() # if faster on average, we recommend upgrading
print(f"Recommended setting `full_state_update={not faster}`")


def is_overridden(method_name: str, instance: object, parent: object) -> bool:
"""Check if a method has been overridden by an instance compared to its parent class."""
instance_attr = getattr(instance, method_name, None)
if instance_attr is None:
return False
# `functools.wraps()` support
if hasattr(instance_attr, "__wrapped__"):
instance_attr = instance_attr.__wrapped__
# `Mock(wraps=...)` support
if isinstance(instance_attr, Mock):
# access the wrapped function
instance_attr = instance_attr._mock_wraps
# `partial` support
elif isinstance(instance_attr, partial):
instance_attr = instance_attr.func
if instance_attr is None:
return False

parent_attr = getattr(parent, method_name, None)
if parent_attr is None:
raise ValueError("The parent should define the method")

return instance_attr.__code__ != parent_attr.__code__

0 comments on commit 8d8fb19

Please sign in to comment.