Skip to content

Commit

Permalink
Fix Collection kwargs filtering (#707)
Browse files Browse the repository at this point in the history
* Fix Collection kwargs filtering
* Update test_collections.py
* Update CHANGELOG.md

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jan 4, 2022
1 parent 2e58596 commit 8dd7d91
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 2 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed MetricCollection kwargs filtering when no `kwargs` are present in update signature ([#707](https://github.com/PyTorchLightning/metrics/pull/707))



## [0.6.2] - 2021-12-15
Expand Down
29 changes: 29 additions & 0 deletions tests/bases/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

from tests.helpers import seed_all
from tests.helpers.testers import DummyMetricDiff, DummyMetricSum
from torchmetrics import Metric
from torchmetrics.classification import Accuracy
from torchmetrics.metric_collections import MetricCollection

seed_all(42)
Expand Down Expand Up @@ -249,3 +251,30 @@ def test_collection_check_arg():

with pytest.raises(ValueError, match="Expected input `postfix` to be a string, but got"):
MetricCollection._check_arg(1, "postfix")


def test_collection_filtering():
class DummyMetric(Metric):
def __init__(self):
super().__init__()

def update(self, *args, kwarg):
print("Entered DummyMetric")

def compute(self):
return

class MyAccuracy(Metric):
def __init__(self):
super().__init__()

def update(self, preds, target, kwarg2):
print("Entered MyAccuracy")

def compute(self):
return

mc = MetricCollection([Accuracy(), DummyMetric()])
mc2 = MetricCollection([MyAccuracy(), DummyMetric()])
mc(torch.tensor([0, 1]), torch.tensor([0, 1]), kwarg="kwarg")
mc2(torch.tensor([0, 1]), torch.tensor([0, 1]), kwarg="kwarg", kwarg2="kwarg2")
10 changes: 8 additions & 2 deletions torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,8 +562,14 @@ def _filter_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
k: v for k, v in kwargs.items() if (k in _sign_params.keys() and _sign_params[k].kind not in _params)
}

# if no kwargs filtered, return al kwargs as default
if not filtered_kwargs:
exists_var_keyword = any([v.kind == inspect.Parameter.VAR_KEYWORD for v in _sign_params.values()])
# if no kwargs filtered, return all kwargs as default
if not filtered_kwargs and not exists_var_keyword:
# no kwargs in update signature -> don't return any kwargs
filtered_kwargs = {}
elif exists_var_keyword:
# kwargs found in update signature -> return all kwargs to be sure to not omit any.
# filtering logic is likely implemented within the update call.
filtered_kwargs = kwargs
return filtered_kwargs

Expand Down

0 comments on commit 8dd7d91

Please sign in to comment.