From 10c461222e057aa4de2dffb3b11369c2f39fe7ae Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 21 Mar 2023 12:33:31 +0100 Subject: [PATCH 01/10] multimodal --- src/torchmetrics/multimodal/clip_score.py | 56 ++++++++++++++++++++--- 1 file changed, 50 insertions(+), 6 deletions(-) diff --git a/src/torchmetrics/multimodal/clip_score.py b/src/torchmetrics/multimodal/clip_score.py index a67ca8e923b..fb95ecb1cec 100644 --- a/src/torchmetrics/multimodal/clip_score.py +++ b/src/torchmetrics/multimodal/clip_score.py @@ -12,15 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Any, List, Union +from typing import Any, List, Optional, Sequence, Union import torch from torch import Tensor from typing_extensions import Literal +from torchmetrics import Metric from torchmetrics.functional.multimodal.clip_score import _clip_score_update, _get_model_and_processor from torchmetrics.utilities.checks import _SKIP_SLOW_DOCTEST, _try_proceed_with_timeout -from torchmetrics.utilities.imports import _TRANSFORMERS_AVAILABLE +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _PESQ_AVAILABLE, _TRANSFORMERS_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["CLIPScore.plot"] if _TRANSFORMERS_AVAILABLE: from transformers import CLIPModel as _CLIPModel @@ -31,11 +36,9 @@ def _download_clip() -> None: _CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") if _SKIP_SLOW_DOCTEST and not _try_proceed_with_timeout(_download_clip): - __doctest_skip__ = ["CLIPScore"] + __doctest_skip__ = ["CLIPScore", "CLIPScore.plot"] else: - __doctest_skip__ = ["CLIPScore"] - -from torchmetrics import Metric + __doctest_skip__ = ["CLIPScore", "CLIPScore.plot"] class CLIPScore(Metric): @@ -80,6 +83,8 @@ class CLIPScore(Metric): full_state_update: bool = True score: Tensor n_samples: Tensor + plot_lower_bound = 0.0 + plot_upper_bound = 100.0 def __init__( self, @@ -116,3 +121,42 @@ def update(self, images: Union[Tensor, List[Tensor]], text: Union[str, List[str] def compute(self) -> Tensor: """Compute accumulated clip score.""" return torch.max(self.score / self.n_samples, torch.zeros_like(self.score)) + + def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics.multimodal import CLIPScore + >>> metric = CLIPScore(model_name_or_path="openai/clip-vit-base-patch16") + >>> metric.update(torch.randint(255, (3, 224, 224)), "a photo of a cat") + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.multimodal import CLIPScore + >>> metric = CLIPScore(model_name_or_path="openai/clip-vit-base-patch16") + >>> values = [ ] + >>> for _ in range(10): + ... values.append(torch.randint(255, (3, 224, 224)), "a photo of a cat") + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) From 27228dcd00e4f5fff58ca6e8e2c28da01d6525d6 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 21 Mar 2023 12:34:05 +0100 Subject: [PATCH 02/10] wrappers --- src/torchmetrics/wrappers/bootstrapping.py | 48 ++++++++++++++++++- src/torchmetrics/wrappers/classwise.py | 50 ++++++++++++++++++- src/torchmetrics/wrappers/minmax.py | 50 ++++++++++++++++++- src/torchmetrics/wrappers/multioutput.py | 56 ++++++++++++++++++++-- src/torchmetrics/wrappers/tracker.py | 47 +++++++++++++++++- 5 files changed, 242 insertions(+), 9 deletions(-) diff --git a/src/torchmetrics/wrappers/bootstrapping.py b/src/torchmetrics/wrappers/bootstrapping.py index 1118ec0877b..e7ca461c0a0 100644 --- a/src/torchmetrics/wrappers/bootstrapping.py +++ b/src/torchmetrics/wrappers/bootstrapping.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Sequence, Union import torch from torch import Tensor @@ -20,6 +20,11 @@ from torchmetrics.metric import Metric from torchmetrics.utilities import apply_to_collection +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["BootStrapper.plot"] def _bootstrap_sampler( @@ -150,3 +155,44 @@ def compute(self) -> Dict[str, Tensor]: if self.raw: output_dict["raw"] = computed_vals return output_dict + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics import BootStrapper, MeanSquaredError + >>> metric = BootStrapper(MeanSquaredError(), num_bootstraps=20) + >>> metric.update(torch.randn(100,), torch.randn(100,)) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics import BootStrapper, MeanSquaredError + >>> metric = BootStrapper(MeanSquaredError(), num_bootstraps=20) + >>> values = [ ] + >>> for _ in range(3): + ... values.append(metric(torch.randn(100,), torch.randn(100,))) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/wrappers/classwise.py b/src/torchmetrics/wrappers/classwise.py index a39ca136598..8e9fefe8601 100644 --- a/src/torchmetrics/wrappers/classwise.py +++ b/src/torchmetrics/wrappers/classwise.py @@ -11,11 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Sequence, Union from torch import Tensor from torchmetrics import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["ClasswiseWrapper.plot"] class ClasswiseWrapper(Metric): @@ -114,3 +119,46 @@ def _wrap_update(self, update: Callable) -> Callable: def _wrap_compute(self, compute: Callable) -> Callable: """Overwrite to do nothing.""" return compute + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics import ClasswiseWrapper + >>> from torchmetrics.classification import MulticlassAccuracy + >>> metric = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None)) + >>> metric.update(torch.randint(3, (20,)), torch.randint(3, (20,))) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics import ClasswiseWrapper + >>> from torchmetrics.classification import MulticlassAccuracy + >>> metric = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None)) + >>> values = [ ] + >>> for _ in range(3): + ... values.append(metric(torch.randint(3, (20,)), torch.randint(3, (20,)))) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/wrappers/minmax.py b/src/torchmetrics/wrappers/minmax.py index 1d7bc811a94..da89e917c12 100644 --- a/src/torchmetrics/wrappers/minmax.py +++ b/src/torchmetrics/wrappers/minmax.py @@ -12,12 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Sequence, Union import torch from torch import Tensor from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["MinMaxMetric.plot"] class MinMaxMetric(Metric): @@ -103,3 +108,46 @@ def _is_suitable_val(val: Union[int, float, Tensor]) -> bool: if isinstance(val, Tensor): return val.numel() == 1 return False + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics import MinMaxMetric + >>> from torchmetrics.classification import BinaryAccuracy + >>> metric = MinMaxMetric(BinaryAccuracy()) + >>> metric.update(torch.randint(2, (20,)), torch.randint(2, (20,))) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics import MinMaxMetric + >>> from torchmetrics.classification import BinaryAccuracy + >>> metric = MinMaxMetric(BinaryAccuracy()) + >>> values = [ ] + >>> for _ in range(3): + ... values.append(metric(torch.randint(2, (20,)), torch.randint(2, (20,)))) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/wrappers/multioutput.py b/src/torchmetrics/wrappers/multioutput.py index 7f2ce8d2cf3..464fb3aa979 100644 --- a/src/torchmetrics/wrappers/multioutput.py +++ b/src/torchmetrics/wrappers/multioutput.py @@ -1,5 +1,5 @@ from copy import deepcopy -from typing import Any, Callable, List, Tuple +from typing import Any, Callable, List, Optional, Sequence, Tuple, Union import torch from torch import Tensor @@ -7,6 +7,11 @@ from torchmetrics import Metric from torchmetrics.utilities import apply_to_collection +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["MultioutputWrapper.plot"] def _get_nan_indices(*tensors: Tensor) -> Tensor: @@ -63,7 +68,7 @@ class MultioutputWrapper(Metric): >>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]]) >>> r2score = MultioutputWrapper(R2Score(), 2) >>> r2score(preds, target) - [tensor(0.9654), tensor(0.9082)] + tensor([0.9654, 0.9082]) """ is_differentiable = False @@ -109,9 +114,9 @@ def update(self, *args: Any, **kwargs: Any) -> None: for metric, (selected_args, selected_kwargs) in zip(self.metrics, reshaped_args_kwargs): metric.update(*selected_args, **selected_kwargs) - def compute(self) -> List[Tensor]: + def compute(self) -> Tensor: """Compute metrics.""" - return [m.compute() for m in self.metrics] + return torch.stack([m.compute() for m in self.metrics], 0) @torch.jit.unused def forward(self, *args: Any, **kwargs: Any) -> Any: @@ -125,7 +130,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: results.append(metric(*selected_args, **selected_kwargs)) if results[0] is None: return None - return results + return torch.stack(results, 0) def reset(self) -> None: """Reset all underlying metrics.""" @@ -140,3 +145,44 @@ def _wrap_update(self, update: Callable) -> Callable: def _wrap_compute(self, compute: Callable) -> Callable: """Overwrite to do nothing.""" return compute + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics import MultioutputWrapper, R2Score + >>> metric = MultioutputWrapper(R2Score(), 2) + >>> metric.update(torch.randn(20, 2), torch.randn(20, 2)) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics import MultioutputWrapper, R2Score + >>> metric = MultioutputWrapper(R2Score(), 2) + >>> values = [ ] + >>> for _ in range(3): + ... values.append(metric(torch.randn(20, 2), torch.randn(20, 2))) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/wrappers/tracker.py b/src/torchmetrics/wrappers/tracker.py index 7b61deca835..84fabf476ef 100644 --- a/src/torchmetrics/wrappers/tracker.py +++ b/src/torchmetrics/wrappers/tracker.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import torch from torch import Tensor @@ -20,8 +20,13 @@ from torchmetrics.collections import MetricCollection from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val from torchmetrics.utilities.prints import rank_zero_warn +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["MetricTracker.plot"] + class MetricTracker(ModuleList): """A wrapper class that can help keeping track of a metric or metric collection over time. @@ -259,3 +264,43 @@ def _check_for_increment(self, method: str) -> None: """Check that a metric that can be updated/used for computations has been intialized.""" if not self._increment_called: raise ValueError(f"`{method}` cannot be called before `.increment()` has been called.") + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics import MetricTracker + >>> from torchmetrics.classification import BinaryAccuracy + >>> tracker = MetricTracker(BinaryAccuracy()) + >>> for epoch in range(5): + ... tracker.increment() + ... for batch_idx in range(5): + ... tracker.update(torch.randint(2, (10,)), torch.randint(2, (10,))) + >>> fig_, ax_ = tracker.plot() # plot all epochs + + """ + val = val if val is not None else [val for val in self.compute_all()] + fig, ax = plot_single_or_multi_val( + val, + ax=ax, + name=self.__class__.__name__, + ) + return fig, ax From 92e8afce2ab951e48c59416ebaa3e9c39983bd36 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 21 Mar 2023 12:35:05 +0100 Subject: [PATCH 03/10] tests --- tests/unittests/multimodal/test_clip_score.py | 12 +++++++ tests/unittests/utilities/test_plot.py | 31 +++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/tests/unittests/multimodal/test_clip_score.py b/tests/unittests/multimodal/test_clip_score.py index d6eb757fd39..cf6f33e8c3b 100644 --- a/tests/unittests/multimodal/test_clip_score.py +++ b/tests/unittests/multimodal/test_clip_score.py @@ -14,6 +14,8 @@ from collections import namedtuple from functools import partial +import matplotlib +import matplotlib.pyplot as plt import pytest import torch from transformers import CLIPModel as _CLIPModel @@ -115,3 +117,13 @@ def test_error_on_wrong_image_format(self, input, model_name_or_path): ValueError, match="Expected all images to be 3d but found image that has either more or less" ): metric(torch.randint(255, (64, 64)), "28-year-old chef found dead in San Francisco mall") + + @skip_on_connection_issues() + def test_plot_method(self, input, model_name_or_path): + """Test the plot method of CLIPScore seperately in this file due to the skipping conditions.""" + metric = CLIPScore(model_name_or_path=model_name_or_path) + preds, target = input + metric.update(preds[0], target[0]) + fig, ax = metric.plot() + assert isinstance(fig, plt.Figure) + assert isinstance(ax, matplotlib.axes.Axes) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index dc4c742a87b..64a0655c40f 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -117,6 +117,7 @@ RetrievalRecallAtFixedPrecision, RetrievalRPrecision, ) +from torchmetrics.wrappers import BootStrapper, ClasswiseWrapper, MetricTracker, MinMaxMetric, MultioutputWrapper _rand_input = lambda: torch.rand(10) _binary_randint_input = lambda: torch.randint(2, (10,)) @@ -422,6 +423,24 @@ pytest.param(SymmetricMeanAbsolutePercentageError, _rand_input, _rand_input, id="symmetric mape"), pytest.param(TweedieDevianceScore, _rand_input, _rand_input, id="tweedie deviance score"), pytest.param(WeightedMeanAbsolutePercentageError, _rand_input, _rand_input, id="weighted mape"), + pytest.param( + partial(BootStrapper, base_metric=BinaryAccuracy()), _rand_input, _binary_randint_input, id="bootstrapper" + ), + pytest.param( + partial(ClasswiseWrapper, metric=MulticlassAccuracy(num_classes=3, average=None)), + _multiclass_randn_input, + _multiclass_randint_input, + id="classwise wrapper", + ), + pytest.param( + partial(MinMaxMetric, base_metric=BinaryAccuracy()), _rand_input, _binary_randint_input, id="minmax wrapper" + ), + pytest.param( + partial(MultioutputWrapper, base_metric=MeanSquaredError(), num_outputs=3), + _multilabel_rand_input, + _multilabel_rand_input, + id="multioutput wrapper", + ), ], ) @pytest.mark.parametrize("num_vals", [1, 5]) @@ -602,3 +621,15 @@ def test_confusion_matrix_plotter(metric_class, preds, target, labels, use_label cond1 = isinstance(axs, matplotlib.axes.Axes) cond2 = isinstance(axs, np.ndarray) and all(isinstance(a, matplotlib.axes.Axes) for a in axs) assert cond1 or cond2 + + +def test_tracker_plotter(): + """Test tracker that uses specialized plot function.""" + tracker = MetricTracker(BinaryAccuracy()) + for _ in range(5): + tracker.increment() + for _ in range(5): + tracker.update(torch.randint(2, (10,)), torch.randint(2, (10,))) + fig, ax = tracker.plot() # plot all epochs + assert isinstance(fig, plt.Figure) + assert isinstance(ax, matplotlib.axes.Axes) From 90f5320b67b9f143cc828337635104589ddf21a2 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Tue, 21 Mar 2023 19:58:14 +0100 Subject: [PATCH 04/10] Apply suggestions from code review --- tests/unittests/utilities/test_plot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index beb8e37cc87..342a4ca7be4 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -460,7 +460,7 @@ partial(MultioutputWrapper, base_metric=MeanSquaredError(), num_outputs=3), _multilabel_rand_input, _multilabel_rand_input, - id="multioutput wrapper", + id="multioutput wrapper",) pytest.param(Dice, _multiclass_randint_input, _multiclass_randint_input, id="dice"), pytest.param( partial(MulticlassExactMatch, num_classes=3), From bb775950ca707e2bf0cb9c7ca845e894c84efe7d Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Tue, 21 Mar 2023 19:58:47 +0100 Subject: [PATCH 05/10] typo --- tests/unittests/utilities/test_plot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 342a4ca7be4..cd0dfae6ce7 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -460,7 +460,7 @@ partial(MultioutputWrapper, base_metric=MeanSquaredError(), num_outputs=3), _multilabel_rand_input, _multilabel_rand_input, - id="multioutput wrapper",) + id="multioutput wrapper"), pytest.param(Dice, _multiclass_randint_input, _multiclass_randint_input, id="dice"), pytest.param( partial(MulticlassExactMatch, num_classes=3), From ef240daeabd3f95fb96f4ce679bcc0b189456b27 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 21 Mar 2023 18:59:30 +0000 Subject: [PATCH 06/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/unittests/utilities/test_plot.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index cd0dfae6ce7..94f8fa2aa2d 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -460,7 +460,8 @@ partial(MultioutputWrapper, base_metric=MeanSquaredError(), num_outputs=3), _multilabel_rand_input, _multilabel_rand_input, - id="multioutput wrapper"), + id="multioutput wrapper", + ), pytest.param(Dice, _multiclass_randint_input, _multiclass_randint_input, id="dice"), pytest.param( partial(MulticlassExactMatch, num_classes=3), From a84a0a12719e6b5e74f61842f6e928a8efe5f91c Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 24 Mar 2023 09:19:39 +0100 Subject: [PATCH 07/10] requirements --- requirements/docs.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/docs.txt b/requirements/docs.txt index 1209e6d16b0..5131098333f 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -17,3 +17,4 @@ sphinx-copybutton>=0.3 -r audio.txt -r detection.txt -r image.txt +-r multimodal.txt From 9340e4bef0f35a2de80c3cd9384fe16688b767a4 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 28 Mar 2023 17:00:34 +0200 Subject: [PATCH 08/10] fix --- src/torchmetrics/multimodal/clip_score.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/multimodal/clip_score.py b/src/torchmetrics/multimodal/clip_score.py index fb95ecb1cec..91a8174dc6d 100644 --- a/src/torchmetrics/multimodal/clip_score.py +++ b/src/torchmetrics/multimodal/clip_score.py @@ -156,7 +156,7 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_ >>> metric = CLIPScore(model_name_or_path="openai/clip-vit-base-patch16") >>> values = [ ] >>> for _ in range(10): - ... values.append(torch.randint(255, (3, 224, 224)), "a photo of a cat") + ... values.append(metric(torch.randint(255, (3, 224, 224)), "a photo of a cat")) >>> fig_, ax_ = metric.plot(values) """ return self._plot(val, ax) From 8bc6d808617abb35783e60568f04d1e93c6086e8 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 29 Mar 2023 08:09:24 +0200 Subject: [PATCH 09/10] fix --- tests/unittests/wrappers/test_tracker.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/unittests/wrappers/test_tracker.py b/tests/unittests/wrappers/test_tracker.py index 8a82a577b1a..d7b2bdfc866 100644 --- a/tests/unittests/wrappers/test_tracker.py +++ b/tests/unittests/wrappers/test_tracker.py @@ -197,7 +197,7 @@ def test_best_metric_for_not_well_defined_metric_collection(base_metric): "mae": MultioutputWrapper(MeanAbsoluteError(), num_outputs=2), } ), - list, + dict, ), ], ) @@ -212,5 +212,11 @@ def test_metric_tracker_and_collection_multioutput(input_to_tracker, assert_type all_res = tracker.compute_all() assert isinstance(all_res, assert_type) best_metric, which_epoch = tracker.best_metric(return_step=True) - assert best_metric is None - assert which_epoch is None + if isinstance(best_metric, dict): + for v in best_metric.values(): + assert v is None + for v in which_epoch.values(): + assert v is None + else: + assert best_metric is None + assert which_epoch is None From ffa165ece782e90d4e647bcb3f9d69cfb3560f2a Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 29 Mar 2023 08:10:44 +0200 Subject: [PATCH 10/10] changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index eaab96afb0b..f872b0f0077 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#1623](https://github.com/Lightning-AI/metrics/pull/1623), [#1638](https://github.com/Lightning-AI/metrics/pull/1638), [#1631](https://github.com/Lightning-AI/metrics/pull/1631), + [#1639](https://github.com/Lightning-AI/metrics/pull/1639), )