Skip to content

Commit

Permalink
Device and dtype properties (#462)
Browse files Browse the repository at this point in the history
* add gpu testing
* change super
* move to metric + simplify
* fix bert
* update docs
* add typing
* changelog

* Apply suggestions from code review

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

(cherry picked from commit b10dba4)
  • Loading branch information
SkafteNicki authored and Borda committed Aug 30, 2021
1 parent aa78216 commit efdf20e
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 18 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added support for float targets in `nDCG` metric ([#437](https://github.com/PyTorchLightning/metrics/pull/437))


- Added `device` and `dtype` properties ([#462](https://github.com/PyTorchLightning/metrics/pull/462))


### Changed


Expand All @@ -37,6 +41,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed metric hashing ([#478](https://github.com/PyTorchLightning/metrics/pull/478))


- Fixed `BootStrapper` metrics not working on GPU ([#462](https://github.com/PyTorchLightning/metrics/pull/462))


- Fixed the semantic ordering of kernel height and width in `SSIM` metric ([#474](https://github.com/PyTorchLightning/metrics/pull/474))


Expand Down
4 changes: 4 additions & 0 deletions docs/source/pages/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ the native `MetricCollection`_ module can also be used to wrap multiple metrics.
val3 = self.metric3['accuracy'](preds, target)
val4 = self.metric4(preds, target)

You can always check which device the metric is located on using the `.device` property.

Metrics in Dataparallel (DP) mode
=================================

Expand Down Expand Up @@ -169,6 +171,8 @@ the following limitations:
- :ref:`references/modules:SSIM` and :ref:`references/functional:ssim [func]`
- :ref:`references/modules:KLDivergence` and :ref:`references/functional:kl_divergence [func]`

You can always check the precision/dtype of the metric by checking the `.dtype` property.

******************
Metric Arithmetics
******************
Expand Down
7 changes: 7 additions & 0 deletions tests/bases/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,20 +253,27 @@ def __init__(self):
def test_device_and_dtype_transfer(tmpdir):
metric = DummyMetricSum()
assert metric.x.is_cuda is False
assert metric.device == torch.device("cpu")
assert metric.x.dtype == torch.float32
assert metric.dtype == torch.float32

metric = metric.to(device="cuda")
assert metric.x.is_cuda
assert metric.device == torch.device("cuda")

metric = metric.double()
assert metric.x.dtype == torch.float64
assert metric.dtype == torch.float64
metric.reset()
assert metric.x.dtype == torch.float64
assert metric.dtype == torch.float64

metric = metric.half()
assert metric.x.dtype == torch.float16
assert metric.dtype == torch.float16
metric.reset()
assert metric.x.dtype == torch.float16
assert metric.dtype == torch.float16


def test_warning_on_compute_before_update():
Expand Down
39 changes: 26 additions & 13 deletions tests/wrappers/test_bootstrapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import operator
from functools import partial

import numpy as np
import pytest
import torch
from sklearn.metrics import precision_score, recall_score
from sklearn.metrics import mean_squared_error, precision_score, recall_score
from torch import Tensor

from torchmetrics.classification import Precision, Recall
from torchmetrics.regression import MeanSquaredError
from torchmetrics.utilities import apply_to_collection
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_7
from torchmetrics.wrappers.bootstrapping import BootStrapper, _bootstrap_sampler
Expand All @@ -36,7 +38,7 @@ def update(self, *args) -> None:
self.out = []
for idx in range(self.num_bootstraps):
size = len(args[0])
sample_idx = _bootstrap_sampler(size, sampling_strategy=self.sampling_strategy)
sample_idx = _bootstrap_sampler(size, sampling_strategy=self.sampling_strategy).to(self.device)
new_args = apply_to_collection(args, Tensor, torch.index_select, dim=0, index=sample_idx)
self.metrics[idx].update(*new_args)
self.out.append(new_args)
Expand Down Expand Up @@ -70,39 +72,50 @@ def test_bootstrap_sampler(sampling_strategy):
assert found_zero, "resampling did not work because all samples were atleast sampled once"


@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("sampling_strategy", ["poisson", "multinomial"])
@pytest.mark.parametrize(
"metric, sk_metric", [[Precision(average="micro"), precision_score], [Recall(average="micro"), recall_score]]
"metric, sk_metric",
[
[Precision(average="micro"), partial(precision_score, average="micro")],
[Recall(average="micro"), partial(recall_score, average="micro")],
[MeanSquaredError(), mean_squared_error],
],
)
def test_bootstrap(sampling_strategy, metric, sk_metric):
def test_bootstrap(device, sampling_strategy, metric, sk_metric):
"""Test that the different bootstraps gets updated as we expected and that the compute method works."""
if device == "cuda" and not torch.cuda.is_available():
pytest.skip("Test with device='cuda' requires gpu")

_kwargs = {"base_metric": metric, "mean": True, "std": True, "raw": True, "sampling_strategy": sampling_strategy}
if _TORCH_GREATER_EQUAL_1_7:
_kwargs.update(dict(quantile=torch.tensor([0.05, 0.95])))
_kwargs.update(dict(quantile=torch.tensor([0.05, 0.95], device=device)))

bootstrapper = TestBootStrapper(**_kwargs)
bootstrapper.to(device)

collected_preds = [[] for _ in range(10)]
collected_target = [[] for _ in range(10)]
for p, t in zip(_preds, _target):
p, t = p.to(device), t.to(device)
bootstrapper.update(p, t)

for i, o in enumerate(bootstrapper.out):

collected_preds[i].append(o[0])
collected_target[i].append(o[1])

collected_preds = [torch.cat(cp) for cp in collected_preds]
collected_target = [torch.cat(ct) for ct in collected_target]
collected_preds = [torch.cat(cp).cpu() for cp in collected_preds]
collected_target = [torch.cat(ct).cpu() for ct in collected_target]

sk_scores = [sk_metric(ct, cp, average="micro") for ct, cp in zip(collected_target, collected_preds)]
sk_scores = [sk_metric(ct, cp) for ct, cp in zip(collected_target, collected_preds)]

output = bootstrapper.compute()
# quantile only avaible for pytorch v1.7 and forward
if _TORCH_GREATER_EQUAL_1_7:
assert np.allclose(output["quantile"][0], np.quantile(sk_scores, 0.05))
assert np.allclose(output["quantile"][1], np.quantile(sk_scores, 0.95))
assert np.allclose(output["quantile"][0].cpu(), np.quantile(sk_scores, 0.05))
assert np.allclose(output["quantile"][1].cpu(), np.quantile(sk_scores, 0.95))

assert np.allclose(output["mean"], np.mean(sk_scores))
assert np.allclose(output["std"], np.std(sk_scores, ddof=1))
assert np.allclose(output["raw"], sk_scores)
assert np.allclose(output["mean"].cpu(), np.mean(sk_scores))
assert np.allclose(output["std"].cpu(), np.std(sk_scores, ddof=1))
assert np.allclose(output["raw"].cpu(), sk_scores)
80 changes: 78 additions & 2 deletions torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing import Any, Callable, Dict, Generator, List, Optional, Union

import torch
from torch import Tensor, nn
from torch import Tensor
from torch.nn import Module

from torchmetrics.utilities import apply_to_collection, rank_zero_warn
Expand All @@ -35,7 +35,7 @@ def jit_distributed_available() -> bool:
return torch.distributed.is_available() and torch.distributed.is_initialized()


class Metric(nn.Module, ABC):
class Metric(Module, ABC):
"""Base class for all metrics present in the Metrics API.
Implements ``add_state()``, ``forward()``, ``reset()`` and a few other things to
Expand Down Expand Up @@ -84,6 +84,8 @@ def __init__(
torch._C._log_api_usage_once(f"torchmetrics.metric.{self.__class__.__name__}")

self._LIGHTNING_GREATER_EQUAL_1_3 = _compare_version("pytorch_lightning", op.ge, "1.3.0")
self._dtype: Union[str, torch.dtype] = torch.get_default_dtype()
self._device = torch.device("cpu")

self.dist_sync_on_step = dist_sync_on_step
self.compute_on_step = compute_on_step
Expand Down Expand Up @@ -411,6 +413,80 @@ def __setstate__(self, state: Dict[str, Any]) -> None:
self.update: Callable = self._wrap_update(self.update) # type: ignore
self.compute: Callable = self._wrap_compute(self.compute) # type: ignore

@property
def dtype(self) -> "torch.dtype":
"""Return the dtype of the metric."""
return self._dtype

@dtype.setter
def dtype(self, new_dtype: "torch.dtype") -> None:
# necessary to avoid infinite recursion
raise RuntimeError("Cannot set the dtype explicitly. Please use module.to(new_dtype).")

@property
def device(self) -> "torch.device":
"""Return the device of the metric."""
return self._device

def to(self, *args: Any, **kwargs: Any) -> "Metric":
"""Moves and/or casts the parameters and buffers.
Works similar to nn.Module.to but also updates the metrics device and dtype properties
"""
# there is diff nb vars in PT 1.5
out = torch._C._nn._parse_to(*args, **kwargs)
self._update_properties(device=out[0], dtype=out[1])
return super().to(*args, **kwargs)

def cuda(self, device: Optional[Union[torch.device, int]] = None) -> "Metric":
"""Moves all model parameters and buffers to the GPU.
Arguments:
device: if specified, all parameters will be copied to that device
"""
if device is None or isinstance(device, int):
device = torch.device("cuda", index=device)
self._update_properties(device=device)
return super().cuda(device=device)

def cpu(self) -> "Metric":
"""Moves all model parameters and buffers to the CPU."""
self._update_properties(device=torch.device("cpu"))
return super().cpu()

def type(self, dst_type: Union[str, torch.dtype]) -> "Metric":
"""Casts all parameters and buffers to :attr:`dst_type`.
Arguments:
dst_type (type or string): the desired type
"""
self._update_properties(dtype=dst_type)
return super().type(dst_type=dst_type)

def float(self) -> "Metric":
"""Casts all floating point parameters and buffers to ``float`` datatype."""
self._update_properties(dtype=torch.float)
return super().float()

def double(self) -> "Metric":
"""Casts all floating point parameters and buffers to ``double`` datatype."""
self._update_properties(dtype=torch.double)
return super().double()

def half(self) -> "Metric":
"""Casts all floating point parameters and buffers to ``half`` datatype."""
self._update_properties(dtype=torch.half)
return super().half()

def _update_properties(
self, device: Optional[torch.device] = None, dtype: Optional[Union[str, torch.dtype]] = None
) -> None:
"""Updates the internal device and or dtype attributes of the metric."""
if device is not None:
self._device = device
if dtype is not None:
self._dtype = dtype

def _apply(self, fn: Callable) -> Module:
"""Overwrite _apply function such that we can also move metric states to the correct device when `.to`,
`.cuda`, etc methods are called."""
Expand Down
4 changes: 2 additions & 2 deletions torchmetrics/text/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(
self.all_layers = all_layers
self.num_threads = num_threads
self.batch_size = batch_size
self.device = device
self.embedding_device = device
self.idf = idf
self.verbose = verbose
self.num_layers = num_layers
Expand Down Expand Up @@ -128,7 +128,7 @@ def compute(self) -> Dict:
num_layers=self.num_layers,
verbose=self.verbose,
idf=self.idf,
device=self.device,
device=self.embedding_device,
baseline_path=self.baseline_path,
batch_size=self.batch_size,
lang=self.lang,
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/wrappers/bootstrapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def update(self, *args: Any, **kwargs: Any) -> None:
size = kwargs_sizes[0]
else:
raise ValueError("None of the input contained tensors, so could not determine the sampling size")
sample_idx = _bootstrap_sampler(size, sampling_strategy=self.sampling_strategy)
sample_idx = _bootstrap_sampler(size, sampling_strategy=self.sampling_strategy).to(self.device)
new_args = apply_to_collection(args, Tensor, torch.index_select, dim=0, index=sample_idx)
new_kwargs = apply_to_collection(kwargs, Tensor, torch.index_select, dim=0, index=sample_idx)
self.metrics[idx].update(*new_args, **new_kwargs)
Expand Down

0 comments on commit efdf20e

Please sign in to comment.