Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add plotting 16/n #1639

Merged
merged 16 commits into from
Mar 31, 2023
56 changes: 50 additions & 6 deletions src/torchmetrics/multimodal/clip_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, List, Union
from typing import Any, List, Optional, Sequence, Union

import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics import Metric
from torchmetrics.functional.multimodal.clip_score import _clip_score_update, _get_model_and_processor
from torchmetrics.utilities.checks import _SKIP_SLOW_DOCTEST, _try_proceed_with_timeout
from torchmetrics.utilities.imports import _TRANSFORMERS_AVAILABLE
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _PESQ_AVAILABLE, _TRANSFORMERS_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["CLIPScore.plot"]

if _TRANSFORMERS_AVAILABLE:
from transformers import CLIPModel as _CLIPModel
Expand All @@ -31,11 +36,9 @@ def _download_clip() -> None:
_CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")

if _SKIP_SLOW_DOCTEST and not _try_proceed_with_timeout(_download_clip):
__doctest_skip__ = ["CLIPScore"]
__doctest_skip__ = ["CLIPScore", "CLIPScore.plot"]
else:
__doctest_skip__ = ["CLIPScore"]

from torchmetrics import Metric
__doctest_skip__ = ["CLIPScore", "CLIPScore.plot"]


class CLIPScore(Metric):
Expand Down Expand Up @@ -80,6 +83,8 @@ class CLIPScore(Metric):
full_state_update: bool = True
score: Tensor
n_samples: Tensor
plot_lower_bound = 0.0
plot_upper_bound = 100.0

def __init__(
self,
Expand Down Expand Up @@ -116,3 +121,42 @@ def update(self, images: Union[Tensor, List[Tensor]], text: Union[str, List[str]
def compute(self) -> Tensor:
"""Compute accumulated clip score."""
return torch.max(self.score / self.n_samples, torch.zeros_like(self.score))

def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis

Returns:
Figure and Axes object

Raises:
ModuleNotFoundError:
If `matplotlib` is not installed

.. plot::
:scale: 75

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.multimodal import CLIPScore
>>> metric = CLIPScore(model_name_or_path="openai/clip-vit-base-patch16")
>>> metric.update(torch.randint(255, (3, 224, 224)), "a photo of a cat")
>>> fig_, ax_ = metric.plot()

.. plot::
:scale: 75

>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.multimodal import CLIPScore
>>> metric = CLIPScore(model_name_or_path="openai/clip-vit-base-patch16")
>>> values = [ ]
>>> for _ in range(10):
... values.append(torch.randint(255, (3, 224, 224)), "a photo of a cat")
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)
48 changes: 47 additions & 1 deletion src/torchmetrics/wrappers/bootstrapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional, Sequence, Union

import torch
from torch import Tensor
from torch.nn import ModuleList

from torchmetrics.metric import Metric
from torchmetrics.utilities import apply_to_collection
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["BootStrapper.plot"]


def _bootstrap_sampler(
Expand Down Expand Up @@ -150,3 +155,44 @@ def compute(self) -> Dict[str, Tensor]:
if self.raw:
output_dict["raw"] = computed_vals
return output_dict

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis

Returns:
Figure and Axes object

Raises:
ModuleNotFoundError:
If `matplotlib` is not installed

.. plot::
:scale: 75

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics import BootStrapper, MeanSquaredError
>>> metric = BootStrapper(MeanSquaredError(), num_bootstraps=20)
>>> metric.update(torch.randn(100,), torch.randn(100,))
>>> fig_, ax_ = metric.plot()

.. plot::
:scale: 75

>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics import BootStrapper, MeanSquaredError
>>> metric = BootStrapper(MeanSquaredError(), num_bootstraps=20)
>>> values = [ ]
>>> for _ in range(3):
... values.append(metric(torch.randn(100,), torch.randn(100,)))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)
50 changes: 49 additions & 1 deletion src/torchmetrics/wrappers/classwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Sequence, Union

from torch import Tensor

from torchmetrics import Metric
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["ClasswiseWrapper.plot"]


class ClasswiseWrapper(Metric):
Expand Down Expand Up @@ -114,3 +119,46 @@ def _wrap_update(self, update: Callable) -> Callable:
def _wrap_compute(self, compute: Callable) -> Callable:
"""Overwrite to do nothing."""
return compute

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis

Returns:
Figure and Axes object

Raises:
ModuleNotFoundError:
If `matplotlib` is not installed

.. plot::
:scale: 75

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics import ClasswiseWrapper
>>> from torchmetrics.classification import MulticlassAccuracy
>>> metric = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None))
>>> metric.update(torch.randint(3, (20,)), torch.randint(3, (20,)))
>>> fig_, ax_ = metric.plot()

.. plot::
:scale: 75

>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics import ClasswiseWrapper
>>> from torchmetrics.classification import MulticlassAccuracy
>>> metric = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None))
>>> values = [ ]
>>> for _ in range(3):
... values.append(metric(torch.randint(3, (20,)), torch.randint(3, (20,))))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)
50 changes: 49 additions & 1 deletion src/torchmetrics/wrappers/minmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional, Sequence, Union

import torch
from torch import Tensor

from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["MinMaxMetric.plot"]


class MinMaxMetric(Metric):
Expand Down Expand Up @@ -103,3 +108,46 @@ def _is_suitable_val(val: Union[int, float, Tensor]) -> bool:
if isinstance(val, Tensor):
return val.numel() == 1
return False

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis

Returns:
Figure and Axes object

Raises:
ModuleNotFoundError:
If `matplotlib` is not installed

.. plot::
:scale: 75

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics import MinMaxMetric
>>> from torchmetrics.classification import BinaryAccuracy
>>> metric = MinMaxMetric(BinaryAccuracy())
>>> metric.update(torch.randint(2, (20,)), torch.randint(2, (20,)))
>>> fig_, ax_ = metric.plot()

.. plot::
:scale: 75

>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics import MinMaxMetric
>>> from torchmetrics.classification import BinaryAccuracy
>>> metric = MinMaxMetric(BinaryAccuracy())
>>> values = [ ]
>>> for _ in range(3):
... values.append(metric(torch.randint(2, (20,)), torch.randint(2, (20,))))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)
56 changes: 51 additions & 5 deletions src/torchmetrics/wrappers/multioutput.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from copy import deepcopy
from typing import Any, Callable, List, Tuple
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union

import torch
from torch import Tensor
from torch.nn import ModuleList

from torchmetrics import Metric
from torchmetrics.utilities import apply_to_collection
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["MultioutputWrapper.plot"]


def _get_nan_indices(*tensors: Tensor) -> Tensor:
Expand Down Expand Up @@ -63,7 +68,7 @@ class MultioutputWrapper(Metric):
>>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]])
>>> r2score = MultioutputWrapper(R2Score(), 2)
>>> r2score(preds, target)
[tensor(0.9654), tensor(0.9082)]
tensor([0.9654, 0.9082])
"""

is_differentiable = False
Expand Down Expand Up @@ -109,9 +114,9 @@ def update(self, *args: Any, **kwargs: Any) -> None:
for metric, (selected_args, selected_kwargs) in zip(self.metrics, reshaped_args_kwargs):
metric.update(*selected_args, **selected_kwargs)

def compute(self) -> List[Tensor]:
def compute(self) -> Tensor:
"""Compute metrics."""
return [m.compute() for m in self.metrics]
return torch.stack([m.compute() for m in self.metrics], 0)

@torch.jit.unused
def forward(self, *args: Any, **kwargs: Any) -> Any:
Expand All @@ -125,7 +130,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
results.append(metric(*selected_args, **selected_kwargs))
if results[0] is None:
return None
return results
return torch.stack(results, 0)

def reset(self) -> None:
"""Reset all underlying metrics."""
Expand All @@ -140,3 +145,44 @@ def _wrap_update(self, update: Callable) -> Callable:
def _wrap_compute(self, compute: Callable) -> Callable:
"""Overwrite to do nothing."""
return compute

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis

Returns:
Figure and Axes object

Raises:
ModuleNotFoundError:
If `matplotlib` is not installed

.. plot::
:scale: 75

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics import MultioutputWrapper, R2Score
>>> metric = MultioutputWrapper(R2Score(), 2)
>>> metric.update(torch.randn(20, 2), torch.randn(20, 2))
>>> fig_, ax_ = metric.plot()

.. plot::
:scale: 75

>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics import MultioutputWrapper, R2Score
>>> metric = MultioutputWrapper(R2Score(), 2)
>>> values = [ ]
>>> for _ in range(3):
... values.append(metric(torch.randn(20, 2), torch.randn(20, 2)))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)
Loading