Skip to content

Commit

Permalink
Fix multitask wrapper not being logged in lightning when used togethe…
Browse files Browse the repository at this point in the history
…r with collections (#2349)

* integration tests
* implementation
* better testing
* Apply suggestions from code review

---------

Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
  • Loading branch information
SkafteNicki and Borda authored Feb 12, 2024
1 parent 9253717 commit c94e21a
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 15 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed `MultitaskWrapper` not being able to be logged in lightning when using metric collections ([#2349](https://github.com/Lightning-AI/torchmetrics/pull/2349))


- Fixed high memory consumption in `Perplexity` metric ([#2346](https://github.com/Lightning-AI/torchmetrics/pull/2346))


Expand Down
50 changes: 41 additions & 9 deletions src/torchmetrics/wrappers/multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,17 +103,49 @@ def __init__(
super().__init__()
self.task_metrics = nn.ModuleDict(task_metrics)

def items(self) -> Iterable[Tuple[str, nn.Module]]:
"""Iterate over task and task metrics."""
return self.task_metrics.items()
def items(self, flatten: bool = True) -> Iterable[Tuple[str, nn.Module]]:
"""Iterate over task and task metrics.
def keys(self) -> Iterable[str]:
"""Iterate over task names."""
return self.task_metrics.keys()
Args:
flatten: If True, will iterate over all sub-metrics in the case of a MetricCollection.
If False, will iterate over the task names and the corresponding metrics.
"""
for task_name, metric in self.task_metrics.items():
if flatten and isinstance(metric, MetricCollection):
for sub_metric_name, sub_metric in metric.items():
yield f"{task_name}_{sub_metric_name}", sub_metric
else:
yield task_name, metric

def keys(self, flatten: bool = True) -> Iterable[str]:
"""Iterate over task names.
Args:
flatten: If True, will iterate over all sub-metrics in the case of a MetricCollection.
If False, will iterate over the task names and the corresponding metrics.
"""
for task_name, metric in self.task_metrics.items():
if flatten and isinstance(metric, MetricCollection):
for sub_metric_name in metric:
yield f"{task_name}_{sub_metric_name}"
else:
yield task_name

def values(self) -> Iterable[nn.Module]:
"""Iterate over task metrics."""
return self.task_metrics.values()
def values(self, flatten: bool = True) -> Iterable[nn.Module]:
"""Iterate over task metrics.
Args:
flatten: If True, will iterate over all sub-metrics in the case of a MetricCollection.
If False, will iterate over the task names and the corresponding metrics.
"""
for metric in self.task_metrics.values():
if flatten and isinstance(metric, MetricCollection):
yield from metric.values()
else:
yield metric

@staticmethod
def _check_task_metrics_type(task_metrics: Dict[str, Union[Metric, MetricCollection]]) -> None:
Expand Down
28 changes: 22 additions & 6 deletions tests/integrations/test_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from torchmetrics import MetricCollection
from torchmetrics.aggregation import SumMetric
from torchmetrics.classification import BinaryAccuracy, BinaryAveragePrecision
from torchmetrics.regression import MeanSquaredError
from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError
from torchmetrics.wrappers import MultitaskWrapper

from integrations.helpers import no_warning_call
Expand Down Expand Up @@ -366,22 +366,34 @@ def test_task_wrapper_lightning_logging(tmpdir):
class TestModel(BoringModel):
def __init__(self) -> None:
super().__init__()
self.metric = MultitaskWrapper({"classification": BinaryAccuracy(), "regression": MeanSquaredError()})
self.multitask = MultitaskWrapper({"classification": BinaryAccuracy(), "regression": MeanSquaredError()})
self.multitask_collection = MultitaskWrapper(
{
"classification": MetricCollection([BinaryAccuracy(), BinaryAveragePrecision()]),
"regression": MetricCollection([MeanSquaredError(), MeanAbsoluteError()]),
}
)

self.accuracy = BinaryAccuracy()
self.mse = MeanSquaredError()

def training_step(self, batch, batch_idx):
preds = torch.rand(10)
target = torch.rand(10)
self.metric(
{"classification": preds.round(), "regression": preds},
{"classification": target.round(), "regression": target},
self.multitask(
{"classification": preds, "regression": preds},
{"classification": target.round().int(), "regression": target},
)
self.multitask_collection(
{"classification": preds, "regression": preds},
{"classification": target.round().int(), "regression": target},
)
self.accuracy(preds.round(), target.round())
self.mse(preds, target)
self.log("accuracy", self.accuracy, on_epoch=True)
self.log("mse", self.mse, on_epoch=True)
self.log_dict(self.metric, on_epoch=True)
self.log_dict(self.multitask, on_epoch=True)
self.log_dict(self.multitask_collection, on_epoch=True)
return self.step(batch)

model = TestModel()
Expand All @@ -404,6 +416,10 @@ def training_step(self, batch, batch_idx):
assert torch.allclose(logged["accuracy_epoch"], logged["classification_epoch"])
assert torch.allclose(logged["mse_step"], logged["regression_step"])
assert torch.allclose(logged["mse_epoch"], logged["regression_epoch"])
assert "regression_MeanAbsoluteError_epoch" in logged
assert "regression_MeanSquaredError_epoch" in logged
assert "classification_BinaryAccuracy_epoch" in logged
assert "classification_BinaryAveragePrecision_epoch" in logged


def test_scriptable(tmpdir):
Expand Down
49 changes: 49 additions & 0 deletions tests/unittests/wrappers/test_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,55 @@ def test_nested_multitask_wrapper():
assert _dict_results_same_as_individual_results(classification_results, regression_results, multitask_results)


@pytest.mark.parametrize("method", ["keys", "items", "values"])
@pytest.mark.parametrize("flatten", [True, False])
def test_key_value_items_method(method, flatten):
"""Test the keys, items, and values methods of the MultitaskWrapper."""
multitask = MultitaskWrapper(
{
"classification": MetricCollection([BinaryAccuracy(), BinaryF1Score()]),
"regression": MetricCollection([MeanSquaredError(), MeanAbsoluteError()]),
}
)
if method == "keys":
output = list(multitask.keys(flatten=flatten))
elif method == "items":
output = list(multitask.items(flatten=flatten))
elif method == "values":
output = list(multitask.values(flatten=flatten))

if flatten:
assert len(output) == 4
if method == "keys":
assert output == [
"classification_BinaryAccuracy",
"classification_BinaryF1Score",
"regression_MeanSquaredError",
"regression_MeanAbsoluteError",
]
elif method == "items":
assert output == [
("classification_BinaryAccuracy", BinaryAccuracy()),
("classification_BinaryF1Score", BinaryF1Score()),
("regression_MeanSquaredError", MeanSquaredError()),
("regression_MeanAbsoluteError", MeanAbsoluteError()),
]
elif method == "values":
assert output == [BinaryAccuracy(), BinaryF1Score(), MeanSquaredError(), MeanAbsoluteError()]
else:
assert len(output) == 2
if method == "keys":
assert output == ["classification", "regression"]
elif method == "items":
assert output[0][0] == "classification"
assert output[1][0] == "regression"
assert isinstance(output[0][1], MetricCollection)
assert isinstance(output[1][1], MetricCollection)
elif method == "values":
assert isinstance(output[0], MetricCollection)
assert isinstance(output[1], MetricCollection)


def test_clone_with_prefix_and_postfix():
"""Check that the clone method works with prefix and postfix arguments."""
multitask_metrics = MultitaskWrapper({"Classification": BinaryAccuracy(), "Regression": MeanSquaredError()})
Expand Down

0 comments on commit c94e21a

Please sign in to comment.