From bc5f9a95439defc1ad54da0cb036a9845d2e6bee Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 19 Apr 2021 10:56:58 +0200 Subject: [PATCH] allow MetricCollection with args (#176) * allow MetricCollection with args * format * chlog --- CHANGELOG.md | 1 + tests/bases/test_average.py | 3 + tests/bases/test_collections.py | 12 ++-- torchmetrics/average.py | 4 +- .../classification/binned_precision_recall.py | 11 ++-- torchmetrics/collections.py | 60 +++++++++++++------ 6 files changed, 58 insertions(+), 33 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f02b5d16dbb..53e746c0053 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,6 +42,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Updated FBeta arguments ([#111](https://github.com/PyTorchLightning/metrics/pull/111)) - Changed `reset` method to use `detach.clone()` instead of `deepcopy` when resetting to default ([#163](https://github.com/PyTorchLightning/metrics/pull/163)) - Metrics passed as dict to `MetricCollection` will now always be in deterministic order ([#173](https://github.com/PyTorchLightning/metrics/pull/173)) +- Allowed `MetricCollection` pass metrics as arguments ([#176](https://github.com/PyTorchLightning/metrics/pull/176)) ### Deprecated diff --git a/tests/bases/test_average.py b/tests/bases/test_average.py index 9c84caf8ddc..164e6845292 100644 --- a/tests/bases/test_average.py +++ b/tests/bases/test_average.py @@ -15,11 +15,13 @@ def average_ignore_weights(values, weights): class DefaultWeightWrapper(AverageMeter): + def update(self, values, weights): super().update(values) class ScalarWrapper(AverageMeter): + def update(self, values, weights): # torch.ravel is PyTorch 1.8 only, so use np.ravel instead values = values.cpu().numpy() @@ -37,6 +39,7 @@ def update(self, values, weights): ], ) class TestAverageMeter(MetricTester): + @pytest.mark.parametrize("ddp", [False, True]) @pytest.mark.parametrize("dist_sync_on_step", [False, True]) def test_average_fn(self, ddp, dist_sync_on_step, values, weights): diff --git a/tests/bases/test_collections.py b/tests/bases/test_collections.py index b9f8e6955b7..37c8cfbac50 100644 --- a/tests/bases/test_collections.py +++ b/tests/bases/test_collections.py @@ -85,23 +85,23 @@ def test_device_and_dtype_transfer_metriccollection(tmpdir): def test_metric_collection_wrong_input(tmpdir): """ Check that errors are raised on wrong input """ - m1 = DummyMetricSum() + dms = DummyMetricSum() # Not all input are metrics (list) with pytest.raises(ValueError): - _ = MetricCollection([m1, 5]) + _ = MetricCollection([dms, 5]) # Not all input are metrics (dict) with pytest.raises(ValueError): - _ = MetricCollection({'metric1': m1, 'metric2': 5}) + _ = MetricCollection({'metric1': dms, 'metric2': 5}) # Same metric passed in multiple times with pytest.raises(ValueError, match='Encountered two metrics both named *.'): - _ = MetricCollection([m1, m1]) + _ = MetricCollection([dms, dms]) # Not a list or dict passed in - with pytest.raises(ValueError, match='Unknown input to MetricCollection.'): - _ = MetricCollection(m1) + with pytest.warns(Warning, match=' which are not `Metric` so they will be ignored.'): + _ = MetricCollection(dms, [dms]) def test_metric_collection_args_kwargs(tmpdir): diff --git a/torchmetrics/average.py b/torchmetrics/average.py index 98621fa87e8..c13cb60ac64 100644 --- a/torchmetrics/average.py +++ b/torchmetrics/average.py @@ -78,9 +78,7 @@ def __init__( # TODO: need to be strings because Unions are not pickleable in Python 3.6 def update( # type: ignore - self, - value: "Union[Tensor, float]", - weight: "Union[Tensor, float]" = 1.0 + self, value: "Union[Tensor, float]", weight: "Union[Tensor, float]" = 1.0 ) -> None: """Updates the average with. diff --git a/torchmetrics/classification/binned_precision_recall.py b/torchmetrics/classification/binned_precision_recall.py index eff2bd3998c..702e182fd56 100644 --- a/torchmetrics/classification/binned_precision_recall.py +++ b/torchmetrics/classification/binned_precision_recall.py @@ -157,13 +157,10 @@ def compute(self) -> Tuple[Tensor, Tensor, Tensor]: recalls = self.TPs / (self.TPs + self.FNs + METRIC_EPS) # Need to guarantee that last precision=1 and recall=0, similar to precision_recall_curve - precisions = torch.cat([ - precisions, torch.ones(self.num_classes, 1, dtype=precisions.dtype, device=precisions.device) - ], - dim=1) - recalls = torch.cat([recalls, - torch.zeros(self.num_classes, 1, dtype=recalls.dtype, device=recalls.device)], - dim=1) + t_ones = torch.ones(self.num_classes, 1, dtype=precisions.dtype, device=precisions.device) + precisions = torch.cat([precisions, t_ones], dim=1) + t_zeros = torch.zeros(self.num_classes, 1, dtype=recalls.dtype, device=recalls.device) + recalls = torch.cat([recalls, t_zeros], dim=1) if self.num_classes == 1: return (precisions[0, :], recalls[0, :], self.thresholds) else: diff --git a/torchmetrics/collections.py b/torchmetrics/collections.py index abfceaa1280..80234738706 100644 --- a/torchmetrics/collections.py +++ b/torchmetrics/collections.py @@ -13,28 +13,29 @@ # limitations under the License. from copy import deepcopy -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Optional, Sequence, Union from torch import nn from torchmetrics.metric import Metric +from torchmetrics.utilities import rank_zero_warn class MetricCollection(nn.ModuleDict): """ - MetricCollection class can be used to chain metrics that have the same - call pattern into one single class. + MetricCollection class can be used to chain metrics that have the same call pattern into one single class. Args: metrics: One of the following - * list or tuple: if metrics are passed in as a list, will use the - metrics class name as key for output dict. Therefore, two metrics - of the same class cannot be chained this way. + * list or tuple (sequence): if metrics are passed in as a list or tuple, will use the metrics class name + as key for output dict. Therefore, two metrics of the same class cannot be chained this way. - * dict: if metrics are passed in as a dict, will use each key in the - dict as key for output dict. Use this format if you want to chain - together multiple of the same metric with different parameters. + * arguments: similar to passing in as a list, metrics passed in as arguments will use their metric + class name as key for the output dict. + + * dict: if metrics are passed in as a dict, will use each key in the dict as key for output dict. + Use this format if you want to chain together multiple of the same metric with different parameters. Note that the keys in the output dict will be sorted alphabetically. prefix: a string to append in front of the keys of the output dict @@ -46,6 +47,8 @@ class MetricCollection(nn.ModuleDict): If two elements in ``metrics`` have the same ``name``. ValueError: If ``metrics`` is not a ``list``, ``tuple`` or a ``dict``. + ValueError: + If ``metrics`` is is ``dict`` and passed any additional_metrics. Example (input as list): >>> import torch @@ -59,6 +62,12 @@ class MetricCollection(nn.ModuleDict): >>> metrics(preds, target) {'Accuracy': tensor(0.1250), 'Precision': tensor(0.0667), 'Recall': tensor(0.1111)} + Example (input as arguments): + >>> metrics = MetricCollection(Accuracy(), Precision(num_classes=3, average='macro'), + ... Recall(num_classes=3, average='macro')) + >>> metrics(preds, target) + {'Accuracy': tensor(0.1250), 'Precision': tensor(0.0667), 'Recall': tensor(0.1111)} + Example (input as dict): >>> metrics = MetricCollection({'micro_recall': Recall(num_classes=3, average='micro'), ... 'macro_recall': Recall(num_classes=3, average='macro')}) @@ -73,10 +82,31 @@ class MetricCollection(nn.ModuleDict): def __init__( self, - metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]], + metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]], + *additional_metrics: Metric, prefix: Optional[str] = None, ): super().__init__() + if isinstance(metrics, Metric): + # set compatible with original type expectations + metrics = [metrics] + if isinstance(metrics, Sequence): + # prepare for optional additions + metrics = list(metrics) + remain = [] + for m in additional_metrics: + (metrics if isinstance(m, Metric) else remain).append(m) + + if remain: + rank_zero_warn( + f"You have passes extra arguments {remain} which are not `Metric` so they will be ignored." + ) + elif additional_metrics: + raise ValueError( + f"You have passes extra arguments {additional_metrics} which are not compatible" + f" with first passed dictionary {metrics} so they will be ignored." + ) + if isinstance(metrics, dict): # Check all values are metrics # Make sure that metrics are added in deterministic order @@ -84,17 +114,13 @@ def __init__( metric = metrics[name] if not isinstance(metric, Metric): raise ValueError( - f"Value {metric} belonging to key {name}" - " is not an instance of `pl.metrics.Metric`" + f"Value {metric} belonging to key {name} is not an instance of `pl.metrics.Metric`" ) self[name] = metric - elif isinstance(metrics, (tuple, list)): + elif isinstance(metrics, Sequence): for metric in metrics: if not isinstance(metric, Metric): - raise ValueError( - f"Input {metric} to `MetricCollection` is not a instance" - " of `pl.metrics.Metric`" - ) + raise ValueError(f"Input {metric} to `MetricCollection` is not a instance of `pl.metrics.Metric`") name = metric.__class__.__name__ if name in self: raise ValueError(f"Encountered two metrics both named {name}")