Skip to content

Commit

Permalink
allow MetricCollection with args (#176)
Browse files Browse the repository at this point in the history
* allow MetricCollection with args

* format

* chlog
  • Loading branch information
Borda authored Apr 19, 2021
1 parent 6116c45 commit bc5f9a9
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 33 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Updated FBeta arguments ([#111](https://github.com/PyTorchLightning/metrics/pull/111))
- Changed `reset` method to use `detach.clone()` instead of `deepcopy` when resetting to default ([#163](https://github.com/PyTorchLightning/metrics/pull/163))
- Metrics passed as dict to `MetricCollection` will now always be in deterministic order ([#173](https://github.com/PyTorchLightning/metrics/pull/173))
- Allowed `MetricCollection` pass metrics as arguments ([#176](https://github.com/PyTorchLightning/metrics/pull/176))

### Deprecated

Expand Down
3 changes: 3 additions & 0 deletions tests/bases/test_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@ def average_ignore_weights(values, weights):


class DefaultWeightWrapper(AverageMeter):

def update(self, values, weights):
super().update(values)


class ScalarWrapper(AverageMeter):

def update(self, values, weights):
# torch.ravel is PyTorch 1.8 only, so use np.ravel instead
values = values.cpu().numpy()
Expand All @@ -37,6 +39,7 @@ def update(self, values, weights):
],
)
class TestAverageMeter(MetricTester):

@pytest.mark.parametrize("ddp", [False, True])
@pytest.mark.parametrize("dist_sync_on_step", [False, True])
def test_average_fn(self, ddp, dist_sync_on_step, values, weights):
Expand Down
12 changes: 6 additions & 6 deletions tests/bases/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,23 +85,23 @@ def test_device_and_dtype_transfer_metriccollection(tmpdir):

def test_metric_collection_wrong_input(tmpdir):
""" Check that errors are raised on wrong input """
m1 = DummyMetricSum()
dms = DummyMetricSum()

# Not all input are metrics (list)
with pytest.raises(ValueError):
_ = MetricCollection([m1, 5])
_ = MetricCollection([dms, 5])

# Not all input are metrics (dict)
with pytest.raises(ValueError):
_ = MetricCollection({'metric1': m1, 'metric2': 5})
_ = MetricCollection({'metric1': dms, 'metric2': 5})

# Same metric passed in multiple times
with pytest.raises(ValueError, match='Encountered two metrics both named *.'):
_ = MetricCollection([m1, m1])
_ = MetricCollection([dms, dms])

# Not a list or dict passed in
with pytest.raises(ValueError, match='Unknown input to MetricCollection.'):
_ = MetricCollection(m1)
with pytest.warns(Warning, match=' which are not `Metric` so they will be ignored.'):
_ = MetricCollection(dms, [dms])


def test_metric_collection_args_kwargs(tmpdir):
Expand Down
4 changes: 1 addition & 3 deletions torchmetrics/average.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,7 @@ def __init__(

# TODO: need to be strings because Unions are not pickleable in Python 3.6
def update( # type: ignore
self,
value: "Union[Tensor, float]",
weight: "Union[Tensor, float]" = 1.0
self, value: "Union[Tensor, float]", weight: "Union[Tensor, float]" = 1.0
) -> None:
"""Updates the average with.
Expand Down
11 changes: 4 additions & 7 deletions torchmetrics/classification/binned_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,10 @@ def compute(self) -> Tuple[Tensor, Tensor, Tensor]:
recalls = self.TPs / (self.TPs + self.FNs + METRIC_EPS)

# Need to guarantee that last precision=1 and recall=0, similar to precision_recall_curve
precisions = torch.cat([
precisions, torch.ones(self.num_classes, 1, dtype=precisions.dtype, device=precisions.device)
],
dim=1)
recalls = torch.cat([recalls,
torch.zeros(self.num_classes, 1, dtype=recalls.dtype, device=recalls.device)],
dim=1)
t_ones = torch.ones(self.num_classes, 1, dtype=precisions.dtype, device=precisions.device)
precisions = torch.cat([precisions, t_ones], dim=1)
t_zeros = torch.zeros(self.num_classes, 1, dtype=recalls.dtype, device=recalls.device)
recalls = torch.cat([recalls, t_zeros], dim=1)
if self.num_classes == 1:
return (precisions[0, :], recalls[0, :], self.thresholds)
else:
Expand Down
60 changes: 43 additions & 17 deletions torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,29 @@
# limitations under the License.

from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, Optional, Sequence, Union

from torch import nn

from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn


class MetricCollection(nn.ModuleDict):
"""
MetricCollection class can be used to chain metrics that have the same
call pattern into one single class.
MetricCollection class can be used to chain metrics that have the same call pattern into one single class.
Args:
metrics: One of the following
* list or tuple: if metrics are passed in as a list, will use the
metrics class name as key for output dict. Therefore, two metrics
of the same class cannot be chained this way.
* list or tuple (sequence): if metrics are passed in as a list or tuple, will use the metrics class name
as key for output dict. Therefore, two metrics of the same class cannot be chained this way.
* dict: if metrics are passed in as a dict, will use each key in the
dict as key for output dict. Use this format if you want to chain
together multiple of the same metric with different parameters.
* arguments: similar to passing in as a list, metrics passed in as arguments will use their metric
class name as key for the output dict.
* dict: if metrics are passed in as a dict, will use each key in the dict as key for output dict.
Use this format if you want to chain together multiple of the same metric with different parameters.
Note that the keys in the output dict will be sorted alphabetically.
prefix: a string to append in front of the keys of the output dict
Expand All @@ -46,6 +47,8 @@ class MetricCollection(nn.ModuleDict):
If two elements in ``metrics`` have the same ``name``.
ValueError:
If ``metrics`` is not a ``list``, ``tuple`` or a ``dict``.
ValueError:
If ``metrics`` is is ``dict`` and passed any additional_metrics.
Example (input as list):
>>> import torch
Expand All @@ -59,6 +62,12 @@ class MetricCollection(nn.ModuleDict):
>>> metrics(preds, target)
{'Accuracy': tensor(0.1250), 'Precision': tensor(0.0667), 'Recall': tensor(0.1111)}
Example (input as arguments):
>>> metrics = MetricCollection(Accuracy(), Precision(num_classes=3, average='macro'),
... Recall(num_classes=3, average='macro'))
>>> metrics(preds, target)
{'Accuracy': tensor(0.1250), 'Precision': tensor(0.0667), 'Recall': tensor(0.1111)}
Example (input as dict):
>>> metrics = MetricCollection({'micro_recall': Recall(num_classes=3, average='micro'),
... 'macro_recall': Recall(num_classes=3, average='macro')})
Expand All @@ -73,28 +82,45 @@ class MetricCollection(nn.ModuleDict):

def __init__(
self,
metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]],
metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]],
*additional_metrics: Metric,
prefix: Optional[str] = None,
):
super().__init__()
if isinstance(metrics, Metric):
# set compatible with original type expectations
metrics = [metrics]
if isinstance(metrics, Sequence):
# prepare for optional additions
metrics = list(metrics)
remain = []
for m in additional_metrics:
(metrics if isinstance(m, Metric) else remain).append(m)

if remain:
rank_zero_warn(
f"You have passes extra arguments {remain} which are not `Metric` so they will be ignored."
)
elif additional_metrics:
raise ValueError(
f"You have passes extra arguments {additional_metrics} which are not compatible"
f" with first passed dictionary {metrics} so they will be ignored."
)

if isinstance(metrics, dict):
# Check all values are metrics
# Make sure that metrics are added in deterministic order
for name in sorted(metrics.keys()):
metric = metrics[name]
if not isinstance(metric, Metric):
raise ValueError(
f"Value {metric} belonging to key {name}"
" is not an instance of `pl.metrics.Metric`"
f"Value {metric} belonging to key {name} is not an instance of `pl.metrics.Metric`"
)
self[name] = metric
elif isinstance(metrics, (tuple, list)):
elif isinstance(metrics, Sequence):
for metric in metrics:
if not isinstance(metric, Metric):
raise ValueError(
f"Input {metric} to `MetricCollection` is not a instance"
" of `pl.metrics.Metric`"
)
raise ValueError(f"Input {metric} to `MetricCollection` is not a instance of `pl.metrics.Metric`")
name = metric.__class__.__name__
if name in self:
raise ValueError(f"Encountered two metrics both named {name}")
Expand Down

0 comments on commit bc5f9a9

Please sign in to comment.